diff --git a/test/src/main/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurers.java b/test/src/main/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurers.java index e4b4981812..07c86e2de4 100644 --- a/test/src/main/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurers.java +++ b/test/src/main/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurers.java @@ -26,7 +26,6 @@ import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextImpl; import org.springframework.security.core.userdetails.User; import org.springframework.security.core.userdetails.UserDetails; -import org.springframework.security.test.context.TestSecurityContextHolder; import org.springframework.test.web.reactive.server.MockServerConfigurer; import org.springframework.test.web.reactive.server.WebTestClient; import org.springframework.test.web.reactive.server.WebTestClientConfigurer; @@ -39,6 +38,7 @@ import reactor.core.publisher.Mono; import java.util.Collection; import java.util.List; import java.util.function.Consumer; +import java.util.function.Supplier; /** * Test utilities for working with Spring Security and @@ -58,7 +58,6 @@ public class SecurityMockServerConfigurers { public void beforeServerCreated(WebHttpHandlerBuilder builder) { builder.filters( filters -> { filters.add(0, new MutatorFilter()); - filters.add(0, new SetupMutatorFilter(TestSecurityContextHolder.getContext())); }); } }; @@ -71,7 +70,7 @@ public class SecurityMockServerConfigurers { * @return the {@link WebTestClientConfigurer}} to use */ public static T mockAuthentication(Authentication authentication) { - return (T) new MutatorWebTestClientConfigurer(authentication); + return (T) new MutatorWebTestClientConfigurer(() -> Mono.just(authentication).map(SecurityContextImpl::new)); } /** @@ -216,21 +215,11 @@ public class SecurityMockServerConfigurers { } private static class MutatorWebTestClientConfigurer implements WebTestClientConfigurer, MockServerConfigurer { - private final Mono context; + private final Supplier> context; - private MutatorWebTestClientConfigurer(Mono context) { + private MutatorWebTestClientConfigurer(Supplier> context) { this.context = context; } - - private MutatorWebTestClientConfigurer(SecurityContext context) { - this(Mono.just(context)); - } - - private MutatorWebTestClientConfigurer(Authentication authentication) { - this(new SecurityContextImpl(authentication)); - } - - @Override public void beforeServerCreated(WebHttpHandlerBuilder builder) { builder.filters(addSetupMutatorFilter()); @@ -247,20 +236,12 @@ public class SecurityMockServerConfigurers { } private static class SetupMutatorFilter implements WebFilter { - private final Mono context; + private final Supplier> context; - private SetupMutatorFilter(Mono context) { + private SetupMutatorFilter(Supplier> context) { this.context = context; } - private SetupMutatorFilter(SecurityContext context) { - this(Mono.just(context)); - } - - private SetupMutatorFilter(Authentication authentication) { - this(new SecurityContextImpl(authentication)); - } - @Override public Mono filter(ServerWebExchange exchange, WebFilterChain webFilterChain) { exchange.getAttributes().computeIfAbsent(MutatorFilter.ATTRIBUTE_NAME, key -> this.context); @@ -273,11 +254,11 @@ public class SecurityMockServerConfigurers { @Override public Mono filter(ServerWebExchange exchange, WebFilterChain webFilterChain) { - Mono context = exchange.getAttribute(ATTRIBUTE_NAME); + Supplier> context = exchange.getAttribute(ATTRIBUTE_NAME); if(context != null) { exchange.getAttributes().remove(ATTRIBUTE_NAME); return webFilterChain.filter(exchange) - .subscriberContext(ReactiveSecurityContextHolder.withSecurityContext(context)); + .subscriberContext(ReactiveSecurityContextHolder.withSecurityContext(context.get())); } return webFilterChain.filter(exchange); } diff --git a/test/src/test/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurersAnnotatedTests.java b/test/src/test/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurersAnnotatedTests.java index 4d5fba24d5..cc21d9bbc1 100644 --- a/test/src/test/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurersAnnotatedTests.java +++ b/test/src/test/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurersAnnotatedTests.java @@ -20,16 +20,18 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; +import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.test.context.TestSecurityContextHolder; import org.springframework.security.test.context.annotation.SecurityTestExecutionListeners; import org.springframework.security.test.context.support.WithMockUser; +import org.springframework.security.web.server.context.SecurityContextServerWebExchangeWebFilter; import org.springframework.test.context.junit4.SpringRunner; import org.springframework.test.web.reactive.server.WebTestClient; import java.security.Principal; -import static org.springframework.security.test.web.reactive.server.SecurityMockServerConfigurers.mockPrincipal; +import static org.springframework.security.test.web.reactive.server.SecurityMockServerConfigurers.mockAuthentication; import static org.springframework.security.test.web.reactive.server.SecurityMockServerConfigurers.springSecurity; /** @@ -42,6 +44,7 @@ public class SecurityMockServerConfigurersAnnotatedTests extends AbstractMockSer WebTestClient client = WebTestClient .bindToController(controller) + .webFilter(new SecurityContextServerWebExchangeWebFilter()) .apply(springSecurity()) .configureClient() .defaultHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE) @@ -62,11 +65,12 @@ public class SecurityMockServerConfigurersAnnotatedTests extends AbstractMockSer @Test @WithMockUser public void withMockUserWhenGlobalMockPrincipalThenOverridesAnnotation() { - Principal principal = () -> "principal"; + TestingAuthenticationToken authentication = new TestingAuthenticationToken("authentication", "secret", "ROLE_USER"); client = WebTestClient .bindToController(controller) + .webFilter(new SecurityContextServerWebExchangeWebFilter()) .apply(springSecurity()) - .apply(mockPrincipal(principal)) + .apply(mockAuthentication(authentication)) .configureClient() .defaultHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE) .build(); @@ -76,33 +80,33 @@ public class SecurityMockServerConfigurersAnnotatedTests extends AbstractMockSer .exchange() .expectStatus().isOk(); - controller.assertPrincipalIsEqualTo(principal); + controller.assertPrincipalIsEqualTo(authentication); } @Test @WithMockUser public void withMockUserWhenMutateWithMockPrincipalThenOverridesAnnotation() { - Principal principal = () -> "principal"; + TestingAuthenticationToken authentication = new TestingAuthenticationToken("authentication", "secret", "ROLE_USER"); client - .mutateWith(mockPrincipal(principal)) + .mutateWith(mockAuthentication(authentication)) .get() .exchange() .expectStatus().isOk(); - controller.assertPrincipalIsEqualTo(principal); + controller.assertPrincipalIsEqualTo(authentication); } @Test @WithMockUser public void withMockUserWhenMutateWithMockPrincipalAndNoMutateThenOverridesAnnotationAndUsesAnnotation() { - Principal principal = () -> "principal"; + TestingAuthenticationToken authentication = new TestingAuthenticationToken("authentication", "secret", "ROLE_USER"); client - .mutateWith(mockPrincipal(principal)) + .mutateWith(mockAuthentication(authentication)) .get() .exchange() .expectStatus().isOk(); - controller.assertPrincipalIsEqualTo(principal); + controller.assertPrincipalIsEqualTo(authentication); client @@ -110,7 +114,6 @@ public class SecurityMockServerConfigurersAnnotatedTests extends AbstractMockSer .exchange() .expectStatus().isOk(); - principal = controller.removePrincipal(); - assertPrincipalCreatedFromUserDetails(principal, userBuilder.build()); + assertPrincipalCreatedFromUserDetails(controller.removePrincipal(), userBuilder.build()); } } diff --git a/test/src/test/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurersClassAnnotatedTests.java b/test/src/test/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurersClassAnnotatedTests.java index 71db75bc53..acda259626 100644 --- a/test/src/test/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurersClassAnnotatedTests.java +++ b/test/src/test/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurersClassAnnotatedTests.java @@ -24,6 +24,7 @@ import org.springframework.security.core.Authentication; import org.springframework.security.test.context.TestSecurityContextHolder; import org.springframework.security.test.context.annotation.SecurityTestExecutionListeners; import org.springframework.security.test.context.support.WithMockUser; +import org.springframework.security.web.server.context.SecurityContextServerWebExchangeWebFilter; import org.springframework.test.context.junit4.SpringRunner; import org.springframework.test.web.reactive.server.WebTestClient; @@ -43,6 +44,7 @@ import static org.springframework.security.test.web.reactive.server.SecurityMock public class SecurityMockServerConfigurersClassAnnotatedTests extends AbstractMockServerConfigurersTests { WebTestClient client = WebTestClient .bindToController(controller) + .webFilter(new SecurityContextServerWebExchangeWebFilter()) .apply(springSecurity()) .configureClient() .defaultHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE) diff --git a/test/src/test/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurersTests.java b/test/src/test/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurersTests.java index bcbfd90943..d97f79683a 100644 --- a/test/src/test/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurersTests.java +++ b/test/src/test/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurersTests.java @@ -22,6 +22,7 @@ import org.springframework.http.MediaType; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.userdetails.User; import org.springframework.security.core.userdetails.UserDetails; +import org.springframework.security.web.server.context.SecurityContextServerWebExchangeWebFilter; import org.springframework.test.web.reactive.server.WebTestClient; import java.security.Principal; @@ -35,55 +36,12 @@ import static org.springframework.security.test.web.reactive.server.SecurityMock public class SecurityMockServerConfigurersTests extends AbstractMockServerConfigurersTests { WebTestClient client = WebTestClient .bindToController(controller) + .webFilter(new SecurityContextServerWebExchangeWebFilter()) .apply(springSecurity()) .configureClient() .defaultHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE) .build(); - @Test - public void mockPrincipalWhenLocalThenSuccess() { - Principal principal = () -> "principal"; - client - .mutateWith(mockPrincipal(principal)) - .get() - .exchange() - .expectStatus().isOk(); - - controller.assertPrincipalIsEqualTo(principal); - } - - @Test - public void mockPrincipalWhenGlobalTheWorks() { - Principal principal = () -> "principal"; - client = WebTestClient - .bindToController(controller) - .apply(springSecurity()) - .apply(mockPrincipal(principal)) - .configureClient() - .defaultHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE) - .build(); - - client - .get() - .exchange() - .expectStatus().isOk(); - - controller.assertPrincipalIsEqualTo(principal); - } - - @Test - public void mockPrincipalWhenMultipleInvocationsThenLastInvocationWins() { - Principal principal = () -> "principal"; - client - .mutateWith(mockPrincipal(() -> "will be overridden")) - .mutateWith(mockPrincipal(principal)) - .get() - .exchange() - .expectStatus().isOk(); - - controller.assertPrincipalIsEqualTo(principal); - } - @Test public void mockAuthenticationWhenLocalThenSuccess() { TestingAuthenticationToken authentication = new TestingAuthenticationToken("authentication", "secret", "ROLE_USER"); @@ -100,6 +58,7 @@ public class SecurityMockServerConfigurersTests extends AbstractMockServerConfig TestingAuthenticationToken authentication = new TestingAuthenticationToken("authentication", "secret", "ROLE_USER"); client = WebTestClient .bindToController(controller) + .webFilter(new SecurityContextServerWebExchangeWebFilter()) .apply(springSecurity()) .apply(mockAuthentication(authentication)) .configureClient() @@ -129,6 +88,7 @@ public class SecurityMockServerConfigurersTests extends AbstractMockServerConfig public void mockUserWhenGlobalThenSuccess() { client = WebTestClient .bindToController(controller) + .webFilter(new SecurityContextServerWebExchangeWebFilter()) .apply(springSecurity()) .apply(mockUser()) .configureClient()