Don't force downcasting of RequestAttributes to ServletRequestAttributes

Fixes gh-7953
This commit is contained in:
Stephane Maldini 2020-02-06 14:16:38 -08:00 committed by Joe Grandja
parent 2dc8147106
commit 0012e24c46
4 changed files with 73 additions and 37 deletions

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -23,6 +23,7 @@ import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.web.context.request.RequestAttributes;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;
import reactor.core.CoreSubscriber;
@ -92,32 +93,21 @@ class SecurityReactorContextConfiguration {
}
private static boolean contextAttributesAvailable() {
HttpServletRequest servletRequest = null;
HttpServletResponse servletResponse = null;
ServletRequestAttributes requestAttributes =
(ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
if (requestAttributes != null) {
servletRequest = requestAttributes.getRequest();
servletResponse = requestAttributes.getResponse();
}
Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
if (authentication != null || servletRequest != null || servletResponse != null) {
return true;
}
return false;
return SecurityContextHolder.getContext().getAuthentication() != null ||
RequestContextHolder.getRequestAttributes() instanceof ServletRequestAttributes;
}
private static Map<Object, Object> getContextAttributes() {
HttpServletRequest servletRequest = null;
HttpServletResponse servletResponse = null;
ServletRequestAttributes requestAttributes =
(ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
if (requestAttributes != null) {
servletRequest = requestAttributes.getRequest();
servletResponse = requestAttributes.getResponse();
RequestAttributes requestAttributes = RequestContextHolder.getRequestAttributes();
if (requestAttributes instanceof ServletRequestAttributes) {
ServletRequestAttributes servletRequestAttributes = (ServletRequestAttributes) requestAttributes;
servletRequest = servletRequestAttributes.getRequest();
servletResponse = servletRequestAttributes.getResponse(); // possible null
}
Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
if (authentication == null && servletRequest == null && servletResponse == null) {
if (authentication == null && servletRequest == null) {
return Collections.emptyMap();
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -28,6 +28,7 @@ import org.springframework.security.config.test.SpringTestRule;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.client.web.reactive.function.client.MockExchangeFunction;
import org.springframework.web.context.request.RequestAttributes;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;
import org.springframework.web.reactive.function.client.ClientRequest;
@ -36,6 +37,7 @@ import org.springframework.web.reactive.function.client.ExchangeFilterFunction;
import reactor.core.CoreSubscriber;
import reactor.core.publisher.BaseSubscriber;
import reactor.core.publisher.Mono;
import reactor.core.publisher.Operators;
import reactor.test.StepVerifier;
import reactor.util.context.Context;
@ -139,6 +141,52 @@ public class SecurityReactorContextConfigurationTests {
assertThat(resultContext).isSameAs(parentContext);
}
@Test
public void createSubscriberIfNecessaryWhenNotServletRequestAttributesThenStillCreate() {
RequestContextHolder.setRequestAttributes(
new RequestAttributes() {
@Override
public Object getAttribute(String name, int scope) {
return null;
}
@Override
public void setAttribute(String name, Object value, int scope) {
}
@Override
public void removeAttribute(String name, int scope) {
}
@Override
public String[] getAttributeNames(int scope) {
return new String[0];
}
@Override
public void registerDestructionCallback(String name, Runnable callback, int scope) {
}
@Override
public Object resolveReference(String key) {
return null;
}
@Override
public String getSessionId() {
return null;
}
@Override
public Object getSessionMutex() {
return null;
}
});
CoreSubscriber<Object> subscriber = this.subscriberRegistrar.createSubscriberIfNecessary(Operators.emptySubscriber());
assertThat(subscriber).isInstanceOf(SecurityReactorContextConfiguration.SecurityReactorContextSubscriber.class);
}
@Test
public void createPublisherWhenLastOperatorAddedThenSecurityContextAttributesAvailable() {
// Trigger the importing of SecurityReactorContextConfiguration via OAuth2ImportSelector

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -28,6 +28,7 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
import org.springframework.web.context.request.RequestAttributes;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;
@ -121,9 +122,9 @@ public final class DefaultOAuth2AuthorizedClientManager implements OAuth2Authori
private static HttpServletRequest getHttpServletRequestOrDefault(Map<String, Object> attributes) {
HttpServletRequest servletRequest = (HttpServletRequest) attributes.get(HttpServletRequest.class.getName());
if (servletRequest == null) {
ServletRequestAttributes context = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
if (context != null) {
servletRequest = context.getRequest();
RequestAttributes context = RequestContextHolder.getRequestAttributes();
if (context instanceof ServletRequestAttributes) {
servletRequest = ((ServletRequestAttributes) context).getRequest();
}
}
return servletRequest;
@ -132,9 +133,9 @@ public final class DefaultOAuth2AuthorizedClientManager implements OAuth2Authori
private static HttpServletResponse getHttpServletResponseOrDefault(Map<String, Object> attributes) {
HttpServletResponse servletResponse = (HttpServletResponse) attributes.get(HttpServletResponse.class.getName());
if (servletResponse == null) {
ServletRequestAttributes context = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
if (context != null) {
servletResponse = context.getResponse();
RequestAttributes context = RequestContextHolder.getRequestAttributes();
if (context instanceof ServletRequestAttributes) {
servletResponse = ((ServletRequestAttributes) context).getResponse();
}
}
return servletResponse;

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -36,6 +36,7 @@ import org.springframework.security.oauth2.client.registration.ClientRegistratio
import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizedClientManager;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
import org.springframework.util.Assert;
import org.springframework.web.context.request.RequestAttributes;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;
import org.springframework.web.reactive.function.client.ClientRequest;
@ -389,15 +390,11 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
attrs.containsKey(HTTP_SERVLET_RESPONSE_ATTR_NAME)) {
return;
}
ServletRequestAttributes context = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
HttpServletRequest request = null;
HttpServletResponse response = null;
if (context != null) {
request = context.getRequest();
response = context.getResponse();
RequestAttributes context = RequestContextHolder.getRequestAttributes();
if (context instanceof ServletRequestAttributes) {
attrs.putIfAbsent(HTTP_SERVLET_REQUEST_ATTR_NAME, ((ServletRequestAttributes) context).getRequest());
attrs.putIfAbsent(HTTP_SERVLET_RESPONSE_ATTR_NAME, ((ServletRequestAttributes) context).getResponse());
}
attrs.putIfAbsent(HTTP_SERVLET_REQUEST_ATTR_NAME, request);
attrs.putIfAbsent(HTTP_SERVLET_RESPONSE_ATTR_NAME, response);
}
private void populateDefaultAuthentication(Map<String, Object> attrs) {