diff --git a/oauth2/oauth2-client/spring-security-oauth2-client.gradle b/oauth2/oauth2-client/spring-security-oauth2-client.gradle index 6e17822390..a966c27b8c 100644 --- a/oauth2/oauth2-client/spring-security-oauth2-client.gradle +++ b/oauth2/oauth2-client/spring-security-oauth2-client.gradle @@ -13,6 +13,7 @@ dependencies { optional 'com.fasterxml.jackson.core:jackson-databind' optional 'com.fasterxml.jackson.datatype:jackson-datatype-jsr310' optional 'org.springframework:spring-jdbc' + optional 'org.springframework:spring-r2dbc' testCompile project(path: ':spring-security-oauth2-core', configuration: 'tests') testCompile project(path: ':spring-security-oauth2-jose', configuration: 'tests') @@ -22,6 +23,8 @@ dependencies { testCompile 'io.projectreactor:reactor-test' testCompile 'io.projectreactor.tools:blockhound' testCompile 'org.skyscreamer:jsonassert' + testCompile 'io.r2dbc:r2dbc-h2:0.8.4.RELEASE' + testCompile 'io.r2dbc:r2dbc-spi-test:0.8.3.RELEASE' testRuntime 'org.hsqldb:hsqldb' diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/R2dbcReactiveOAuth2AuthorizedClientService.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/R2dbcReactiveOAuth2AuthorizedClientService.java new file mode 100644 index 0000000000..a4b3f8bbeb --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/R2dbcReactiveOAuth2AuthorizedClientService.java @@ -0,0 +1,387 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client; + +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.time.Instant; +import java.time.LocalDateTime; +import java.time.ZoneOffset; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Set; +import java.util.function.BiFunction; +import java.util.function.Function; + +import io.r2dbc.spi.Row; +import io.r2dbc.spi.RowMetadata; +import reactor.core.publisher.Mono; + +import org.springframework.dao.DataRetrievalFailureException; +import org.springframework.r2dbc.core.DatabaseClient; +import org.springframework.r2dbc.core.DatabaseClient.GenericExecuteSpec; +import org.springframework.r2dbc.core.Parameter; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2RefreshToken; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.util.StringUtils; + +/** + * A R2DBC implementation of {@link ReactiveOAuth2AuthorizedClientService} that uses a + * {@link DatabaseClient} for {@link OAuth2AuthorizedClient} persistence. + * + *

+ * NOTE: This {@code ReactiveOAuth2AuthorizedClientService} depends on the table + * definition described in + * "classpath:org/springframework/security/oauth2/client/oauth2-client-schema.sql" and + * therefore MUST be defined in the database schema. + * + * @author Ovidiu Popa + * @since 5.5 + * @see ReactiveOAuth2AuthorizedClientService + * @see OAuth2AuthorizedClient + * @see DatabaseClient + * + */ +public class R2dbcReactiveOAuth2AuthorizedClientService implements ReactiveOAuth2AuthorizedClientService { + + // @formatter:off + private static final String COLUMN_NAMES = + "client_registration_id, " + + "principal_name, " + + "access_token_type, " + + "access_token_value, " + + "access_token_issued_at, " + + "access_token_expires_at, " + + "access_token_scopes, " + + "refresh_token_value, " + + "refresh_token_issued_at"; + // @formatter:on + + private static final String TABLE_NAME = "oauth2_authorized_client"; + + private static final String PK_FILTER = "client_registration_id = :clientRegistrationId AND principal_name = :principalName"; + + // @formatter:off + private static final String LOAD_AUTHORIZED_CLIENT_SQL = "SELECT " + COLUMN_NAMES + " FROM " + TABLE_NAME + + " WHERE " + PK_FILTER; + // @formatter:on + + // @formatter:off + private static final String SAVE_AUTHORIZED_CLIENT_SQL = "INSERT INTO " + TABLE_NAME + " (" + COLUMN_NAMES + ")" + + "VALUES (:clientRegistrationId, :principalName, :accessTokenType, :accessTokenValue," + + " :accessTokenIssuedAt, :accessTokenExpiresAt, :accessTokenScopes, :refreshTokenValue," + + " :refreshTokenIssuedAt)"; + // @formatter:on + + private static final String REMOVE_AUTHORIZED_CLIENT_SQL = "DELETE FROM " + TABLE_NAME + " WHERE " + PK_FILTER; + + // @formatter:off + private static final String UPDATE_AUTHORIZED_CLIENT_SQL = "UPDATE " + TABLE_NAME + + " SET access_token_type = :accessTokenType, " + + " access_token_value = :accessTokenValue, " + + " access_token_issued_at = :accessTokenIssuedAt," + + " access_token_expires_at = :accessTokenExpiresAt, " + + " access_token_scopes = :accessTokenScopes," + + " refresh_token_value = :refreshTokenValue, " + + " refresh_token_issued_at = :refreshTokenIssuedAt" + + " WHERE " + + PK_FILTER; + // @formatter:on + + protected final DatabaseClient databaseClient; + + protected final ReactiveClientRegistrationRepository clientRegistrationRepository; + + protected Function> authorizedClientParametersMapper; + + protected BiFunction authorizedClientRowMapper; + + /** + * Constructs a {@code R2dbcReactiveOAuth2AuthorizedClientService} using the provided + * parameters. + * @param databaseClient the database client + * @param clientRegistrationRepository the repository of client registrations + */ + public R2dbcReactiveOAuth2AuthorizedClientService(DatabaseClient databaseClient, + ReactiveClientRegistrationRepository clientRegistrationRepository) { + Assert.notNull(databaseClient, "databaseClient cannot be null"); + Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null"); + this.databaseClient = databaseClient; + this.clientRegistrationRepository = clientRegistrationRepository; + this.authorizedClientParametersMapper = new OAuth2AuthorizedClientParametersMapper(); + this.authorizedClientRowMapper = new OAuth2AuthorizedClientRowMapper(); + } + + @Override + @SuppressWarnings("unchecked") + public Mono loadAuthorizedClient(String clientRegistrationId, + String principalName) { + Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty"); + Assert.hasText(principalName, "principalName cannot be empty"); + + return (Mono) this.databaseClient.sql(LOAD_AUTHORIZED_CLIENT_SQL) + .bind("clientRegistrationId", clientRegistrationId).bind("principalName", principalName) + .map(this.authorizedClientRowMapper).first().flatMap(this::getAuthorizedClient); + } + + private Mono getAuthorizedClient(OAuth2AuthorizedClientHolder authorizedClientHolder) { + return this.clientRegistrationRepository.findByRegistrationId(authorizedClientHolder.getClientRegistrationId()) + .switchIfEmpty( + Mono.error(dataRetrievalFailureException(authorizedClientHolder.getClientRegistrationId()))) + .map((clientRegistration) -> new OAuth2AuthorizedClient(clientRegistration, + authorizedClientHolder.getPrincipalName(), authorizedClientHolder.getAccessToken(), + authorizedClientHolder.getRefreshToken())); + } + + private static Throwable dataRetrievalFailureException(String clientRegistrationId) { + return new DataRetrievalFailureException("The ClientRegistration with id '" + clientRegistrationId + + "' exists in the data source, however, it was not found in the ReactiveClientRegistrationRepository."); + } + + @Override + public Mono saveAuthorizedClient(OAuth2AuthorizedClient authorizedClient, Authentication principal) { + Assert.notNull(authorizedClient, "authorizedClient cannot be null"); + Assert.notNull(principal, "principal cannot be null"); + return this + .loadAuthorizedClient(authorizedClient.getClientRegistration().getRegistrationId(), principal.getName()) + .flatMap((dbAuthorizedClient) -> updateAuthorizedClient(authorizedClient, principal)) + .switchIfEmpty(Mono.defer(() -> insertAuthorizedClient(authorizedClient, principal))).then(); + } + + private Mono updateAuthorizedClient(OAuth2AuthorizedClient authorizedClient, Authentication principal) { + GenericExecuteSpec executeSpec = this.databaseClient.sql(UPDATE_AUTHORIZED_CLIENT_SQL); + for (Entry entry : this.authorizedClientParametersMapper + .apply(new OAuth2AuthorizedClientHolder(authorizedClient, principal)).entrySet()) { + executeSpec = executeSpec.bind(entry.getKey(), entry.getValue()); + } + return executeSpec.fetch().rowsUpdated(); + } + + private Mono insertAuthorizedClient(OAuth2AuthorizedClient authorizedClient, Authentication principal) { + GenericExecuteSpec executeSpec = this.databaseClient.sql(SAVE_AUTHORIZED_CLIENT_SQL); + for (Entry entry : this.authorizedClientParametersMapper + .apply(new OAuth2AuthorizedClientHolder(authorizedClient, principal)).entrySet()) { + executeSpec = executeSpec.bind(entry.getKey(), entry.getValue()); + } + return executeSpec.fetch().rowsUpdated(); + } + + @Override + public Mono removeAuthorizedClient(String clientRegistrationId, String principalName) { + Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty"); + Assert.hasText(principalName, "principalName cannot be empty"); + return this.databaseClient.sql(REMOVE_AUTHORIZED_CLIENT_SQL).bind("clientRegistrationId", clientRegistrationId) + .bind("principalName", principalName).then(); + } + + /** + * Sets the {@code Function} used for mapping {@link OAuth2AuthorizedClientHolder} to + * a {@code Map} of {@link String} and {@link Parameter}. The default is + * {@link OAuth2AuthorizedClientParametersMapper}. + * @param authorizedClientParametersMapper the {@code Function} used for mapping + * {@link OAuth2AuthorizedClientHolder} to a {@code Map} of {@link String} and + * {@link Parameter} + */ + public final void setAuthorizedClientParametersMapper( + Function> authorizedClientParametersMapper) { + Assert.notNull(authorizedClientParametersMapper, "authorizedClientParametersMapper cannot be null"); + this.authorizedClientParametersMapper = authorizedClientParametersMapper; + } + + /** + * Sets the {@link BiFunction} used for mapping the current {@code io.r2dbc.spi.Row} + * to {@link OAuth2AuthorizedClientHolder}. The default is + * {@link OAuth2AuthorizedClientRowMapper}. + * @param authorizedClientRowMapper the {@link BiFunction} used for mapping the + * current {@code io.r2dbc.spi.Row} to {@link OAuth2AuthorizedClientHolder} + */ + public final void setAuthorizedClientRowMapper( + BiFunction authorizedClientRowMapper) { + Assert.notNull(authorizedClientRowMapper, "authorizedClientRowMapper cannot be null"); + this.authorizedClientRowMapper = authorizedClientRowMapper; + } + + /** + * A holder for {@link OAuth2AuthorizedClient} data and End-User + * {@link Authentication} (Resource Owner). + */ + public static final class OAuth2AuthorizedClientHolder { + + private final String clientRegistrationId; + + private final String principalName; + + private final OAuth2AccessToken accessToken; + + private final OAuth2RefreshToken refreshToken; + + /** + * Constructs an {@code OAuth2AuthorizedClientHolder} using the provided + * parameters. + * @param authorizedClient the authorized client + * @param principal the End-User {@link Authentication} (Resource Owner) + */ + public OAuth2AuthorizedClientHolder(OAuth2AuthorizedClient authorizedClient, Authentication principal) { + Assert.notNull(authorizedClient, "authorizedClient cannot be null"); + Assert.notNull(principal, "principal cannot be null"); + this.clientRegistrationId = authorizedClient.getClientRegistration().getRegistrationId(); + this.principalName = principal.getName(); + this.accessToken = authorizedClient.getAccessToken(); + this.refreshToken = authorizedClient.getRefreshToken(); + } + + /** + * Constructs an {@code OAuth2AuthorizedClientHolder} using the provided + * parameters. + * @param clientRegistrationId the client registration id + * @param principalName the principal name of the End-User (Resource Owner) + * @param accessToken the access token + * @param refreshToken the refresh token + */ + public OAuth2AuthorizedClientHolder(String clientRegistrationId, String principalName, + OAuth2AccessToken accessToken, OAuth2RefreshToken refreshToken) { + Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty"); + Assert.hasText(principalName, "principalName cannot be empty"); + Assert.notNull(accessToken, "accessToken cannot be null"); + this.clientRegistrationId = clientRegistrationId; + this.principalName = principalName; + this.accessToken = accessToken; + this.refreshToken = refreshToken; + } + + public String getClientRegistrationId() { + return this.clientRegistrationId; + } + + public String getPrincipalName() { + return this.principalName; + } + + public OAuth2AccessToken getAccessToken() { + return this.accessToken; + } + + public OAuth2RefreshToken getRefreshToken() { + return this.refreshToken; + } + + } + + /** + * The default {@code Function} that maps {@link OAuth2AuthorizedClientHolder} to a + * {@code Map} of {@link String} and {@link Parameter}. + */ + public static class OAuth2AuthorizedClientParametersMapper + implements Function> { + + @Override + public Map apply(OAuth2AuthorizedClientHolder authorizedClientHolder) { + + final Map parameters = new HashMap<>(); + + final OAuth2AccessToken accessToken = authorizedClientHolder.getAccessToken(); + final OAuth2RefreshToken refreshToken = authorizedClientHolder.getRefreshToken(); + + parameters.put("clientRegistrationId", + Parameter.fromOrEmpty(authorizedClientHolder.getClientRegistrationId(), String.class)); + parameters.put("principalName", + Parameter.fromOrEmpty(authorizedClientHolder.getPrincipalName(), String.class)); + parameters.put("accessTokenType", + Parameter.fromOrEmpty(accessToken.getTokenType().getValue(), String.class)); + parameters.put("accessTokenValue", Parameter.fromOrEmpty( + ByteBuffer.wrap(accessToken.getTokenValue().getBytes(StandardCharsets.UTF_8)), ByteBuffer.class)); + parameters.put("accessTokenIssuedAt", Parameter.fromOrEmpty( + LocalDateTime.ofInstant(accessToken.getIssuedAt(), ZoneOffset.UTC), LocalDateTime.class)); + parameters.put("accessTokenExpiresAt", Parameter.fromOrEmpty( + LocalDateTime.ofInstant(accessToken.getExpiresAt(), ZoneOffset.UTC), LocalDateTime.class)); + String accessTokenScopes = null; + if (!CollectionUtils.isEmpty(accessToken.getScopes())) { + accessTokenScopes = StringUtils.collectionToDelimitedString(accessToken.getScopes(), ","); + + } + parameters.put("accessTokenScopes", Parameter.fromOrEmpty(accessTokenScopes, String.class)); + ByteBuffer refreshTokenValue = null; + LocalDateTime refreshTokenIssuedAt = null; + if (refreshToken != null) { + refreshTokenValue = ByteBuffer.wrap(refreshToken.getTokenValue().getBytes(StandardCharsets.UTF_8)); + if (refreshToken.getIssuedAt() != null) { + refreshTokenIssuedAt = LocalDateTime.ofInstant(refreshToken.getIssuedAt(), ZoneOffset.UTC); + } + + } + + parameters.put("refreshTokenValue", Parameter.fromOrEmpty(refreshTokenValue, ByteBuffer.class)); + parameters.put("refreshTokenIssuedAt", Parameter.fromOrEmpty(refreshTokenIssuedAt, LocalDateTime.class)); + return parameters; + } + + } + + /** + * The default {@link BiFunction} that maps the current {@code io.r2dbc.spi.Row} to a + * {@link OAuth2AuthorizedClientHolder}. + */ + public static class OAuth2AuthorizedClientRowMapper + implements BiFunction { + + @Override + public OAuth2AuthorizedClientHolder apply(Row row, RowMetadata rowMetadata) { + + String dbClientRegistrationId = row.get("client_registration_id", String.class); + OAuth2AccessToken.TokenType tokenType = null; + if (OAuth2AccessToken.TokenType.BEARER.getValue() + .equalsIgnoreCase(row.get("access_token_type", String.class))) { + tokenType = OAuth2AccessToken.TokenType.BEARER; + } + String tokenValue = new String(row.get("access_token_value", ByteBuffer.class).array(), + StandardCharsets.UTF_8); + Instant issuedAt = row.get("access_token_issued_at", LocalDateTime.class).toInstant(ZoneOffset.UTC); + Instant expiresAt = row.get("access_token_expires_at", LocalDateTime.class).toInstant(ZoneOffset.UTC); + + Set scopes = Collections.emptySet(); + String accessTokenScopes = row.get("access_token_scopes", String.class); + if (accessTokenScopes != null) { + scopes = StringUtils.commaDelimitedListToSet(accessTokenScopes); + } + final OAuth2AccessToken accessToken = new OAuth2AccessToken(tokenType, tokenValue, issuedAt, expiresAt, + scopes); + + OAuth2RefreshToken refreshToken = null; + ByteBuffer refreshTokenValue = row.get("refresh_token_value", ByteBuffer.class); + if (refreshTokenValue != null) { + tokenValue = new String(refreshTokenValue.array(), StandardCharsets.UTF_8); + issuedAt = null; + LocalDateTime refreshTokenIssuedAt = row.get("refresh_token_issued_at", LocalDateTime.class); + if (refreshTokenIssuedAt != null) { + issuedAt = refreshTokenIssuedAt.toInstant(ZoneOffset.UTC); + } + refreshToken = new OAuth2RefreshToken(tokenValue, issuedAt); + } + + String dbPrincipalName = row.get("principal_name", String.class); + return new OAuth2AuthorizedClientHolder(dbClientRegistrationId, dbPrincipalName, accessToken, refreshToken); + } + + } + +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/R2dbcReactiveOAuth2AuthorizedClientServiceTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/R2dbcReactiveOAuth2AuthorizedClientServiceTests.java new file mode 100644 index 0000000000..41c5e9a0c5 --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/R2dbcReactiveOAuth2AuthorizedClientServiceTests.java @@ -0,0 +1,381 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client; + +import io.r2dbc.h2.H2ConnectionFactory; +import io.r2dbc.spi.ConnectionFactory; +import io.r2dbc.spi.Result; +import org.junit.Before; +import org.junit.Test; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import org.springframework.core.io.ClassPathResource; +import org.springframework.dao.DataRetrievalFailureException; +import org.springframework.r2dbc.connection.init.CompositeDatabasePopulator; +import org.springframework.r2dbc.connection.init.ConnectionFactoryInitializer; +import org.springframework.r2dbc.connection.init.ResourceDatabasePopulator; +import org.springframework.r2dbc.core.DatabaseClient; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2RefreshToken; +import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; +import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; + +/** + * Tests for {@link R2dbcReactiveOAuth2AuthorizedClientService} + * + * @author Ovidiu Popa + * + */ +public class R2dbcReactiveOAuth2AuthorizedClientServiceTests { + + private static final String OAUTH2_CLIENT_SCHEMA_SQL_RESOURCE = "org/springframework/security/oauth2/client/oauth2-client-schema.sql"; + + private ClientRegistration clientRegistration; + + private ReactiveClientRegistrationRepository clientRegistrationRepository; + + private DatabaseClient databaseClient; + + private static int principalId = 1000; + + private R2dbcReactiveOAuth2AuthorizedClientService authorizedClientService; + + @Before + public void setUp() { + final ConnectionFactory connectionFactory = createDb(); + this.clientRegistration = TestClientRegistrations.clientRegistration().build(); + this.clientRegistrationRepository = mock(ReactiveClientRegistrationRepository.class); + given(this.clientRegistrationRepository.findByRegistrationId(anyString())) + .willReturn(Mono.just(this.clientRegistration)); + this.databaseClient = DatabaseClient.create(connectionFactory); + this.authorizedClientService = new R2dbcReactiveOAuth2AuthorizedClientService(this.databaseClient, + this.clientRegistrationRepository); + } + + @Test + public void constructorWhenDatabaseClientIsNullThenThrowIllegalArgumentException() { + assertThatExceptionOfType(IllegalArgumentException.class) + .isThrownBy( + () -> new R2dbcReactiveOAuth2AuthorizedClientService(null, this.clientRegistrationRepository)) + .withMessageContaining("databaseClient cannot be null"); + } + + @Test + public void constructorWhenClientRegistrationRepositoryIsNullThenThrowIllegalArgumentException() { + assertThatExceptionOfType(IllegalArgumentException.class) + .isThrownBy(() -> new R2dbcReactiveOAuth2AuthorizedClientService(this.databaseClient, null)) + .withMessageContaining("clientRegistrationRepository cannot be null"); + } + + @Test + public void loadAuthorizedClientWhenClientRegistrationIdIsNullThenThrowIllegalArgumentException() { + assertThatExceptionOfType(IllegalArgumentException.class) + .isThrownBy(() -> this.authorizedClientService.loadAuthorizedClient(null, "principalName")) + .withMessageContaining("clientRegistrationId cannot be empty"); + } + + @Test + public void loadAuthorizedClientWhenPrincipalNameIsNullThenThrowIllegalArgumentException() { + assertThatExceptionOfType(IllegalArgumentException.class) + .isThrownBy(() -> this.authorizedClientService + .loadAuthorizedClient(this.clientRegistration.getRegistrationId(), null)) + .withMessageContaining("principalName cannot be empty"); + } + + @Test + public void loadAuthorizedClientWhenDoesNotExistThenReturnNull() { + this.authorizedClientService.loadAuthorizedClient("registration-not-found", "principalName") + .as(StepVerifier::create).expectNextCount(0).verifyComplete(); + } + + @Test + public void loadAuthorizedClientWhenExistsThenReturnAuthorizedClient() { + Authentication principal = createPrincipal(); + OAuth2AuthorizedClient expected = createAuthorizedClient(principal, this.clientRegistration); + this.authorizedClientService.saveAuthorizedClient(expected, principal).as(StepVerifier::create) + .verifyComplete(); + + this.authorizedClientService + .loadAuthorizedClient(this.clientRegistration.getRegistrationId(), principal.getName()) + .as(StepVerifier::create).assertNext((authorizedClient) -> { + assertThat(authorizedClient).isNotNull(); + assertThat(authorizedClient.getClientRegistration()).isEqualTo(expected.getClientRegistration()); + assertThat(authorizedClient.getPrincipalName()).isEqualTo(expected.getPrincipalName()); + assertThat(authorizedClient.getAccessToken().getTokenType()) + .isEqualTo(expected.getAccessToken().getTokenType()); + assertThat(authorizedClient.getAccessToken().getTokenValue()) + .isEqualTo(expected.getAccessToken().getTokenValue()); + assertThat(authorizedClient.getAccessToken().getIssuedAt()) + .isEqualTo(expected.getAccessToken().getIssuedAt()); + assertThat(authorizedClient.getAccessToken().getExpiresAt()) + .isEqualTo(expected.getAccessToken().getExpiresAt()); + assertThat(authorizedClient.getAccessToken().getScopes()) + .isEqualTo(expected.getAccessToken().getScopes()); + assertThat(authorizedClient.getRefreshToken().getTokenValue()) + .isEqualTo(expected.getRefreshToken().getTokenValue()); + assertThat(authorizedClient.getRefreshToken().getIssuedAt()) + .isEqualTo(expected.getRefreshToken().getIssuedAt()); + }).verifyComplete(); + } + + @Test + public void loadAuthorizedClientWhenExistsButNotFoundInClientRegistrationRepositoryThenThrowDataRetrievalFailureException() { + given(this.clientRegistrationRepository.findByRegistrationId(any())).willReturn(Mono.empty()); + Authentication principal = createPrincipal(); + OAuth2AuthorizedClient expected = createAuthorizedClient(principal, this.clientRegistration); + + this.authorizedClientService.saveAuthorizedClient(expected, principal).as(StepVerifier::create) + .verifyComplete(); + + this.authorizedClientService + .loadAuthorizedClient(this.clientRegistration.getRegistrationId(), principal.getName()) + .as(StepVerifier::create) + .verifyErrorSatisfies((exception) -> assertThat(exception) + .isInstanceOf(DataRetrievalFailureException.class) + .hasMessage("The ClientRegistration with id '" + this.clientRegistration.getRegistrationId() + + "' exists in the data source, however, it was not found in the ReactiveClientRegistrationRepository.")); + } + + @Test + public void saveAuthorizedClientWhenAuthorizedClientIsNullThenThrowIllegalArgumentException() { + Authentication principal = createPrincipal(); + + assertThatExceptionOfType(IllegalArgumentException.class) + .isThrownBy(() -> this.authorizedClientService.saveAuthorizedClient(null, principal)) + .withMessageContaining("authorizedClient cannot be null"); + } + + @Test + public void saveAuthorizedClientWhenPrincipalIsNullThenThrowIllegalArgumentException() { + Authentication principal = createPrincipal(); + OAuth2AuthorizedClient authorizedClient = createAuthorizedClient(principal, this.clientRegistration); + assertThatExceptionOfType(IllegalArgumentException.class) + .isThrownBy(() -> this.authorizedClientService.saveAuthorizedClient(authorizedClient, null)) + .withMessageContaining("principal cannot be null"); + } + + @Test + public void saveAuthorizedClientWhenSaveThenLoadReturnsSaved() { + Authentication principal = createPrincipal(); + final OAuth2AuthorizedClient expected = createAuthorizedClient(principal, this.clientRegistration); + + this.authorizedClientService.saveAuthorizedClient(expected, principal).as(StepVerifier::create) + .verifyComplete(); + + this.authorizedClientService + .loadAuthorizedClient(this.clientRegistration.getRegistrationId(), principal.getName()) + .as(StepVerifier::create).assertNext((authorizedClient) -> { + assertThat(authorizedClient).isNotNull(); + assertThat(authorizedClient.getClientRegistration()).isEqualTo(expected.getClientRegistration()); + assertThat(authorizedClient.getPrincipalName()).isEqualTo(expected.getPrincipalName()); + assertThat(authorizedClient.getAccessToken().getTokenType()) + .isEqualTo(expected.getAccessToken().getTokenType()); + assertThat(authorizedClient.getAccessToken().getTokenValue()) + .isEqualTo(expected.getAccessToken().getTokenValue()); + assertThat(authorizedClient.getAccessToken().getIssuedAt()) + .isEqualTo(expected.getAccessToken().getIssuedAt()); + assertThat(authorizedClient.getAccessToken().getExpiresAt()) + .isEqualTo(expected.getAccessToken().getExpiresAt()); + assertThat(authorizedClient.getAccessToken().getScopes()) + .isEqualTo(expected.getAccessToken().getScopes()); + assertThat(authorizedClient.getRefreshToken().getTokenValue()) + .isEqualTo(expected.getRefreshToken().getTokenValue()); + assertThat(authorizedClient.getRefreshToken().getIssuedAt()) + .isEqualTo(expected.getRefreshToken().getIssuedAt()); + }).verifyComplete(); + + // Test save/load of NOT NULL attributes only + principal = createPrincipal(); + OAuth2AuthorizedClient updatedExpectedPrincipal = createAuthorizedClient(principal, this.clientRegistration, + true); + this.authorizedClientService.saveAuthorizedClient(updatedExpectedPrincipal, principal).as(StepVerifier::create) + .verifyComplete(); + + this.authorizedClientService + .loadAuthorizedClient(this.clientRegistration.getRegistrationId(), principal.getName()) + .as(StepVerifier::create).assertNext((authorizedClient) -> { + assertThat(authorizedClient).isNotNull(); + assertThat(authorizedClient.getClientRegistration()) + .isEqualTo(updatedExpectedPrincipal.getClientRegistration()); + assertThat(authorizedClient.getPrincipalName()) + .isEqualTo(updatedExpectedPrincipal.getPrincipalName()); + assertThat(authorizedClient.getAccessToken().getTokenType()) + .isEqualTo(updatedExpectedPrincipal.getAccessToken().getTokenType()); + assertThat(authorizedClient.getAccessToken().getTokenValue()) + .isEqualTo(updatedExpectedPrincipal.getAccessToken().getTokenValue()); + assertThat(authorizedClient.getAccessToken().getIssuedAt()) + .isEqualTo(updatedExpectedPrincipal.getAccessToken().getIssuedAt()); + assertThat(authorizedClient.getAccessToken().getExpiresAt()) + .isEqualTo(updatedExpectedPrincipal.getAccessToken().getExpiresAt()); + assertThat(authorizedClient.getAccessToken().getScopes()).isEmpty(); + assertThat(authorizedClient.getRefreshToken()).isNull(); + }).verifyComplete(); + } + + @Test + public void saveAuthorizedClientWhenSaveClientWithExistingPrimaryKeyThenUpdate() { + // Given a saved authorized client + Authentication principal = createPrincipal(); + OAuth2AuthorizedClient authorizedClient = createAuthorizedClient(principal, this.clientRegistration); + this.authorizedClientService.saveAuthorizedClient(authorizedClient, principal).as(StepVerifier::create) + .verifyComplete(); + + // When a client with the same principal and registration id is saved + OAuth2AuthorizedClient updatedAuthorizedClient = createAuthorizedClient(principal, this.clientRegistration); + this.authorizedClientService.saveAuthorizedClient(updatedAuthorizedClient, principal).as(StepVerifier::create) + .verifyComplete(); + + // Then the saved client is updated + this.authorizedClientService + .loadAuthorizedClient(this.clientRegistration.getRegistrationId(), principal.getName()) + .as(StepVerifier::create).assertNext((savedClient) -> { + assertThat(savedClient).isNotNull(); + assertThat(savedClient.getClientRegistration()) + .isEqualTo(updatedAuthorizedClient.getClientRegistration()); + assertThat(savedClient.getPrincipalName()).isEqualTo(updatedAuthorizedClient.getPrincipalName()); + assertThat(savedClient.getAccessToken().getTokenType()) + .isEqualTo(updatedAuthorizedClient.getAccessToken().getTokenType()); + assertThat(savedClient.getAccessToken().getTokenValue()) + .isEqualTo(updatedAuthorizedClient.getAccessToken().getTokenValue()); + assertThat(savedClient.getAccessToken().getIssuedAt()) + .isEqualTo(updatedAuthorizedClient.getAccessToken().getIssuedAt()); + assertThat(savedClient.getAccessToken().getExpiresAt()) + .isEqualTo(updatedAuthorizedClient.getAccessToken().getExpiresAt()); + assertThat(savedClient.getAccessToken().getScopes()) + .isEqualTo(updatedAuthorizedClient.getAccessToken().getScopes()); + assertThat(savedClient.getRefreshToken().getTokenValue()) + .isEqualTo(updatedAuthorizedClient.getRefreshToken().getTokenValue()); + assertThat(savedClient.getRefreshToken().getIssuedAt()) + .isEqualTo(updatedAuthorizedClient.getRefreshToken().getIssuedAt()); + }); + } + + @Test + public void removeAuthorizedClientWhenClientRegistrationIdIsNullThenThrowIllegalArgumentException() { + assertThatExceptionOfType(IllegalArgumentException.class) + .isThrownBy(() -> this.authorizedClientService.removeAuthorizedClient(null, "principalName")) + .withMessageContaining("clientRegistrationId cannot be empty"); + } + + @Test + public void removeAuthorizedClientWhenPrincipalNameIsNullThenThrowIllegalArgumentException() { + assertThatExceptionOfType(IllegalArgumentException.class) + .isThrownBy(() -> this.authorizedClientService + .removeAuthorizedClient(this.clientRegistration.getRegistrationId(), null)) + .withMessageContaining("principalName cannot be empty"); + } + + @Test + public void removeAuthorizedClientWhenExistsThenRemoved() { + Authentication principal = createPrincipal(); + OAuth2AuthorizedClient authorizedClient = createAuthorizedClient(principal, this.clientRegistration); + + this.authorizedClientService.saveAuthorizedClient(authorizedClient, principal).as(StepVerifier::create) + .verifyComplete(); + + this.authorizedClientService + .loadAuthorizedClient(this.clientRegistration.getRegistrationId(), principal.getName()) + .as(StepVerifier::create).assertNext((dbAuthorizedClient) -> assertThat(dbAuthorizedClient).isNotNull()) + .verifyComplete(); + + this.authorizedClientService + .removeAuthorizedClient(this.clientRegistration.getRegistrationId(), principal.getName()) + .as(StepVerifier::create).verifyComplete(); + + this.authorizedClientService + .loadAuthorizedClient(this.clientRegistration.getRegistrationId(), principal.getName()) + .as(StepVerifier::create).expectNextCount(0).verifyComplete(); + } + + @Test + public void setAuthorizedClientRowMapperWhenNullThenThrowIllegalArgumentException() { + assertThatExceptionOfType(IllegalArgumentException.class) + .isThrownBy(() -> this.authorizedClientService.setAuthorizedClientRowMapper(null)) + .withMessageContaining("authorizedClientRowMapper cannot be nul"); + } + + @Test + public void setAuthorizedClientParametersMapperWhenNullThenThrowIllegalArgumentException() { + assertThatExceptionOfType(IllegalArgumentException.class) + .isThrownBy(() -> this.authorizedClientService.setAuthorizedClientParametersMapper(null)) + .withMessageContaining("authorizedClientParametersMapper cannot be nul"); + } + + private static ConnectionFactory createDb() { + ConnectionFactory connectionFactory = H2ConnectionFactory.inMemory("oauth-test"); + + Mono.from(connectionFactory.create()) + .flatMapMany((connection) -> Flux + .from(connection.createStatement("drop table oauth2_authorized_client").execute()) + .flatMap(Result::getRowsUpdated).onErrorResume((e) -> Mono.empty()) + .thenMany(connection.close())) + .as(StepVerifier::create).verifyComplete(); + ConnectionFactoryInitializer createDb = createDb(OAUTH2_CLIENT_SCHEMA_SQL_RESOURCE); + createDb.setConnectionFactory(connectionFactory); + createDb.afterPropertiesSet(); + return connectionFactory; + } + + private static ConnectionFactoryInitializer createDb(String schema) { + ConnectionFactoryInitializer initializer = new ConnectionFactoryInitializer(); + + CompositeDatabasePopulator populator = new CompositeDatabasePopulator(); + populator.addPopulators(new ResourceDatabasePopulator(new ClassPathResource(schema))); + initializer.setDatabasePopulator(populator); + return initializer; + } + + private static Authentication createPrincipal() { + return new TestingAuthenticationToken("principal-" + principalId++, "password"); + } + + private static OAuth2AuthorizedClient createAuthorizedClient(Authentication principal, + ClientRegistration clientRegistration) { + return createAuthorizedClient(principal, clientRegistration, false); + } + + private static OAuth2AuthorizedClient createAuthorizedClient(Authentication principal, + ClientRegistration clientRegistration, boolean requiredAttributesOnly) { + OAuth2AccessToken accessToken; + if (!requiredAttributesOnly) { + accessToken = TestOAuth2AccessTokens.scopes("read", "write"); + } + else { + accessToken = TestOAuth2AccessTokens.noScopes(); + } + OAuth2RefreshToken refreshToken = null; + if (!requiredAttributesOnly) { + refreshToken = TestOAuth2RefreshTokens.refreshToken(); + } + return new OAuth2AuthorizedClient(clientRegistration, principal.getName(), accessToken, refreshToken); + } + +}