From 750cb739023a4ceb52e91ea66ce09f98638a8821 Mon Sep 17 00:00:00 2001 From: Arjen Poutsma Date: Mon, 12 Feb 2024 15:44:53 +0100 Subject: [PATCH] Introduce single-value request predicates This commit introduces new HTTP method, Content-Type, and Accept header request predicates that handle single values. Previously, these predicates were always dealt with as single-value collections, which introduced computational overhead. Closes gh-32244 --- .../function/server/RequestPredicates.java | 207 ++++++++++++------ .../server/RequestPredicatesTests.java | 58 ++++- .../servlet/function/RequestPredicates.java | 207 ++++++++++++------ .../function/RequestPredicatesTests.java | 45 +++- 4 files changed, 372 insertions(+), 145 deletions(-) diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RequestPredicates.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RequestPredicates.java index b37e58bc02..856f9e3f51 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RequestPredicates.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RequestPredicates.java @@ -47,7 +47,6 @@ import org.springframework.http.codec.multipart.Part; import org.springframework.http.server.PathContainer; import org.springframework.http.server.RequestPath; import org.springframework.http.server.reactive.ServerHttpRequest; -import org.springframework.lang.NonNull; import org.springframework.lang.Nullable; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; @@ -90,7 +89,8 @@ public abstract class RequestPredicates { * @return a predicate that tests against the given HTTP method */ public static RequestPredicate method(HttpMethod httpMethod) { - return new HttpMethodPredicate(httpMethod); + Assert.notNull(httpMethod, "HttpMethod must not be null"); + return new SingleHttpMethodPredicate(httpMethod); } /** @@ -101,7 +101,13 @@ public abstract class RequestPredicates { * @since 5.1 */ public static RequestPredicate methods(HttpMethod... httpMethods) { - return new HttpMethodPredicate(httpMethods); + Assert.notEmpty(httpMethods, "HttpMethods must not be empty"); + if (httpMethods.length == 1) { + return new SingleHttpMethodPredicate(httpMethods[0]); + } + else { + return new MultipleHttpMethodsPredicate(httpMethods); + } } /** @@ -151,7 +157,12 @@ public abstract class RequestPredicates { */ public static RequestPredicate contentType(MediaType... mediaTypes) { Assert.notEmpty(mediaTypes, "'mediaTypes' must not be empty"); - return new ContentTypePredicate(mediaTypes); + if (mediaTypes.length == 1) { + return new SingleContentTypePredicate(mediaTypes[0]); + } + else { + return new MultipleContentTypesPredicate(mediaTypes); + } } /** @@ -163,7 +174,12 @@ public abstract class RequestPredicates { */ public static RequestPredicate accept(MediaType... mediaTypes) { Assert.notEmpty(mediaTypes, "'mediaTypes' must not be empty"); - return new AcceptPredicate(mediaTypes); + if (mediaTypes.length == 1) { + return new SingleAcceptPredicate(mediaTypes[0]); + } + else { + return new MultipleAcceptsPredicate(mediaTypes); + } } /** @@ -529,29 +545,23 @@ public abstract class RequestPredicates { } - private static class HttpMethodPredicate implements RequestPredicate { + private static class SingleHttpMethodPredicate implements RequestPredicate { - private final Set httpMethods; + private final HttpMethod httpMethod; - public HttpMethodPredicate(HttpMethod httpMethod) { - Assert.notNull(httpMethod, "HttpMethod must not be null"); - this.httpMethods = Set.of(httpMethod); - } - - public HttpMethodPredicate(HttpMethod... httpMethods) { - Assert.notEmpty(httpMethods, "HttpMethods must not be empty"); - this.httpMethods = new LinkedHashSet<>(Arrays.asList(httpMethods)); + public SingleHttpMethodPredicate(HttpMethod httpMethod) { + this.httpMethod = httpMethod; } @Override public boolean test(ServerRequest request) { HttpMethod method = method(request); - boolean match = this.httpMethods.contains(method); - traceMatch("Method", this.httpMethods, method, match); + boolean match = this.httpMethod.equals(method); + traceMatch("Method", this.httpMethod, method, match); return match; } - private static HttpMethod method(ServerRequest request) { + static HttpMethod method(ServerRequest request) { if (CorsUtils.isPreFlightRequest(request.exchange().getRequest())) { String accessControlRequestMethod = request.headers().firstHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD); @@ -562,6 +572,34 @@ public abstract class RequestPredicates { return request.method(); } + @Override + public void accept(Visitor visitor) { + visitor.method(Set.of(this.httpMethod)); + } + + @Override + public String toString() { + return this.httpMethod.toString(); + } + } + + + private static class MultipleHttpMethodsPredicate implements RequestPredicate { + + private final Set httpMethods; + + public MultipleHttpMethodsPredicate(HttpMethod[] httpMethods) { + this.httpMethods = new LinkedHashSet<>(Arrays.asList(httpMethods)); + } + + @Override + public boolean test(ServerRequest request) { + HttpMethod method = SingleHttpMethodPredicate.method(request); + boolean match = this.httpMethods.contains(method); + traceMatch("Method", this.httpMethods, method, match); + return match; + } + @Override public void accept(Visitor visitor) { visitor.method(Collections.unmodifiableSet(this.httpMethods)); @@ -569,12 +607,7 @@ public abstract class RequestPredicates { @Override public String toString() { - if (this.httpMethods.size() == 1) { - return this.httpMethods.iterator().next().toString(); - } - else { - return this.httpMethods.toString(); - } + return this.httpMethods.toString(); } } @@ -669,20 +702,46 @@ public abstract class RequestPredicates { } - private static class ContentTypePredicate extends HeadersPredicate { + private static class SingleContentTypePredicate extends HeadersPredicate { - private final Set mediaTypes; + private final MediaType mediaType; - public ContentTypePredicate(MediaType... mediaTypes) { - this(Set.of(mediaTypes)); + public SingleContentTypePredicate(MediaType mediaType) { + super(headers -> { + MediaType contentType = headers.contentType().orElse(MediaType.APPLICATION_OCTET_STREAM); + boolean match = mediaType.includes(contentType); + traceMatch("Content-Type", mediaType, contentType, match); + return match; + }); + this.mediaType = mediaType; } - private ContentTypePredicate(Set mediaTypes) { + @Override + public void accept(Visitor visitor) { + visitor.header(HttpHeaders.CONTENT_TYPE, this.mediaType.toString()); + } + + @Override + public String toString() { + return "Content-Type: " + this.mediaType; + } + } + + + private static class MultipleContentTypesPredicate extends HeadersPredicate { + + private final MediaType[] mediaTypes; + + public MultipleContentTypesPredicate(MediaType[] mediaTypes) { super(headers -> { - MediaType contentType = - headers.contentType().orElse(MediaType.APPLICATION_OCTET_STREAM); - boolean match = mediaTypes.stream() - .anyMatch(mediaType -> mediaType.includes(contentType)); + MediaType contentType = headers.contentType().orElse(MediaType.APPLICATION_OCTET_STREAM); + boolean match = false; + for (MediaType mediaType : mediaTypes) { + if (mediaType.includes(contentType)) { + match = true; + break; + } + } traceMatch("Content-Type", mediaTypes, contentType, match); return match; }); @@ -691,44 +750,37 @@ public abstract class RequestPredicates { @Override public void accept(Visitor visitor) { - visitor.header(HttpHeaders.CONTENT_TYPE, - (this.mediaTypes.size() == 1) ? - this.mediaTypes.iterator().next().toString() : - this.mediaTypes.toString()); + visitor.header(HttpHeaders.CONTENT_TYPE, Arrays.toString(this.mediaTypes)); } @Override public String toString() { - return String.format("Content-Type: %s", - (this.mediaTypes.size() == 1) ? - this.mediaTypes.iterator().next().toString() : - this.mediaTypes.toString()); + return "Content-Type: " + Arrays.toString(this.mediaTypes); } } - private static class AcceptPredicate extends HeadersPredicate { + private static class SingleAcceptPredicate extends HeadersPredicate { - private final Set mediaTypes; + private final MediaType mediaType; - public AcceptPredicate(MediaType... mediaTypes) { - this(Set.of(mediaTypes)); - } - - private AcceptPredicate(Set mediaTypes) { + public SingleAcceptPredicate(MediaType mediaType) { super(headers -> { List acceptedMediaTypes = acceptedMediaTypes(headers); - boolean match = acceptedMediaTypes.stream() - .anyMatch(acceptedMediaType -> mediaTypes.stream() - .anyMatch(acceptedMediaType::isCompatibleWith)); - traceMatch("Accept", mediaTypes, acceptedMediaTypes, match); + boolean match = false; + for (MediaType acceptedMediaType : acceptedMediaTypes) { + if (acceptedMediaType.isCompatibleWith(mediaType)) { + match = true; + break; + } + } + traceMatch("Accept", mediaType, acceptedMediaTypes, match); return match; }); - this.mediaTypes = mediaTypes; + this.mediaType = mediaType; } - @NonNull - private static List acceptedMediaTypes(ServerRequest.Headers headers) { + static List acceptedMediaTypes(ServerRequest.Headers headers) { List acceptedMediaTypes = headers.accept(); if (acceptedMediaTypes.isEmpty()) { acceptedMediaTypes = Collections.singletonList(MediaType.ALL); @@ -741,18 +793,47 @@ public abstract class RequestPredicates { @Override public void accept(Visitor visitor) { - visitor.header(HttpHeaders.ACCEPT, - (this.mediaTypes.size() == 1) ? - this.mediaTypes.iterator().next().toString() : - this.mediaTypes.toString()); + visitor.header(HttpHeaders.ACCEPT, this.mediaType.toString()); } @Override public String toString() { - return String.format("Accept: %s", - (this.mediaTypes.size() == 1) ? - this.mediaTypes.iterator().next().toString() : - this.mediaTypes.toString()); + return "Accept: " + this.mediaType; + } + } + + + private static class MultipleAcceptsPredicate extends HeadersPredicate { + + private final MediaType[] mediaTypes; + + public MultipleAcceptsPredicate(MediaType[] mediaTypes) { + super(headers -> { + List acceptedMediaTypes = SingleAcceptPredicate.acceptedMediaTypes(headers); + boolean match = false; + outer: + for (MediaType acceptedMediaType : acceptedMediaTypes) { + for (MediaType mediaType : mediaTypes) { + if (acceptedMediaType.isCompatibleWith(mediaType)) { + match = true; + break outer; + } + } + } + traceMatch("Accept", mediaTypes, acceptedMediaTypes, match); + return match; + }); + this.mediaTypes = mediaTypes; + } + + @Override + public void accept(Visitor visitor) { + visitor.header(HttpHeaders.ACCEPT, Arrays.toString(this.mediaTypes)); + } + + @Override + public String toString() { + return "Accept: " + Arrays.toString(this.mediaTypes); } } diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/RequestPredicatesTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/RequestPredicatesTests.java index 2f855513da..68140d02a0 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/RequestPredicatesTests.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/RequestPredicatesTests.java @@ -219,11 +219,10 @@ class RequestPredicatesTests { @Test - void contentType() { - MediaType json = MediaType.APPLICATION_JSON; - RequestPredicate predicate = RequestPredicates.contentType(json); + void singleContentType() { + RequestPredicate predicate = RequestPredicates.contentType(MediaType.APPLICATION_JSON); MockServerHttpRequest mockRequest = MockServerHttpRequest.get("https://example.com") - .header(HttpHeaders.CONTENT_TYPE, json.toString()) + .header(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) .build(); ServerRequest request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList()); assertThat(predicate.test(request)).isTrue(); @@ -236,15 +235,58 @@ class RequestPredicatesTests { } @Test - void accept() { - MediaType json = MediaType.APPLICATION_JSON; - RequestPredicate predicate = RequestPredicates.accept(json); + void multipleContentTypes() { + RequestPredicate predicate = RequestPredicates.contentType(MediaType.APPLICATION_JSON, MediaType.TEXT_PLAIN); MockServerHttpRequest mockRequest = MockServerHttpRequest.get("https://example.com") - .header(HttpHeaders.ACCEPT, json.toString()) + .header(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) .build(); ServerRequest request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList()); assertThat(predicate.test(request)).isTrue(); + mockRequest = MockServerHttpRequest.get("https://example.com") + .header(HttpHeaders.CONTENT_TYPE, MediaType.TEXT_PLAIN_VALUE) + .build(); + request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList()); + assertThat(predicate.test(request)).isTrue(); + + mockRequest = MockServerHttpRequest.get("https://example.com") + .header(HttpHeaders.CONTENT_TYPE, "foo/bar") + .build(); + request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList()); + assertThat(predicate.test(request)).isFalse(); + } + + @Test + void singleAccept() { + RequestPredicate predicate = RequestPredicates.accept(MediaType.APPLICATION_JSON, MediaType.TEXT_PLAIN); + MockServerHttpRequest mockRequest = MockServerHttpRequest.get("https://example.com") + .header(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE) + .build(); + ServerRequest request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList()); + assertThat(predicate.test(request)).isTrue(); + + mockRequest = MockServerHttpRequest.get("https://example.com") + .header(HttpHeaders.ACCEPT, "foo/bar") + .build(); + request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList()); + assertThat(predicate.test(request)).isFalse(); + } + + @Test + void multipleAccepts() { + RequestPredicate predicate = RequestPredicates.accept(MediaType.APPLICATION_JSON, MediaType.TEXT_PLAIN); + MockServerHttpRequest mockRequest = MockServerHttpRequest.get("https://example.com") + .header(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE) + .build(); + ServerRequest request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList()); + assertThat(predicate.test(request)).isTrue(); + + mockRequest = MockServerHttpRequest.get("https://example.com") + .header(HttpHeaders.ACCEPT, MediaType.TEXT_PLAIN_VALUE) + .build(); + request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList()); + assertThat(predicate.test(request)).isTrue(); + mockRequest = MockServerHttpRequest.get("https://example.com") .header(HttpHeaders.ACCEPT, "foo/bar") .build(); diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/function/RequestPredicates.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/function/RequestPredicates.java index d326bf0547..5cd3e45a59 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/function/RequestPredicates.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/function/RequestPredicates.java @@ -48,7 +48,6 @@ import org.springframework.http.MediaType; import org.springframework.http.converter.HttpMessageConverter; import org.springframework.http.server.PathContainer; import org.springframework.http.server.RequestPath; -import org.springframework.lang.NonNull; import org.springframework.lang.Nullable; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; @@ -90,7 +89,8 @@ public abstract class RequestPredicates { * @return a predicate that tests against the given HTTP method */ public static RequestPredicate method(HttpMethod httpMethod) { - return new HttpMethodPredicate(httpMethod); + Assert.notNull(httpMethod, "HttpMethod must not be null"); + return new SingleHttpMethodPredicate(httpMethod); } /** @@ -100,7 +100,13 @@ public abstract class RequestPredicates { * @return a predicate that tests against the given HTTP methods */ public static RequestPredicate methods(HttpMethod... httpMethods) { - return new HttpMethodPredicate(httpMethods); + Assert.notEmpty(httpMethods, "HttpMethods must not be empty"); + if (httpMethods.length == 1) { + return new SingleHttpMethodPredicate(httpMethods[0]); + } + else { + return new MultipleHttpMethodsPredicate(httpMethods); + } } /** @@ -150,7 +156,12 @@ public abstract class RequestPredicates { */ public static RequestPredicate contentType(MediaType... mediaTypes) { Assert.notEmpty(mediaTypes, "'mediaTypes' must not be empty"); - return new ContentTypePredicate(mediaTypes); + if (mediaTypes.length == 1) { + return new SingleContentTypePredicate(mediaTypes[0]); + } + else { + return new MultipleContentTypesPredicate(mediaTypes); + } } /** @@ -162,7 +173,12 @@ public abstract class RequestPredicates { */ public static RequestPredicate accept(MediaType... mediaTypes) { Assert.notEmpty(mediaTypes, "'mediaTypes' must not be empty"); - return new AcceptPredicate(mediaTypes); + if (mediaTypes.length == 1) { + return new SingleAcceptPredicate(mediaTypes[0]); + } + else { + return new MultipleAcceptsPredicate(mediaTypes); + } } /** @@ -527,29 +543,23 @@ public abstract class RequestPredicates { } - private static class HttpMethodPredicate implements RequestPredicate { + private static class SingleHttpMethodPredicate implements RequestPredicate { - private final Set httpMethods; + private final HttpMethod httpMethod; - public HttpMethodPredicate(HttpMethod httpMethod) { - Assert.notNull(httpMethod, "HttpMethod must not be null"); - this.httpMethods = Set.of(httpMethod); - } - - public HttpMethodPredicate(HttpMethod... httpMethods) { - Assert.notEmpty(httpMethods, "HttpMethods must not be empty"); - this.httpMethods = new LinkedHashSet<>(Arrays.asList(httpMethods)); + public SingleHttpMethodPredicate(HttpMethod httpMethod) { + this.httpMethod = httpMethod; } @Override public boolean test(ServerRequest request) { HttpMethod method = method(request); - boolean match = this.httpMethods.contains(method); - traceMatch("Method", this.httpMethods, method, match); + boolean match = this.httpMethod.equals(method); + traceMatch("Method", this.httpMethod, method, match); return match; } - private static HttpMethod method(ServerRequest request) { + static HttpMethod method(ServerRequest request) { if (CorsUtils.isPreFlightRequest(request.servletRequest())) { String accessControlRequestMethod = request.headers().firstHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD); @@ -560,6 +570,34 @@ public abstract class RequestPredicates { return request.method(); } + @Override + public void accept(Visitor visitor) { + visitor.method(Set.of(this.httpMethod)); + } + + @Override + public String toString() { + return this.httpMethod.toString(); + } + } + + + private static class MultipleHttpMethodsPredicate implements RequestPredicate { + + private final Set httpMethods; + + public MultipleHttpMethodsPredicate(HttpMethod[] httpMethods) { + this.httpMethods = new LinkedHashSet<>(Arrays.asList(httpMethods)); + } + + @Override + public boolean test(ServerRequest request) { + HttpMethod method = SingleHttpMethodPredicate.method(request); + boolean match = this.httpMethods.contains(method); + traceMatch("Method", this.httpMethods, method, match); + return match; + } + @Override public void accept(Visitor visitor) { visitor.method(Collections.unmodifiableSet(this.httpMethods)); @@ -567,12 +605,7 @@ public abstract class RequestPredicates { @Override public String toString() { - if (this.httpMethods.size() == 1) { - return this.httpMethods.iterator().next().toString(); - } - else { - return this.httpMethods.toString(); - } + return this.httpMethods.toString(); } } @@ -667,20 +700,46 @@ public abstract class RequestPredicates { } - private static class ContentTypePredicate extends HeadersPredicate { + private static class SingleContentTypePredicate extends HeadersPredicate { - private final Set mediaTypes; + private final MediaType mediaType; - public ContentTypePredicate(MediaType... mediaTypes) { - this(Set.of(mediaTypes)); + public SingleContentTypePredicate(MediaType mediaType) { + super(headers -> { + MediaType contentType = headers.contentType().orElse(MediaType.APPLICATION_OCTET_STREAM); + boolean match = mediaType.includes(contentType); + traceMatch("Content-Type", mediaType, contentType, match); + return match; + }); + this.mediaType = mediaType; } - private ContentTypePredicate(Set mediaTypes) { + @Override + public void accept(Visitor visitor) { + visitor.header(HttpHeaders.CONTENT_TYPE, this.mediaType.toString()); + } + + @Override + public String toString() { + return "Content-Type: " + this.mediaType; + } + } + + + private static class MultipleContentTypesPredicate extends HeadersPredicate { + + private final MediaType[] mediaTypes; + + public MultipleContentTypesPredicate(MediaType[] mediaTypes) { super(headers -> { - MediaType contentType = - headers.contentType().orElse(MediaType.APPLICATION_OCTET_STREAM); - boolean match = mediaTypes.stream() - .anyMatch(mediaType -> mediaType.includes(contentType)); + MediaType contentType = headers.contentType().orElse(MediaType.APPLICATION_OCTET_STREAM); + boolean match = false; + for (MediaType mediaType : mediaTypes) { + if (mediaType.includes(contentType)) { + match = true; + break; + } + } traceMatch("Content-Type", mediaTypes, contentType, match); return match; }); @@ -689,44 +748,37 @@ public abstract class RequestPredicates { @Override public void accept(Visitor visitor) { - visitor.header(HttpHeaders.CONTENT_TYPE, - (this.mediaTypes.size() == 1) ? - this.mediaTypes.iterator().next().toString() : - this.mediaTypes.toString()); + visitor.header(HttpHeaders.CONTENT_TYPE, Arrays.toString(this.mediaTypes)); } @Override public String toString() { - return String.format("Content-Type: %s", - (this.mediaTypes.size() == 1) ? - this.mediaTypes.iterator().next().toString() : - this.mediaTypes.toString()); + return "Content-Type: " + Arrays.toString(this.mediaTypes); } } - private static class AcceptPredicate extends HeadersPredicate { + private static class SingleAcceptPredicate extends HeadersPredicate { - private final Set mediaTypes; + private final MediaType mediaType; - public AcceptPredicate(MediaType... mediaTypes) { - this(Set.of(mediaTypes)); - } - - private AcceptPredicate(Set mediaTypes) { + public SingleAcceptPredicate(MediaType mediaType) { super(headers -> { List acceptedMediaTypes = acceptedMediaTypes(headers); - boolean match = acceptedMediaTypes.stream() - .anyMatch(acceptedMediaType -> mediaTypes.stream() - .anyMatch(acceptedMediaType::isCompatibleWith)); - traceMatch("Accept", mediaTypes, acceptedMediaTypes, match); + boolean match = false; + for (MediaType acceptedMediaType : acceptedMediaTypes) { + if (acceptedMediaType.isCompatibleWith(mediaType)) { + match = true; + break; + } + } + traceMatch("Accept", mediaType, acceptedMediaTypes, match); return match; }); - this.mediaTypes = mediaTypes; + this.mediaType = mediaType; } - @NonNull - private static List acceptedMediaTypes(ServerRequest.Headers headers) { + static List acceptedMediaTypes(ServerRequest.Headers headers) { List acceptedMediaTypes = headers.accept(); if (acceptedMediaTypes.isEmpty()) { acceptedMediaTypes = Collections.singletonList(MediaType.ALL); @@ -739,18 +791,47 @@ public abstract class RequestPredicates { @Override public void accept(Visitor visitor) { - visitor.header(HttpHeaders.ACCEPT, - (this.mediaTypes.size() == 1) ? - this.mediaTypes.iterator().next().toString() : - this.mediaTypes.toString()); + visitor.header(HttpHeaders.ACCEPT, this.mediaType.toString()); } @Override public String toString() { - return String.format("Accept: %s", - (this.mediaTypes.size() == 1) ? - this.mediaTypes.iterator().next().toString() : - this.mediaTypes.toString()); + return "Accept: " + this.mediaType; + } + } + + + private static class MultipleAcceptsPredicate extends HeadersPredicate { + + private final MediaType[] mediaTypes; + + public MultipleAcceptsPredicate(MediaType[] mediaTypes) { + super(headers -> { + List acceptedMediaTypes = SingleAcceptPredicate.acceptedMediaTypes(headers); + boolean match = false; + outer: + for (MediaType acceptedMediaType : acceptedMediaTypes) { + for (MediaType mediaType : mediaTypes) { + if (acceptedMediaType.isCompatibleWith(mediaType)) { + match = true; + break outer; + } + } + } + traceMatch("Accept", mediaTypes, acceptedMediaTypes, match); + return match; + }); + this.mediaTypes = mediaTypes; + } + + @Override + public void accept(Visitor visitor) { + visitor.header(HttpHeaders.ACCEPT, Arrays.toString(this.mediaTypes)); + } + + @Override + public String toString() { + return "Accept: " + Arrays.toString(this.mediaTypes); } } diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/function/RequestPredicatesTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/function/RequestPredicatesTests.java index 4da80ab5e8..4b56e61ed5 100644 --- a/spring-webmvc/src/test/java/org/springframework/web/servlet/function/RequestPredicatesTests.java +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/function/RequestPredicatesTests.java @@ -31,7 +31,6 @@ import org.springframework.web.testfixture.servlet.MockHttpServletRequest; import org.springframework.web.util.pattern.PathPatternParser; import static org.assertj.core.api.Assertions.assertThat; -import static org.springframework.http.MediaType.TEXT_XML_VALUE; /** * @author Arjen Poutsma @@ -179,22 +178,46 @@ class RequestPredicatesTests { } @Test - void contentType() { - MediaType json = MediaType.APPLICATION_JSON; - RequestPredicate predicate = RequestPredicates.contentType(json); - ServerRequest request = initRequest("GET", "/path", req -> req.setContentType(json.toString())); + void singleContentType() { + RequestPredicate predicate = RequestPredicates.contentType(MediaType.APPLICATION_JSON); + ServerRequest request = initRequest("GET", "/path", r -> r.setContentType(MediaType.APPLICATION_JSON_VALUE)); assertThat(predicate.test(request)).isTrue(); - assertThat(predicate.test(initRequest("GET", ""))).isFalse(); + + assertThat(predicate.test(initRequest("GET", "", r -> r.setContentType(MediaType.TEXT_XML_VALUE)))).isFalse(); } @Test - void accept() { - MediaType json = MediaType.APPLICATION_JSON; - RequestPredicate predicate = RequestPredicates.accept(json); - ServerRequest request = initRequest("GET", "/path", req -> req.addHeader("Accept", json.toString())); + void multipleContentTypes() { + RequestPredicate predicate = RequestPredicates.contentType(MediaType.APPLICATION_JSON, MediaType.TEXT_PLAIN); + ServerRequest request = initRequest("GET", "/path", r -> r.setContentType(MediaType.APPLICATION_JSON_VALUE)); assertThat(predicate.test(request)).isTrue(); - request = initRequest("GET", "", req -> req.addHeader("Accept", TEXT_XML_VALUE)); + request = initRequest("GET", "/path", r -> r.setContentType(MediaType.TEXT_PLAIN_VALUE)); + assertThat(predicate.test(request)).isTrue(); + + assertThat(predicate.test(initRequest("GET", "", r -> r.setContentType(MediaType.TEXT_XML_VALUE)))).isFalse(); + } + + @Test + void singleAccept() { + RequestPredicate predicate = RequestPredicates.accept(MediaType.APPLICATION_JSON); + ServerRequest request = initRequest("GET", "/path", r -> r.addHeader("Accept", MediaType.APPLICATION_JSON_VALUE)); + assertThat(predicate.test(request)).isTrue(); + + request = initRequest("GET", "", req -> req.addHeader("Accept", MediaType.TEXT_XML_VALUE)); + assertThat(predicate.test(request)).isFalse(); + } + + @Test + void multipleAccepts() { + RequestPredicate predicate = RequestPredicates.accept(MediaType.APPLICATION_JSON, MediaType.TEXT_PLAIN); + ServerRequest request = initRequest("GET", "/path", r -> r.addHeader("Accept", MediaType.APPLICATION_JSON_VALUE)); + assertThat(predicate.test(request)).isTrue(); + + request = initRequest("GET", "/path", r -> r.addHeader("Accept", MediaType.TEXT_PLAIN_VALUE)); + assertThat(predicate.test(request)).isTrue(); + + request = initRequest("GET", "", req -> req.addHeader("Accept", MediaType.TEXT_XML_VALUE)); assertThat(predicate.test(request)).isFalse(); }