Added HttpServletResponse to AuthorizationRequestRepository

This change enables AuthorizationRequestRepository to possibly save the AuthorizationRequestAttributes to a cookie.

Fixes gh-4446
This commit is contained in:
Luander Ribeiro 2017-07-24 20:43:20 +02:00 committed by Joe Grandja
parent ef1de5eda0
commit 65734414f7
4 changed files with 12 additions and 7 deletions

View File

@ -126,7 +126,7 @@ public class AuthorizationCodeRequestRedirectFilter extends OncePerRequestFilter
.state(this.stateGenerator.generateKey()) .state(this.stateGenerator.generateKey())
.build(); .build();
this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequestAttributes, request); this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequestAttributes, request, response);
URI redirectUri = this.authorizationUriBuilder.build(authorizationRequestAttributes); URI redirectUri = this.authorizationUriBuilder.build(authorizationRequestAttributes);
this.authorizationRedirectStrategy.sendRedirect(request, response, redirectUri.toString()); this.authorizationRedirectStrategy.sendRedirect(request, response, redirectUri.toString());

View File

@ -18,6 +18,7 @@ package org.springframework.security.oauth2.client.authentication;
import org.springframework.security.oauth2.core.endpoint.AuthorizationRequestAttributes; import org.springframework.security.oauth2.core.endpoint.AuthorizationRequestAttributes;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
/** /**
* Implementations of this interface are responsible for the persistence * Implementations of this interface are responsible for the persistence
@ -38,7 +39,8 @@ public interface AuthorizationRequestRepository {
AuthorizationRequestAttributes loadAuthorizationRequest(HttpServletRequest request); AuthorizationRequestAttributes loadAuthorizationRequest(HttpServletRequest request);
void saveAuthorizationRequest(AuthorizationRequestAttributes authorizationRequest, HttpServletRequest request); void saveAuthorizationRequest(AuthorizationRequestAttributes authorizationRequest, HttpServletRequest request,
HttpServletResponse response);
AuthorizationRequestAttributes removeAuthorizationRequest(HttpServletRequest request); AuthorizationRequestAttributes removeAuthorizationRequest(HttpServletRequest request);

View File

@ -18,6 +18,7 @@ package org.springframework.security.oauth2.client.authentication;
import org.springframework.security.oauth2.core.endpoint.AuthorizationRequestAttributes; import org.springframework.security.oauth2.core.endpoint.AuthorizationRequestAttributes;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpSession; import javax.servlet.http.HttpSession;
/** /**
@ -44,7 +45,8 @@ public final class HttpSessionAuthorizationRequestRepository implements Authoriz
} }
@Override @Override
public void saveAuthorizationRequest(AuthorizationRequestAttributes authorizationRequest, HttpServletRequest request) { public void saveAuthorizationRequest(AuthorizationRequestAttributes authorizationRequest, HttpServletRequest request,
HttpServletResponse response) {
if (authorizationRequest == null) { if (authorizationRequest == null) {
this.removeAuthorizationRequest(request); this.removeAuthorizationRequest(request);
return; return;

View File

@ -106,8 +106,8 @@ public class AuthorizationCodeAuthenticationProcessingFilterTests {
String state = "some state"; String state = "some state";
request.addParameter(OAuth2Parameter.CODE, authCode); request.addParameter(OAuth2Parameter.CODE, authCode);
request.addParameter(OAuth2Parameter.STATE, state); request.addParameter(OAuth2Parameter.STATE, state);
setupAuthorizationRequest(authorizationRequestRepository, request, clientRegistration, state);
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
setupAuthorizationRequest(authorizationRequestRepository, request, response, clientRegistration, state);
FilterChain filterChain = mock(FilterChain.class); FilterChain filterChain = mock(FilterChain.class);
filter.doFilter(request, response, filterChain); filter.doFilter(request, response, filterChain);
@ -156,8 +156,8 @@ public class AuthorizationCodeAuthenticationProcessingFilterTests {
String state = "some other state"; String state = "some other state";
request.addParameter(OAuth2Parameter.CODE, authCode); request.addParameter(OAuth2Parameter.CODE, authCode);
request.addParameter(OAuth2Parameter.STATE, state); request.addParameter(OAuth2Parameter.STATE, state);
setupAuthorizationRequest(authorizationRequestRepository, request, clientRegistration, "some state");
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
setupAuthorizationRequest(authorizationRequestRepository, request, response, clientRegistration, "some state");
FilterChain filterChain = mock(FilterChain.class); FilterChain filterChain = mock(FilterChain.class);
filter.doFilter(request, response, filterChain); filter.doFilter(request, response, filterChain);
@ -181,8 +181,8 @@ public class AuthorizationCodeAuthenticationProcessingFilterTests {
String state = "some state"; String state = "some state";
request.addParameter(OAuth2Parameter.CODE, authCode); request.addParameter(OAuth2Parameter.CODE, authCode);
request.addParameter(OAuth2Parameter.STATE, state); request.addParameter(OAuth2Parameter.STATE, state);
setupAuthorizationRequest(authorizationRequestRepository, request, clientRegistration, state);
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
setupAuthorizationRequest(authorizationRequestRepository, request, response, clientRegistration, state);
FilterChain filterChain = mock(FilterChain.class); FilterChain filterChain = mock(FilterChain.class);
filter.doFilter(request, response, filterChain); filter.doFilter(request, response, filterChain);
@ -227,6 +227,7 @@ public class AuthorizationCodeAuthenticationProcessingFilterTests {
private void setupAuthorizationRequest(AuthorizationRequestRepository authorizationRequestRepository, private void setupAuthorizationRequest(AuthorizationRequestRepository authorizationRequestRepository,
HttpServletRequest request, HttpServletRequest request,
HttpServletResponse response,
ClientRegistration clientRegistration, ClientRegistration clientRegistration,
String state) { String state) {
@ -239,7 +240,7 @@ public class AuthorizationCodeAuthenticationProcessingFilterTests {
.state(state) .state(state)
.build(); .build();
authorizationRequestRepository.saveAuthorizationRequest(authorizationRequestAttributes, request); authorizationRequestRepository.saveAuthorizationRequest(authorizationRequestAttributes, request, response);
} }
private MockHttpServletRequest setupRequest(ClientRegistration clientRegistration) { private MockHttpServletRequest setupRequest(ClientRegistration clientRegistration) {