Improve RequestAttributesThreadLocalAccessor

Ensure access to request attributes after initial REQUEST dispatch
is done, and the RequestAttributes markedCompleted.

Closes gh-32296
This commit is contained in:
rstoyanchev 2024-05-10 18:29:30 +01:00
parent 3ada9a0c79
commit 2c9ed4608f
2 changed files with 143 additions and 1 deletions

View File

@ -16,7 +16,12 @@
package org.springframework.web.context.request;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.Map;
import io.micrometer.context.ThreadLocalAccessor;
import jakarta.servlet.http.HttpServletRequest;
import org.springframework.lang.Nullable;
@ -26,6 +31,7 @@ import org.springframework.lang.Nullable;
* {@link RequestAttributes} propagation.
*
* @author Tadaya Tsuyukubo
* @author Rossen Stoyanchev
* @since 6.2
*/
public class RequestAttributesThreadLocalAccessor implements ThreadLocalAccessor<RequestAttributes> {
@ -44,7 +50,11 @@ public class RequestAttributesThreadLocalAccessor implements ThreadLocalAccessor
@Override
@Nullable
public RequestAttributes getValue() {
return RequestContextHolder.getRequestAttributes();
RequestAttributes request = RequestContextHolder.getRequestAttributes();
if (request instanceof ServletRequestAttributes sra && !(sra instanceof SnapshotServletRequestAttributes)) {
request = new SnapshotServletRequestAttributes(sra);
}
return request;
}
@Override
@ -57,4 +67,82 @@ public class RequestAttributesThreadLocalAccessor implements ThreadLocalAccessor
RequestContextHolder.resetRequestAttributes();
}
/**
* ServletRequestAttributes that takes another instance, and makes a copy of the
* request attributes at present to provides extended read access during async
* handling when the DispatcherServlet has exited from the initial REQUEST dispatch
* and marked the request {@link ServletRequestAttributes#requestCompleted()}.
* <p>Note that beyond access to request attributes, here is no attempt to support
* setting or removing request attributes, nor to access session attributes after
* the initial REQUEST dispatch has exited.
*/
private static final class SnapshotServletRequestAttributes extends ServletRequestAttributes {
private final ServletRequestAttributes delegate;
private final Map<String, Object> attributeMap;
public SnapshotServletRequestAttributes(ServletRequestAttributes requestAttributes) {
super(requestAttributes.getRequest(), requestAttributes.getResponse());
this.delegate = requestAttributes;
this.attributeMap = getAttributes(requestAttributes.getRequest());
}
private static Map<String, Object> getAttributes(HttpServletRequest request) {
Map<String, Object> map = new HashMap<>();
Enumeration<String> names = request.getAttributeNames();
while (names.hasMoreElements()) {
String name = names.nextElement();
map.put(name, request.getAttribute(name));
}
return map;
}
// Delegate methods that check isRequestActive()
@Nullable
@Override
public Object getAttribute(String name, int scope) {
if (scope == RequestAttributes.SCOPE_REQUEST && !this.delegate.isRequestActive()) {
return this.attributeMap.get(name);
}
try {
return this.delegate.getAttribute(name, scope);
}
catch (IllegalStateException ex) {
if (scope == RequestAttributes.SCOPE_REQUEST) {
return this.attributeMap.get(name);
}
throw ex;
}
}
@Override
public String[] getAttributeNames(int scope) {
if (scope == RequestAttributes.SCOPE_REQUEST && !this.delegate.isRequestActive()) {
return this.attributeMap.keySet().toArray(new String[0]);
}
try {
return this.delegate.getAttributeNames(scope);
}
catch (IllegalStateException ex) {
if (scope == RequestAttributes.SCOPE_REQUEST) {
return this.attributeMap.keySet().toArray(new String[0]);
}
throw ex;
}
}
@Override
public void setAttribute(String name, Object value, int scope) {
this.delegate.setAttribute(name, value, scope);
}
@Override
public void removeAttribute(String name, int scope) {
this.delegate.removeAttribute(name, scope);
}
}
}

View File

@ -23,12 +23,18 @@ import io.micrometer.context.ContextRegistry;
import io.micrometer.context.ContextSnapshot;
import io.micrometer.context.ContextSnapshot.Scope;
import io.micrometer.context.ContextSnapshotFactory;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.springframework.web.testfixture.servlet.MockHttpServletRequest;
import org.springframework.web.testfixture.servlet.MockHttpServletResponse;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalStateException;
import static org.mockito.Mockito.mock;
import static org.springframework.web.context.request.RequestAttributes.SCOPE_REQUEST;
/**
* Tests for {@link RequestAttributesThreadLocalAccessor}.
@ -73,6 +79,54 @@ class RequestAttributesThreadLocalAccessorTests {
assertThat(requestAfterScope).hasValueSatisfying(value -> assertThat(value).isSameAs(previousRequest));
}
@Test
void accessAfterRequestMarkedCompleted() {
MockHttpServletRequest servletRequest = new MockHttpServletRequest();
servletRequest.setAttribute("k1", "v1");
servletRequest.setAttribute("k2", "v2");
ServletRequestAttributes attributes = new ServletRequestAttributes(servletRequest, new MockHttpServletResponse());
ContextSnapshot snapshot = getSnapshotFor(attributes);
attributes.requestCompleted(); // REQUEST dispatch ends, async handling continues
try (Scope scope = snapshot.setThreadLocals()) {
RequestAttributes current = RequestContextHolder.getRequestAttributes();
assertThat(current).isNotNull();
assertThat(current.getAttributeNames(SCOPE_REQUEST)).containsExactly("k1", "k2");
assertThat(current.getAttribute("k1", SCOPE_REQUEST)).isEqualTo("v1");
assertThat(current.getAttribute("k2", SCOPE_REQUEST)).isEqualTo("v2");
assertThatIllegalStateException().isThrownBy(() -> current.setAttribute("k3", "v3", SCOPE_REQUEST));
}
}
@Test
void accessBeforeRequestMarkedCompleted() {
MockHttpServletRequest servletRequest = new MockHttpServletRequest();
ServletRequestAttributes previous = new ServletRequestAttributes(servletRequest, new MockHttpServletResponse());
ContextSnapshot snapshot = getSnapshotFor(previous);
RequestContextHolder.setRequestAttributes(previous);
try {
try (Scope scope = snapshot.setThreadLocals()) {
RequestAttributes attributes = RequestContextHolder.getRequestAttributes();
assertThat(attributes).isNotNull();
attributes.setAttribute("k1", "v1", SCOPE_REQUEST);
}
RequestAttributes attributes = RequestContextHolder.getRequestAttributes();
assertThat(attributes).isNotNull();
attributes.setAttribute("k2", "v2", SCOPE_REQUEST);
}
finally {
RequestContextHolder.resetRequestAttributes();
}
assertThat(previous.getAttributeNames(SCOPE_REQUEST)).containsExactly("k1", "k2");
assertThat(previous.getAttribute("k1", SCOPE_REQUEST)).isEqualTo("v1");
assertThat(previous.getAttribute("k2", SCOPE_REQUEST)).isEqualTo("v2");
}
private ContextSnapshot getSnapshotFor(RequestAttributes request) {
RequestContextHolder.setRequestAttributes(request);
try {