Create OAuth2AuthorizationResponse lazily

This commit creates `OAuth2AuthorizationResponse` as lazily as possible to prevent the creation when `authorizationRequest` is `null`.

Fixes gh-4848
This commit is contained in:
Johnny Lim 2017-11-18 00:06:06 +09:00 committed by Joe Grandja
parent c04b3b4114
commit edccafca84
2 changed files with 25 additions and 2 deletions

View File

@ -108,7 +108,6 @@ public class OAuth2LoginAuthenticationFilter extends AbstractAuthenticationProce
OAuth2Error oauth2Error = new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST); OAuth2Error oauth2Error = new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
} }
OAuth2AuthorizationResponse authorizationResponse = this.convert(request);
OAuth2AuthorizationRequest authorizationRequest = this.authorizationRequestRepository.loadAuthorizationRequest(request); OAuth2AuthorizationRequest authorizationRequest = this.authorizationRequestRepository.loadAuthorizationRequest(request);
if (authorizationRequest == null) { if (authorizationRequest == null) {
@ -120,6 +119,8 @@ public class OAuth2LoginAuthenticationFilter extends AbstractAuthenticationProce
String registrationId = (String) authorizationRequest.getAdditionalParameters().get(OAuth2ParameterNames.REGISTRATION_ID); String registrationId = (String) authorizationRequest.getAdditionalParameters().get(OAuth2ParameterNames.REGISTRATION_ID);
ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(registrationId); ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(registrationId);
OAuth2AuthorizationResponse authorizationResponse = this.convert(request);
OAuth2LoginAuthenticationToken authenticationRequest = new OAuth2LoginAuthenticationToken( OAuth2LoginAuthenticationToken authenticationRequest = new OAuth2LoginAuthenticationToken(
clientRegistration, new OAuth2AuthorizationExchange(authorizationRequest, authorizationResponse)); clientRegistration, new OAuth2AuthorizationExchange(authorizationRequest, authorizationResponse));
authenticationRequest.setDetails(this.authenticationDetailsSource.buildDetails(request)); authenticationRequest.setDetails(this.authenticationDetailsSource.buildDetails(request));

View File

@ -19,6 +19,7 @@ import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.mockito.ArgumentCaptor; import org.mockito.ArgumentCaptor;
import org.powermock.api.mockito.PowerMockito;
import org.powermock.core.classloader.annotations.PowerMockIgnore; import org.powermock.core.classloader.annotations.PowerMockIgnore;
import org.powermock.core.classloader.annotations.PrepareForTest; import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner; import org.powermock.modules.junit4.PowerMockRunner;
@ -53,8 +54,10 @@ import java.util.HashMap;
import java.util.Map; import java.util.Map;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.Assert.fail;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.*; import static org.mockito.Mockito.*;
import static org.powermock.api.mockito.PowerMockito.verifyPrivate;
/** /**
* Tests for {@link OAuth2LoginAuthenticationFilter}. * Tests for {@link OAuth2LoginAuthenticationFilter}.
@ -62,7 +65,7 @@ import static org.mockito.Mockito.*;
* @author Joe Grandja * @author Joe Grandja
*/ */
@PowerMockIgnore("javax.security.*") @PowerMockIgnore("javax.security.*")
@PrepareForTest({OAuth2AuthorizationRequest.class, OAuth2AuthorizationExchange.class}) @PrepareForTest({OAuth2AuthorizationRequest.class, OAuth2AuthorizationExchange.class, OAuth2LoginAuthenticationFilter.class})
@RunWith(PowerMockRunner.class) @RunWith(PowerMockRunner.class)
public class OAuth2LoginAuthenticationFilterTests { public class OAuth2LoginAuthenticationFilterTests {
private ClientRegistration registration1; private ClientRegistration registration1;
@ -263,6 +266,25 @@ public class OAuth2LoginAuthenticationFilterTests {
verify(this.filter).attemptAuthentication(any(HttpServletRequest.class), any(HttpServletResponse.class)); verify(this.filter).attemptAuthentication(any(HttpServletRequest.class), any(HttpServletResponse.class));
} }
@Test
public void attemptAuthenticationWhenAuthorizationRequestIsNullThenAuthorizationResponseNotCreated() throws Exception {
OAuth2LoginAuthenticationFilter filter = PowerMockito.spy(new OAuth2LoginAuthenticationFilter(
this.clientRegistrationRepository, this.authorizedClientService));
MockHttpServletRequest request = new MockHttpServletRequest();
request.addParameter(OAuth2ParameterNames.CODE, "code");
request.addParameter(OAuth2ParameterNames.STATE, "state");
MockHttpServletResponse response = new MockHttpServletResponse();
try {
filter.attemptAuthentication(request, response);
fail();
} catch (OAuth2AuthenticationException ex) {
verifyPrivate(filter, never()).invoke("convert", any(HttpServletRequest.class));
}
}
private void setUpAuthorizationRequest(HttpServletRequest request, HttpServletResponse response, private void setUpAuthorizationRequest(HttpServletRequest request, HttpServletResponse response,
ClientRegistration registration) { ClientRegistration registration) {
OAuth2AuthorizationRequest authorizationRequest = mock(OAuth2AuthorizationRequest.class); OAuth2AuthorizationRequest authorizationRequest = mock(OAuth2AuthorizationRequest.class);