diff --git a/spring-core/src/main/java/org/springframework/core/codec/AbstractDataBufferDecoder.java b/spring-core/src/main/java/org/springframework/core/codec/AbstractDataBufferDecoder.java index 3107ea0bb7..b03d8079db 100644 --- a/spring-core/src/main/java/org/springframework/core/codec/AbstractDataBufferDecoder.java +++ b/spring-core/src/main/java/org/springframework/core/codec/AbstractDataBufferDecoder.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -54,17 +54,17 @@ public abstract class AbstractDataBufferDecoder extends AbstractDecoder { @Override - public Flux decode(Publisher inputStream, ResolvableType elementType, + public Flux decode(Publisher input, ResolvableType elementType, @Nullable MimeType mimeType, @Nullable Map hints) { - return Flux.from(inputStream).map(buffer -> decodeDataBuffer(buffer, elementType, mimeType, hints)); + return Flux.from(input).map(buffer -> decodeDataBuffer(buffer, elementType, mimeType, hints)); } @Override - public Mono decodeToMono(Publisher inputStream, ResolvableType elementType, + public Mono decodeToMono(Publisher input, ResolvableType elementType, @Nullable MimeType mimeType, @Nullable Map hints) { - return DataBufferUtils.join(inputStream) + return DataBufferUtils.join(input) .map(buffer -> decodeDataBuffer(buffer, elementType, mimeType, hints)); } diff --git a/spring-core/src/main/java/org/springframework/core/codec/AbstractDecoder.java b/spring-core/src/main/java/org/springframework/core/codec/AbstractDecoder.java index f545c311c7..8ab00206d9 100644 --- a/spring-core/src/main/java/org/springframework/core/codec/AbstractDecoder.java +++ b/spring-core/src/main/java/org/springframework/core/codec/AbstractDecoder.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -78,7 +78,12 @@ public abstract class AbstractDecoder implements Decoder { if (mimeType == null) { return true; } - return this.decodableMimeTypes.stream().anyMatch(candidate -> candidate.isCompatibleWith(mimeType)); + for (MimeType candidate : this.decodableMimeTypes) { + if (candidate.isCompatibleWith(mimeType)) { + return true; + } + } + return false; } @Override diff --git a/spring-core/src/main/java/org/springframework/core/codec/AbstractEncoder.java b/spring-core/src/main/java/org/springframework/core/codec/AbstractEncoder.java index 2712f48528..452b4ed5ca 100644 --- a/spring-core/src/main/java/org/springframework/core/codec/AbstractEncoder.java +++ b/spring-core/src/main/java/org/springframework/core/codec/AbstractEncoder.java @@ -74,7 +74,7 @@ public abstract class AbstractEncoder implements Encoder { if (mimeType == null) { return true; } - for(MimeType candidate : this.encodableMimeTypes) { + for (MimeType candidate : this.encodableMimeTypes) { if (candidate.isCompatibleWith(mimeType)) { return true; } diff --git a/spring-core/src/main/java/org/springframework/core/codec/DataBufferDecoder.java b/spring-core/src/main/java/org/springframework/core/codec/DataBufferDecoder.java index ee9ae4ac62..17d6a424ab 100644 --- a/spring-core/src/main/java/org/springframework/core/codec/DataBufferDecoder.java +++ b/spring-core/src/main/java/org/springframework/core/codec/DataBufferDecoder.java @@ -57,10 +57,10 @@ public class DataBufferDecoder extends AbstractDataBufferDecoder { } @Override - public Flux decode(Publisher inputStream, ResolvableType elementType, + public Flux decode(Publisher input, ResolvableType elementType, @Nullable MimeType mimeType, @Nullable Map hints) { - return Flux.from(inputStream); + return Flux.from(input); } @Override diff --git a/spring-core/src/main/java/org/springframework/core/codec/ResourceEncoder.java b/spring-core/src/main/java/org/springframework/core/codec/ResourceEncoder.java index f989cc8cae..58b0a094fb 100644 --- a/spring-core/src/main/java/org/springframework/core/codec/ResourceEncoder.java +++ b/spring-core/src/main/java/org/springframework/core/codec/ResourceEncoder.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -65,15 +65,14 @@ public class ResourceEncoder extends AbstractSingleValueEncoder { } @Override - protected Flux encode(Resource resource, DataBufferFactory dataBufferFactory, + protected Flux encode(Resource resource, DataBufferFactory bufferFactory, ResolvableType type, @Nullable MimeType mimeType, @Nullable Map hints) { if (logger.isDebugEnabled() && !Hints.isLoggingSuppressed(hints)) { String logPrefix = Hints.getLogPrefix(hints); logger.debug(logPrefix + "Writing [" + resource + "]"); } - - return DataBufferUtils.read(resource, dataBufferFactory, this.bufferSize); + return DataBufferUtils.read(resource, bufferFactory, this.bufferSize); } } diff --git a/spring-core/src/main/java/org/springframework/core/codec/ResourceRegionEncoder.java b/spring-core/src/main/java/org/springframework/core/codec/ResourceRegionEncoder.java index 3695a5514f..3e9d19c20d 100644 --- a/spring-core/src/main/java/org/springframework/core/codec/ResourceRegionEncoder.java +++ b/spring-core/src/main/java/org/springframework/core/codec/ResourceRegionEncoder.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -76,16 +76,16 @@ public class ResourceRegionEncoder extends AbstractEncoder { } @Override - public Flux encode(Publisher inputStream, + public Flux encode(Publisher input, DataBufferFactory bufferFactory, ResolvableType elementType, @Nullable MimeType mimeType, @Nullable Map hints) { - Assert.notNull(inputStream, "'inputStream' must not be null"); + Assert.notNull(input, "'inputStream' must not be null"); Assert.notNull(bufferFactory, "'bufferFactory' must not be null"); Assert.notNull(elementType, "'elementType' must not be null"); - if (inputStream instanceof Mono) { - return Mono.from(inputStream) + if (input instanceof Mono) { + return Mono.from(input) .flatMapMany(region -> { if (!region.getResource().isReadable()) { return Flux.error(new EncodingException( @@ -96,32 +96,25 @@ public class ResourceRegionEncoder extends AbstractEncoder { } else { final String boundaryString = Hints.getRequiredHint(hints, BOUNDARY_STRING_HINT); - byte[] startBoundary = getAsciiBytes("\r\n--" + boundaryString + "\r\n"); - byte[] contentType = mimeType != null ? getAsciiBytes("Content-Type: " + mimeType + "\r\n") : new byte[0]; + byte[] startBoundary = toAsciiBytes("\r\n--" + boundaryString + "\r\n"); + byte[] contentType = mimeType != null ? toAsciiBytes("Content-Type: " + mimeType + "\r\n") : new byte[0]; - return Flux.from(inputStream). - concatMap(region -> { + return Flux.from(input) + .concatMap(region -> { if (!region.getResource().isReadable()) { return Flux.error(new EncodingException( "Resource " + region.getResource() + " is not readable")); } - else { - return Flux.concat( - getRegionPrefix(bufferFactory, startBoundary, contentType, region), - writeResourceRegion(region, bufferFactory, hints)); - } + Flux prefix = Flux.just( + bufferFactory.wrap(startBoundary), + bufferFactory.wrap(contentType), + bufferFactory.wrap(getContentRangeHeader(region))); // only wrapping, no allocation + + return prefix.concatWith(writeResourceRegion(region, bufferFactory, hints)); }) - .concatWith(getRegionSuffix(bufferFactory, boundaryString)); + .concatWithValues(getRegionSuffix(bufferFactory, boundaryString)); } - } - - private Flux getRegionPrefix(DataBufferFactory bufferFactory, byte[] startBoundary, - byte[] contentType, ResourceRegion region) { - - return Flux.just( - bufferFactory.wrap(startBoundary), - bufferFactory.wrap(contentType), - bufferFactory.wrap(getContentRangeHeader(region))); // only wrapping, no allocation + // No doOnDiscard (no caching after DataBufferUtils#read) } private Flux writeResourceRegion( @@ -140,12 +133,12 @@ public class ResourceRegionEncoder extends AbstractEncoder { return DataBufferUtils.takeUntilByteCount(in, count); } - private Flux getRegionSuffix(DataBufferFactory bufferFactory, String boundaryString) { - byte[] endBoundary = getAsciiBytes("\r\n--" + boundaryString + "--"); - return Flux.just(bufferFactory.wrap(endBoundary)); + private DataBuffer getRegionSuffix(DataBufferFactory bufferFactory, String boundaryString) { + byte[] endBoundary = toAsciiBytes("\r\n--" + boundaryString + "--"); + return bufferFactory.wrap(endBoundary); } - private byte[] getAsciiBytes(String in) { + private byte[] toAsciiBytes(String in) { return in.getBytes(StandardCharsets.US_ASCII); } @@ -155,10 +148,10 @@ public class ResourceRegionEncoder extends AbstractEncoder { OptionalLong contentLength = contentLength(region.getResource()); if (contentLength.isPresent()) { long length = contentLength.getAsLong(); - return getAsciiBytes("Content-Range: bytes " + start + '-' + end + '/' + length + "\r\n\r\n"); + return toAsciiBytes("Content-Range: bytes " + start + '-' + end + '/' + length + "\r\n\r\n"); } else { - return getAsciiBytes("Content-Range: bytes " + start + '-' + end + "\r\n\r\n"); + return toAsciiBytes("Content-Range: bytes " + start + '-' + end + "\r\n\r\n"); } } diff --git a/spring-core/src/main/java/org/springframework/core/codec/StringDecoder.java b/spring-core/src/main/java/org/springframework/core/codec/StringDecoder.java index 7d74ba7bb1..28cf7df55e 100644 --- a/spring-core/src/main/java/org/springframework/core/codec/StringDecoder.java +++ b/spring-core/src/main/java/org/springframework/core/codec/StringDecoder.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,7 +25,6 @@ import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; -import java.util.stream.Collectors; import org.reactivestreams.Publisher; import reactor.core.publisher.Flux; @@ -88,14 +87,14 @@ public final class StringDecoder extends AbstractDataBufferDecoder { } @Override - public Flux decode(Publisher inputStream, ResolvableType elementType, + public Flux decode(Publisher input, ResolvableType elementType, @Nullable MimeType mimeType, @Nullable Map hints) { List delimiterBytes = getDelimiterBytes(mimeType); - Flux inputFlux = Flux.from(inputStream) - .flatMapIterable(dataBuffer -> splitOnDelimiter(dataBuffer, delimiterBytes)) - .bufferUntil(StringDecoder::isEndFrame) + Flux inputFlux = Flux.from(input) + .flatMapIterable(buffer -> splitOnDelimiter(buffer, delimiterBytes)) + .bufferUntil(buffer -> buffer == END_FRAME) .map(StringDecoder::joinUntilEndFrame) .doOnDiscard(PooledDataBuffer.class, DataBufferUtils::release); @@ -103,51 +102,60 @@ public final class StringDecoder extends AbstractDataBufferDecoder { } private List getDelimiterBytes(@Nullable MimeType mimeType) { - return this.delimitersCache.computeIfAbsent(getCharset(mimeType), - charset -> this.delimiters.stream() - .map(s -> s.getBytes(charset)) - .collect(Collectors.toList())); + return this.delimitersCache.computeIfAbsent(getCharset(mimeType), charset -> { + List list = new ArrayList<>(); + for (String delimiter : this.delimiters) { + byte[] bytes = delimiter.getBytes(charset); + list.add(bytes); + } + return list; + }); } /** * Split the given data buffer on delimiter boundaries. * The returned Flux contains an {@link #END_FRAME} buffer after each delimiter. */ - private List splitOnDelimiter(DataBuffer dataBuffer, List delimiterBytes) { + private List splitOnDelimiter(DataBuffer buffer, List delimiterBytes) { List frames = new ArrayList<>(); - do { - int length = Integer.MAX_VALUE; - byte[] matchingDelimiter = null; - for (byte[] delimiter : delimiterBytes) { - int index = indexOf(dataBuffer, delimiter); - if (index >= 0 && index < length) { - length = index; - matchingDelimiter = delimiter; + try { + do { + int length = Integer.MAX_VALUE; + byte[] matchingDelimiter = null; + for (byte[] delimiter : delimiterBytes) { + int index = indexOf(buffer, delimiter); + if (index >= 0 && index < length) { + length = index; + matchingDelimiter = delimiter; + } } - } - DataBuffer frame; - int readPosition = dataBuffer.readPosition(); - if (matchingDelimiter != null) { - if (this.stripDelimiter) { - frame = dataBuffer.slice(readPosition, length); + DataBuffer frame; + int readPosition = buffer.readPosition(); + if (matchingDelimiter != null) { + frame = this.stripDelimiter ? + buffer.slice(readPosition, length) : + buffer.slice(readPosition, length + matchingDelimiter.length); + buffer.readPosition(readPosition + length + matchingDelimiter.length); + frames.add(DataBufferUtils.retain(frame)); + frames.add(END_FRAME); } else { - frame = dataBuffer.slice(readPosition, length + matchingDelimiter.length); + frame = buffer.slice(readPosition, buffer.readableByteCount()); + buffer.readPosition(readPosition + buffer.readableByteCount()); + frames.add(DataBufferUtils.retain(frame)); } - dataBuffer.readPosition(readPosition + length + matchingDelimiter.length); - - frames.add(DataBufferUtils.retain(frame)); - frames.add(END_FRAME); - } - else { - frame = dataBuffer.slice(readPosition, dataBuffer.readableByteCount()); - dataBuffer.readPosition(readPosition + dataBuffer.readableByteCount()); - frames.add(DataBufferUtils.retain(frame)); } + while (buffer.readableByteCount() > 0); + } + catch (Throwable ex) { + for (DataBuffer frame : frames) { + DataBufferUtils.release(frame); + } + throw ex; + } + finally { + DataBufferUtils.release(buffer); } - while (dataBuffer.readableByteCount() > 0); - - DataBufferUtils.release(dataBuffer); return frames; } @@ -155,44 +163,38 @@ public final class StringDecoder extends AbstractDataBufferDecoder { * Find the given delimiter in the given data buffer. * @return the index of the delimiter, or -1 if not found. */ - private static int indexOf(DataBuffer dataBuffer, byte[] delimiter) { - for (int i = dataBuffer.readPosition(); i < dataBuffer.writePosition(); i++) { - int dataBufferPos = i; + private static int indexOf(DataBuffer buffer, byte[] delimiter) { + for (int i = buffer.readPosition(); i < buffer.writePosition(); i++) { + int bufferPos = i; int delimiterPos = 0; while (delimiterPos < delimiter.length) { - if (dataBuffer.getByte(dataBufferPos) != delimiter[delimiterPos]) { + if (buffer.getByte(bufferPos) != delimiter[delimiterPos]) { break; } else { - dataBufferPos++; - if (dataBufferPos == dataBuffer.writePosition() && - delimiterPos != delimiter.length - 1) { + bufferPos++; + boolean endOfBuffer = bufferPos == buffer.writePosition(); + boolean endOfDelimiter = delimiterPos == delimiter.length - 1; + if (endOfBuffer && !endOfDelimiter) { return -1; } } delimiterPos++; } if (delimiterPos == delimiter.length) { - return i - dataBuffer.readPosition(); + return i - buffer.readPosition(); } } return -1; } - /** - * Check whether the given buffer is {@link #END_FRAME}. - */ - private static boolean isEndFrame(DataBuffer dataBuffer) { - return dataBuffer == END_FRAME; - } - /** * Join the given list of buffers into a single buffer. */ private static DataBuffer joinUntilEndFrame(List dataBuffers) { if (!dataBuffers.isEmpty()) { int lastIdx = dataBuffers.size() - 1; - if (isEndFrame(dataBuffers.get(lastIdx))) { + if (dataBuffers.get(lastIdx) == END_FRAME) { dataBuffers.remove(lastIdx); } } diff --git a/spring-core/src/test/java/org/springframework/core/codec/ResourceRegionEncoderTests.java b/spring-core/src/test/java/org/springframework/core/codec/ResourceRegionEncoderTests.java index d96b96b522..99f058fd05 100644 --- a/spring-core/src/test/java/org/springframework/core/codec/ResourceRegionEncoderTests.java +++ b/spring-core/src/test/java/org/springframework/core/codec/ResourceRegionEncoderTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,6 +19,8 @@ package org.springframework.core.codec; import java.util.Collections; import java.util.function.Consumer; +import io.netty.buffer.PooledByteBufAllocator; +import org.junit.After; import org.junit.Test; import org.reactivestreams.Subscription; import reactor.core.publisher.BaseSubscriber; @@ -32,6 +34,7 @@ import org.springframework.core.io.Resource; import org.springframework.core.io.buffer.DataBuffer; import org.springframework.core.io.buffer.DataBufferUtils; import org.springframework.core.io.buffer.LeakAwareDataBufferFactory; +import org.springframework.core.io.buffer.NettyDataBufferFactory; import org.springframework.core.io.buffer.support.DataBufferTestUtils; import org.springframework.core.io.support.ResourceRegion; import org.springframework.util.MimeType; @@ -48,9 +51,15 @@ public class ResourceRegionEncoderTests { private ResourceRegionEncoder encoder = new ResourceRegionEncoder(); - private LeakAwareDataBufferFactory bufferFactory = new LeakAwareDataBufferFactory(); + private LeakAwareDataBufferFactory bufferFactory = + new LeakAwareDataBufferFactory(new NettyDataBufferFactory(PooledByteBufAllocator.DEFAULT)); + @After + public void tearDown() throws Exception { + this.bufferFactory.checkForLeaks(); + } + @Test public void canEncode() { ResolvableType resourceRegion = ResolvableType.forClass(ResourceRegion.class); @@ -79,8 +88,6 @@ public class ResourceRegionEncoderTests { .consumeNextWith(stringConsumer("Spring")) .expectComplete() .verify(); - - this.bufferFactory.checkForLeaks(); } @Test @@ -120,8 +127,6 @@ public class ResourceRegionEncoderTests { .consumeNextWith(stringConsumer("\r\n--" + boundary + "--")) .expectComplete() .verify(); - - this.bufferFactory.checkForLeaks(); } @Test // gh-22107 @@ -144,8 +149,23 @@ public class ResourceRegionEncoderTests { ZeroDemandSubscriber subscriber = new ZeroDemandSubscriber(); flux.subscribe(subscriber); subscriber.cancel(); + } - this.bufferFactory.checkForLeaks(); + @Test // gh-22107 + public void cancelWithoutDemandForSingleResourceRegion() { + Resource resource = new ClassPathResource("ResourceRegionEncoderTests.txt", getClass()); + Mono regions = Mono.just(new ResourceRegion(resource, 0, 6)); + String boundary = MimeTypeUtils.generateMultipartBoundaryString(); + + Flux flux = this.encoder.encode(regions, this.bufferFactory, + ResolvableType.forClass(ResourceRegion.class), + MimeType.valueOf("text/plain"), + Collections.singletonMap(ResourceRegionEncoder.BOUNDARY_STRING_HINT, boundary) + ); + + ZeroDemandSubscriber subscriber = new ZeroDemandSubscriber(); + flux.subscribe(subscriber); + subscriber.cancel(); } @Test @@ -170,14 +190,11 @@ public class ResourceRegionEncoderTests { .consumeNextWith(stringConsumer("Spring")) .expectError(EncodingException.class) .verify(); - - this.bufferFactory.checkForLeaks(); } protected Consumer stringConsumer(String expected) { return dataBuffer -> { - String value = - DataBufferTestUtils.dumpString(dataBuffer, UTF_8); + String value = DataBufferTestUtils.dumpString(dataBuffer, UTF_8); DataBufferUtils.release(dataBuffer); assertEquals(expected, value); }; diff --git a/spring-web/src/main/java/org/springframework/http/codec/EncoderHttpMessageWriter.java b/spring-web/src/main/java/org/springframework/http/codec/EncoderHttpMessageWriter.java index d1067e29ae..e8058691db 100644 --- a/spring-web/src/main/java/org/springframework/http/codec/EncoderHttpMessageWriter.java +++ b/spring-web/src/main/java/org/springframework/http/codec/EncoderHttpMessageWriter.java @@ -125,13 +125,19 @@ public class EncoderHttpMessageWriter implements HttpMessageWriter { })) .flatMap(buffer -> { headers.setContentLength(buffer.readableByteCount()); - return message.writeWith(Mono.fromCallable(() -> buffer) - .doOnDiscard(PooledDataBuffer.class, PooledDataBuffer::release)); + return message.writeWith( + Mono.fromCallable(() -> buffer) + .doOnDiscard(PooledDataBuffer.class, PooledDataBuffer::release)); }); } - return (isStreamingMediaType(contentType) ? - message.writeAndFlushWith(body.map(Flux::just)) : message.writeWith(body)); + if (isStreamingMediaType(contentType)) { + return message.writeAndFlushWith(body.map(buffer -> + Mono.fromCallable(() -> buffer) + .doOnDiscard(PooledDataBuffer.class, PooledDataBuffer::release))); + } + + return message.writeWith(body); } @Nullable @@ -162,10 +168,16 @@ public class EncoderHttpMessageWriter implements HttpMessageWriter { } private boolean isStreamingMediaType(@Nullable MediaType contentType) { - return (contentType != null && this.encoder instanceof HttpMessageEncoder && - ((HttpMessageEncoder) this.encoder).getStreamingMediaTypes().stream() - .anyMatch(streamingMediaType -> contentType.isCompatibleWith(streamingMediaType) && - contentType.getParameters().entrySet().containsAll(streamingMediaType.getParameters().keySet()))); + if (contentType == null || !(this.encoder instanceof HttpMessageEncoder)) { + return false; + } + for (MediaType mediaType : ((HttpMessageEncoder) this.encoder).getStreamingMediaTypes()) { + if (contentType.isCompatibleWith(mediaType) && + contentType.getParameters().entrySet().containsAll(mediaType.getParameters().keySet())) { + return true; + } + } + return false; } diff --git a/spring-web/src/main/java/org/springframework/http/codec/FormHttpMessageReader.java b/spring-web/src/main/java/org/springframework/http/codec/FormHttpMessageReader.java index f9fd993434..01c2d30b33 100644 --- a/spring-web/src/main/java/org/springframework/http/codec/FormHttpMessageReader.java +++ b/spring-web/src/main/java/org/springframework/http/codec/FormHttpMessageReader.java @@ -56,7 +56,7 @@ public class FormHttpMessageReader extends LoggingCodecSupport */ public static final Charset DEFAULT_CHARSET = StandardCharsets.UTF_8; - private static final ResolvableType MULTIVALUE_TYPE = + private static final ResolvableType MULTIVALUE_STRINGS_TYPE = ResolvableType.forClassWithGenerics(MultiValueMap.class, String.class, String.class); @@ -83,9 +83,11 @@ public class FormHttpMessageReader extends LoggingCodecSupport @Override public boolean canRead(ResolvableType elementType, @Nullable MediaType mediaType) { - return ((MULTIVALUE_TYPE.isAssignableFrom(elementType) || - (elementType.hasUnresolvableGenerics() && - MultiValueMap.class.isAssignableFrom(elementType.toClass()))) && + boolean multiValueUnresolved = + elementType.hasUnresolvableGenerics() && + MultiValueMap.class.isAssignableFrom(elementType.toClass()); + + return ((MULTIVALUE_STRINGS_TYPE.isAssignableFrom(elementType) || multiValueUnresolved) && (mediaType == null || MediaType.APPLICATION_FORM_URLENCODED.isCompatibleWith(mediaType))); } diff --git a/spring-web/src/main/java/org/springframework/http/codec/ServerSentEventHttpMessageReader.java b/spring-web/src/main/java/org/springframework/http/codec/ServerSentEventHttpMessageReader.java index c9f772a0c9..d66f59c601 100644 --- a/spring-web/src/main/java/org/springframework/http/codec/ServerSentEventHttpMessageReader.java +++ b/spring-web/src/main/java/org/springframework/http/codec/ServerSentEventHttpMessageReader.java @@ -164,8 +164,8 @@ public class ServerSentEventHttpMessageReader implements HttpMessageReader input = Mono.just(bufferFactory.wrap(bytes)); - return this.decoder.decodeToMono(input, dataType, MediaType.TEXT_EVENT_STREAM, hints); + DataBuffer buffer = bufferFactory.wrap(bytes); // wrapping only, no allocation + return this.decoder.decodeToMono(Mono.just(buffer), dataType, MediaType.TEXT_EVENT_STREAM, hints); } @Override diff --git a/spring-web/src/main/java/org/springframework/http/codec/ServerSentEventHttpMessageWriter.java b/spring-web/src/main/java/org/springframework/http/codec/ServerSentEventHttpMessageWriter.java index 2a15257393..fec0ed09ec 100644 --- a/spring-web/src/main/java/org/springframework/http/codec/ServerSentEventHttpMessageWriter.java +++ b/spring-web/src/main/java/org/springframework/http/codec/ServerSentEventHttpMessageWriter.java @@ -184,7 +184,7 @@ public class ServerSentEventHttpMessageWriter implements HttpMessageWriter encodeText(CharSequence text, MediaType mediaType, DataBufferFactory bufferFactory) { Assert.notNull(mediaType.getCharset(), "Expected MediaType with charset"); byte[] bytes = text.toString().getBytes(mediaType.getCharset()); - return Mono.fromCallable(() -> bufferFactory.wrap(bytes)); // wrapping, not allocating + return Mono.just(bufferFactory.wrap(bytes)); // wrapping, not allocating } @Override diff --git a/spring-web/src/main/java/org/springframework/http/codec/protobuf/ProtobufDecoder.java b/spring-web/src/main/java/org/springframework/http/codec/protobuf/ProtobufDecoder.java index ab0348c04e..c1f0e17bf7 100644 --- a/spring-web/src/main/java/org/springframework/http/codec/protobuf/ProtobufDecoder.java +++ b/spring-web/src/main/java/org/springframework/http/codec/protobuf/ProtobufDecoder.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -77,6 +77,7 @@ public class ProtobufDecoder extends ProtobufCodecSupport implements Decoder, Method> methodCache = new ConcurrentReferenceHashMap<>(); + private final ExtensionRegistry extensionRegistry; private int maxMessageSize = DEFAULT_MESSAGE_MAX_SIZE; @@ -114,8 +115,12 @@ public class ProtobufDecoder extends ProtobufCodecSupport implements Decoder decode(Publisher inputStream, ResolvableType elementType, @Nullable MimeType mimeType, @Nullable Map hints) { + MessageDecoderFunction decoderFunction = + new MessageDecoderFunction(elementType, this.maxMessageSize); + return Flux.from(inputStream) - .flatMapIterable(new MessageDecoderFunction(elementType, this.maxMessageSize)); + .flatMapIterable(decoderFunction) + .doOnTerminate(decoderFunction::discard); } @Override @@ -212,12 +217,13 @@ public class ProtobufDecoder extends ProtobufCodecSupport implements Decoder 0); return messages; @@ -286,6 +292,12 @@ public class ProtobufDecoder extends ProtobufCodecSupport implements Decoder encode(Publisher inputStream, DataBufferFactory bufferFactory, ResolvableType elementType, @Nullable MimeType mimeType, @Nullable Map hints) { - return Flux - .from(inputStream) - .map(message -> encodeMessage(message, bufferFactory, !(inputStream instanceof Mono))); - } - - private DataBuffer encodeMessage(Message message, DataBufferFactory bufferFactory, boolean streaming) { - DataBuffer buffer = bufferFactory.allocateBuffer(); - OutputStream outputStream = buffer.asOutputStream(); - try { - if (streaming) { - message.writeDelimitedTo(outputStream); - } - else { - message.writeTo(outputStream); - } - return buffer; - } - catch (IOException ex) { - throw new IllegalStateException("Unexpected I/O error while writing to data buffer", ex); - } + return Flux.from(inputStream) + .map(message -> { + DataBuffer buffer = bufferFactory.allocateBuffer(); + boolean release = true; + try { + if (!(inputStream instanceof Mono)) { + message.writeDelimitedTo(buffer.asOutputStream()); + } + else { + message.writeTo(buffer.asOutputStream()); + } + release = false; + return buffer; + } + catch (IOException ex) { + throw new IllegalStateException("Unexpected I/O error while writing to data buffer", ex); + } + finally { + if (release) { + DataBufferUtils.release(buffer); + } + } + }); } @Override diff --git a/spring-web/src/main/java/org/springframework/http/codec/xml/Jaxb2XmlEncoder.java b/spring-web/src/main/java/org/springframework/http/codec/xml/Jaxb2XmlEncoder.java index c112d369ea..59a11970ad 100644 --- a/spring-web/src/main/java/org/springframework/http/codec/xml/Jaxb2XmlEncoder.java +++ b/spring-web/src/main/java/org/springframework/http/codec/xml/Jaxb2XmlEncoder.java @@ -111,13 +111,13 @@ public class Jaxb2XmlEncoder extends AbstractSingleValueEncoder { return Flux.defer(() -> { boolean release = true; DataBuffer buffer = bufferFactory.allocateBuffer(1024); - OutputStream outputStream = buffer.asOutputStream(); - Class clazz = ClassUtils.getUserClass(value); try { + OutputStream outputStream = buffer.asOutputStream(); + Class clazz = ClassUtils.getUserClass(value); Marshaller marshaller = initMarshaller(clazz); marshaller.marshal(value, outputStream); release = false; - return Mono.fromCallable(() -> buffer); // Rely on doOnDiscard in base class + return Mono.fromCallable(() -> buffer); // relying on doOnDiscard in base class } catch (MarshalException ex) { return Flux.error(new EncodingException( diff --git a/spring-web/src/main/java/org/springframework/http/codec/xml/XmlEventDecoder.java b/spring-web/src/main/java/org/springframework/http/codec/xml/XmlEventDecoder.java index e6eeace45c..d44ff23a11 100644 --- a/spring-web/src/main/java/org/springframework/http/codec/xml/XmlEventDecoder.java +++ b/spring-web/src/main/java/org/springframework/http/codec/xml/XmlEventDecoder.java @@ -35,7 +35,6 @@ import com.fasterxml.aalto.stax.InputFactoryImpl; import org.reactivestreams.Publisher; import reactor.core.Exceptions; import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; import org.springframework.core.ResolvableType; import org.springframework.core.codec.AbstractDecoder; @@ -97,30 +96,32 @@ public class XmlEventDecoder extends AbstractDecoder { @Override @SuppressWarnings({"rawtypes", "unchecked", "cast"}) // on JDK 9 where XMLEventReader is Iterator instead of simply Iterator - public Flux decode(Publisher inputStream, ResolvableType elementType, + public Flux decode(Publisher input, ResolvableType elementType, @Nullable MimeType mimeType, @Nullable Map hints) { - Flux flux = Flux.from(inputStream); if (this.useAalto) { - AaltoDataBufferToXmlEvent aaltoMapper = new AaltoDataBufferToXmlEvent(); - return flux.flatMap(aaltoMapper) - .doFinally(signalType -> aaltoMapper.endOfInput()); + AaltoDataBufferToXmlEvent mapper = new AaltoDataBufferToXmlEvent(); + return Flux.from(input) + .flatMapIterable(mapper) + .doFinally(signalType -> mapper.endOfInput()); } else { - Mono singleBuffer = DataBufferUtils.join(flux); - return singleBuffer.flatMapIterable(dataBuffer -> { - InputStream is = dataBuffer.asInputStream(); - return () -> { - try { - // Explicit cast to (Iterator) is necessary on JDK 9+ since XMLEventReader - // now extends Iterator instead of simply Iterator - return (Iterator) inputFactory.createXMLEventReader(is); - } - catch (XMLStreamException ex) { - throw Exceptions.propagate(ex); - } - }; - }); + return DataBufferUtils.join(input). + flatMapIterable(buffer -> { + try { + InputStream is = buffer.asInputStream(); + Iterator eventReader = inputFactory.createXMLEventReader(is); + List result = new ArrayList<>(); + eventReader.forEachRemaining(event -> result.add((XMLEvent) event)); + return result; + } + catch (XMLStreamException ex) { + throw Exceptions.propagate(ex); + } + finally { + DataBufferUtils.release(buffer); + } + }); } } @@ -128,7 +129,7 @@ public class XmlEventDecoder extends AbstractDecoder { /* * Separate static class to isolate Aalto dependency. */ - private static class AaltoDataBufferToXmlEvent implements Function> { + private static class AaltoDataBufferToXmlEvent implements Function> { private static final AsyncXMLInputFactory inputFactory = StaxUtils.createDefensiveInputFactory(InputFactoryImpl::new); @@ -140,7 +141,7 @@ public class XmlEventDecoder extends AbstractDecoder { @Override - public Publisher apply(DataBuffer dataBuffer) { + public List apply(DataBuffer dataBuffer) { try { this.streamReader.getInputFeeder().feedInput(dataBuffer.asByteBuffer()); List events = new ArrayList<>(); @@ -157,10 +158,10 @@ public class XmlEventDecoder extends AbstractDecoder { } } } - return Flux.fromIterable(events); + return events; } catch (XMLStreamException ex) { - return Mono.error(ex); + throw Exceptions.propagate(ex); } finally { DataBufferUtils.release(dataBuffer); diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractServerHttpResponse.java b/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractServerHttpResponse.java index d570992d8e..d3541ff92f 100644 --- a/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractServerHttpResponse.java +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractServerHttpResponse.java @@ -180,8 +180,7 @@ public abstract class AbstractServerHttpResponse implements ServerHttpResponse { @Override public final Mono writeAndFlushWith(Publisher> body) { - return new ChannelSendOperator<>(body, - writePublisher -> doCommit(() -> writeAndFlushWithInternal(writePublisher))) + return new ChannelSendOperator<>(body, inner -> doCommit(() -> writeAndFlushWithInternal(inner))) .doOnError(t -> removeContentLength()); } diff --git a/spring-web/src/test/java/org/springframework/http/codec/CodecDataBufferLeakTests.java b/spring-web/src/test/java/org/springframework/http/codec/CancelWithoutDemandCodecTests.java similarity index 76% rename from spring-web/src/test/java/org/springframework/http/codec/CodecDataBufferLeakTests.java rename to spring-web/src/test/java/org/springframework/http/codec/CancelWithoutDemandCodecTests.java index 484020cb40..509e135542 100644 --- a/spring-web/src/test/java/org/springframework/http/codec/CodecDataBufferLeakTests.java +++ b/spring-web/src/test/java/org/springframework/http/codec/CancelWithoutDemandCodecTests.java @@ -20,6 +20,8 @@ import java.util.Collections; import java.util.List; import java.util.function.Supplier; +import com.google.protobuf.Message; +import org.junit.After; import org.junit.Test; import org.reactivestreams.Publisher; import org.reactivestreams.Subscription; @@ -38,18 +40,28 @@ import org.springframework.http.ReactiveHttpOutputMessage; import org.springframework.http.client.MultipartBodyBuilder; import org.springframework.http.codec.json.Jackson2JsonEncoder; import org.springframework.http.codec.multipart.MultipartHttpMessageWriter; +import org.springframework.http.codec.protobuf.ProtobufDecoder; +import org.springframework.http.codec.protobuf.ProtobufEncoder; import org.springframework.http.codec.xml.Jaxb2XmlEncoder; +import org.springframework.protobuf.Msg; +import org.springframework.protobuf.SecondMsg; +import org.springframework.util.MimeType; /** * Test scenarios for data buffer leaks. * @author Rossen Stoyanchev - * @since 5.2 */ -public class CodecDataBufferLeakTests { +public class CancelWithoutDemandCodecTests { private final LeakAwareDataBufferFactory bufferFactory = new LeakAwareDataBufferFactory(); + @After + public void tearDown() throws Exception { + this.bufferFactory.checkForLeaks(); + } + + @Test // gh-22107 public void cancelWithEncoderHttpMessageWriterAndSingleValue() { CharSequenceEncoder encoder = CharSequenceEncoder.allMimeTypes(); @@ -58,8 +70,6 @@ public class CodecDataBufferLeakTests { writer.write(Mono.just("foo"), ResolvableType.forType(String.class), MediaType.TEXT_PLAIN, outputMessage, Collections.emptyMap()).block(Duration.ofSeconds(5)); - - this.bufferFactory.checkForLeaks(); } @Test // gh-22107 @@ -73,8 +83,6 @@ public class CodecDataBufferLeakTests { BaseSubscriber subscriber = new ZeroDemandSubscriber(); flux.subscribe(subscriber); // Assume sync execution (e.g. encoding with Flux.just).. subscriber.cancel(); - - this.bufferFactory.checkForLeaks(); } @Test // gh-22107 @@ -88,8 +96,39 @@ public class CodecDataBufferLeakTests { BaseSubscriber subscriber = new ZeroDemandSubscriber(); flux.subscribe(subscriber); // Assume sync execution (e.g. encoding with Flux.just).. subscriber.cancel(); + } - this.bufferFactory.checkForLeaks(); + @Test // gh-22543 + public void cancelWithProtobufEncoder() { + ProtobufEncoder encoder = new ProtobufEncoder(); + Msg msg = Msg.newBuilder().setFoo("Foo").setBlah(SecondMsg.newBuilder().setBlah(123).build()).build(); + + Flux flux = encoder.encode(Mono.just(msg), + this.bufferFactory, ResolvableType.forClass(Msg.class), + new MimeType("application", "x-protobuf"), Collections.emptyMap()); + + BaseSubscriber subscriber = new ZeroDemandSubscriber(); + flux.subscribe(subscriber); // Assume sync execution (e.g. encoding with Flux.just).. + subscriber.cancel(); + } + + @Test // gh-22731 + public void cancelWithProtobufDecoder() throws InterruptedException { + ProtobufDecoder decoder = new ProtobufDecoder(); + + Mono input = Mono.fromCallable(() -> { + Msg msg = Msg.newBuilder().setFoo("Foo").build(); + byte[] bytes = msg.toByteArray(); + DataBuffer buffer = this.bufferFactory.allocateBuffer(bytes.length); + buffer.write(bytes); + return buffer; + }); + + Flux messages = decoder.decode(input, ResolvableType.forType(Msg.class), + new MimeType("application", "x-protobuf"), Collections.emptyMap()); + ZeroDemandMessageSubscriber subscriber = new ZeroDemandMessageSubscriber(); + messages.subscribe(subscriber); + subscriber.cancel(); } @Test // gh-22107 @@ -104,8 +143,6 @@ public class CodecDataBufferLeakTests { writer.write(Mono.just(builder.build()), null, MediaType.MULTIPART_FORM_DATA, outputMessage, Collections.emptyMap()).block(Duration.ofSeconds(5)); - - this.bufferFactory.checkForLeaks(); } @Test // gh-22107 @@ -116,8 +153,6 @@ public class CodecDataBufferLeakTests { writer.write(Mono.just(event), ResolvableType.forClass(ServerSentEvent.class), MediaType.TEXT_EVENT_STREAM, outputMessage, Collections.emptyMap()).block(Duration.ofSeconds(5)); - - this.bufferFactory.checkForLeaks(); } @@ -183,4 +218,13 @@ public class CodecDataBufferLeakTests { // Just subscribe without requesting } } + + + private static class ZeroDemandMessageSubscriber extends BaseSubscriber { + + @Override + protected void hookOnSubscribe(Subscription subscription) { + // Just subscribe without requesting + } + } } diff --git a/spring-web/src/test/java/org/springframework/http/codec/protobuf/ProtobufDecoderTests.java b/spring-web/src/test/java/org/springframework/http/codec/protobuf/ProtobufDecoderTests.java index 68a737b07e..07e944fc38 100644 --- a/spring-web/src/test/java/org/springframework/http/codec/protobuf/ProtobufDecoderTests.java +++ b/spring-web/src/test/java/org/springframework/http/codec/protobuf/ProtobufDecoderTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -35,10 +35,10 @@ import org.springframework.protobuf.Msg; import org.springframework.protobuf.SecondMsg; import org.springframework.util.MimeType; -import static java.util.Collections.emptyMap; +import static java.util.Collections.*; import static org.junit.Assert.*; -import static org.springframework.core.ResolvableType.forClass; -import static org.springframework.core.io.buffer.DataBufferUtils.release; +import static org.springframework.core.ResolvableType.*; +import static org.springframework.core.io.buffer.DataBufferUtils.*; /** * Unit tests for {@link ProtobufDecoder}. @@ -223,11 +223,11 @@ public class ProtobufDecoderTests extends AbstractDecoderTestCase dataBuffer(Msg msg) { - return Mono.defer(() -> { + return Mono.fromCallable(() -> { byte[] bytes = msg.toByteArray(); DataBuffer buffer = this.bufferFactory.allocateBuffer(bytes.length); buffer.write(bytes); - return Mono.just(buffer); + return buffer; }); }