Introduce request attributes in RestClient

This commit introduces request attributes in the RestClient and
underlying infrastructure (i.e. HttpRequest).

Closes gh-32027
This commit is contained in:
Arjen Poutsma 2024-06-10 10:06:52 +02:00
parent c36e270481
commit 60b5bbe334
21 changed files with 385 additions and 9 deletions

View File

@ -1007,7 +1007,8 @@ method parameters:
is supported for non-String values.
| `@RequestAttribute`
| Provide an `Object` to add as a request attribute. Only supported by `WebClient`.
| Provide an `Object` to add as a request attribute. Only supported by `RestClient`
and `WebClient`.
| `@RequestBody`
| Provide the body of the request either as an Object to be serialized, or a

View File

@ -18,6 +18,8 @@ package org.springframework.mock.http.client;
import java.io.IOException;
import java.net.URI;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.springframework.http.HttpMethod;
import org.springframework.http.client.ClientHttpRequest;
@ -46,6 +48,9 @@ public class MockClientHttpRequest extends MockHttpOutputMessage implements Clie
private boolean executed = false;
@Nullable
Map<String, Object> attributes;
/**
* Create a {@code MockClientHttpRequest} with {@link HttpMethod#GET GET} as
@ -115,6 +120,16 @@ public class MockClientHttpRequest extends MockHttpOutputMessage implements Clie
return this.executed;
}
@Override
public Map<String, Object> getAttributes() {
Map<String, Object> attributes = this.attributes;
if (attributes == null) {
attributes = new ConcurrentHashMap<>();
this.attributes = attributes;
}
return attributes;
}
/**
* Set the {@link #isExecuted() executed} flag to {@code true} and return the
* configured {@link #setResponse(ClientHttpResponse) response}.

View File

@ -17,6 +17,7 @@
package org.springframework.http;
import java.net.URI;
import java.util.Map;
/**
* Represents an HTTP request message, consisting of a
@ -41,4 +42,10 @@ public interface HttpRequest extends HttpMessage {
*/
URI getURI();
/**
* Return a mutable map of request attributes for this request.
* @since 6.2
*/
Map<String, Object> getAttributes();
}

View File

@ -18,6 +18,8 @@ package org.springframework.http.client;
import java.io.IOException;
import java.io.OutputStream;
import java.util.LinkedHashMap;
import java.util.Map;
import org.springframework.http.HttpHeaders;
import org.springframework.lang.Nullable;
@ -39,6 +41,9 @@ public abstract class AbstractClientHttpRequest implements ClientHttpRequest {
@Nullable
private HttpHeaders readOnlyHeaders;
@Nullable
private Map<String, Object> attributes;
@Override
public final HttpHeaders getHeaders() {
@ -60,6 +65,16 @@ public abstract class AbstractClientHttpRequest implements ClientHttpRequest {
return getBodyInternal(this.headers);
}
@Override
public Map<String, Object> getAttributes() {
Map<String, Object> attributes = this.attributes;
if (attributes == null) {
attributes = new LinkedHashMap<>();
this.attributes = attributes;
}
return attributes;
}
@Override
public final ClientHttpResponse execute() throws IOException {
assertNotExecuted();

View File

@ -91,6 +91,7 @@ class InterceptingClientHttpRequest extends AbstractBufferingClientHttpRequest {
HttpMethod method = request.getMethod();
ClientHttpRequest delegate = requestFactory.createRequest(request.getURI(), method);
request.getHeaders().forEach((key, value) -> delegate.getHeaders().addAll(key, value));
request.getAttributes().forEach((key, value) -> delegate.getAttributes().put(key, value));
if (body.length > 0) {
if (delegate instanceof StreamingHttpOutputMessage streamingOutputMessage) {
streamingOutputMessage.setBody(new StreamingHttpOutputMessage.Body() {

View File

@ -17,6 +17,7 @@
package org.springframework.http.client.support;
import java.net.URI;
import java.util.Map;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
@ -70,6 +71,14 @@ public class HttpRequestWrapper implements HttpRequest {
return this.request.getURI();
}
/**
* Return the attributes of the wrapped request.
*/
@Override
public Map<String, Object> getAttributes() {
return this.request.getAttributes();
}
/**
* Return the headers of the wrapped request.
*/

View File

@ -29,11 +29,16 @@ import java.net.URLEncoder;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.security.Principal;
import java.util.AbstractCollection;
import java.util.AbstractMap;
import java.util.AbstractSet;
import java.util.Arrays;
import java.util.Collection;
import java.util.Enumeration;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import jakarta.servlet.http.HttpServletRequest;
@ -67,6 +72,10 @@ public class ServletServerHttpRequest implements ServerHttpRequest {
@Nullable
private HttpHeaders headers;
@Nullable
private Map<String, Object> attributes;
@Nullable
private ServerHttpAsyncRequestControl asyncRequestControl;
@ -207,6 +216,16 @@ public class ServletServerHttpRequest implements ServerHttpRequest {
return new InetSocketAddress(this.servletRequest.getRemoteHost(), this.servletRequest.getRemotePort());
}
@Override
public Map<String, Object> getAttributes() {
Map<String, Object> attributes = this.attributes;
if (attributes == null) {
attributes = new AttributesMap();
this.attributes = attributes;
}
return attributes;
}
@Override
public InputStream getBody() throws IOException {
if (isFormPost(this.servletRequest) && this.servletRequest.getQueryString() == null) {
@ -276,4 +295,151 @@ public class ServletServerHttpRequest implements ServerHttpRequest {
return new ByteArrayInputStream(bytes);
}
private final class AttributesMap extends AbstractMap<String, Object> {
@Nullable
private transient Set<String> keySet;
@Nullable
private transient Collection<Object> values;
@Nullable
private transient Set<Entry<String, Object>> entrySet;
@Override
public int size() {
int size = 0;
for (Enumeration<?> names = servletRequest.getAttributeNames(); names.hasMoreElements(); names.nextElement()) {
size++;
}
return size;
}
@Override
@Nullable
public Object get(Object key) {
if (key instanceof String name) {
return servletRequest.getAttribute(name);
}
else {
return null;
}
}
@Override
@Nullable
public Object put(String key, Object value) {
Object old = get(key);
servletRequest.setAttribute(key, value);
return old;
}
@Override
@Nullable
public Object remove(Object key) {
if (key instanceof String name) {
Object old = get(key);
servletRequest.removeAttribute(name);
return old;
}
else {
return null;
}
}
@Override
public void clear() {
for (Enumeration<String> names = servletRequest.getAttributeNames(); names.hasMoreElements(); ) {
String name = names.nextElement();
servletRequest.removeAttribute(name);
}
}
@Override
public Set<String> keySet() {
Set<String> keySet = this.keySet;
if (keySet == null) {
keySet = new AbstractSet<>() {
@Override
public Iterator<String> iterator() {
return servletRequest.getAttributeNames().asIterator();
}
@Override
public int size() {
return AttributesMap.this.size();
}
};
this.keySet = keySet;
}
return keySet;
}
@Override
public Collection<Object> values() {
Collection<Object> values = this.values;
if (values == null) {
values = new AbstractCollection<>() {
@Override
public Iterator<Object> iterator() {
Enumeration<String> e = servletRequest.getAttributeNames();
return new Iterator<>() {
@Override
public boolean hasNext() {
return e.hasMoreElements();
}
@Override
public Object next() {
String name = e.nextElement();
return servletRequest.getAttribute(name);
}
};
}
@Override
public int size() {
return AttributesMap.this.size();
}
};
this.values = values;
}
return values;
}
@Override
public Set<Entry<String, Object>> entrySet() {
Set<Entry<String, Object>> entrySet = this.entrySet;
if (entrySet == null) {
entrySet = new AbstractSet<>() {
@Override
public Iterator<Entry<String, Object>> iterator() {
Enumeration<String> e = servletRequest.getAttributeNames();
return new Iterator<>() {
@Override
public boolean hasNext() {
return e.hasMoreElements();
}
@Override
public Entry<String, Object> next() {
String name = e.nextElement();
Object value = servletRequest.getAttribute(name);
return new SimpleImmutableEntry<>(name, value);
}
};
}
@Override
public int size() {
return AttributesMap.this.size();
}
};
this.entrySet = entrySet;
}
return entrySet;
}
}
}

View File

@ -19,6 +19,9 @@ package org.springframework.http.server.reactive;
import java.net.URI;
import java.net.URLDecoder;
import java.nio.charset.StandardCharsets;
import java.util.Collections;
import java.util.Map;
import java.util.function.Supplier;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
@ -68,6 +71,9 @@ public abstract class AbstractServerHttpRequest implements ServerHttpRequest {
@Nullable
private String logPrefix;
@Nullable
private Supplier<Map<String, Object>> attributesSupplier;
/**
* Constructor with the method, URI and headers for the request.
@ -122,6 +128,16 @@ public abstract class AbstractServerHttpRequest implements ServerHttpRequest {
return this.uri;
}
@Override
public Map<String, Object> getAttributes() {
if (this.attributesSupplier != null) {
return this.attributesSupplier.get();
}
else {
return Collections.emptyMap();
}
}
@Override
public RequestPath getPath() {
return this.path;
@ -230,4 +246,12 @@ public abstract class AbstractServerHttpRequest implements ServerHttpRequest {
return getId();
}
/**
* Set the attribute supplier.
* <p><strong>Note:</strong> This is exposed mainly for internal framework
* use.
*/
public void setAttributesSupplier(Supplier<Map<String, Object>> attributesSupplier) {
this.attributesSupplier = attributesSupplier;
}
}

View File

@ -18,6 +18,7 @@ package org.springframework.http.server.reactive;
import java.net.InetSocketAddress;
import java.net.URI;
import java.util.Map;
import reactor.core.publisher.Flux;
@ -70,6 +71,11 @@ public class ServerHttpRequestDecorator implements ServerHttpRequest {
return getDelegate().getURI();
}
@Override
public Map<String, Object> getAttributes() {
return getDelegate().getAttributes();
}
@Override
public RequestPath getPath() {
return getDelegate().getPath();

View File

@ -28,6 +28,7 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Predicate;
@ -82,6 +83,8 @@ final class DefaultRestClient implements RestClient {
private static final ClientRequestObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultClientRequestObservationConvention();
private static final String URI_TEMPLATE_ATTRIBUTE = RestClient.class.getName() + ".uriTemplate";
private final ClientHttpRequestFactory clientRequestFactory;
@ -297,7 +300,7 @@ final class DefaultRestClient implements RestClient {
private InternalBody body;
@Nullable
private String uriTemplate;
private Map<String, Object> attributes;
@Nullable
private Consumer<ClientHttpRequest> httpRequestConsumer;
@ -308,19 +311,19 @@ final class DefaultRestClient implements RestClient {
@Override
public RequestBodySpec uri(String uriTemplate, Object... uriVariables) {
this.uriTemplate = uriTemplate;
attribute(URI_TEMPLATE_ATTRIBUTE, uriTemplate);
return uri(DefaultRestClient.this.uriBuilderFactory.expand(uriTemplate, uriVariables));
}
@Override
public RequestBodySpec uri(String uriTemplate, Map<String, ?> uriVariables) {
this.uriTemplate = uriTemplate;
attribute(URI_TEMPLATE_ATTRIBUTE, uriTemplate);
return uri(DefaultRestClient.this.uriBuilderFactory.expand(uriTemplate, uriVariables));
}
@Override
public RequestBodySpec uri(String uriTemplate, Function<UriBuilder, URI> uriFunction) {
this.uriTemplate = uriTemplate;
attribute(URI_TEMPLATE_ATTRIBUTE, uriTemplate);
return uri(uriFunction.apply(DefaultRestClient.this.uriBuilderFactory.uriString(uriTemplate)));
}
@ -392,6 +395,27 @@ final class DefaultRestClient implements RestClient {
return this;
}
@Override
public RequestBodySpec attribute(String name, Object value) {
getAttributes().put(name, value);
return this;
}
@Override
public RequestBodySpec attributes(Consumer<Map<String, Object>> attributesConsumer) {
attributesConsumer.accept(getAttributes());
return this;
}
private Map<String, Object> getAttributes() {
Map<String, Object> attributes = this.attributes;
if (attributes == null) {
attributes = new ConcurrentHashMap<>(4);
this.attributes = attributes;
}
return attributes;
}
@Override
public RequestBodySpec httpRequest(Consumer<ClientHttpRequest> requestConsumer) {
this.httpRequestConsumer = (this.httpRequestConsumer != null ?
@ -483,8 +507,10 @@ final class DefaultRestClient implements RestClient {
HttpHeaders headers = initHeaders();
ClientHttpRequest clientRequest = createRequest(uri);
clientRequest.getHeaders().addAll(headers);
Map<String, Object> attributes = getAttributes();
clientRequest.getAttributes().putAll(attributes);
ClientRequestObservationContext observationContext = new ClientRequestObservationContext(clientRequest);
observationContext.setUriTemplate(this.uriTemplate);
observationContext.setUriTemplate((String) attributes.get(URI_TEMPLATE_ATTRIBUTE));
observation = ClientHttpObservationDocumentation.HTTP_CLIENT_EXCHANGES.observation(observationConvention,
DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, observationRegistry).start();
if (this.body != null) {

View File

@ -498,6 +498,24 @@ public interface RestClient {
*/
S headers(Consumer<HttpHeaders> headersConsumer);
/**
* Set the attribute with the given name to the given value.
* @param name the name of the attribute to add
* @param value the value of the attribute to add
* @return this builder
* @since 6.2
*/
S attribute(String name, Object value);
/**
* Provides access to every attribute declared so far with the
* possibility to add, replace, or remove values.
* @param attributesConsumer the consumer to provide access to
* @return this builder
* @since 6.2
*/
S attributes(Consumer<Map<String, Object>> attributesConsumer);
/**
* Callback for access to the {@link ClientHttpRequest} that in turn
* provides access to the native request of the underlying HTTP library.

View File

@ -56,7 +56,7 @@ public final class RestClientAdapter implements HttpExchangeAdapter {
@Override
public boolean supportsRequestAttributes() {
return false;
return true;
}
@Override
@ -121,6 +121,8 @@ public final class RestClientAdapter implements HttpExchangeAdapter {
bodySpec.header(HttpHeaders.COOKIE, String.join("; ", cookies));
}
bodySpec.attributes(attributes -> attributes.putAll(values.getAttributes()));
if (values.getBodyValue() != null) {
bodySpec.body(values.getBodyValue());
}

View File

@ -42,6 +42,7 @@ import org.springframework.http.MediaType;
import org.springframework.http.codec.HttpMessageReader;
import org.springframework.http.codec.ServerCodecConfigurer;
import org.springframework.http.codec.multipart.Part;
import org.springframework.http.server.reactive.AbstractServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.lang.Nullable;
@ -137,6 +138,10 @@ public class DefaultServerWebExchange implements ServerWebExchange {
this.formDataMono = initFormData(request, codecConfigurer, getLogPrefix());
this.multipartDataMono = initMultipartData(codecConfigurer, getLogPrefix());
this.applicationContext = applicationContext;
if (request instanceof AbstractServerHttpRequest abstractServerHttpRequest) {
abstractServerHttpRequest.setAttributesSupplier(() -> this.attributes);
}
}
private static Mono<MultiValueMap<String, String>> initFormData(ServerHttpRequest request,

View File

@ -115,6 +115,32 @@ class InterceptingClientHttpRequestFactoryTests {
request.execute();
}
@Test
void changeAttribute() throws Exception {
final String attrName = "Foo";
final String attrValue = "Bar";
ClientHttpRequestInterceptor interceptor = (request, body, execution) -> {
System.out.println("interceptor");
request.getAttributes().put(attrName, attrValue);
return execution.execute(request, body);
};
requestMock = new MockClientHttpRequest() {
@Override
protected ClientHttpResponse executeInternal() {
System.out.println("execute");
assertThat(getAttributes()).containsEntry(attrName, attrValue);
return responseMock;
}
};
requestFactory = new InterceptingClientHttpRequestFactory(requestFactoryMock, Collections.singletonList(interceptor));
ClientHttpRequest request = requestFactory.createRequest(URI.create("https://example.com"), HttpMethod.GET);
request.execute();
}
@Test
void changeURI() throws Exception {
final URI changedUri = URI.create("https://example.com/2");

View File

@ -217,4 +217,10 @@ class ServletServerHttpRequestTests {
assertThat(request.getHeaders().getContentLength()).isEqualTo(result.length);
}
@Test
void attributes() {
request.getAttributes().put("foo", "bar");
assertThat(mockRequest.getAttribute("foo")).isEqualTo("bar");
}
}

View File

@ -17,6 +17,8 @@
package org.springframework.web.util;
import java.net.URI;
import java.util.Collections;
import java.util.Map;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
@ -363,6 +365,11 @@ class ForwardedHeaderUtilsTests {
return UriComponentsBuilder.fromUriString("/").build().toUri();
}
@Override
public Map<String, Object> getAttributes() {
return Collections.emptyMap();
}
@Override
public HttpHeaders getHeaders() {
return new HttpHeaders();

View File

@ -18,6 +18,8 @@ package org.springframework.web.testfixture.http.client;
import java.io.IOException;
import java.net.URI;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.springframework.http.HttpMethod;
import org.springframework.http.client.ClientHttpRequest;
@ -46,6 +48,9 @@ public class MockClientHttpRequest extends MockHttpOutputMessage implements Clie
private boolean executed = false;
@Nullable
Map<String, Object> attributes;
/**
* Create a {@code MockClientHttpRequest} with {@link HttpMethod#GET GET} as
@ -115,6 +120,16 @@ public class MockClientHttpRequest extends MockHttpOutputMessage implements Clie
return this.executed;
}
@Override
public Map<String, Object> getAttributes() {
Map<String, Object> attributes = this.attributes;
if (attributes == null) {
attributes = new ConcurrentHashMap<>();
this.attributes = attributes;
}
return attributes;
}
/**
* Set the {@link #isExecuted() executed} flag to {@code true} and return the
* configured {@link #setResponse(ClientHttpResponse) response}.

View File

@ -18,6 +18,8 @@ package org.springframework.web.reactive.function.client;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.util.Collections;
import java.util.Map;
import java.util.function.Consumer;
import java.util.function.Function;
@ -65,6 +67,11 @@ final class DefaultClientResponseBuilder implements ClientResponse.Builder {
public HttpHeaders getHeaders() {
return HttpHeaders.EMPTY;
}
@Override
public Map<String, Object> getAttributes() {
return Collections.emptyMap();
}
};

View File

@ -17,6 +17,7 @@
package org.springframework.web.reactive.function.client;
import java.net.URI;
import java.util.Map;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
@ -149,6 +150,11 @@ public abstract class ExchangeFunctions {
return request.url();
}
@Override
public Map<String, Object> getAttributes() {
return request.attributes();
}
@Override
public HttpHeaders getHeaders() {
return request.headers();

View File

@ -193,7 +193,7 @@ class DefaultServerRequestBuilder implements ServerRequest.Builder {
@Override
public ServerRequest build() {
ServerHttpRequest serverHttpRequest = new BuiltServerHttpRequest(this.exchange.getRequest().getId(),
this.method, this.uri, this.contextPath, this.headers, this.cookies, this.body);
this.method, this.uri, this.contextPath, this.headers, this.cookies, this.body, this.attributes);
ServerWebExchange exchange = new DelegatingServerWebExchange(
serverHttpRequest, this.attributes, this.exchange, this.messageReaders);
return new DefaultServerRequest(exchange, this.messageReaders);
@ -220,8 +220,10 @@ class DefaultServerRequestBuilder implements ServerRequest.Builder {
private final Flux<DataBuffer> body;
private final Map<String, Object> attributes;
public BuiltServerHttpRequest(String id, HttpMethod method, URI uri, @Nullable String contextPath,
HttpHeaders headers, MultiValueMap<String, HttpCookie> cookies, Flux<DataBuffer> body) {
HttpHeaders headers, MultiValueMap<String, HttpCookie> cookies, Flux<DataBuffer> body, Map<String, Object> attributes) {
this.id = id;
this.method = method;
@ -231,6 +233,7 @@ class DefaultServerRequestBuilder implements ServerRequest.Builder {
this.cookies = unmodifiableCopy(cookies);
this.queryParams = parseQueryParams(uri);
this.body = body;
this.attributes = attributes;
}
private static <K, V> MultiValueMap<K, V> unmodifiableCopy(MultiValueMap<K, V> original) {
@ -273,6 +276,11 @@ class DefaultServerRequestBuilder implements ServerRequest.Builder {
return this.uri;
}
@Override
public Map<String, Object> getAttributes() {
return this.attributes;
}
@Override
public RequestPath getPath() {
return this.path;

View File

@ -19,6 +19,7 @@ package org.springframework.web.reactive.function.client;
import java.net.InetSocketAddress;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.OptionalLong;
@ -324,6 +325,11 @@ class DefaultClientResponseTests {
public HttpHeaders getHeaders() {
return HttpHeaders.EMPTY;
}
@Override
public Map<String, Object> getAttributes() {
return Collections.emptyMap();
}
};
given(mockExchangeStrategies.messageReaders()).willReturn(