diff --git a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java index a46e7841b3..f98b89b617 100644 --- a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java +++ b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java @@ -3214,6 +3214,8 @@ public class ServerHttpSecurity { private ReactiveAuthenticationManager authenticationManager; + private ServerAuthenticationConverter serverAuthenticationConverter; + private X509Spec() { } @@ -3227,11 +3229,17 @@ public class ServerHttpSecurity { return this; } + public X509Spec serverAuthenticationConverter(ServerAuthenticationConverter serverAuthenticationConverter) { + this.serverAuthenticationConverter = serverAuthenticationConverter; + return this; + } + protected void configure(ServerHttpSecurity http) { ReactiveAuthenticationManager authenticationManager = getAuthenticationManager(); X509PrincipalExtractor principalExtractor = getPrincipalExtractor(); + ServerAuthenticationConverter converter = getServerAuthenticationConverter(principalExtractor); AuthenticationWebFilter filter = new AuthenticationWebFilter(authenticationManager); - filter.setServerAuthenticationConverter(new ServerX509AuthenticationConverter(principalExtractor)); + filter.setServerAuthenticationConverter(serverAuthenticationConverter); http.addFilterAt(filter, SecurityWebFiltersOrder.AUTHENTICATION); } @@ -3250,6 +3258,13 @@ public class ServerHttpSecurity { return new ReactivePreAuthenticatedAuthenticationManager(userDetailsService); } + private ServerAuthenticationConverter getServerAuthenticationConverter(X509PrincipalExtractor extractor) { + if (this.serverAuthenticationConverter != null) { + return this.serverAuthenticationConverter; + } + return new ServerX509AuthenticationConverter(extractor); + } + } public final class OAuth2LoginSpec { diff --git a/config/src/test/java/org/springframework/security/config/web/server/ServerHttpSecurityTests.java b/config/src/test/java/org/springframework/security/config/web/server/ServerHttpSecurityTests.java index 0aedcb2de8..4f4763bafb 100644 --- a/config/src/test/java/org/springframework/security/config/web/server/ServerHttpSecurityTests.java +++ b/config/src/test/java/org/springframework/security/config/web/server/ServerHttpSecurityTests.java @@ -60,6 +60,7 @@ import org.springframework.security.web.server.authentication.AnonymousAuthentic import org.springframework.security.web.server.authentication.DelegatingServerAuthenticationSuccessHandler; import org.springframework.security.web.server.authentication.HttpBasicServerAuthenticationEntryPoint; import org.springframework.security.web.server.authentication.HttpStatusServerEntryPoint; +import org.springframework.security.web.server.authentication.ServerAuthenticationConverter; import org.springframework.security.web.server.authentication.ServerAuthenticationFailureHandler; import org.springframework.security.web.server.authentication.ServerAuthenticationSuccessHandler; import org.springframework.security.web.server.authentication.ServerX509AuthenticationConverter; @@ -497,6 +498,17 @@ public class ServerHttpSecurityTests { assertThat(x509WebFilter).isNotNull(); } + @Test + public void x509WithConverterAndNoExtractorThenAddsX509Filter() { + ServerAuthenticationConverter mockConverter = mock(ServerAuthenticationConverter.class); + this.http.x509((x509) -> x509.serverAuthenticationConverter(mockConverter)); + SecurityWebFilterChain securityWebFilterChain = this.http.build(); + WebFilter x509WebFilter = securityWebFilterChain.getWebFilters() + .filter(filter -> matchesX509Converter(filter, mockConverter)) + .blockFirst(); + assertThat(x509WebFilter).isNotNull(); + } + @Test public void addsX509FilterWhenX509AuthenticationIsConfiguredWithDefaults() { this.http.x509(withDefaults()); @@ -769,6 +781,17 @@ public class ServerHttpSecurityTests { } } + private boolean matchesX509Converter(WebFilter filter, ServerAuthenticationConverter expectedConverter) { + try { + Object converter = ReflectionTestUtils.getField(filter, "authenticationConverter"); + return converter.equals(expectedConverter); + } + catch (IllegalArgumentException ex) { + // field doesn't exist + return false; + } + } + private Optional getWebFilter(SecurityWebFilterChain filterChain, Class filterClass) { return (Optional) filterChain.getWebFilters() .filter(Objects::nonNull)