diff --git a/spring-webflux/src/main/kotlin/org/springframework/web/reactive/function/server/CoRouterFunctionDsl.kt b/spring-webflux/src/main/kotlin/org/springframework/web/reactive/function/server/CoRouterFunctionDsl.kt index f82a6b8bd9e..a0e366b19df 100644 --- a/spring-webflux/src/main/kotlin/org/springframework/web/reactive/function/server/CoRouterFunctionDsl.kt +++ b/spring-webflux/src/main/kotlin/org/springframework/web/reactive/function/server/CoRouterFunctionDsl.kt @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -586,6 +586,30 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct } } + /** + * Add an attribute with the given name and value to the last route built with this builder. + * @param name the attribute name + * @param value the attribute value +- * @since 6.0 + */ + fun withAttribute(name: String, value: Any) { + builder.withAttribute(name, value) + } + + /** + * Manipulate the attributes of the last route built with the given consumer. + * + * The map provided to the consumer is "live", so that the consumer can be used + * to [overwrite][MutableMap.put] existing attributes, + * [remove][MutableMap.remove] attributes, or use any of the other + * [MutableMap] methods. + * @param attributesConsumer a function that consumes the attributes map + * @since 6.0 + */ + fun withAttributes(attributesConsumer: (MutableMap) -> Unit) { + builder.withAttributes(attributesConsumer) + } + /** * Return a composed routing function created from all the registered routes. */ diff --git a/spring-webflux/src/main/kotlin/org/springframework/web/reactive/function/server/RouterFunctionDsl.kt b/spring-webflux/src/main/kotlin/org/springframework/web/reactive/function/server/RouterFunctionDsl.kt index 38e2fbe8771..c6f35ac2f23 100644 --- a/spring-webflux/src/main/kotlin/org/springframework/web/reactive/function/server/RouterFunctionDsl.kt +++ b/spring-webflux/src/main/kotlin/org/springframework/web/reactive/function/server/RouterFunctionDsl.kt @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -702,6 +702,30 @@ class RouterFunctionDsl internal constructor (private val init: RouterFunctionDs builder.onError({it is E}, responseProvider) } + /** + * Add an attribute with the given name and value to the last route built with this builder. + * @param name the attribute name + * @param value the attribute value + * @since 6.0 + */ + fun withAttribute(name: String, value: Any) { + builder.withAttribute(name, value) + } + + /** + * Manipulate the attributes of the last route built with the given consumer. + * + * The map provided to the consumer is "live", so that the consumer can be used + * to [overwrite][MutableMap.put] existing attributes, + * [remove][MutableMap.remove] attributes, or use any of the other + * [MutableMap] methods. + * @param attributesConsumer a function that consumes the attributes map + * @since 6.0 + */ + fun withAttributes(attributesConsumer: (MutableMap) -> Unit) { + builder.withAttributes(attributesConsumer) + } + /** * Return a composed routing function created from all the registered routes. * @since 5.1 diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/AttributesTestVisitor.java b/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/AttributesTestVisitor.java index 7a97e8c3631..8d8a4957172 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/AttributesTestVisitor.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/AttributesTestVisitor.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,43 +16,60 @@ package org.springframework.web.reactive.function.server; +import java.util.Deque; +import java.util.LinkedList; +import java.util.List; import java.util.Map; +import java.util.Objects; +import java.util.Optional; import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.Stream; import reactor.core.publisher.Mono; import org.springframework.core.io.Resource; import org.springframework.lang.Nullable; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.entry; - /** * @author Arjen Poutsma */ class AttributesTestVisitor implements RouterFunctions.Visitor { + private Deque> nestedAttributes = new LinkedList<>(); + @Nullable private Map attributes; + private List>> routerFunctionsAttributes = new LinkedList<>(); + private int visitCount; + public List>> routerFunctionsAttributes() { + return this.routerFunctionsAttributes; + } + public int visitCount() { return this.visitCount; } @Override public void startNested(RequestPredicate predicate) { + nestedAttributes.addFirst(attributes); + attributes = null; } @Override public void endNested(RequestPredicate predicate) { + attributes = nestedAttributes.removeFirst(); } @Override public void route(RequestPredicate predicate, HandlerFunction handlerFunction) { - assertThat(this.attributes).isNotNull(); - this.attributes = null; + Stream> current = Optional.ofNullable(attributes).stream(); + Stream> nested = nestedAttributes.stream().filter(Objects::nonNull); + routerFunctionsAttributes.add(Stream.concat(current, nested).collect(Collectors.toUnmodifiableList())); + attributes = null; } @Override @@ -61,7 +78,6 @@ class AttributesTestVisitor implements RouterFunctions.Visitor { @Override public void attributes(Map attributes) { - assertThat(attributes).containsExactly(entry("foo", "bar"), entry("baz", "qux")); this.attributes = attributes; this.visitCount++; } diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/RouterFunctionBuilderTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/RouterFunctionBuilderTests.java index 85b22f56593..b91d2840ad4 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/RouterFunctionBuilderTests.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/RouterFunctionBuilderTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,6 +18,8 @@ package org.springframework.web.reactive.function.server; import java.io.IOException; import java.util.Collections; +import java.util.List; +import java.util.Map; import java.util.concurrent.atomic.AtomicInteger; import org.junit.jupiter.api.Test; @@ -236,12 +238,28 @@ public class RouterFunctionBuilderTests { atts.put("foo", "bar"); atts.put("baz", "qux"); }) + .path("/atts", b1 -> b1 + .GET("/3", request -> ServerResponse.ok().build()) + .withAttribute("foo", "bar") + .GET("/4", request -> ServerResponse.ok().build()) + .withAttribute("baz", "qux") + .path("/5", b2 -> b2 + .GET(request -> ServerResponse.ok().build()) + .withAttribute("foo", "n3")) + .withAttribute("foo", "n2") + ) + .withAttribute("foo", "n1") .build(); AttributesTestVisitor visitor = new AttributesTestVisitor(); route.accept(visitor); - assertThat(visitor.visitCount()).isEqualTo(2); + assertThat(visitor.routerFunctionsAttributes()).containsExactly( + List.of(Map.of("foo", "bar", "baz", "qux")), + List.of(Map.of("foo", "bar", "baz", "qux")), + List.of(Map.of("foo", "bar"), Map.of("foo", "n1")), + List.of(Map.of("baz", "qux"), Map.of("foo", "n1")), + List.of(Map.of("foo", "n3"), Map.of("foo", "n2"), Map.of("foo", "n1")) + ); + assertThat(visitor.visitCount()).isEqualTo(7); } - - } diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/RouterFunctionTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/RouterFunctionTests.java index 4fcd6dbf1c7..49561cf8659 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/RouterFunctionTests.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/RouterFunctionTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,6 +17,8 @@ package org.springframework.web.reactive.function.server; import java.util.Collections; +import java.util.List; +import java.util.Map; import org.junit.jupiter.api.Test; import reactor.core.publisher.Mono; @@ -26,7 +28,10 @@ import org.springframework.web.testfixture.http.server.reactive.MockServerHttpRe import org.springframework.web.testfixture.server.MockServerWebExchange; import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.http.HttpMethod.GET; import static org.springframework.web.reactive.function.server.RequestPredicates.GET; +import static org.springframework.web.reactive.function.server.RequestPredicates.method; +import static org.springframework.web.reactive.function.server.RequestPredicates.path; /** * @author Arjen Poutsma @@ -137,11 +142,28 @@ public class RouterFunctionTests { .withAttributes(atts -> { atts.put("foo", "bar"); atts.put("baz", "qux"); - })); + })) + .and(RouterFunctions.nest(path("/atts"), + RouterFunctions.route(GET("/3"), request -> ServerResponse.ok().build()) + .withAttribute("foo", "bar") + .and(RouterFunctions.route(GET("/4"), request -> ServerResponse.ok().build()) + .withAttribute("baz", "qux")) + .and(RouterFunctions.nest(path("/5"), + RouterFunctions.route(method(GET), request -> ServerResponse.ok().build()) + .withAttribute("foo", "n3")) + .withAttribute("foo", "n2"))) + .withAttribute("foo", "n1")); AttributesTestVisitor visitor = new AttributesTestVisitor(); route.accept(visitor); - assertThat(visitor.visitCount()).isEqualTo(2); + assertThat(visitor.routerFunctionsAttributes()).containsExactly( + List.of(Map.of("foo", "bar", "baz", "qux")), + List.of(Map.of("foo", "bar", "baz", "qux")), + List.of(Map.of("foo", "bar"), Map.of("foo", "n1")), + List.of(Map.of("baz", "qux"), Map.of("foo", "n1")), + List.of(Map.of("foo", "n3"), Map.of("foo", "n2"), Map.of("foo", "n1")) + ); + assertThat(visitor.visitCount()).isEqualTo(7); } diff --git a/spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/server/CoRouterFunctionDslTests.kt b/spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/server/CoRouterFunctionDslTests.kt index bdeae8b00af..041956cefa3 100644 --- a/spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/server/CoRouterFunctionDslTests.kt +++ b/spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/server/CoRouterFunctionDslTests.kt @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ package org.springframework.web.reactive.function.server +import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThatExceptionOfType import org.junit.jupiter.api.Test import org.springframework.core.io.ClassPathResource @@ -25,6 +26,7 @@ import org.springframework.http.HttpStatus import org.springframework.http.MediaType.* import org.springframework.web.testfixture.http.server.reactive.MockServerHttpRequest.* import org.springframework.web.testfixture.server.MockServerWebExchange +import org.springframework.web.reactive.function.server.AttributesTestVisitor import reactor.test.StepVerifier /** @@ -163,6 +165,20 @@ class CoRouterFunctionDslTests { .verifyComplete() } + @Test + fun attributes() { + val visitor = AttributesTestVisitor() + attributesRouter.accept(visitor) + assertThat(visitor.routerFunctionsAttributes()).containsExactly( + listOf(mapOf("foo" to "bar", "baz" to "qux")), + listOf(mapOf("foo" to "bar", "baz" to "qux")), + listOf(mapOf("foo" to "bar"), mapOf("foo" to "n1")), + listOf(mapOf("baz" to "qux"), mapOf("foo" to "n1")), + listOf(mapOf("foo" to "n3"), mapOf("foo" to "n2"), mapOf("foo" to "n1")) + ); + assertThat(visitor.visitCount()).isEqualTo(7); + } + private fun sampleRouter() = coRouter { (GET("/foo/") or GET("/foos/")) { req -> handle(req) } "/api".nest { @@ -231,6 +247,39 @@ class CoRouterFunctionDslTests { } } + private val attributesRouter = router { + GET("/atts/1") { + ok().build() + } + withAttribute("foo", "bar") + withAttribute("baz", "qux") + GET("/atts/2") { + ok().build() + } + withAttributes { atts -> + atts["foo"] = "bar" + atts["baz"] = "qux" + } + "/atts".nest { + GET("/3") { + ok().build() + } + withAttribute("foo", "bar") + GET("/4") { + ok().build() + } + withAttribute("baz", "qux") + "/5".nest { + GET { + ok().build() + } + withAttribute("foo", "n3") + } + withAttribute("foo", "n2") + } + withAttribute("foo", "n1") + } + @Suppress("UNUSED_PARAMETER") private suspend fun handleFromClass(req: ServerRequest) = ServerResponse.ok().buildAndAwait() } diff --git a/spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/server/RouterFunctionDslTests.kt b/spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/server/RouterFunctionDslTests.kt index a050776f48a..4392b04fbf1 100644 --- a/spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/server/RouterFunctionDslTests.kt +++ b/spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/server/RouterFunctionDslTests.kt @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ package org.springframework.web.reactive.function.server +import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThatExceptionOfType import org.junit.jupiter.api.Test import org.springframework.core.io.ClassPathResource @@ -25,6 +26,7 @@ import org.springframework.http.HttpStatus import org.springframework.http.MediaType.* import org.springframework.web.testfixture.http.server.reactive.MockServerHttpRequest.* import org.springframework.web.testfixture.server.MockServerWebExchange +import org.springframework.web.reactive.function.server.AttributesTestVisitor import reactor.core.publisher.Mono import reactor.test.StepVerifier @@ -153,6 +155,19 @@ class RouterFunctionDslTests { } } + @Test + fun attributes() { + val visitor = AttributesTestVisitor() + attributesRouter.accept(visitor) + assertThat(visitor.routerFunctionsAttributes()).containsExactly( + listOf(mapOf("foo" to "bar", "baz" to "qux")), + listOf(mapOf("foo" to "bar", "baz" to "qux")), + listOf(mapOf("foo" to "bar"), mapOf("foo" to "n1")), + listOf(mapOf("baz" to "qux"), mapOf("foo" to "n1")), + listOf(mapOf("foo" to "n3"), mapOf("foo" to "n2"), mapOf("foo" to "n1")) + ); + assertThat(visitor.visitCount()).isEqualTo(7); + } private fun sampleRouter() = router { (GET("/foo/") or GET("/foos/")) { req -> handle(req) } @@ -210,6 +225,39 @@ class RouterFunctionDslTests { } } + private val attributesRouter = router { + GET("/atts/1") { + ok().build() + } + withAttribute("foo", "bar") + withAttribute("baz", "qux") + GET("/atts/2") { + ok().build() + } + withAttributes { atts -> + atts["foo"] = "bar" + atts["baz"] = "qux" + } + "/atts".nest { + GET("/3") { + ok().build() + } + withAttribute("foo", "bar") + GET("/4") { + ok().build() + } + withAttribute("baz", "qux") + "/5".nest { + GET { + ok().build() + } + withAttribute("foo", "n3") + } + withAttribute("foo", "n2") + } + withAttribute("foo", "n1") + } + @Suppress("UNUSED_PARAMETER") private fun handleFromClass(req: ServerRequest) = ServerResponse.ok().build() } diff --git a/spring-webmvc/src/main/kotlin/org/springframework/web/servlet/function/RouterFunctionDsl.kt b/spring-webmvc/src/main/kotlin/org/springframework/web/servlet/function/RouterFunctionDsl.kt index 391e824cb2b..f48f7ebdf74 100644 --- a/spring-webmvc/src/main/kotlin/org/springframework/web/servlet/function/RouterFunctionDsl.kt +++ b/spring-webmvc/src/main/kotlin/org/springframework/web/servlet/function/RouterFunctionDsl.kt @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,7 +21,7 @@ import org.springframework.http.HttpMethod import org.springframework.http.HttpStatusCode import org.springframework.http.MediaType import java.net.URI -import java.util.* +import java.util.Optional import java.util.function.Supplier /** @@ -699,6 +699,30 @@ class RouterFunctionDsl internal constructor (private val init: (RouterFunctionD builder.onError({it is E}, responseProvider) } + /** + * Add an attribute with the given name and value to the last route built with this builder. + * @param name the attribute name + * @param value the attribute value + * @since 6.0 + */ + fun withAttribute(name: String, value: Any) { + builder.withAttribute(name, value) + } + + /** + * Manipulate the attributes of the last route built with the given consumer. + * + * The map provided to the consumer is "live", so that the consumer can be used + * to [overwrite][MutableMap.put] existing attributes, + * [remove][MutableMap.remove] attributes, or use any of the other + * [MutableMap] methods. + * @param attributesConsumer a function that consumes the attributes map + * @since 6.0 + */ + fun withAttributes(attributesConsumer: (MutableMap) -> Unit) { + builder.withAttributes(attributesConsumer) + } + /** * Return a composed routing function created from all the registered routes. */ diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/function/AttributesTestVisitor.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/function/AttributesTestVisitor.java index 84e031518b4..0cf0e64db7e 100644 --- a/spring-webmvc/src/test/java/org/springframework/web/servlet/function/AttributesTestVisitor.java +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/function/AttributesTestVisitor.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,42 +16,58 @@ package org.springframework.web.servlet.function; +import java.util.Deque; +import java.util.LinkedList; +import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Optional; import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.Stream; import org.springframework.core.io.Resource; import org.springframework.lang.Nullable; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.entry; - /** * @author Arjen Poutsma */ class AttributesTestVisitor implements RouterFunctions.Visitor { + private Deque> nestedAttributes = new LinkedList<>(); + @Nullable private Map attributes; + private List>> routerFunctionsAttributes = new LinkedList<>(); + private int visitCount; + public List>> routerFunctionsAttributes() { + return this.routerFunctionsAttributes; + } + public int visitCount() { return this.visitCount; } @Override public void startNested(RequestPredicate predicate) { + nestedAttributes.addFirst(attributes); + attributes = null; } @Override public void endNested(RequestPredicate predicate) { + attributes = nestedAttributes.removeFirst(); } @Override public void route(RequestPredicate predicate, HandlerFunction handlerFunction) { - assertThat(this.attributes).isNotNull(); - this.attributes = null; + Stream> current = Optional.ofNullable(attributes).stream(); + Stream> nested = nestedAttributes.stream().filter(Objects::nonNull); + routerFunctionsAttributes.add(Stream.concat(current, nested).collect(Collectors.toUnmodifiableList())); + attributes = null; } @Override @@ -60,7 +76,6 @@ class AttributesTestVisitor implements RouterFunctions.Visitor { @Override public void attributes(Map attributes) { - assertThat(attributes).containsExactly(entry("foo", "bar"), entry("baz", "qux")); this.attributes = attributes; this.visitCount++; } diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/function/RouterFunctionBuilderTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/function/RouterFunctionBuilderTests.java index d535e00ee6b..c9a497a2a4d 100644 --- a/spring-webmvc/src/test/java/org/springframework/web/servlet/function/RouterFunctionBuilderTests.java +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/function/RouterFunctionBuilderTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,6 +17,8 @@ package org.springframework.web.servlet.function; import java.io.IOException; +import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Consumer; @@ -226,12 +228,28 @@ class RouterFunctionBuilderTests { atts.put("foo", "bar"); atts.put("baz", "qux"); }) + .path("/atts", b1 -> b1 + .GET("/3", request -> ServerResponse.ok().build()) + .withAttribute("foo", "bar") + .GET("/4", request -> ServerResponse.ok().build()) + .withAttribute("baz", "qux") + .path("/5", b2 -> b2 + .GET(request -> ServerResponse.ok().build()) + .withAttribute("foo", "n3")) + .withAttribute("foo", "n2") + ) + .withAttribute("foo", "n1") .build(); AttributesTestVisitor visitor = new AttributesTestVisitor(); route.accept(visitor); - assertThat(visitor.visitCount()).isEqualTo(2); + assertThat(visitor.routerFunctionsAttributes()).containsExactly( + List.of(Map.of("foo", "bar", "baz", "qux")), + List.of(Map.of("foo", "bar", "baz", "qux")), + List.of(Map.of("foo", "bar"), Map.of("foo", "n1")), + List.of(Map.of("baz", "qux"), Map.of("foo", "n1")), + List.of(Map.of("foo", "n3"), Map.of("foo", "n2"), Map.of("foo", "n1")) + ); + assertThat(visitor.visitCount()).isEqualTo(7); } - - } diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/function/RouterFunctionTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/function/RouterFunctionTests.java index e3cba1677f9..e175051a5af 100644 --- a/spring-webmvc/src/test/java/org/springframework/web/servlet/function/RouterFunctionTests.java +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/function/RouterFunctionTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,6 +17,8 @@ package org.springframework.web.servlet.function; import java.util.Collections; +import java.util.List; +import java.util.Map; import java.util.Optional; import org.junit.jupiter.api.Test; @@ -24,7 +26,10 @@ import org.junit.jupiter.api.Test; import org.springframework.web.servlet.handler.PathPatternsTestUtils; import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.http.HttpMethod.GET; import static org.springframework.web.servlet.function.RequestPredicates.GET; +import static org.springframework.web.servlet.function.RequestPredicates.method; +import static org.springframework.web.servlet.function.RequestPredicates.path; /** * @author Arjen Poutsma @@ -120,11 +125,28 @@ class RouterFunctionTests { .withAttributes(atts -> { atts.put("foo", "bar"); atts.put("baz", "qux"); - })); + })) + .and(RouterFunctions.nest(path("/atts"), + RouterFunctions.route(GET("/3"), request -> ServerResponse.ok().build()) + .withAttribute("foo", "bar") + .and(RouterFunctions.route(GET("/4"), request -> ServerResponse.ok().build()) + .withAttribute("baz", "qux")) + .and(RouterFunctions.nest(path("/5"), + RouterFunctions.route(method(GET), request -> ServerResponse.ok().build()) + .withAttribute("foo", "n3")) + .withAttribute("foo", "n2"))) + .withAttribute("foo", "n1")); AttributesTestVisitor visitor = new AttributesTestVisitor(); route.accept(visitor); - assertThat(visitor.visitCount()).isEqualTo(2); + assertThat(visitor.routerFunctionsAttributes()).containsExactly( + List.of(Map.of("foo", "bar", "baz", "qux")), + List.of(Map.of("foo", "bar", "baz", "qux")), + List.of(Map.of("foo", "bar"), Map.of("foo", "n1")), + List.of(Map.of("baz", "qux"), Map.of("foo", "n1")), + List.of(Map.of("foo", "n3"), Map.of("foo", "n2"), Map.of("foo", "n1")) + ); + assertThat(visitor.visitCount()).isEqualTo(7); } diff --git a/spring-webmvc/src/test/kotlin/org/springframework/web/servlet/function/RouterFunctionDslTests.kt b/spring-webmvc/src/test/kotlin/org/springframework/web/servlet/function/RouterFunctionDslTests.kt index 750d05d01e3..ccfb300b39e 100644 --- a/spring-webmvc/src/test/kotlin/org/springframework/web/servlet/function/RouterFunctionDslTests.kt +++ b/spring-webmvc/src/test/kotlin/org/springframework/web/servlet/function/RouterFunctionDslTests.kt @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -134,6 +134,20 @@ class RouterFunctionDslTests { assertThat(sampleRouter().route(request).get().handle(request).headers().getFirst("foo")).isEqualTo("bar") } + @Test + fun attributes() { + val visitor = AttributesTestVisitor() + attributesRouter.accept(visitor) + assertThat(visitor.routerFunctionsAttributes()).containsExactly( + listOf(mapOf("foo" to "bar", "baz" to "qux")), + listOf(mapOf("foo" to "bar", "baz" to "qux")), + listOf(mapOf("foo" to "bar"), mapOf("foo" to "n1")), + listOf(mapOf("baz" to "qux"), mapOf("foo" to "n1")), + listOf(mapOf("foo" to "n3"), mapOf("foo" to "n2"), mapOf("foo" to "n1")) + ); + assertThat(visitor.visitCount()).isEqualTo(7); + } + private fun sampleRouter() = router { (GET("/foo/") or GET("/foos/")) { req -> handle(req) } "/api".nest { @@ -202,6 +216,39 @@ class RouterFunctionDslTests { } } + private val attributesRouter = router { + GET("/atts/1") { + ok().build() + } + withAttribute("foo", "bar") + withAttribute("baz", "qux") + GET("/atts/2") { + ok().build() + } + withAttributes { atts -> + atts["foo"] = "bar" + atts["baz"] = "qux" + } + "/atts".nest { + GET("/3") { + ok().build() + } + withAttribute("foo", "bar") + GET("/4") { + ok().build() + } + withAttribute("baz", "qux") + "/5".nest { + GET { + ok().build() + } + withAttribute("foo", "n3") + } + withAttribute("foo", "n2") + } + withAttribute("foo", "n1") + } + @Suppress("UNUSED_PARAMETER") private fun handleFromClass(req: ServerRequest) = ServerResponse.ok().build() }