Refactoring of Servlet 3.1 and Undertow support

- Introduce abstract base class for Servlet 3.1 and Undertow support
- Simplify Undertow support
This commit is contained in:
Arjen Poutsma 2016-03-17 15:40:16 +01:00
parent f7c6c69e51
commit d20b0003c6
5 changed files with 504 additions and 558 deletions

View File

@ -0,0 +1,208 @@
/*
* Copyright 2002-2016 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.http.server.reactive;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicLong;
import org.reactivestreams.Publisher;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
import reactor.core.util.BackpressureUtils;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.util.Assert;
/**
* Abstract base class for {@code Publisher} implementations that bridge between
* event-listener APIs and Reactive Streams. Specifically, base class for the Servlet 3.1
* and Undertow support.
*
* @author Arjen Poutsma
* @see ServletServerHttpRequest
* @see UndertowHttpHandlerAdapter
*/
abstract class AbstractResponseBodyPublisher implements Publisher<DataBuffer> {
private ResponseBodySubscription subscription;
private volatile boolean stalled;
@Override
public void subscribe(Subscriber<? super DataBuffer> subscriber) {
Objects.requireNonNull(subscriber);
Assert.state(this.subscription == null, "Only a single subscriber allowed");
this.subscription = new ResponseBodySubscription(subscriber);
subscriber.onSubscribe(this.subscription);
}
/**
* Publishes the given signal to the subscriber.
* @param dataBuffer the signal to publish
* @see Subscriber#onNext(Object)
*/
protected final void publishOnNext(DataBuffer dataBuffer) {
Assert.state(this.subscription != null);
this.subscription.publishOnNext(dataBuffer);
}
/**
* Publishes the given error to the subscriber.
* @param t the error to publish
* @see Subscriber#onError(Throwable)
*/
protected final void publishOnError(Throwable t) {
if (this.subscription != null) {
this.subscription.publishOnError(t);
}
}
/**
* Publishes the complete signal to the subscriber.
* @see Subscriber#onComplete()
*/
protected final void publishOnComplete() {
if (this.subscription != null) {
this.subscription.publishOnComplete();
}
}
/**
* Returns true if the {@code Subscriber} associated with this {@code Publisher} has
* cancelled its {@code Subscription}.
* @return {@code true} if a subscriber has been registered and its subscription has
* been cancelled; {@code false} otherwise
* @see ResponseBodySubscription#isCancelled()
* @see Subscription#cancel()
*/
protected final boolean isSubscriptionCancelled() {
return (this.subscription != null && this.subscription.isCancelled());
}
/**
* Checks the subscription for demand, and marks this publisher as "stalled" if there
* is none. The next time the subscriber {@linkplain Subscription#request(long)
* requests} more events, the {@link #noLongerStalled()} method is called.
* @return {@code true} if there is demand; {@code false} otherwise
*/
protected final boolean checkSubscriptionForDemand() {
if (this.subscription == null || !this.subscription.hasDemand()) {
this.stalled = true;
return false;
}
else {
return true;
}
}
/**
* Abstract template method called when this publisher is no longer "stalled". Used in
* sub-classes to resume reading from the request.
*/
protected abstract void noLongerStalled();
private final class ResponseBodySubscription implements Subscription {
private final Subscriber<? super DataBuffer> subscriber;
private final AtomicLong demand = new AtomicLong();
private boolean cancelled;
public ResponseBodySubscription(Subscriber<? super DataBuffer> subscriber) {
Assert.notNull(subscriber, "'subscriber' must not be null");
this.subscriber = subscriber;
}
@Override
public final void cancel() {
this.cancelled = true;
}
/**
* Indicates whether this subscription has been cancelled.
* @see #cancel()
*/
protected final boolean isCancelled() {
return this.cancelled;
}
@Override
public final void request(long n) {
if (!isCancelled() && BackpressureUtils.checkRequest(n, this.subscriber)) {
long demand = BackpressureUtils.addAndGet(this.demand, n);
if (stalled && demand > 0) {
stalled = false;
noLongerStalled();
}
}
}
/**
* Indicates whether this subscription has demand.
* @see #request(long)
*/
protected final boolean hasDemand() {
return this.demand.get() > 0;
}
/**
* Publishes the given signal to the subscriber wrapped by this subscription, if
* it has not been cancelled. If there is {@linkplain #hasDemand() no demand} for
* the signal, an exception will be thrown.
* @param dataBuffer the signal to publish
* @see Subscriber#onNext(Object)
*/
protected final void publishOnNext(DataBuffer dataBuffer) {
if (!isCancelled()) {
if (hasDemand()) {
BackpressureUtils.getAndSub(this.demand, 1L);
this.subscriber.onNext(dataBuffer);
}
else {
throw new IllegalStateException("No demand for: " + dataBuffer);
}
}
}
/**
* Publishes the given error to the subscriber wrapped by this subscription, if it
* has not been cancelled.
* @param t the error to publish
* @see Subscriber#onError(Throwable)
*/
protected final void publishOnError(Throwable t) {
if (!isCancelled()) {
this.subscriber.onError(t);
}
}
/**
* Publishes the complete signal to the subscriber wrapped by this subscription,
* if it has not been cancelled.
* @see Subscriber#onComplete()
*/
protected final void publishOnComplete() {
if (!isCancelled()) {
this.subscriber.onComplete();
}
}
}
}

View File

@ -16,11 +16,8 @@
package org.springframework.http.server.reactive;
import java.io.IOException;
import java.util.concurrent.atomic.AtomicInteger;
import javax.servlet.AsyncContext;
import javax.servlet.ServletInputStream;
import javax.servlet.ServletOutputStream;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
@ -56,32 +53,20 @@ final class ServletAsyncContextSynchronizer {
this.asyncContext = asyncContext;
}
/**
* Returns the request of this synchronizer.
*/
public ServletRequest getRequest() {
return this.asyncContext.getRequest();
}
/**
* Returns the response of this synchronizer.
*/
public ServletResponse getResponse() {
return this.asyncContext.getResponse();
}
/**
* Returns the input stream of this synchronizer.
* @return the input stream
* @throws IOException if an input or output exception occurred
*/
public ServletInputStream getInputStream() throws IOException {
return getRequest().getInputStream();
}
/**
* Returns the output stream of this synchronizer.
* @return the output stream
* @throws IOException if an input or output exception occurred
*/
public ServletOutputStream getOutputStream() throws IOException {
return getResponse().getOutputStream();
}
/**
* Completes the reading side of the asynchronous operation. When both this method and
* {@link #writeComplete()} have been called, the {@code AsyncContext} will be

View File

@ -22,7 +22,6 @@ import java.net.URISyntaxException;
import java.nio.charset.Charset;
import java.util.Enumeration;
import java.util.Map;
import java.util.concurrent.atomic.AtomicLong;
import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;
import javax.servlet.http.Cookie;
@ -30,9 +29,6 @@ import javax.servlet.http.HttpServletRequest;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.reactivestreams.Publisher;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
import reactor.core.publisher.Flux;
import org.springframework.core.io.buffer.DataBuffer;
@ -68,7 +64,6 @@ public class ServletServerHttpRequest extends AbstractServerHttpRequest {
RequestBodyPublisher bodyPublisher =
new RequestBodyPublisher(synchronizer, allocator, bufferSize);
this.requestBodyPublisher = Flux.from(bodyPublisher);
this.request.getInputStream().setReadListener(bodyPublisher);
}
@ -142,8 +137,10 @@ public class ServletServerHttpRequest extends AbstractServerHttpRequest {
return this.requestBodyPublisher;
}
private static class RequestBodyPublisher
implements ReadListener, Publisher<DataBuffer> {
private static class RequestBodyPublisher extends AbstractResponseBodyPublisher {
private final RequestBodyReadListener readListener =
new RequestBodyReadListener();
private final ServletAsyncContextSynchronizer synchronizer;
@ -151,184 +148,78 @@ public class ServletServerHttpRequest extends AbstractServerHttpRequest {
private final byte[] buffer;
private final DemandCounter demand = new DemandCounter();
private Subscriber<? super DataBuffer> subscriber;
private boolean stalled;
private boolean cancelled;
public RequestBodyPublisher(ServletAsyncContextSynchronizer synchronizer,
DataBufferAllocator allocator, int bufferSize) {
DataBufferAllocator allocator, int bufferSize) throws IOException {
this.synchronizer = synchronizer;
this.allocator = allocator;
this.buffer = new byte[bufferSize];
synchronizer.getRequest().getInputStream().setReadListener(readListener);
}
@Override
public void subscribe(Subscriber<? super DataBuffer> subscriber) {
if (subscriber == null) {
throw new NullPointerException();
protected void noLongerStalled() {
try {
readListener.onDataAvailable();
}
else if (this.subscriber != null) {
subscriber.onError(
new IllegalStateException("Only one subscriber allowed"));
}
this.subscriber = subscriber;
this.subscriber.onSubscribe(new RequestBodySubscription());
}
@Override
public void onDataAvailable() throws IOException {
if (cancelled) {
return;
}
ServletInputStream input = this.synchronizer.getInputStream();
logger.trace("onDataAvailable: " + input);
while (true) {
logger.trace("Demand: " + this.demand);
if (!demand.hasDemand()) {
stalled = true;
break;
}
boolean ready = input.isReady();
logger.trace(
"Input ready: " + ready + " finished: " + input.isFinished());
if (!ready) {
break;
}
int read = input.read(buffer);
logger.trace("Input read:" + read);
if (read == -1) {
break;
}
else if (read > 0) {
this.demand.decrement();
DataBuffer dataBuffer = allocator.allocateBuffer(read);
dataBuffer.write(this.buffer, 0, read);
this.subscriber.onNext(dataBuffer);
}
catch (IOException ex) {
readListener.onError(ex);
}
}
@Override
public void onAllDataRead() throws IOException {
if (cancelled) {
return;
}
logger.trace("All data read");
this.synchronizer.readComplete();
if (this.subscriber != null) {
this.subscriber.onComplete();
}
}
@Override
public void onError(Throwable t) {
if (cancelled) {
return;
}
logger.trace("RequestBodyPublisher Error", t);
this.synchronizer.readComplete();
if (this.subscriber != null) {
this.subscriber.onError(t);
}
}
private class RequestBodySubscription implements Subscription {
private class RequestBodyReadListener implements ReadListener {
@Override
public void request(long n) {
if (cancelled) {
public void onDataAvailable() throws IOException {
if (isSubscriptionCancelled()) {
return;
}
logger.trace("Updating demand " + demand + " by " + n);
logger.trace("onDataAvailable");
ServletInputStream input = synchronizer.getRequest().getInputStream();
demand.increase(n);
logger.trace("Stalled: " + stalled);
if (stalled) {
stalled = false;
try {
onDataAvailable();
while (true) {
if (!checkSubscriptionForDemand()) {
break;
}
catch (IOException ex) {
onError(ex);
boolean ready = input.isReady();
logger.trace(
"Input ready: " + ready + " finished: " + input.isFinished());
if (!ready) {
break;
}
int read = input.read(buffer);
logger.trace("Input read:" + read);
if (read == -1) {
break;
}
else if (read > 0) {
DataBuffer dataBuffer = allocator.allocateBuffer(read);
dataBuffer.write(buffer, 0, read);
publishOnNext(dataBuffer);
}
}
}
@Override
public void cancel() {
if (cancelled) {
return;
}
cancelled = true;
public void onAllDataRead() throws IOException {
logger.trace("All data read");
synchronizer.readComplete();
demand.reset();
}
}
/**
* Small utility class for keeping track of Reactive Streams demand.
*/
private static final class DemandCounter {
private final AtomicLong demand = new AtomicLong();
/**
* Increases the demand by the given number
* @param n the positive number to increase demand by
* @return the increased demand
* @see Subscription#request(long)
*/
public long increase(long n) {
Assert.isTrue(n > 0, "'n' must be higher than 0");
return demand
.updateAndGet(d -> d != Long.MAX_VALUE ? d + n : Long.MAX_VALUE);
}
/**
* Decreases the demand by one.
* @return the decremented demand
*/
public long decrement() {
return demand
.updateAndGet(d -> d != Long.MAX_VALUE ? d - 1 : Long.MAX_VALUE);
}
/**
* Indicates whether this counter has demand, i.e. whether it is higher than
* 0.
* @return {@code true} if this counter has demand; {@code false} otherwise
*/
public boolean hasDemand() {
return this.demand.get() > 0;
}
/**
* Resets this counter to 0.
* @see Subscription#cancel()
*/
public void reset() {
this.demand.set(0);
publishOnComplete();
}
@Override
public String toString() {
return demand.toString();
public void onError(Throwable t) {
logger.trace("RequestBodyReadListener Error", t);
synchronizer.readComplete();
publishOnError(t);
}
}
}
}

View File

@ -32,6 +32,7 @@ import org.reactivestreams.Publisher;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
import reactor.core.publisher.Mono;
import reactor.core.util.BackpressureUtils;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferAllocator;
@ -61,7 +62,6 @@ public class ServletServerHttpResponse extends AbstractServerHttpResponse {
this.response = (HttpServletResponse) synchronizer.getResponse();
this.responseBodySubscriber =
new ResponseBodySubscriber(synchronizer, bufferSize);
this.response.getOutputStream().setWriteListener(responseBodySubscriber);
}
public HttpServletResponse getServletResponse() {
@ -118,39 +118,46 @@ public class ServletServerHttpResponse extends AbstractServerHttpResponse {
}
}
private static class ResponseBodySubscriber
implements WriteListener, Subscriber<DataBuffer> {
private static class ResponseBodySubscriber implements Subscriber<DataBuffer> {
private final ResponseBodyWriteListener writeListener =
new ResponseBodyWriteListener();
private final ServletAsyncContextSynchronizer synchronizer;
private final int bufferSize;
private volatile DataBuffer dataBuffer;
private volatile boolean completed = false;
private Subscription subscription;
private DataBuffer dataBuffer;
private volatile boolean subscriberComplete = false;
public ResponseBodySubscriber(ServletAsyncContextSynchronizer synchronizer,
int bufferSize) {
int bufferSize) throws IOException {
this.synchronizer = synchronizer;
this.bufferSize = bufferSize;
synchronizer.getResponse().getOutputStream().setWriteListener(writeListener);
}
@Override
public void onSubscribe(Subscription subscription) {
this.subscription = subscription;
this.subscription.request(1);
logger.trace("onSubscribe. Subscription: " + subscription);
if (BackpressureUtils.validate(this.subscription, subscription)) {
this.subscription = subscription;
this.subscription.request(1);
}
}
@Override
public void onNext(DataBuffer dataBuffer) {
Assert.isNull(this.dataBuffer);
Assert.state(this.dataBuffer == null);
logger.trace("onNext. buffer: " + dataBuffer);
this.dataBuffer = dataBuffer;
try {
onWritePossible();
this.writeListener.onWritePossible();
}
catch (IOException e) {
onError(e);
@ -158,66 +165,93 @@ public class ServletServerHttpResponse extends AbstractServerHttpResponse {
}
@Override
public void onComplete() {
logger.trace("onComplete. buffer: " + dataBuffer);
this.subscriberComplete = true;
if (dataBuffer == null) {
this.synchronizer.writeComplete();
}
}
@Override
public void onWritePossible() throws IOException {
ServletOutputStream output = this.synchronizer.getOutputStream();
boolean ready = output.isReady();
logger.trace("onWritePossible. ready: " + ready + " buffer: " + dataBuffer);
if (ready) {
if (this.dataBuffer != null) {
int toBeWritten = this.dataBuffer.readableByteCount();
InputStream input = this.dataBuffer.asInputStream();
int writeCount = write(input, output);
logger.trace("written: " + writeCount + " total: " + toBeWritten);
if (writeCount == toBeWritten) {
this.dataBuffer = null;
if (!this.subscriberComplete) {
this.subscription.request(1);
}
else {
this.synchronizer.writeComplete();
}
}
}
else if (this.subscription != null) {
this.subscription.request(1);
}
}
}
private int write(InputStream in, ServletOutputStream output) throws IOException {
int byteCount = 0;
byte[] buffer = new byte[bufferSize];
int bytesRead = -1;
while (output.isReady() && (bytesRead = in.read(buffer)) != -1) {
output.write(buffer, 0, bytesRead);
byteCount += bytesRead;
}
return byteCount;
}
@Override
public void onError(Throwable ex) {
if (this.subscription != null) {
this.subscription.cancel();
}
logger.error("ResponseBodySubscriber error", ex);
public void onError(Throwable t) {
logger.error("onError", t);
HttpServletResponse response =
(HttpServletResponse) this.synchronizer.getResponse();
response.setStatus(HttpStatus.INTERNAL_SERVER_ERROR.value());
this.synchronizer.complete();
}
@Override
public void onComplete() {
logger.trace("onComplete. buffer: " + this.dataBuffer);
this.completed = true;
if (this.dataBuffer != null) {
try {
this.writeListener.onWritePossible();
}
catch (IOException ex) {
onError(ex);
}
}
if (this.dataBuffer == null) {
this.synchronizer.writeComplete();
}
}
private class ResponseBodyWriteListener implements WriteListener {
@Override
public void onWritePossible() throws IOException {
logger.trace("onWritePossible");
ServletOutputStream output = synchronizer.getResponse().getOutputStream();
boolean ready = output.isReady();
logger.trace("ready: " + ready + " buffer: " + dataBuffer);
if (ready) {
if (dataBuffer != null) {
int total = dataBuffer.readableByteCount();
int written = writeDataBuffer();
logger.trace("written: " + written + " total: " + total);
if (written == total) {
releaseBuffer();
if (!completed) {
subscription.request(1);
}
else {
synchronizer.writeComplete();
}
}
}
else if (subscription != null) {
subscription.request(1);
}
}
}
private int writeDataBuffer() throws IOException {
InputStream input = dataBuffer.asInputStream();
ServletOutputStream output = synchronizer.getResponse().getOutputStream();
int bytesWritten = 0;
byte[] buffer = new byte[bufferSize];
int bytesRead = -1;
while (output.isReady() && (bytesRead = input.read(buffer)) != -1) {
output.write(buffer, 0, bytesRead);
bytesWritten += bytesRead;
}
return bytesWritten;
}
private void releaseBuffer() {
// TODO: call PooledDataBuffer.release() when we it is introduced
dataBuffer = null;
}
@Override
public void onError(Throwable ex) {
logger.error("ResponseBodyWriteListener error", ex);
}
}
}
}

View File

@ -18,18 +18,11 @@ package org.springframework.http.server.reactive;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.Queue;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLongFieldUpdater;
import io.undertow.connector.PooledByteBuffer;
import io.undertow.server.HttpServerExchange;
import io.undertow.util.SameThreadExecutor;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.reactivestreams.Publisher;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
import org.xnio.ChannelListener;
@ -38,9 +31,7 @@ import org.xnio.IoUtils;
import org.xnio.channels.StreamSinkChannel;
import org.xnio.channels.StreamSourceChannel;
import reactor.core.publisher.Mono;
import reactor.core.subscriber.BaseSubscriber;
import reactor.core.util.BackpressureUtils;
import reactor.core.util.Exceptions;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferAllocator;
@ -75,14 +66,14 @@ public class UndertowHttpHandlerAdapter implements io.undertow.server.HttpHandle
RequestBodyPublisher requestBody = new RequestBodyPublisher(exchange, allocator);
ServerHttpRequest request = new UndertowServerHttpRequest(exchange, requestBody);
ResponseBodySubscriber responseBodySubscriber = new ResponseBodySubscriber(exchange);
ResponseBodySubscriber responseBodySubscriber =
new ResponseBodySubscriber(exchange);
ServerHttpResponse response = new UndertowServerHttpResponse(exchange,
publisher -> Mono
.from(subscriber -> publisher.subscribe(responseBodySubscriber)),
allocator);
exchange.dispatch();
this.delegate.handle(request, response).subscribe(new Subscriber<Void>() {
@Override
@ -113,375 +104,212 @@ public class UndertowHttpHandlerAdapter implements io.undertow.server.HttpHandle
});
}
private static class RequestBodyPublisher implements Publisher<DataBuffer> {
private static class RequestBodyPublisher extends AbstractResponseBodyPublisher {
private static final AtomicLongFieldUpdater<RequestBodySubscription> DEMAND =
AtomicLongFieldUpdater.newUpdater(RequestBodySubscription.class, "demand");
private static final Log logger = LogFactory.getLog(RequestBodyPublisher.class);
private final ChannelListener<StreamSourceChannel> listener =
new RequestBodyListener();
private final HttpServerExchange exchange;
private final StreamSourceChannel requestChannel;
private final DataBufferAllocator allocator;
private Subscriber<? super DataBuffer> subscriber;
private final PooledByteBuffer pooledByteBuffer;
public RequestBodyPublisher(HttpServerExchange exchange,
DataBufferAllocator allocator) {
this.exchange = exchange;
this.requestChannel = exchange.getRequestChannel();
this.requestChannel.getReadSetter().set(listener);
this.requestChannel.resumeReads();
this.pooledByteBuffer =
exchange.getConnection().getByteBufferPool().allocate();
this.allocator = allocator;
}
@Override
public void subscribe(Subscriber<? super DataBuffer> subscriber) {
if (subscriber == null) {
throw Exceptions.argumentIsNullException();
private void close() {
if (this.pooledByteBuffer != null) {
IoUtils.safeClose(this.pooledByteBuffer);
}
if (this.subscriber != null) {
subscriber.onError(new IllegalStateException("Only one subscriber allowed"));
if (this.requestChannel != null) {
IoUtils.safeClose(this.requestChannel);
}
this.subscriber = subscriber;
this.subscriber.onSubscribe(new RequestBodySubscription());
}
@Override
protected void noLongerStalled() {
listener.handleEvent(requestChannel);
}
private class RequestBodySubscription implements Subscription, Runnable,
ChannelListener<StreamSourceChannel> {
volatile long demand;
private PooledByteBuffer pooledBuffer;
private StreamSourceChannel channel;
private boolean subscriptionClosed;
private boolean draining;
@Override
public void request(long n) {
BackpressureUtils.checkRequest(n, subscriber);
if (this.subscriptionClosed) {
return;
}
BackpressureUtils.getAndAdd(DEMAND, this, n);
scheduleNextMessage();
}
private void scheduleNextMessage() {
exchange.dispatch(exchange.isInIoThread() ? SameThreadExecutor.INSTANCE :
exchange.getIoThread(), this);
}
@Override
public void cancel() {
this.subscriptionClosed = true;
close();
}
private void close() {
if (this.pooledBuffer != null) {
IoUtils.safeClose(this.pooledBuffer);
this.pooledBuffer = null;
}
if (this.channel != null) {
IoUtils.safeClose(this.channel);
this.channel = null;
}
}
@Override
public void run() {
if (this.subscriptionClosed || this.draining) {
return;
}
if (0 == BackpressureUtils.getAndSub(DEMAND, this, 1)) {
return;
}
this.draining = true;
if (this.channel == null) {
this.channel = exchange.getRequestChannel();
if (this.channel == null) {
if (exchange.isRequestComplete()) {
return;
}
else {
throw new IllegalStateException("Failed to acquire channel!");
}
}
}
if (this.pooledBuffer == null) {
this.pooledBuffer = exchange.getConnection().getByteBufferPool().allocate();
}
else {
this.pooledBuffer.getBuffer().clear();
}
try {
ByteBuffer buffer = this.pooledBuffer.getBuffer();
int count;
do {
count = this.channel.read(buffer);
if (count == 0) {
this.channel.getReadSetter().set(this);
this.channel.resumeReads();
}
else if (count == -1) {
if (buffer.position() > 0) {
doOnNext(buffer);
}
doOnComplete();
}
else {
if (buffer.remaining() == 0) {
if (this.demand == 0) {
this.channel.suspendReads();
}
doOnNext(buffer);
if (this.demand > 0) {
scheduleNextMessage();
}
break;
}
}
} while (count > 0);
}
catch (IOException e) {
doOnError(e);
}
}
private void doOnNext(ByteBuffer buffer) {
this.draining = false;
buffer.flip();
DataBuffer dataBuffer = allocator.wrap(buffer);
subscriber.onNext(dataBuffer);
}
private void doOnComplete() {
this.subscriptionClosed = true;
try {
subscriber.onComplete();
}
finally {
close();
}
}
private void doOnError(Throwable t) {
this.subscriptionClosed = true;
try {
subscriber.onError(t);
}
finally {
close();
}
}
private class RequestBodyListener
implements ChannelListener<StreamSourceChannel> {
@Override
public void handleEvent(StreamSourceChannel channel) {
if (this.subscriptionClosed) {
if (isSubscriptionCancelled()) {
return;
}
logger.trace("handleEvent");
ByteBuffer byteBuffer = pooledByteBuffer.getBuffer();
try {
ByteBuffer buffer = this.pooledBuffer.getBuffer();
int count;
do {
count = channel.read(buffer);
if (count == 0) {
return;
while (true) {
if (!checkSubscriptionForDemand()) {
break;
}
else if (count == -1) {
if (buffer.position() > 0) {
doOnNext(buffer);
}
doOnComplete();
int read = channel.read(byteBuffer);
logger.trace("Input read:" + read);
if (read == -1) {
publishOnComplete();
close();
break;
}
else if (read == 0) {
// input not ready, wait until we are invoked again
break;
}
else {
if (buffer.remaining() == 0) {
if (this.demand == 0) {
channel.suspendReads();
}
doOnNext(buffer);
if (this.demand > 0) {
scheduleNextMessage();
}
break;
}
byteBuffer.flip();
DataBuffer dataBuffer = allocator.wrap(byteBuffer);
publishOnNext(dataBuffer);
}
} while (count > 0);
}
}
catch (IOException e) {
doOnError(e);
catch (IOException ex) {
publishOnError(ex);
}
}
}
}
private static class ResponseBodySubscriber
implements ChannelListener<StreamSinkChannel>, BaseSubscriber<DataBuffer>{
private static class ResponseBodySubscriber implements Subscriber<DataBuffer> {
private static final Log logger = LogFactory.getLog(ResponseBodySubscriber.class);
private final ChannelListener<StreamSinkChannel> listener =
new ResponseBodyListener();
private final HttpServerExchange exchange;
private final StreamSinkChannel responseChannel;
private volatile ByteBuffer byteBuffer;
private volatile boolean completed = false;
private Subscription subscription;
private final Queue<PooledByteBuffer> buffers = new ConcurrentLinkedQueue<>();
private final AtomicInteger writing = new AtomicInteger();
private final AtomicBoolean closing = new AtomicBoolean();
private StreamSinkChannel responseChannel;
public ResponseBodySubscriber(HttpServerExchange exchange) {
this.exchange = exchange;
this.responseChannel = exchange.getResponseChannel();
this.responseChannel.getWriteSetter().set(listener);
this.responseChannel.resumeWrites();
}
@Override
public void onSubscribe(Subscription subscription) {
BaseSubscriber.super.onSubscribe(subscription);
this.subscription = subscription;
this.subscription.request(1);
logger.trace("onSubscribe. Subscription: " + subscription);
if (BackpressureUtils.validate(this.subscription, subscription)) {
this.subscription = subscription;
this.subscription.request(1);
}
}
@Override
public void onNext(DataBuffer dataBuffer) {
BaseSubscriber.super.onNext(dataBuffer);
Assert.state(this.byteBuffer == null);
logger.trace("onNext. buffer: " + dataBuffer);
ByteBuffer buffer = dataBuffer.asByteBuffer();
if (this.responseChannel == null) {
this.responseChannel = exchange.getResponseChannel();
}
this.writing.incrementAndGet();
try {
int c;
do {
c = this.responseChannel.write(buffer);
} while (buffer.hasRemaining() && c > 0);
if (buffer.hasRemaining()) {
this.writing.incrementAndGet();
enqueue(buffer);
this.responseChannel.getWriteSetter().set(this);
this.responseChannel.resumeWrites();
}
else {
this.subscription.request(1);
}
}
catch (IOException ex) {
onError(ex);
}
finally {
this.writing.decrementAndGet();
if (this.closing.get()) {
closeIfDone();
}
}
}
private void enqueue(ByteBuffer src) {
do {
PooledByteBuffer buffer = exchange.getConnection().getByteBufferPool().allocate();
ByteBuffer dst = buffer.getBuffer();
copy(dst, src);
dst.flip();
this.buffers.add(buffer);
} while (src.remaining() > 0);
}
private void copy(ByteBuffer dst, ByteBuffer src) {
int n = Math.min(dst.capacity(), src.remaining());
for (int i = 0; i < n; i++) {
dst.put(src.get());
}
this.byteBuffer = dataBuffer.asByteBuffer();
}
@Override
public void handleEvent(StreamSinkChannel channel) {
try {
int c;
do {
ByteBuffer buffer = this.buffers.peek().getBuffer();
do {
c = channel.write(buffer);
} while (buffer.hasRemaining() && c > 0);
if (!buffer.hasRemaining()) {
IoUtils.safeClose(this.buffers.remove());
}
} while (!this.buffers.isEmpty() && c > 0);
if (!this.buffers.isEmpty()) {
channel.resumeWrites();
}
else {
this.writing.decrementAndGet();
if (this.closing.get()) {
closeIfDone();
}
else {
this.subscription.request(1);
}
}
}
catch (IOException ex) {
onError(ex);
}
}
@Override
public void onError(Throwable ex) {
BaseSubscriber.super.onError(ex);
logger.error("ResponseBodySubscriber error", ex);
public void onError(Throwable t) {
logger.error("onError", t);
if (!exchange.isResponseStarted() && exchange.getStatusCode() < 500) {
exchange.setStatusCode(500);
}
closeChannel(responseChannel);
}
@Override
public void onComplete() {
if (this.responseChannel != null) {
this.closing.set(true);
closeIfDone();
logger.trace("onComplete. buffer: " + this.byteBuffer);
this.completed = true;
if (this.byteBuffer == null) {
closeChannel(responseChannel);
}
}
private void closeIfDone() {
if (this.writing.get() == 0) {
if (this.closing.compareAndSet(true, false)) {
closeChannel();
}
}
}
private void closeChannel() {
private void closeChannel(StreamSinkChannel channel) {
try {
this.responseChannel.shutdownWrites();
channel.shutdownWrites();
if (!this.responseChannel.flush()) {
this.responseChannel.getWriteSetter().set(ChannelListeners
.flushingChannelListener(
o -> IoUtils.safeClose(this.responseChannel),
if (!channel.flush()) {
channel.getWriteSetter().set(ChannelListeners
.flushingChannelListener(o -> IoUtils.safeClose(channel),
ChannelListeners.closingChannelExceptionHandler()));
this.responseChannel.resumeWrites();
channel.resumeWrites();
}
this.responseChannel = null;
}
catch (IOException ex) {
onError(ex);
catch (IOException ignored) {
logger.error(ignored, ignored);
}
}
private class ResponseBodyListener implements ChannelListener<StreamSinkChannel> {
@Override
public void handleEvent(StreamSinkChannel channel) {
if (byteBuffer != null) {
try {
int total = byteBuffer.remaining();
int written = writeByteBuffer(channel);
logger.trace("written: " + written + " total: " + total);
if (written == total) {
releaseBuffer();
if (!completed) {
subscription.request(1);
}
else {
closeChannel(channel);
}
}
}
catch (IOException ex) {
onError(ex);
}
}
else if (subscription != null) {
subscription.request(1);
}
}
private void releaseBuffer() {
byteBuffer = null;
}
private int writeByteBuffer(StreamSinkChannel channel) throws IOException {
int written;
int totalWritten = 0;
do {
written = channel.write(byteBuffer);
totalWritten += written;
}
while (byteBuffer.hasRemaining() && written > 0);
return totalWritten;
}
}
}
}