From 93cda9496945589bd81a292a62d281a5a83ca02b Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Fri, 6 Sep 2019 07:22:06 -0400 Subject: [PATCH] Add attributes Consumer to OAuth2AuthorizationContext Fixes gh-7385 --- ...ClientServiceOAuth2AuthorizedClientManager.java | 8 +++++++- .../oauth2/client/OAuth2AuthorizationContext.java | 14 +++++++++----- .../web/DefaultOAuth2AuthorizedClientManager.java | 8 +++++++- ...faultReactiveOAuth2AuthorizedClientManager.java | 13 +++++++++++-- .../client/OAuth2AuthorizationContextTests.java | 6 ++++-- 5 files changed, 38 insertions(+), 11 deletions(-) diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizedClientServiceOAuth2AuthorizedClientManager.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizedClientServiceOAuth2AuthorizedClientManager.java index 1072a973fb..0ceabd2cad 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizedClientServiceOAuth2AuthorizedClientManager.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizedClientServiceOAuth2AuthorizedClientManager.java @@ -21,6 +21,7 @@ import org.springframework.security.oauth2.client.registration.ClientRegistratio import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; import java.util.Collections; @@ -83,7 +84,12 @@ public final class AuthorizedClientServiceOAuth2AuthorizedClientManager implemen } OAuth2AuthorizationContext authorizationContext = contextBuilder .principal(principal) - .attributes(this.contextAttributesMapper.apply(authorizeRequest)) + .attributes(attributes -> { + Map contextAttributes = this.contextAttributesMapper.apply(authorizeRequest); + if (!CollectionUtils.isEmpty(contextAttributes)) { + attributes.putAll(contextAttributes); + } + }) .build(); authorizedClient = this.authorizedClientProvider.authorize(authorizationContext); diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizationContext.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizationContext.java index ffaa445bb2..8bac099ae7 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizationContext.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizationContext.java @@ -25,6 +25,7 @@ import java.util.Collections; import java.util.HashMap; import java.util.LinkedHashMap; import java.util.Map; +import java.util.function.Consumer; /** * A context that holds authorization-specific state and is used by an {@link OAuth2AuthorizedClientProvider} @@ -161,13 +162,16 @@ public final class OAuth2AuthorizationContext { } /** - * Sets the attributes associated to the context. + * Provides a {@link Consumer} access to the attributes associated to the context. * - * @param attributes the attributes associated to the context - * @return the {@link Builder} + * @param attributesConsumer a {@link Consumer} of the attributes associated to the context + * @return the {@link OAuth2AuthorizeRequest.Builder} */ - public Builder attributes(Map attributes) { - this.attributes = attributes; + public Builder attributes(Consumer> attributesConsumer) { + if (this.attributes == null) { + this.attributes = new HashMap<>(); + } + attributesConsumer.accept(this.attributes); return this; } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManager.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManager.java index 688d6ffb63..19719dc7c4 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManager.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManager.java @@ -26,6 +26,7 @@ import org.springframework.security.oauth2.client.registration.ClientRegistratio import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; import org.springframework.web.context.request.RequestContextHolder; import org.springframework.web.context.request.ServletRequestAttributes; @@ -95,7 +96,12 @@ public final class DefaultOAuth2AuthorizedClientManager implements OAuth2Authori } OAuth2AuthorizationContext authorizationContext = contextBuilder .principal(principal) - .attributes(this.contextAttributesMapper.apply(authorizeRequest)) + .attributes(attributes -> { + Map contextAttributes = this.contextAttributesMapper.apply(authorizeRequest); + if (!CollectionUtils.isEmpty(contextAttributes)) { + attributes.putAll(contextAttributes); + } + }) .build(); authorizedClient = this.authorizedClientProvider.authorize(authorizationContext); diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultReactiveOAuth2AuthorizedClientManager.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultReactiveOAuth2AuthorizedClientManager.java index 43ea8d10f2..a04b8ee04c 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultReactiveOAuth2AuthorizedClientManager.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultReactiveOAuth2AuthorizedClientManager.java @@ -26,6 +26,7 @@ import org.springframework.security.oauth2.client.registration.ReactiveClientReg import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; import org.springframework.web.server.ServerWebExchange; import reactor.core.publisher.Mono; @@ -106,7 +107,11 @@ public final class DefaultReactiveOAuth2AuthorizedClientManager implements React .flatMap(this.contextAttributesMapper::apply) .map(attrs -> OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient) .principal(authorizeRequest.getPrincipal()) - .attributes(attrs) + .attributes(attributes -> { + if (!CollectionUtils.isEmpty(attrs)) { + attributes.putAll(attrs); + } + }) .build()); } @@ -116,7 +121,11 @@ public final class DefaultReactiveOAuth2AuthorizedClientManager implements React .flatMap(this.contextAttributesMapper::apply) .map(attrs -> OAuth2AuthorizationContext.withClientRegistration(clientRegistration) .principal(authorizeRequest.getPrincipal()) - .attributes(attrs) + .attributes(attributes -> { + if (!CollectionUtils.isEmpty(attrs)) { + attributes.putAll(attrs); + } + }) .build()); } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizationContextTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizationContextTests.java index efa307459c..9749b6ded7 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizationContextTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizationContextTests.java @@ -68,8 +68,10 @@ public class OAuth2AuthorizationContextTests { public void withAuthorizedClientWhenAllValuesProvidedThenAllValuesAreSet() { OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext.withAuthorizedClient(this.authorizedClient) .principal(this.principal) - .attribute("attribute1", "value1") - .attribute("attribute2", "value2") + .attributes(attributes -> { + attributes.put("attribute1", "value1"); + attributes.put("attribute2", "value2"); + }) .build(); assertThat(authorizationContext.getClientRegistration()).isSameAs(this.clientRegistration); assertThat(authorizationContext.getAuthorizedClient()).isSameAs(this.authorizedClient);