Don't force downcasting of RequestAttributes to ServletRequestAttributes
Fixes gh-7953
This commit is contained in:
parent
2dc8147106
commit
0012e24c46
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright 2002-2019 the original author or authors.
|
||||
* Copyright 2002-2020 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -23,6 +23,7 @@ import org.springframework.context.annotation.Bean;
|
|||
import org.springframework.context.annotation.Configuration;
|
||||
import org.springframework.security.core.Authentication;
|
||||
import org.springframework.security.core.context.SecurityContextHolder;
|
||||
import org.springframework.web.context.request.RequestAttributes;
|
||||
import org.springframework.web.context.request.RequestContextHolder;
|
||||
import org.springframework.web.context.request.ServletRequestAttributes;
|
||||
import reactor.core.CoreSubscriber;
|
||||
|
@ -92,32 +93,21 @@ class SecurityReactorContextConfiguration {
|
|||
}
|
||||
|
||||
private static boolean contextAttributesAvailable() {
|
||||
HttpServletRequest servletRequest = null;
|
||||
HttpServletResponse servletResponse = null;
|
||||
ServletRequestAttributes requestAttributes =
|
||||
(ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
|
||||
if (requestAttributes != null) {
|
||||
servletRequest = requestAttributes.getRequest();
|
||||
servletResponse = requestAttributes.getResponse();
|
||||
}
|
||||
Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
|
||||
if (authentication != null || servletRequest != null || servletResponse != null) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
return SecurityContextHolder.getContext().getAuthentication() != null ||
|
||||
RequestContextHolder.getRequestAttributes() instanceof ServletRequestAttributes;
|
||||
}
|
||||
|
||||
private static Map<Object, Object> getContextAttributes() {
|
||||
HttpServletRequest servletRequest = null;
|
||||
HttpServletResponse servletResponse = null;
|
||||
ServletRequestAttributes requestAttributes =
|
||||
(ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
|
||||
if (requestAttributes != null) {
|
||||
servletRequest = requestAttributes.getRequest();
|
||||
servletResponse = requestAttributes.getResponse();
|
||||
RequestAttributes requestAttributes = RequestContextHolder.getRequestAttributes();
|
||||
if (requestAttributes instanceof ServletRequestAttributes) {
|
||||
ServletRequestAttributes servletRequestAttributes = (ServletRequestAttributes) requestAttributes;
|
||||
servletRequest = servletRequestAttributes.getRequest();
|
||||
servletResponse = servletRequestAttributes.getResponse(); // possible null
|
||||
}
|
||||
Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
|
||||
if (authentication == null && servletRequest == null && servletResponse == null) {
|
||||
if (authentication == null && servletRequest == null) {
|
||||
return Collections.emptyMap();
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright 2002-2019 the original author or authors.
|
||||
* Copyright 2002-2020 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -28,6 +28,7 @@ import org.springframework.security.config.test.SpringTestRule;
|
|||
import org.springframework.security.core.Authentication;
|
||||
import org.springframework.security.core.context.SecurityContextHolder;
|
||||
import org.springframework.security.oauth2.client.web.reactive.function.client.MockExchangeFunction;
|
||||
import org.springframework.web.context.request.RequestAttributes;
|
||||
import org.springframework.web.context.request.RequestContextHolder;
|
||||
import org.springframework.web.context.request.ServletRequestAttributes;
|
||||
import org.springframework.web.reactive.function.client.ClientRequest;
|
||||
|
@ -36,6 +37,7 @@ import org.springframework.web.reactive.function.client.ExchangeFilterFunction;
|
|||
import reactor.core.CoreSubscriber;
|
||||
import reactor.core.publisher.BaseSubscriber;
|
||||
import reactor.core.publisher.Mono;
|
||||
import reactor.core.publisher.Operators;
|
||||
import reactor.test.StepVerifier;
|
||||
import reactor.util.context.Context;
|
||||
|
||||
|
@ -139,6 +141,52 @@ public class SecurityReactorContextConfigurationTests {
|
|||
assertThat(resultContext).isSameAs(parentContext);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void createSubscriberIfNecessaryWhenNotServletRequestAttributesThenStillCreate() {
|
||||
RequestContextHolder.setRequestAttributes(
|
||||
new RequestAttributes() {
|
||||
@Override
|
||||
public Object getAttribute(String name, int scope) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setAttribute(String name, Object value, int scope) {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void removeAttribute(String name, int scope) {
|
||||
}
|
||||
|
||||
@Override
|
||||
public String[] getAttributeNames(int scope) {
|
||||
return new String[0];
|
||||
}
|
||||
|
||||
@Override
|
||||
public void registerDestructionCallback(String name, Runnable callback, int scope) {
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object resolveReference(String key) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getSessionId() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object getSessionMutex() {
|
||||
return null;
|
||||
}
|
||||
});
|
||||
|
||||
CoreSubscriber<Object> subscriber = this.subscriberRegistrar.createSubscriberIfNecessary(Operators.emptySubscriber());
|
||||
assertThat(subscriber).isInstanceOf(SecurityReactorContextConfiguration.SecurityReactorContextSubscriber.class);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void createPublisherWhenLastOperatorAddedThenSecurityContextAttributesAvailable() {
|
||||
// Trigger the importing of SecurityReactorContextConfiguration via OAuth2ImportSelector
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright 2002-2019 the original author or authors.
|
||||
* Copyright 2002-2020 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -28,6 +28,7 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
|
|||
import org.springframework.util.Assert;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
import org.springframework.util.StringUtils;
|
||||
import org.springframework.web.context.request.RequestAttributes;
|
||||
import org.springframework.web.context.request.RequestContextHolder;
|
||||
import org.springframework.web.context.request.ServletRequestAttributes;
|
||||
|
||||
|
@ -121,9 +122,9 @@ public final class DefaultOAuth2AuthorizedClientManager implements OAuth2Authori
|
|||
private static HttpServletRequest getHttpServletRequestOrDefault(Map<String, Object> attributes) {
|
||||
HttpServletRequest servletRequest = (HttpServletRequest) attributes.get(HttpServletRequest.class.getName());
|
||||
if (servletRequest == null) {
|
||||
ServletRequestAttributes context = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
|
||||
if (context != null) {
|
||||
servletRequest = context.getRequest();
|
||||
RequestAttributes context = RequestContextHolder.getRequestAttributes();
|
||||
if (context instanceof ServletRequestAttributes) {
|
||||
servletRequest = ((ServletRequestAttributes) context).getRequest();
|
||||
}
|
||||
}
|
||||
return servletRequest;
|
||||
|
@ -132,9 +133,9 @@ public final class DefaultOAuth2AuthorizedClientManager implements OAuth2Authori
|
|||
private static HttpServletResponse getHttpServletResponseOrDefault(Map<String, Object> attributes) {
|
||||
HttpServletResponse servletResponse = (HttpServletResponse) attributes.get(HttpServletResponse.class.getName());
|
||||
if (servletResponse == null) {
|
||||
ServletRequestAttributes context = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
|
||||
if (context != null) {
|
||||
servletResponse = context.getResponse();
|
||||
RequestAttributes context = RequestContextHolder.getRequestAttributes();
|
||||
if (context instanceof ServletRequestAttributes) {
|
||||
servletResponse = ((ServletRequestAttributes) context).getResponse();
|
||||
}
|
||||
}
|
||||
return servletResponse;
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright 2002-2019 the original author or authors.
|
||||
* Copyright 2002-2020 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -36,6 +36,7 @@ import org.springframework.security.oauth2.client.registration.ClientRegistratio
|
|||
import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizedClientManager;
|
||||
import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.web.context.request.RequestAttributes;
|
||||
import org.springframework.web.context.request.RequestContextHolder;
|
||||
import org.springframework.web.context.request.ServletRequestAttributes;
|
||||
import org.springframework.web.reactive.function.client.ClientRequest;
|
||||
|
@ -389,15 +390,11 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
|
|||
attrs.containsKey(HTTP_SERVLET_RESPONSE_ATTR_NAME)) {
|
||||
return;
|
||||
}
|
||||
ServletRequestAttributes context = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
|
||||
HttpServletRequest request = null;
|
||||
HttpServletResponse response = null;
|
||||
if (context != null) {
|
||||
request = context.getRequest();
|
||||
response = context.getResponse();
|
||||
RequestAttributes context = RequestContextHolder.getRequestAttributes();
|
||||
if (context instanceof ServletRequestAttributes) {
|
||||
attrs.putIfAbsent(HTTP_SERVLET_REQUEST_ATTR_NAME, ((ServletRequestAttributes) context).getRequest());
|
||||
attrs.putIfAbsent(HTTP_SERVLET_RESPONSE_ATTR_NAME, ((ServletRequestAttributes) context).getResponse());
|
||||
}
|
||||
attrs.putIfAbsent(HTTP_SERVLET_REQUEST_ATTR_NAME, request);
|
||||
attrs.putIfAbsent(HTTP_SERVLET_RESPONSE_ATTR_NAME, response);
|
||||
}
|
||||
|
||||
private void populateDefaultAuthentication(Map<String, Object> attrs) {
|
||||
|
|
Loading…
Reference in New Issue