diff --git a/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/AbstractResponseBodyPublisher.java b/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/AbstractResponseBodyPublisher.java new file mode 100644 index 00000000000..cdd3fffc169 --- /dev/null +++ b/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/AbstractResponseBodyPublisher.java @@ -0,0 +1,208 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server.reactive; + +import java.util.Objects; +import java.util.concurrent.atomic.AtomicLong; + +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import reactor.core.util.BackpressureUtils; + +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.util.Assert; + +/** + * Abstract base class for {@code Publisher} implementations that bridge between + * event-listener APIs and Reactive Streams. Specifically, base class for the Servlet 3.1 + * and Undertow support. + * + * @author Arjen Poutsma + * @see ServletServerHttpRequest + * @see UndertowHttpHandlerAdapter + */ +abstract class AbstractResponseBodyPublisher implements Publisher { + + private ResponseBodySubscription subscription; + + private volatile boolean stalled; + + @Override + public void subscribe(Subscriber subscriber) { + Objects.requireNonNull(subscriber); + Assert.state(this.subscription == null, "Only a single subscriber allowed"); + + this.subscription = new ResponseBodySubscription(subscriber); + subscriber.onSubscribe(this.subscription); + } + + /** + * Publishes the given signal to the subscriber. + * @param dataBuffer the signal to publish + * @see Subscriber#onNext(Object) + */ + protected final void publishOnNext(DataBuffer dataBuffer) { + Assert.state(this.subscription != null); + this.subscription.publishOnNext(dataBuffer); + } + + /** + * Publishes the given error to the subscriber. + * @param t the error to publish + * @see Subscriber#onError(Throwable) + */ + protected final void publishOnError(Throwable t) { + if (this.subscription != null) { + this.subscription.publishOnError(t); + } + } + + /** + * Publishes the complete signal to the subscriber. + * @see Subscriber#onComplete() + */ + protected final void publishOnComplete() { + if (this.subscription != null) { + this.subscription.publishOnComplete(); + } + } + + /** + * Returns true if the {@code Subscriber} associated with this {@code Publisher} has + * cancelled its {@code Subscription}. + * @return {@code true} if a subscriber has been registered and its subscription has + * been cancelled; {@code false} otherwise + * @see ResponseBodySubscription#isCancelled() + * @see Subscription#cancel() + */ + protected final boolean isSubscriptionCancelled() { + return (this.subscription != null && this.subscription.isCancelled()); + } + + /** + * Checks the subscription for demand, and marks this publisher as "stalled" if there + * is none. The next time the subscriber {@linkplain Subscription#request(long) + * requests} more events, the {@link #noLongerStalled()} method is called. + * @return {@code true} if there is demand; {@code false} otherwise + */ + protected final boolean checkSubscriptionForDemand() { + if (this.subscription == null || !this.subscription.hasDemand()) { + this.stalled = true; + return false; + } + else { + return true; + } + } + + /** + * Abstract template method called when this publisher is no longer "stalled". Used in + * sub-classes to resume reading from the request. + */ + protected abstract void noLongerStalled(); + + private final class ResponseBodySubscription implements Subscription { + + private final Subscriber subscriber; + + private final AtomicLong demand = new AtomicLong(); + + private boolean cancelled; + + public ResponseBodySubscription(Subscriber subscriber) { + Assert.notNull(subscriber, "'subscriber' must not be null"); + + this.subscriber = subscriber; + } + + @Override + public final void cancel() { + this.cancelled = true; + } + + /** + * Indicates whether this subscription has been cancelled. + * @see #cancel() + */ + protected final boolean isCancelled() { + return this.cancelled; + } + + @Override + public final void request(long n) { + if (!isCancelled() && BackpressureUtils.checkRequest(n, this.subscriber)) { + long demand = BackpressureUtils.addAndGet(this.demand, n); + + if (stalled && demand > 0) { + stalled = false; + noLongerStalled(); + } + } + } + + /** + * Indicates whether this subscription has demand. + * @see #request(long) + */ + protected final boolean hasDemand() { + return this.demand.get() > 0; + } + + /** + * Publishes the given signal to the subscriber wrapped by this subscription, if + * it has not been cancelled. If there is {@linkplain #hasDemand() no demand} for + * the signal, an exception will be thrown. + * @param dataBuffer the signal to publish + * @see Subscriber#onNext(Object) + */ + protected final void publishOnNext(DataBuffer dataBuffer) { + if (!isCancelled()) { + if (hasDemand()) { + BackpressureUtils.getAndSub(this.demand, 1L); + this.subscriber.onNext(dataBuffer); + } + else { + throw new IllegalStateException("No demand for: " + dataBuffer); + } + } + } + + /** + * Publishes the given error to the subscriber wrapped by this subscription, if it + * has not been cancelled. + * @param t the error to publish + * @see Subscriber#onError(Throwable) + */ + protected final void publishOnError(Throwable t) { + if (!isCancelled()) { + this.subscriber.onError(t); + } + } + + /** + * Publishes the complete signal to the subscriber wrapped by this subscription, + * if it has not been cancelled. + * @see Subscriber#onComplete() + */ + protected final void publishOnComplete() { + if (!isCancelled()) { + this.subscriber.onComplete(); + } + } + } +} diff --git a/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/ServletAsyncContextSynchronizer.java b/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/ServletAsyncContextSynchronizer.java index 283b598e9ed..ce507c9729b 100644 --- a/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/ServletAsyncContextSynchronizer.java +++ b/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/ServletAsyncContextSynchronizer.java @@ -16,11 +16,8 @@ package org.springframework.http.server.reactive; -import java.io.IOException; import java.util.concurrent.atomic.AtomicInteger; import javax.servlet.AsyncContext; -import javax.servlet.ServletInputStream; -import javax.servlet.ServletOutputStream; import javax.servlet.ServletRequest; import javax.servlet.ServletResponse; @@ -56,32 +53,20 @@ final class ServletAsyncContextSynchronizer { this.asyncContext = asyncContext; } + /** + * Returns the request of this synchronizer. + */ public ServletRequest getRequest() { return this.asyncContext.getRequest(); } + /** + * Returns the response of this synchronizer. + */ public ServletResponse getResponse() { return this.asyncContext.getResponse(); } - /** - * Returns the input stream of this synchronizer. - * @return the input stream - * @throws IOException if an input or output exception occurred - */ - public ServletInputStream getInputStream() throws IOException { - return getRequest().getInputStream(); - } - - /** - * Returns the output stream of this synchronizer. - * @return the output stream - * @throws IOException if an input or output exception occurred - */ - public ServletOutputStream getOutputStream() throws IOException { - return getResponse().getOutputStream(); - } - /** * Completes the reading side of the asynchronous operation. When both this method and * {@link #writeComplete()} have been called, the {@code AsyncContext} will be diff --git a/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/ServletServerHttpRequest.java b/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/ServletServerHttpRequest.java index 289d20b8c80..f615e0495e6 100644 --- a/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/ServletServerHttpRequest.java +++ b/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/ServletServerHttpRequest.java @@ -22,7 +22,6 @@ import java.net.URISyntaxException; import java.nio.charset.Charset; import java.util.Enumeration; import java.util.Map; -import java.util.concurrent.atomic.AtomicLong; import javax.servlet.ReadListener; import javax.servlet.ServletInputStream; import javax.servlet.http.Cookie; @@ -30,9 +29,6 @@ import javax.servlet.http.HttpServletRequest; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; -import org.reactivestreams.Publisher; -import org.reactivestreams.Subscriber; -import org.reactivestreams.Subscription; import reactor.core.publisher.Flux; import org.springframework.core.io.buffer.DataBuffer; @@ -68,7 +64,6 @@ public class ServletServerHttpRequest extends AbstractServerHttpRequest { RequestBodyPublisher bodyPublisher = new RequestBodyPublisher(synchronizer, allocator, bufferSize); this.requestBodyPublisher = Flux.from(bodyPublisher); - this.request.getInputStream().setReadListener(bodyPublisher); } @@ -142,8 +137,10 @@ public class ServletServerHttpRequest extends AbstractServerHttpRequest { return this.requestBodyPublisher; } - private static class RequestBodyPublisher - implements ReadListener, Publisher { + private static class RequestBodyPublisher extends AbstractResponseBodyPublisher { + + private final RequestBodyReadListener readListener = + new RequestBodyReadListener(); private final ServletAsyncContextSynchronizer synchronizer; @@ -151,184 +148,78 @@ public class ServletServerHttpRequest extends AbstractServerHttpRequest { private final byte[] buffer; - private final DemandCounter demand = new DemandCounter(); - - private Subscriber subscriber; - - private boolean stalled; - - private boolean cancelled; - public RequestBodyPublisher(ServletAsyncContextSynchronizer synchronizer, - DataBufferAllocator allocator, int bufferSize) { + DataBufferAllocator allocator, int bufferSize) throws IOException { this.synchronizer = synchronizer; this.allocator = allocator; this.buffer = new byte[bufferSize]; + synchronizer.getRequest().getInputStream().setReadListener(readListener); } @Override - public void subscribe(Subscriber subscriber) { - if (subscriber == null) { - throw new NullPointerException(); + protected void noLongerStalled() { + try { + readListener.onDataAvailable(); } - else if (this.subscriber != null) { - subscriber.onError( - new IllegalStateException("Only one subscriber allowed")); - } - this.subscriber = subscriber; - this.subscriber.onSubscribe(new RequestBodySubscription()); - } - - @Override - public void onDataAvailable() throws IOException { - if (cancelled) { - return; - } - ServletInputStream input = this.synchronizer.getInputStream(); - logger.trace("onDataAvailable: " + input); - - while (true) { - logger.trace("Demand: " + this.demand); - - if (!demand.hasDemand()) { - stalled = true; - break; - } - - boolean ready = input.isReady(); - logger.trace( - "Input ready: " + ready + " finished: " + input.isFinished()); - - if (!ready) { - break; - } - - int read = input.read(buffer); - logger.trace("Input read:" + read); - - if (read == -1) { - break; - } - else if (read > 0) { - this.demand.decrement(); - - DataBuffer dataBuffer = allocator.allocateBuffer(read); - dataBuffer.write(this.buffer, 0, read); - - this.subscriber.onNext(dataBuffer); - - } + catch (IOException ex) { + readListener.onError(ex); } } - @Override - public void onAllDataRead() throws IOException { - if (cancelled) { - return; - } - logger.trace("All data read"); - this.synchronizer.readComplete(); - if (this.subscriber != null) { - this.subscriber.onComplete(); - } - } - - @Override - public void onError(Throwable t) { - if (cancelled) { - return; - } - logger.trace("RequestBodyPublisher Error", t); - this.synchronizer.readComplete(); - if (this.subscriber != null) { - this.subscriber.onError(t); - } - } - - private class RequestBodySubscription implements Subscription { + private class RequestBodyReadListener implements ReadListener { @Override - public void request(long n) { - if (cancelled) { + public void onDataAvailable() throws IOException { + if (isSubscriptionCancelled()) { return; } - logger.trace("Updating demand " + demand + " by " + n); + logger.trace("onDataAvailable"); + ServletInputStream input = synchronizer.getRequest().getInputStream(); - demand.increase(n); - - logger.trace("Stalled: " + stalled); - - if (stalled) { - stalled = false; - try { - onDataAvailable(); + while (true) { + if (!checkSubscriptionForDemand()) { + break; } - catch (IOException ex) { - onError(ex); + + boolean ready = input.isReady(); + logger.trace( + "Input ready: " + ready + " finished: " + input.isFinished()); + + if (!ready) { + break; + } + + int read = input.read(buffer); + logger.trace("Input read:" + read); + + if (read == -1) { + break; + } + else if (read > 0) { + DataBuffer dataBuffer = allocator.allocateBuffer(read); + dataBuffer.write(buffer, 0, read); + + publishOnNext(dataBuffer); } } } @Override - public void cancel() { - if (cancelled) { - return; - } - cancelled = true; + public void onAllDataRead() throws IOException { + logger.trace("All data read"); synchronizer.readComplete(); - demand.reset(); - } - } - /** - * Small utility class for keeping track of Reactive Streams demand. - */ - private static final class DemandCounter { - - private final AtomicLong demand = new AtomicLong(); - - /** - * Increases the demand by the given number - * @param n the positive number to increase demand by - * @return the increased demand - * @see Subscription#request(long) - */ - public long increase(long n) { - Assert.isTrue(n > 0, "'n' must be higher than 0"); - return demand - .updateAndGet(d -> d != Long.MAX_VALUE ? d + n : Long.MAX_VALUE); - } - - /** - * Decreases the demand by one. - * @return the decremented demand - */ - public long decrement() { - return demand - .updateAndGet(d -> d != Long.MAX_VALUE ? d - 1 : Long.MAX_VALUE); - } - - /** - * Indicates whether this counter has demand, i.e. whether it is higher than - * 0. - * @return {@code true} if this counter has demand; {@code false} otherwise - */ - public boolean hasDemand() { - return this.demand.get() > 0; - } - - /** - * Resets this counter to 0. - * @see Subscription#cancel() - */ - public void reset() { - this.demand.set(0); + publishOnComplete(); } @Override - public String toString() { - return demand.toString(); + public void onError(Throwable t) { + logger.trace("RequestBodyReadListener Error", t); + synchronizer.readComplete(); + + publishOnError(t); } } + } } diff --git a/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/ServletServerHttpResponse.java b/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/ServletServerHttpResponse.java index 35cfbb00fed..237fe48994c 100644 --- a/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/ServletServerHttpResponse.java +++ b/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/ServletServerHttpResponse.java @@ -32,6 +32,7 @@ import org.reactivestreams.Publisher; import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; import reactor.core.publisher.Mono; +import reactor.core.util.BackpressureUtils; import org.springframework.core.io.buffer.DataBuffer; import org.springframework.core.io.buffer.DataBufferAllocator; @@ -61,7 +62,6 @@ public class ServletServerHttpResponse extends AbstractServerHttpResponse { this.response = (HttpServletResponse) synchronizer.getResponse(); this.responseBodySubscriber = new ResponseBodySubscriber(synchronizer, bufferSize); - this.response.getOutputStream().setWriteListener(responseBodySubscriber); } public HttpServletResponse getServletResponse() { @@ -118,39 +118,46 @@ public class ServletServerHttpResponse extends AbstractServerHttpResponse { } } - private static class ResponseBodySubscriber - implements WriteListener, Subscriber { + private static class ResponseBodySubscriber implements Subscriber { + + private final ResponseBodyWriteListener writeListener = + new ResponseBodyWriteListener(); private final ServletAsyncContextSynchronizer synchronizer; private final int bufferSize; + private volatile DataBuffer dataBuffer; + + private volatile boolean completed = false; + private Subscription subscription; - private DataBuffer dataBuffer; - - private volatile boolean subscriberComplete = false; - public ResponseBodySubscriber(ServletAsyncContextSynchronizer synchronizer, - int bufferSize) { + int bufferSize) throws IOException { this.synchronizer = synchronizer; this.bufferSize = bufferSize; + synchronizer.getResponse().getOutputStream().setWriteListener(writeListener); } @Override public void onSubscribe(Subscription subscription) { - this.subscription = subscription; - this.subscription.request(1); + logger.trace("onSubscribe. Subscription: " + subscription); + if (BackpressureUtils.validate(this.subscription, subscription)) { + this.subscription = subscription; + this.subscription.request(1); + } } @Override public void onNext(DataBuffer dataBuffer) { - Assert.isNull(this.dataBuffer); + Assert.state(this.dataBuffer == null); + logger.trace("onNext. buffer: " + dataBuffer); this.dataBuffer = dataBuffer; try { - onWritePossible(); + this.writeListener.onWritePossible(); } catch (IOException e) { onError(e); @@ -158,66 +165,93 @@ public class ServletServerHttpResponse extends AbstractServerHttpResponse { } @Override - public void onComplete() { - logger.trace("onComplete. buffer: " + dataBuffer); - - this.subscriberComplete = true; - - if (dataBuffer == null) { - this.synchronizer.writeComplete(); - } - } - - @Override - public void onWritePossible() throws IOException { - ServletOutputStream output = this.synchronizer.getOutputStream(); - - boolean ready = output.isReady(); - logger.trace("onWritePossible. ready: " + ready + " buffer: " + dataBuffer); - - if (ready) { - if (this.dataBuffer != null) { - int toBeWritten = this.dataBuffer.readableByteCount(); - InputStream input = this.dataBuffer.asInputStream(); - int writeCount = write(input, output); - logger.trace("written: " + writeCount + " total: " + toBeWritten); - if (writeCount == toBeWritten) { - this.dataBuffer = null; - if (!this.subscriberComplete) { - this.subscription.request(1); - } - else { - this.synchronizer.writeComplete(); - } - } - } - else if (this.subscription != null) { - this.subscription.request(1); - } - } - } - - private int write(InputStream in, ServletOutputStream output) throws IOException { - int byteCount = 0; - byte[] buffer = new byte[bufferSize]; - int bytesRead = -1; - while (output.isReady() && (bytesRead = in.read(buffer)) != -1) { - output.write(buffer, 0, bytesRead); - byteCount += bytesRead; - } - return byteCount; - } - - @Override - public void onError(Throwable ex) { - if (this.subscription != null) { - this.subscription.cancel(); - } - logger.error("ResponseBodySubscriber error", ex); + public void onError(Throwable t) { + logger.error("onError", t); HttpServletResponse response = (HttpServletResponse) this.synchronizer.getResponse(); response.setStatus(HttpStatus.INTERNAL_SERVER_ERROR.value()); this.synchronizer.complete(); + + } + + @Override + public void onComplete() { + logger.trace("onComplete. buffer: " + this.dataBuffer); + + this.completed = true; + + if (this.dataBuffer != null) { + try { + this.writeListener.onWritePossible(); + } + catch (IOException ex) { + onError(ex); + } + } + + if (this.dataBuffer == null) { + this.synchronizer.writeComplete(); + } + } + + private class ResponseBodyWriteListener implements WriteListener { + + @Override + public void onWritePossible() throws IOException { + logger.trace("onWritePossible"); + ServletOutputStream output = synchronizer.getResponse().getOutputStream(); + + boolean ready = output.isReady(); + logger.trace("ready: " + ready + " buffer: " + dataBuffer); + + if (ready) { + if (dataBuffer != null) { + + int total = dataBuffer.readableByteCount(); + int written = writeDataBuffer(); + + logger.trace("written: " + written + " total: " + total); + if (written == total) { + releaseBuffer(); + if (!completed) { + subscription.request(1); + } + else { + synchronizer.writeComplete(); + } + } + } + else if (subscription != null) { + subscription.request(1); + } + } + } + + private int writeDataBuffer() throws IOException { + InputStream input = dataBuffer.asInputStream(); + ServletOutputStream output = synchronizer.getResponse().getOutputStream(); + + int bytesWritten = 0; + byte[] buffer = new byte[bufferSize]; + int bytesRead = -1; + + while (output.isReady() && (bytesRead = input.read(buffer)) != -1) { + output.write(buffer, 0, bytesRead); + bytesWritten += bytesRead; + } + + return bytesWritten; + } + + private void releaseBuffer() { + // TODO: call PooledDataBuffer.release() when we it is introduced + dataBuffer = null; + } + + @Override + public void onError(Throwable ex) { + logger.error("ResponseBodyWriteListener error", ex); + } } } } \ No newline at end of file diff --git a/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/UndertowHttpHandlerAdapter.java b/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/UndertowHttpHandlerAdapter.java index 340ac750b03..d4d48b5a594 100644 --- a/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/UndertowHttpHandlerAdapter.java +++ b/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/UndertowHttpHandlerAdapter.java @@ -18,18 +18,11 @@ package org.springframework.http.server.reactive; import java.io.IOException; import java.nio.ByteBuffer; -import java.util.Queue; -import java.util.concurrent.ConcurrentLinkedQueue; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicLongFieldUpdater; import io.undertow.connector.PooledByteBuffer; import io.undertow.server.HttpServerExchange; -import io.undertow.util.SameThreadExecutor; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; -import org.reactivestreams.Publisher; import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; import org.xnio.ChannelListener; @@ -38,9 +31,7 @@ import org.xnio.IoUtils; import org.xnio.channels.StreamSinkChannel; import org.xnio.channels.StreamSourceChannel; import reactor.core.publisher.Mono; -import reactor.core.subscriber.BaseSubscriber; import reactor.core.util.BackpressureUtils; -import reactor.core.util.Exceptions; import org.springframework.core.io.buffer.DataBuffer; import org.springframework.core.io.buffer.DataBufferAllocator; @@ -75,14 +66,14 @@ public class UndertowHttpHandlerAdapter implements io.undertow.server.HttpHandle RequestBodyPublisher requestBody = new RequestBodyPublisher(exchange, allocator); ServerHttpRequest request = new UndertowServerHttpRequest(exchange, requestBody); - ResponseBodySubscriber responseBodySubscriber = new ResponseBodySubscriber(exchange); + ResponseBodySubscriber responseBodySubscriber = + new ResponseBodySubscriber(exchange); + ServerHttpResponse response = new UndertowServerHttpResponse(exchange, publisher -> Mono .from(subscriber -> publisher.subscribe(responseBodySubscriber)), allocator); - exchange.dispatch(); - this.delegate.handle(request, response).subscribe(new Subscriber() { @Override @@ -113,375 +104,212 @@ public class UndertowHttpHandlerAdapter implements io.undertow.server.HttpHandle }); } - private static class RequestBodyPublisher implements Publisher { + private static class RequestBodyPublisher extends AbstractResponseBodyPublisher { - private static final AtomicLongFieldUpdater DEMAND = - AtomicLongFieldUpdater.newUpdater(RequestBodySubscription.class, "demand"); + private static final Log logger = LogFactory.getLog(RequestBodyPublisher.class); + private final ChannelListener listener = + new RequestBodyListener(); - private final HttpServerExchange exchange; + private final StreamSourceChannel requestChannel; private final DataBufferAllocator allocator; - private Subscriber subscriber; + private final PooledByteBuffer pooledByteBuffer; public RequestBodyPublisher(HttpServerExchange exchange, DataBufferAllocator allocator) { - this.exchange = exchange; + this.requestChannel = exchange.getRequestChannel(); + this.requestChannel.getReadSetter().set(listener); + this.requestChannel.resumeReads(); + this.pooledByteBuffer = + exchange.getConnection().getByteBufferPool().allocate(); this.allocator = allocator; } - @Override - public void subscribe(Subscriber subscriber) { - if (subscriber == null) { - throw Exceptions.argumentIsNullException(); + private void close() { + if (this.pooledByteBuffer != null) { + IoUtils.safeClose(this.pooledByteBuffer); } - if (this.subscriber != null) { - subscriber.onError(new IllegalStateException("Only one subscriber allowed")); + if (this.requestChannel != null) { + IoUtils.safeClose(this.requestChannel); } - - this.subscriber = subscriber; - this.subscriber.onSubscribe(new RequestBodySubscription()); } + @Override + protected void noLongerStalled() { + listener.handleEvent(requestChannel); + } - private class RequestBodySubscription implements Subscription, Runnable, - ChannelListener { - - volatile long demand; - - private PooledByteBuffer pooledBuffer; - - private StreamSourceChannel channel; - - private boolean subscriptionClosed; - - private boolean draining; - - - @Override - public void request(long n) { - BackpressureUtils.checkRequest(n, subscriber); - if (this.subscriptionClosed) { - return; - } - BackpressureUtils.getAndAdd(DEMAND, this, n); - scheduleNextMessage(); - } - - private void scheduleNextMessage() { - exchange.dispatch(exchange.isInIoThread() ? SameThreadExecutor.INSTANCE : - exchange.getIoThread(), this); - } - - @Override - public void cancel() { - this.subscriptionClosed = true; - close(); - } - - private void close() { - if (this.pooledBuffer != null) { - IoUtils.safeClose(this.pooledBuffer); - this.pooledBuffer = null; - } - if (this.channel != null) { - IoUtils.safeClose(this.channel); - this.channel = null; - } - } - - @Override - public void run() { - if (this.subscriptionClosed || this.draining) { - return; - } - if (0 == BackpressureUtils.getAndSub(DEMAND, this, 1)) { - return; - } - - this.draining = true; - - if (this.channel == null) { - this.channel = exchange.getRequestChannel(); - - if (this.channel == null) { - if (exchange.isRequestComplete()) { - return; - } - else { - throw new IllegalStateException("Failed to acquire channel!"); - } - } - } - if (this.pooledBuffer == null) { - this.pooledBuffer = exchange.getConnection().getByteBufferPool().allocate(); - } - else { - this.pooledBuffer.getBuffer().clear(); - } - - try { - ByteBuffer buffer = this.pooledBuffer.getBuffer(); - int count; - do { - count = this.channel.read(buffer); - if (count == 0) { - this.channel.getReadSetter().set(this); - this.channel.resumeReads(); - } - else if (count == -1) { - if (buffer.position() > 0) { - doOnNext(buffer); - } - doOnComplete(); - } - else { - if (buffer.remaining() == 0) { - if (this.demand == 0) { - this.channel.suspendReads(); - } - doOnNext(buffer); - if (this.demand > 0) { - scheduleNextMessage(); - } - break; - } - } - } while (count > 0); - } - catch (IOException e) { - doOnError(e); - } - } - - private void doOnNext(ByteBuffer buffer) { - this.draining = false; - buffer.flip(); - DataBuffer dataBuffer = allocator.wrap(buffer); - subscriber.onNext(dataBuffer); - } - - private void doOnComplete() { - this.subscriptionClosed = true; - try { - subscriber.onComplete(); - } - finally { - close(); - } - } - - private void doOnError(Throwable t) { - this.subscriptionClosed = true; - try { - subscriber.onError(t); - } - finally { - close(); - } - } + private class RequestBodyListener + implements ChannelListener { @Override public void handleEvent(StreamSourceChannel channel) { - if (this.subscriptionClosed) { + if (isSubscriptionCancelled()) { return; } - + logger.trace("handleEvent"); + ByteBuffer byteBuffer = pooledByteBuffer.getBuffer(); try { - ByteBuffer buffer = this.pooledBuffer.getBuffer(); - int count; - do { - count = channel.read(buffer); - if (count == 0) { - return; + while (true) { + if (!checkSubscriptionForDemand()) { + break; } - else if (count == -1) { - if (buffer.position() > 0) { - doOnNext(buffer); - } - doOnComplete(); + int read = channel.read(byteBuffer); + logger.trace("Input read:" + read); + + if (read == -1) { + publishOnComplete(); + close(); + break; + } + else if (read == 0) { + // input not ready, wait until we are invoked again + break; } else { - if (buffer.remaining() == 0) { - if (this.demand == 0) { - channel.suspendReads(); - } - doOnNext(buffer); - if (this.demand > 0) { - scheduleNextMessage(); - } - break; - } + byteBuffer.flip(); + DataBuffer dataBuffer = allocator.wrap(byteBuffer); + publishOnNext(dataBuffer); } - } while (count > 0); + } } - catch (IOException e) { - doOnError(e); + catch (IOException ex) { + publishOnError(ex); } } } + } - private static class ResponseBodySubscriber - implements ChannelListener, BaseSubscriber{ + private static class ResponseBodySubscriber implements Subscriber { + + private static final Log logger = LogFactory.getLog(ResponseBodySubscriber.class); + + private final ChannelListener listener = + new ResponseBodyListener(); private final HttpServerExchange exchange; + private final StreamSinkChannel responseChannel; + + private volatile ByteBuffer byteBuffer; + + private volatile boolean completed = false; + private Subscription subscription; - private final Queue buffers = new ConcurrentLinkedQueue<>(); - - private final AtomicInteger writing = new AtomicInteger(); - - private final AtomicBoolean closing = new AtomicBoolean(); - - private StreamSinkChannel responseChannel; - - public ResponseBodySubscriber(HttpServerExchange exchange) { this.exchange = exchange; + this.responseChannel = exchange.getResponseChannel(); + this.responseChannel.getWriteSetter().set(listener); + this.responseChannel.resumeWrites(); } @Override public void onSubscribe(Subscription subscription) { - BaseSubscriber.super.onSubscribe(subscription); - this.subscription = subscription; - this.subscription.request(1); + logger.trace("onSubscribe. Subscription: " + subscription); + if (BackpressureUtils.validate(this.subscription, subscription)) { + this.subscription = subscription; + this.subscription.request(1); + } } @Override public void onNext(DataBuffer dataBuffer) { - BaseSubscriber.super.onNext(dataBuffer); + Assert.state(this.byteBuffer == null); + logger.trace("onNext. buffer: " + dataBuffer); - ByteBuffer buffer = dataBuffer.asByteBuffer(); - - if (this.responseChannel == null) { - this.responseChannel = exchange.getResponseChannel(); - } - - this.writing.incrementAndGet(); - try { - int c; - do { - c = this.responseChannel.write(buffer); - } while (buffer.hasRemaining() && c > 0); - - if (buffer.hasRemaining()) { - this.writing.incrementAndGet(); - enqueue(buffer); - this.responseChannel.getWriteSetter().set(this); - this.responseChannel.resumeWrites(); - } - else { - this.subscription.request(1); - } - - } - catch (IOException ex) { - onError(ex); - } - finally { - this.writing.decrementAndGet(); - if (this.closing.get()) { - closeIfDone(); - } - } - } - - private void enqueue(ByteBuffer src) { - do { - PooledByteBuffer buffer = exchange.getConnection().getByteBufferPool().allocate(); - ByteBuffer dst = buffer.getBuffer(); - copy(dst, src); - dst.flip(); - this.buffers.add(buffer); - } while (src.remaining() > 0); - } - - private void copy(ByteBuffer dst, ByteBuffer src) { - int n = Math.min(dst.capacity(), src.remaining()); - for (int i = 0; i < n; i++) { - dst.put(src.get()); - } + this.byteBuffer = dataBuffer.asByteBuffer(); } @Override - public void handleEvent(StreamSinkChannel channel) { - try { - int c; - do { - ByteBuffer buffer = this.buffers.peek().getBuffer(); - do { - c = channel.write(buffer); - } while (buffer.hasRemaining() && c > 0); - - if (!buffer.hasRemaining()) { - IoUtils.safeClose(this.buffers.remove()); - } - } while (!this.buffers.isEmpty() && c > 0); - - if (!this.buffers.isEmpty()) { - channel.resumeWrites(); - } - else { - this.writing.decrementAndGet(); - - if (this.closing.get()) { - closeIfDone(); - } - else { - this.subscription.request(1); - } - } - } - catch (IOException ex) { - onError(ex); - } - } - - @Override - public void onError(Throwable ex) { - BaseSubscriber.super.onError(ex); - logger.error("ResponseBodySubscriber error", ex); + public void onError(Throwable t) { + logger.error("onError", t); if (!exchange.isResponseStarted() && exchange.getStatusCode() < 500) { exchange.setStatusCode(500); } + closeChannel(responseChannel); } @Override public void onComplete() { - if (this.responseChannel != null) { - this.closing.set(true); - closeIfDone(); + logger.trace("onComplete. buffer: " + this.byteBuffer); + + this.completed = true; + + if (this.byteBuffer == null) { + closeChannel(responseChannel); } } - private void closeIfDone() { - if (this.writing.get() == 0) { - if (this.closing.compareAndSet(true, false)) { - closeChannel(); - } - } - } - - private void closeChannel() { + private void closeChannel(StreamSinkChannel channel) { try { - this.responseChannel.shutdownWrites(); + channel.shutdownWrites(); - if (!this.responseChannel.flush()) { - this.responseChannel.getWriteSetter().set(ChannelListeners - .flushingChannelListener( - o -> IoUtils.safeClose(this.responseChannel), + if (!channel.flush()) { + channel.getWriteSetter().set(ChannelListeners + .flushingChannelListener(o -> IoUtils.safeClose(channel), ChannelListeners.closingChannelExceptionHandler())); - this.responseChannel.resumeWrites(); + channel.resumeWrites(); } - this.responseChannel = null; } - catch (IOException ex) { - onError(ex); + catch (IOException ignored) { + logger.error(ignored, ignored); + } } + + private class ResponseBodyListener implements ChannelListener { + + @Override + public void handleEvent(StreamSinkChannel channel) { + if (byteBuffer != null) { + try { + int total = byteBuffer.remaining(); + int written = writeByteBuffer(channel); + + logger.trace("written: " + written + " total: " + total); + + if (written == total) { + releaseBuffer(); + if (!completed) { + subscription.request(1); + } + else { + closeChannel(channel); + } + } + } + catch (IOException ex) { + onError(ex); + } + } + else if (subscription != null) { + subscription.request(1); + } + + } + + private void releaseBuffer() { + byteBuffer = null; + + } + + private int writeByteBuffer(StreamSinkChannel channel) throws IOException { + int written; + int totalWritten = 0; + do { + written = channel.write(byteBuffer); + totalWritten += written; + } + while (byteBuffer.hasRemaining() && written > 0); + return totalWritten; + } + + } + } }