From b6befd133cd33d24065960613447facce9863904 Mon Sep 17 00:00:00 2001 From: Scott Frederick Date: Thu, 30 Mar 2023 16:32:57 -0500 Subject: [PATCH] Add SSL bundle support to RestTemplateBuilder auto-configuration Update RestTemplateBuilder auto-configuration so that an SSL can be configured via an SSL bundle. Closes gh-34810 --- .../src/docs/asciidoc/io/rest-client.adoc | 5 +- spring-boot-project/spring-boot/build.gradle | 1 + .../client/ClientHttpRequestFactories.java | 63 +++++++++++++++---- .../ClientHttpRequestFactorySettings.java | 45 +++++++++++-- .../boot/web/client/RestTemplateBuilder.java | 15 +++++ .../HttpWebServiceMessageSenderBuilder.java | 17 ++++- ...ClientHttpRequestFactorySettingsTests.java | 17 +++++ 7 files changed, 141 insertions(+), 22 deletions(-) diff --git a/spring-boot-project/spring-boot-docs/src/docs/asciidoc/io/rest-client.adoc b/spring-boot-project/spring-boot-docs/src/docs/asciidoc/io/rest-client.adoc index a74737a0160..a0f420cb37c 100644 --- a/spring-boot-project/spring-boot-docs/src/docs/asciidoc/io/rest-client.adoc +++ b/spring-boot-project/spring-boot-docs/src/docs/asciidoc/io/rest-client.adoc @@ -13,8 +13,9 @@ The following code shows a typical example: include::code:MyService[] -TIP: `RestTemplateBuilder` includes a number of useful methods that can be used to quickly configure a `RestTemplate`. -For example, to add BASIC auth support, you can use `builder.basicAuthentication("user", "password").build()`. +`RestTemplateBuilder` includes a number of useful methods that can be used to quickly configure a `RestTemplate`. +For example, to add BASIC authentication support, you can use `builder.basicAuthentication("user", "password").build()`. +To add SSL support using an <>, you can use `builder.setSslBundle(sslBundle).build()`. diff --git a/spring-boot-project/spring-boot/build.gradle b/spring-boot-project/spring-boot/build.gradle index 1c12803ea35..f0831197439 100644 --- a/spring-boot-project/spring-boot/build.gradle +++ b/spring-boot-project/spring-boot/build.gradle @@ -29,6 +29,7 @@ dependencies { optional("com.oracle.database.jdbc:ucp") optional("com.oracle.database.jdbc:ojdbc8") optional("com.samskivert:jmustache") + optional("com.squareup.okhttp3:okhttp") optional("com.zaxxer:HikariCP") optional("io.netty:netty-tcnative-boringssl-static") optional("io.projectreactor:reactor-tools") diff --git a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/ClientHttpRequestFactories.java b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/ClientHttpRequestFactories.java index 232d87b8464..416791cc2c3 100644 --- a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/ClientHttpRequestFactories.java +++ b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/ClientHttpRequestFactories.java @@ -23,13 +23,22 @@ import java.time.Duration; import java.util.concurrent.TimeUnit; import java.util.function.Supplier; +import javax.net.ssl.SSLSocketFactory; +import javax.net.ssl.TrustManager; +import javax.net.ssl.X509TrustManager; + +import okhttp3.OkHttpClient; import org.apache.hc.client5.http.classic.HttpClient; import org.apache.hc.client5.http.impl.classic.HttpClientBuilder; import org.apache.hc.client5.http.impl.io.PoolingHttpClientConnectionManager; import org.apache.hc.client5.http.impl.io.PoolingHttpClientConnectionManagerBuilder; +import org.apache.hc.client5.http.ssl.DefaultHostnameVerifier; +import org.apache.hc.client5.http.ssl.SSLConnectionSocketFactory; import org.apache.hc.core5.http.io.SocketConfig; import org.springframework.boot.context.properties.PropertyMapper; +import org.springframework.boot.ssl.SslBundle; +import org.springframework.boot.ssl.SslOptions; import org.springframework.http.client.AbstractClientHttpRequestFactoryWrapper; import org.springframework.http.client.ClientHttpRequestFactory; import org.springframework.http.client.HttpComponentsClientHttpRequestFactory; @@ -37,6 +46,7 @@ import org.springframework.http.client.OkHttp3ClientHttpRequestFactory; import org.springframework.http.client.SimpleClientHttpRequestFactory; import org.springframework.util.Assert; import org.springframework.util.ClassUtils; +import org.springframework.util.CollectionUtils; import org.springframework.util.ReflectionUtils; /** @@ -45,6 +55,7 @@ import org.springframework.util.ReflectionUtils; * * @author Andy Wilkinson * @author Phillip Webb + * @author Scott Frederick * @since 3.0.0 */ public final class ClientHttpRequestFactories { @@ -134,25 +145,39 @@ public final class ClientHttpRequestFactories { static class HttpComponents { static HttpComponentsClientHttpRequestFactory get(ClientHttpRequestFactorySettings settings) { - HttpComponentsClientHttpRequestFactory requestFactory = createRequestFactory(settings.readTimeout()); + HttpComponentsClientHttpRequestFactory requestFactory = createRequestFactory(settings.readTimeout(), + settings.sslBundle()); PropertyMapper map = PropertyMapper.get().alwaysApplyingWhenNonNull(); map.from(settings::connectTimeout).asInt(Duration::toMillis).to(requestFactory::setConnectTimeout); map.from(settings::bufferRequestBody).to(requestFactory::setBufferRequestBody); return requestFactory; } - private static HttpComponentsClientHttpRequestFactory createRequestFactory(Duration readTimeout) { - return (readTimeout != null) ? new HttpComponentsClientHttpRequestFactory(createHttpClient(readTimeout)) - : new HttpComponentsClientHttpRequestFactory(); + private static HttpComponentsClientHttpRequestFactory createRequestFactory(Duration readTimeout, + SslBundle sslBundle) { + return new HttpComponentsClientHttpRequestFactory(createHttpClient(readTimeout, sslBundle)); } - private static HttpClient createHttpClient(Duration readTimeout) { - SocketConfig socketConfig = SocketConfig.custom() - .setSoTimeout((int) readTimeout.toMillis(), TimeUnit.MILLISECONDS) - .build(); - PoolingHttpClientConnectionManager connectionManager = PoolingHttpClientConnectionManagerBuilder.create() - .setDefaultSocketConfig(socketConfig) - .build(); + private static HttpClient createHttpClient(Duration readTimeout, SslBundle sslBundle) { + PoolingHttpClientConnectionManagerBuilder connectionManagerBuilder = PoolingHttpClientConnectionManagerBuilder + .create(); + if (readTimeout != null) { + SocketConfig socketConfig = SocketConfig.custom() + .setSoTimeout((int) readTimeout.toMillis(), TimeUnit.MILLISECONDS) + .build(); + connectionManagerBuilder.setDefaultSocketConfig(socketConfig); + } + if (sslBundle != null) { + SslOptions options = sslBundle.getOptions(); + String[] enabledProtocols = (!CollectionUtils.isEmpty(options.getEnabledProtocols())) + ? options.getEnabledProtocols().toArray(String[]::new) : null; + String[] ciphers = (!CollectionUtils.isEmpty(options.getCiphers())) + ? options.getCiphers().toArray(String[]::new) : null; + SSLConnectionSocketFactory socketFactory = new SSLConnectionSocketFactory(sslBundle.createSslContext(), + enabledProtocols, ciphers, new DefaultHostnameVerifier()); + connectionManagerBuilder.setSSLSocketFactory(socketFactory); + } + PoolingHttpClientConnectionManager connectionManager = connectionManagerBuilder.build(); return HttpClientBuilder.create().setConnectionManager(connectionManager).build(); } @@ -166,13 +191,27 @@ public final class ClientHttpRequestFactories { static OkHttp3ClientHttpRequestFactory get(ClientHttpRequestFactorySettings settings) { Assert.state(settings.bufferRequestBody() == null, () -> "OkHttp3ClientHttpRequestFactory does not support request body buffering"); - OkHttp3ClientHttpRequestFactory requestFactory = new OkHttp3ClientHttpRequestFactory(); + OkHttp3ClientHttpRequestFactory requestFactory = createRequestFactory(settings.sslBundle()); PropertyMapper map = PropertyMapper.get().alwaysApplyingWhenNonNull(); map.from(settings::connectTimeout).asInt(Duration::toMillis).to(requestFactory::setConnectTimeout); map.from(settings::readTimeout).asInt(Duration::toMillis).to(requestFactory::setReadTimeout); return requestFactory; } + private static OkHttp3ClientHttpRequestFactory createRequestFactory(SslBundle sslBundle) { + if (sslBundle != null) { + SSLSocketFactory socketFactory = sslBundle.createSslContext().getSocketFactory(); + TrustManager[] trustManagers = sslBundle.getManagers().getTrustManagers(); + Assert.state(trustManagers.length == 1, + "Trust material must be provided in the SSL bundle for OkHttp3ClientHttpRequestFactory"); + OkHttpClient client = new OkHttpClient.Builder() + .sslSocketFactory(socketFactory, (X509TrustManager) trustManagers[0]) + .build(); + return new OkHttp3ClientHttpRequestFactory(client); + } + return new OkHttp3ClientHttpRequestFactory(); + } + } /** diff --git a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/ClientHttpRequestFactorySettings.java b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/ClientHttpRequestFactorySettings.java index 9d68f4f7096..22deb5a4a16 100644 --- a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/ClientHttpRequestFactorySettings.java +++ b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/ClientHttpRequestFactorySettings.java @@ -18,6 +18,7 @@ package org.springframework.boot.web.client; import java.time.Duration; +import org.springframework.boot.ssl.SslBundle; import org.springframework.http.client.ClientHttpRequestFactory; /** @@ -26,20 +27,37 @@ import org.springframework.http.client.ClientHttpRequestFactory; * @param connectTimeout the connect timeout * @param readTimeout the read timeout * @param bufferRequestBody if request body buffering is used + * @param sslBundle the SSL bundle providing SSL configuration * @author Andy Wilkinson * @author Phillip Webb + * @author Scott Frederick * @since 3.0.0 * @see ClientHttpRequestFactories */ -public record ClientHttpRequestFactorySettings(Duration connectTimeout, Duration readTimeout, - Boolean bufferRequestBody) { +public record ClientHttpRequestFactorySettings(Duration connectTimeout, Duration readTimeout, Boolean bufferRequestBody, + SslBundle sslBundle) { /** * Use defaults for the {@link ClientHttpRequestFactory} which can differ depending on * the implementation. */ public static final ClientHttpRequestFactorySettings DEFAULTS = new ClientHttpRequestFactorySettings(null, null, - null); + null, null); + + /** + * Create a new {@link ClientHttpRequestFactorySettings} instance. + * @param connectTimeout the connection timeout + * @param readTimeout the read timeout + * @param bufferRequestBody the bugger request body + * @param sslBundle the ssl bundle + * @since 3.1.0 + */ + public ClientHttpRequestFactorySettings { + } + + public ClientHttpRequestFactorySettings(Duration connectTimeout, Duration readTimeout, Boolean bufferRequestBody) { + this(connectTimeout, readTimeout, bufferRequestBody, null); + } /** * Return a new {@link ClientHttpRequestFactorySettings} instance with an updated @@ -48,7 +66,8 @@ public record ClientHttpRequestFactorySettings(Duration connectTimeout, Duration * @return a new {@link ClientHttpRequestFactorySettings} instance */ public ClientHttpRequestFactorySettings withConnectTimeout(Duration connectTimeout) { - return new ClientHttpRequestFactorySettings(connectTimeout, this.readTimeout, this.bufferRequestBody); + return new ClientHttpRequestFactorySettings(connectTimeout, this.readTimeout, this.bufferRequestBody, + this.sslBundle); } /** @@ -59,7 +78,8 @@ public record ClientHttpRequestFactorySettings(Duration connectTimeout, Duration */ public ClientHttpRequestFactorySettings withReadTimeout(Duration readTimeout) { - return new ClientHttpRequestFactorySettings(this.connectTimeout, readTimeout, this.bufferRequestBody); + return new ClientHttpRequestFactorySettings(this.connectTimeout, readTimeout, this.bufferRequestBody, + this.sslBundle); } /** @@ -69,7 +89,20 @@ public record ClientHttpRequestFactorySettings(Duration connectTimeout, Duration * @return a new {@link ClientHttpRequestFactorySettings} instance */ public ClientHttpRequestFactorySettings withBufferRequestBody(Boolean bufferRequestBody) { - return new ClientHttpRequestFactorySettings(this.connectTimeout, this.readTimeout, bufferRequestBody); + return new ClientHttpRequestFactorySettings(this.connectTimeout, this.readTimeout, bufferRequestBody, + this.sslBundle); + } + + /** + * Return a new {@link ClientHttpRequestFactorySettings} instance with an updated SSL + * bundle setting. + * @param sslBundle the new SSL bundle setting + * @return a new {@link ClientHttpRequestFactorySettings} instance + * @since 3.1.0 + */ + public ClientHttpRequestFactorySettings withSslBundle(SslBundle sslBundle) { + return new ClientHttpRequestFactorySettings(this.connectTimeout, this.readTimeout, this.bufferRequestBody, + sslBundle); } } diff --git a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/RestTemplateBuilder.java b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/RestTemplateBuilder.java index 9127c5246af..e1192a6a17e 100644 --- a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/RestTemplateBuilder.java +++ b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/client/RestTemplateBuilder.java @@ -33,6 +33,7 @@ import java.util.function.Supplier; import reactor.netty.http.client.HttpClientRequest; import org.springframework.beans.BeanUtils; +import org.springframework.boot.ssl.SslBundle; import org.springframework.http.client.ClientHttpRequest; import org.springframework.http.client.ClientHttpRequestFactory; import org.springframework.http.client.ClientHttpRequestInterceptor; @@ -64,6 +65,7 @@ import org.springframework.web.util.UriTemplateHandler; * @author Dmytro Nosan * @author Kevin Strijbos * @author Ilya Lukyanovich + * @author Scott Frederick * @since 1.4.0 */ public class RestTemplateBuilder { @@ -453,6 +455,19 @@ public class RestTemplateBuilder { this.customizers, this.requestCustomizers); } + /** + * Sets the SSL bundle on the underlying {@link ClientHttpRequestFactory}. + * @param sslBundle the SSL bundle + * @return a new builder instance + * @since 2.1.0 + */ + public RestTemplateBuilder setSslBundle(SslBundle sslBundle) { + return new RestTemplateBuilder(this.requestFactorySettings.withSslBundle(sslBundle), this.detectRequestFactory, + this.rootUri, this.messageConverters, this.interceptors, this.requestFactory, this.uriTemplateHandler, + this.errorHandler, this.basicAuthentication, this.defaultHeaders, this.customizers, + this.requestCustomizers); + } + /** * Set the {@link RestTemplateCustomizer RestTemplateCustomizers} that should be * applied to the {@link RestTemplate}. Customizers are applied in the order that they diff --git a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/webservices/client/HttpWebServiceMessageSenderBuilder.java b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/webservices/client/HttpWebServiceMessageSenderBuilder.java index f823566492e..8f5dbb3d7c0 100644 --- a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/webservices/client/HttpWebServiceMessageSenderBuilder.java +++ b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/webservices/client/HttpWebServiceMessageSenderBuilder.java @@ -1,5 +1,5 @@ /* - * Copyright 2012-2022 the original author or authors. + * Copyright 2012-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,6 +20,7 @@ import java.time.Duration; import java.util.function.Function; import java.util.function.Supplier; +import org.springframework.boot.ssl.SslBundle; import org.springframework.boot.web.client.ClientHttpRequestFactories; import org.springframework.boot.web.client.ClientHttpRequestFactorySettings; import org.springframework.http.client.ClientHttpRequestFactory; @@ -40,6 +41,8 @@ public class HttpWebServiceMessageSenderBuilder { private Duration readTimeout; + private SslBundle sslBundle; + private Function requestFactory; /** @@ -62,6 +65,16 @@ public class HttpWebServiceMessageSenderBuilder { return this; } + /** + * Set an {@link SslBundle} that will be used to configure a secure connection. + * @param sslBundle the SSL bundle + * @return a new builder instance + */ + public HttpWebServiceMessageSenderBuilder sslBundle(SslBundle sslBundle) { + this.sslBundle = sslBundle; + return this; + } + /** * Set the {@code Supplier} of {@link ClientHttpRequestFactory} that should be called * to create the HTTP-based {@link WebServiceMessageSender}. @@ -100,7 +113,7 @@ public class HttpWebServiceMessageSenderBuilder { private ClientHttpRequestFactory getRequestFactory() { ClientHttpRequestFactorySettings settings = new ClientHttpRequestFactorySettings(this.connectTimeout, - this.readTimeout, null); + this.readTimeout, null, this.sslBundle); return (this.requestFactory != null) ? this.requestFactory.apply(settings) : ClientHttpRequestFactories.get(settings); } diff --git a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/ClientHttpRequestFactorySettingsTests.java b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/ClientHttpRequestFactorySettingsTests.java index 5d6663a5c75..8103a3ae697 100644 --- a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/ClientHttpRequestFactorySettingsTests.java +++ b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/client/ClientHttpRequestFactorySettingsTests.java @@ -20,7 +20,10 @@ import java.time.Duration; import org.junit.jupiter.api.Test; +import org.springframework.boot.ssl.SslBundle; + import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.mock; /** * Tests for {@link ClientHttpRequestFactorySettings}. @@ -37,6 +40,7 @@ class ClientHttpRequestFactorySettingsTests { assertThat(settings.connectTimeout()).isNull(); assertThat(settings.readTimeout()).isNull(); assertThat(settings.bufferRequestBody()).isNull(); + assertThat(settings.sslBundle()).isNull(); } @Test @@ -46,6 +50,7 @@ class ClientHttpRequestFactorySettingsTests { assertThat(settings.connectTimeout()).isEqualTo(ONE_SECOND); assertThat(settings.readTimeout()).isNull(); assertThat(settings.bufferRequestBody()).isNull(); + assertThat(settings.sslBundle()).isNull(); } @Test @@ -55,6 +60,7 @@ class ClientHttpRequestFactorySettingsTests { assertThat(settings.connectTimeout()).isNull(); assertThat(settings.readTimeout()).isEqualTo(ONE_SECOND); assertThat(settings.bufferRequestBody()).isNull(); + assertThat(settings.sslBundle()).isNull(); } @Test @@ -64,6 +70,17 @@ class ClientHttpRequestFactorySettingsTests { assertThat(settings.connectTimeout()).isNull(); assertThat(settings.readTimeout()).isNull(); assertThat(settings.bufferRequestBody()).isTrue(); + assertThat(settings.sslBundle()).isNull(); + } + + @Test + void withSslBundleReturnsInstanceWithUpdatedSslBundle() { + SslBundle sslBundle = mock(SslBundle.class); + ClientHttpRequestFactorySettings settings = ClientHttpRequestFactorySettings.DEFAULTS.withSslBundle(sslBundle); + assertThat(settings.connectTimeout()).isNull(); + assertThat(settings.readTimeout()).isNull(); + assertThat(settings.bufferRequestBody()).isNull(); + assertThat(settings.sslBundle()).isSameAs(sslBundle); } }