This commit is contained in:
Rossen Stoyanchev 2015-11-12 11:52:06 -05:00
parent f1bec5f1e4
commit 141d75791d
6 changed files with 162 additions and 158 deletions

View File

@ -42,84 +42,67 @@ class RequestBodyPublisher implements Publisher<ByteBuffer> {
private static final AtomicLongFieldUpdater<RequestBodySubscription> DEMAND = private static final AtomicLongFieldUpdater<RequestBodySubscription> DEMAND =
AtomicLongFieldUpdater.newUpdater(RequestBodySubscription.class, "demand"); AtomicLongFieldUpdater.newUpdater(RequestBodySubscription.class, "demand");
private final HttpServerExchange exchange; private final HttpServerExchange exchange;
private Subscriber<? super ByteBuffer> subscriber; private Subscriber<? super ByteBuffer> subscriber;
public RequestBodyPublisher(HttpServerExchange exchange) { public RequestBodyPublisher(HttpServerExchange exchange) {
Assert.notNull(exchange, "'exchange' is required."); Assert.notNull(exchange, "'exchange' is required.");
this.exchange = exchange; this.exchange = exchange;
} }
@Override @Override
public void subscribe(Subscriber<? super ByteBuffer> s) { public void subscribe(Subscriber<? super ByteBuffer> subscriber) {
if (s == null) { if (subscriber == null) {
throw SpecificationExceptions.spec_2_13_exception(); throw SpecificationExceptions.spec_2_13_exception();
} }
if (this.subscriber != null) { if (this.subscriber != null) {
s.onError(new IllegalStateException("Only one subscriber allowed")); subscriber.onError(new IllegalStateException("Only one subscriber allowed"));
} }
this.subscriber = s; this.subscriber = subscriber;
this.subscriber.onSubscribe(new RequestBodySubscription()); this.subscriber.onSubscribe(new RequestBodySubscription());
} }
private class RequestBodySubscription
implements Subscription, Runnable, ChannelListener<StreamSourceChannel> { private class RequestBodySubscription implements Subscription, Runnable,
ChannelListener<StreamSourceChannel> {
volatile long demand; volatile long demand;
private PooledByteBuffer pooledBuffer; private PooledByteBuffer pooledBuffer;
private StreamSourceChannel channel; private StreamSourceChannel channel;
private boolean subscriptionClosed; private boolean subscriptionClosed;
private boolean draining; private boolean draining;
@Override
public void cancel() {
this.subscriptionClosed = true;
close();
}
@Override @Override
public void request(long n) { public void request(long n) {
BackpressureUtils.checkRequest(n, subscriber); BackpressureUtils.checkRequest(n, subscriber);
if (this.subscriptionClosed) { if (this.subscriptionClosed) {
return; return;
} }
BackpressureUtils.getAndAdd(DEMAND, this, n); BackpressureUtils.getAndAdd(DEMAND, this, n);
scheduleNextMessage(); scheduleNextMessage();
} }
private void scheduleNextMessage() { private void scheduleNextMessage() {
exchange.dispatch(exchange.isInIoThread() ? exchange.dispatch(exchange.isInIoThread() ? SameThreadExecutor.INSTANCE :
SameThreadExecutor.INSTANCE : exchange.getIoThread(), this); exchange.getIoThread(), this);
} }
private void doOnNext(ByteBuffer buffer) { @Override
this.draining = false; public void cancel() {
buffer.flip();
subscriber.onNext(buffer);
}
private void doOnComplete() {
this.subscriptionClosed = true; this.subscriptionClosed = true;
try {
subscriber.onComplete();
}
finally {
close(); close();
} }
}
private void doOnError(Throwable t) {
this.subscriptionClosed = true;
try {
subscriber.onError(t);
}
finally {
close();
}
}
private void close() { private void close() {
if (this.pooledBuffer != null) { if (this.pooledBuffer != null) {
@ -137,7 +120,6 @@ class RequestBodyPublisher implements Publisher<ByteBuffer> {
if (this.subscriptionClosed || this.draining) { if (this.subscriptionClosed || this.draining) {
return; return;
} }
if (0 == BackpressureUtils.getAndSub(DEMAND, this, 1)) { if (0 == BackpressureUtils.getAndSub(DEMAND, this, 1)) {
return; return;
} }
@ -152,8 +134,7 @@ class RequestBodyPublisher implements Publisher<ByteBuffer> {
return; return;
} }
else { else {
throw new IllegalStateException( throw new IllegalStateException("Failed to acquire channel!");
"Another party already acquired the channel!");
} }
} }
} }
@ -198,6 +179,32 @@ class RequestBodyPublisher implements Publisher<ByteBuffer> {
} }
} }
private void doOnNext(ByteBuffer buffer) {
this.draining = false;
buffer.flip();
subscriber.onNext(buffer);
}
private void doOnComplete() {
this.subscriptionClosed = true;
try {
subscriber.onComplete();
}
finally {
close();
}
}
private void doOnError(Throwable t) {
this.subscriptionClosed = true;
try {
subscriber.onError(t);
}
finally {
close();
}
}
@Override @Override
public void handleEvent(StreamSourceChannel channel) { public void handleEvent(StreamSourceChannel channel) {
if (this.subscriptionClosed) { if (this.subscriptionClosed) {
@ -237,4 +244,5 @@ class RequestBodyPublisher implements Publisher<ByteBuffer> {
} }
} }
} }
} }

View File

@ -16,11 +16,6 @@
package org.springframework.reactive.web.http.undertow; package org.springframework.reactive.web.http.undertow;
import static org.springframework.http.HttpStatus.INTERNAL_SERVER_ERROR;
import static org.xnio.ChannelListeners.closingChannelExceptionHandler;
import static org.xnio.ChannelListeners.flushingChannelListener;
import static org.xnio.IoUtils.safeClose;
import java.io.IOException; import java.io.IOException;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.util.Queue; import java.util.Queue;
@ -37,30 +32,43 @@ import org.xnio.ChannelListener;
import org.xnio.channels.StreamSinkChannel; import org.xnio.channels.StreamSinkChannel;
import reactor.core.subscriber.BaseSubscriber; import reactor.core.subscriber.BaseSubscriber;
import static org.xnio.ChannelListeners.closingChannelExceptionHandler;
import static org.xnio.ChannelListeners.flushingChannelListener;
import static org.xnio.IoUtils.safeClose;
/** /**
* @author Marek Hawrylczak * @author Marek Hawrylczak
* @author Rossen Stoyanchev
*/ */
class ResponseBodySubscriber extends BaseSubscriber<ByteBuffer> class ResponseBodySubscriber extends BaseSubscriber<ByteBuffer>
implements ChannelListener<StreamSinkChannel> { implements ChannelListener<StreamSinkChannel> {
private static final Log logger = LogFactory.getLog(ResponseBodySubscriber.class); private static final Log logger = LogFactory.getLog(ResponseBodySubscriber.class);
private final HttpServerExchange exchange; private final HttpServerExchange exchange;
private final Queue<PooledByteBuffer> buffers;
private final AtomicInteger writing = new AtomicInteger();
private final AtomicBoolean closing = new AtomicBoolean();
private StreamSinkChannel responseChannel;
private Subscription subscription; private Subscription subscription;
private final Queue<PooledByteBuffer> buffers;
private final AtomicInteger writing = new AtomicInteger();
private final AtomicBoolean closing = new AtomicBoolean();
private StreamSinkChannel responseChannel;
public ResponseBodySubscriber(HttpServerExchange exchange) { public ResponseBodySubscriber(HttpServerExchange exchange) {
this.exchange = exchange; this.exchange = exchange;
this.buffers = new ConcurrentLinkedQueue<>(); this.buffers = new ConcurrentLinkedQueue<>();
} }
@Override @Override
public void onSubscribe(Subscription s) { public void onSubscribe(Subscription subscription) {
super.onSubscribe(s); super.onSubscribe(subscription);
this.subscription = s; this.subscription = subscription;
this.subscription.request(1); this.subscription.request(1);
} }
@ -78,6 +86,7 @@ class ResponseBodySubscriber extends BaseSubscriber<ByteBuffer>
do { do {
c = this.responseChannel.write(buffer); c = this.responseChannel.write(buffer);
} while (buffer.hasRemaining() && c > 0); } while (buffer.hasRemaining() && c > 0);
if (buffer.hasRemaining()) { if (buffer.hasRemaining()) {
this.writing.incrementAndGet(); this.writing.incrementAndGet();
enqueue(buffer); enqueue(buffer);
@ -102,13 +111,11 @@ class ResponseBodySubscriber extends BaseSubscriber<ByteBuffer>
private void enqueue(ByteBuffer src) { private void enqueue(ByteBuffer src) {
do { do {
PooledByteBuffer pooledBuffer = PooledByteBuffer buffer = this.exchange.getConnection().getByteBufferPool().allocate();
this.exchange.getConnection().getByteBufferPool().allocate(); ByteBuffer dst = buffer.getBuffer();
ByteBuffer dst = pooledBuffer.getBuffer();
copy(dst, src); copy(dst, src);
dst.flip(); dst.flip();
this.buffers.add(pooledBuffer); this.buffers.add(buffer);
} while (src.remaining() > 0); } while (src.remaining() > 0);
} }
@ -128,10 +135,12 @@ class ResponseBodySubscriber extends BaseSubscriber<ByteBuffer>
do { do {
c = channel.write(buffer); c = channel.write(buffer);
} while (buffer.hasRemaining() && c > 0); } while (buffer.hasRemaining() && c > 0);
if (!buffer.hasRemaining()) { if (!buffer.hasRemaining()) {
safeClose(this.buffers.remove()); safeClose(this.buffers.remove());
} }
} while (!this.buffers.isEmpty() && c > 0); } while (!this.buffers.isEmpty() && c > 0);
if (!this.buffers.isEmpty()) { if (!this.buffers.isEmpty()) {
channel.resumeWrites(); channel.resumeWrites();
} }
@ -152,20 +161,17 @@ class ResponseBodySubscriber extends BaseSubscriber<ByteBuffer>
} }
@Override @Override
public void onError(Throwable t) { public void onError(Throwable ex) {
super.onError(t); super.onError(ex);
if (!this.exchange.isResponseStarted() && logger.error("ResponseBodySubscriber error", ex);
this.exchange.getStatusCode() < INTERNAL_SERVER_ERROR.value()) { if (!this.exchange.isResponseStarted() && this.exchange.getStatusCode() < 500) {
this.exchange.setStatusCode(500);
this.exchange.setStatusCode(INTERNAL_SERVER_ERROR.value());
} }
logger.error("ResponseBodySubscriber error", t);
} }
@Override @Override
public void onComplete() { public void onComplete() {
super.onComplete(); super.onComplete();
if (this.responseChannel != null) { if (this.responseChannel != null) {
this.closing.set(true); this.closing.set(true);
closeIfDone(); closeIfDone();
@ -185,10 +191,8 @@ class ResponseBodySubscriber extends BaseSubscriber<ByteBuffer>
this.responseChannel.shutdownWrites(); this.responseChannel.shutdownWrites();
if (!this.responseChannel.flush()) { if (!this.responseChannel.flush()) {
this.responseChannel.getWriteSetter().set( this.responseChannel.getWriteSetter().set(flushingChannelListener(
flushingChannelListener( o -> safeClose(this.responseChannel), closingChannelExceptionHandler()));
o -> safeClose(this.responseChannel),
closingChannelExceptionHandler()));
this.responseChannel.resumeWrites(); this.responseChannel.resumeWrites();
} }
this.responseChannel = null; this.responseChannel = null;

View File

@ -16,41 +16,48 @@
package org.springframework.reactive.web.http.undertow; package org.springframework.reactive.web.http.undertow;
import static org.springframework.http.HttpStatus.INTERNAL_SERVER_ERROR;
import org.springframework.http.server.ReactiveServerHttpRequest; import org.springframework.http.server.ReactiveServerHttpRequest;
import org.springframework.http.server.ReactiveServerHttpResponse; import org.springframework.http.server.ReactiveServerHttpResponse;
import org.springframework.reactive.web.http.HttpHandler; import org.springframework.reactive.web.http.HttpHandler;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import io.undertow.server.HttpServerExchange; import io.undertow.server.HttpServerExchange;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.reactivestreams.Subscriber; import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription; import org.reactivestreams.Subscription;
/** /**
* @author Marek Hawrylczak * @author Marek Hawrylczak
* @author Rossen Stoyanchev
*/ */
class RequestHandlerAdapter implements io.undertow.server.HttpHandler { class UndertowHttpHandlerAdapter implements io.undertow.server.HttpHandler {
private final HttpHandler httpHandler; private static Log logger = LogFactory.getLog(UndertowHttpHandlerAdapter.class);
public RequestHandlerAdapter(HttpHandler httpHandler) {
Assert.notNull(httpHandler, "'httpHandler' is required."); private final HttpHandler delegate;
this.httpHandler = httpHandler;
public UndertowHttpHandlerAdapter(HttpHandler delegate) {
Assert.notNull(delegate, "'delegate' is required.");
this.delegate = delegate;
} }
@Override @Override
public void handleRequest(HttpServerExchange exchange) throws Exception { public void handleRequest(HttpServerExchange exchange) throws Exception {
RequestBodyPublisher requestBodyPublisher = new RequestBodyPublisher(exchange); RequestBodyPublisher requestPublisher = new RequestBodyPublisher(exchange);
ReactiveServerHttpRequest request = ReactiveServerHttpRequest request = new UndertowServerHttpRequest(exchange, requestPublisher);
new UndertowServerHttpRequest(exchange, requestBodyPublisher);
ResponseBodySubscriber responseBodySubscriber = new ResponseBodySubscriber(exchange); ResponseBodySubscriber responseSubscriber = new ResponseBodySubscriber(exchange);
ReactiveServerHttpResponse response = ReactiveServerHttpResponse response = new UndertowServerHttpResponse(exchange, responseSubscriber);
new UndertowServerHttpResponse(exchange, responseBodySubscriber);
exchange.dispatch(); exchange.dispatch();
this.httpHandler.handle(request, response).subscribe(new Subscriber<Void>() {
this.delegate.handle(request, response).subscribe(new Subscriber<Void>() {
@Override @Override
public void onSubscribe(Subscription subscription) { public void onSubscribe(Subscription subscription) {
subscription.request(Long.MAX_VALUE); subscription.request(Long.MAX_VALUE);
@ -58,14 +65,16 @@ class RequestHandlerAdapter implements io.undertow.server.HttpHandler {
@Override @Override
public void onNext(Void aVoid) { public void onNext(Void aVoid) {
// no op
} }
@Override @Override
public void onError(Throwable t) { public void onError(Throwable ex) {
if (!exchange.isResponseStarted() && if (exchange.isResponseStarted() || exchange.getStatusCode() > 500) {
exchange.getStatusCode() < INTERNAL_SERVER_ERROR.value()) { logger.error("Error from request handling. Completing the request.", ex);
}
exchange.setStatusCode(INTERNAL_SERVER_ERROR.value()); else {
exchange.setStatusCode(500);
} }
exchange.endExchange(); exchange.endExchange();
} }
@ -76,4 +85,5 @@ class RequestHandlerAdapter implements io.undertow.server.HttpHandler {
} }
}); });
} }
} }

View File

@ -27,44 +27,42 @@ import io.undertow.server.HttpHandler;
/** /**
* @author Marek Hawrylczak * @author Marek Hawrylczak
*/ */
public class UndertowHttpServer extends HttpServerSupport public class UndertowHttpServer extends HttpServerSupport implements InitializingBean, HttpServer {
implements InitializingBean, HttpServer {
private Undertow undertowServer; private Undertow server;
private boolean running; private boolean running;
@Override @Override
public void afterPropertiesSet() throws Exception { public void afterPropertiesSet() throws Exception {
Assert.notNull(getHttpHandler()); Assert.notNull(getHttpHandler());
HttpHandler handler = new UndertowHttpHandlerAdapter(getHttpHandler());
HttpHandler handler = new RequestHandlerAdapter(getHttpHandler()); int port = (getPort() != -1 ? getPort() : 8080);
this.server = Undertow.builder().addHttpListener(port, "localhost")
this.undertowServer = Undertow.builder() .setHandler(handler).build();
.addHttpListener(getPort() != -1 ? getPort() : 8080, "localhost")
.setHandler(handler)
.build();
} }
@Override @Override
public void start() { public void start() {
if (!running) { if (!this.running) {
this.undertowServer.start(); this.server.start();
running = true; this.running = true;
} }
} }
@Override @Override
public void stop() { public void stop() {
if (running) { if (this.running) {
this.undertowServer.stop(); this.server.stop();
running = false; this.running = false;
} }
} }
@Override @Override
public boolean isRunning() { public boolean isRunning() {
return running; return this.running;
} }
} }

View File

@ -20,37 +20,32 @@ import java.net.URI;
import java.net.URISyntaxException; import java.net.URISyntaxException;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.server.ReactiveServerHttpRequest;
import org.springframework.util.StringUtils;
import io.undertow.server.HttpServerExchange; import io.undertow.server.HttpServerExchange;
import io.undertow.util.HeaderValues; import io.undertow.util.HeaderValues;
import org.reactivestreams.Publisher; import org.reactivestreams.Publisher;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.server.ReactiveServerHttpRequest;
/** /**
* @author Marek Hawrylczak * @author Marek Hawrylczak
* @author Rossen Stoyanchev
*/ */
class UndertowServerHttpRequest implements ReactiveServerHttpRequest { class UndertowServerHttpRequest implements ReactiveServerHttpRequest {
private final HttpServerExchange exchange; private final HttpServerExchange exchange;
private final Publisher<ByteBuffer> requestBodyPublisher; private final Publisher<ByteBuffer> body;
private HttpHeaders headers; private HttpHeaders headers;
public UndertowServerHttpRequest(HttpServerExchange exchange,
Publisher<ByteBuffer> requestBodyPublisher) {
public UndertowServerHttpRequest(HttpServerExchange exchange, Publisher<ByteBuffer> body) {
this.exchange = exchange; this.exchange = exchange;
this.requestBodyPublisher = requestBodyPublisher; this.body = body;
} }
@Override
public Publisher<ByteBuffer> getBody() {
return this.requestBodyPublisher;
}
@Override @Override
public HttpMethod getMethod() { public HttpMethod getMethod() {
@ -60,11 +55,9 @@ class UndertowServerHttpRequest implements ReactiveServerHttpRequest {
@Override @Override
public URI getURI() { public URI getURI() {
try { try {
StringBuilder uri = new StringBuilder(this.exchange.getRequestPath()); return new URI(this.exchange.getRequestScheme(), null, this.exchange.getHostName(),
if (StringUtils.hasLength(this.exchange.getQueryString())) { this.exchange.getHostPort(), this.exchange.getRequestURI(),
uri.append('?').append(this.exchange.getQueryString()); this.exchange.getQueryString(), null);
}
return new URI(uri.toString());
} }
catch (URISyntaxException ex) { catch (URISyntaxException ex) {
throw new IllegalStateException("Could not get URI: " + ex.getMessage(), ex); throw new IllegalStateException("Could not get URI: " + ex.getMessage(), ex);
@ -83,4 +76,10 @@ class UndertowServerHttpRequest implements ReactiveServerHttpRequest {
} }
return this.headers; return this.headers;
} }
@Override
public Publisher<ByteBuffer> getBody() {
return this.body;
}
} }

View File

@ -23,65 +23,50 @@ import java.util.Map;
import org.springframework.http.HttpHeaders; import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus; import org.springframework.http.HttpStatus;
import org.springframework.http.server.ReactiveServerHttpResponse; import org.springframework.http.server.ReactiveServerHttpResponse;
import org.springframework.util.Assert;
import io.undertow.server.HttpServerExchange; import io.undertow.server.HttpServerExchange;
import io.undertow.util.HttpString; import io.undertow.util.HttpString;
import org.reactivestreams.Publisher; import org.reactivestreams.Publisher;
import org.reactivestreams.Subscription; import org.reactivestreams.Subscription;
import reactor.rx.Streams;
/** /**
* @author Marek Hawrylczak * @author Marek Hawrylczak
* @author Rossen Stoyanchev
*/ */
class UndertowServerHttpResponse implements ReactiveServerHttpResponse { class UndertowServerHttpResponse implements ReactiveServerHttpResponse {
private final HttpServerExchange exchange;
private final HttpHeaders headers;
private final ResponseBodySubscriber responseBodySubscriber; private final HttpServerExchange exchange;
private final ResponseBodySubscriber bodySubscriber;
private final HttpHeaders headers = new HttpHeaders();
private boolean headersWritten = false; private boolean headersWritten = false;
public UndertowServerHttpResponse(HttpServerExchange exchange,
ResponseBodySubscriber responseBodySubscriber) {
public UndertowServerHttpResponse(HttpServerExchange exchange, ResponseBodySubscriber body) {
this.exchange = exchange; this.exchange = exchange;
this.responseBodySubscriber = responseBodySubscriber; this.bodySubscriber = body;
this.headers = new HttpHeaders();
} }
@Override @Override
public void setStatusCode(HttpStatus status) { public void setStatusCode(HttpStatus status) {
Assert.notNull(status);
this.exchange.setStatusCode(status.value()); this.exchange.setStatusCode(status.value());
} }
@Override
public Publisher<Void> setBody(Publisher<ByteBuffer> contentPublisher) {
applyHeaders();
return s -> s.onSubscribe(new Subscription() {
@Override
public void request(long n) {
Streams.wrap(contentPublisher)
.finallyDo(byteBufferSignal -> {
if (byteBufferSignal.isOnComplete()) {
s.onComplete();
}
else {
s.onError(byteBufferSignal.getThrowable());
}
}
).subscribe(responseBodySubscriber);
}
@Override @Override
public void cancel() { public Publisher<Void> setBody(Publisher<ByteBuffer> bodyPublisher) {
} applyHeaders();
}); return (subscriber -> bodyPublisher.subscribe(bodySubscriber));
} }
@Override @Override
public HttpHeaders getHeaders() { public HttpHeaders getHeaders() {
return (this.headersWritten ? return (this.headersWritten ? HttpHeaders.readOnlyHttpHeaders(this.headers) : this.headers);
HttpHeaders.readOnlyHttpHeaders(this.headers) : this.headers);
} }
@Override @Override
@ -102,12 +87,12 @@ class UndertowServerHttpResponse implements ReactiveServerHttpResponse {
private void applyHeaders() { private void applyHeaders() {
if (!this.headersWritten) { if (!this.headersWritten) {
for (Map.Entry<String, List<String>> entry : this.headers.entrySet()) { for (Map.Entry<String, List<String>> entry : this.headers.entrySet()) {
String headerName = entry.getKey(); HttpString headerName = HttpString.tryFromString(entry.getKey());
this.exchange.getResponseHeaders() this.exchange.getResponseHeaders().addAll(headerName, entry.getValue());
.addAll(HttpString.tryFromString(headerName), entry.getValue());
} }
this.headersWritten = true; this.headersWritten = true;
} }
} }
} }