diff --git a/spring-core/src/main/java/org/springframework/core/io/buffer/DataBufferLimitException.java b/spring-core/src/main/java/org/springframework/core/io/buffer/DataBufferLimitException.java index ee606aed57f..c03839056bb 100644 --- a/spring-core/src/main/java/org/springframework/core/io/buffer/DataBufferLimitException.java +++ b/spring-core/src/main/java/org/springframework/core/io/buffer/DataBufferLimitException.java @@ -21,7 +21,7 @@ package org.springframework.core.io.buffer; * This can be raised when data buffers are cached and aggregated, e.g. * {@link DataBufferUtils#join}. Or it could also be raised when data buffers * have been released but a parsed representation is being aggregated, e.g. async - * parsing with Jackson. + * parsing with Jackson, SSE parsing and aggregating lines per event. * * @author Rossen Stoyanchev * @since 5.1.11 diff --git a/spring-web/src/main/java/org/springframework/http/codec/ServerSentEventHttpMessageReader.java b/spring-web/src/main/java/org/springframework/http/codec/ServerSentEventHttpMessageReader.java index a6de190d510..7677c48ebfc 100644 --- a/spring-web/src/main/java/org/springframework/http/codec/ServerSentEventHttpMessageReader.java +++ b/spring-web/src/main/java/org/springframework/http/codec/ServerSentEventHttpMessageReader.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -31,6 +31,7 @@ import org.springframework.core.codec.Decoder; import org.springframework.core.codec.StringDecoder; import org.springframework.core.io.buffer.DataBuffer; import org.springframework.core.io.buffer.DataBufferFactory; +import org.springframework.core.io.buffer.DataBufferLimitException; import org.springframework.core.io.buffer.DefaultDataBufferFactory; import org.springframework.http.MediaType; import org.springframework.http.ReactiveHttpInputMessage; @@ -48,14 +49,16 @@ public class ServerSentEventHttpMessageReader implements HttpMessageReader decoder; + private final StringDecoder lineDecoder = StringDecoder.textPlainOnly(); + + + /** * Constructor without a {@code Decoder}. In this mode only {@code String} @@ -82,6 +85,29 @@ public class ServerSentEventHttpMessageReader implements HttpMessageReaderNote that the {@link #getDecoder() data decoder}, if provided, must + * also be customized accordingly to raise the limit if necessary in order + * to be able to parse the data portion of the event. + *

By default this is set to 256K. + * @param byteCount the max number of bytes to buffer, or -1 for unlimited + * @since 5.1.13 + */ + public void setMaxInMemorySize(int byteCount) { + this.lineDecoder.setMaxInMemorySize(byteCount); + } + + /** + * Return the {@link #setMaxInMemorySize configured} byte count limit. + * @since 5.1.13 + */ + public int getMaxInMemorySize() { + return this.lineDecoder.getMaxInMemorySize(); + } + + @Override public List getReadableMediaTypes() { return Collections.singletonList(MediaType.TEXT_EVENT_STREAM); @@ -101,12 +127,15 @@ public class ServerSentEventHttpMessageReader implements HttpMessageReader read( ResolvableType elementType, ReactiveHttpInputMessage message, Map hints) { + LimitTracker limitTracker = new LimitTracker(); + boolean shouldWrap = isServerSentEvent(elementType); ResolvableType valueType = (shouldWrap ? elementType.getGeneric() : elementType); - return stringDecoder.decode(message.getBody(), STRING_TYPE, null, hints) + return this.lineDecoder.decode(message.getBody(), STRING_TYPE, null, hints) + .doOnNext(limitTracker::afterLineParsed) .bufferUntil(String::isEmpty) - .concatMap(lines -> Mono.justOrEmpty(buildEvent(lines, valueType, shouldWrap, hints))); + .map(lines -> buildEvent(lines, valueType, shouldWrap, hints)); } @Nullable @@ -172,16 +201,47 @@ public class ServerSentEventHttpMessageReader implements HttpMessageReader readMono( ResolvableType elementType, ReactiveHttpInputMessage message, Map hints) { - // We're ahead of String + "*/*" - // Let's see if we can aggregate the output (lest we time out)... + // In order of readers, we're ahead of String + "*/*" + // If this is called, simply delegate to StringDecoder if (elementType.resolve() == String.class) { Flux body = message.getBody(); - return stringDecoder.decodeToMono(body, elementType, null, null).cast(Object.class); + return this.lineDecoder.decodeToMono(body, elementType, null, null).cast(Object.class); } return Mono.error(new UnsupportedOperationException( "ServerSentEventHttpMessageReader only supports reading stream of events as a Flux")); } + + private class LimitTracker { + + private int accumulated = 0; + + + public void afterLineParsed(String line) { + if (getMaxInMemorySize() < 0) { + return; + } + if (line.isEmpty()) { + this.accumulated = 0; + } + if (line.length() > Integer.MAX_VALUE - this.accumulated) { + raiseLimitException(); + } + else { + this.accumulated += line.length(); + if (this.accumulated > getMaxInMemorySize()) { + raiseLimitException(); + } + } + } + + private void raiseLimitException() { + // Do not release here, it's likely down via doOnDiscard.. + throw new DataBufferLimitException( + "Exceeded limit on max bytes to buffer : " + getMaxInMemorySize()); + } + } + } diff --git a/spring-web/src/main/java/org/springframework/http/codec/support/BaseDefaultCodecs.java b/spring-web/src/main/java/org/springframework/http/codec/support/BaseDefaultCodecs.java index 72ac9859395..18626783889 100644 --- a/spring-web/src/main/java/org/springframework/http/codec/support/BaseDefaultCodecs.java +++ b/spring-web/src/main/java/org/springframework/http/codec/support/BaseDefaultCodecs.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -238,9 +238,6 @@ class BaseDefaultCodecs implements CodecConfigurer.DefaultCodecs, CodecConfigure if (codec instanceof DecoderHttpMessageReader) { codec = ((DecoderHttpMessageReader) codec).getDecoder(); } - else if (codec instanceof ServerSentEventHttpMessageReader) { - codec = ((ServerSentEventHttpMessageReader) codec).getDecoder(); - } if (codec == null) { return; @@ -269,6 +266,10 @@ class BaseDefaultCodecs implements CodecConfigurer.DefaultCodecs, CodecConfigure if (codec instanceof FormHttpMessageReader) { ((FormHttpMessageReader) codec).setMaxInMemorySize(size); } + if (codec instanceof ServerSentEventHttpMessageReader) { + ((ServerSentEventHttpMessageReader) codec).setMaxInMemorySize(size); + initCodec(((ServerSentEventHttpMessageReader) codec).getDecoder()); + } if (synchronossMultipartPresent) { if (codec instanceof SynchronossPartHttpMessageReader) { ((SynchronossPartHttpMessageReader) codec).setMaxInMemorySize(size); diff --git a/spring-web/src/test/java/org/springframework/http/codec/ServerSentEventHttpMessageReaderTests.java b/spring-web/src/test/java/org/springframework/http/codec/ServerSentEventHttpMessageReaderTests.java index 01d9a34376f..c5a5983f07b 100644 --- a/spring-web/src/test/java/org/springframework/http/codec/ServerSentEventHttpMessageReaderTests.java +++ b/spring-web/src/test/java/org/springframework/http/codec/ServerSentEventHttpMessageReaderTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,6 +27,7 @@ import reactor.test.StepVerifier; import org.springframework.core.ResolvableType; import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferLimitException; import org.springframework.core.testfixture.io.buffer.AbstractLeakCheckingTests; import org.springframework.http.MediaType; import org.springframework.http.codec.json.Jackson2JsonDecoder; @@ -42,20 +43,21 @@ import static org.assertj.core.api.Assertions.assertThat; */ public class ServerSentEventHttpMessageReaderTests extends AbstractLeakCheckingTests { - private ServerSentEventHttpMessageReader messageReader = - new ServerSentEventHttpMessageReader(new Jackson2JsonDecoder()); + private Jackson2JsonDecoder jsonDecoder = new Jackson2JsonDecoder(); + + private ServerSentEventHttpMessageReader reader = new ServerSentEventHttpMessageReader(this.jsonDecoder); @Test public void cantRead() { - assertThat(messageReader.canRead(ResolvableType.forClass(Object.class), new MediaType("foo", "bar"))).isFalse(); - assertThat(messageReader.canRead(ResolvableType.forClass(Object.class), null)).isFalse(); + assertThat(reader.canRead(ResolvableType.forClass(Object.class), new MediaType("foo", "bar"))).isFalse(); + assertThat(reader.canRead(ResolvableType.forClass(Object.class), null)).isFalse(); } @Test public void canRead() { - assertThat(messageReader.canRead(ResolvableType.forClass(Object.class), new MediaType("text", "event-stream"))).isTrue(); - assertThat(messageReader.canRead(ResolvableType.forClass(ServerSentEvent.class), new MediaType("foo", "bar"))).isTrue(); + assertThat(reader.canRead(ResolvableType.forClass(Object.class), new MediaType("text", "event-stream"))).isTrue(); + assertThat(reader.canRead(ResolvableType.forClass(ServerSentEvent.class), new MediaType("foo", "bar"))).isTrue(); } @Test @@ -66,7 +68,7 @@ public class ServerSentEventHttpMessageReaderTests extends AbstractLeakCheckingT "id:c42\nevent:foo\nretry:123\n:bla\n:bla bla\n:bla bla bla\ndata:bar\n\n" + "id:c43\nevent:bar\nretry:456\ndata:baz\n\n"))); - Flux events = this.messageReader + Flux events = this.reader .read(ResolvableType.forClassWithGenerics(ServerSentEvent.class, String.class), request, Collections.emptyMap()).cast(ServerSentEvent.class); @@ -98,7 +100,7 @@ public class ServerSentEventHttpMessageReaderTests extends AbstractLeakCheckingT stringBuffer("ent:foo\nretry:123\n:bla\n:bla bla\n:bla bla bla\ndata:"), stringBuffer("bar\n\nid:c43\nevent:bar\nretry:456\ndata:baz\n\n"))); - Flux events = messageReader + Flux events = reader .read(ResolvableType.forClassWithGenerics(ServerSentEvent.class, String.class), request, Collections.emptyMap()).cast(ServerSentEvent.class); @@ -126,7 +128,7 @@ public class ServerSentEventHttpMessageReaderTests extends AbstractLeakCheckingT MockServerHttpRequest request = MockServerHttpRequest.post("/") .body(Mono.just(stringBuffer("data:foo\ndata:bar\n\ndata:baz\n\n"))); - Flux data = messageReader.read(ResolvableType.forClass(String.class), + Flux data = reader.read(ResolvableType.forClass(String.class), request, Collections.emptyMap()).cast(String.class); StepVerifier.create(data) @@ -143,7 +145,7 @@ public class ServerSentEventHttpMessageReaderTests extends AbstractLeakCheckingT "data:{\"foo\": \"foofoo\", \"bar\": \"barbar\"}\n\n" + "data:{\"foo\": \"foofoofoo\", \"bar\": \"barbarbar\"}\n\n"))); - Flux data = messageReader.read(ResolvableType.forClass(Pojo.class), request, + Flux data = reader.read(ResolvableType.forClass(Pojo.class), request, Collections.emptyMap()).cast(Pojo.class); StepVerifier.create(data) @@ -165,7 +167,7 @@ public class ServerSentEventHttpMessageReaderTests extends AbstractLeakCheckingT MockServerHttpRequest request = MockServerHttpRequest.post("/") .body(Mono.just(stringBuffer(body))); - String actual = messageReader + String actual = reader .readMono(ResolvableType.forClass(String.class), request, Collections.emptyMap()) .cast(String.class) .block(Duration.ZERO); @@ -182,7 +184,7 @@ public class ServerSentEventHttpMessageReaderTests extends AbstractLeakCheckingT MockServerHttpRequest request = MockServerHttpRequest.post("/") .body(body); - Flux data = messageReader.read(ResolvableType.forClass(String.class), + Flux data = reader.read(ResolvableType.forClass(String.class), request, Collections.emptyMap()).cast(String.class); StepVerifier.create(data) @@ -192,6 +194,54 @@ public class ServerSentEventHttpMessageReaderTests extends AbstractLeakCheckingT .verify(); } + @Test + public void maxInMemoryLimit() { + + this.reader.setMaxInMemorySize(17); + + MockServerHttpRequest request = MockServerHttpRequest.post("/") + .body(Flux.just(stringBuffer("data:\"TOO MUCH DATA\"\ndata:bar\n\ndata:baz\n\n"))); + + Flux data = this.reader.read(ResolvableType.forClass(String.class), + request, Collections.emptyMap()).cast(String.class); + + StepVerifier.create(data) + .expectError(DataBufferLimitException.class) + .verify(); + } + + @Test // gh-24312 + public void maxInMemoryLimitAllowsReadingPojoLargerThanDefaultSize() { + + int limit = this.jsonDecoder.getMaxInMemorySize(); + + String fooValue = getStringOfSize(limit) + "and then some more"; + String content = "data:{\"foo\": \"" + fooValue + "\"}\n\n"; + MockServerHttpRequest request = MockServerHttpRequest.post("/").body(Mono.just(stringBuffer(content))); + + Jackson2JsonDecoder jacksonDecoder = new Jackson2JsonDecoder(); + ServerSentEventHttpMessageReader messageReader = new ServerSentEventHttpMessageReader(jacksonDecoder); + + jacksonDecoder.setMaxInMemorySize(limit + 1024); + messageReader.setMaxInMemorySize(limit + 1024); + + Flux data = messageReader.read(ResolvableType.forClass(Pojo.class), request, + Collections.emptyMap()).cast(Pojo.class); + + StepVerifier.create(data) + .consumeNextWith(pojo -> assertThat(pojo.getFoo()).isEqualTo(fooValue)) + .expectComplete() + .verify(); + } + + private static String getStringOfSize(long size) { + StringBuilder content = new StringBuilder("Aa"); + while (content.length() < size) { + content.append(content); + } + return content.toString(); + } + private DataBuffer stringBuffer(String value) { byte[] bytes = value.getBytes(StandardCharsets.UTF_8); DataBuffer buffer = this.bufferFactory.allocateBuffer(bytes.length); diff --git a/spring-web/src/test/java/org/springframework/http/codec/support/ClientCodecConfigurerTests.java b/spring-web/src/test/java/org/springframework/http/codec/support/ClientCodecConfigurerTests.java index ca119d5de7e..fce509efbb8 100644 --- a/spring-web/src/test/java/org/springframework/http/codec/support/ClientCodecConfigurerTests.java +++ b/spring-web/src/test/java/org/springframework/http/codec/support/ClientCodecConfigurerTests.java @@ -140,6 +140,7 @@ public class ClientCodecConfigurerTests { assertThat(((Jaxb2XmlDecoder) getNextDecoder(readers)).getMaxInMemorySize()).isEqualTo(size); ServerSentEventHttpMessageReader reader = (ServerSentEventHttpMessageReader) nextReader(readers); + assertThat(reader.getMaxInMemorySize()).isEqualTo(size); assertThat(((Jackson2JsonDecoder) reader.getDecoder()).getMaxInMemorySize()).isEqualTo(size); assertThat(((StringDecoder) getNextDecoder(readers)).getMaxInMemorySize()).isEqualTo(size);