diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/AuthorizationCodeRequestRedirectFilter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/AuthorizationCodeRequestRedirectFilter.java index d6819dd148..b8b0cea523 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/AuthorizationCodeRequestRedirectFilter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/AuthorizationCodeRequestRedirectFilter.java @@ -126,7 +126,7 @@ public class AuthorizationCodeRequestRedirectFilter extends OncePerRequestFilter .state(this.stateGenerator.generateKey()) .build(); - this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequestAttributes, request); + this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequestAttributes, request, response); URI redirectUri = this.authorizationUriBuilder.build(authorizationRequestAttributes); this.authorizationRedirectStrategy.sendRedirect(request, response, redirectUri.toString()); diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/AuthorizationRequestRepository.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/AuthorizationRequestRepository.java index 3954a03bb5..ecacce1264 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/AuthorizationRequestRepository.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/AuthorizationRequestRepository.java @@ -18,6 +18,7 @@ package org.springframework.security.oauth2.client.authentication; import org.springframework.security.oauth2.core.endpoint.AuthorizationRequestAttributes; import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; /** * Implementations of this interface are responsible for the persistence @@ -38,7 +39,8 @@ public interface AuthorizationRequestRepository { AuthorizationRequestAttributes loadAuthorizationRequest(HttpServletRequest request); - void saveAuthorizationRequest(AuthorizationRequestAttributes authorizationRequest, HttpServletRequest request); + void saveAuthorizationRequest(AuthorizationRequestAttributes authorizationRequest, HttpServletRequest request, + HttpServletResponse response); AuthorizationRequestAttributes removeAuthorizationRequest(HttpServletRequest request); diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/HttpSessionAuthorizationRequestRepository.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/HttpSessionAuthorizationRequestRepository.java index dfc343d9b0..930f1bd208 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/HttpSessionAuthorizationRequestRepository.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/HttpSessionAuthorizationRequestRepository.java @@ -18,6 +18,7 @@ package org.springframework.security.oauth2.client.authentication; import org.springframework.security.oauth2.core.endpoint.AuthorizationRequestAttributes; import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpSession; /** @@ -44,7 +45,8 @@ public final class HttpSessionAuthorizationRequestRepository implements Authoriz } @Override - public void saveAuthorizationRequest(AuthorizationRequestAttributes authorizationRequest, HttpServletRequest request) { + public void saveAuthorizationRequest(AuthorizationRequestAttributes authorizationRequest, HttpServletRequest request, + HttpServletResponse response) { if (authorizationRequest == null) { this.removeAuthorizationRequest(request); return; diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/AuthorizationCodeAuthenticationProcessingFilterTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/AuthorizationCodeAuthenticationProcessingFilterTests.java index f7cd49b88e..df25a49aae 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/AuthorizationCodeAuthenticationProcessingFilterTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/AuthorizationCodeAuthenticationProcessingFilterTests.java @@ -106,8 +106,8 @@ public class AuthorizationCodeAuthenticationProcessingFilterTests { String state = "some state"; request.addParameter(OAuth2Parameter.CODE, authCode); request.addParameter(OAuth2Parameter.STATE, state); - setupAuthorizationRequest(authorizationRequestRepository, request, clientRegistration, state); MockHttpServletResponse response = new MockHttpServletResponse(); + setupAuthorizationRequest(authorizationRequestRepository, request, response, clientRegistration, state); FilterChain filterChain = mock(FilterChain.class); filter.doFilter(request, response, filterChain); @@ -156,8 +156,8 @@ public class AuthorizationCodeAuthenticationProcessingFilterTests { String state = "some other state"; request.addParameter(OAuth2Parameter.CODE, authCode); request.addParameter(OAuth2Parameter.STATE, state); - setupAuthorizationRequest(authorizationRequestRepository, request, clientRegistration, "some state"); MockHttpServletResponse response = new MockHttpServletResponse(); + setupAuthorizationRequest(authorizationRequestRepository, request, response, clientRegistration, "some state"); FilterChain filterChain = mock(FilterChain.class); filter.doFilter(request, response, filterChain); @@ -181,8 +181,8 @@ public class AuthorizationCodeAuthenticationProcessingFilterTests { String state = "some state"; request.addParameter(OAuth2Parameter.CODE, authCode); request.addParameter(OAuth2Parameter.STATE, state); - setupAuthorizationRequest(authorizationRequestRepository, request, clientRegistration, state); MockHttpServletResponse response = new MockHttpServletResponse(); + setupAuthorizationRequest(authorizationRequestRepository, request, response, clientRegistration, state); FilterChain filterChain = mock(FilterChain.class); filter.doFilter(request, response, filterChain); @@ -227,6 +227,7 @@ public class AuthorizationCodeAuthenticationProcessingFilterTests { private void setupAuthorizationRequest(AuthorizationRequestRepository authorizationRequestRepository, HttpServletRequest request, + HttpServletResponse response, ClientRegistration clientRegistration, String state) { @@ -239,7 +240,7 @@ public class AuthorizationCodeAuthenticationProcessingFilterTests { .state(state) .build(); - authorizationRequestRepository.saveAuthorizationRequest(authorizationRequestAttributes, request); + authorizationRequestRepository.saveAuthorizationRequest(authorizationRequestAttributes, request, response); } private MockHttpServletRequest setupRequest(ClientRegistration clientRegistration) {