Introduce single-value request predicates

This commit introduces new HTTP method, Content-Type, and Accept header
request predicates that handle single values. Previously, these
predicates were always dealt with as single-value collections, which
introduced computational overhead.

Closes gh-32244
This commit is contained in:
Arjen Poutsma 2024-02-12 15:44:53 +01:00
parent 5851cdc679
commit 750cb73902
4 changed files with 372 additions and 145 deletions

View File

@ -47,7 +47,6 @@ import org.springframework.http.codec.multipart.Part;
import org.springframework.http.server.PathContainer;
import org.springframework.http.server.RequestPath;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.lang.NonNull;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
@ -90,7 +89,8 @@ public abstract class RequestPredicates {
* @return a predicate that tests against the given HTTP method
*/
public static RequestPredicate method(HttpMethod httpMethod) {
return new HttpMethodPredicate(httpMethod);
Assert.notNull(httpMethod, "HttpMethod must not be null");
return new SingleHttpMethodPredicate(httpMethod);
}
/**
@ -101,7 +101,13 @@ public abstract class RequestPredicates {
* @since 5.1
*/
public static RequestPredicate methods(HttpMethod... httpMethods) {
return new HttpMethodPredicate(httpMethods);
Assert.notEmpty(httpMethods, "HttpMethods must not be empty");
if (httpMethods.length == 1) {
return new SingleHttpMethodPredicate(httpMethods[0]);
}
else {
return new MultipleHttpMethodsPredicate(httpMethods);
}
}
/**
@ -151,7 +157,12 @@ public abstract class RequestPredicates {
*/
public static RequestPredicate contentType(MediaType... mediaTypes) {
Assert.notEmpty(mediaTypes, "'mediaTypes' must not be empty");
return new ContentTypePredicate(mediaTypes);
if (mediaTypes.length == 1) {
return new SingleContentTypePredicate(mediaTypes[0]);
}
else {
return new MultipleContentTypesPredicate(mediaTypes);
}
}
/**
@ -163,7 +174,12 @@ public abstract class RequestPredicates {
*/
public static RequestPredicate accept(MediaType... mediaTypes) {
Assert.notEmpty(mediaTypes, "'mediaTypes' must not be empty");
return new AcceptPredicate(mediaTypes);
if (mediaTypes.length == 1) {
return new SingleAcceptPredicate(mediaTypes[0]);
}
else {
return new MultipleAcceptsPredicate(mediaTypes);
}
}
/**
@ -529,29 +545,23 @@ public abstract class RequestPredicates {
}
private static class HttpMethodPredicate implements RequestPredicate {
private static class SingleHttpMethodPredicate implements RequestPredicate {
private final Set<HttpMethod> httpMethods;
private final HttpMethod httpMethod;
public HttpMethodPredicate(HttpMethod httpMethod) {
Assert.notNull(httpMethod, "HttpMethod must not be null");
this.httpMethods = Set.of(httpMethod);
}
public HttpMethodPredicate(HttpMethod... httpMethods) {
Assert.notEmpty(httpMethods, "HttpMethods must not be empty");
this.httpMethods = new LinkedHashSet<>(Arrays.asList(httpMethods));
public SingleHttpMethodPredicate(HttpMethod httpMethod) {
this.httpMethod = httpMethod;
}
@Override
public boolean test(ServerRequest request) {
HttpMethod method = method(request);
boolean match = this.httpMethods.contains(method);
traceMatch("Method", this.httpMethods, method, match);
boolean match = this.httpMethod.equals(method);
traceMatch("Method", this.httpMethod, method, match);
return match;
}
private static HttpMethod method(ServerRequest request) {
static HttpMethod method(ServerRequest request) {
if (CorsUtils.isPreFlightRequest(request.exchange().getRequest())) {
String accessControlRequestMethod =
request.headers().firstHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD);
@ -562,6 +572,34 @@ public abstract class RequestPredicates {
return request.method();
}
@Override
public void accept(Visitor visitor) {
visitor.method(Set.of(this.httpMethod));
}
@Override
public String toString() {
return this.httpMethod.toString();
}
}
private static class MultipleHttpMethodsPredicate implements RequestPredicate {
private final Set<HttpMethod> httpMethods;
public MultipleHttpMethodsPredicate(HttpMethod[] httpMethods) {
this.httpMethods = new LinkedHashSet<>(Arrays.asList(httpMethods));
}
@Override
public boolean test(ServerRequest request) {
HttpMethod method = SingleHttpMethodPredicate.method(request);
boolean match = this.httpMethods.contains(method);
traceMatch("Method", this.httpMethods, method, match);
return match;
}
@Override
public void accept(Visitor visitor) {
visitor.method(Collections.unmodifiableSet(this.httpMethods));
@ -569,12 +607,7 @@ public abstract class RequestPredicates {
@Override
public String toString() {
if (this.httpMethods.size() == 1) {
return this.httpMethods.iterator().next().toString();
}
else {
return this.httpMethods.toString();
}
return this.httpMethods.toString();
}
}
@ -669,20 +702,46 @@ public abstract class RequestPredicates {
}
private static class ContentTypePredicate extends HeadersPredicate {
private static class SingleContentTypePredicate extends HeadersPredicate {
private final Set<MediaType> mediaTypes;
private final MediaType mediaType;
public ContentTypePredicate(MediaType... mediaTypes) {
this(Set.of(mediaTypes));
public SingleContentTypePredicate(MediaType mediaType) {
super(headers -> {
MediaType contentType = headers.contentType().orElse(MediaType.APPLICATION_OCTET_STREAM);
boolean match = mediaType.includes(contentType);
traceMatch("Content-Type", mediaType, contentType, match);
return match;
});
this.mediaType = mediaType;
}
private ContentTypePredicate(Set<MediaType> mediaTypes) {
@Override
public void accept(Visitor visitor) {
visitor.header(HttpHeaders.CONTENT_TYPE, this.mediaType.toString());
}
@Override
public String toString() {
return "Content-Type: " + this.mediaType;
}
}
private static class MultipleContentTypesPredicate extends HeadersPredicate {
private final MediaType[] mediaTypes;
public MultipleContentTypesPredicate(MediaType[] mediaTypes) {
super(headers -> {
MediaType contentType =
headers.contentType().orElse(MediaType.APPLICATION_OCTET_STREAM);
boolean match = mediaTypes.stream()
.anyMatch(mediaType -> mediaType.includes(contentType));
MediaType contentType = headers.contentType().orElse(MediaType.APPLICATION_OCTET_STREAM);
boolean match = false;
for (MediaType mediaType : mediaTypes) {
if (mediaType.includes(contentType)) {
match = true;
break;
}
}
traceMatch("Content-Type", mediaTypes, contentType, match);
return match;
});
@ -691,44 +750,37 @@ public abstract class RequestPredicates {
@Override
public void accept(Visitor visitor) {
visitor.header(HttpHeaders.CONTENT_TYPE,
(this.mediaTypes.size() == 1) ?
this.mediaTypes.iterator().next().toString() :
this.mediaTypes.toString());
visitor.header(HttpHeaders.CONTENT_TYPE, Arrays.toString(this.mediaTypes));
}
@Override
public String toString() {
return String.format("Content-Type: %s",
(this.mediaTypes.size() == 1) ?
this.mediaTypes.iterator().next().toString() :
this.mediaTypes.toString());
return "Content-Type: " + Arrays.toString(this.mediaTypes);
}
}
private static class AcceptPredicate extends HeadersPredicate {
private static class SingleAcceptPredicate extends HeadersPredicate {
private final Set<MediaType> mediaTypes;
private final MediaType mediaType;
public AcceptPredicate(MediaType... mediaTypes) {
this(Set.of(mediaTypes));
}
private AcceptPredicate(Set<MediaType> mediaTypes) {
public SingleAcceptPredicate(MediaType mediaType) {
super(headers -> {
List<MediaType> acceptedMediaTypes = acceptedMediaTypes(headers);
boolean match = acceptedMediaTypes.stream()
.anyMatch(acceptedMediaType -> mediaTypes.stream()
.anyMatch(acceptedMediaType::isCompatibleWith));
traceMatch("Accept", mediaTypes, acceptedMediaTypes, match);
boolean match = false;
for (MediaType acceptedMediaType : acceptedMediaTypes) {
if (acceptedMediaType.isCompatibleWith(mediaType)) {
match = true;
break;
}
}
traceMatch("Accept", mediaType, acceptedMediaTypes, match);
return match;
});
this.mediaTypes = mediaTypes;
this.mediaType = mediaType;
}
@NonNull
private static List<MediaType> acceptedMediaTypes(ServerRequest.Headers headers) {
static List<MediaType> acceptedMediaTypes(ServerRequest.Headers headers) {
List<MediaType> acceptedMediaTypes = headers.accept();
if (acceptedMediaTypes.isEmpty()) {
acceptedMediaTypes = Collections.singletonList(MediaType.ALL);
@ -741,18 +793,47 @@ public abstract class RequestPredicates {
@Override
public void accept(Visitor visitor) {
visitor.header(HttpHeaders.ACCEPT,
(this.mediaTypes.size() == 1) ?
this.mediaTypes.iterator().next().toString() :
this.mediaTypes.toString());
visitor.header(HttpHeaders.ACCEPT, this.mediaType.toString());
}
@Override
public String toString() {
return String.format("Accept: %s",
(this.mediaTypes.size() == 1) ?
this.mediaTypes.iterator().next().toString() :
this.mediaTypes.toString());
return "Accept: " + this.mediaType;
}
}
private static class MultipleAcceptsPredicate extends HeadersPredicate {
private final MediaType[] mediaTypes;
public MultipleAcceptsPredicate(MediaType[] mediaTypes) {
super(headers -> {
List<MediaType> acceptedMediaTypes = SingleAcceptPredicate.acceptedMediaTypes(headers);
boolean match = false;
outer:
for (MediaType acceptedMediaType : acceptedMediaTypes) {
for (MediaType mediaType : mediaTypes) {
if (acceptedMediaType.isCompatibleWith(mediaType)) {
match = true;
break outer;
}
}
}
traceMatch("Accept", mediaTypes, acceptedMediaTypes, match);
return match;
});
this.mediaTypes = mediaTypes;
}
@Override
public void accept(Visitor visitor) {
visitor.header(HttpHeaders.ACCEPT, Arrays.toString(this.mediaTypes));
}
@Override
public String toString() {
return "Accept: " + Arrays.toString(this.mediaTypes);
}
}

View File

@ -219,11 +219,10 @@ class RequestPredicatesTests {
@Test
void contentType() {
MediaType json = MediaType.APPLICATION_JSON;
RequestPredicate predicate = RequestPredicates.contentType(json);
void singleContentType() {
RequestPredicate predicate = RequestPredicates.contentType(MediaType.APPLICATION_JSON);
MockServerHttpRequest mockRequest = MockServerHttpRequest.get("https://example.com")
.header(HttpHeaders.CONTENT_TYPE, json.toString())
.header(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
.build();
ServerRequest request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList());
assertThat(predicate.test(request)).isTrue();
@ -236,15 +235,58 @@ class RequestPredicatesTests {
}
@Test
void accept() {
MediaType json = MediaType.APPLICATION_JSON;
RequestPredicate predicate = RequestPredicates.accept(json);
void multipleContentTypes() {
RequestPredicate predicate = RequestPredicates.contentType(MediaType.APPLICATION_JSON, MediaType.TEXT_PLAIN);
MockServerHttpRequest mockRequest = MockServerHttpRequest.get("https://example.com")
.header(HttpHeaders.ACCEPT, json.toString())
.header(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
.build();
ServerRequest request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList());
assertThat(predicate.test(request)).isTrue();
mockRequest = MockServerHttpRequest.get("https://example.com")
.header(HttpHeaders.CONTENT_TYPE, MediaType.TEXT_PLAIN_VALUE)
.build();
request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList());
assertThat(predicate.test(request)).isTrue();
mockRequest = MockServerHttpRequest.get("https://example.com")
.header(HttpHeaders.CONTENT_TYPE, "foo/bar")
.build();
request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList());
assertThat(predicate.test(request)).isFalse();
}
@Test
void singleAccept() {
RequestPredicate predicate = RequestPredicates.accept(MediaType.APPLICATION_JSON, MediaType.TEXT_PLAIN);
MockServerHttpRequest mockRequest = MockServerHttpRequest.get("https://example.com")
.header(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE)
.build();
ServerRequest request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList());
assertThat(predicate.test(request)).isTrue();
mockRequest = MockServerHttpRequest.get("https://example.com")
.header(HttpHeaders.ACCEPT, "foo/bar")
.build();
request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList());
assertThat(predicate.test(request)).isFalse();
}
@Test
void multipleAccepts() {
RequestPredicate predicate = RequestPredicates.accept(MediaType.APPLICATION_JSON, MediaType.TEXT_PLAIN);
MockServerHttpRequest mockRequest = MockServerHttpRequest.get("https://example.com")
.header(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE)
.build();
ServerRequest request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList());
assertThat(predicate.test(request)).isTrue();
mockRequest = MockServerHttpRequest.get("https://example.com")
.header(HttpHeaders.ACCEPT, MediaType.TEXT_PLAIN_VALUE)
.build();
request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList());
assertThat(predicate.test(request)).isTrue();
mockRequest = MockServerHttpRequest.get("https://example.com")
.header(HttpHeaders.ACCEPT, "foo/bar")
.build();

View File

@ -48,7 +48,6 @@ import org.springframework.http.MediaType;
import org.springframework.http.converter.HttpMessageConverter;
import org.springframework.http.server.PathContainer;
import org.springframework.http.server.RequestPath;
import org.springframework.lang.NonNull;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
@ -90,7 +89,8 @@ public abstract class RequestPredicates {
* @return a predicate that tests against the given HTTP method
*/
public static RequestPredicate method(HttpMethod httpMethod) {
return new HttpMethodPredicate(httpMethod);
Assert.notNull(httpMethod, "HttpMethod must not be null");
return new SingleHttpMethodPredicate(httpMethod);
}
/**
@ -100,7 +100,13 @@ public abstract class RequestPredicates {
* @return a predicate that tests against the given HTTP methods
*/
public static RequestPredicate methods(HttpMethod... httpMethods) {
return new HttpMethodPredicate(httpMethods);
Assert.notEmpty(httpMethods, "HttpMethods must not be empty");
if (httpMethods.length == 1) {
return new SingleHttpMethodPredicate(httpMethods[0]);
}
else {
return new MultipleHttpMethodsPredicate(httpMethods);
}
}
/**
@ -150,7 +156,12 @@ public abstract class RequestPredicates {
*/
public static RequestPredicate contentType(MediaType... mediaTypes) {
Assert.notEmpty(mediaTypes, "'mediaTypes' must not be empty");
return new ContentTypePredicate(mediaTypes);
if (mediaTypes.length == 1) {
return new SingleContentTypePredicate(mediaTypes[0]);
}
else {
return new MultipleContentTypesPredicate(mediaTypes);
}
}
/**
@ -162,7 +173,12 @@ public abstract class RequestPredicates {
*/
public static RequestPredicate accept(MediaType... mediaTypes) {
Assert.notEmpty(mediaTypes, "'mediaTypes' must not be empty");
return new AcceptPredicate(mediaTypes);
if (mediaTypes.length == 1) {
return new SingleAcceptPredicate(mediaTypes[0]);
}
else {
return new MultipleAcceptsPredicate(mediaTypes);
}
}
/**
@ -527,29 +543,23 @@ public abstract class RequestPredicates {
}
private static class HttpMethodPredicate implements RequestPredicate {
private static class SingleHttpMethodPredicate implements RequestPredicate {
private final Set<HttpMethod> httpMethods;
private final HttpMethod httpMethod;
public HttpMethodPredicate(HttpMethod httpMethod) {
Assert.notNull(httpMethod, "HttpMethod must not be null");
this.httpMethods = Set.of(httpMethod);
}
public HttpMethodPredicate(HttpMethod... httpMethods) {
Assert.notEmpty(httpMethods, "HttpMethods must not be empty");
this.httpMethods = new LinkedHashSet<>(Arrays.asList(httpMethods));
public SingleHttpMethodPredicate(HttpMethod httpMethod) {
this.httpMethod = httpMethod;
}
@Override
public boolean test(ServerRequest request) {
HttpMethod method = method(request);
boolean match = this.httpMethods.contains(method);
traceMatch("Method", this.httpMethods, method, match);
boolean match = this.httpMethod.equals(method);
traceMatch("Method", this.httpMethod, method, match);
return match;
}
private static HttpMethod method(ServerRequest request) {
static HttpMethod method(ServerRequest request) {
if (CorsUtils.isPreFlightRequest(request.servletRequest())) {
String accessControlRequestMethod =
request.headers().firstHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD);
@ -560,6 +570,34 @@ public abstract class RequestPredicates {
return request.method();
}
@Override
public void accept(Visitor visitor) {
visitor.method(Set.of(this.httpMethod));
}
@Override
public String toString() {
return this.httpMethod.toString();
}
}
private static class MultipleHttpMethodsPredicate implements RequestPredicate {
private final Set<HttpMethod> httpMethods;
public MultipleHttpMethodsPredicate(HttpMethod[] httpMethods) {
this.httpMethods = new LinkedHashSet<>(Arrays.asList(httpMethods));
}
@Override
public boolean test(ServerRequest request) {
HttpMethod method = SingleHttpMethodPredicate.method(request);
boolean match = this.httpMethods.contains(method);
traceMatch("Method", this.httpMethods, method, match);
return match;
}
@Override
public void accept(Visitor visitor) {
visitor.method(Collections.unmodifiableSet(this.httpMethods));
@ -567,12 +605,7 @@ public abstract class RequestPredicates {
@Override
public String toString() {
if (this.httpMethods.size() == 1) {
return this.httpMethods.iterator().next().toString();
}
else {
return this.httpMethods.toString();
}
return this.httpMethods.toString();
}
}
@ -667,20 +700,46 @@ public abstract class RequestPredicates {
}
private static class ContentTypePredicate extends HeadersPredicate {
private static class SingleContentTypePredicate extends HeadersPredicate {
private final Set<MediaType> mediaTypes;
private final MediaType mediaType;
public ContentTypePredicate(MediaType... mediaTypes) {
this(Set.of(mediaTypes));
public SingleContentTypePredicate(MediaType mediaType) {
super(headers -> {
MediaType contentType = headers.contentType().orElse(MediaType.APPLICATION_OCTET_STREAM);
boolean match = mediaType.includes(contentType);
traceMatch("Content-Type", mediaType, contentType, match);
return match;
});
this.mediaType = mediaType;
}
private ContentTypePredicate(Set<MediaType> mediaTypes) {
@Override
public void accept(Visitor visitor) {
visitor.header(HttpHeaders.CONTENT_TYPE, this.mediaType.toString());
}
@Override
public String toString() {
return "Content-Type: " + this.mediaType;
}
}
private static class MultipleContentTypesPredicate extends HeadersPredicate {
private final MediaType[] mediaTypes;
public MultipleContentTypesPredicate(MediaType[] mediaTypes) {
super(headers -> {
MediaType contentType =
headers.contentType().orElse(MediaType.APPLICATION_OCTET_STREAM);
boolean match = mediaTypes.stream()
.anyMatch(mediaType -> mediaType.includes(contentType));
MediaType contentType = headers.contentType().orElse(MediaType.APPLICATION_OCTET_STREAM);
boolean match = false;
for (MediaType mediaType : mediaTypes) {
if (mediaType.includes(contentType)) {
match = true;
break;
}
}
traceMatch("Content-Type", mediaTypes, contentType, match);
return match;
});
@ -689,44 +748,37 @@ public abstract class RequestPredicates {
@Override
public void accept(Visitor visitor) {
visitor.header(HttpHeaders.CONTENT_TYPE,
(this.mediaTypes.size() == 1) ?
this.mediaTypes.iterator().next().toString() :
this.mediaTypes.toString());
visitor.header(HttpHeaders.CONTENT_TYPE, Arrays.toString(this.mediaTypes));
}
@Override
public String toString() {
return String.format("Content-Type: %s",
(this.mediaTypes.size() == 1) ?
this.mediaTypes.iterator().next().toString() :
this.mediaTypes.toString());
return "Content-Type: " + Arrays.toString(this.mediaTypes);
}
}
private static class AcceptPredicate extends HeadersPredicate {
private static class SingleAcceptPredicate extends HeadersPredicate {
private final Set<MediaType> mediaTypes;
private final MediaType mediaType;
public AcceptPredicate(MediaType... mediaTypes) {
this(Set.of(mediaTypes));
}
private AcceptPredicate(Set<MediaType> mediaTypes) {
public SingleAcceptPredicate(MediaType mediaType) {
super(headers -> {
List<MediaType> acceptedMediaTypes = acceptedMediaTypes(headers);
boolean match = acceptedMediaTypes.stream()
.anyMatch(acceptedMediaType -> mediaTypes.stream()
.anyMatch(acceptedMediaType::isCompatibleWith));
traceMatch("Accept", mediaTypes, acceptedMediaTypes, match);
boolean match = false;
for (MediaType acceptedMediaType : acceptedMediaTypes) {
if (acceptedMediaType.isCompatibleWith(mediaType)) {
match = true;
break;
}
}
traceMatch("Accept", mediaType, acceptedMediaTypes, match);
return match;
});
this.mediaTypes = mediaTypes;
this.mediaType = mediaType;
}
@NonNull
private static List<MediaType> acceptedMediaTypes(ServerRequest.Headers headers) {
static List<MediaType> acceptedMediaTypes(ServerRequest.Headers headers) {
List<MediaType> acceptedMediaTypes = headers.accept();
if (acceptedMediaTypes.isEmpty()) {
acceptedMediaTypes = Collections.singletonList(MediaType.ALL);
@ -739,18 +791,47 @@ public abstract class RequestPredicates {
@Override
public void accept(Visitor visitor) {
visitor.header(HttpHeaders.ACCEPT,
(this.mediaTypes.size() == 1) ?
this.mediaTypes.iterator().next().toString() :
this.mediaTypes.toString());
visitor.header(HttpHeaders.ACCEPT, this.mediaType.toString());
}
@Override
public String toString() {
return String.format("Accept: %s",
(this.mediaTypes.size() == 1) ?
this.mediaTypes.iterator().next().toString() :
this.mediaTypes.toString());
return "Accept: " + this.mediaType;
}
}
private static class MultipleAcceptsPredicate extends HeadersPredicate {
private final MediaType[] mediaTypes;
public MultipleAcceptsPredicate(MediaType[] mediaTypes) {
super(headers -> {
List<MediaType> acceptedMediaTypes = SingleAcceptPredicate.acceptedMediaTypes(headers);
boolean match = false;
outer:
for (MediaType acceptedMediaType : acceptedMediaTypes) {
for (MediaType mediaType : mediaTypes) {
if (acceptedMediaType.isCompatibleWith(mediaType)) {
match = true;
break outer;
}
}
}
traceMatch("Accept", mediaTypes, acceptedMediaTypes, match);
return match;
});
this.mediaTypes = mediaTypes;
}
@Override
public void accept(Visitor visitor) {
visitor.header(HttpHeaders.ACCEPT, Arrays.toString(this.mediaTypes));
}
@Override
public String toString() {
return "Accept: " + Arrays.toString(this.mediaTypes);
}
}

View File

@ -31,7 +31,6 @@ import org.springframework.web.testfixture.servlet.MockHttpServletRequest;
import org.springframework.web.util.pattern.PathPatternParser;
import static org.assertj.core.api.Assertions.assertThat;
import static org.springframework.http.MediaType.TEXT_XML_VALUE;
/**
* @author Arjen Poutsma
@ -179,22 +178,46 @@ class RequestPredicatesTests {
}
@Test
void contentType() {
MediaType json = MediaType.APPLICATION_JSON;
RequestPredicate predicate = RequestPredicates.contentType(json);
ServerRequest request = initRequest("GET", "/path", req -> req.setContentType(json.toString()));
void singleContentType() {
RequestPredicate predicate = RequestPredicates.contentType(MediaType.APPLICATION_JSON);
ServerRequest request = initRequest("GET", "/path", r -> r.setContentType(MediaType.APPLICATION_JSON_VALUE));
assertThat(predicate.test(request)).isTrue();
assertThat(predicate.test(initRequest("GET", ""))).isFalse();
assertThat(predicate.test(initRequest("GET", "", r -> r.setContentType(MediaType.TEXT_XML_VALUE)))).isFalse();
}
@Test
void accept() {
MediaType json = MediaType.APPLICATION_JSON;
RequestPredicate predicate = RequestPredicates.accept(json);
ServerRequest request = initRequest("GET", "/path", req -> req.addHeader("Accept", json.toString()));
void multipleContentTypes() {
RequestPredicate predicate = RequestPredicates.contentType(MediaType.APPLICATION_JSON, MediaType.TEXT_PLAIN);
ServerRequest request = initRequest("GET", "/path", r -> r.setContentType(MediaType.APPLICATION_JSON_VALUE));
assertThat(predicate.test(request)).isTrue();
request = initRequest("GET", "", req -> req.addHeader("Accept", TEXT_XML_VALUE));
request = initRequest("GET", "/path", r -> r.setContentType(MediaType.TEXT_PLAIN_VALUE));
assertThat(predicate.test(request)).isTrue();
assertThat(predicate.test(initRequest("GET", "", r -> r.setContentType(MediaType.TEXT_XML_VALUE)))).isFalse();
}
@Test
void singleAccept() {
RequestPredicate predicate = RequestPredicates.accept(MediaType.APPLICATION_JSON);
ServerRequest request = initRequest("GET", "/path", r -> r.addHeader("Accept", MediaType.APPLICATION_JSON_VALUE));
assertThat(predicate.test(request)).isTrue();
request = initRequest("GET", "", req -> req.addHeader("Accept", MediaType.TEXT_XML_VALUE));
assertThat(predicate.test(request)).isFalse();
}
@Test
void multipleAccepts() {
RequestPredicate predicate = RequestPredicates.accept(MediaType.APPLICATION_JSON, MediaType.TEXT_PLAIN);
ServerRequest request = initRequest("GET", "/path", r -> r.addHeader("Accept", MediaType.APPLICATION_JSON_VALUE));
assertThat(predicate.test(request)).isTrue();
request = initRequest("GET", "/path", r -> r.addHeader("Accept", MediaType.TEXT_PLAIN_VALUE));
assertThat(predicate.test(request)).isTrue();
request = initRequest("GET", "", req -> req.addHeader("Accept", MediaType.TEXT_XML_VALUE));
assertThat(predicate.test(request)).isFalse();
}