From f8ef2e0220414cfb0ecf5dec9dea4e5a68fbce74 Mon Sep 17 00:00:00 2001 From: Rossen Stoyanchev Date: Sun, 10 Jan 2016 07:26:26 -0500 Subject: [PATCH] Add base classes for ServerHttpRequest/Response impls --- .../reactive/AbstractServerHttpRequest.java | 69 ++++++++++++++ .../reactive/AbstractServerHttpResponse.java | 81 ++++++++++++++++ .../reactive/ReactorServerHttpRequest.java | 30 ++---- .../reactive/ReactorServerHttpResponse.java | 31 ++---- .../reactive/RxNettyServerHttpRequest.java | 30 ++---- .../reactive/RxNettyServerHttpResponse.java | 26 +---- .../reactive/ServletServerHttpRequest.java | 94 ++++++++----------- .../reactive/ServletServerHttpResponse.java | 48 +++------- .../reactive/UndertowServerHttpRequest.java | 36 ++----- .../reactive/UndertowServerHttpResponse.java | 30 +----- 10 files changed, 243 insertions(+), 232 deletions(-) create mode 100644 spring-web-reactive/src/main/java/org/springframework/http/server/reactive/AbstractServerHttpRequest.java create mode 100644 spring-web-reactive/src/main/java/org/springframework/http/server/reactive/AbstractServerHttpResponse.java diff --git a/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/AbstractServerHttpRequest.java b/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/AbstractServerHttpRequest.java new file mode 100644 index 00000000000..e7a35a8e6b4 --- /dev/null +++ b/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/AbstractServerHttpRequest.java @@ -0,0 +1,69 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.http.server.reactive; + +import java.net.URI; +import java.net.URISyntaxException; + +import org.springframework.http.HttpHeaders; + +/** + * Common base class for {@link ServerHttpRequest} implementations. + * + * @author Rossen Stoyanchev + */ +public abstract class AbstractServerHttpRequest implements ServerHttpRequest { + + private URI uri; + + private HttpHeaders headers; + + + @Override + public URI getURI() { + if (this.uri == null) { + try { + this.uri = initUri(); + } + catch (URISyntaxException ex) { + throw new IllegalStateException("Could not get URI: " + ex.getMessage(), ex); + } + } + return this.uri; + } + + /** + * Initialize a URI that represents the request. + * Invoked lazily on the first call to {@link #getURI()} and then cached. + * @throws URISyntaxException + */ + protected abstract URI initUri() throws URISyntaxException; + + @Override + public HttpHeaders getHeaders() { + if (this.headers == null) { + this.headers = HttpHeaders.readOnlyHttpHeaders(initHeaders()); + } + return this.headers; + } + + /** + * Initialize the headers from the underlying request. + * Invoked lazily on the first call to {@link #getHeaders()} and then cached. + */ + protected abstract HttpHeaders initHeaders(); + +} diff --git a/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/AbstractServerHttpResponse.java b/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/AbstractServerHttpResponse.java new file mode 100644 index 00000000000..2c87f89c466 --- /dev/null +++ b/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/AbstractServerHttpResponse.java @@ -0,0 +1,81 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.http.server.reactive; + +import java.nio.ByteBuffer; + +import org.reactivestreams.Publisher; +import reactor.Flux; +import reactor.Mono; + +import org.springframework.http.HttpHeaders; + +/** + * Base class for {@link ServerHttpResponse} implementations. + * + * @author Rossen Stoyanchev + */ +public abstract class AbstractServerHttpResponse implements ServerHttpResponse { + + private final HttpHeaders headers; + + private boolean headersWritten = false; + + + protected AbstractServerHttpResponse() { + this.headers = new HttpHeaders(); + } + + + @Override + public HttpHeaders getHeaders() { + return (this.headersWritten ? HttpHeaders.readOnlyHttpHeaders(this.headers) : this.headers); + } + + @Override + public Mono setBody(Publisher publisher) { + return Flux.from(publisher).lift(new WriteWithOperator<>(writeWithPublisher -> { + writeHeaders(); + return setBodyInternal(writeWithPublisher); + })).after(); + } + + /** + * Implement this method to write to the underlying the response. + * @param publisher the publisher to write with + */ + protected abstract Mono setBodyInternal(Publisher publisher); + + @Override + public void writeHeaders() { + if (!this.headersWritten) { + try { + writeHeadersInternal(); + } + finally { + this.headersWritten = true; + } + } + } + + /** + * Implement this method to apply header changes from {@link #getHeaders()} + * to the underlying response. This method is protected from being called + * more than once. + */ + protected abstract void writeHeadersInternal(); + +} diff --git a/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/ReactorServerHttpRequest.java b/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/ReactorServerHttpRequest.java index 6d569c93525..e1f62514c78 100644 --- a/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/ReactorServerHttpRequest.java +++ b/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/ReactorServerHttpRequest.java @@ -32,14 +32,10 @@ import org.springframework.util.Assert; * * @author Stephane Maldini */ -public class ReactorServerHttpRequest implements ServerHttpRequest { +public class ReactorServerHttpRequest extends AbstractServerHttpRequest { private final HttpChannel channel; - private URI uri; - - private HttpHeaders headers; - public ReactorServerHttpRequest(HttpChannel request) { Assert.notNull("'request' must not be null."); @@ -57,27 +53,17 @@ public class ReactorServerHttpRequest implements ServerHttpRequest { } @Override - public URI getURI() { - if (this.uri == null) { - try { - this.uri = new URI(this.channel.uri()); - } - catch (URISyntaxException ex) { - throw new IllegalStateException("Could not get URI: " + ex.getMessage(), ex); - } - } - return this.uri; + protected URI initUri() throws URISyntaxException { + return new URI(this.channel.uri()); } @Override - public HttpHeaders getHeaders() { - if (this.headers == null) { - this.headers = new HttpHeaders(); - for (String name : this.channel.headers().names()) { - this.headers.put(name, this.channel.headers().getAll(name)); - } + protected HttpHeaders initHeaders() { + HttpHeaders headers = new HttpHeaders(); + for (String name : this.channel.headers().names()) { + headers.put(name, this.channel.headers().getAll(name)); } - return this.headers; + return headers; } @Override diff --git a/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/ReactorServerHttpResponse.java b/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/ReactorServerHttpResponse.java index f1c7af00ef6..64a44b94f65 100644 --- a/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/ReactorServerHttpResponse.java +++ b/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/ReactorServerHttpResponse.java @@ -24,7 +24,6 @@ import reactor.io.buffer.Buffer; import reactor.io.net.http.HttpChannel; import reactor.io.net.http.model.Status; -import org.springframework.http.HttpHeaders; import org.springframework.http.HttpStatus; import org.springframework.util.Assert; @@ -34,19 +33,14 @@ import org.springframework.util.Assert; * @author Stephane Maldini * @author Rossen Stoyanchev */ -public class ReactorServerHttpResponse implements ServerHttpResponse { +public class ReactorServerHttpResponse extends AbstractServerHttpResponse { private final HttpChannel channel; - private final HttpHeaders headers; - - private boolean headersWritten = false; - public ReactorServerHttpResponse(HttpChannel response) { Assert.notNull("'response' must not be null."); this.channel = response; - this.headers = new HttpHeaders(); } @@ -60,29 +54,16 @@ public class ReactorServerHttpResponse implements ServerHttpResponse { } @Override - public HttpHeaders getHeaders() { - return (this.headersWritten ? HttpHeaders.readOnlyHttpHeaders(this.headers) : this.headers); - } - - @Override - public Mono setBody(Publisher publisher) { - return Flux.from(publisher).lift(new WriteWithOperator<>(this::setBodyInternal)).after(); - } - protected Mono setBodyInternal(Publisher publisher) { - writeHeaders(); - return Mono.from(getReactorChannel().writeWith(Flux.from(publisher).map(Buffer::new))); + return Mono.from(this.channel.writeWith(Flux.from(publisher).map(Buffer::new))); } @Override - public void writeHeaders() { - if (!this.headersWritten) { - for (String name : this.headers.keySet()) { - for (String value : this.headers.get(name)) { - this.channel.responseHeaders().add(name, value); - } + protected void writeHeadersInternal() { + for (String name : getHeaders().keySet()) { + for (String value : getHeaders().get(name)) { + this.channel.responseHeaders().add(name, value); } - this.headersWritten = true; } } diff --git a/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/RxNettyServerHttpRequest.java b/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/RxNettyServerHttpRequest.java index 7e4683eed50..afdabae6685 100644 --- a/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/RxNettyServerHttpRequest.java +++ b/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/RxNettyServerHttpRequest.java @@ -36,14 +36,10 @@ import org.springframework.util.Assert; * @author Rossen Stoyanchev * @author Stephane Maldini */ -public class RxNettyServerHttpRequest implements ServerHttpRequest { +public class RxNettyServerHttpRequest extends AbstractServerHttpRequest { private final HttpServerRequest request; - private URI uri; - - private HttpHeaders headers; - public RxNettyServerHttpRequest(HttpServerRequest request) { Assert.notNull("'request', request must not be null."); @@ -61,27 +57,17 @@ public class RxNettyServerHttpRequest implements ServerHttpRequest { } @Override - public URI getURI() { - if (this.uri == null) { - try { - this.uri = new URI(this.getRxNettyRequest().getUri()); - } - catch (URISyntaxException ex) { - throw new IllegalStateException("Could not get URI: " + ex.getMessage(), ex); - } - } - return this.uri; + protected URI initUri() throws URISyntaxException { + return new URI(this.getRxNettyRequest().getUri()); } @Override - public HttpHeaders getHeaders() { - if (this.headers == null) { - this.headers = new HttpHeaders(); - for (String name : this.getRxNettyRequest().getHeaderNames()) { - this.headers.put(name, this.getRxNettyRequest().getAllHeaderValues(name)); - } + protected HttpHeaders initHeaders() { + HttpHeaders headers = new HttpHeaders(); + for (String name : this.getRxNettyRequest().getHeaderNames()) { + headers.put(name, this.getRxNettyRequest().getAllHeaderValues(name)); } - return this.headers; + return headers; } @Override diff --git a/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/RxNettyServerHttpResponse.java b/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/RxNettyServerHttpResponse.java index c60eaaded8d..5f47e78ced5 100644 --- a/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/RxNettyServerHttpResponse.java +++ b/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/RxNettyServerHttpResponse.java @@ -36,19 +36,14 @@ import org.springframework.util.Assert; * @author Rossen Stoyanchev * @author Stephane Maldini */ -public class RxNettyServerHttpResponse implements ServerHttpResponse { +public class RxNettyServerHttpResponse extends AbstractServerHttpResponse { private final HttpServerResponse response; - private final HttpHeaders headers; - - private boolean headersWritten = false; - public RxNettyServerHttpResponse(HttpServerResponse response) { Assert.notNull("'response', response must not be null."); this.response = response; - this.headers = new HttpHeaders(); } @@ -62,17 +57,7 @@ public class RxNettyServerHttpResponse implements ServerHttpResponse { } @Override - public HttpHeaders getHeaders() { - return (this.headersWritten ? HttpHeaders.readOnlyHttpHeaders(this.headers) : this.headers); - } - - @Override - public Mono setBody(Publisher publisher) { - return Flux.from(publisher).lift(new WriteWithOperator<>(this::setBodyInternal)).after(); - } - protected Mono setBodyInternal(Publisher publisher) { - writeHeaders(); Observable content = RxJava1Converter.from(publisher).map(this::toBytes); Observable completion = getRxNettyResponse().writeBytes(content); return RxJava1Converter.from(completion).after(); @@ -85,13 +70,10 @@ public class RxNettyServerHttpResponse implements ServerHttpResponse { } @Override - public void writeHeaders() { - if (!this.headersWritten) { - for (String name : this.headers.keySet()) { - for (String value : this.headers.get(name)) + protected void writeHeadersInternal() { + for (String name : getHeaders().keySet()) { + for (String value : getHeaders().get(name)) this.response.addHeader(name, value); - } - this.headersWritten = true; } } diff --git a/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/ServletServerHttpRequest.java b/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/ServletServerHttpRequest.java index d5c14618608..a076768c7f1 100644 --- a/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/ServletServerHttpRequest.java +++ b/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/ServletServerHttpRequest.java @@ -39,14 +39,10 @@ import org.springframework.util.StringUtils; * * @author Rossen Stoyanchev */ -public class ServletServerHttpRequest implements ServerHttpRequest { +public class ServletServerHttpRequest extends AbstractServerHttpRequest { private final HttpServletRequest request; - private URI uri; - - private HttpHeaders headers; - private final Flux requestBodyPublisher; @@ -68,62 +64,48 @@ public class ServletServerHttpRequest implements ServerHttpRequest { } @Override - public URI getURI() { - if (this.uri == null) { - try { - this.uri = new URI(getServletRequest().getScheme(), null, - getServletRequest().getServerName(), - getServletRequest().getServerPort(), - getServletRequest().getRequestURI(), - getServletRequest().getQueryString(), null); - } - catch (URISyntaxException ex) { - throw new IllegalStateException("Could not get HttpServletRequest URI: " + ex.getMessage(), ex); - } - } - return this.uri; + protected URI initUri() throws URISyntaxException { + return new URI(getServletRequest().getScheme(), null, + getServletRequest().getServerName(), + getServletRequest().getServerPort(), + getServletRequest().getRequestURI(), + getServletRequest().getQueryString(), null); } @Override - public HttpHeaders getHeaders() { - if (this.headers == null) { - this.headers = new HttpHeaders(); - for (Enumeration names = getServletRequest().getHeaderNames(); names.hasMoreElements(); ) { - String headerName = (String) names.nextElement(); - for (Enumeration headerValues = getServletRequest().getHeaders(headerName); - headerValues.hasMoreElements(); ) { - String headerValue = (String) headerValues.nextElement(); - this.headers.add(headerName, headerValue); - } - } - // HttpServletRequest exposes some headers as properties: we should include those if not already present - MediaType contentType = this.headers.getContentType(); - if (contentType == null) { - String requestContentType = getServletRequest().getContentType(); - if (StringUtils.hasLength(requestContentType)) { - contentType = MediaType.parseMediaType(requestContentType); - this.headers.setContentType(contentType); - } - } - if (contentType != null && contentType.getCharSet() == null) { - String requestEncoding = getServletRequest().getCharacterEncoding(); - if (StringUtils.hasLength(requestEncoding)) { - Charset charSet = Charset.forName(requestEncoding); - Map params = new LinkedCaseInsensitiveMap<>(); - params.putAll(contentType.getParameters()); - params.put("charset", charSet.toString()); - MediaType newContentType = new MediaType(contentType.getType(), contentType.getSubtype(), params); - this.headers.setContentType(newContentType); - } - } - if (this.headers.getContentLength() == -1) { - int requestContentLength = getServletRequest().getContentLength(); - if (requestContentLength != -1) { - this.headers.setContentLength(requestContentLength); - } + protected HttpHeaders initHeaders() { + HttpHeaders headers = new HttpHeaders(); + for (Enumeration names = getServletRequest().getHeaderNames(); names.hasMoreElements(); ) { + String name = (String) names.nextElement(); + for (Enumeration values = getServletRequest().getHeaders(name); values.hasMoreElements(); ) { + headers.add(name, (String) values.nextElement()); } } - return this.headers; + MediaType contentType = headers.getContentType(); + if (contentType == null) { + String requestContentType = getServletRequest().getContentType(); + if (StringUtils.hasLength(requestContentType)) { + contentType = MediaType.parseMediaType(requestContentType); + headers.setContentType(contentType); + } + } + if (contentType != null && contentType.getCharSet() == null) { + String encoding = getServletRequest().getCharacterEncoding(); + if (StringUtils.hasLength(encoding)) { + Charset charset = Charset.forName(encoding); + Map params = new LinkedCaseInsensitiveMap<>(); + params.putAll(contentType.getParameters()); + params.put("charset", charset.toString()); + headers.setContentType(new MediaType(contentType.getType(), contentType.getSubtype(), params)); + } + } + if (headers.getContentLength() == -1) { + int contentLength = getServletRequest().getContentLength(); + if (contentLength != -1) { + headers.setContentLength(contentLength); + } + } + return headers; } @Override diff --git a/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/ServletServerHttpResponse.java b/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/ServletServerHttpResponse.java index 27b67989554..a292b44f9ba 100644 --- a/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/ServletServerHttpResponse.java +++ b/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/ServletServerHttpResponse.java @@ -24,10 +24,8 @@ import java.util.function.Function; import javax.servlet.http.HttpServletResponse; import org.reactivestreams.Publisher; -import reactor.Flux; import reactor.Mono; -import org.springframework.http.HttpHeaders; import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; import org.springframework.util.Assert; @@ -37,16 +35,12 @@ import org.springframework.util.Assert; * * @author Rossen Stoyanchev */ -public class ServletServerHttpResponse implements ServerHttpResponse { +public class ServletServerHttpResponse extends AbstractServerHttpResponse { private final HttpServletResponse response; private final Function, Mono> responseBodyWriter; - private final HttpHeaders headers; - - private boolean headersWritten = false; - public ServletServerHttpResponse(HttpServletResponse response, Function, Mono> responseBodyWriter) { @@ -55,7 +49,6 @@ public class ServletServerHttpResponse implements ServerHttpResponse { Assert.notNull(responseBodyWriter, "'responseBodyWriter' must not be null"); this.response = response; this.responseBodyWriter = responseBodyWriter; - this.headers = new HttpHeaders(); } @@ -69,38 +62,25 @@ public class ServletServerHttpResponse implements ServerHttpResponse { } @Override - public HttpHeaders getHeaders() { - return (this.headersWritten ? HttpHeaders.readOnlyHttpHeaders(this.headers) : this.headers); - } - - @Override - public Mono setBody(final Publisher publisher) { - return Flux.from(publisher).lift(new WriteWithOperator<>(this::setBodyInternal)).after(); - } - protected Mono setBodyInternal(Publisher publisher) { - writeHeaders(); return this.responseBodyWriter.apply(publisher); } @Override - public void writeHeaders() { - if (!this.headersWritten) { - for (Map.Entry> entry : this.headers.entrySet()) { - String headerName = entry.getKey(); - for (String headerValue : entry.getValue()) { - this.response.addHeader(headerName, headerValue); - } + protected void writeHeadersInternal() { + for (Map.Entry> entry : getHeaders().entrySet()) { + String headerName = entry.getKey(); + for (String headerValue : entry.getValue()) { + this.response.addHeader(headerName, headerValue); } - MediaType contentType = this.headers.getContentType(); - if (this.response.getContentType() == null && contentType != null) { - this.response.setContentType(contentType.toString()); - } - Charset charset = (contentType != null ? contentType.getCharSet() : null); - if (this.response.getCharacterEncoding() == null && charset != null) { - this.response.setCharacterEncoding(charset.name()); - } - this.headersWritten = true; + } + MediaType contentType = getHeaders().getContentType(); + if (this.response.getContentType() == null && contentType != null) { + this.response.setContentType(contentType.toString()); + } + Charset charset = (contentType != null ? contentType.getCharSet() : null); + if (this.response.getCharacterEncoding() == null && charset != null) { + this.response.setCharacterEncoding(charset.name()); } } diff --git a/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpRequest.java b/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpRequest.java index 98c766e5898..6b9b3c32f49 100644 --- a/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpRequest.java +++ b/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpRequest.java @@ -35,14 +35,10 @@ import org.springframework.util.Assert; * @author Marek Hawrylczak * @author Rossen Stoyanchev */ -public class UndertowServerHttpRequest implements ServerHttpRequest { +public class UndertowServerHttpRequest extends AbstractServerHttpRequest { private final HttpServerExchange exchange; - private URI uri; - - private HttpHeaders headers; - private final Flux body; @@ -64,31 +60,19 @@ public class UndertowServerHttpRequest implements ServerHttpRequest { } @Override - public URI getURI() { - if (this.uri == null) { - try { - return new URI(this.getUndertowExchange().getRequestScheme(), null, - this.getUndertowExchange().getHostName(), - this.getUndertowExchange().getHostPort(), - this.getUndertowExchange().getRequestURI(), - this.getUndertowExchange().getQueryString(), null); - } - catch (URISyntaxException ex) { - throw new IllegalStateException("Could not get URI: " + ex.getMessage(), ex); - } - } - return this.uri; + protected URI initUri() throws URISyntaxException { + return new URI(this.exchange.getRequestScheme(), null, + this.exchange.getHostName(), this.exchange.getHostPort(), + this.exchange.getRequestURI(), this.exchange.getQueryString(), null); } @Override - public HttpHeaders getHeaders() { - if (this.headers == null) { - this.headers = new HttpHeaders(); - for (HeaderValues values : this.getUndertowExchange().getRequestHeaders()) { - this.headers.put(values.getHeaderName().toString(), values); - } + protected HttpHeaders initHeaders() { + HttpHeaders headers = new HttpHeaders(); + for (HeaderValues values : this.getUndertowExchange().getRequestHeaders()) { + headers.put(values.getHeaderName().toString(), values); } - return this.headers; + return headers; } @Override diff --git a/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpResponse.java b/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpResponse.java index 08032cf1e23..b32255f461a 100644 --- a/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpResponse.java +++ b/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpResponse.java @@ -24,10 +24,8 @@ import java.util.function.Function; import io.undertow.server.HttpServerExchange; import io.undertow.util.HttpString; import org.reactivestreams.Publisher; -import reactor.Flux; import reactor.Mono; -import org.springframework.http.HttpHeaders; import org.springframework.http.HttpStatus; import org.springframework.util.Assert; @@ -37,16 +35,12 @@ import org.springframework.util.Assert; * @author Marek Hawrylczak * @author Rossen Stoyanchev */ -public class UndertowServerHttpResponse implements ServerHttpResponse { +public class UndertowServerHttpResponse extends AbstractServerHttpResponse { private final HttpServerExchange exchange; private final Function, Mono> responseBodyWriter; - private final HttpHeaders headers; - - private boolean headersWritten = false; - public UndertowServerHttpResponse(HttpServerExchange exchange, Function, Mono> responseBodyWriter) { @@ -55,7 +49,6 @@ public class UndertowServerHttpResponse implements ServerHttpResponse { Assert.notNull(responseBodyWriter, "'responseBodyWriter' must not be null"); this.exchange = exchange; this.responseBodyWriter = responseBodyWriter; - this.headers = new HttpHeaders(); } @@ -70,28 +63,15 @@ public class UndertowServerHttpResponse implements ServerHttpResponse { } @Override - public HttpHeaders getHeaders() { - return (this.headersWritten ? HttpHeaders.readOnlyHttpHeaders(this.headers) : this.headers); - } - - @Override - public Mono setBody(Publisher publisher) { - return Flux.from(publisher).lift(new WriteWithOperator<>(this::setBodyInternal)).after(); - } - protected Mono setBodyInternal(Publisher publisher) { - writeHeaders(); return this.responseBodyWriter.apply(publisher); } @Override - public void writeHeaders() { - if (!this.headersWritten) { - for (Map.Entry> entry : this.headers.entrySet()) { - HttpString headerName = HttpString.tryFromString(entry.getKey()); - this.exchange.getResponseHeaders().addAll(headerName, entry.getValue()); - } - this.headersWritten = true; + protected void writeHeadersInternal() { + for (Map.Entry> entry : getHeaders().entrySet()) { + HttpString headerName = HttpString.tryFromString(entry.getKey()); + this.exchange.getResponseHeaders().addAll(headerName, entry.getValue()); } }