maxResponseBody client filter

Issue: SPR-16989
This commit is contained in:
Rossen Stoyanchev 2018-08-02 21:16:06 +03:00
parent 5095ec40b5
commit aec98268fe
2 changed files with 77 additions and 17 deletions

View File

@ -25,11 +25,15 @@ import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Predicate;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferUtils;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.util.Assert;
import org.springframework.web.reactive.function.BodyExtractors;
/**
* Static factory methods providing access to built-in implementations of
@ -50,6 +54,21 @@ public abstract class ExchangeFilterFunctions {
public static final String BASIC_AUTHENTICATION_CREDENTIALS_ATTRIBUTE =
ExchangeFilterFunctions.class.getName() + ".basicAuthenticationCredentials";
/**
* Consume up to the specified number of bytes from the response body and
* cancel if any more data arrives. Internally delegates to
* {@link DataBufferUtils#takeUntilByteCount}.
* @return the filter to limit the response size with
* @since 5.1
*/
public static ExchangeFilterFunction limitResponseSize(long maxByteCount) {
return (request, next) ->
next.exchange(request).map(response -> {
Flux<DataBuffer> body = response.body(BodyExtractors.toDataBuffers());
body = DataBufferUtils.takeUntilByteCount(body, maxByteCount);
return ClientResponse.from(response).body(body).build();
});
}
/**
* Return a filter for HTTP Basic Authentication that adds an authorization

View File

@ -17,28 +17,36 @@
package org.springframework.web.reactive.function.client;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import org.junit.Test;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferUtils;
import org.springframework.core.io.buffer.DefaultDataBufferFactory;
import org.springframework.core.io.buffer.support.DataBufferTestUtils;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus;
import org.springframework.web.reactive.function.BodyExtractors;
import static org.junit.Assert.*;
import static org.mockito.Mockito.*;
import static org.springframework.http.HttpMethod.GET;
import static org.springframework.web.reactive.function.client.ExchangeFilterFunctions.Credentials.basicAuthenticationCredentials;
/**
* @author Arjen Poutsma
*/
@SuppressWarnings("deprecation")
public class ExchangeFilterFunctionsTests {
private static final URI DEFAULT_URL = URI.create("http://example.com");
@Test
public void andThen() {
ClientRequest request = ClientRequest.create(GET, URI.create("http://example.com")).build();
ClientRequest request = ClientRequest.create(HttpMethod.GET, DEFAULT_URL).build();
ClientResponse response = mock(ClientResponse.class);
ExchangeFunction exchange = r -> Mono.just(response);
@ -68,7 +76,7 @@ public class ExchangeFilterFunctionsTests {
@Test
public void apply() {
ClientRequest request = ClientRequest.create(GET, URI.create("http://example.com")).build();
ClientRequest request = ClientRequest.create(HttpMethod.GET, DEFAULT_URL).build();
ClientResponse response = mock(ClientResponse.class);
ExchangeFunction exchange = r -> Mono.just(response);
@ -86,8 +94,9 @@ public class ExchangeFilterFunctionsTests {
}
@Test
@SuppressWarnings("deprecation")
public void basicAuthenticationUsernamePassword() {
ClientRequest request = ClientRequest.create(GET, URI.create("http://example.com")).build();
ClientRequest request = ClientRequest.create(HttpMethod.GET, DEFAULT_URL).build();
ClientResponse response = mock(ClientResponse.class);
ExchangeFunction exchange = r -> {
@ -109,9 +118,11 @@ public class ExchangeFilterFunctionsTests {
}
@Test
@SuppressWarnings("deprecation")
public void basicAuthenticationAttributes() {
ClientRequest request = ClientRequest.create(GET, URI.create("http://example.com"))
.attributes(basicAuthenticationCredentials("foo", "bar"))
ClientRequest request = ClientRequest.create(HttpMethod.GET, DEFAULT_URL)
.attributes(org.springframework.web.reactive.function.client.ExchangeFilterFunctions
.Credentials.basicAuthenticationCredentials("foo", "bar"))
.build();
ClientResponse response = mock(ClientResponse.class);
@ -128,8 +139,9 @@ public class ExchangeFilterFunctionsTests {
}
@Test
@SuppressWarnings("deprecation")
public void basicAuthenticationAbsentAttributes() {
ClientRequest request = ClientRequest.create(GET, URI.create("http://example.com")).build();
ClientRequest request = ClientRequest.create(HttpMethod.GET, DEFAULT_URL).build();
ClientResponse response = mock(ClientResponse.class);
ExchangeFunction exchange = r -> {
@ -145,7 +157,7 @@ public class ExchangeFilterFunctionsTests {
@Test
public void statusHandlerMatch() {
ClientRequest request = ClientRequest.create(GET, URI.create("http://example.com")).build();
ClientRequest request = ClientRequest.create(HttpMethod.GET, DEFAULT_URL).build();
ClientResponse response = mock(ClientResponse.class);
when(response.statusCode()).thenReturn(HttpStatus.NOT_FOUND);
@ -163,16 +175,13 @@ public class ExchangeFilterFunctionsTests {
@Test
public void statusHandlerNoMatch() {
ClientRequest request = ClientRequest.create(GET, URI.create("http://example.com")).build();
ClientRequest request = ClientRequest.create(HttpMethod.GET, DEFAULT_URL).build();
ClientResponse response = mock(ClientResponse.class);
when(response.statusCode()).thenReturn(HttpStatus.NOT_FOUND);
ExchangeFunction exchange = r -> Mono.just(response);
ExchangeFilterFunction errorHandler = ExchangeFilterFunctions.statusError(
HttpStatus::is5xxServerError, r -> new MyException());
Mono<ClientResponse> result = errorHandler.filter(request, exchange);
Mono<ClientResponse> result = ExchangeFilterFunctions
.statusError(HttpStatus::is5xxServerError, req -> new MyException())
.filter(request, req -> Mono.just(response));
StepVerifier.create(result)
.expectNext(response)
@ -180,6 +189,38 @@ public class ExchangeFilterFunctionsTests {
.verify();
}
@Test
public void limitResponseSize() {
DefaultDataBufferFactory bufferFactory = new DefaultDataBufferFactory();
DataBuffer b1 = dataBuffer("foo", bufferFactory);
DataBuffer b2 = dataBuffer("bar", bufferFactory);
DataBuffer b3 = dataBuffer("baz", bufferFactory);
ClientRequest request = ClientRequest.create(HttpMethod.GET, DEFAULT_URL).build();
ClientResponse response = ClientResponse.create(HttpStatus.OK).body(Flux.just(b1, b2, b3)).build();
Mono<ClientResponse> result = ExchangeFilterFunctions.limitResponseSize(5)
.filter(request, req -> Mono.just(response));
StepVerifier.create(result.flatMapMany(res -> res.body(BodyExtractors.toDataBuffers())))
.consumeNextWith(buffer -> assertEquals("foo", string(buffer)))
.consumeNextWith(buffer -> assertEquals("ba", string(buffer)))
.expectComplete()
.verify();
}
private String string(DataBuffer buffer) {
String value = DataBufferTestUtils.dumpString(buffer, StandardCharsets.UTF_8);
DataBufferUtils.release(buffer);
return value;
}
private DataBuffer dataBuffer(String foo, DefaultDataBufferFactory bufferFactory) {
return bufferFactory.wrap(foo.getBytes(StandardCharsets.UTF_8));
}
@SuppressWarnings("serial")
private static class MyException extends Exception {