Don't force downcasting of RequestAttributes to ServletRequestAttributes

Fixes gh-7953
This commit is contained in:
Stephane Maldini 2020-02-06 14:16:38 -08:00 committed by Joe Grandja
parent 2dc8147106
commit 0012e24c46
4 changed files with 73 additions and 37 deletions

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2019 the original author or authors. * Copyright 2002-2020 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -23,6 +23,7 @@ import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Configuration;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.web.context.request.RequestAttributes;
import org.springframework.web.context.request.RequestContextHolder; import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes; import org.springframework.web.context.request.ServletRequestAttributes;
import reactor.core.CoreSubscriber; import reactor.core.CoreSubscriber;
@ -92,32 +93,21 @@ class SecurityReactorContextConfiguration {
} }
private static boolean contextAttributesAvailable() { private static boolean contextAttributesAvailable() {
HttpServletRequest servletRequest = null; return SecurityContextHolder.getContext().getAuthentication() != null ||
HttpServletResponse servletResponse = null; RequestContextHolder.getRequestAttributes() instanceof ServletRequestAttributes;
ServletRequestAttributes requestAttributes =
(ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
if (requestAttributes != null) {
servletRequest = requestAttributes.getRequest();
servletResponse = requestAttributes.getResponse();
}
Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
if (authentication != null || servletRequest != null || servletResponse != null) {
return true;
}
return false;
} }
private static Map<Object, Object> getContextAttributes() { private static Map<Object, Object> getContextAttributes() {
HttpServletRequest servletRequest = null; HttpServletRequest servletRequest = null;
HttpServletResponse servletResponse = null; HttpServletResponse servletResponse = null;
ServletRequestAttributes requestAttributes = RequestAttributes requestAttributes = RequestContextHolder.getRequestAttributes();
(ServletRequestAttributes) RequestContextHolder.getRequestAttributes(); if (requestAttributes instanceof ServletRequestAttributes) {
if (requestAttributes != null) { ServletRequestAttributes servletRequestAttributes = (ServletRequestAttributes) requestAttributes;
servletRequest = requestAttributes.getRequest(); servletRequest = servletRequestAttributes.getRequest();
servletResponse = requestAttributes.getResponse(); servletResponse = servletRequestAttributes.getResponse(); // possible null
} }
Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
if (authentication == null && servletRequest == null && servletResponse == null) { if (authentication == null && servletRequest == null) {
return Collections.emptyMap(); return Collections.emptyMap();
} }

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2019 the original author or authors. * Copyright 2002-2020 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -28,6 +28,7 @@ import org.springframework.security.config.test.SpringTestRule;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.client.web.reactive.function.client.MockExchangeFunction; import org.springframework.security.oauth2.client.web.reactive.function.client.MockExchangeFunction;
import org.springframework.web.context.request.RequestAttributes;
import org.springframework.web.context.request.RequestContextHolder; import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes; import org.springframework.web.context.request.ServletRequestAttributes;
import org.springframework.web.reactive.function.client.ClientRequest; import org.springframework.web.reactive.function.client.ClientRequest;
@ -36,6 +37,7 @@ import org.springframework.web.reactive.function.client.ExchangeFilterFunction;
import reactor.core.CoreSubscriber; import reactor.core.CoreSubscriber;
import reactor.core.publisher.BaseSubscriber; import reactor.core.publisher.BaseSubscriber;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
import reactor.core.publisher.Operators;
import reactor.test.StepVerifier; import reactor.test.StepVerifier;
import reactor.util.context.Context; import reactor.util.context.Context;
@ -139,6 +141,52 @@ public class SecurityReactorContextConfigurationTests {
assertThat(resultContext).isSameAs(parentContext); assertThat(resultContext).isSameAs(parentContext);
} }
@Test
public void createSubscriberIfNecessaryWhenNotServletRequestAttributesThenStillCreate() {
RequestContextHolder.setRequestAttributes(
new RequestAttributes() {
@Override
public Object getAttribute(String name, int scope) {
return null;
}
@Override
public void setAttribute(String name, Object value, int scope) {
}
@Override
public void removeAttribute(String name, int scope) {
}
@Override
public String[] getAttributeNames(int scope) {
return new String[0];
}
@Override
public void registerDestructionCallback(String name, Runnable callback, int scope) {
}
@Override
public Object resolveReference(String key) {
return null;
}
@Override
public String getSessionId() {
return null;
}
@Override
public Object getSessionMutex() {
return null;
}
});
CoreSubscriber<Object> subscriber = this.subscriberRegistrar.createSubscriberIfNecessary(Operators.emptySubscriber());
assertThat(subscriber).isInstanceOf(SecurityReactorContextConfiguration.SecurityReactorContextSubscriber.class);
}
@Test @Test
public void createPublisherWhenLastOperatorAddedThenSecurityContextAttributesAvailable() { public void createPublisherWhenLastOperatorAddedThenSecurityContextAttributesAvailable() {
// Trigger the importing of SecurityReactorContextConfiguration via OAuth2ImportSelector // Trigger the importing of SecurityReactorContextConfiguration via OAuth2ImportSelector

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2019 the original author or authors. * Copyright 2002-2020 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -28,6 +28,7 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils; import org.springframework.util.StringUtils;
import org.springframework.web.context.request.RequestAttributes;
import org.springframework.web.context.request.RequestContextHolder; import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes; import org.springframework.web.context.request.ServletRequestAttributes;
@ -121,9 +122,9 @@ public final class DefaultOAuth2AuthorizedClientManager implements OAuth2Authori
private static HttpServletRequest getHttpServletRequestOrDefault(Map<String, Object> attributes) { private static HttpServletRequest getHttpServletRequestOrDefault(Map<String, Object> attributes) {
HttpServletRequest servletRequest = (HttpServletRequest) attributes.get(HttpServletRequest.class.getName()); HttpServletRequest servletRequest = (HttpServletRequest) attributes.get(HttpServletRequest.class.getName());
if (servletRequest == null) { if (servletRequest == null) {
ServletRequestAttributes context = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes(); RequestAttributes context = RequestContextHolder.getRequestAttributes();
if (context != null) { if (context instanceof ServletRequestAttributes) {
servletRequest = context.getRequest(); servletRequest = ((ServletRequestAttributes) context).getRequest();
} }
} }
return servletRequest; return servletRequest;
@ -132,9 +133,9 @@ public final class DefaultOAuth2AuthorizedClientManager implements OAuth2Authori
private static HttpServletResponse getHttpServletResponseOrDefault(Map<String, Object> attributes) { private static HttpServletResponse getHttpServletResponseOrDefault(Map<String, Object> attributes) {
HttpServletResponse servletResponse = (HttpServletResponse) attributes.get(HttpServletResponse.class.getName()); HttpServletResponse servletResponse = (HttpServletResponse) attributes.get(HttpServletResponse.class.getName());
if (servletResponse == null) { if (servletResponse == null) {
ServletRequestAttributes context = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes(); RequestAttributes context = RequestContextHolder.getRequestAttributes();
if (context != null) { if (context instanceof ServletRequestAttributes) {
servletResponse = context.getResponse(); servletResponse = ((ServletRequestAttributes) context).getResponse();
} }
} }
return servletResponse; return servletResponse;

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2019 the original author or authors. * Copyright 2002-2020 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -36,6 +36,7 @@ import org.springframework.security.oauth2.client.registration.ClientRegistratio
import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizedClientManager; import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizedClientManager;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.web.context.request.RequestAttributes;
import org.springframework.web.context.request.RequestContextHolder; import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes; import org.springframework.web.context.request.ServletRequestAttributes;
import org.springframework.web.reactive.function.client.ClientRequest; import org.springframework.web.reactive.function.client.ClientRequest;
@ -389,15 +390,11 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
attrs.containsKey(HTTP_SERVLET_RESPONSE_ATTR_NAME)) { attrs.containsKey(HTTP_SERVLET_RESPONSE_ATTR_NAME)) {
return; return;
} }
ServletRequestAttributes context = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes(); RequestAttributes context = RequestContextHolder.getRequestAttributes();
HttpServletRequest request = null; if (context instanceof ServletRequestAttributes) {
HttpServletResponse response = null; attrs.putIfAbsent(HTTP_SERVLET_REQUEST_ATTR_NAME, ((ServletRequestAttributes) context).getRequest());
if (context != null) { attrs.putIfAbsent(HTTP_SERVLET_RESPONSE_ATTR_NAME, ((ServletRequestAttributes) context).getResponse());
request = context.getRequest();
response = context.getResponse();
} }
attrs.putIfAbsent(HTTP_SERVLET_REQUEST_ATTR_NAME, request);
attrs.putIfAbsent(HTTP_SERVLET_RESPONSE_ATTR_NAME, response);
} }
private void populateDefaultAuthentication(Map<String, Object> attrs) { private void populateDefaultAuthentication(Map<String, Object> attrs) {