diff --git a/spring-core/src/main/java/org/springframework/core/codec/StringDecoder.java b/spring-core/src/main/java/org/springframework/core/codec/StringDecoder.java index 1bd9f3d8be..b267a54c43 100644 --- a/spring-core/src/main/java/org/springframework/core/codec/StringDecoder.java +++ b/spring-core/src/main/java/org/springframework/core/codec/StringDecoder.java @@ -16,6 +16,9 @@ package org.springframework.core.codec; +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.ByteBuffer; import java.nio.CharBuffer; import java.nio.charset.Charset; import java.nio.charset.StandardCharsets; @@ -25,13 +28,17 @@ import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; +import java.util.function.IntPredicate; import org.reactivestreams.Publisher; import reactor.core.publisher.Flux; import org.springframework.core.ResolvableType; import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferFactory; import org.springframework.core.io.buffer.DataBufferUtils; +import org.springframework.core.io.buffer.DefaultDataBufferFactory; +import org.springframework.core.io.buffer.PooledDataBuffer; import org.springframework.core.log.LogFormatUtils; import org.springframework.lang.Nullable; import org.springframework.util.Assert; @@ -88,8 +95,15 @@ public final class StringDecoder extends AbstractDataBufferDecoder { byte[][] delimiterBytes = getDelimiterBytes(mimeType); - Flux inputFlux = - DataBufferUtils.split(input, delimiterBytes, this.stripDelimiter); + Flux inputFlux = Flux.defer(() -> { + DataBufferUtils.Matcher matcher = DataBufferUtils.matcher(delimiterBytes); + return Flux.from(input) + .concatMapIterable(buffer -> endFrameAfterDelimiter(buffer, matcher)) + .bufferUntil(buffer -> buffer instanceof EndFrameBuffer) + .map(buffers -> joinAndStrip(buffers, this.stripDelimiter)) + .doOnDiscard(PooledDataBuffer.class, DataBufferUtils::release); + + }); return super.decode(inputFlux, elementType, mimeType, hints); } @@ -128,6 +142,69 @@ public final class StringDecoder extends AbstractDataBufferDecoder { } } + /** + * Finds the first match and longest delimiter, {@link EndFrameBuffer} just after it. + * + * @param dataBuffer the buffer to find delimiters in + * @param matcher used to find the first delimiters + * @return a flux of buffers, containing {@link EndFrameBuffer} after each delimiter that was + * found in {@code dataBuffer}. Returns Flux, because returning List (w/ flatMapIterable) + * results in memory leaks due to pre-fetching. + */ + private static List endFrameAfterDelimiter(DataBuffer dataBuffer, DataBufferUtils.Matcher matcher) { + List result = new ArrayList<>(); + do { + int endIdx = matcher.match(dataBuffer); + if (endIdx != -1) { + int readPosition = dataBuffer.readPosition(); + int length = endIdx - readPosition + 1; + result.add(dataBuffer.retainedSlice(readPosition, length)); + result.add(new EndFrameBuffer(matcher.delimiter())); + dataBuffer.readPosition(endIdx + 1); + } + else { + result.add(DataBufferUtils.retain(dataBuffer)); + break; + } + } + while (dataBuffer.readableByteCount() > 0); + + DataBufferUtils.release(dataBuffer); + return result; + } + + /** + * Joins the given list of buffers. If the list ends with a {@link EndFrameBuffer}, it is + * removed. If {@code stripDelimiter} is {@code true} and the resulting buffer ends with + * a delimiter, it is removed. + * @param dataBuffers the data buffers to join + * @param stripDelimiter whether to strip the delimiter + * @return the joined buffer + */ + private static DataBuffer joinAndStrip(List dataBuffers, + boolean stripDelimiter) { + + Assert.state(!dataBuffers.isEmpty(), "DataBuffers should not be empty"); + + byte[] matchingDelimiter = null; + + int lastIdx = dataBuffers.size() - 1; + DataBuffer lastBuffer = dataBuffers.get(lastIdx); + if (lastBuffer instanceof EndFrameBuffer) { + matchingDelimiter = ((EndFrameBuffer) lastBuffer).delimiter(); + dataBuffers.remove(lastIdx); + } + + DataBuffer result = dataBuffers.get(0).factory().join(dataBuffers); + + if (stripDelimiter && matchingDelimiter != null) { + result.writePosition(result.writePosition() - matchingDelimiter.length); + } + return result; + } + + + /** * Create a {@code StringDecoder} for {@code "text/plain"}. @@ -186,4 +263,167 @@ public final class StringDecoder extends AbstractDataBufferDecoder { new MimeType("text", "plain", DEFAULT_CHARSET), MimeTypeUtils.ALL); } + + private static class EndFrameBuffer implements DataBuffer { + + private static final DataBuffer BUFFER = new DefaultDataBufferFactory().wrap(new byte[0]); + + private byte[] delimiter; + + + public EndFrameBuffer(byte[] delimiter) { + this.delimiter = delimiter; + } + + public byte[] delimiter() { + return this.delimiter; + } + + @Override + public DataBufferFactory factory() { + return BUFFER.factory(); + } + + @Override + public int indexOf(IntPredicate predicate, int fromIndex) { + return BUFFER.indexOf(predicate, fromIndex); + } + + @Override + public int lastIndexOf(IntPredicate predicate, int fromIndex) { + return BUFFER.lastIndexOf(predicate, fromIndex); + } + + @Override + public int readableByteCount() { + return BUFFER.readableByteCount(); + } + + @Override + public int writableByteCount() { + return BUFFER.writableByteCount(); + } + + @Override + public int capacity() { + return BUFFER.capacity(); + } + + @Override + public DataBuffer capacity(int capacity) { + return BUFFER.capacity(capacity); + } + + @Override + public DataBuffer ensureCapacity(int capacity) { + return BUFFER.ensureCapacity(capacity); + } + + @Override + public int readPosition() { + return BUFFER.readPosition(); + } + + @Override + public DataBuffer readPosition(int readPosition) { + return BUFFER.readPosition(readPosition); + } + + @Override + public int writePosition() { + return BUFFER.writePosition(); + } + + @Override + public DataBuffer writePosition(int writePosition) { + return BUFFER.writePosition(writePosition); + } + + @Override + public byte getByte(int index) { + return BUFFER.getByte(index); + } + + @Override + public byte read() { + return BUFFER.read(); + } + + @Override + public DataBuffer read(byte[] destination) { + return BUFFER.read(destination); + } + + @Override + public DataBuffer read(byte[] destination, int offset, int length) { + return BUFFER.read(destination, offset, length); + } + + @Override + public DataBuffer write(byte b) { + return BUFFER.write(b); + } + + @Override + public DataBuffer write(byte[] source) { + return BUFFER.write(source); + } + + @Override + public DataBuffer write(byte[] source, int offset, int length) { + return BUFFER.write(source, offset, length); + } + + @Override + public DataBuffer write(DataBuffer... buffers) { + return BUFFER.write(buffers); + } + + @Override + public DataBuffer write(ByteBuffer... buffers) { + return BUFFER.write(buffers); + } + + @Override + public DataBuffer write(CharSequence charSequence, Charset charset) { + return BUFFER.write(charSequence, charset); + } + + @Override + public DataBuffer slice(int index, int length) { + return BUFFER.slice(index, length); + } + + @Override + public DataBuffer retainedSlice(int index, int length) { + return BUFFER.retainedSlice(index, length); + } + + @Override + public ByteBuffer asByteBuffer() { + return BUFFER.asByteBuffer(); + } + + @Override + public ByteBuffer asByteBuffer(int index, int length) { + return BUFFER.asByteBuffer(index, length); + } + + @Override + public InputStream asInputStream() { + return BUFFER.asInputStream(); + } + + @Override + public InputStream asInputStream(boolean releaseOnClose) { + return BUFFER.asInputStream(releaseOnClose); + } + + @Override + public OutputStream asOutputStream() { + return BUFFER.asOutputStream(); + } + } + + }