diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientTests.java index 918cda0ddc..ab5c1f3111 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientTests.java @@ -17,22 +17,19 @@ package org.springframework.security.oauth2.client; import org.junit.Before; import org.junit.Test; -import org.junit.runner.RunWith; -import org.powermock.core.classloader.annotations.PrepareForTest; -import org.powermock.modules.junit4.PowerMockRunner; + import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.core.OAuth2AccessToken; import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.Mockito.mock; +import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration; +import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.noScopes; /** * Tests for {@link OAuth2AuthorizedClient}. * * @author Joe Grandja */ -@RunWith(PowerMockRunner.class) -@PrepareForTest(ClientRegistration.class) public class OAuth2AuthorizedClientTests { private ClientRegistration clientRegistration; private String principalName; @@ -40,9 +37,9 @@ public class OAuth2AuthorizedClientTests { @Before public void setUp() { - this.clientRegistration = mock(ClientRegistration.class); + this.clientRegistration = clientRegistration().build(); this.principalName = "principal"; - this.accessToken = mock(OAuth2AccessToken.class); + this.accessToken = noScopes(); } @Test(expected = IllegalArgumentException.class) diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java index 622911124f..2615c83dda 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java @@ -15,62 +15,50 @@ */ package org.springframework.security.oauth2.client.authentication; +import java.util.Collections; + import org.junit.Before; import org.junit.Test; -import org.junit.runner.RunWith; -import org.powermock.core.classloader.annotations.PrepareForTest; -import org.powermock.modules.junit4.PowerMockRunner; + import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest; import org.springframework.security.oauth2.client.registration.ClientRegistration; -import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AuthorizationException; -import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2ErrorCodes; -import org.springframework.security.oauth2.core.OAuth2RefreshToken; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse; -import java.util.Collections; - import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; +import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration; +import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses.accessTokenResponse; +import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests.request; +import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationResponses.error; +import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationResponses.success; /** * Tests for {@link OAuth2AuthorizationCodeAuthenticationProvider}. * * @author Joe Grandja */ -@PrepareForTest({ClientRegistration.class, OAuth2AuthorizationRequest.class, - OAuth2AuthorizationResponse.class, OAuth2AccessTokenResponse.class}) -@RunWith(PowerMockRunner.class) public class OAuth2AuthorizationCodeAuthenticationProviderTests { private ClientRegistration clientRegistration; private OAuth2AuthorizationRequest authorizationRequest; - private OAuth2AuthorizationResponse authorizationResponse; - private OAuth2AuthorizationExchange authorizationExchange; private OAuth2AccessTokenResponseClient accessTokenResponseClient; private OAuth2AuthorizationCodeAuthenticationProvider authenticationProvider; @Before @SuppressWarnings("unchecked") - public void setUp() throws Exception { - this.clientRegistration = mock(ClientRegistration.class); - this.authorizationRequest = mock(OAuth2AuthorizationRequest.class); - this.authorizationResponse = mock(OAuth2AuthorizationResponse.class); - this.authorizationExchange = new OAuth2AuthorizationExchange(this.authorizationRequest, this.authorizationResponse); + public void setUp() { + this.clientRegistration = clientRegistration().build(); + this.authorizationRequest = request().build(); this.accessTokenResponseClient = mock(OAuth2AccessTokenResponseClient.class); this.authenticationProvider = new OAuth2AuthorizationCodeAuthenticationProvider(this.accessTokenResponseClient); - - when(this.authorizationRequest.getState()).thenReturn("12345"); - when(this.authorizationResponse.getState()).thenReturn("12345"); - when(this.authorizationRequest.getRedirectUri()).thenReturn("http://example.com"); - when(this.authorizationResponse.getRedirectUri()).thenReturn("http://example.com"); } @Test @@ -86,60 +74,62 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests { @Test public void authenticateWhenAuthorizationErrorResponseThenThrowOAuth2AuthorizationException() { - when(this.authorizationResponse.statusError()).thenReturn(true); - when(this.authorizationResponse.getError()).thenReturn(new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST)); + OAuth2AuthorizationResponse authorizationResponse = error().errorCode(OAuth2ErrorCodes.INVALID_REQUEST).build(); + OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange( + this.authorizationRequest, authorizationResponse); assertThatThrownBy(() -> { this.authenticationProvider.authenticate( new OAuth2AuthorizationCodeAuthenticationToken( - this.clientRegistration, this.authorizationExchange)); + this.clientRegistration, authorizationExchange)); }).isInstanceOf(OAuth2AuthorizationException.class).hasMessageContaining(OAuth2ErrorCodes.INVALID_REQUEST); } @Test public void authenticateWhenAuthorizationResponseStateNotEqualAuthorizationRequestStateThenThrowOAuth2AuthorizationException() { - when(this.authorizationRequest.getState()).thenReturn("12345"); - when(this.authorizationResponse.getState()).thenReturn("67890"); + OAuth2AuthorizationResponse authorizationResponse = success().state("67890").build(); + OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange( + this.authorizationRequest, authorizationResponse); assertThatThrownBy(() -> { this.authenticationProvider.authenticate( new OAuth2AuthorizationCodeAuthenticationToken( - this.clientRegistration, this.authorizationExchange)); + this.clientRegistration, authorizationExchange)); }).isInstanceOf(OAuth2AuthorizationException.class).hasMessageContaining("invalid_state_parameter"); } @Test public void authenticateWhenAuthorizationResponseRedirectUriNotEqualAuthorizationRequestRedirectUriThenThrowOAuth2AuthorizationException() { - when(this.authorizationRequest.getRedirectUri()).thenReturn("http://example.com"); - when(this.authorizationResponse.getRedirectUri()).thenReturn("http://example2.com"); + OAuth2AuthorizationResponse authorizationResponse = success().redirectUri("http://example2.com").build(); + OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange( + this.authorizationRequest, authorizationResponse); assertThatThrownBy(() -> { this.authenticationProvider.authenticate( new OAuth2AuthorizationCodeAuthenticationToken( - this.clientRegistration, this.authorizationExchange)); + this.clientRegistration, authorizationExchange)); }).isInstanceOf(OAuth2AuthorizationException.class).hasMessageContaining("invalid_redirect_uri_parameter"); } @Test public void authenticateWhenAuthorizationSuccessResponseThenExchangedForAccessToken() { - OAuth2AccessToken accessToken = mock(OAuth2AccessToken.class); - OAuth2RefreshToken refreshToken = mock(OAuth2RefreshToken.class); - OAuth2AccessTokenResponse accessTokenResponse = mock(OAuth2AccessTokenResponse.class); - when(accessTokenResponse.getAccessToken()).thenReturn(accessToken); - when(accessTokenResponse.getRefreshToken()).thenReturn(refreshToken); + OAuth2AccessTokenResponse accessTokenResponse = accessTokenResponse().refreshToken("refresh").build(); when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse); + OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange( + this.authorizationRequest, success().build()); OAuth2AuthorizationCodeAuthenticationToken authenticationResult = (OAuth2AuthorizationCodeAuthenticationToken) this.authenticationProvider.authenticate( - new OAuth2AuthorizationCodeAuthenticationToken(this.clientRegistration, this.authorizationExchange)); + new OAuth2AuthorizationCodeAuthenticationToken(this.clientRegistration, authorizationExchange)); assertThat(authenticationResult.isAuthenticated()).isTrue(); assertThat(authenticationResult.getPrincipal()).isEqualTo(this.clientRegistration.getClientId()); - assertThat(authenticationResult.getCredentials()).isEqualTo(accessToken.getTokenValue()); + assertThat(authenticationResult.getCredentials()) + .isEqualTo(accessTokenResponse.getAccessToken().getTokenValue()); assertThat(authenticationResult.getAuthorities()).isEqualTo(Collections.emptyList()); assertThat(authenticationResult.getClientRegistration()).isEqualTo(this.clientRegistration); - assertThat(authenticationResult.getAuthorizationExchange()).isEqualTo(this.authorizationExchange); - assertThat(authenticationResult.getAccessToken()).isEqualTo(accessToken); - assertThat(authenticationResult.getRefreshToken()).isEqualTo(refreshToken); + assertThat(authenticationResult.getAuthorizationExchange()).isEqualTo(authorizationExchange); + assertThat(authenticationResult.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); + assertThat(authenticationResult.getRefreshToken()).isEqualTo(accessTokenResponse.getRefreshToken()); } } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeAuthenticationTokenTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeAuthenticationTokenTests.java index 0017a46c91..e0d6bd07b6 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeAuthenticationTokenTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeAuthenticationTokenTests.java @@ -15,30 +15,27 @@ */ package org.springframework.security.oauth2.client.authentication; +import java.util.Collections; + import org.junit.Before; import org.junit.Test; -import org.junit.runner.RunWith; -import org.powermock.core.classloader.annotations.PrepareForTest; -import org.powermock.modules.junit4.PowerMockRunner; + import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange; -import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse; - -import java.util.Collections; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; +import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration; +import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.noScopes; +import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests.request; +import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationResponses.success; /** * Tests for {@link OAuth2AuthorizationCodeAuthenticationToken}. * * @author Joe Grandja */ -@RunWith(PowerMockRunner.class) -@PrepareForTest({ClientRegistration.class, OAuth2AuthorizationExchange.class, OAuth2AuthorizationResponse.class}) public class OAuth2AuthorizationCodeAuthenticationTokenTests { private ClientRegistration clientRegistration; private OAuth2AuthorizationExchange authorizationExchange; @@ -46,9 +43,10 @@ public class OAuth2AuthorizationCodeAuthenticationTokenTests { @Before public void setUp() { - this.clientRegistration = mock(ClientRegistration.class); - this.authorizationExchange = mock(OAuth2AuthorizationExchange.class); - this.accessToken = mock(OAuth2AccessToken.class); + this.clientRegistration = clientRegistration().build(); + this.authorizationExchange = new OAuth2AuthorizationExchange(request().build(), + success().code("code").build()); + this.accessToken = noScopes(); } @Test @@ -65,10 +63,6 @@ public class OAuth2AuthorizationCodeAuthenticationTokenTests { @Test public void constructorAuthorizationRequestResponseWhenAllParametersProvidedAndValidThenCreated() { - OAuth2AuthorizationResponse authorizationResponse = mock(OAuth2AuthorizationResponse.class); - when(authorizationResponse.getCode()).thenReturn("code"); - when(this.authorizationExchange.getAuthorizationResponse()).thenReturn(authorizationResponse); - OAuth2AuthorizationCodeAuthenticationToken authentication = new OAuth2AuthorizationCodeAuthenticationToken(this.clientRegistration, this.authorizationExchange); diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginAuthenticationProviderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginAuthenticationProviderTests.java index 69949edb67..66fcece124 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginAuthenticationProviderTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginAuthenticationProviderTests.java @@ -15,15 +15,21 @@ */ package org.springframework.security.oauth2.client.authentication; +import java.time.Instant; +import java.util.Arrays; +import java.util.HashMap; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; -import org.junit.runner.RunWith; import org.mockito.ArgumentCaptor; import org.mockito.stubbing.Answer; -import org.powermock.core.classloader.annotations.PrepareForTest; -import org.powermock.modules.junit4.PowerMockRunner; + import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper; @@ -34,7 +40,6 @@ import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest; import org.springframework.security.oauth2.client.userinfo.OAuth2UserService; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; -import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange; @@ -42,30 +47,22 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequ import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse; import org.springframework.security.oauth2.core.user.OAuth2User; -import java.time.Instant; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashMap; -import java.util.LinkedHashSet; -import java.util.List; -import java.util.Map; -import java.util.Set; - import static org.assertj.core.api.Assertions.assertThat; import static org.hamcrest.CoreMatchers.containsString; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyCollection; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; +import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration; +import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests.request; +import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationResponses.error; +import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationResponses.success; /** * Tests for {@link OAuth2LoginAuthenticationProvider}. * * @author Joe Grandja */ -@PrepareForTest({ClientRegistration.class, OAuth2AuthorizationRequest.class, - OAuth2AuthorizationResponse.class, OAuth2AccessTokenResponse.class}) -@RunWith(PowerMockRunner.class) public class OAuth2LoginAuthenticationProviderTests { private ClientRegistration clientRegistration; private OAuth2AuthorizationRequest authorizationRequest; @@ -81,19 +78,13 @@ public class OAuth2LoginAuthenticationProviderTests { @Before @SuppressWarnings("unchecked") public void setUp() throws Exception { - this.clientRegistration = mock(ClientRegistration.class); - this.authorizationRequest = mock(OAuth2AuthorizationRequest.class); - this.authorizationResponse = mock(OAuth2AuthorizationResponse.class); + this.clientRegistration = clientRegistration().build(); + this.authorizationRequest = request().scope("scope1", "scope2").build(); + this.authorizationResponse = success().build(); this.authorizationExchange = new OAuth2AuthorizationExchange(this.authorizationRequest, this.authorizationResponse); this.accessTokenResponseClient = mock(OAuth2AccessTokenResponseClient.class); this.userService = mock(OAuth2UserService.class); this.authenticationProvider = new OAuth2LoginAuthenticationProvider(this.accessTokenResponseClient, this.userService); - - when(this.authorizationRequest.getScopes()).thenReturn(new LinkedHashSet<>(Arrays.asList("scope1", "scope2"))); - when(this.authorizationRequest.getState()).thenReturn("12345"); - when(this.authorizationResponse.getState()).thenReturn("12345"); - when(this.authorizationRequest.getRedirectUri()).thenReturn("http://example.com"); - when(this.authorizationResponse.getRedirectUri()).thenReturn("http://example.com"); } @Test @@ -121,11 +112,13 @@ public class OAuth2LoginAuthenticationProviderTests { @Test public void authenticateWhenAuthorizationRequestContainsOpenidScopeThenReturnNull() { - when(this.authorizationRequest.getScopes()).thenReturn(new LinkedHashSet<>(Collections.singleton("openid"))); + OAuth2AuthorizationRequest authorizationRequest = request().scope("openid").build(); + OAuth2AuthorizationExchange authorizationExchange = + new OAuth2AuthorizationExchange(authorizationRequest, this.authorizationResponse); OAuth2LoginAuthenticationToken authentication = (OAuth2LoginAuthenticationToken) this.authenticationProvider.authenticate( - new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange)); + new OAuth2LoginAuthenticationToken(this.clientRegistration, authorizationExchange)); assertThat(authentication).isNull(); } @@ -135,11 +128,13 @@ public class OAuth2LoginAuthenticationProviderTests { this.exception.expect(OAuth2AuthenticationException.class); this.exception.expectMessage(containsString(OAuth2ErrorCodes.INVALID_REQUEST)); - when(this.authorizationResponse.statusError()).thenReturn(true); - when(this.authorizationResponse.getError()).thenReturn(new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST)); + OAuth2AuthorizationResponse authorizationResponse = + error().errorCode(OAuth2ErrorCodes.INVALID_REQUEST).build(); + OAuth2AuthorizationExchange authorizationExchange = + new OAuth2AuthorizationExchange(this.authorizationRequest, authorizationResponse); this.authenticationProvider.authenticate( - new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange)); + new OAuth2LoginAuthenticationToken(this.clientRegistration, authorizationExchange)); } @Test @@ -147,11 +142,13 @@ public class OAuth2LoginAuthenticationProviderTests { this.exception.expect(OAuth2AuthenticationException.class); this.exception.expectMessage(containsString("invalid_state_parameter")); - when(this.authorizationRequest.getState()).thenReturn("12345"); - when(this.authorizationResponse.getState()).thenReturn("67890"); + OAuth2AuthorizationResponse authorizationResponse = + success().state("67890").build(); + OAuth2AuthorizationExchange authorizationExchange = + new OAuth2AuthorizationExchange(this.authorizationRequest, authorizationResponse); this.authenticationProvider.authenticate( - new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange)); + new OAuth2LoginAuthenticationToken(this.clientRegistration, authorizationExchange)); } @Test @@ -159,11 +156,13 @@ public class OAuth2LoginAuthenticationProviderTests { this.exception.expect(OAuth2AuthenticationException.class); this.exception.expectMessage(containsString("invalid_redirect_uri_parameter")); - when(this.authorizationRequest.getRedirectUri()).thenReturn("http://example.com"); - when(this.authorizationResponse.getRedirectUri()).thenReturn("http://example2.com"); + OAuth2AuthorizationResponse authorizationResponse = + success().redirectUri("http://example2.com").build(); + OAuth2AuthorizationExchange authorizationExchange = + new OAuth2AuthorizationExchange(this.authorizationRequest, authorizationResponse); this.authenticationProvider.authenticate( - new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange)); + new OAuth2LoginAuthenticationToken(this.clientRegistration, authorizationExchange)); } @Test diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginAuthenticationTokenTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginAuthenticationTokenTests.java index 9e635d0c2d..5e651dc2a1 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginAuthenticationTokenTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginAuthenticationTokenTests.java @@ -15,30 +15,30 @@ */ package org.springframework.security.oauth2.client.authentication; +import java.util.Collection; +import java.util.Collections; + import org.junit.Before; import org.junit.Test; -import org.junit.runner.RunWith; -import org.powermock.core.classloader.annotations.PrepareForTest; -import org.powermock.modules.junit4.PowerMockRunner; + import org.springframework.security.core.GrantedAuthority; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange; import org.springframework.security.oauth2.core.user.OAuth2User; -import java.util.Collection; -import java.util.Collections; - import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.Mockito.mock; +import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration; +import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.noScopes; +import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests.request; +import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationResponses.success; /** * Tests for {@link OAuth2LoginAuthenticationToken}. * * @author Joe Grandja */ -@RunWith(PowerMockRunner.class) -@PrepareForTest({ClientRegistration.class, OAuth2AuthorizationExchange.class}) public class OAuth2LoginAuthenticationTokenTests { private OAuth2User principal; private Collection authorities; @@ -50,9 +50,10 @@ public class OAuth2LoginAuthenticationTokenTests { public void setUp() { this.principal = mock(OAuth2User.class); this.authorities = Collections.emptyList(); - this.clientRegistration = mock(ClientRegistration.class); - this.authorizationExchange = mock(OAuth2AuthorizationExchange.class); - this.accessToken = mock(OAuth2AccessToken.class); + this.clientRegistration = clientRegistration().build(); + this.authorizationExchange = new OAuth2AuthorizationExchange( + request().build(), success().code("code").build()); + this.accessToken = noScopes(); } @Test(expected = IllegalArgumentException.class) diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/NimbusAuthorizationCodeTokenResponseClientTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/NimbusAuthorizationCodeTokenResponseClientTests.java index f88c1a0042..dfee000019 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/NimbusAuthorizationCodeTokenResponseClientTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/NimbusAuthorizationCodeTokenResponseClientTests.java @@ -15,16 +15,15 @@ */ package org.springframework.security.oauth2.client.endpoint; +import java.time.Instant; + import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockWebServer; import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; -import org.junit.runner.RunWith; -import org.powermock.core.classloader.annotations.PowerMockIgnore; -import org.powermock.core.classloader.annotations.PrepareForTest; -import org.powermock.modules.junit4.PowerMockRunner; + import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; import org.springframework.security.oauth2.client.registration.ClientRegistration; @@ -36,27 +35,19 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExch import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse; -import java.time.Instant; -import java.util.Arrays; -import java.util.LinkedHashSet; -import java.util.Set; - import static org.assertj.core.api.Assertions.assertThat; import static org.hamcrest.CoreMatchers.containsString; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; +import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration; +import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests.request; +import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationResponses.success; /** * Tests for {@link NimbusAuthorizationCodeTokenResponseClient}. * * @author Joe Grandja */ -@PowerMockIgnore("okhttp3.*") -@PrepareForTest({ClientRegistration.class, OAuth2AuthorizationRequest.class, OAuth2AuthorizationResponse.class, OAuth2AuthorizationExchange.class}) -@RunWith(PowerMockRunner.class) public class NimbusAuthorizationCodeTokenResponseClientTests { - private ClientRegistration clientRegistration; - private ClientRegistration.ProviderDetails providerDetails; + private ClientRegistration.Builder clientRegistrationBuilder; private OAuth2AuthorizationRequest authorizationRequest; private OAuth2AuthorizationResponse authorizationResponse; private OAuth2AuthorizationExchange authorizationExchange; @@ -67,18 +58,11 @@ public class NimbusAuthorizationCodeTokenResponseClientTests { @Before public void setUp() throws Exception { - this.clientRegistration = mock(ClientRegistration.class); - this.providerDetails = mock(ClientRegistration.ProviderDetails.class); - this.authorizationRequest = mock(OAuth2AuthorizationRequest.class); - this.authorizationResponse = mock(OAuth2AuthorizationResponse.class); + this.clientRegistrationBuilder = clientRegistration() + .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC); + this.authorizationRequest = request().build(); + this.authorizationResponse = success().build(); this.authorizationExchange = new OAuth2AuthorizationExchange(this.authorizationRequest, this.authorizationResponse); - - when(this.clientRegistration.getProviderDetails()).thenReturn(this.providerDetails); - when(this.clientRegistration.getClientId()).thenReturn("client-id"); - when(this.clientRegistration.getClientSecret()).thenReturn("secret"); - when(this.clientRegistration.getClientAuthenticationMethod()).thenReturn(ClientAuthenticationMethod.BASIC); - when(this.authorizationRequest.getRedirectUri()).thenReturn("http://example.com"); - when(this.authorizationResponse.getCode()).thenReturn("code"); } @Test @@ -100,12 +84,13 @@ public class NimbusAuthorizationCodeTokenResponseClientTests { server.start(); String tokenUri = server.url("/oauth2/token").toString(); - when(this.providerDetails.getTokenUri()).thenReturn(tokenUri); + this.clientRegistrationBuilder.tokenUri(tokenUri); Instant expiresAtBefore = Instant.now().plusSeconds(3600); OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse( - new OAuth2AuthorizationCodeGrantRequest(this.clientRegistration, this.authorizationExchange)); + new OAuth2AuthorizationCodeGrantRequest( + this.clientRegistrationBuilder.build(), this.authorizationExchange)); Instant expiresAtAfter = Instant.now().plusSeconds(3600); @@ -126,10 +111,13 @@ public class NimbusAuthorizationCodeTokenResponseClientTests { this.exception.expect(IllegalArgumentException.class); String redirectUri = "http:\\example.com"; - when(this.clientRegistration.getRedirectUriTemplate()).thenReturn(redirectUri); + OAuth2AuthorizationRequest authorizationRequest = request().redirectUri(redirectUri).build(); + OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange( + authorizationRequest, this.authorizationResponse); this.tokenResponseClient.getTokenResponse( - new OAuth2AuthorizationCodeGrantRequest(this.clientRegistration, this.authorizationExchange)); + new OAuth2AuthorizationCodeGrantRequest( + this.clientRegistrationBuilder.build(), authorizationExchange)); } @Test @@ -137,10 +125,11 @@ public class NimbusAuthorizationCodeTokenResponseClientTests { this.exception.expect(IllegalArgumentException.class); String tokenUri = "http:\\provider.com\\oauth2\\token"; - when(this.providerDetails.getTokenUri()).thenReturn(tokenUri); + this.clientRegistrationBuilder.tokenUri(tokenUri); this.tokenResponseClient.getTokenResponse( - new OAuth2AuthorizationCodeGrantRequest(this.clientRegistration, this.authorizationExchange)); + new OAuth2AuthorizationCodeGrantRequest( + this.clientRegistrationBuilder.build(), this.authorizationExchange)); } @Test @@ -165,11 +154,12 @@ public class NimbusAuthorizationCodeTokenResponseClientTests { server.start(); String tokenUri = server.url("/oauth2/token").toString(); - when(this.providerDetails.getTokenUri()).thenReturn(tokenUri); + this.clientRegistrationBuilder.tokenUri(tokenUri); try { this.tokenResponseClient.getTokenResponse( - new OAuth2AuthorizationCodeGrantRequest(this.clientRegistration, this.authorizationExchange)); + new OAuth2AuthorizationCodeGrantRequest( + this.clientRegistrationBuilder.build(), this.authorizationExchange)); } finally { server.shutdown(); } @@ -180,10 +170,11 @@ public class NimbusAuthorizationCodeTokenResponseClientTests { this.exception.expect(OAuth2AuthorizationException.class); String tokenUri = "http://invalid-provider.com/oauth2/token"; - when(this.providerDetails.getTokenUri()).thenReturn(tokenUri); + this.clientRegistrationBuilder.tokenUri(tokenUri); this.tokenResponseClient.getTokenResponse( - new OAuth2AuthorizationCodeGrantRequest(this.clientRegistration, this.authorizationExchange)); + new OAuth2AuthorizationCodeGrantRequest( + this.clientRegistrationBuilder.build(), this.authorizationExchange)); } @Test @@ -203,11 +194,12 @@ public class NimbusAuthorizationCodeTokenResponseClientTests { server.start(); String tokenUri = server.url("/oauth2/token").toString(); - when(this.providerDetails.getTokenUri()).thenReturn(tokenUri); + this.clientRegistrationBuilder.tokenUri(tokenUri); try { this.tokenResponseClient.getTokenResponse( - new OAuth2AuthorizationCodeGrantRequest(this.clientRegistration, this.authorizationExchange)); + new OAuth2AuthorizationCodeGrantRequest( + this.clientRegistrationBuilder.build(), this.authorizationExchange)); } finally { server.shutdown(); } @@ -225,11 +217,12 @@ public class NimbusAuthorizationCodeTokenResponseClientTests { server.start(); String tokenUri = server.url("/oauth2/token").toString(); - when(this.providerDetails.getTokenUri()).thenReturn(tokenUri); + this.clientRegistrationBuilder.tokenUri(tokenUri); try { this.tokenResponseClient.getTokenResponse( - new OAuth2AuthorizationCodeGrantRequest(this.clientRegistration, this.authorizationExchange)); + new OAuth2AuthorizationCodeGrantRequest( + this.clientRegistrationBuilder.build(), this.authorizationExchange)); } finally { server.shutdown(); } @@ -254,11 +247,12 @@ public class NimbusAuthorizationCodeTokenResponseClientTests { server.start(); String tokenUri = server.url("/oauth2/token").toString(); - when(this.providerDetails.getTokenUri()).thenReturn(tokenUri); + this.clientRegistrationBuilder.tokenUri(tokenUri); try { this.tokenResponseClient.getTokenResponse( - new OAuth2AuthorizationCodeGrantRequest(this.clientRegistration, this.authorizationExchange)); + new OAuth2AuthorizationCodeGrantRequest( + this.clientRegistrationBuilder.build(), this.authorizationExchange)); } finally { server.shutdown(); } @@ -280,13 +274,16 @@ public class NimbusAuthorizationCodeTokenResponseClientTests { server.start(); String tokenUri = server.url("/oauth2/token").toString(); - when(this.providerDetails.getTokenUri()).thenReturn(tokenUri); + this.clientRegistrationBuilder.tokenUri(tokenUri); - Set requestedScopes = new LinkedHashSet<>(Arrays.asList("openid", "profile", "email", "address")); - when(this.authorizationRequest.getScopes()).thenReturn(requestedScopes); + OAuth2AuthorizationRequest authorizationRequest = + request().scope("openid", "profile", "email", "address").build(); + OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange( + authorizationRequest, this.authorizationResponse); OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse( - new OAuth2AuthorizationCodeGrantRequest(this.clientRegistration, this.authorizationExchange)); + new OAuth2AuthorizationCodeGrantRequest( + this.clientRegistrationBuilder.build(), authorizationExchange)); server.shutdown(); @@ -308,13 +305,16 @@ public class NimbusAuthorizationCodeTokenResponseClientTests { server.start(); String tokenUri = server.url("/oauth2/token").toString(); - when(this.providerDetails.getTokenUri()).thenReturn(tokenUri); + this.clientRegistrationBuilder.tokenUri(tokenUri); - Set requestedScopes = new LinkedHashSet<>(Arrays.asList("openid", "profile", "email", "address")); - when(this.authorizationRequest.getScopes()).thenReturn(requestedScopes); + OAuth2AuthorizationRequest authorizationRequest = + request().scope("openid", "profile", "email", "address").build(); + OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange( + authorizationRequest, this.authorizationResponse); OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse( - new OAuth2AuthorizationCodeGrantRequest(this.clientRegistration, this.authorizationExchange)); + new OAuth2AuthorizationCodeGrantRequest( + this.clientRegistrationBuilder.build(), authorizationExchange)); server.shutdown(); diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2AuthorizationCodeGrantRequestTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2AuthorizationCodeGrantRequestTests.java index a676902b93..7566102ae0 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2AuthorizationCodeGrantRequestTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2AuthorizationCodeGrantRequestTests.java @@ -17,31 +17,28 @@ package org.springframework.security.oauth2.client.endpoint; import org.junit.Before; import org.junit.Test; -import org.junit.runner.RunWith; -import org.powermock.core.classloader.annotations.PrepareForTest; -import org.powermock.modules.junit4.PowerMockRunner; + import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange; import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.Mockito.mock; +import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration; +import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationExchanges.success; /** * Tests for {@link OAuth2AuthorizationCodeGrantRequest}. * * @author Joe Grandja */ -@RunWith(PowerMockRunner.class) -@PrepareForTest({ClientRegistration.class, OAuth2AuthorizationExchange.class}) public class OAuth2AuthorizationCodeGrantRequestTests { private ClientRegistration clientRegistration; private OAuth2AuthorizationExchange authorizationExchange; @Before public void setUp() { - this.clientRegistration = mock(ClientRegistration.class); - this.authorizationExchange = mock(OAuth2AuthorizationExchange.class); + this.clientRegistration = clientRegistration().build(); + this.authorizationExchange = success(); } @Test(expected = IllegalArgumentException.class) diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeAuthenticationProviderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeAuthenticationProviderTests.java index cc86577ce0..2cc872351a 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeAuthenticationProviderTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeAuthenticationProviderTests.java @@ -15,16 +15,22 @@ */ package org.springframework.security.oauth2.client.oidc.authentication; +import java.time.Instant; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; -import org.junit.runner.RunWith; import org.mockito.ArgumentCaptor; import org.mockito.stubbing.Answer; -import org.powermock.api.mockito.PowerMockito; -import org.powermock.core.classloader.annotations.PrepareForTest; -import org.powermock.modules.junit4.PowerMockRunner; + import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper; @@ -36,7 +42,6 @@ import org.springframework.security.oauth2.client.registration.ClientRegistratio import org.springframework.security.oauth2.client.userinfo.OAuth2UserService; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; -import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange; @@ -47,33 +52,27 @@ import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames import org.springframework.security.oauth2.core.oidc.user.OidcUser; import org.springframework.security.oauth2.jwt.Jwt; import org.springframework.security.oauth2.jwt.JwtDecoder; - -import java.time.Instant; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashMap; -import java.util.LinkedHashSet; -import java.util.List; -import java.util.Map; -import java.util.Set; +import org.springframework.test.util.ReflectionTestUtils; import static org.assertj.core.api.Assertions.assertThat; import static org.hamcrest.CoreMatchers.containsString; -import static org.mockito.ArgumentMatchers.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyCollection; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; +import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration; +import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests.request; +import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationResponses.error; +import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationResponses.success; /** * Tests for {@link OidcAuthorizationCodeAuthenticationProvider}. * * @author Joe Grandja */ -@PrepareForTest({ClientRegistration.class, OAuth2AuthorizationRequest.class, OAuth2AuthorizationResponse.class, - OAuth2AccessTokenResponse.class, OidcAuthorizationCodeAuthenticationProvider.class}) -@RunWith(PowerMockRunner.class) public class OidcAuthorizationCodeAuthenticationProviderTests { private ClientRegistration clientRegistration; - private ClientRegistration.ProviderDetails providerDetails; private OAuth2AuthorizationRequest authorizationRequest; private OAuth2AuthorizationResponse authorizationResponse; private OAuth2AuthorizationExchange authorizationExchange; @@ -88,26 +87,16 @@ public class OidcAuthorizationCodeAuthenticationProviderTests { @Before @SuppressWarnings("unchecked") public void setUp() throws Exception { - this.clientRegistration = mock(ClientRegistration.class); - this.providerDetails = mock(ClientRegistration.ProviderDetails.class); - this.authorizationRequest = mock(OAuth2AuthorizationRequest.class); - this.authorizationResponse = mock(OAuth2AuthorizationResponse.class); + this.clientRegistration = clientRegistration().clientId("client1").build(); + this.authorizationRequest = request().scope("openid", "profile", "email").build(); + this.authorizationResponse = success().build(); this.authorizationExchange = new OAuth2AuthorizationExchange(this.authorizationRequest, this.authorizationResponse); this.accessTokenResponseClient = mock(OAuth2AccessTokenResponseClient.class); this.accessTokenResponse = this.accessTokenSuccessResponse(); this.userService = mock(OAuth2UserService.class); - this.authenticationProvider = PowerMockito.spy( - new OidcAuthorizationCodeAuthenticationProvider(this.accessTokenResponseClient, this.userService)); + this.authenticationProvider = + new OidcAuthorizationCodeAuthenticationProvider(this.accessTokenResponseClient, this.userService); - when(this.clientRegistration.getRegistrationId()).thenReturn("client-registration-id-1"); - when(this.clientRegistration.getClientId()).thenReturn("client1"); - when(this.clientRegistration.getProviderDetails()).thenReturn(this.providerDetails); - when(this.providerDetails.getJwkSetUri()).thenReturn("https://provider.com/oauth2/keys"); - when(this.authorizationRequest.getScopes()).thenReturn(new LinkedHashSet<>(Arrays.asList("openid", "profile", "email"))); - when(this.authorizationRequest.getState()).thenReturn("12345"); - when(this.authorizationResponse.getState()).thenReturn("12345"); - when(this.authorizationRequest.getRedirectUri()).thenReturn("http://example.com"); - when(this.authorizationResponse.getRedirectUri()).thenReturn("http://example.com"); when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(this.accessTokenResponse); } @@ -136,11 +125,13 @@ public class OidcAuthorizationCodeAuthenticationProviderTests { @Test public void authenticateWhenAuthorizationRequestDoesNotContainOpenidScopeThenReturnNull() { - when(this.authorizationRequest.getScopes()).thenReturn(new LinkedHashSet<>(Collections.singleton("scope1"))); + OAuth2AuthorizationRequest authorizationRequest = request().scope("scope1").build(); + OAuth2AuthorizationExchange authorizationExchange = + new OAuth2AuthorizationExchange(authorizationRequest, this.authorizationResponse); OAuth2LoginAuthenticationToken authentication = (OAuth2LoginAuthenticationToken) this.authenticationProvider.authenticate( - new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange)); + new OAuth2LoginAuthenticationToken(this.clientRegistration, authorizationExchange)); assertThat(authentication).isNull(); } @@ -150,11 +141,12 @@ public class OidcAuthorizationCodeAuthenticationProviderTests { this.exception.expect(OAuth2AuthenticationException.class); this.exception.expectMessage(containsString(OAuth2ErrorCodes.INVALID_SCOPE)); - when(this.authorizationResponse.statusError()).thenReturn(true); - when(this.authorizationResponse.getError()).thenReturn(new OAuth2Error(OAuth2ErrorCodes.INVALID_SCOPE)); + OAuth2AuthorizationResponse authorizationResponse = error().errorCode(OAuth2ErrorCodes.INVALID_SCOPE).build(); + OAuth2AuthorizationExchange authorizationExchange = + new OAuth2AuthorizationExchange(this.authorizationRequest, authorizationResponse); this.authenticationProvider.authenticate( - new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange)); + new OAuth2LoginAuthenticationToken(this.clientRegistration, authorizationExchange)); } @Test @@ -162,11 +154,12 @@ public class OidcAuthorizationCodeAuthenticationProviderTests { this.exception.expect(OAuth2AuthenticationException.class); this.exception.expectMessage(containsString("invalid_state_parameter")); - when(this.authorizationRequest.getState()).thenReturn("34567"); - when(this.authorizationResponse.getState()).thenReturn("89012"); + OAuth2AuthorizationResponse authorizationResponse = success().state("89012").build(); + OAuth2AuthorizationExchange authorizationExchange = + new OAuth2AuthorizationExchange(this.authorizationRequest, authorizationResponse); this.authenticationProvider.authenticate( - new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange)); + new OAuth2LoginAuthenticationToken(this.clientRegistration, authorizationExchange)); } @Test @@ -174,11 +167,12 @@ public class OidcAuthorizationCodeAuthenticationProviderTests { this.exception.expect(OAuth2AuthenticationException.class); this.exception.expectMessage(containsString("invalid_redirect_uri_parameter")); - when(this.authorizationRequest.getRedirectUri()).thenReturn("http://example1.com"); - when(this.authorizationResponse.getRedirectUri()).thenReturn("http://example2.com"); + OAuth2AuthorizationResponse authorizationResponse = success().redirectUri("http://example2.com").build(); + OAuth2AuthorizationExchange authorizationExchange = + new OAuth2AuthorizationExchange(this.authorizationRequest, authorizationResponse); this.authenticationProvider.authenticate( - new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange)); + new OAuth2LoginAuthenticationToken(this.clientRegistration, authorizationExchange)); } @Test @@ -201,10 +195,10 @@ public class OidcAuthorizationCodeAuthenticationProviderTests { this.exception.expect(OAuth2AuthenticationException.class); this.exception.expectMessage(containsString("missing_signature_verifier")); - when(this.providerDetails.getJwkSetUri()).thenReturn(null); + ClientRegistration clientRegistration = clientRegistration().jwkSetUri(null).build(); this.authenticationProvider.authenticate( - new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange)); + new OAuth2LoginAuthenticationToken(clientRegistration, this.authorizationExchange)); } @Test @@ -434,7 +428,8 @@ public class OidcAuthorizationCodeAuthenticationProviderTests { JwtDecoder jwtDecoder = mock(JwtDecoder.class); when(jwtDecoder.decode(anyString())).thenReturn(idToken); - PowerMockito.doReturn(jwtDecoder).when(this.authenticationProvider, "getJwtDecoder", any(ClientRegistration.class)); + ReflectionTestUtils.setField(this.authenticationProvider, + "jwtDecoders", Collections.singletonMap("registration-id", jwtDecoder)); } private OAuth2AccessTokenResponse accessTokenSuccessResponse() { diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserServiceTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserServiceTests.java index 6a87b4e49f..42ed9ffcfc 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserServiceTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserServiceTests.java @@ -15,6 +15,14 @@ */ package org.springframework.security.oauth2.client.oidc.userinfo; +import java.time.Instant; +import java.util.Arrays; +import java.util.HashMap; +import java.util.LinkedHashSet; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.TimeUnit; + import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockWebServer; import okhttp3.mockwebserver.RecordedRequest; @@ -23,17 +31,13 @@ import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; -import org.junit.runner.RunWith; -import org.powermock.core.classloader.annotations.PowerMockIgnore; -import org.powermock.core.classloader.annotations.PrepareForTest; -import org.powermock.modules.junit4.PowerMockRunner; + import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.userinfo.DefaultOAuth2UserService; import org.springframework.security.oauth2.core.AuthenticationMethod; -import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.oidc.IdTokenClaimNames; @@ -43,31 +47,19 @@ import org.springframework.security.oauth2.core.oidc.StandardClaimNames; import org.springframework.security.oauth2.core.oidc.user.OidcUser; import org.springframework.security.oauth2.core.oidc.user.OidcUserAuthority; -import java.util.Arrays; -import java.util.HashMap; -import java.util.LinkedHashSet; -import java.util.Map; -import java.util.Set; -import java.util.concurrent.TimeUnit; - import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.hamcrest.CoreMatchers.containsString; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; +import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration; +import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.scopes; /** * Tests for {@link OidcUserService}. * * @author Joe Grandja */ -@PowerMockIgnore({"okhttp3.*", "okio.Buffer"}) -@PrepareForTest(ClientRegistration.class) -@RunWith(PowerMockRunner.class) public class OidcUserServiceTests { - private ClientRegistration clientRegistration; - private ClientRegistration.ProviderDetails providerDetails; - private ClientRegistration.ProviderDetails.UserInfoEndpoint userInfoEndpoint; + private ClientRegistration.Builder clientRegistrationBuilder; private OAuth2AccessToken accessToken; private OidcIdToken idToken; private OidcUserService userService = new OidcUserService(); @@ -80,26 +72,17 @@ public class OidcUserServiceTests { public void setup() throws Exception { this.server = new MockWebServer(); this.server.start(); - this.clientRegistration = mock(ClientRegistration.class); - this.providerDetails = mock(ClientRegistration.ProviderDetails.class); - this.userInfoEndpoint = mock(ClientRegistration.ProviderDetails.UserInfoEndpoint.class); - when(this.clientRegistration.getProviderDetails()).thenReturn(this.providerDetails); - when(this.providerDetails.getUserInfoEndpoint()).thenReturn(this.userInfoEndpoint); - when(this.clientRegistration.getAuthorizationGrantType()).thenReturn(AuthorizationGrantType.AUTHORIZATION_CODE); + this.clientRegistrationBuilder = clientRegistration() + .userInfoUri(null) + .userInfoAuthenticationMethod(AuthenticationMethod.HEADER) + .userNameAttributeName(StandardClaimNames.SUB); - when(this.userInfoEndpoint.getAuthenticationMethod()).thenReturn(AuthenticationMethod.HEADER); - when(this.userInfoEndpoint.getUserNameAttributeName()).thenReturn(StandardClaimNames.SUB); + this.accessToken = scopes(OidcScopes.OPENID, OidcScopes.PROFILE); - this.accessToken = mock(OAuth2AccessToken.class); - Set authorizedScopes = new LinkedHashSet<>(Arrays.asList(OidcScopes.OPENID, OidcScopes.PROFILE)); - when(this.accessToken.getScopes()).thenReturn(authorizedScopes); - - this.idToken = mock(OidcIdToken.class); Map idTokenClaims = new HashMap<>(); idTokenClaims.put(IdTokenClaimNames.ISS, "https://provider.com"); idTokenClaims.put(IdTokenClaimNames.SUB, "subject1"); - when(this.idToken.getClaims()).thenReturn(idTokenClaims); - when(this.idToken.getSubject()).thenReturn("subject1"); + this.idToken = new OidcIdToken("access-token", Instant.MIN, Instant.MAX, idTokenClaims); this.userService.setOauth2UserService(new DefaultOAuth2UserService()); } @@ -123,22 +106,23 @@ public class OidcUserServiceTests { @Test public void loadUserWhenUserInfoUriIsNullThenUserInfoEndpointNotRequested() { - when(this.userInfoEndpoint.getUri()).thenReturn(null); - OidcUser user = this.userService.loadUser( - new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken)); + new OidcUserRequest(this.clientRegistrationBuilder.build(), this.accessToken, this.idToken)); assertThat(user.getUserInfo()).isNull(); } @Test public void loadUserWhenAuthorizedScopesDoesNotContainUserInfoScopesThenUserInfoEndpointNotRequested() { - Set authorizedScopes = new LinkedHashSet<>(Arrays.asList("scope1", "scope2")); - when(this.accessToken.getScopes()).thenReturn(authorizedScopes); + ClientRegistration clientRegistration = this.clientRegistrationBuilder + .userInfoUri("http://provider.com/user").build(); - when(this.userInfoEndpoint.getUri()).thenReturn("http://provider.com/user"); + Set authorizedScopes = new LinkedHashSet<>(Arrays.asList("scope1", "scope2")); + OAuth2AccessToken accessToken = new OAuth2AccessToken( + OAuth2AccessToken.TokenType.BEARER, "access-token", + Instant.MIN, Instant.MAX, authorizedScopes); OidcUser user = this.userService.loadUser( - new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken)); + new OidcUserRequest(clientRegistration, accessToken, this.idToken)); assertThat(user.getUserInfo()).isNull(); } @@ -156,11 +140,11 @@ public class OidcUserServiceTests { String userInfoUri = this.server.url("/user").toString(); - when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri); - when(this.accessToken.getTokenValue()).thenReturn("access-token"); + ClientRegistration clientRegistration = this.clientRegistrationBuilder + .userInfoUri(userInfoUri).build(); OidcUser user = this.userService.loadUser( - new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken)); + new OidcUserRequest(clientRegistration, this.accessToken, this.idToken)); assertThat(user.getIdToken()).isNotNull(); assertThat(user.getUserInfo()).isNotNull(); @@ -196,11 +180,11 @@ public class OidcUserServiceTests { String userInfoUri = this.server.url("/user").toString(); - when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri); - when(this.userInfoEndpoint.getUserNameAttributeName()).thenReturn(StandardClaimNames.EMAIL); - when(this.accessToken.getTokenValue()).thenReturn("access-token"); + ClientRegistration clientRegistration = this.clientRegistrationBuilder + .userInfoUri(userInfoUri) + .userNameAttributeName(StandardClaimNames.EMAIL).build(); - this.userService.loadUser(new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken)); + this.userService.loadUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken)); } @Test @@ -215,10 +199,10 @@ public class OidcUserServiceTests { String userInfoUri = this.server.url("/user").toString(); - when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri); - when(this.accessToken.getTokenValue()).thenReturn("access-token"); + ClientRegistration clientRegistration = this.clientRegistrationBuilder + .userInfoUri(userInfoUri).build(); - this.userService.loadUser(new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken)); + this.userService.loadUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken)); } @Test @@ -238,10 +222,10 @@ public class OidcUserServiceTests { String userInfoUri = this.server.url("/user").toString(); - when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri); - when(this.accessToken.getTokenValue()).thenReturn("access-token"); + ClientRegistration clientRegistration = this.clientRegistrationBuilder + .userInfoUri(userInfoUri).build(); - this.userService.loadUser(new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken)); + this.userService.loadUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken)); } @Test @@ -253,10 +237,10 @@ public class OidcUserServiceTests { String userInfoUri = server.url("/user").toString(); - when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri); - when(this.accessToken.getTokenValue()).thenReturn("access-token"); + ClientRegistration clientRegistration = this.clientRegistrationBuilder + .userInfoUri(userInfoUri).build(); - this.userService.loadUser(new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken)); + this.userService.loadUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken)); } @Test @@ -266,10 +250,10 @@ public class OidcUserServiceTests { String userInfoUri = "http://invalid-provider.com/user"; - when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri); - when(this.accessToken.getTokenValue()).thenReturn("access-token"); + ClientRegistration clientRegistration = this.clientRegistrationBuilder + .userInfoUri(userInfoUri).build(); - this.userService.loadUser(new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken)); + this.userService.loadUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken)); } @Test @@ -286,12 +270,12 @@ public class OidcUserServiceTests { String userInfoUri = this.server.url("/user").toString(); - when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri); - when(this.userInfoEndpoint.getUserNameAttributeName()).thenReturn(StandardClaimNames.EMAIL); - when(this.accessToken.getTokenValue()).thenReturn("access-token"); + ClientRegistration clientRegistration = this.clientRegistrationBuilder + .userInfoUri(userInfoUri) + .userNameAttributeName(StandardClaimNames.EMAIL).build(); OidcUser user = this.userService.loadUser( - new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken)); + new OidcUserRequest(clientRegistration, this.accessToken, this.idToken)); assertThat(user.getName()).isEqualTo("user1@example.com"); } @@ -311,10 +295,10 @@ public class OidcUserServiceTests { String userInfoUri = this.server.url("/user").toString(); - when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri); - when(this.accessToken.getTokenValue()).thenReturn("access-token"); + ClientRegistration clientRegistration = this.clientRegistrationBuilder + .userInfoUri(userInfoUri).build(); - this.userService.loadUser(new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken)); + this.userService.loadUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken)); assertThat(this.server.takeRequest(1, TimeUnit.SECONDS).getHeader(HttpHeaders.ACCEPT)) .isEqualTo(MediaType.APPLICATION_JSON_VALUE); } @@ -334,11 +318,10 @@ public class OidcUserServiceTests { String userInfoUri = this.server.url("/user").toString(); - when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri); - when(this.userInfoEndpoint.getAuthenticationMethod()).thenReturn(AuthenticationMethod.HEADER); - when(this.accessToken.getTokenValue()).thenReturn("access-token"); + ClientRegistration clientRegistration = this.clientRegistrationBuilder + .userInfoUri(userInfoUri).build(); - this.userService.loadUser(new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken)); + this.userService.loadUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken)); RecordedRequest request = this.server.takeRequest(); assertThat(request.getMethod()).isEqualTo(HttpMethod.GET.name()); assertThat(request.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_VALUE); @@ -360,11 +343,11 @@ public class OidcUserServiceTests { String userInfoUri = this.server.url("/user").toString(); - when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri); - when(this.userInfoEndpoint.getAuthenticationMethod()).thenReturn(AuthenticationMethod.FORM); - when(this.accessToken.getTokenValue()).thenReturn("access-token"); + ClientRegistration clientRegistration = this.clientRegistrationBuilder + .userInfoUri(userInfoUri) + .userInfoAuthenticationMethod(AuthenticationMethod.FORM).build(); - this.userService.loadUser(new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken)); + this.userService.loadUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken)); RecordedRequest request = this.server.takeRequest(); assertThat(request.getMethod()).isEqualTo(HttpMethod.POST.name()); assertThat(request.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_VALUE); diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/CustomUserTypesOAuth2UserServiceTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/CustomUserTypesOAuth2UserServiceTests.java index 1b3a9bc9c9..57d054eb34 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/CustomUserTypesOAuth2UserServiceTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/CustomUserTypesOAuth2UserServiceTests.java @@ -15,6 +15,12 @@ */ package org.springframework.security.oauth2.client.userinfo; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockWebServer; import org.junit.After; @@ -22,42 +28,29 @@ import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; -import org.junit.runner.RunWith; -import org.powermock.core.classloader.annotations.PowerMockIgnore; -import org.powermock.core.classloader.annotations.PrepareForTest; -import org.powermock.modules.junit4.PowerMockRunner; + import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.user.OAuth2User; -import java.util.Collection; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - import static org.assertj.core.api.Assertions.assertThat; import static org.hamcrest.CoreMatchers.containsString; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; +import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration; +import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.noScopes; /** * Tests for {@link CustomUserTypesOAuth2UserService}. * * @author Joe Grandja */ -@PowerMockIgnore("okhttp3.*") -@PrepareForTest(ClientRegistration.class) -@RunWith(PowerMockRunner.class) public class CustomUserTypesOAuth2UserServiceTests { - private ClientRegistration clientRegistration; - private ClientRegistration.ProviderDetails providerDetails; - private ClientRegistration.ProviderDetails.UserInfoEndpoint userInfoEndpoint; + private ClientRegistration.Builder clientRegistrationBuilder; private OAuth2AccessToken accessToken; private CustomUserTypesOAuth2UserService userService; private MockWebServer server; @@ -69,14 +62,9 @@ public class CustomUserTypesOAuth2UserServiceTests { public void setUp() throws Exception { this.server = new MockWebServer(); this.server.start(); - this.clientRegistration = mock(ClientRegistration.class); - this.providerDetails = mock(ClientRegistration.ProviderDetails.class); - this.userInfoEndpoint = mock(ClientRegistration.ProviderDetails.UserInfoEndpoint.class); - when(this.clientRegistration.getProviderDetails()).thenReturn(this.providerDetails); - when(this.providerDetails.getUserInfoEndpoint()).thenReturn(this.userInfoEndpoint); String registrationId = "client-registration-id-1"; - when(this.clientRegistration.getRegistrationId()).thenReturn(registrationId); - this.accessToken = mock(OAuth2AccessToken.class); + this.clientRegistrationBuilder = clientRegistration().registrationId(registrationId); + this.accessToken = noScopes(); Map> customUserTypes = new HashMap<>(); customUserTypes.put(registrationId, CustomOAuth2User.class); @@ -120,9 +108,10 @@ public class CustomUserTypesOAuth2UserServiceTests { @Test public void loadUserWhenCustomUserTypeNotFoundThenReturnNull() { - when(this.clientRegistration.getRegistrationId()).thenReturn("other-client-registration-id-1"); + ClientRegistration clientRegistration = + clientRegistration().registrationId("other-client-registration-id-1").build(); - OAuth2User user = this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken)); + OAuth2User user = this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)); assertThat(user).isNull(); } @@ -138,10 +127,10 @@ public class CustomUserTypesOAuth2UserServiceTests { String userInfoUri = this.server.url("/user").toString(); - when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri); - when(this.accessToken.getTokenValue()).thenReturn("access-token"); + ClientRegistration clientRegistration = this.clientRegistrationBuilder + .userInfoUri(userInfoUri).build(); - OAuth2User user = this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken)); + OAuth2User user = this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)); assertThat(user.getName()).isEqualTo("first last"); assertThat(user.getAttributes().size()).isEqualTo(4); @@ -169,10 +158,10 @@ public class CustomUserTypesOAuth2UserServiceTests { String userInfoUri = this.server.url("/user").toString(); - when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri); - when(this.accessToken.getTokenValue()).thenReturn("access-token"); + ClientRegistration clientRegistration = this.clientRegistrationBuilder + .userInfoUri(userInfoUri).build(); - this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken)); + this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)); } @Test @@ -184,10 +173,10 @@ public class CustomUserTypesOAuth2UserServiceTests { String userInfoUri = this.server.url("/user").toString(); - when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri); - when(this.accessToken.getTokenValue()).thenReturn("access-token"); + ClientRegistration clientRegistration = this.clientRegistrationBuilder + .userInfoUri(userInfoUri).build(); - this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken)); + this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)); } @Test @@ -197,10 +186,18 @@ public class CustomUserTypesOAuth2UserServiceTests { String userInfoUri = "http://invalid-provider.com/user"; - when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri); - when(this.accessToken.getTokenValue()).thenReturn("access-token"); + ClientRegistration clientRegistration = this.clientRegistrationBuilder + .userInfoUri(userInfoUri).build(); - this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken)); + this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)); + } + + private ClientRegistration.Builder withRegistrationId(String registrationId) { + return ClientRegistration + .withRegistrationId(registrationId) + .authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS) + .clientId("client") + .tokenUri("/token"); } private MockResponse jsonResponse(String json) { diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/DefaultOAuth2UserServiceTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/DefaultOAuth2UserServiceTests.java index a2fc4733bd..09175384e6 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/DefaultOAuth2UserServiceTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/DefaultOAuth2UserServiceTests.java @@ -15,6 +15,8 @@ */ package org.springframework.security.oauth2.client.userinfo; +import java.util.concurrent.TimeUnit; + import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockWebServer; import okhttp3.mockwebserver.RecordedRequest; @@ -23,10 +25,7 @@ import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; -import org.junit.runner.RunWith; -import org.powermock.core.classloader.annotations.PowerMockIgnore; -import org.powermock.core.classloader.annotations.PrepareForTest; -import org.powermock.modules.junit4.PowerMockRunner; + import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; @@ -37,25 +36,18 @@ import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.user.OAuth2User; import org.springframework.security.oauth2.core.user.OAuth2UserAuthority; -import java.util.concurrent.TimeUnit; - import static org.assertj.core.api.Assertions.assertThat; import static org.hamcrest.CoreMatchers.containsString; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; +import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration; +import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.noScopes; /** * Tests for {@link DefaultOAuth2UserService}. * * @author Joe Grandja */ -@PowerMockIgnore({"okhttp3.*", "okio.Buffer"}) -@PrepareForTest(ClientRegistration.class) -@RunWith(PowerMockRunner.class) public class DefaultOAuth2UserServiceTests { - private ClientRegistration clientRegistration; - private ClientRegistration.ProviderDetails providerDetails; - private ClientRegistration.ProviderDetails.UserInfoEndpoint userInfoEndpoint; + private ClientRegistration.Builder clientRegistrationBuilder; private OAuth2AccessToken accessToken; private DefaultOAuth2UserService userService = new DefaultOAuth2UserService(); private MockWebServer server; @@ -67,12 +59,10 @@ public class DefaultOAuth2UserServiceTests { public void setup() throws Exception { this.server = new MockWebServer(); this.server.start(); - this.clientRegistration = mock(ClientRegistration.class); - this.providerDetails = mock(ClientRegistration.ProviderDetails.class); - this.userInfoEndpoint = mock(ClientRegistration.ProviderDetails.UserInfoEndpoint.class); - when(this.clientRegistration.getProviderDetails()).thenReturn(this.providerDetails); - when(this.providerDetails.getUserInfoEndpoint()).thenReturn(this.userInfoEndpoint); - this.accessToken = mock(OAuth2AccessToken.class); + this.clientRegistrationBuilder = clientRegistration() + .userInfoUri(null) + .userNameAttributeName(null); + this.accessToken = noScopes(); } @After @@ -103,8 +93,8 @@ public class DefaultOAuth2UserServiceTests { this.exception.expect(OAuth2AuthenticationException.class); this.exception.expectMessage(containsString("missing_user_info_uri")); - when(this.userInfoEndpoint.getUri()).thenReturn(null); - this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken)); + ClientRegistration clientRegistration = this.clientRegistrationBuilder.build(); + this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)); } @Test @@ -112,9 +102,9 @@ public class DefaultOAuth2UserServiceTests { this.exception.expect(OAuth2AuthenticationException.class); this.exception.expectMessage(containsString("missing_user_name_attribute")); - when(this.userInfoEndpoint.getUri()).thenReturn("http://provider.com/user"); - when(this.userInfoEndpoint.getUserNameAttributeName()).thenReturn(null); - this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken)); + ClientRegistration clientRegistration = this.clientRegistrationBuilder + .userInfoUri("http://provider.com/user").build(); + this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)); } @Test @@ -131,12 +121,12 @@ public class DefaultOAuth2UserServiceTests { String userInfoUri = this.server.url("/user").toString(); - when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri); - when(this.userInfoEndpoint.getAuthenticationMethod()).thenReturn(AuthenticationMethod.HEADER); - when(this.userInfoEndpoint.getUserNameAttributeName()).thenReturn("user-name"); - when(this.accessToken.getTokenValue()).thenReturn("access-token"); + ClientRegistration clientRegistration = this.clientRegistrationBuilder + .userInfoUri(userInfoUri) + .userInfoAuthenticationMethod(AuthenticationMethod.HEADER) + .userNameAttributeName("user-name").build(); - OAuth2User user = this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken)); + OAuth2User user = this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)); assertThat(user.getName()).isEqualTo("user1"); assertThat(user.getAttributes().size()).isEqualTo(6); @@ -171,12 +161,12 @@ public class DefaultOAuth2UserServiceTests { String userInfoUri = this.server.url("/user").toString(); - when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri); - when(this.userInfoEndpoint.getAuthenticationMethod()).thenReturn(AuthenticationMethod.HEADER); - when(this.userInfoEndpoint.getUserNameAttributeName()).thenReturn("user-name"); - when(this.accessToken.getTokenValue()).thenReturn("access-token"); + ClientRegistration clientRegistration = this.clientRegistrationBuilder + .userInfoUri(userInfoUri) + .userInfoAuthenticationMethod(AuthenticationMethod.HEADER) + .userNameAttributeName("user-name").build(); - this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken)); + this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)); } @Test @@ -194,12 +184,12 @@ public class DefaultOAuth2UserServiceTests { String userInfoUri = this.server.url("/user").toString(); - when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri); - when(this.userInfoEndpoint.getAuthenticationMethod()).thenReturn(AuthenticationMethod.HEADER); - when(this.userInfoEndpoint.getUserNameAttributeName()).thenReturn("user-name"); - when(this.accessToken.getTokenValue()).thenReturn("access-token"); + ClientRegistration clientRegistration = this.clientRegistrationBuilder + .userInfoUri(userInfoUri) + .userInfoAuthenticationMethod(AuthenticationMethod.HEADER) + .userNameAttributeName("user-name").build(); - this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken)); + this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)); } @Test @@ -215,12 +205,12 @@ public class DefaultOAuth2UserServiceTests { String userInfoUri = this.server.url("/user").toString(); - when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri); - when(this.userInfoEndpoint.getAuthenticationMethod()).thenReturn(AuthenticationMethod.HEADER); - when(this.userInfoEndpoint.getUserNameAttributeName()).thenReturn("user-name"); - when(this.accessToken.getTokenValue()).thenReturn("access-token"); + ClientRegistration clientRegistration = this.clientRegistrationBuilder + .userInfoUri(userInfoUri) + .userInfoAuthenticationMethod(AuthenticationMethod.HEADER) + .userNameAttributeName("user-name").build(); - this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken)); + this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)); } @Test @@ -232,12 +222,12 @@ public class DefaultOAuth2UserServiceTests { String userInfoUri = this.server.url("/user").toString(); - when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri); - when(this.userInfoEndpoint.getAuthenticationMethod()).thenReturn(AuthenticationMethod.HEADER); - when(this.userInfoEndpoint.getUserNameAttributeName()).thenReturn("user-name"); - when(this.accessToken.getTokenValue()).thenReturn("access-token"); + ClientRegistration clientRegistration = this.clientRegistrationBuilder + .userInfoUri(userInfoUri) + .userInfoAuthenticationMethod(AuthenticationMethod.HEADER) + .userNameAttributeName("user-name").build(); - this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken)); + this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)); } @Test @@ -247,12 +237,12 @@ public class DefaultOAuth2UserServiceTests { String userInfoUri = "http://invalid-provider.com/user"; - when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri); - when(this.userInfoEndpoint.getAuthenticationMethod()).thenReturn(AuthenticationMethod.HEADER); - when(this.userInfoEndpoint.getUserNameAttributeName()).thenReturn("user-name"); - when(this.accessToken.getTokenValue()).thenReturn("access-token"); + ClientRegistration clientRegistration = this.clientRegistrationBuilder + .userInfoUri(userInfoUri) + .userInfoAuthenticationMethod(AuthenticationMethod.HEADER) + .userNameAttributeName("user-name").build(); - this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken)); + this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)); } // gh-5294 @@ -270,12 +260,12 @@ public class DefaultOAuth2UserServiceTests { String userInfoUri = this.server.url("/user").toString(); - when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri); - when(this.userInfoEndpoint.getAuthenticationMethod()).thenReturn(AuthenticationMethod.HEADER); - when(this.userInfoEndpoint.getUserNameAttributeName()).thenReturn("user-name"); - when(this.accessToken.getTokenValue()).thenReturn("access-token"); + ClientRegistration clientRegistration = this.clientRegistrationBuilder + .userInfoUri(userInfoUri) + .userInfoAuthenticationMethod(AuthenticationMethod.HEADER) + .userNameAttributeName("user-name").build(); - this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken)); + this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)); assertThat(this.server.takeRequest(1, TimeUnit.SECONDS).getHeader(HttpHeaders.ACCEPT)) .isEqualTo(MediaType.APPLICATION_JSON_VALUE); } @@ -295,12 +285,12 @@ public class DefaultOAuth2UserServiceTests { String userInfoUri = this.server.url("/user").toString(); - when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri); - when(this.userInfoEndpoint.getAuthenticationMethod()).thenReturn(AuthenticationMethod.HEADER); - when(this.userInfoEndpoint.getUserNameAttributeName()).thenReturn("user-name"); - when(this.accessToken.getTokenValue()).thenReturn("access-token"); + ClientRegistration clientRegistration = this.clientRegistrationBuilder + .userInfoUri(userInfoUri) + .userInfoAuthenticationMethod(AuthenticationMethod.HEADER) + .userNameAttributeName("user-name").build(); - this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken)); + this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)); RecordedRequest request = this.server.takeRequest(); assertThat(request.getMethod()).isEqualTo(HttpMethod.GET.name()); assertThat(request.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_VALUE); @@ -322,12 +312,12 @@ public class DefaultOAuth2UserServiceTests { String userInfoUri = this.server.url("/user").toString(); - when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri); - when(this.userInfoEndpoint.getAuthenticationMethod()).thenReturn(AuthenticationMethod.FORM); - when(this.userInfoEndpoint.getUserNameAttributeName()).thenReturn("user-name"); - when(this.accessToken.getTokenValue()).thenReturn("access-token"); + ClientRegistration clientRegistration = this.clientRegistrationBuilder + .userInfoUri(userInfoUri) + .userInfoAuthenticationMethod(AuthenticationMethod.FORM) + .userNameAttributeName("user-name").build(); - this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken)); + this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)); RecordedRequest request = this.server.takeRequest(); assertThat(request.getMethod()).isEqualTo(HttpMethod.POST.name()); assertThat(request.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_VALUE); diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilterTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilterTests.java index 4ba66595b7..1a01ec536b 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilterTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilterTests.java @@ -15,13 +15,17 @@ */ package org.springframework.security.oauth2.client.web; +import java.util.HashMap; +import java.util.Map; +import javax.servlet.FilterChain; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import javax.servlet.http.HttpSession; + import org.junit.After; import org.junit.Before; import org.junit.Test; -import org.junit.runner.RunWith; -import org.powermock.core.classloader.annotations.PowerMockIgnore; -import org.powermock.core.classloader.annotations.PrepareForTest; -import org.powermock.modules.junit4.PowerMockRunner; + import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.authentication.AnonymousAuthenticationToken; @@ -39,36 +43,31 @@ import org.springframework.security.oauth2.client.registration.ClientRegistratio import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.InMemoryClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; -import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AuthorizationException; import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2ErrorCodes; -import org.springframework.security.oauth2.core.OAuth2RefreshToken; -import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.web.savedrequest.HttpSessionRequestCache; import org.springframework.security.web.savedrequest.RequestCache; -import javax.servlet.FilterChain; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; -import javax.servlet.http.HttpSession; -import java.util.HashMap; -import java.util.Map; - import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.mockito.Mockito.*; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.noScopes; +import static org.springframework.security.oauth2.core.TestOAuth2RefreshTokens.refreshToken; +import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationExchanges.success; +import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests.request; /** * Tests for {@link OAuth2AuthorizationCodeGrantFilter}. * * @author Joe Grandja */ -@PowerMockIgnore("javax.security.*") -@PrepareForTest({OAuth2AuthorizationRequest.class, OAuth2AuthorizationExchange.class, OAuth2AuthorizationCodeGrantFilter.class}) -@RunWith(PowerMockRunner.class) public class OAuth2AuthorizationCodeGrantFilterTests { private ClientRegistration registration1; private String principalName1 = "principal-1"; @@ -367,19 +366,15 @@ public class OAuth2AuthorizationCodeGrantFilterTests { ClientRegistration registration) { Map additionalParameters = new HashMap<>(); additionalParameters.put(OAuth2ParameterNames.REGISTRATION_ID, registration.getRegistrationId()); - OAuth2AuthorizationRequest authorizationRequest = mock(OAuth2AuthorizationRequest.class); - when(authorizationRequest.getAdditionalParameters()).thenReturn(additionalParameters); - when(authorizationRequest.getRedirectUri()).thenReturn(request.getRequestURL().toString()); - when(authorizationRequest.getState()).thenReturn("state"); + OAuth2AuthorizationRequest authorizationRequest = request() + .additionalParameters(additionalParameters) + .redirectUri(request.getRequestURL().toString()).build(); this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, request, response); } private void setUpAuthenticationResult(ClientRegistration registration) { - OAuth2AuthorizationCodeAuthenticationToken authentication = mock(OAuth2AuthorizationCodeAuthenticationToken.class); - when(authentication.getClientRegistration()).thenReturn(registration); - when(authentication.getAuthorizationExchange()).thenReturn(mock(OAuth2AuthorizationExchange.class)); - when(authentication.getAccessToken()).thenReturn(mock(OAuth2AccessToken.class)); - when(authentication.getRefreshToken()).thenReturn(mock(OAuth2RefreshToken.class)); + OAuth2AuthorizationCodeAuthenticationToken authentication = + new OAuth2AuthorizationCodeAuthenticationToken(registration, success(), noScopes(), refreshToken()); when(this.authenticationManager.authenticate(any(Authentication.class))).thenReturn(authentication); } } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2LoginAuthenticationFilterTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2LoginAuthenticationFilterTests.java index 096befaca6..6b0bd666fd 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2LoginAuthenticationFilterTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2LoginAuthenticationFilterTests.java @@ -15,13 +15,16 @@ */ package org.springframework.security.oauth2.client.web; +import java.util.HashMap; +import java.util.Map; +import javax.servlet.FilterChain; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + import org.junit.Before; import org.junit.Test; -import org.junit.runner.RunWith; import org.mockito.ArgumentCaptor; -import org.powermock.core.classloader.annotations.PowerMockIgnore; -import org.powermock.core.classloader.annotations.PrepareForTest; -import org.powermock.modules.junit4.PowerMockRunner; + import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.authentication.AuthenticationManager; @@ -42,7 +45,6 @@ import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.OAuth2RefreshToken; -import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; @@ -51,24 +53,22 @@ import org.springframework.security.web.authentication.AuthenticationFailureHand import org.springframework.security.web.util.UrlUtils; import org.springframework.web.util.UriComponentsBuilder; -import javax.servlet.FilterChain; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; -import java.util.HashMap; -import java.util.Map; - import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.mockito.Mockito.*; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyZeroInteractions; +import static org.mockito.Mockito.when; +import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationExchanges.success; /** * Tests for {@link OAuth2LoginAuthenticationFilter}. * * @author Joe Grandja */ -@PowerMockIgnore("javax.security.*") -@PrepareForTest({OAuth2AuthorizationExchange.class, OAuth2LoginAuthenticationFilter.class}) -@RunWith(PowerMockRunner.class) public class OAuth2LoginAuthenticationFilterTests { private ClientRegistration registration1; private ClientRegistration registration2; @@ -440,7 +440,7 @@ public class OAuth2LoginAuthenticationFilterTests { when(this.loginAuthentication.getName()).thenReturn(this.principalName1); when(this.loginAuthentication.getAuthorities()).thenReturn(AuthorityUtils.createAuthorityList("ROLE_USER")); when(this.loginAuthentication.getClientRegistration()).thenReturn(registration); - when(this.loginAuthentication.getAuthorizationExchange()).thenReturn(mock(OAuth2AuthorizationExchange.class)); + when(this.loginAuthentication.getAuthorizationExchange()).thenReturn(success()); when(this.loginAuthentication.getAccessToken()).thenReturn(mock(OAuth2AccessToken.class)); when(this.loginAuthentication.getRefreshToken()).thenReturn(mock(OAuth2RefreshToken.class)); when(this.loginAuthentication.isAuthenticated()).thenReturn(true);