diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/AuthorizationCodeAuthenticationFilter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/AuthorizationCodeAuthenticationFilter.java index 58d2e12ace..34925510aa 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/AuthorizationCodeAuthenticationFilter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/AuthorizationCodeAuthenticationFilter.java @@ -25,6 +25,7 @@ import org.springframework.security.oauth2.client.authentication.OAuth2ClientAut import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.OAuth2ErrorCode; import org.springframework.security.oauth2.core.endpoint.AuthorizationRequest; import org.springframework.security.oauth2.core.endpoint.AuthorizationResponse; import org.springframework.security.oauth2.core.endpoint.OAuth2Parameter; @@ -95,6 +96,12 @@ public class AuthorizationCodeAuthenticationFilter extends AbstractAuthenticatio public Authentication attemptAuthentication(HttpServletRequest request, HttpServletResponse response) throws AuthenticationException, IOException, ServletException { + if (!this.authorizationResponseSuccess(request) && !this.authorizationResponseError(request)) { + OAuth2Error oauth2Error = new OAuth2Error(OAuth2ErrorCode.INVALID_REQUEST); + throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); + } + AuthorizationResponse authorizationResponse = this.convert(request); + AuthorizationRequest authorizationRequest = this.authorizationRequestRepository.loadAuthorizationRequest(request); if (authorizationRequest == null) { OAuth2Error oauth2Error = new OAuth2Error(AUTHORIZATION_REQUEST_NOT_FOUND_ERROR_CODE); @@ -102,8 +109,6 @@ public class AuthorizationCodeAuthenticationFilter extends AbstractAuthenticatio } this.authorizationRequestRepository.removeAuthorizationRequest(request); - AuthorizationResponse authorizationResponse = this.convert(request); - String registrationId = (String)authorizationRequest.getAdditionalParameters().get(OAuth2Parameter.REGISTRATION_ID); ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(registrationId); @@ -144,10 +149,6 @@ public class AuthorizationCodeAuthenticationFilter extends AbstractAuthenticatio } private AuthorizationResponse convert(HttpServletRequest request) { - if (!this.authorizationResponseSuccess(request) && !this.authorizationResponseError(request)) { - return null; - } - String code = request.getParameter(OAuth2Parameter.CODE); String errorCode = request.getParameter(OAuth2Parameter.ERROR); String state = request.getParameter(OAuth2Parameter.STATE);