diff --git a/spring-test/src/main/java/org/springframework/mock/web/reactive/function/server/MockServerRequest.java b/spring-test/src/main/java/org/springframework/mock/web/reactive/function/server/MockServerRequest.java index 842e2e8f49..80ccb56d9d 100644 --- a/spring-test/src/main/java/org/springframework/mock/web/reactive/function/server/MockServerRequest.java +++ b/spring-test/src/main/java/org/springframework/mock/web/reactive/function/server/MockServerRequest.java @@ -40,6 +40,7 @@ import org.springframework.http.HttpMethod; import org.springframework.http.HttpRange; import org.springframework.http.HttpRequest; import org.springframework.http.MediaType; +import org.springframework.http.codec.HttpMessageReader; import org.springframework.http.codec.multipart.Part; import org.springframework.http.server.PathContainer; import org.springframework.http.server.RequestPath; @@ -50,7 +51,9 @@ import org.springframework.util.CollectionUtils; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; import org.springframework.web.reactive.function.BodyExtractor; +import org.springframework.web.reactive.function.server.HandlerStrategies; import org.springframework.web.reactive.function.server.ServerRequest; +import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.WebSession; import org.springframework.web.util.UriBuilder; import org.springframework.web.util.UriComponentsBuilder; @@ -91,12 +94,18 @@ public class MockServerRequest implements ServerRequest { @Nullable private final InetSocketAddress remoteAddress; + private final List> messageReaders; + + @Nullable + private final ServerWebExchange exchange; + private MockServerRequest(HttpMethod method, URI uri, String contextPath, MockHeaders headers, MultiValueMap cookies, @Nullable Object body, Map attributes, MultiValueMap queryParams, Map pathVariables, @Nullable WebSession session, @Nullable Principal principal, - @Nullable InetSocketAddress remoteAddress) { + @Nullable InetSocketAddress remoteAddress, List> messageReaders, + @Nullable ServerWebExchange exchange) { this.method = method; this.uri = uri; @@ -110,6 +119,8 @@ public class MockServerRequest implements ServerRequest { this.session = session; this.principal = principal; this.remoteAddress = remoteAddress; + this.messageReaders = messageReaders; + this.exchange = exchange; } @@ -153,6 +164,11 @@ public class MockServerRequest implements ServerRequest { return Optional.ofNullable(this.remoteAddress); } + @Override + public List> messageReaders() { + return this.messageReaders; + } + @Override @SuppressWarnings("unchecked") public S body(BodyExtractor extractor) { @@ -220,7 +236,6 @@ public class MockServerRequest implements ServerRequest { return Mono.justOrEmpty(this.principal); } - @Override @SuppressWarnings("unchecked") public Mono> formData() { @@ -235,6 +250,12 @@ public class MockServerRequest implements ServerRequest { return (Mono>) this.body; } + @Override + public ServerWebExchange exchange() { + Assert.state(this.exchange != null, "No exchange"); + return this.exchange; + } + public static Builder builder() { return new BuilderImpl(); } @@ -282,6 +303,10 @@ public class MockServerRequest implements ServerRequest { Builder remoteAddress(InetSocketAddress remoteAddress); + Builder messageReaders(List> messageReaders); + + Builder exchange(ServerWebExchange exchange); + MockServerRequest body(Object body); MockServerRequest build(); @@ -318,6 +343,11 @@ public class MockServerRequest implements ServerRequest { @Nullable private InetSocketAddress remoteAddress; + private List> messageReaders = HandlerStrategies.withDefaults().messageReaders(); + + @Nullable + private ServerWebExchange exchange; + @Override public Builder method(HttpMethod method) { Assert.notNull(method, "'method' must not be null"); @@ -440,19 +470,35 @@ public class MockServerRequest implements ServerRequest { return this; } + @Override + public Builder messageReaders(List> messageReaders) { + Assert.notNull(messageReaders, "'messageReaders' must not be null"); + this.messageReaders = messageReaders; + return this; + } + + @Override + public Builder exchange(ServerWebExchange exchange) { + Assert.notNull(exchange, "'exchange' must not be null"); + this.exchange = exchange; + return this; + } + @Override public MockServerRequest body(Object body) { this.body = body; return new MockServerRequest(this.method, this.uri, this.contextPath, this.headers, this.cookies, this.body, this.attributes, this.queryParams, this.pathVariables, - this.session, this.principal, this.remoteAddress); + this.session, this.principal, this.remoteAddress, this.messageReaders, + this.exchange); } @Override public MockServerRequest build() { return new MockServerRequest(this.method, this.uri, this.contextPath, this.headers, this.cookies, null, this.attributes, this.queryParams, this.pathVariables, - this.session, this.principal, this.remoteAddress); + this.session, this.principal, this.remoteAddress, this.messageReaders, + this.exchange); } } diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/DefaultServerRequest.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/DefaultServerRequest.java index dd6c4aaf0e..cfcbf0f8ec 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/DefaultServerRequest.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/DefaultServerRequest.java @@ -120,6 +120,11 @@ class DefaultServerRequest implements ServerRequest { return Optional.ofNullable(request().getRemoteAddress()); } + @Override + public List> messageReaders() { + return this.messageReaders; + } + @Override public T body(BodyExtractor extractor) { return body(extractor, Collections.emptyMap()); @@ -208,7 +213,8 @@ class DefaultServerRequest implements ServerRequest { return this.exchange.getRequest(); } - ServerWebExchange exchange() { + @Override + public ServerWebExchange exchange() { return this.exchange; } diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/DefaultServerRequestBuilder.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/DefaultServerRequestBuilder.java new file mode 100644 index 0000000000..cab2e6958a --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/DefaultServerRequestBuilder.java @@ -0,0 +1,444 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.function.server; + +import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.security.Principal; +import java.time.Instant; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.context.ApplicationContext; +import org.springframework.context.i18n.LocaleContext; +import org.springframework.core.ResolvableType; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferFactory; +import org.springframework.core.io.buffer.DataBufferUtils; +import org.springframework.core.io.buffer.DefaultDataBufferFactory; +import org.springframework.http.HttpCookie; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.InvalidMediaTypeException; +import org.springframework.http.MediaType; +import org.springframework.http.codec.HttpMessageReader; +import org.springframework.http.codec.multipart.Part; +import org.springframework.http.server.RequestPath; +import org.springframework.http.server.reactive.ServerHttpRequest; +import org.springframework.http.server.reactive.ServerHttpResponse; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.util.StringUtils; +import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.server.WebSession; +import org.springframework.web.util.UriUtils; + +/** + * Default {@link ServerRequest.Builder} implementation. + * + * @author Arjen Poutsma + * @since 5.1 + */ +class DefaultServerRequestBuilder implements ServerRequest.Builder { + + private final HttpHeaders headers = new HttpHeaders(); + + private final MultiValueMap cookies = new LinkedMultiValueMap<>(); + + private final Map attributes = new LinkedHashMap<>(); + + private final List> messageReaders; + + private ServerWebExchange exchange; + + private HttpMethod method = HttpMethod.GET; + + @Nullable + private URI uri; + + private Flux body = Flux.empty(); + + public DefaultServerRequestBuilder(ServerRequest other) { + Assert.notNull(other, "ServerRequest must not be null"); + this.messageReaders = other.messageReaders(); + this.exchange = other.exchange(); + method(other.method()); + uri(other.uri()); + headers(headers -> headers.addAll(other.headers().asHttpHeaders())); + cookies(cookies -> cookies.addAll(other.cookies())); + } + + @Override + public ServerRequest.Builder method(HttpMethod method) { + Assert.notNull(method, "'method' must not be null"); + this.method = method; + return this; + } + + @Override + public ServerRequest.Builder uri(URI uri) { + Assert.notNull(uri, "'uri' must not be null"); + this.uri = uri; + return this; + } + + @Override + public ServerRequest.Builder header(String headerName, String... headerValues) { + for (String headerValue : headerValues) { + this.headers.add(headerName, headerValue); + } + return this; + } + + @Override + public ServerRequest.Builder headers(Consumer headersConsumer) { + Assert.notNull(headersConsumer, "'headersConsumer' must not be null"); + headersConsumer.accept(this.headers); + return this; + } + + @Override + public ServerRequest.Builder cookie(String name, String... values) { + for (String value : values) { + this.cookies.add(name, new HttpCookie(name, value)); + } + return this; + } + + @Override + public ServerRequest.Builder cookies( + Consumer> cookiesConsumer) { + + Assert.notNull(cookiesConsumer, "'cookiesConsumer' must not be null"); + cookiesConsumer.accept(this.cookies); + return this; + } + + @Override + public ServerRequest.Builder body(Flux body) { + Assert.notNull(body, "'body' must not be null"); + releaseBody(); + this.body = body; + return this; + } + + @Override + public ServerRequest.Builder body(String body) { + Assert.notNull(body, "'body' must not be null"); + releaseBody(); + DataBufferFactory dataBufferFactory = new DefaultDataBufferFactory(); + this.body = Flux.just(body). + map(s -> { + byte[] bytes = body.getBytes(StandardCharsets.UTF_8); + return dataBufferFactory.wrap(bytes); + }); + return this; + } + + private void releaseBody() { + this.body.subscribe(DataBufferUtils.releaseConsumer()); + } + + @Override + public ServerRequest.Builder attribute(String name, Object value) { + Assert.notNull(name, "'name' must not be null"); + this.attributes.put(name, value); + return this; + } + + @Override + public ServerRequest.Builder attributes(Consumer> attributesConsumer) { + Assert.notNull(attributesConsumer, "'attributesConsumer' must not be null"); + attributesConsumer.accept(this.attributes); + return this; + } + + @Override + public ServerRequest build() { + ServerHttpRequest serverHttpRequest = new BuiltServerHttpRequest(this.method, this.uri, + this.headers, this.cookies, this.body); + ServerWebExchange exchange = new DelegatingServerWebExchange(serverHttpRequest, + this.exchange, this.messageReaders); + return new DefaultServerRequest(exchange, this.messageReaders); + } + + private static class BuiltServerHttpRequest implements ServerHttpRequest { + + private static final Pattern QUERY_PATTERN = Pattern.compile("([^&=]+)(=?)([^&]+)?"); + + + private final HttpMethod method; + + private final URI uri; + + private final RequestPath path; + + private final MultiValueMap queryParams; + + private final HttpHeaders headers; + + private final MultiValueMap cookies; + + private final Flux body; + + public BuiltServerHttpRequest(HttpMethod method, URI uri, + HttpHeaders headers, + MultiValueMap cookies, + Flux body) { + this.method = method; + this.uri = uri; + this.path = RequestPath.parse(uri, null); + this.headers = HttpHeaders.readOnlyHttpHeaders(headers); + this.cookies = unmodifiableCopy(cookies); + this.queryParams = parseQueryParams(uri); + this.body = body; + } + + private static MultiValueMap unmodifiableCopy(MultiValueMap original) { + return CollectionUtils.unmodifiableMultiValueMap(new LinkedMultiValueMap<>(original)); + } + + private static MultiValueMap parseQueryParams(URI uri) { + MultiValueMap queryParams = new LinkedMultiValueMap<>(); + String query = uri.getRawQuery(); + if (query != null) { + Matcher matcher = QUERY_PATTERN.matcher(query); + while (matcher.find()) { + String name = UriUtils.decode(matcher.group(1), StandardCharsets.UTF_8); + String eq = matcher.group(2); + String value = matcher.group(3); + if (value != null) { + value = UriUtils.decode(value, StandardCharsets.UTF_8); + } + else { + value = StringUtils.hasLength(eq) ? "" : null; + } + queryParams.add(name, value); + } + } + return queryParams; + } + + @Nullable + @Override + public HttpMethod getMethod() { + return this.method; + } + + @Override + public String getMethodValue() { + return this.method.name(); + } + + @Override + public URI getURI() { + return this.uri; + } + + @Override + public RequestPath getPath() { + return this.path; + } + + @Override + public HttpHeaders getHeaders() { + return this.headers; + } + + @Override + public MultiValueMap getCookies() { + return this.cookies; + } + + @Override + public MultiValueMap getQueryParams() { + return this.queryParams; + } + + @Override + public Flux getBody() { + return this.body; + } + } + + private static class DelegatingServerWebExchange implements ServerWebExchange { + + private static final ResolvableType FORM_DATA_TYPE = + ResolvableType.forClassWithGenerics(MultiValueMap.class, String.class, String.class); + + private static final ResolvableType MULTIPART_DATA_TYPE = ResolvableType.forClassWithGenerics( + MultiValueMap.class, String.class, Part.class); + + private static final Mono> EMPTY_FORM_DATA = + Mono.just(CollectionUtils.unmodifiableMultiValueMap(new LinkedMultiValueMap(0))) + .cache(); + + private static final Mono> EMPTY_MULTIPART_DATA = + Mono.just(CollectionUtils.unmodifiableMultiValueMap(new LinkedMultiValueMap(0))) + .cache(); + + private final ServerHttpRequest request; + + private final ServerWebExchange delegate; + + private final Mono> formDataMono; + + private final Mono> multipartDataMono; + + public DelegatingServerWebExchange(ServerHttpRequest request, ServerWebExchange delegate, + List> messageReaders) { + this.request = request; + this.delegate = delegate; + this.formDataMono = initFormData(request, messageReaders); + this.multipartDataMono = initMultipartData(request, messageReaders); + } + + @SuppressWarnings("unchecked") + private static Mono> initFormData(ServerHttpRequest request, + List> readers) { + + try { + MediaType contentType = request.getHeaders().getContentType(); + if (MediaType.APPLICATION_FORM_URLENCODED.isCompatibleWith(contentType)) { + return ((HttpMessageReader>) readers.stream() + .filter(reader -> reader.canRead(FORM_DATA_TYPE, MediaType.APPLICATION_FORM_URLENCODED)) + .findFirst() + .orElseThrow(() -> new IllegalStateException("No form data HttpMessageReader."))) + .readMono(FORM_DATA_TYPE, request, Collections.emptyMap()) + .switchIfEmpty(EMPTY_FORM_DATA) + .cache(); + } + } + catch (InvalidMediaTypeException ex) { + // Ignore + } + return EMPTY_FORM_DATA; + } + + @SuppressWarnings("unchecked") + private static Mono> initMultipartData(ServerHttpRequest request, + List> readers) { + + try { + MediaType contentType = request.getHeaders().getContentType(); + if (MediaType.MULTIPART_FORM_DATA.isCompatibleWith(contentType)) { + return ((HttpMessageReader>) readers.stream() + .filter(reader -> reader.canRead(MULTIPART_DATA_TYPE, MediaType.MULTIPART_FORM_DATA)) + .findFirst() + .orElseThrow(() -> new IllegalStateException("No multipart HttpMessageReader."))) + .readMono(MULTIPART_DATA_TYPE, request, Collections.emptyMap()) + .switchIfEmpty(EMPTY_MULTIPART_DATA) + .cache(); + } + } + catch (InvalidMediaTypeException ex) { + // Ignore + } + return EMPTY_MULTIPART_DATA; + } + @Override + public ServerHttpRequest getRequest() { + return this.request; + } + + @Override + public Mono> getFormData() { + return this.formDataMono; + } + + @Override + public Mono> getMultipartData() { + return this.multipartDataMono; + } + + // Delegating methods + + @Override + public ServerHttpResponse getResponse() { + return this.delegate.getResponse(); + } + + @Override + public Map getAttributes() { + return this.delegate.getAttributes(); + } + + @Override + public Mono getSession() { + return this.delegate.getSession(); + } + + @Override + public Mono getPrincipal() { + return this.delegate.getPrincipal(); + } + + + @Override + public LocaleContext getLocaleContext() { + return this.delegate.getLocaleContext(); + } + + @Nullable + @Override + public ApplicationContext getApplicationContext() { + return this.delegate.getApplicationContext(); + } + + @Override + public boolean isNotModified() { + return this.delegate.isNotModified(); + } + + @Override + public boolean checkNotModified(Instant lastModified) { + return this.delegate.checkNotModified(lastModified); + } + + @Override + public boolean checkNotModified(String etag) { + return this.delegate.checkNotModified(etag); + } + + @Override + public boolean checkNotModified(@Nullable String etag, Instant lastModified) { + return this.delegate.checkNotModified(etag, lastModified); + } + + @Override + public String transformUrl(String url) { + return this.delegate.transformUrl(url); + } + + @Override + public void addUrlTransformer(Function transformer) { + this.delegate.addUrlTransformer(transformer); + } + } +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RequestPredicates.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RequestPredicates.java index f3419f6405..72e0a5dc72 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RequestPredicates.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RequestPredicates.java @@ -40,6 +40,7 @@ import org.springframework.core.ParameterizedTypeReference; import org.springframework.http.HttpCookie; import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; +import org.springframework.http.codec.HttpMessageReader; import org.springframework.http.codec.multipart.Part; import org.springframework.http.server.PathContainer; import org.springframework.http.server.reactive.ServerHttpRequest; @@ -47,6 +48,7 @@ import org.springframework.lang.Nullable; import org.springframework.util.Assert; import org.springframework.util.MultiValueMap; import org.springframework.web.reactive.function.BodyExtractor; +import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.WebSession; import org.springframework.web.util.UriBuilder; import org.springframework.web.util.UriUtils; @@ -524,6 +526,11 @@ public abstract class RequestPredicates { return this.request.remoteAddress(); } + @Override + public List> messageReaders() { + return this.request.messageReaders(); + } + @Override public T body(BodyExtractor extractor) { return this.request.body(extractor); @@ -604,6 +611,11 @@ public abstract class RequestPredicates { return this.request.multipartData(); } + @Override + public ServerWebExchange exchange() { + return this.request.exchange(); + } + @Override public String toString() { return method() + " " + path(); diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/ServerRequest.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/ServerRequest.java index d363c396e2..878133e49b 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/ServerRequest.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/ServerRequest.java @@ -25,11 +25,13 @@ import java.util.Locale; import java.util.Map; import java.util.Optional; import java.util.OptionalLong; +import java.util.function.Consumer; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import org.springframework.core.ParameterizedTypeReference; +import org.springframework.core.io.buffer.DataBuffer; import org.springframework.http.HttpCookie; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; @@ -41,6 +43,7 @@ import org.springframework.http.codec.multipart.Part; import org.springframework.http.server.PathContainer; import org.springframework.http.server.reactive.ServerHttpRequest; import org.springframework.lang.Nullable; +import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.util.MultiValueMap; import org.springframework.web.reactive.function.BodyExtractor; @@ -120,6 +123,11 @@ public interface ServerRequest { */ Optional remoteAddress(); + /** + * Return the readers used to convert the body of this request. + */ + List> messageReaders(); + /** * Extract the body with the given {@code BodyExtractor}. * @param extractor the {@code BodyExtractor} that reads from the request @@ -271,7 +279,16 @@ public interface ServerRequest { */ Mono> multipartData(); + /** + * Returns the web exchange that this request is based on. Manipulating the exchange directly, + * instead of using the methods provided on {@code ServerRequest} and {@code ServerResponse}, + * can lead to irregular results. + * + * @return the web exchange + */ + ServerWebExchange exchange(); + // Static methods /** * Create a new {@code ServerRequest} based on the given {@code ServerWebExchange} and @@ -284,6 +301,15 @@ public interface ServerRequest { return new DefaultServerRequest(exchange, messageReaders); } + /** + * Create a builder with the status, headers, and cookies of the given request. + * @param other the response to copy the status, headers, and cookies from + * @return the created builder + */ + static Builder from(ServerRequest other) { + Assert.notNull(other, "'other' must not be null"); + return new DefaultServerRequestBuilder(other); + } /** * Represents the headers of the HTTP request. @@ -349,4 +375,108 @@ public interface ServerRequest { HttpHeaders asHttpHeaders(); } + + /** + * Defines a builder for a request. + */ + interface Builder { + + /** + * Set the method of the request. + * @param method the new method + * @return this builder + */ + Builder method(HttpMethod method); + + /** + * Set the uri of the request. + * @param uri the new uri + * @return this builder + */ + Builder uri(URI uri); + + /** + * Add the given header value(s) under the given name. + * @param headerName the header name + * @param headerValues the header value(s) + * @return this builder + * @see HttpHeaders#add(String, String) + */ + Builder header(String headerName, String... headerValues); + + /** + * Manipulate this request's headers with the given consumer. The + * headers provided to the consumer are "live", so that the consumer can be used to + * {@linkplain HttpHeaders#set(String, String) overwrite} existing header values, + * {@linkplain HttpHeaders#remove(Object) remove} values, or use any of the other + * {@link HttpHeaders} methods. + * @param headersConsumer a function that consumes the {@code HttpHeaders} + * @return this builder + */ + Builder headers(Consumer headersConsumer); + + /** + * Add a cookie with the given name and value(s). + * @param name the cookie name + * @param values the cookie value(s) + * @return this builder + */ + Builder cookie(String name, String... values); + + /** + * Manipulate this request's cookies with the given consumer. The + * map provided to the consumer is "live", so that the consumer can be used to + * {@linkplain MultiValueMap#set(Object, Object) overwrite} existing header values, + * {@linkplain MultiValueMap#remove(Object) remove} values, or use any of the other + * {@link MultiValueMap} methods. + * @param cookiesConsumer a function that consumes the cookies map + * @return this builder + */ + Builder cookies(Consumer> cookiesConsumer); + + /** + * Sets the body of the request. Calling this methods will + * {@linkplain org.springframework.core.io.buffer.DataBufferUtils#release(DataBuffer) release} + * the existing body of the builder. + * @param body the new body. + * @return this builder + */ + Builder body(Flux body); + + /** + * Sets the body of the request to the UTF-8 encoded bytes of the given string. + * Calling this methods will + * {@linkplain org.springframework.core.io.buffer.DataBufferUtils#release(DataBuffer) release} + * the existing body of the builder. + * @param body the new body. + * @return this builder + */ + Builder body(String body); + + /** + * Adds an attribute with the given name and value. + * @param name the attribute name + * @param value the attribute value + * @return this builder + */ + Builder attribute(String name, Object value); + + /** + * Manipulate this request's attributes with the given consumer. The map provided to the + * consumer is "live", so that the consumer can be used to + * {@linkplain Map#put(Object, Object) overwrite} existing header values, + * {@linkplain Map#remove(Object) remove} values, or use any of the other + * {@link Map} methods. + * @param attributesConsumer a function that consumes the attributes map + * @return this builder + */ + Builder attributes(Consumer> attributesConsumer); + + /** + * Builds the request. + * @return the built request + */ + ServerRequest build(); + } + } diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/support/ServerRequestWrapper.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/support/ServerRequestWrapper.java index 606ab037d7..0ff4b0c67d 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/support/ServerRequestWrapper.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/support/ServerRequestWrapper.java @@ -35,6 +35,7 @@ import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; import org.springframework.http.HttpRange; import org.springframework.http.MediaType; +import org.springframework.http.codec.HttpMessageReader; import org.springframework.http.codec.multipart.Part; import org.springframework.http.server.PathContainer; import org.springframework.http.server.reactive.ServerHttpRequest; @@ -42,6 +43,7 @@ import org.springframework.util.Assert; import org.springframework.util.MultiValueMap; import org.springframework.web.reactive.function.BodyExtractor; import org.springframework.web.reactive.function.server.ServerRequest; +import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.WebSession; import org.springframework.web.util.UriBuilder; @@ -121,6 +123,11 @@ public class ServerRequestWrapper implements ServerRequest { return this.delegate.remoteAddress(); } + @Override + public List> messageReaders() { + return this.delegate.messageReaders(); + } + @Override public T body(BodyExtractor extractor) { return this.delegate.body(extractor); @@ -201,6 +208,10 @@ public class ServerRequestWrapper implements ServerRequest { return this.delegate.multipartData(); } + @Override + public ServerWebExchange exchange() { + return this.delegate.exchange(); + } /** * Implementation of the {@code Headers} interface that can be subclassed diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/DefaultServerRequestBuilderTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/DefaultServerRequestBuilderTests.java new file mode 100644 index 0000000000..b5e1ef3870 --- /dev/null +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/DefaultServerRequestBuilderTests.java @@ -0,0 +1,81 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.function.server; + +import java.nio.charset.StandardCharsets; + +import org.junit.Before; +import org.junit.Test; +import reactor.core.publisher.Flux; +import reactor.test.StepVerifier; + +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferFactory; +import org.springframework.core.io.buffer.DefaultDataBufferFactory; +import org.springframework.http.HttpMethod; +import org.springframework.http.ResponseCookie; +import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest; +import org.springframework.mock.web.test.server.MockServerWebExchange; + +import static org.junit.Assert.*; + +/** + * @author Arjen Poutsma + */ +public class DefaultServerRequestBuilderTests { + + private DataBufferFactory dataBufferFactory; + + @Before + public void createBufferFactory() { + this.dataBufferFactory = new DefaultDataBufferFactory(); + } + + @Test + public void from() throws Exception { + + MockServerHttpRequest request = MockServerHttpRequest.post("http://example.com") + .header("foo", "bar") + .build(); + MockServerWebExchange exchange = MockServerWebExchange.from(request); + + ServerRequest other = + ServerRequest.create(exchange, HandlerStrategies.withDefaults().messageReaders()); + + Flux body = Flux.just("baz") + .map(s -> s.getBytes(StandardCharsets.UTF_8)) + .map(dataBufferFactory::wrap); + + ServerRequest result = ServerRequest.from(other) + .method(HttpMethod.HEAD) + .headers(httpHeaders -> httpHeaders.set("foo", "baar")) + .cookies(cookies -> cookies.set("baz", ResponseCookie.from("baz", "quux").build())) + .body(body) + .build(); + + assertEquals(HttpMethod.HEAD, result.method()); + assertEquals(1, result.headers().asHttpHeaders().size()); + assertEquals("baar", result.headers().asHttpHeaders().getFirst("foo")); + assertEquals(1, result.cookies().size()); + assertEquals("quux", result.cookies().getFirst("baz").getValue()); + + StepVerifier.create(result.bodyToFlux(String.class)) + .expectNext("baz") + .verifyComplete(); + } + +} \ No newline at end of file diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/MockServerRequest.java b/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/MockServerRequest.java index d7bf19a8e6..e48b16deb0 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/MockServerRequest.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/MockServerRequest.java @@ -40,6 +40,7 @@ import org.springframework.http.HttpMethod; import org.springframework.http.HttpRange; import org.springframework.http.HttpRequest; import org.springframework.http.MediaType; +import org.springframework.http.codec.HttpMessageReader; import org.springframework.http.codec.multipart.Part; import org.springframework.http.server.PathContainer; import org.springframework.http.server.RequestPath; @@ -50,6 +51,7 @@ import org.springframework.util.CollectionUtils; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; import org.springframework.web.reactive.function.BodyExtractor; +import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.WebSession; import org.springframework.web.util.UriBuilder; import org.springframework.web.util.UriComponentsBuilder; @@ -90,12 +92,18 @@ public class MockServerRequest implements ServerRequest { @Nullable private final InetSocketAddress remoteAddress; + private final List> messageReaders; + + @Nullable + private final ServerWebExchange exchange; + private MockServerRequest(HttpMethod method, URI uri, String contextPath, MockHeaders headers, MultiValueMap cookies, @Nullable Object body, Map attributes, MultiValueMap queryParams, Map pathVariables, @Nullable WebSession session, @Nullable Principal principal, - @Nullable InetSocketAddress remoteAddress) { + @Nullable InetSocketAddress remoteAddress, List> messageReaders, + @Nullable ServerWebExchange exchange) { this.method = method; this.uri = uri; @@ -109,6 +117,8 @@ public class MockServerRequest implements ServerRequest { this.session = session; this.principal = principal; this.remoteAddress = remoteAddress; + this.messageReaders = messageReaders; + this.exchange = exchange; } @@ -152,6 +162,11 @@ public class MockServerRequest implements ServerRequest { return Optional.ofNullable(this.remoteAddress); } + @Override + public List> messageReaders() { + return this.messageReaders; + } + @Override @SuppressWarnings("unchecked") public S body(BodyExtractor extractor) { @@ -233,6 +248,12 @@ public class MockServerRequest implements ServerRequest { return (Mono>) this.body; } + @Override + public ServerWebExchange exchange() { + Assert.state(this.exchange != null, "No exchange"); + return this.exchange; + } + public static Builder builder() { return new BuilderImpl(); } @@ -271,7 +292,7 @@ public class MockServerRequest implements ServerRequest { Builder session(WebSession session); /** - * @deprecated in favor of {@link #principal(Principal)} + * @deprecated in favor of {@link #principal(Principal)} */ @Deprecated Builder session(Principal principal); @@ -280,6 +301,10 @@ public class MockServerRequest implements ServerRequest { Builder remoteAddress(InetSocketAddress remoteAddress); + Builder messageReaders(List> messageReaders); + + Builder exchange(ServerWebExchange exchange); + MockServerRequest body(Object body); MockServerRequest build(); @@ -316,6 +341,11 @@ public class MockServerRequest implements ServerRequest { @Nullable private InetSocketAddress remoteAddress; + private List> messageReaders = HandlerStrategies.withDefaults().messageReaders(); + + @Nullable + private ServerWebExchange exchange; + @Override public Builder method(HttpMethod method) { Assert.notNull(method, "'method' must not be null"); @@ -419,6 +449,7 @@ public class MockServerRequest implements ServerRequest { } @Override + @Deprecated public Builder session(Principal principal) { return principal(principal); } @@ -437,19 +468,35 @@ public class MockServerRequest implements ServerRequest { return this; } + @Override + public Builder messageReaders(List> messageReaders) { + Assert.notNull(messageReaders, "'messageReaders' must not be null"); + this.messageReaders = messageReaders; + return this; + } + + @Override + public Builder exchange(ServerWebExchange exchange) { + Assert.notNull(exchange, "'exchange' must not be null"); + this.exchange = exchange; + return this; + } + @Override public MockServerRequest body(Object body) { this.body = body; return new MockServerRequest(this.method, this.uri, this.contextPath, this.headers, this.cookies, this.body, this.attributes, this.queryParams, this.pathVariables, - this.session, this.principal, this.remoteAddress); + this.session, this.principal, this.remoteAddress, this.messageReaders, + this.exchange); } @Override public MockServerRequest build() { return new MockServerRequest(this.method, this.uri, this.contextPath, this.headers, this.cookies, null, this.attributes, this.queryParams, this.pathVariables, - this.session, this.principal, this.remoteAddress); + this.session, this.principal, this.remoteAddress, this.messageReaders, + this.exchange); } }