From 1d57a084aa49c3ff6f3abbeae1e7f7784bcafd64 Mon Sep 17 00:00:00 2001 From: Rob Winch Date: Tue, 31 Jul 2018 17:07:09 -0500 Subject: [PATCH] Add ServerOAuth2AuthorizedClientRepository Fixes: gh-5621 --- ...erverOAuth2AuthorizedClientRepository.java | 109 +++++++++ ...erverOAuth2AuthorizedClientRepository.java | 81 +++++++ ...erverOAuth2AuthorizedClientRepository.java | 96 ++++++++ ...OAuth2AuthorizedClientRepositoryTests.java | 132 +++++++++++ ...OAuth2AuthorizedClientRepositoryTests.java | 208 ++++++++++++++++++ 5 files changed, 626 insertions(+) create mode 100644 oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository.java create mode 100644 oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/ServerOAuth2AuthorizedClientRepository.java create mode 100644 oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/WebSessionServerOAuth2AuthorizedClientRepository.java create mode 100644 oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/AuthenticatedPrincipalServerOAuth2AuthorizedClientRepositoryTests.java create mode 100644 oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/WebSessionServerOAuth2AuthorizedClientRepositoryTests.java diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository.java new file mode 100644 index 0000000000..e89838c158 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository.java @@ -0,0 +1,109 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client.web.server; + +import org.springframework.security.authentication.AuthenticationTrustResolver; +import org.springframework.security.authentication.AuthenticationTrustResolverImpl; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService; +import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientService; +import org.springframework.security.oauth2.client.web.HttpSessionOAuth2AuthorizedClientRepository; +import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; +import org.springframework.util.Assert; +import org.springframework.web.server.ServerWebExchange; +import reactor.core.publisher.Mono; + +/** + * An implementation of an {@link ServerOAuth2AuthorizedClientRepository} that + * delegates to the provided {@link ServerOAuth2AuthorizedClientRepository} if the current + * {@code Principal} is authenticated, otherwise, + * to the default (or provided) {@link ServerOAuth2AuthorizedClientRepository} + * if the current request is unauthenticated (or anonymous). + * The default {@code ReactiveOAuth2AuthorizedClientRepository} is + * {@link WebSessionServerOAuth2AuthorizedClientRepository}. + * + * @author Rob Winch + * @since 5.1 + * @see OAuth2AuthorizedClientRepository + * @see OAuth2AuthorizedClient + * @see OAuth2AuthorizedClientService + * @see HttpSessionOAuth2AuthorizedClientRepository + */ +public final class AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository + implements ServerOAuth2AuthorizedClientRepository { + private final AuthenticationTrustResolver authenticationTrustResolver = new AuthenticationTrustResolverImpl(); + private final ReactiveOAuth2AuthorizedClientService authorizedClientService; + private ServerOAuth2AuthorizedClientRepository anonymousAuthorizedClientRepository = new WebSessionServerOAuth2AuthorizedClientRepository(); + + /** + * Creates an instance + * + * @param authorizedClientService the authorized client service + */ + public AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository(ReactiveOAuth2AuthorizedClientService authorizedClientService) { + Assert.notNull(authorizedClientService, "authorizedClientService cannot be null"); + this.authorizedClientService = authorizedClientService; + } + + /** + * Sets the {@link ServerOAuth2AuthorizedClientRepository} used for requests that are unauthenticated (or anonymous). + * The default is {@link WebSessionServerOAuth2AuthorizedClientRepository}. + * + * @param anonymousAuthorizedClientRepository the repository used for requests that are unauthenticated (or anonymous) + */ + public final void setAnonymousAuthorizedClientRepository( + ServerOAuth2AuthorizedClientRepository anonymousAuthorizedClientRepository) { + Assert.notNull(anonymousAuthorizedClientRepository, "anonymousAuthorizedClientRepository cannot be null"); + this.anonymousAuthorizedClientRepository = anonymousAuthorizedClientRepository; + } + + @Override + public Mono loadAuthorizedClient(String clientRegistrationId, Authentication principal, + ServerWebExchange exchange) { + if (this.isPrincipalAuthenticated(principal)) { + return this.authorizedClientService.loadAuthorizedClient(clientRegistrationId, principal.getName()); + } else { + return this.anonymousAuthorizedClientRepository.loadAuthorizedClient(clientRegistrationId, principal, exchange); + } + } + + @Override + public Mono saveAuthorizedClient(OAuth2AuthorizedClient authorizedClient, Authentication principal, + ServerWebExchange exchange) { + if (this.isPrincipalAuthenticated(principal)) { + return this.authorizedClientService.saveAuthorizedClient(authorizedClient, principal); + } else { + return this.anonymousAuthorizedClientRepository.saveAuthorizedClient(authorizedClient, principal, exchange); + } + } + + @Override + public Mono removeAuthorizedClient(String clientRegistrationId, Authentication principal, + ServerWebExchange exchange) { + if (this.isPrincipalAuthenticated(principal)) { + return this.authorizedClientService.removeAuthorizedClient(clientRegistrationId, principal.getName()); + } else { + return this.anonymousAuthorizedClientRepository.removeAuthorizedClient(clientRegistrationId, principal, exchange); + } + } + + private boolean isPrincipalAuthenticated(Authentication authentication) { + return authentication != null && + !this.authenticationTrustResolver.isAnonymous(authentication) && + authentication.isAuthenticated(); + } +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/ServerOAuth2AuthorizedClientRepository.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/ServerOAuth2AuthorizedClientRepository.java new file mode 100644 index 0000000000..432b25a168 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/ServerOAuth2AuthorizedClientRepository.java @@ -0,0 +1,81 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client.web.server; + +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.web.server.ServerWebExchange; +import reactor.core.publisher.Mono; + +/** + * Implementations of this interface are responsible for the persistence + * of {@link OAuth2AuthorizedClient Authorized Client(s)} between requests. + * + *

+ * The primary purpose of an {@link OAuth2AuthorizedClient Authorized Client} + * is to associate an {@link OAuth2AuthorizedClient#getAccessToken() Access Token} credential + * to a {@link OAuth2AuthorizedClient#getClientRegistration() Client} and Resource Owner, + * who is the {@link OAuth2AuthorizedClient#getPrincipalName() Principal} + * that originally granted the authorization. + * + * @author Rob Winch + * @since 5.1 + * @see OAuth2AuthorizedClient + * @see ClientRegistration + * @see Authentication + * @see OAuth2AccessToken + */ +public interface ServerOAuth2AuthorizedClientRepository { + + /** + * Returns the {@link OAuth2AuthorizedClient} associated to the + * provided client registration identifier and End-User {@link Authentication} (Resource Owner) + * or {@code null} if not available. + * + * @param clientRegistrationId the identifier for the client's registration + * @param principal the End-User {@link Authentication} (Resource Owner) + * @param exchange the {@code ServerWebExchange} + * @param a type of OAuth2AuthorizedClient + * @return the {@link OAuth2AuthorizedClient} or {@code null} if not available + */ + Mono loadAuthorizedClient(String clientRegistrationId, + Authentication principal, ServerWebExchange exchange); + + /** + * Saves the {@link OAuth2AuthorizedClient} associating it to + * the provided End-User {@link Authentication} (Resource Owner). + * + * @param authorizedClient the authorized client + * @param principal the End-User {@link Authentication} (Resource Owner) + * @param exchange the {@code ServerWebExchange} + */ + Mono saveAuthorizedClient(OAuth2AuthorizedClient authorizedClient, + Authentication principal, ServerWebExchange exchange); + + /** + * Removes the {@link OAuth2AuthorizedClient} associated to the + * provided client registration identifier and End-User {@link Authentication} (Resource Owner). + * + * @param clientRegistrationId the identifier for the client's registration + * @param principal the End-User {@link Authentication} (Resource Owner) + * @param exchange the {@code ServerWebExchange} + */ + Mono removeAuthorizedClient(String clientRegistrationId, Authentication principal, + ServerWebExchange exchange); + +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/WebSessionServerOAuth2AuthorizedClientRepository.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/WebSessionServerOAuth2AuthorizedClientRepository.java new file mode 100644 index 0000000000..3284235abc --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/WebSessionServerOAuth2AuthorizedClientRepository.java @@ -0,0 +1,96 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client.web.server; + +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; +import org.springframework.util.Assert; +import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.server.WebSession; +import reactor.core.publisher.Mono; + +import java.util.HashMap; +import java.util.Map; + +/** + * An implementation of an {@link OAuth2AuthorizedClientRepository} that stores + * {@link OAuth2AuthorizedClient}'s in the {@code HttpSession}. + * + * @author Rob Winch + * @since 5.1 + * @see OAuth2AuthorizedClientRepository + * @see OAuth2AuthorizedClient + */ +public final class WebSessionServerOAuth2AuthorizedClientRepository + implements ServerOAuth2AuthorizedClientRepository { + private static final String DEFAULT_AUTHORIZED_CLIENTS_ATTR_NAME = + WebSessionServerOAuth2AuthorizedClientRepository.class.getName() + ".AUTHORIZED_CLIENTS"; + private final String sessionAttributeName = DEFAULT_AUTHORIZED_CLIENTS_ATTR_NAME; + + @Override + @SuppressWarnings("unchecked") + public Mono loadAuthorizedClient(String clientRegistrationId, Authentication principal, + ServerWebExchange exchange) { + Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty"); + Assert.notNull(exchange, "exchange cannot be null"); + return exchange.getSession() + .map(this::getAuthorizedClients) + .flatMap(clients -> Mono.justOrEmpty((T) clients.get(clientRegistrationId))); + } + + @Override + public Mono saveAuthorizedClient(OAuth2AuthorizedClient authorizedClient, Authentication principal, + ServerWebExchange exchange) { + Assert.notNull(authorizedClient, "authorizedClient cannot be null"); + Assert.notNull(exchange, "exchange cannot be null"); + return exchange.getSession() + .doOnSuccess(session -> { + Map authorizedClients = getAuthorizedClients(session); + authorizedClients.put(authorizedClient.getClientRegistration().getRegistrationId(), authorizedClient); + session.getAttributes().put(this.sessionAttributeName, authorizedClients); + }) + .then(Mono.empty()); + } + + @Override + public Mono removeAuthorizedClient(String clientRegistrationId, Authentication principal, + ServerWebExchange exchange) { + Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty"); + Assert.notNull(exchange, "exchange cannot be null"); + return exchange.getSession() + .doOnSuccess(session -> { + Map authorizedClients = getAuthorizedClients(session); + authorizedClients.remove(clientRegistrationId); + if (authorizedClients.isEmpty()) { + session.getAttributes().remove(this.sessionAttributeName); + } else { + session.getAttributes().put(this.sessionAttributeName, authorizedClients); + } + }) + .then(Mono.empty()); + } + + @SuppressWarnings("unchecked") + private Map getAuthorizedClients(WebSession session) { + Map authorizedClients = session == null ? null : + (Map) session.getAttribute(this.sessionAttributeName); + if (authorizedClients == null) { + authorizedClients = new HashMap<>(); + } + return authorizedClients; + } +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/AuthenticatedPrincipalServerOAuth2AuthorizedClientRepositoryTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/AuthenticatedPrincipalServerOAuth2AuthorizedClientRepositoryTests.java new file mode 100644 index 0000000000..5fbe27b4e9 --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/AuthenticatedPrincipalServerOAuth2AuthorizedClientRepositoryTests.java @@ -0,0 +1,132 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client.web.server; + +import org.junit.Before; +import org.junit.Test; +import org.springframework.mock.http.server.reactive.MockServerHttpRequest; +import org.springframework.mock.web.server.MockServerWebExchange; +import org.springframework.security.authentication.AnonymousAuthenticationToken; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.authority.AuthorityUtils; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientService; +import org.springframework.security.oauth2.client.web.AuthenticatedPrincipalOAuth2AuthorizedClientRepository; +import org.springframework.security.oauth2.client.web.server.AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository; +import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository; +import reactor.core.publisher.Mono; + +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +/** + * + * @author Rob Winch + */ +public class AuthenticatedPrincipalServerOAuth2AuthorizedClientRepositoryTests { + private String registrationId = "registrationId"; + private String principalName = "principalName"; + private ReactiveOAuth2AuthorizedClientService authorizedClientService; + private ServerOAuth2AuthorizedClientRepository anonymousAuthorizedClientRepository; + private AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository authorizedClientRepository; + + private MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/")); + + @Before + public void setup() { + this.authorizedClientService = mock(ReactiveOAuth2AuthorizedClientService.class); + this.anonymousAuthorizedClientRepository = mock( + ServerOAuth2AuthorizedClientRepository.class); + this.authorizedClientRepository = new AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository(this.authorizedClientService); + this.authorizedClientRepository.setAnonymousAuthorizedClientRepository(this.anonymousAuthorizedClientRepository); + } + + @Test + public void constructorWhenAuthorizedClientServiceIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new AuthenticatedPrincipalOAuth2AuthorizedClientRepository(null)) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void setAuthorizedClientRepositoryWhenAuthorizedClientRepositoryIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizedClientRepository.setAnonymousAuthorizedClientRepository(null)) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void loadAuthorizedClientWhenAuthenticatedPrincipalThenLoadFromService() { + when(this.authorizedClientService.loadAuthorizedClient(any(), any())).thenReturn(Mono.empty()); + Authentication authentication = this.createAuthenticatedPrincipal(); + this.authorizedClientRepository.loadAuthorizedClient(this.registrationId, authentication, this.exchange).block(); + verify(this.authorizedClientService).loadAuthorizedClient(this.registrationId, this.principalName); + } + + @Test + public void loadAuthorizedClientWhenAnonymousPrincipalThenLoadFromAnonymousRepository() { + when(this.anonymousAuthorizedClientRepository.loadAuthorizedClient(any(), any(), any())).thenReturn(Mono.empty()); + Authentication authentication = this.createAnonymousPrincipal(); + this.authorizedClientRepository.loadAuthorizedClient(this.registrationId, authentication, this.exchange).block(); + verify(this.anonymousAuthorizedClientRepository).loadAuthorizedClient(this.registrationId, authentication, this.exchange); + } + + @Test + public void saveAuthorizedClientWhenAuthenticatedPrincipalThenSaveToService() { + when(this.authorizedClientService.saveAuthorizedClient(any(), any())).thenReturn(Mono.empty()); + Authentication authentication = this.createAuthenticatedPrincipal(); + OAuth2AuthorizedClient authorizedClient = mock(OAuth2AuthorizedClient.class); + this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, authentication, this.exchange).block(); + verify(this.authorizedClientService).saveAuthorizedClient(authorizedClient, authentication); + } + + @Test + public void saveAuthorizedClientWhenAnonymousPrincipalThenSaveToAnonymousRepository() { + when(this.anonymousAuthorizedClientRepository.saveAuthorizedClient(any(), any(), any())).thenReturn(Mono.empty()); + Authentication authentication = this.createAnonymousPrincipal(); + OAuth2AuthorizedClient authorizedClient = mock(OAuth2AuthorizedClient.class); + this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, authentication, this.exchange).block(); + verify(this.anonymousAuthorizedClientRepository).saveAuthorizedClient(authorizedClient, authentication, this.exchange); + } + + @Test + public void removeAuthorizedClientWhenAuthenticatedPrincipalThenRemoveFromService() { + when(this.authorizedClientService.removeAuthorizedClient(any(), any())).thenReturn(Mono.empty()); + Authentication authentication = this.createAuthenticatedPrincipal(); + this.authorizedClientRepository.removeAuthorizedClient(this.registrationId, authentication, this.exchange).block(); + verify(this.authorizedClientService).removeAuthorizedClient(this.registrationId, this.principalName); + } + + @Test + public void removeAuthorizedClientWhenAnonymousPrincipalThenRemoveFromAnonymousRepository() { + when(this.anonymousAuthorizedClientRepository.removeAuthorizedClient(any(), any(), any())).thenReturn(Mono.empty()); + Authentication authentication = this.createAnonymousPrincipal(); + this.authorizedClientRepository.removeAuthorizedClient(this.registrationId, authentication, this.exchange).block(); + verify(this.anonymousAuthorizedClientRepository).removeAuthorizedClient(this.registrationId, authentication, this.exchange); + } + + private Authentication createAuthenticatedPrincipal() { + TestingAuthenticationToken authentication = new TestingAuthenticationToken(this.principalName, "password"); + authentication.setAuthenticated(true); + return authentication; + } + + private Authentication createAnonymousPrincipal() { + return new AnonymousAuthenticationToken("key-1234", "anonymousUser", AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS")); + } +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/WebSessionServerOAuth2AuthorizedClientRepositoryTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/WebSessionServerOAuth2AuthorizedClientRepositoryTests.java new file mode 100644 index 0000000000..d95f15305e --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/WebSessionServerOAuth2AuthorizedClientRepositoryTests.java @@ -0,0 +1,208 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client.web.server; + +import org.junit.Test; +import org.springframework.mock.http.server.reactive.MockServerHttpRequest; +import org.springframework.mock.web.server.MockServerWebExchange; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.client.web.server.WebSessionServerOAuth2AuthorizedClientRepository; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.web.server.WebSession; + +import static org.assertj.core.api.Assertions.*; +import static org.mockito.Mockito.mock; + +/** + * @author Rob Winch + * @since 5.1 + */ +public class WebSessionServerOAuth2AuthorizedClientRepositoryTests { + private WebSessionServerOAuth2AuthorizedClientRepository authorizedClientRepository = + new WebSessionServerOAuth2AuthorizedClientRepository(); + + private MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/")); + + private ClientRegistration registration1 = TestClientRegistrations.clientRegistration().build(); + + private ClientRegistration registration2 = TestClientRegistrations.clientRegistration2().build(); + + private String registrationId1 = this.registration1.getRegistrationId(); + private String registrationId2 = this.registration2.getRegistrationId(); + private String principalName1 = "principalName-1"; + + + @Test + public void loadAuthorizedClientWhenClientRegistrationIdIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizedClientRepository.loadAuthorizedClient(null, null, this.exchange).block()) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void loadAuthorizedClientWhenPrincipalNameIsNullThenExceptionNotThrown() { + this.authorizedClientRepository.loadAuthorizedClient(this.registrationId1, null, this.exchange).block(); + } + + @Test + public void loadAuthorizedClientWhenRequestIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizedClientRepository.loadAuthorizedClient(this.registrationId1, null, null).block()) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void loadAuthorizedClientWhenClientRegistrationNotFoundThenReturnNull() { + OAuth2AuthorizedClient authorizedClient = + this.authorizedClientRepository.loadAuthorizedClient("registration-not-found", null, this.exchange).block(); + assertThat(authorizedClient).isNull(); + } + + @Test + public void loadAuthorizedClientWhenSavedThenReturnAuthorizedClient() { + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( + this.registration1, this.principalName1, mock(OAuth2AccessToken.class)); + this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, null, this.exchange).block(); + + OAuth2AuthorizedClient loadedAuthorizedClient = + this.authorizedClientRepository.loadAuthorizedClient(this.registrationId1, null, this.exchange).block(); + assertThat(loadedAuthorizedClient).isEqualTo(authorizedClient); + } + + @Test + public void saveAuthorizedClientWhenAuthorizedClientIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizedClientRepository.saveAuthorizedClient(null, null, this.exchange).block()) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void saveAuthorizedClientWhenAuthenticationIsNullThenExceptionNotThrown() { + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( + this.registration2, this.principalName1, mock(OAuth2AccessToken.class)); + this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, null, this.exchange).block(); + } + + @Test + public void saveAuthorizedClientWhenRequestIsNullThenThrowIllegalArgumentException() { + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( + this.registration2, this.principalName1, mock(OAuth2AccessToken.class)); + assertThatThrownBy(() -> this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, null, null).block()) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void saveAuthorizedClientWhenSavedThenSavedToSession() { + OAuth2AuthorizedClient expected = new OAuth2AuthorizedClient( + this.registration2, this.principalName1, mock(OAuth2AccessToken.class)); + this.authorizedClientRepository.saveAuthorizedClient(expected, null, this.exchange).block(); + + OAuth2AuthorizedClient result = this.authorizedClientRepository + .loadAuthorizedClient(this.registrationId2, null, this.exchange).block(); + + assertThat(result).isEqualTo(expected); + } + + @Test + public void removeAuthorizedClientWhenClientRegistrationIdIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizedClientRepository.removeAuthorizedClient( + null, null, this.exchange)).isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void removeAuthorizedClientWhenPrincipalNameIsNullThenExceptionNotThrown() { + this.authorizedClientRepository.removeAuthorizedClient(this.registrationId1, null, this.exchange); + } + + @Test + public void removeAuthorizedClientWhenRequestIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizedClientRepository.removeAuthorizedClient( + this.registrationId1, null, null)).isInstanceOf(IllegalArgumentException.class); + } + + + @Test + public void removeAuthorizedClientWhenNotSavedThenSessionNotCreated() { + this.authorizedClientRepository.removeAuthorizedClient( + this.registrationId2, null, this.exchange); + assertThat(this.exchange.getSession().block().isStarted()).isFalse(); + } + + @Test + public void removeAuthorizedClientWhenClient1SavedAndClient2RemovedThenClient1NotRemoved() { + OAuth2AuthorizedClient authorizedClient1 = new OAuth2AuthorizedClient( + this.registration1, this.principalName1, mock(OAuth2AccessToken.class)); + this.authorizedClientRepository.saveAuthorizedClient(authorizedClient1, null, this.exchange).block(); + + // Remove registrationId2 (never added so is not removed either) + this.authorizedClientRepository.removeAuthorizedClient( + this.registrationId2, null, this.exchange); + + OAuth2AuthorizedClient loadedAuthorizedClient1 = this.authorizedClientRepository.loadAuthorizedClient( + this.registrationId1, null, this.exchange).block(); + assertThat(loadedAuthorizedClient1).isNotNull(); + assertThat(loadedAuthorizedClient1).isSameAs(authorizedClient1); + } + + @Test + public void removeAuthorizedClientWhenSavedThenRemoved() { + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( + this.registration2, this.principalName1, mock(OAuth2AccessToken.class)); + this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, null, this.exchange).block(); + OAuth2AuthorizedClient loadedAuthorizedClient = this.authorizedClientRepository.loadAuthorizedClient( + this.registrationId2, null, this.exchange).block(); + assertThat(loadedAuthorizedClient).isSameAs(authorizedClient); + this.authorizedClientRepository.removeAuthorizedClient( + this.registrationId2, null, this.exchange).block(); + loadedAuthorizedClient = this.authorizedClientRepository.loadAuthorizedClient( + this.registrationId2, null, this.exchange).block(); + assertThat(loadedAuthorizedClient).isNull(); + } + + @Test + public void removeAuthorizedClientWhenSavedThenRemovedFromSession() { + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( + this.registration1, this.principalName1, mock(OAuth2AccessToken.class)); + this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, null, this.exchange).block(); + OAuth2AuthorizedClient loadedAuthorizedClient = this.authorizedClientRepository.loadAuthorizedClient( + this.registrationId1, null, this.exchange).block(); + assertThat(loadedAuthorizedClient).isSameAs(authorizedClient); + this.authorizedClientRepository.removeAuthorizedClient( + this.registrationId1, null, this.exchange).block(); + + WebSession session = this.exchange.getSession().block(); + assertThat(session).isNotNull(); + assertThat(session.getAttributes()).isEmpty(); + } + + @Test + public void removeAuthorizedClientWhenClient1Client2SavedAndClient1RemovedThenClient2NotRemoved() { + OAuth2AuthorizedClient authorizedClient1 = new OAuth2AuthorizedClient( + this.registration1, this.principalName1, mock(OAuth2AccessToken.class)); + this.authorizedClientRepository.saveAuthorizedClient(authorizedClient1, null, this.exchange).block(); + + OAuth2AuthorizedClient authorizedClient2 = new OAuth2AuthorizedClient( + this.registration2, this.principalName1, mock(OAuth2AccessToken.class)); + this.authorizedClientRepository.saveAuthorizedClient(authorizedClient2, null, this.exchange).block(); + + this.authorizedClientRepository.removeAuthorizedClient( + this.registrationId1, null, this.exchange).block(); + + OAuth2AuthorizedClient loadedAuthorizedClient2 = this.authorizedClientRepository.loadAuthorizedClient( + this.registrationId2, null, this.exchange).block(); + assertThat(loadedAuthorizedClient2).isNotNull(); + assertThat(loadedAuthorizedClient2).isSameAs(authorizedClient2); + } +}