From a0af708c036be0bb7054e15f5fd26eaae7b6f39e Mon Sep 17 00:00:00 2001 From: Nicklas Wiegandt Date: Sun, 13 Oct 2024 19:07:00 +0200 Subject: [PATCH] Add cookie support to RestClient See gh-33697 --- .../ROOT/pages/integration/rest-clients.adoc | 2 + .../web/client/DefaultRestClient.java | 60 ++++++++ .../web/client/DefaultRestClientBuilder.java | 40 ++++++ .../web/client/RestClient.java | 36 +++++ .../web/client/RestClientBuilderTests.java | 133 ++++++++++++++++++ .../client/RestClientIntegrationTests.java | 68 +++++++++ 6 files changed, 339 insertions(+) diff --git a/framework-docs/modules/ROOT/pages/integration/rest-clients.adoc b/framework-docs/modules/ROOT/pages/integration/rest-clients.adoc index ccdefbd1fc..482217e6d6 100644 --- a/framework-docs/modules/ROOT/pages/integration/rest-clients.adoc +++ b/framework-docs/modules/ROOT/pages/integration/rest-clients.adoc @@ -38,6 +38,7 @@ RestClient customClient = RestClient.builder() .baseUrl("https://example.com") .defaultUriVariables(Map.of("variable", "foo")) .defaultHeader("My-Header", "Foo") + .defaultCookie("My-Cookie", "Bar") .requestInterceptor(myCustomInterceptor) .requestInitializer(myCustomInitializer) .build(); @@ -55,6 +56,7 @@ val customClient = RestClient.builder() .baseUrl("https://example.com") .defaultUriVariables(mapOf("variable" to "foo")) .defaultHeader("My-Header", "Foo") + .defaultCookie("My-Cookie", "Bar") .requestInterceptor(myCustomInterceptor) .requestInitializer(myCustomInitializer) .build() diff --git a/spring-web/src/main/java/org/springframework/web/client/DefaultRestClient.java b/spring-web/src/main/java/org/springframework/web/client/DefaultRestClient.java index 9bf267c421..a9ab2c1364 100644 --- a/spring-web/src/main/java/org/springframework/web/client/DefaultRestClient.java +++ b/spring-web/src/main/java/org/springframework/web/client/DefaultRestClient.java @@ -40,6 +40,7 @@ import org.apache.commons.logging.LogFactory; import org.springframework.core.ParameterizedTypeReference; import org.springframework.core.ResolvableType; +import org.springframework.http.HttpCookie; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; import org.springframework.http.HttpRequest; @@ -64,6 +65,8 @@ import org.springframework.http.converter.SmartHttpMessageConverter; import org.springframework.lang.Nullable; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; import org.springframework.web.util.UriBuilder; import org.springframework.web.util.UriBuilderFactory; @@ -103,6 +106,9 @@ final class DefaultRestClient implements RestClient { @Nullable private final HttpHeaders defaultHeaders; + @Nullable + private final MultiValueMap defaultCookies; + @Nullable private final Consumer> defaultRequest; @@ -123,6 +129,7 @@ final class DefaultRestClient implements RestClient { @Nullable List initializers, UriBuilderFactory uriBuilderFactory, @Nullable HttpHeaders defaultHeaders, + @Nullable MultiValueMap defaultCookies, @Nullable Consumer> defaultRequest, @Nullable List statusHandlers, List> messageConverters, @@ -135,6 +142,7 @@ final class DefaultRestClient implements RestClient { this.interceptors = interceptors; this.uriBuilderFactory = uriBuilderFactory; this.defaultHeaders = defaultHeaders; + this.defaultCookies = defaultCookies; this.defaultRequest = defaultRequest; this.defaultStatusHandlers = (statusHandlers != null ? new ArrayList<>(statusHandlers) : new ArrayList<>()); this.messageConverters = messageConverters; @@ -293,6 +301,8 @@ final class DefaultRestClient implements RestClient { private class DefaultRequestBodyUriSpec implements RequestBodyUriSpec { + private static final String COOKIE_DELIMITER = "; "; + private final HttpMethod httpMethod; @Nullable @@ -301,6 +311,9 @@ final class DefaultRestClient implements RestClient { @Nullable private HttpHeaders headers; + @Nullable + private MultiValueMap cookies; + @Nullable private InternalBody body; @@ -356,6 +369,13 @@ final class DefaultRestClient implements RestClient { return this.headers; } + private MultiValueMap getCookies() { + if (this.cookies == null) { + this.cookies = new LinkedMultiValueMap<>(3); + } + return this.cookies; + } + @Override public DefaultRequestBodyUriSpec header(String headerName, String... headerValues) { for (String headerValue : headerValues) { @@ -382,6 +402,18 @@ final class DefaultRestClient implements RestClient { return this; } + @Override + public DefaultRequestBodyUriSpec cookie(String name, String value) { + getCookies().add(name, value); + return this; + } + + @Override + public DefaultRequestBodyUriSpec cookies(Consumer> cookiesConsumer) { + cookiesConsumer.accept(getCookies()); + return this; + } + @Override public DefaultRequestBodyUriSpec contentType(MediaType contentType) { getHeaders().setContentType(contentType); @@ -525,6 +557,12 @@ final class DefaultRestClient implements RestClient { try { uri = initUri(); HttpHeaders headers = initHeaders(); + + MultiValueMap cookies = initCookies(); + if (!CollectionUtils.isEmpty(cookies)) { + headers.put(HttpHeaders.COOKIE, List.of(cookiesToHeaderValue(cookies))); + } + ClientHttpRequest clientRequest = createRequest(uri); clientRequest.getHeaders().addAll(headers); Map attributes = getAttributes(); @@ -599,6 +637,28 @@ final class DefaultRestClient implements RestClient { } } + private MultiValueMap initCookies() { + MultiValueMap mergedCookies = new LinkedMultiValueMap<>(); + + if(!CollectionUtils.isEmpty(defaultCookies)) { + mergedCookies.putAll(defaultCookies); + } + + if(!CollectionUtils.isEmpty(this.cookies)) { + mergedCookies.putAll(this.cookies); + } + + return mergedCookies; + } + + private String cookiesToHeaderValue(MultiValueMap cookies) { + List flatCookies = new ArrayList<>(); + cookies.forEach((name, cookieValues) -> cookieValues.forEach(value -> + flatCookies.add(new HttpCookie(name, value).toString()) + )); + return String.join(COOKIE_DELIMITER, flatCookies); + } + private ClientHttpRequest createRequest(URI uri) throws IOException { ClientHttpRequestFactory factory; if (DefaultRestClient.this.interceptors != null) { diff --git a/spring-web/src/main/java/org/springframework/web/client/DefaultRestClientBuilder.java b/spring-web/src/main/java/org/springframework/web/client/DefaultRestClientBuilder.java index aeb6eab326..9522cd2496 100644 --- a/spring-web/src/main/java/org/springframework/web/client/DefaultRestClientBuilder.java +++ b/spring-web/src/main/java/org/springframework/web/client/DefaultRestClientBuilder.java @@ -56,6 +56,8 @@ import org.springframework.lang.Nullable; import org.springframework.util.Assert; import org.springframework.util.ClassUtils; import org.springframework.util.CollectionUtils; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; import org.springframework.web.util.DefaultUriBuilderFactory; import org.springframework.web.util.UriBuilderFactory; import org.springframework.web.util.UriTemplateHandler; @@ -127,6 +129,9 @@ final class DefaultRestClientBuilder implements RestClient.Builder { @Nullable private HttpHeaders defaultHeaders; + @Nullable + private MultiValueMap defaultCookies; + @Nullable private Consumer> defaultRequest; @@ -169,6 +174,8 @@ final class DefaultRestClientBuilder implements RestClient.Builder { else { this.defaultHeaders = null; } + this.defaultCookies = (other.defaultCookies != null ? + new LinkedMultiValueMap<>(other.defaultCookies) : null); this.defaultRequest = other.defaultRequest; this.statusHandlers = (other.statusHandlers != null ? new ArrayList<>(other.statusHandlers) : null); @@ -289,6 +296,25 @@ final class DefaultRestClientBuilder implements RestClient.Builder { return this.defaultHeaders; } + @Override + public RestClient.Builder defaultCookie(String cookie, String... values) { + initCookies().addAll(cookie, Arrays.asList(values)); + return this; + } + + @Override + public RestClient.Builder defaultCookies(Consumer> cookiesConsumer) { + cookiesConsumer.accept(initCookies()); + return this; + } + + private MultiValueMap initCookies() { + if (this.defaultCookies == null) { + this.defaultCookies = new LinkedMultiValueMap<>(3); + } + return this.defaultCookies; + } + @Override public RestClient.Builder defaultRequest(Consumer> defaultRequest) { this.defaultRequest = this.defaultRequest != null ? @@ -443,11 +469,13 @@ final class DefaultRestClientBuilder implements RestClient.Builder { ClientHttpRequestFactory requestFactory = initRequestFactory(); UriBuilderFactory uriBuilderFactory = initUriBuilderFactory(); HttpHeaders defaultHeaders = copyDefaultHeaders(); + MultiValueMap defaultCookies = copyDefaultCookies(); List> messageConverters = (this.messageConverters != null ? this.messageConverters : initMessageConverters()); return new DefaultRestClient(requestFactory, this.interceptors, this.initializers, uriBuilderFactory, defaultHeaders, + defaultCookies, this.defaultRequest, this.statusHandlers, messageConverters, @@ -501,4 +529,16 @@ final class DefaultRestClientBuilder implements RestClient.Builder { } } + @Nullable + private MultiValueMap copyDefaultCookies() { + if (this.defaultCookies != null) { + MultiValueMap copy = new LinkedMultiValueMap<>(this.defaultCookies.size()); + this.defaultCookies.forEach((key, values) -> copy.put(key, new ArrayList<>(values))); + return CollectionUtils.unmodifiableMultiValueMap(copy); + } + else { + return null; + } + } + } diff --git a/spring-web/src/main/java/org/springframework/web/client/RestClient.java b/spring-web/src/main/java/org/springframework/web/client/RestClient.java index c0e7ec0b50..7de2e2eab3 100644 --- a/spring-web/src/main/java/org/springframework/web/client/RestClient.java +++ b/spring-web/src/main/java/org/springframework/web/client/RestClient.java @@ -45,6 +45,7 @@ import org.springframework.http.client.ClientHttpResponse; import org.springframework.http.client.observation.ClientRequestObservationConvention; import org.springframework.http.converter.HttpMessageConverter; import org.springframework.lang.Nullable; +import org.springframework.util.MultiValueMap; import org.springframework.web.util.DefaultUriBuilderFactory; import org.springframework.web.util.UriBuilder; import org.springframework.web.util.UriBuilderFactory; @@ -312,6 +313,23 @@ public interface RestClient { */ Builder defaultHeaders(Consumer headersConsumer); + /** + * Global option to specify a cookie to be added to every request, + * if the request does not already contain such a cookie. + * @param cookie the cookie name + * @param values the cookie values + * @since 6.2 + */ + Builder defaultCookie(String cookie, String... values); + + /** + * Provides access to every {@link #defaultCookie(String, String...)} + * declared so far with the possibility to add, replace, or remove. + * @param cookiesConsumer a function that consumes the cookies map + * @since 6.2 + */ + Builder defaultCookies(Consumer> cookiesConsumer); + /** * Provide a consumer to customize every request being built. * @param defaultRequest the consumer to use for modifying requests @@ -519,6 +537,24 @@ public interface RestClient { */ S acceptCharset(Charset... acceptableCharsets); + /** + * Add a cookie with the given name and value. + * @param name the cookie name + * @param value the cookie value + * @return this builder + * @since 6.2 + */ + S cookie(String name, String value); + + /** + * Provides access to every cookie declared so far with the possibility + * to add, replace, or remove values. + * @param cookiesConsumer the consumer to provide access to + * @return this builder + * @since 6.2 + */ + S cookies(Consumer> cookiesConsumer); + /** * Set the value of the {@code If-Modified-Since} header. * @param ifModifiedSince the new value of the header diff --git a/spring-web/src/test/java/org/springframework/web/client/RestClientBuilderTests.java b/spring-web/src/test/java/org/springframework/web/client/RestClientBuilderTests.java index 4247390169..918d718730 100644 --- a/spring-web/src/test/java/org/springframework/web/client/RestClientBuilderTests.java +++ b/spring-web/src/test/java/org/springframework/web/client/RestClientBuilderTests.java @@ -21,6 +21,7 @@ import java.net.URI; import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.Map; import org.assertj.core.api.InstanceOfAssertFactories; import org.junit.jupiter.api.Test; @@ -42,6 +43,7 @@ import static org.assertj.core.api.Assertions.fail; /** * @author Arjen Poutsma * @author Sebastien Deleuze + * @author Nicklas Wiegandt */ public class RestClientBuilderTests { @@ -137,6 +139,123 @@ public class RestClientBuilderTests { assertThatIllegalArgumentException().isThrownBy(() -> builder.messageConverters(converters)); } + @Test + void defaultCookieAddsCookieToDefaultCookiesMap() { + RestClient.Builder builder = RestClient.builder(); + + builder.defaultCookie("myCookie", "testValue"); + + assertThat(fieldValue("defaultCookies", (DefaultRestClientBuilder) builder)) + .asInstanceOf(InstanceOfAssertFactories.MAP) + .containsExactly(Map.entry("myCookie", List.of("testValue"))); + } + + @Test + void defaultCookieWithMultipleValuesAddsCookieToDefaultCookiesMapWithAllValues() { + RestClient.Builder builder = RestClient.builder(); + + builder.defaultCookie("myCookie", "testValue1", "testValue2"); + + assertThat(fieldValue("defaultCookies", (DefaultRestClientBuilder) builder)) + .asInstanceOf(InstanceOfAssertFactories.MAP) + .containsExactly(Map.entry("myCookie", List.of("testValue1", "testValue2"))); + } + + @Test + void defaultCookiesAllowsToAddCookie() { + RestClient.Builder builder = RestClient.builder(); + builder.defaultCookie("firstCookie", "firstValue"); + + builder.defaultCookies(cookies -> cookies.add("secondCookie", "secondValue")); + + assertThat(fieldValue("defaultCookies", (DefaultRestClientBuilder) builder)) + .asInstanceOf(InstanceOfAssertFactories.MAP) + .containsExactly( + Map.entry("firstCookie", List.of("firstValue")), + Map.entry("secondCookie", List.of("secondValue")) + ); + } + + @Test + void defaultCookiesAllowsToRemoveCookie() { + RestClient.Builder builder = RestClient.builder(); + builder.defaultCookie("firstCookie", "firstValue"); + builder.defaultCookie("secondCookie", "secondValue"); + + builder.defaultCookies(cookies -> cookies.remove("firstCookie")); + + assertThat(fieldValue("defaultCookies", (DefaultRestClientBuilder) builder)) + .asInstanceOf(InstanceOfAssertFactories.MAP) + .containsExactly(Map.entry("secondCookie", List.of("secondValue"))); + } + + @Test + void copyConstructorCopiesDefaultCookies() { + DefaultRestClientBuilder sourceBuilder = new DefaultRestClientBuilder(); + sourceBuilder.defaultCookie("firstCookie", "firstValue"); + sourceBuilder.defaultCookie("secondCookie", "secondValue"); + + DefaultRestClientBuilder copiedBuilder = new DefaultRestClientBuilder(sourceBuilder); + + assertThat(fieldValue("defaultCookies", copiedBuilder)) + .asInstanceOf(InstanceOfAssertFactories.MAP) + .containsExactly( + Map.entry("firstCookie", List.of("firstValue")), + Map.entry("secondCookie", List.of("secondValue")) + ); + } + + @Test + void copyConstructorCopiesDefaultCookiesImmutable() { + DefaultRestClientBuilder sourceBuilder = new DefaultRestClientBuilder(); + sourceBuilder.defaultCookie("firstCookie", "firstValue"); + sourceBuilder.defaultCookie("secondCookie", "secondValue"); + DefaultRestClientBuilder copiedBuilder = new DefaultRestClientBuilder(sourceBuilder); + + sourceBuilder.defaultCookie("thirdCookie", "thirdValue"); + + assertThat(fieldValue("defaultCookies", copiedBuilder)) + .asInstanceOf(InstanceOfAssertFactories.MAP) + .containsExactly( + Map.entry("firstCookie", List.of("firstValue")), + Map.entry("secondCookie", List.of("secondValue")) + ); + } + + @Test + void buildCopiesDefaultCookies() { + RestClient.Builder builder = RestClient.builder(); + builder.defaultCookie("firstCookie", "firstValue"); + builder.defaultCookie("secondCookie", "secondValue"); + + RestClient restClient = builder.build(); + + assertThat(fieldValue("defaultCookies", restClient)) + .asInstanceOf(InstanceOfAssertFactories.MAP) + .containsExactly( + Map.entry("firstCookie", List.of("firstValue")), + Map.entry("secondCookie", List.of("secondValue")) + ); + } + + @Test + void buildCopiesDefaultCookiesImmutable() { + RestClient.Builder builder = RestClient.builder(); + builder.defaultCookie("firstCookie", "firstValue"); + builder.defaultCookie("secondCookie", "secondValue"); + RestClient restClient = builder.build(); + + builder.defaultCookie("thirdCookie", "thirdValue"); + builder.defaultCookie("firstCookie", "fourthValue"); + + assertThat(fieldValue("defaultCookies", restClient)) + .asInstanceOf(InstanceOfAssertFactories.MAP) + .containsExactly( + Map.entry("firstCookie", List.of("firstValue")), + Map.entry("secondCookie", List.of("secondValue")) + ); + } + @Nullable private static Object fieldValue(String name, DefaultRestClientBuilder instance) { try { @@ -150,4 +269,18 @@ public class RestClientBuilderTests { return null; } } + + @Nullable + private static Object fieldValue(String name, RestClient instance) { + try { + Field field = DefaultRestClient.class.getDeclaredField(name); + field.setAccessible(true); + + return field.get(instance); + } + catch (NoSuchFieldException | IllegalAccessException ex) { + fail(ex.getMessage(), ex); + return null; + } + } } diff --git a/spring-web/src/test/java/org/springframework/web/client/RestClientIntegrationTests.java b/spring-web/src/test/java/org/springframework/web/client/RestClientIntegrationTests.java index 7ec28be1bb..43be4226af 100644 --- a/spring-web/src/test/java/org/springframework/web/client/RestClientIntegrationTests.java +++ b/spring-web/src/test/java/org/springframework/web/client/RestClientIntegrationTests.java @@ -814,6 +814,25 @@ class RestClientIntegrationTests { expectRequest(request -> assertThat(request.getHeader("foo")).isEqualTo("bar")); } + @ParameterizedRestClientTest + void retrieveDefaultCookiesAsCookieHeader(ClientHttpRequestFactory requestFactory) { + startServer(requestFactory); + prepareResponse(response -> + response.setHeader("Content-Type", "text/plain").setBody("Hello Spring!")); + RestClient restClientWithCookies = this.restClient.mutate() + .defaultCookie("testCookie", "firstValue", "secondValue") + .build(); + + restClientWithCookies.get() + .uri("/greeting") + .header("X-Test-Header", "testvalue") + .retrieve(); + + expectRequest(request -> + assertThat(request.getHeader(HttpHeaders.COOKIE)) + .isEqualTo("testCookie=firstValue; testCookie=secondValue") + ); + } @ParameterizedRestClientTest void filterForErrorHandling(ClientHttpRequestFactory requestFactory) { @@ -947,6 +966,55 @@ class RestClientIntegrationTests { expectRequest(request -> assertThat(request.getPath()).isEqualTo("/foo%20bar")); } + @ParameterizedRestClientTest + void cookieAddsCookie(ClientHttpRequestFactory requestFactory) { + startServer(requestFactory); + prepareResponse(response -> response.setHeader("Content-Type", "text/plain") + .setBody("Hello Spring!")); + + this.restClient.get() + .uri("/greeting") + .cookie("foo", "bar") + .retrieve() + .body(String.class); + + expectRequest(request -> assertThat(request.getHeader("Cookie")).isEqualTo("foo=bar")); + } + + @ParameterizedRestClientTest + void cookieOverridesDefaultCookie(ClientHttpRequestFactory requestFactory) { + startServer(requestFactory); + prepareResponse(response -> response.setHeader("Content-Type", "text/plain") + .setBody("Hello Spring!")); + RestClient restClientWithCookies = this.restClient.mutate() + .defaultCookie("testCookie", "firstValue", "secondValue") + .build(); + + restClientWithCookies.get() + .uri("/greeting") + .cookie("testCookie", "test") + .retrieve() + .body(String.class); + + expectRequest(request -> assertThat(request.getHeader("Cookie")).isEqualTo("testCookie=test")); + } + + @ParameterizedRestClientTest + void cookiesCanRemoveCookie(ClientHttpRequestFactory requestFactory) { + startServer(requestFactory); + prepareResponse(response -> response.setHeader("Content-Type", "text/plain") + .setBody("Hello Spring!")); + + this.restClient.get() + .uri("/greeting") + .cookie("foo", "bar") + .cookie("test", "Hello") + .cookies(cookies -> cookies.remove("foo")) + .retrieve() + .body(String.class); + + expectRequest(request -> assertThat(request.getHeader("Cookie")).isEqualTo("test=Hello")); + } private void prepareResponse(Consumer consumer) { MockResponse response = new MockResponse();