Response headers always in sync with native response

ServerHttpResponse implementations now immediately propagate
HttpHeaders changes as they so there is no need to call applyHeaders().

The writeHeaders from ServerHttpResponse is also removed. RxNetty and
Reactor Net both support implicitly completing if the handler
completes without explicitly writing the headers or the response body.
This commit is contained in:
Rossen Stoyanchev 2015-12-29 17:35:19 -05:00
parent 34eb6d5426
commit 6b05d17248
12 changed files with 267 additions and 145 deletions

View File

@ -0,0 +1,99 @@
/*
* Copyright 2002-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.http;
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
/**
* Extension of HttpHeaders (to be merged into HttpHeaders) that allows the
* registration of {@link HeaderChangeListener}. For use with HTTP response
* implementations that can keep track of changes made headers and keep the
* underlying server headers always in sync.
*
* @author Rossen Stoyanchev
*/
public class ExtendedHttpHeaders extends HttpHeaders {
private final List<HeaderChangeListener> listeners = new ArrayList<>(1);
public void registerChangeListener(HeaderChangeListener listener) {
this.listeners.add(listener);
}
@Override
public void add(String name, String value) {
for (HeaderChangeListener listener : this.listeners) {
listener.headerAdded(name, value);
}
super.add(name, value);
}
@Override
public void set(String name, String value) {
List<String> values = new LinkedList<String>();
values.add(value);
put(name, values);
}
@Override
public List<String> put(String key, List<String> values) {
for (HeaderChangeListener listener : this.listeners) {
listener.headerPut(key, values);
}
return super.put(key, values);
}
@Override
public List<String> remove(Object key) {
for (HeaderChangeListener listener : this.listeners) {
listener.headerRemoved((String) key);
}
return super.remove(key);
}
@Override
public void putAll(Map<? extends String, ? extends List<String>> map) {
for (Entry<? extends String, ? extends List<String>> entry : map.entrySet()) {
put(entry.getKey(), entry.getValue());
}
super.putAll(map);
}
@Override
public void clear() {
for (Entry<? extends String, ? extends List<String>> entry : super.entrySet()) {
remove(entry.getKey(), entry.getValue());
}
super.clear();
}
public interface HeaderChangeListener {
void headerAdded(String name, String value);
void headerPut(String key, List<String> values);
void headerRemoved(String key);
}
}

View File

@ -16,6 +16,7 @@
package org.springframework.http.server.reactive;
import org.reactivestreams.Publisher;
import reactor.Publishers;
import org.springframework.http.HttpStatus;
@ -30,7 +31,7 @@ public class InternalServerErrorExceptionHandler implements HttpExceptionHandler
@Override
public Publisher<Void> handle(ServerHttpRequest request, ServerHttpResponse response, Throwable ex) {
response.setStatusCode(HttpStatus.INTERNAL_SERVER_ERROR);
return response.writeHeaders();
return Publishers.empty();
}
}

View File

@ -16,6 +16,7 @@
package org.springframework.http.server.reactive;
import java.nio.ByteBuffer;
import java.util.List;
import org.reactivestreams.Publisher;
import reactor.Publishers;
@ -23,12 +24,14 @@ import reactor.io.buffer.Buffer;
import reactor.io.net.http.HttpChannel;
import reactor.io.net.http.model.Status;
import org.springframework.http.ExtendedHttpHeaders;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.util.Assert;
/**
* @author Stephane Maldini
* @author Rossen Stoyanchev
*/
public class ReactorServerHttpResponse implements ServerHttpResponse {
@ -36,13 +39,17 @@ public class ReactorServerHttpResponse implements ServerHttpResponse {
private final HttpHeaders headers;
private boolean headersWritten = false;
public ReactorServerHttpResponse(HttpChannel<?, Buffer> response) {
Assert.notNull("'response', response must not be null.");
this.channel = response;
this.headers = new HttpHeaders();
this.headers = initHttpHeaders();
}
private HttpHeaders initHttpHeaders() {
ExtendedHttpHeaders headers = new ExtendedHttpHeaders();
headers.registerChangeListener(new ReactorHeaderChangeListener());
return headers;
}
@ -53,34 +60,33 @@ public class ReactorServerHttpResponse implements ServerHttpResponse {
@Override
public HttpHeaders getHeaders() {
return (this.headersWritten ? HttpHeaders.readOnlyHttpHeaders(this.headers) : this.headers);
}
@Override
public Publisher<Void> writeHeaders() {
if (this.headersWritten) {
return Publishers.empty();
}
applyHeaders();
return this.channel.writeHeaders();
return this.headers;
}
@Override
public Publisher<Void> setBody(Publisher<ByteBuffer> publisher) {
return Publishers.lift(publisher, new WriteWithOperator<>(writePublisher -> {
applyHeaders();
return this.channel.writeWith(Publishers.map(writePublisher, Buffer::new));
}));
return Publishers.lift(publisher, new WriteWithOperator<>(writePublisher ->
this.channel.writeWith(Publishers.map(writePublisher, Buffer::new))));
}
private void applyHeaders() {
if (!this.headersWritten) {
for (String name : this.headers.keySet()) {
for (String value : this.headers.get(name)) {
this.channel.responseHeaders().add(name, value);
}
}
this.headersWritten = true;
private class ReactorHeaderChangeListener implements ExtendedHttpHeaders.HeaderChangeListener {
@Override
public void headerAdded(String name, String value) {
channel.responseHeaders().add(name, value);
}
@Override
public void headerPut(String key, List<String> values) {
channel.responseHeaders().remove(key);
channel.responseHeaders().add(key, values);
}
@Override
public void headerRemoved(String key) {
channel.responseHeaders().remove(key);
}
}
}

View File

@ -78,7 +78,9 @@ public class RxNettyServerHttpRequest implements ServerHttpRequest {
@Override
public Publisher<ByteBuffer> getBody() {
Observable<ByteBuffer> bytesContent = this.request.getContent().map(ByteBuf::nioBuffer);
Observable<ByteBuffer> bytesContent = this.request.getContent()
.concatWith(Observable.empty())
.map(ByteBuf::nioBuffer);
return RxJava1Converter.from(bytesContent);
}

View File

@ -17,15 +17,16 @@
package org.springframework.http.server.reactive;
import java.nio.ByteBuffer;
import java.util.List;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.reactivex.netty.protocol.http.server.HttpServerResponse;
import org.reactivestreams.Publisher;
import reactor.Publishers;
import reactor.core.publisher.convert.RxJava1Converter;
import reactor.io.buffer.Buffer;
import rx.Observable;
import org.springframework.http.ExtendedHttpHeaders;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.util.Assert;
@ -40,13 +41,17 @@ public class RxNettyServerHttpResponse implements ServerHttpResponse {
private final HttpHeaders headers;
private boolean headersWritten = false;
public RxNettyServerHttpResponse(HttpServerResponse<?> response) {
Assert.notNull("'response', response must not be null.");
this.response = response;
this.headers = new HttpHeaders();
this.headers = initHttpHeaders();
}
private HttpHeaders initHttpHeaders() {
ExtendedHttpHeaders headers = new ExtendedHttpHeaders();
headers.registerChangeListener(new RxNettyHeaderChangeListener());
return headers;
}
@ -57,36 +62,42 @@ public class RxNettyServerHttpResponse implements ServerHttpResponse {
@Override
public HttpHeaders getHeaders() {
return (this.headersWritten ? HttpHeaders.readOnlyHttpHeaders(this.headers) : this.headers);
}
@Override
public Publisher<Void> writeHeaders() {
if (this.headersWritten) {
return Publishers.empty();
}
applyHeaders();
return RxJava1Converter.from(this.response.sendHeaders());
return this.headers;
}
@Override
public Publisher<Void> setBody(Publisher<ByteBuffer> publisher) {
return Publishers.lift(publisher, new WriteWithOperator<>(writePublisher -> {
applyHeaders();
Observable<byte[]> observable = RxJava1Converter.from(writePublisher)
.map(buffer -> new Buffer(buffer).asBytes());
.map(buffer -> {
byte[] bytes = new byte[buffer.remaining()];
buffer.get(bytes);
return bytes;
});
return RxJava1Converter.from(this.response.writeBytes(observable));
}));
}
private void applyHeaders() {
if (!this.headersWritten) {
for (String name : this.headers.keySet()) {
for (String value : this.headers.get(name)) {
this.response.addHeader(name, value);
}
private class RxNettyHeaderChangeListener implements ExtendedHttpHeaders.HeaderChangeListener {
@Override
public void headerAdded(String name, String value) {
response.addHeader(name, value);
}
@Override
public void headerPut(String key, List<String> values) {
response.removeHeader(key);
for (String value : values) {
response.addHeader(key, value);
}
this.headersWritten = true;
}
@Override
public void headerRemoved(String key) {
response.removeHeader(key);
}
}
}

View File

@ -16,8 +16,6 @@
package org.springframework.http.server.reactive;
import org.reactivestreams.Publisher;
import org.springframework.http.HttpStatus;
import org.springframework.http.ReactiveHttpOutputMessage;
@ -34,11 +32,4 @@ public interface ServerHttpResponse extends ReactiveHttpOutputMessage {
*/
void setStatusCode(HttpStatus status);
/**
* Write the response headers. This method must be invoked to send responses without body.
* @return A {@code Publisher<Void>} used to signal the demand, and receive a notification
* when the handling is complete (success or error) including the flush of the data on the
* network.
*/
Publisher<Void> writeHeaders();
}

View File

@ -18,9 +18,7 @@ package org.springframework.http.server.reactive;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.charset.Charset;
import java.util.List;
import java.util.Map;
import javax.servlet.ServletOutputStream;
import javax.servlet.WriteListener;
import javax.servlet.http.HttpServletResponse;
@ -32,9 +30,9 @@ import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
import reactor.Publishers;
import org.springframework.http.ExtendedHttpHeaders;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.util.Assert;
/**
@ -51,16 +49,20 @@ public class ServletServerHttpResponse implements ServerHttpResponse {
private final ResponseBodySubscriber subscriber;
private boolean headersWritten = false;
public ServletServerHttpResponse(HttpServletResponse response, ServletAsyncContextSynchronizer synchronizer) {
Assert.notNull(response, "'response' must not be null");
this.response = response;
this.headers = new HttpHeaders();
this.headers = initHttpHeaders();
this.subscriber = new ResponseBodySubscriber(synchronizer);
}
private HttpHeaders initHttpHeaders() {
ExtendedHttpHeaders headers = new ExtendedHttpHeaders();
headers.registerChangeListener(new ServletHeaderChangeListener());
return headers;
}
@Override
public void setStatusCode(HttpStatus status) {
@ -69,48 +71,41 @@ public class ServletServerHttpResponse implements ServerHttpResponse {
@Override
public HttpHeaders getHeaders() {
return (this.headersWritten ? HttpHeaders.readOnlyHttpHeaders(this.headers) : this.headers);
return this.headers;
}
WriteListener getWriteListener() {
return this.subscriber;
}
@Override
public Publisher<Void> writeHeaders() {
applyHeaders();
return Publishers.empty();
}
@Override
public Publisher<Void> setBody(final Publisher<ByteBuffer> publisher) {
return Publishers.lift(publisher, new WriteWithOperator<>(writePublisher -> {
applyHeaders();
return (s -> writePublisher.subscribe(subscriber));
}));
return Publishers.lift(publisher, new WriteWithOperator<>(writePublisher ->
(s -> writePublisher.subscribe(subscriber))));
}
private void applyHeaders() {
if (!this.headersWritten) {
for (Map.Entry<String, List<String>> entry : this.headers.entrySet()) {
String headerName = entry.getKey();
for (String headerValue : entry.getValue()) {
this.response.addHeader(headerName, headerValue);
}
private class ServletHeaderChangeListener implements ExtendedHttpHeaders.HeaderChangeListener {
@Override
public void headerAdded(String name, String value) {
response.addHeader(name, value);
}
@Override
public void headerPut(String key, List<String> values) {
// We can only add but not remove
for (String value : values) {
response.addHeader(key, value);
}
MediaType contentType = this.headers.getContentType();
if (this.response.getContentType() == null && contentType != null) {
this.response.setContentType(contentType.toString());
}
Charset charset = (contentType != null ? contentType.getCharSet() : null);
if (this.response.getCharacterEncoding() == null && charset != null) {
this.response.setCharacterEncoding(charset.name());
}
this.headersWritten = true;
}
@Override
public void headerRemoved(String key) {
// No Servlet support for removing headers
}
}
private static class ResponseBodySubscriber implements WriteListener, Subscriber<ByteBuffer> {
private final ServletAsyncContextSynchronizer synchronizer;

View File

@ -19,16 +19,11 @@ package org.springframework.http.server.reactive;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.util.Assert;
import io.undertow.connector.PooledByteBuffer;
import io.undertow.server.HttpServerExchange;
import io.undertow.util.HttpString;
@ -41,6 +36,11 @@ import org.xnio.channels.StreamSinkChannel;
import reactor.Publishers;
import reactor.core.subscriber.BaseSubscriber;
import org.springframework.http.ExtendedHttpHeaders;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.util.Assert;
import static org.xnio.ChannelListeners.closingChannelExceptionHandler;
import static org.xnio.ChannelListeners.flushingChannelListener;
import static org.xnio.IoUtils.safeClose;
@ -58,14 +58,19 @@ public class UndertowServerHttpResponse implements ServerHttpResponse {
private final ResponseBodySubscriber bodySubscriber = new ResponseBodySubscriber();
private final HttpHeaders headers = new HttpHeaders();
private boolean headersWritten = false;
private final HttpHeaders headers;
public UndertowServerHttpResponse(HttpServerExchange exchange) {
Assert.notNull(exchange, "'exchange' is required.");
this.exchange = exchange;
this.headers = initHttpHeaders();
}
private HttpHeaders initHttpHeaders() {
ExtendedHttpHeaders headers = new ExtendedHttpHeaders();
headers.registerChangeListener(new UndertowHeaderChangeListener());
return headers;
}
@ -77,44 +82,34 @@ public class UndertowServerHttpResponse implements ServerHttpResponse {
@Override
public HttpHeaders getHeaders() {
return (this.headersWritten ? HttpHeaders.readOnlyHttpHeaders(this.headers) : this.headers);
}
@Override
public Publisher<Void> writeHeaders() {
applyHeaders();
return s -> s.onSubscribe(new Subscription() {
@Override
public void request(long n) {
s.onComplete();
}
@Override
public void cancel() {
}
});
}
private void applyHeaders() {
if (!this.headersWritten) {
for (Map.Entry<String, List<String>> entry : this.headers.entrySet()) {
HttpString headerName = HttpString.tryFromString(entry.getKey());
this.exchange.getResponseHeaders().addAll(headerName, entry.getValue());
}
this.headersWritten = true;
}
return this.headers;
}
@Override
public Publisher<Void> setBody(Publisher<ByteBuffer> publisher) {
return Publishers.lift(publisher, new WriteWithOperator<>(writePublisher -> {
applyHeaders();
return (subscriber -> writePublisher.subscribe(bodySubscriber));
}));
return Publishers.lift(publisher, new WriteWithOperator<>(writePublisher ->
(subscriber -> writePublisher.subscribe(bodySubscriber))));
}
private class UndertowHeaderChangeListener implements ExtendedHttpHeaders.HeaderChangeListener {
@Override
public void headerAdded(String name, String value) {
exchange.getResponseHeaders().add(HttpString.tryFromString(name), value);
}
@Override
public void headerPut(String key, List<String> values) {
exchange.getResponseHeaders().putAll(HttpString.tryFromString(key), values);
}
@Override
public void headerRemoved(String key) {
exchange.getResponseHeaders().remove(key);
}
}
private class ResponseBodySubscriber extends BaseSubscriber<ByteBuffer>
implements ChannelListener<StreamSinkChannel> {
@ -266,4 +261,5 @@ public class UndertowServerHttpResponse implements ServerHttpResponse {
}
}
}
}

View File

@ -35,7 +35,7 @@ public class ResponseStatusExceptionHandler implements HttpExceptionHandler {
public Publisher<Void> handle(ServerHttpRequest request, ServerHttpResponse response, Throwable ex) {
if (ex instanceof ResponseStatusException) {
response.setStatusCode(((ResponseStatusException) ex).getHttpStatus());
return response.writeHeaders();
return Publishers.empty();
}
return Publishers.error(ex);
}

View File

@ -16,8 +16,6 @@
package org.springframework.web.reactive.handler;
import java.util.Arrays;
import org.reactivestreams.Publisher;
import reactor.Publishers;
@ -74,15 +72,19 @@ public class SimpleHandlerResultHandler implements Ordered, HandlerResultHandler
((this.conversionService != null) && this.conversionService.canConvert(type.getRawClass(), Publisher.class));
}
@SuppressWarnings("unchecked")
@Override
public Publisher<Void> handleResult(ServerHttpRequest request,
ServerHttpResponse response, HandlerResult result) {
Object value = result.getResult();
if (Void.TYPE.equals(result.getResultType().getRawClass())) {
return response.writeHeaders();
return Publishers.empty();
}
Publisher<Void> completion = (value instanceof Publisher ? (Publisher<Void>)value : this.conversionService.convert(value, Publisher.class));
return Publishers.concat(Publishers.from(Arrays.asList(completion, response.writeHeaders())));
return (value instanceof Publisher ? (Publisher<Void>)value :
this.conversionService.convert(value, Publisher.class));
}
}

View File

@ -49,11 +49,6 @@ public class MockServerHttpResponse implements ServerHttpResponse {
return this.headers;
}
@Override
public Publisher<Void> writeHeaders() {
return Publishers.empty();
}
@Override
public Publisher<Void> setBody(Publisher<ByteBuffer> body) {
this.body = body;

View File

@ -23,6 +23,7 @@ import java.util.Map;
import org.junit.Test;
import org.reactivestreams.Publisher;
import reactor.Publishers;
import reactor.io.buffer.Buffer;
import reactor.rx.Streams;
@ -67,7 +68,7 @@ public class SimpleUrlHandlerMappingIntegrationTests extends AbstractHttpHandler
}
@Test
public void testFoo() throws Exception {
public void testFooHandler() throws Exception {
RestTemplate restTemplate = new RestTemplate();
@ -80,7 +81,7 @@ public class SimpleUrlHandlerMappingIntegrationTests extends AbstractHttpHandler
}
@Test
public void testBar() throws Exception {
public void testBarHandler() throws Exception {
RestTemplate restTemplate = new RestTemplate();
@ -92,6 +93,19 @@ public class SimpleUrlHandlerMappingIntegrationTests extends AbstractHttpHandler
assertArrayEquals("bar".getBytes(UTF_8), response.getBody());
}
@Test
public void testHeaderSettingHandler() throws Exception {
RestTemplate restTemplate = new RestTemplate();
URI url = new URI("http://localhost:" + port + "/header");
RequestEntity<Void> request = RequestEntity.get(url).build();
ResponseEntity<byte[]> response = restTemplate.exchange(request, byte[].class);
assertEquals(HttpStatus.OK, response.getStatusCode());
assertEquals("bar", response.getHeaders().getFirst("foo"));
}
@Test
public void testNotFound() throws Exception {
@ -114,6 +128,7 @@ public class SimpleUrlHandlerMappingIntegrationTests extends AbstractHttpHandler
Map<String, Object> map = new HashMap<>();
map.put("/foo", new FooHandler());
map.put("/bar", new BarHandler());
map.put("/header", new HeaderSettingHandler());
setHandlers(map);
}
}
@ -134,4 +149,13 @@ public class SimpleUrlHandlerMappingIntegrationTests extends AbstractHttpHandler
}
}
private static class HeaderSettingHandler implements HttpHandler {
@Override
public Publisher<Void> handle(ServerHttpRequest request, ServerHttpResponse response) {
response.getHeaders().add("foo", "bar");
return Publishers.empty();
}
}
}