diff --git a/oauth2/oauth2-client/spring-security-oauth2-client.gradle b/oauth2/oauth2-client/spring-security-oauth2-client.gradle
index 6e17822390..a966c27b8c 100644
--- a/oauth2/oauth2-client/spring-security-oauth2-client.gradle
+++ b/oauth2/oauth2-client/spring-security-oauth2-client.gradle
@@ -13,6 +13,7 @@ dependencies {
optional 'com.fasterxml.jackson.core:jackson-databind'
optional 'com.fasterxml.jackson.datatype:jackson-datatype-jsr310'
optional 'org.springframework:spring-jdbc'
+ optional 'org.springframework:spring-r2dbc'
testCompile project(path: ':spring-security-oauth2-core', configuration: 'tests')
testCompile project(path: ':spring-security-oauth2-jose', configuration: 'tests')
@@ -22,6 +23,8 @@ dependencies {
testCompile 'io.projectreactor:reactor-test'
testCompile 'io.projectreactor.tools:blockhound'
testCompile 'org.skyscreamer:jsonassert'
+ testCompile 'io.r2dbc:r2dbc-h2:0.8.4.RELEASE'
+ testCompile 'io.r2dbc:r2dbc-spi-test:0.8.3.RELEASE'
testRuntime 'org.hsqldb:hsqldb'
diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/R2dbcReactiveOAuth2AuthorizedClientService.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/R2dbcReactiveOAuth2AuthorizedClientService.java
new file mode 100644
index 0000000000..a4b3f8bbeb
--- /dev/null
+++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/R2dbcReactiveOAuth2AuthorizedClientService.java
@@ -0,0 +1,387 @@
+/*
+ * Copyright 2002-2020 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.oauth2.client;
+
+import java.nio.ByteBuffer;
+import java.nio.charset.StandardCharsets;
+import java.time.Instant;
+import java.time.LocalDateTime;
+import java.time.ZoneOffset;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Map.Entry;
+import java.util.Set;
+import java.util.function.BiFunction;
+import java.util.function.Function;
+
+import io.r2dbc.spi.Row;
+import io.r2dbc.spi.RowMetadata;
+import reactor.core.publisher.Mono;
+
+import org.springframework.dao.DataRetrievalFailureException;
+import org.springframework.r2dbc.core.DatabaseClient;
+import org.springframework.r2dbc.core.DatabaseClient.GenericExecuteSpec;
+import org.springframework.r2dbc.core.Parameter;
+import org.springframework.security.core.Authentication;
+import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
+import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.security.oauth2.core.OAuth2RefreshToken;
+import org.springframework.util.Assert;
+import org.springframework.util.CollectionUtils;
+import org.springframework.util.StringUtils;
+
+/**
+ * A R2DBC implementation of {@link ReactiveOAuth2AuthorizedClientService} that uses a
+ * {@link DatabaseClient} for {@link OAuth2AuthorizedClient} persistence.
+ *
+ *
+ * NOTE: This {@code ReactiveOAuth2AuthorizedClientService} depends on the table
+ * definition described in
+ * "classpath:org/springframework/security/oauth2/client/oauth2-client-schema.sql" and
+ * therefore MUST be defined in the database schema.
+ *
+ * @author Ovidiu Popa
+ * @since 5.5
+ * @see ReactiveOAuth2AuthorizedClientService
+ * @see OAuth2AuthorizedClient
+ * @see DatabaseClient
+ *
+ */
+public class R2dbcReactiveOAuth2AuthorizedClientService implements ReactiveOAuth2AuthorizedClientService {
+
+ // @formatter:off
+ private static final String COLUMN_NAMES =
+ "client_registration_id, " +
+ "principal_name, " +
+ "access_token_type, " +
+ "access_token_value, " +
+ "access_token_issued_at, " +
+ "access_token_expires_at, " +
+ "access_token_scopes, " +
+ "refresh_token_value, " +
+ "refresh_token_issued_at";
+ // @formatter:on
+
+ private static final String TABLE_NAME = "oauth2_authorized_client";
+
+ private static final String PK_FILTER = "client_registration_id = :clientRegistrationId AND principal_name = :principalName";
+
+ // @formatter:off
+ private static final String LOAD_AUTHORIZED_CLIENT_SQL = "SELECT " + COLUMN_NAMES + " FROM " + TABLE_NAME
+ + " WHERE " + PK_FILTER;
+ // @formatter:on
+
+ // @formatter:off
+ private static final String SAVE_AUTHORIZED_CLIENT_SQL = "INSERT INTO " + TABLE_NAME + " (" + COLUMN_NAMES + ")" +
+ "VALUES (:clientRegistrationId, :principalName, :accessTokenType, :accessTokenValue," +
+ " :accessTokenIssuedAt, :accessTokenExpiresAt, :accessTokenScopes, :refreshTokenValue," +
+ " :refreshTokenIssuedAt)";
+ // @formatter:on
+
+ private static final String REMOVE_AUTHORIZED_CLIENT_SQL = "DELETE FROM " + TABLE_NAME + " WHERE " + PK_FILTER;
+
+ // @formatter:off
+ private static final String UPDATE_AUTHORIZED_CLIENT_SQL = "UPDATE " + TABLE_NAME +
+ " SET access_token_type = :accessTokenType, " +
+ " access_token_value = :accessTokenValue, " +
+ " access_token_issued_at = :accessTokenIssuedAt," +
+ " access_token_expires_at = :accessTokenExpiresAt, " +
+ " access_token_scopes = :accessTokenScopes," +
+ " refresh_token_value = :refreshTokenValue, " +
+ " refresh_token_issued_at = :refreshTokenIssuedAt" +
+ " WHERE " +
+ PK_FILTER;
+ // @formatter:on
+
+ protected final DatabaseClient databaseClient;
+
+ protected final ReactiveClientRegistrationRepository clientRegistrationRepository;
+
+ protected Function> authorizedClientParametersMapper;
+
+ protected BiFunction authorizedClientRowMapper;
+
+ /**
+ * Constructs a {@code R2dbcReactiveOAuth2AuthorizedClientService} using the provided
+ * parameters.
+ * @param databaseClient the database client
+ * @param clientRegistrationRepository the repository of client registrations
+ */
+ public R2dbcReactiveOAuth2AuthorizedClientService(DatabaseClient databaseClient,
+ ReactiveClientRegistrationRepository clientRegistrationRepository) {
+ Assert.notNull(databaseClient, "databaseClient cannot be null");
+ Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null");
+ this.databaseClient = databaseClient;
+ this.clientRegistrationRepository = clientRegistrationRepository;
+ this.authorizedClientParametersMapper = new OAuth2AuthorizedClientParametersMapper();
+ this.authorizedClientRowMapper = new OAuth2AuthorizedClientRowMapper();
+ }
+
+ @Override
+ @SuppressWarnings("unchecked")
+ public Mono loadAuthorizedClient(String clientRegistrationId,
+ String principalName) {
+ Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty");
+ Assert.hasText(principalName, "principalName cannot be empty");
+
+ return (Mono) this.databaseClient.sql(LOAD_AUTHORIZED_CLIENT_SQL)
+ .bind("clientRegistrationId", clientRegistrationId).bind("principalName", principalName)
+ .map(this.authorizedClientRowMapper).first().flatMap(this::getAuthorizedClient);
+ }
+
+ private Mono getAuthorizedClient(OAuth2AuthorizedClientHolder authorizedClientHolder) {
+ return this.clientRegistrationRepository.findByRegistrationId(authorizedClientHolder.getClientRegistrationId())
+ .switchIfEmpty(
+ Mono.error(dataRetrievalFailureException(authorizedClientHolder.getClientRegistrationId())))
+ .map((clientRegistration) -> new OAuth2AuthorizedClient(clientRegistration,
+ authorizedClientHolder.getPrincipalName(), authorizedClientHolder.getAccessToken(),
+ authorizedClientHolder.getRefreshToken()));
+ }
+
+ private static Throwable dataRetrievalFailureException(String clientRegistrationId) {
+ return new DataRetrievalFailureException("The ClientRegistration with id '" + clientRegistrationId
+ + "' exists in the data source, however, it was not found in the ReactiveClientRegistrationRepository.");
+ }
+
+ @Override
+ public Mono saveAuthorizedClient(OAuth2AuthorizedClient authorizedClient, Authentication principal) {
+ Assert.notNull(authorizedClient, "authorizedClient cannot be null");
+ Assert.notNull(principal, "principal cannot be null");
+ return this
+ .loadAuthorizedClient(authorizedClient.getClientRegistration().getRegistrationId(), principal.getName())
+ .flatMap((dbAuthorizedClient) -> updateAuthorizedClient(authorizedClient, principal))
+ .switchIfEmpty(Mono.defer(() -> insertAuthorizedClient(authorizedClient, principal))).then();
+ }
+
+ private Mono updateAuthorizedClient(OAuth2AuthorizedClient authorizedClient, Authentication principal) {
+ GenericExecuteSpec executeSpec = this.databaseClient.sql(UPDATE_AUTHORIZED_CLIENT_SQL);
+ for (Entry entry : this.authorizedClientParametersMapper
+ .apply(new OAuth2AuthorizedClientHolder(authorizedClient, principal)).entrySet()) {
+ executeSpec = executeSpec.bind(entry.getKey(), entry.getValue());
+ }
+ return executeSpec.fetch().rowsUpdated();
+ }
+
+ private Mono insertAuthorizedClient(OAuth2AuthorizedClient authorizedClient, Authentication principal) {
+ GenericExecuteSpec executeSpec = this.databaseClient.sql(SAVE_AUTHORIZED_CLIENT_SQL);
+ for (Entry entry : this.authorizedClientParametersMapper
+ .apply(new OAuth2AuthorizedClientHolder(authorizedClient, principal)).entrySet()) {
+ executeSpec = executeSpec.bind(entry.getKey(), entry.getValue());
+ }
+ return executeSpec.fetch().rowsUpdated();
+ }
+
+ @Override
+ public Mono removeAuthorizedClient(String clientRegistrationId, String principalName) {
+ Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty");
+ Assert.hasText(principalName, "principalName cannot be empty");
+ return this.databaseClient.sql(REMOVE_AUTHORIZED_CLIENT_SQL).bind("clientRegistrationId", clientRegistrationId)
+ .bind("principalName", principalName).then();
+ }
+
+ /**
+ * Sets the {@code Function} used for mapping {@link OAuth2AuthorizedClientHolder} to
+ * a {@code Map} of {@link String} and {@link Parameter}. The default is
+ * {@link OAuth2AuthorizedClientParametersMapper}.
+ * @param authorizedClientParametersMapper the {@code Function} used for mapping
+ * {@link OAuth2AuthorizedClientHolder} to a {@code Map} of {@link String} and
+ * {@link Parameter}
+ */
+ public final void setAuthorizedClientParametersMapper(
+ Function> authorizedClientParametersMapper) {
+ Assert.notNull(authorizedClientParametersMapper, "authorizedClientParametersMapper cannot be null");
+ this.authorizedClientParametersMapper = authorizedClientParametersMapper;
+ }
+
+ /**
+ * Sets the {@link BiFunction} used for mapping the current {@code io.r2dbc.spi.Row}
+ * to {@link OAuth2AuthorizedClientHolder}. The default is
+ * {@link OAuth2AuthorizedClientRowMapper}.
+ * @param authorizedClientRowMapper the {@link BiFunction} used for mapping the
+ * current {@code io.r2dbc.spi.Row} to {@link OAuth2AuthorizedClientHolder}
+ */
+ public final void setAuthorizedClientRowMapper(
+ BiFunction authorizedClientRowMapper) {
+ Assert.notNull(authorizedClientRowMapper, "authorizedClientRowMapper cannot be null");
+ this.authorizedClientRowMapper = authorizedClientRowMapper;
+ }
+
+ /**
+ * A holder for {@link OAuth2AuthorizedClient} data and End-User
+ * {@link Authentication} (Resource Owner).
+ */
+ public static final class OAuth2AuthorizedClientHolder {
+
+ private final String clientRegistrationId;
+
+ private final String principalName;
+
+ private final OAuth2AccessToken accessToken;
+
+ private final OAuth2RefreshToken refreshToken;
+
+ /**
+ * Constructs an {@code OAuth2AuthorizedClientHolder} using the provided
+ * parameters.
+ * @param authorizedClient the authorized client
+ * @param principal the End-User {@link Authentication} (Resource Owner)
+ */
+ public OAuth2AuthorizedClientHolder(OAuth2AuthorizedClient authorizedClient, Authentication principal) {
+ Assert.notNull(authorizedClient, "authorizedClient cannot be null");
+ Assert.notNull(principal, "principal cannot be null");
+ this.clientRegistrationId = authorizedClient.getClientRegistration().getRegistrationId();
+ this.principalName = principal.getName();
+ this.accessToken = authorizedClient.getAccessToken();
+ this.refreshToken = authorizedClient.getRefreshToken();
+ }
+
+ /**
+ * Constructs an {@code OAuth2AuthorizedClientHolder} using the provided
+ * parameters.
+ * @param clientRegistrationId the client registration id
+ * @param principalName the principal name of the End-User (Resource Owner)
+ * @param accessToken the access token
+ * @param refreshToken the refresh token
+ */
+ public OAuth2AuthorizedClientHolder(String clientRegistrationId, String principalName,
+ OAuth2AccessToken accessToken, OAuth2RefreshToken refreshToken) {
+ Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty");
+ Assert.hasText(principalName, "principalName cannot be empty");
+ Assert.notNull(accessToken, "accessToken cannot be null");
+ this.clientRegistrationId = clientRegistrationId;
+ this.principalName = principalName;
+ this.accessToken = accessToken;
+ this.refreshToken = refreshToken;
+ }
+
+ public String getClientRegistrationId() {
+ return this.clientRegistrationId;
+ }
+
+ public String getPrincipalName() {
+ return this.principalName;
+ }
+
+ public OAuth2AccessToken getAccessToken() {
+ return this.accessToken;
+ }
+
+ public OAuth2RefreshToken getRefreshToken() {
+ return this.refreshToken;
+ }
+
+ }
+
+ /**
+ * The default {@code Function} that maps {@link OAuth2AuthorizedClientHolder} to a
+ * {@code Map} of {@link String} and {@link Parameter}.
+ */
+ public static class OAuth2AuthorizedClientParametersMapper
+ implements Function> {
+
+ @Override
+ public Map apply(OAuth2AuthorizedClientHolder authorizedClientHolder) {
+
+ final Map parameters = new HashMap<>();
+
+ final OAuth2AccessToken accessToken = authorizedClientHolder.getAccessToken();
+ final OAuth2RefreshToken refreshToken = authorizedClientHolder.getRefreshToken();
+
+ parameters.put("clientRegistrationId",
+ Parameter.fromOrEmpty(authorizedClientHolder.getClientRegistrationId(), String.class));
+ parameters.put("principalName",
+ Parameter.fromOrEmpty(authorizedClientHolder.getPrincipalName(), String.class));
+ parameters.put("accessTokenType",
+ Parameter.fromOrEmpty(accessToken.getTokenType().getValue(), String.class));
+ parameters.put("accessTokenValue", Parameter.fromOrEmpty(
+ ByteBuffer.wrap(accessToken.getTokenValue().getBytes(StandardCharsets.UTF_8)), ByteBuffer.class));
+ parameters.put("accessTokenIssuedAt", Parameter.fromOrEmpty(
+ LocalDateTime.ofInstant(accessToken.getIssuedAt(), ZoneOffset.UTC), LocalDateTime.class));
+ parameters.put("accessTokenExpiresAt", Parameter.fromOrEmpty(
+ LocalDateTime.ofInstant(accessToken.getExpiresAt(), ZoneOffset.UTC), LocalDateTime.class));
+ String accessTokenScopes = null;
+ if (!CollectionUtils.isEmpty(accessToken.getScopes())) {
+ accessTokenScopes = StringUtils.collectionToDelimitedString(accessToken.getScopes(), ",");
+
+ }
+ parameters.put("accessTokenScopes", Parameter.fromOrEmpty(accessTokenScopes, String.class));
+ ByteBuffer refreshTokenValue = null;
+ LocalDateTime refreshTokenIssuedAt = null;
+ if (refreshToken != null) {
+ refreshTokenValue = ByteBuffer.wrap(refreshToken.getTokenValue().getBytes(StandardCharsets.UTF_8));
+ if (refreshToken.getIssuedAt() != null) {
+ refreshTokenIssuedAt = LocalDateTime.ofInstant(refreshToken.getIssuedAt(), ZoneOffset.UTC);
+ }
+
+ }
+
+ parameters.put("refreshTokenValue", Parameter.fromOrEmpty(refreshTokenValue, ByteBuffer.class));
+ parameters.put("refreshTokenIssuedAt", Parameter.fromOrEmpty(refreshTokenIssuedAt, LocalDateTime.class));
+ return parameters;
+ }
+
+ }
+
+ /**
+ * The default {@link BiFunction} that maps the current {@code io.r2dbc.spi.Row} to a
+ * {@link OAuth2AuthorizedClientHolder}.
+ */
+ public static class OAuth2AuthorizedClientRowMapper
+ implements BiFunction {
+
+ @Override
+ public OAuth2AuthorizedClientHolder apply(Row row, RowMetadata rowMetadata) {
+
+ String dbClientRegistrationId = row.get("client_registration_id", String.class);
+ OAuth2AccessToken.TokenType tokenType = null;
+ if (OAuth2AccessToken.TokenType.BEARER.getValue()
+ .equalsIgnoreCase(row.get("access_token_type", String.class))) {
+ tokenType = OAuth2AccessToken.TokenType.BEARER;
+ }
+ String tokenValue = new String(row.get("access_token_value", ByteBuffer.class).array(),
+ StandardCharsets.UTF_8);
+ Instant issuedAt = row.get("access_token_issued_at", LocalDateTime.class).toInstant(ZoneOffset.UTC);
+ Instant expiresAt = row.get("access_token_expires_at", LocalDateTime.class).toInstant(ZoneOffset.UTC);
+
+ Set scopes = Collections.emptySet();
+ String accessTokenScopes = row.get("access_token_scopes", String.class);
+ if (accessTokenScopes != null) {
+ scopes = StringUtils.commaDelimitedListToSet(accessTokenScopes);
+ }
+ final OAuth2AccessToken accessToken = new OAuth2AccessToken(tokenType, tokenValue, issuedAt, expiresAt,
+ scopes);
+
+ OAuth2RefreshToken refreshToken = null;
+ ByteBuffer refreshTokenValue = row.get("refresh_token_value", ByteBuffer.class);
+ if (refreshTokenValue != null) {
+ tokenValue = new String(refreshTokenValue.array(), StandardCharsets.UTF_8);
+ issuedAt = null;
+ LocalDateTime refreshTokenIssuedAt = row.get("refresh_token_issued_at", LocalDateTime.class);
+ if (refreshTokenIssuedAt != null) {
+ issuedAt = refreshTokenIssuedAt.toInstant(ZoneOffset.UTC);
+ }
+ refreshToken = new OAuth2RefreshToken(tokenValue, issuedAt);
+ }
+
+ String dbPrincipalName = row.get("principal_name", String.class);
+ return new OAuth2AuthorizedClientHolder(dbClientRegistrationId, dbPrincipalName, accessToken, refreshToken);
+ }
+
+ }
+
+}
diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/R2dbcReactiveOAuth2AuthorizedClientServiceTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/R2dbcReactiveOAuth2AuthorizedClientServiceTests.java
new file mode 100644
index 0000000000..41c5e9a0c5
--- /dev/null
+++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/R2dbcReactiveOAuth2AuthorizedClientServiceTests.java
@@ -0,0 +1,381 @@
+/*
+ * Copyright 2002-2020 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.oauth2.client;
+
+import io.r2dbc.h2.H2ConnectionFactory;
+import io.r2dbc.spi.ConnectionFactory;
+import io.r2dbc.spi.Result;
+import org.junit.Before;
+import org.junit.Test;
+import reactor.core.publisher.Flux;
+import reactor.core.publisher.Mono;
+import reactor.test.StepVerifier;
+
+import org.springframework.core.io.ClassPathResource;
+import org.springframework.dao.DataRetrievalFailureException;
+import org.springframework.r2dbc.connection.init.CompositeDatabasePopulator;
+import org.springframework.r2dbc.connection.init.ConnectionFactoryInitializer;
+import org.springframework.r2dbc.connection.init.ResourceDatabasePopulator;
+import org.springframework.r2dbc.core.DatabaseClient;
+import org.springframework.security.authentication.TestingAuthenticationToken;
+import org.springframework.security.core.Authentication;
+import org.springframework.security.oauth2.client.registration.ClientRegistration;
+import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
+import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
+import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.security.oauth2.core.OAuth2RefreshToken;
+import org.springframework.security.oauth2.core.TestOAuth2AccessTokens;
+import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyString;
+import static org.mockito.BDDMockito.given;
+import static org.mockito.Mockito.mock;
+
+/**
+ * Tests for {@link R2dbcReactiveOAuth2AuthorizedClientService}
+ *
+ * @author Ovidiu Popa
+ *
+ */
+public class R2dbcReactiveOAuth2AuthorizedClientServiceTests {
+
+ private static final String OAUTH2_CLIENT_SCHEMA_SQL_RESOURCE = "org/springframework/security/oauth2/client/oauth2-client-schema.sql";
+
+ private ClientRegistration clientRegistration;
+
+ private ReactiveClientRegistrationRepository clientRegistrationRepository;
+
+ private DatabaseClient databaseClient;
+
+ private static int principalId = 1000;
+
+ private R2dbcReactiveOAuth2AuthorizedClientService authorizedClientService;
+
+ @Before
+ public void setUp() {
+ final ConnectionFactory connectionFactory = createDb();
+ this.clientRegistration = TestClientRegistrations.clientRegistration().build();
+ this.clientRegistrationRepository = mock(ReactiveClientRegistrationRepository.class);
+ given(this.clientRegistrationRepository.findByRegistrationId(anyString()))
+ .willReturn(Mono.just(this.clientRegistration));
+ this.databaseClient = DatabaseClient.create(connectionFactory);
+ this.authorizedClientService = new R2dbcReactiveOAuth2AuthorizedClientService(this.databaseClient,
+ this.clientRegistrationRepository);
+ }
+
+ @Test
+ public void constructorWhenDatabaseClientIsNullThenThrowIllegalArgumentException() {
+ assertThatExceptionOfType(IllegalArgumentException.class)
+ .isThrownBy(
+ () -> new R2dbcReactiveOAuth2AuthorizedClientService(null, this.clientRegistrationRepository))
+ .withMessageContaining("databaseClient cannot be null");
+ }
+
+ @Test
+ public void constructorWhenClientRegistrationRepositoryIsNullThenThrowIllegalArgumentException() {
+ assertThatExceptionOfType(IllegalArgumentException.class)
+ .isThrownBy(() -> new R2dbcReactiveOAuth2AuthorizedClientService(this.databaseClient, null))
+ .withMessageContaining("clientRegistrationRepository cannot be null");
+ }
+
+ @Test
+ public void loadAuthorizedClientWhenClientRegistrationIdIsNullThenThrowIllegalArgumentException() {
+ assertThatExceptionOfType(IllegalArgumentException.class)
+ .isThrownBy(() -> this.authorizedClientService.loadAuthorizedClient(null, "principalName"))
+ .withMessageContaining("clientRegistrationId cannot be empty");
+ }
+
+ @Test
+ public void loadAuthorizedClientWhenPrincipalNameIsNullThenThrowIllegalArgumentException() {
+ assertThatExceptionOfType(IllegalArgumentException.class)
+ .isThrownBy(() -> this.authorizedClientService
+ .loadAuthorizedClient(this.clientRegistration.getRegistrationId(), null))
+ .withMessageContaining("principalName cannot be empty");
+ }
+
+ @Test
+ public void loadAuthorizedClientWhenDoesNotExistThenReturnNull() {
+ this.authorizedClientService.loadAuthorizedClient("registration-not-found", "principalName")
+ .as(StepVerifier::create).expectNextCount(0).verifyComplete();
+ }
+
+ @Test
+ public void loadAuthorizedClientWhenExistsThenReturnAuthorizedClient() {
+ Authentication principal = createPrincipal();
+ OAuth2AuthorizedClient expected = createAuthorizedClient(principal, this.clientRegistration);
+ this.authorizedClientService.saveAuthorizedClient(expected, principal).as(StepVerifier::create)
+ .verifyComplete();
+
+ this.authorizedClientService
+ .loadAuthorizedClient(this.clientRegistration.getRegistrationId(), principal.getName())
+ .as(StepVerifier::create).assertNext((authorizedClient) -> {
+ assertThat(authorizedClient).isNotNull();
+ assertThat(authorizedClient.getClientRegistration()).isEqualTo(expected.getClientRegistration());
+ assertThat(authorizedClient.getPrincipalName()).isEqualTo(expected.getPrincipalName());
+ assertThat(authorizedClient.getAccessToken().getTokenType())
+ .isEqualTo(expected.getAccessToken().getTokenType());
+ assertThat(authorizedClient.getAccessToken().getTokenValue())
+ .isEqualTo(expected.getAccessToken().getTokenValue());
+ assertThat(authorizedClient.getAccessToken().getIssuedAt())
+ .isEqualTo(expected.getAccessToken().getIssuedAt());
+ assertThat(authorizedClient.getAccessToken().getExpiresAt())
+ .isEqualTo(expected.getAccessToken().getExpiresAt());
+ assertThat(authorizedClient.getAccessToken().getScopes())
+ .isEqualTo(expected.getAccessToken().getScopes());
+ assertThat(authorizedClient.getRefreshToken().getTokenValue())
+ .isEqualTo(expected.getRefreshToken().getTokenValue());
+ assertThat(authorizedClient.getRefreshToken().getIssuedAt())
+ .isEqualTo(expected.getRefreshToken().getIssuedAt());
+ }).verifyComplete();
+ }
+
+ @Test
+ public void loadAuthorizedClientWhenExistsButNotFoundInClientRegistrationRepositoryThenThrowDataRetrievalFailureException() {
+ given(this.clientRegistrationRepository.findByRegistrationId(any())).willReturn(Mono.empty());
+ Authentication principal = createPrincipal();
+ OAuth2AuthorizedClient expected = createAuthorizedClient(principal, this.clientRegistration);
+
+ this.authorizedClientService.saveAuthorizedClient(expected, principal).as(StepVerifier::create)
+ .verifyComplete();
+
+ this.authorizedClientService
+ .loadAuthorizedClient(this.clientRegistration.getRegistrationId(), principal.getName())
+ .as(StepVerifier::create)
+ .verifyErrorSatisfies((exception) -> assertThat(exception)
+ .isInstanceOf(DataRetrievalFailureException.class)
+ .hasMessage("The ClientRegistration with id '" + this.clientRegistration.getRegistrationId()
+ + "' exists in the data source, however, it was not found in the ReactiveClientRegistrationRepository."));
+ }
+
+ @Test
+ public void saveAuthorizedClientWhenAuthorizedClientIsNullThenThrowIllegalArgumentException() {
+ Authentication principal = createPrincipal();
+
+ assertThatExceptionOfType(IllegalArgumentException.class)
+ .isThrownBy(() -> this.authorizedClientService.saveAuthorizedClient(null, principal))
+ .withMessageContaining("authorizedClient cannot be null");
+ }
+
+ @Test
+ public void saveAuthorizedClientWhenPrincipalIsNullThenThrowIllegalArgumentException() {
+ Authentication principal = createPrincipal();
+ OAuth2AuthorizedClient authorizedClient = createAuthorizedClient(principal, this.clientRegistration);
+ assertThatExceptionOfType(IllegalArgumentException.class)
+ .isThrownBy(() -> this.authorizedClientService.saveAuthorizedClient(authorizedClient, null))
+ .withMessageContaining("principal cannot be null");
+ }
+
+ @Test
+ public void saveAuthorizedClientWhenSaveThenLoadReturnsSaved() {
+ Authentication principal = createPrincipal();
+ final OAuth2AuthorizedClient expected = createAuthorizedClient(principal, this.clientRegistration);
+
+ this.authorizedClientService.saveAuthorizedClient(expected, principal).as(StepVerifier::create)
+ .verifyComplete();
+
+ this.authorizedClientService
+ .loadAuthorizedClient(this.clientRegistration.getRegistrationId(), principal.getName())
+ .as(StepVerifier::create).assertNext((authorizedClient) -> {
+ assertThat(authorizedClient).isNotNull();
+ assertThat(authorizedClient.getClientRegistration()).isEqualTo(expected.getClientRegistration());
+ assertThat(authorizedClient.getPrincipalName()).isEqualTo(expected.getPrincipalName());
+ assertThat(authorizedClient.getAccessToken().getTokenType())
+ .isEqualTo(expected.getAccessToken().getTokenType());
+ assertThat(authorizedClient.getAccessToken().getTokenValue())
+ .isEqualTo(expected.getAccessToken().getTokenValue());
+ assertThat(authorizedClient.getAccessToken().getIssuedAt())
+ .isEqualTo(expected.getAccessToken().getIssuedAt());
+ assertThat(authorizedClient.getAccessToken().getExpiresAt())
+ .isEqualTo(expected.getAccessToken().getExpiresAt());
+ assertThat(authorizedClient.getAccessToken().getScopes())
+ .isEqualTo(expected.getAccessToken().getScopes());
+ assertThat(authorizedClient.getRefreshToken().getTokenValue())
+ .isEqualTo(expected.getRefreshToken().getTokenValue());
+ assertThat(authorizedClient.getRefreshToken().getIssuedAt())
+ .isEqualTo(expected.getRefreshToken().getIssuedAt());
+ }).verifyComplete();
+
+ // Test save/load of NOT NULL attributes only
+ principal = createPrincipal();
+ OAuth2AuthorizedClient updatedExpectedPrincipal = createAuthorizedClient(principal, this.clientRegistration,
+ true);
+ this.authorizedClientService.saveAuthorizedClient(updatedExpectedPrincipal, principal).as(StepVerifier::create)
+ .verifyComplete();
+
+ this.authorizedClientService
+ .loadAuthorizedClient(this.clientRegistration.getRegistrationId(), principal.getName())
+ .as(StepVerifier::create).assertNext((authorizedClient) -> {
+ assertThat(authorizedClient).isNotNull();
+ assertThat(authorizedClient.getClientRegistration())
+ .isEqualTo(updatedExpectedPrincipal.getClientRegistration());
+ assertThat(authorizedClient.getPrincipalName())
+ .isEqualTo(updatedExpectedPrincipal.getPrincipalName());
+ assertThat(authorizedClient.getAccessToken().getTokenType())
+ .isEqualTo(updatedExpectedPrincipal.getAccessToken().getTokenType());
+ assertThat(authorizedClient.getAccessToken().getTokenValue())
+ .isEqualTo(updatedExpectedPrincipal.getAccessToken().getTokenValue());
+ assertThat(authorizedClient.getAccessToken().getIssuedAt())
+ .isEqualTo(updatedExpectedPrincipal.getAccessToken().getIssuedAt());
+ assertThat(authorizedClient.getAccessToken().getExpiresAt())
+ .isEqualTo(updatedExpectedPrincipal.getAccessToken().getExpiresAt());
+ assertThat(authorizedClient.getAccessToken().getScopes()).isEmpty();
+ assertThat(authorizedClient.getRefreshToken()).isNull();
+ }).verifyComplete();
+ }
+
+ @Test
+ public void saveAuthorizedClientWhenSaveClientWithExistingPrimaryKeyThenUpdate() {
+ // Given a saved authorized client
+ Authentication principal = createPrincipal();
+ OAuth2AuthorizedClient authorizedClient = createAuthorizedClient(principal, this.clientRegistration);
+ this.authorizedClientService.saveAuthorizedClient(authorizedClient, principal).as(StepVerifier::create)
+ .verifyComplete();
+
+ // When a client with the same principal and registration id is saved
+ OAuth2AuthorizedClient updatedAuthorizedClient = createAuthorizedClient(principal, this.clientRegistration);
+ this.authorizedClientService.saveAuthorizedClient(updatedAuthorizedClient, principal).as(StepVerifier::create)
+ .verifyComplete();
+
+ // Then the saved client is updated
+ this.authorizedClientService
+ .loadAuthorizedClient(this.clientRegistration.getRegistrationId(), principal.getName())
+ .as(StepVerifier::create).assertNext((savedClient) -> {
+ assertThat(savedClient).isNotNull();
+ assertThat(savedClient.getClientRegistration())
+ .isEqualTo(updatedAuthorizedClient.getClientRegistration());
+ assertThat(savedClient.getPrincipalName()).isEqualTo(updatedAuthorizedClient.getPrincipalName());
+ assertThat(savedClient.getAccessToken().getTokenType())
+ .isEqualTo(updatedAuthorizedClient.getAccessToken().getTokenType());
+ assertThat(savedClient.getAccessToken().getTokenValue())
+ .isEqualTo(updatedAuthorizedClient.getAccessToken().getTokenValue());
+ assertThat(savedClient.getAccessToken().getIssuedAt())
+ .isEqualTo(updatedAuthorizedClient.getAccessToken().getIssuedAt());
+ assertThat(savedClient.getAccessToken().getExpiresAt())
+ .isEqualTo(updatedAuthorizedClient.getAccessToken().getExpiresAt());
+ assertThat(savedClient.getAccessToken().getScopes())
+ .isEqualTo(updatedAuthorizedClient.getAccessToken().getScopes());
+ assertThat(savedClient.getRefreshToken().getTokenValue())
+ .isEqualTo(updatedAuthorizedClient.getRefreshToken().getTokenValue());
+ assertThat(savedClient.getRefreshToken().getIssuedAt())
+ .isEqualTo(updatedAuthorizedClient.getRefreshToken().getIssuedAt());
+ });
+ }
+
+ @Test
+ public void removeAuthorizedClientWhenClientRegistrationIdIsNullThenThrowIllegalArgumentException() {
+ assertThatExceptionOfType(IllegalArgumentException.class)
+ .isThrownBy(() -> this.authorizedClientService.removeAuthorizedClient(null, "principalName"))
+ .withMessageContaining("clientRegistrationId cannot be empty");
+ }
+
+ @Test
+ public void removeAuthorizedClientWhenPrincipalNameIsNullThenThrowIllegalArgumentException() {
+ assertThatExceptionOfType(IllegalArgumentException.class)
+ .isThrownBy(() -> this.authorizedClientService
+ .removeAuthorizedClient(this.clientRegistration.getRegistrationId(), null))
+ .withMessageContaining("principalName cannot be empty");
+ }
+
+ @Test
+ public void removeAuthorizedClientWhenExistsThenRemoved() {
+ Authentication principal = createPrincipal();
+ OAuth2AuthorizedClient authorizedClient = createAuthorizedClient(principal, this.clientRegistration);
+
+ this.authorizedClientService.saveAuthorizedClient(authorizedClient, principal).as(StepVerifier::create)
+ .verifyComplete();
+
+ this.authorizedClientService
+ .loadAuthorizedClient(this.clientRegistration.getRegistrationId(), principal.getName())
+ .as(StepVerifier::create).assertNext((dbAuthorizedClient) -> assertThat(dbAuthorizedClient).isNotNull())
+ .verifyComplete();
+
+ this.authorizedClientService
+ .removeAuthorizedClient(this.clientRegistration.getRegistrationId(), principal.getName())
+ .as(StepVerifier::create).verifyComplete();
+
+ this.authorizedClientService
+ .loadAuthorizedClient(this.clientRegistration.getRegistrationId(), principal.getName())
+ .as(StepVerifier::create).expectNextCount(0).verifyComplete();
+ }
+
+ @Test
+ public void setAuthorizedClientRowMapperWhenNullThenThrowIllegalArgumentException() {
+ assertThatExceptionOfType(IllegalArgumentException.class)
+ .isThrownBy(() -> this.authorizedClientService.setAuthorizedClientRowMapper(null))
+ .withMessageContaining("authorizedClientRowMapper cannot be nul");
+ }
+
+ @Test
+ public void setAuthorizedClientParametersMapperWhenNullThenThrowIllegalArgumentException() {
+ assertThatExceptionOfType(IllegalArgumentException.class)
+ .isThrownBy(() -> this.authorizedClientService.setAuthorizedClientParametersMapper(null))
+ .withMessageContaining("authorizedClientParametersMapper cannot be nul");
+ }
+
+ private static ConnectionFactory createDb() {
+ ConnectionFactory connectionFactory = H2ConnectionFactory.inMemory("oauth-test");
+
+ Mono.from(connectionFactory.create())
+ .flatMapMany((connection) -> Flux
+ .from(connection.createStatement("drop table oauth2_authorized_client").execute())
+ .flatMap(Result::getRowsUpdated).onErrorResume((e) -> Mono.empty())
+ .thenMany(connection.close()))
+ .as(StepVerifier::create).verifyComplete();
+ ConnectionFactoryInitializer createDb = createDb(OAUTH2_CLIENT_SCHEMA_SQL_RESOURCE);
+ createDb.setConnectionFactory(connectionFactory);
+ createDb.afterPropertiesSet();
+ return connectionFactory;
+ }
+
+ private static ConnectionFactoryInitializer createDb(String schema) {
+ ConnectionFactoryInitializer initializer = new ConnectionFactoryInitializer();
+
+ CompositeDatabasePopulator populator = new CompositeDatabasePopulator();
+ populator.addPopulators(new ResourceDatabasePopulator(new ClassPathResource(schema)));
+ initializer.setDatabasePopulator(populator);
+ return initializer;
+ }
+
+ private static Authentication createPrincipal() {
+ return new TestingAuthenticationToken("principal-" + principalId++, "password");
+ }
+
+ private static OAuth2AuthorizedClient createAuthorizedClient(Authentication principal,
+ ClientRegistration clientRegistration) {
+ return createAuthorizedClient(principal, clientRegistration, false);
+ }
+
+ private static OAuth2AuthorizedClient createAuthorizedClient(Authentication principal,
+ ClientRegistration clientRegistration, boolean requiredAttributesOnly) {
+ OAuth2AccessToken accessToken;
+ if (!requiredAttributesOnly) {
+ accessToken = TestOAuth2AccessTokens.scopes("read", "write");
+ }
+ else {
+ accessToken = TestOAuth2AccessTokens.noScopes();
+ }
+ OAuth2RefreshToken refreshToken = null;
+ if (!requiredAttributesOnly) {
+ refreshToken = TestOAuth2RefreshTokens.refreshToken();
+ }
+ return new OAuth2AuthorizedClient(clientRegistration, principal.getName(), accessToken, refreshToken);
+ }
+
+}