diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/InMemoryReactiveClientRegistrationRepository.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/InMemoryReactiveClientRegistrationRepository.java index 1cddacc26c..98ba597c41 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/InMemoryReactiveClientRegistrationRepository.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/InMemoryReactiveClientRegistrationRepository.java @@ -15,16 +15,17 @@ */ package org.springframework.security.oauth2.client.registration; +import java.util.Arrays; +import java.util.Collections; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentHashMap; - -import org.springframework.util.Assert; import reactor.core.publisher.Mono; +import org.springframework.util.Assert; + /** * A Reactive {@link ClientRegistrationRepository} that stores {@link ClientRegistration}(s) in-memory. * @@ -45,12 +46,12 @@ public final class InMemoryReactiveClientRegistrationRepository * @param registrations the client registration(s) */ public InMemoryReactiveClientRegistrationRepository(ClientRegistration... registrations) { - Assert.notEmpty(registrations, "registrations cannot be empty"); - this.clientIdToClientRegistration = new ConcurrentHashMap<>(); - for (ClientRegistration registration : registrations) { - Assert.notNull(registration, "registrations cannot contain null values"); - this.clientIdToClientRegistration.put(registration.getRegistrationId(), registration); - } + this(toList(registrations)); + } + + private static List toList(ClientRegistration... registrations) { + Assert.notEmpty(registrations, "registrations cannot be null or empty"); + return Arrays.asList(registrations); } /** @@ -59,8 +60,7 @@ public final class InMemoryReactiveClientRegistrationRepository * @param registrations the client registration(s) */ public InMemoryReactiveClientRegistrationRepository(List registrations) { - Assert.notEmpty(registrations, "registrations cannot be null or empty"); - this.clientIdToClientRegistration = toConcurrentMap(registrations); + this.clientIdToClientRegistration = toUnmodifiableConcurrentMap(registrations); } @Override @@ -78,11 +78,17 @@ public final class InMemoryReactiveClientRegistrationRepository return this.clientIdToClientRegistration.values().iterator(); } - private ConcurrentHashMap toConcurrentMap(List registrations) { + private static Map toUnmodifiableConcurrentMap(List registrations) { + Assert.notEmpty(registrations, "registrations cannot be null or empty"); ConcurrentHashMap result = new ConcurrentHashMap<>(); for (ClientRegistration registration : registrations) { + Assert.notNull(registration, "no registration can be null"); + if (result.containsKey(registration.getRegistrationId())) { + throw new IllegalStateException(String.format("Duplicate key %s", + registration.getRegistrationId())); + } result.put(registration.getRegistrationId(), registration); } - return result; + return Collections.unmodifiableMap(result); } } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/registration/InMemoryReactiveClientRegistrationRepositoryTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/registration/InMemoryReactiveClientRegistrationRepositoryTests.java index 4c29adb86b..5ca19c509b 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/registration/InMemoryReactiveClientRegistrationRepositoryTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/registration/InMemoryReactiveClientRegistrationRepositoryTests.java @@ -16,16 +16,16 @@ package org.springframework.security.oauth2.client.registration; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; - +import java.util.Arrays; import java.util.List; import org.junit.Before; import org.junit.Test; - import reactor.test.StepVerifier; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + /** * @author Rob Winch * @since 5.1 @@ -61,6 +61,12 @@ public class InMemoryReactiveClientRegistrationRepositoryTests { .isInstanceOf(IllegalArgumentException.class); } + @Test(expected = IllegalStateException.class) + public void constructorListClientRegistrationWhenDuplicateIdThenIllegalArgumentException() { + List registrations = Arrays.asList(this.registration, this.registration); + new InMemoryReactiveClientRegistrationRepository(registrations); + } + @Test public void constructorWhenClientRegistrationIsNullThenIllegalArgumentException() { ClientRegistration registration = null;