Add SSL bundle support to RestTemplateBuilder auto-configuration

Update RestTemplateBuilder auto-configuration so that an SSL can be
configured via an SSL bundle.

Closes gh-34810
This commit is contained in:
Scott Frederick 2023-03-30 16:32:57 -05:00 committed by Phillip Webb
parent fd5fd1491a
commit b6befd133c
7 changed files with 141 additions and 22 deletions

View File

@ -13,8 +13,9 @@ The following code shows a typical example:
include::code:MyService[]
TIP: `RestTemplateBuilder` includes a number of useful methods that can be used to quickly configure a `RestTemplate`.
For example, to add BASIC auth support, you can use `builder.basicAuthentication("user", "password").build()`.
`RestTemplateBuilder` includes a number of useful methods that can be used to quickly configure a `RestTemplate`.
For example, to add BASIC authentication support, you can use `builder.basicAuthentication("user", "password").build()`.
To add SSL support using an <<features#features.ssl.bundles,SSL bundle>>, you can use `builder.setSslBundle(sslBundle).build()`.

View File

@ -29,6 +29,7 @@ dependencies {
optional("com.oracle.database.jdbc:ucp")
optional("com.oracle.database.jdbc:ojdbc8")
optional("com.samskivert:jmustache")
optional("com.squareup.okhttp3:okhttp")
optional("com.zaxxer:HikariCP")
optional("io.netty:netty-tcnative-boringssl-static")
optional("io.projectreactor:reactor-tools")

View File

@ -23,13 +23,22 @@ import java.time.Duration;
import java.util.concurrent.TimeUnit;
import java.util.function.Supplier;
import javax.net.ssl.SSLSocketFactory;
import javax.net.ssl.TrustManager;
import javax.net.ssl.X509TrustManager;
import okhttp3.OkHttpClient;
import org.apache.hc.client5.http.classic.HttpClient;
import org.apache.hc.client5.http.impl.classic.HttpClientBuilder;
import org.apache.hc.client5.http.impl.io.PoolingHttpClientConnectionManager;
import org.apache.hc.client5.http.impl.io.PoolingHttpClientConnectionManagerBuilder;
import org.apache.hc.client5.http.ssl.DefaultHostnameVerifier;
import org.apache.hc.client5.http.ssl.SSLConnectionSocketFactory;
import org.apache.hc.core5.http.io.SocketConfig;
import org.springframework.boot.context.properties.PropertyMapper;
import org.springframework.boot.ssl.SslBundle;
import org.springframework.boot.ssl.SslOptions;
import org.springframework.http.client.AbstractClientHttpRequestFactoryWrapper;
import org.springframework.http.client.ClientHttpRequestFactory;
import org.springframework.http.client.HttpComponentsClientHttpRequestFactory;
@ -37,6 +46,7 @@ import org.springframework.http.client.OkHttp3ClientHttpRequestFactory;
import org.springframework.http.client.SimpleClientHttpRequestFactory;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
import org.springframework.util.CollectionUtils;
import org.springframework.util.ReflectionUtils;
/**
@ -45,6 +55,7 @@ import org.springframework.util.ReflectionUtils;
*
* @author Andy Wilkinson
* @author Phillip Webb
* @author Scott Frederick
* @since 3.0.0
*/
public final class ClientHttpRequestFactories {
@ -134,25 +145,39 @@ public final class ClientHttpRequestFactories {
static class HttpComponents {
static HttpComponentsClientHttpRequestFactory get(ClientHttpRequestFactorySettings settings) {
HttpComponentsClientHttpRequestFactory requestFactory = createRequestFactory(settings.readTimeout());
HttpComponentsClientHttpRequestFactory requestFactory = createRequestFactory(settings.readTimeout(),
settings.sslBundle());
PropertyMapper map = PropertyMapper.get().alwaysApplyingWhenNonNull();
map.from(settings::connectTimeout).asInt(Duration::toMillis).to(requestFactory::setConnectTimeout);
map.from(settings::bufferRequestBody).to(requestFactory::setBufferRequestBody);
return requestFactory;
}
private static HttpComponentsClientHttpRequestFactory createRequestFactory(Duration readTimeout) {
return (readTimeout != null) ? new HttpComponentsClientHttpRequestFactory(createHttpClient(readTimeout))
: new HttpComponentsClientHttpRequestFactory();
private static HttpComponentsClientHttpRequestFactory createRequestFactory(Duration readTimeout,
SslBundle sslBundle) {
return new HttpComponentsClientHttpRequestFactory(createHttpClient(readTimeout, sslBundle));
}
private static HttpClient createHttpClient(Duration readTimeout) {
SocketConfig socketConfig = SocketConfig.custom()
.setSoTimeout((int) readTimeout.toMillis(), TimeUnit.MILLISECONDS)
.build();
PoolingHttpClientConnectionManager connectionManager = PoolingHttpClientConnectionManagerBuilder.create()
.setDefaultSocketConfig(socketConfig)
.build();
private static HttpClient createHttpClient(Duration readTimeout, SslBundle sslBundle) {
PoolingHttpClientConnectionManagerBuilder connectionManagerBuilder = PoolingHttpClientConnectionManagerBuilder
.create();
if (readTimeout != null) {
SocketConfig socketConfig = SocketConfig.custom()
.setSoTimeout((int) readTimeout.toMillis(), TimeUnit.MILLISECONDS)
.build();
connectionManagerBuilder.setDefaultSocketConfig(socketConfig);
}
if (sslBundle != null) {
SslOptions options = sslBundle.getOptions();
String[] enabledProtocols = (!CollectionUtils.isEmpty(options.getEnabledProtocols()))
? options.getEnabledProtocols().toArray(String[]::new) : null;
String[] ciphers = (!CollectionUtils.isEmpty(options.getCiphers()))
? options.getCiphers().toArray(String[]::new) : null;
SSLConnectionSocketFactory socketFactory = new SSLConnectionSocketFactory(sslBundle.createSslContext(),
enabledProtocols, ciphers, new DefaultHostnameVerifier());
connectionManagerBuilder.setSSLSocketFactory(socketFactory);
}
PoolingHttpClientConnectionManager connectionManager = connectionManagerBuilder.build();
return HttpClientBuilder.create().setConnectionManager(connectionManager).build();
}
@ -166,13 +191,27 @@ public final class ClientHttpRequestFactories {
static OkHttp3ClientHttpRequestFactory get(ClientHttpRequestFactorySettings settings) {
Assert.state(settings.bufferRequestBody() == null,
() -> "OkHttp3ClientHttpRequestFactory does not support request body buffering");
OkHttp3ClientHttpRequestFactory requestFactory = new OkHttp3ClientHttpRequestFactory();
OkHttp3ClientHttpRequestFactory requestFactory = createRequestFactory(settings.sslBundle());
PropertyMapper map = PropertyMapper.get().alwaysApplyingWhenNonNull();
map.from(settings::connectTimeout).asInt(Duration::toMillis).to(requestFactory::setConnectTimeout);
map.from(settings::readTimeout).asInt(Duration::toMillis).to(requestFactory::setReadTimeout);
return requestFactory;
}
private static OkHttp3ClientHttpRequestFactory createRequestFactory(SslBundle sslBundle) {
if (sslBundle != null) {
SSLSocketFactory socketFactory = sslBundle.createSslContext().getSocketFactory();
TrustManager[] trustManagers = sslBundle.getManagers().getTrustManagers();
Assert.state(trustManagers.length == 1,
"Trust material must be provided in the SSL bundle for OkHttp3ClientHttpRequestFactory");
OkHttpClient client = new OkHttpClient.Builder()
.sslSocketFactory(socketFactory, (X509TrustManager) trustManagers[0])
.build();
return new OkHttp3ClientHttpRequestFactory(client);
}
return new OkHttp3ClientHttpRequestFactory();
}
}
/**

View File

@ -18,6 +18,7 @@ package org.springframework.boot.web.client;
import java.time.Duration;
import org.springframework.boot.ssl.SslBundle;
import org.springframework.http.client.ClientHttpRequestFactory;
/**
@ -26,20 +27,37 @@ import org.springframework.http.client.ClientHttpRequestFactory;
* @param connectTimeout the connect timeout
* @param readTimeout the read timeout
* @param bufferRequestBody if request body buffering is used
* @param sslBundle the SSL bundle providing SSL configuration
* @author Andy Wilkinson
* @author Phillip Webb
* @author Scott Frederick
* @since 3.0.0
* @see ClientHttpRequestFactories
*/
public record ClientHttpRequestFactorySettings(Duration connectTimeout, Duration readTimeout,
Boolean bufferRequestBody) {
public record ClientHttpRequestFactorySettings(Duration connectTimeout, Duration readTimeout, Boolean bufferRequestBody,
SslBundle sslBundle) {
/**
* Use defaults for the {@link ClientHttpRequestFactory} which can differ depending on
* the implementation.
*/
public static final ClientHttpRequestFactorySettings DEFAULTS = new ClientHttpRequestFactorySettings(null, null,
null);
null, null);
/**
* Create a new {@link ClientHttpRequestFactorySettings} instance.
* @param connectTimeout the connection timeout
* @param readTimeout the read timeout
* @param bufferRequestBody the bugger request body
* @param sslBundle the ssl bundle
* @since 3.1.0
*/
public ClientHttpRequestFactorySettings {
}
public ClientHttpRequestFactorySettings(Duration connectTimeout, Duration readTimeout, Boolean bufferRequestBody) {
this(connectTimeout, readTimeout, bufferRequestBody, null);
}
/**
* Return a new {@link ClientHttpRequestFactorySettings} instance with an updated
@ -48,7 +66,8 @@ public record ClientHttpRequestFactorySettings(Duration connectTimeout, Duration
* @return a new {@link ClientHttpRequestFactorySettings} instance
*/
public ClientHttpRequestFactorySettings withConnectTimeout(Duration connectTimeout) {
return new ClientHttpRequestFactorySettings(connectTimeout, this.readTimeout, this.bufferRequestBody);
return new ClientHttpRequestFactorySettings(connectTimeout, this.readTimeout, this.bufferRequestBody,
this.sslBundle);
}
/**
@ -59,7 +78,8 @@ public record ClientHttpRequestFactorySettings(Duration connectTimeout, Duration
*/
public ClientHttpRequestFactorySettings withReadTimeout(Duration readTimeout) {
return new ClientHttpRequestFactorySettings(this.connectTimeout, readTimeout, this.bufferRequestBody);
return new ClientHttpRequestFactorySettings(this.connectTimeout, readTimeout, this.bufferRequestBody,
this.sslBundle);
}
/**
@ -69,7 +89,20 @@ public record ClientHttpRequestFactorySettings(Duration connectTimeout, Duration
* @return a new {@link ClientHttpRequestFactorySettings} instance
*/
public ClientHttpRequestFactorySettings withBufferRequestBody(Boolean bufferRequestBody) {
return new ClientHttpRequestFactorySettings(this.connectTimeout, this.readTimeout, bufferRequestBody);
return new ClientHttpRequestFactorySettings(this.connectTimeout, this.readTimeout, bufferRequestBody,
this.sslBundle);
}
/**
* Return a new {@link ClientHttpRequestFactorySettings} instance with an updated SSL
* bundle setting.
* @param sslBundle the new SSL bundle setting
* @return a new {@link ClientHttpRequestFactorySettings} instance
* @since 3.1.0
*/
public ClientHttpRequestFactorySettings withSslBundle(SslBundle sslBundle) {
return new ClientHttpRequestFactorySettings(this.connectTimeout, this.readTimeout, this.bufferRequestBody,
sslBundle);
}
}

View File

@ -33,6 +33,7 @@ import java.util.function.Supplier;
import reactor.netty.http.client.HttpClientRequest;
import org.springframework.beans.BeanUtils;
import org.springframework.boot.ssl.SslBundle;
import org.springframework.http.client.ClientHttpRequest;
import org.springframework.http.client.ClientHttpRequestFactory;
import org.springframework.http.client.ClientHttpRequestInterceptor;
@ -64,6 +65,7 @@ import org.springframework.web.util.UriTemplateHandler;
* @author Dmytro Nosan
* @author Kevin Strijbos
* @author Ilya Lukyanovich
* @author Scott Frederick
* @since 1.4.0
*/
public class RestTemplateBuilder {
@ -453,6 +455,19 @@ public class RestTemplateBuilder {
this.customizers, this.requestCustomizers);
}
/**
* Sets the SSL bundle on the underlying {@link ClientHttpRequestFactory}.
* @param sslBundle the SSL bundle
* @return a new builder instance
* @since 2.1.0
*/
public RestTemplateBuilder setSslBundle(SslBundle sslBundle) {
return new RestTemplateBuilder(this.requestFactorySettings.withSslBundle(sslBundle), this.detectRequestFactory,
this.rootUri, this.messageConverters, this.interceptors, this.requestFactory, this.uriTemplateHandler,
this.errorHandler, this.basicAuthentication, this.defaultHeaders, this.customizers,
this.requestCustomizers);
}
/**
* Set the {@link RestTemplateCustomizer RestTemplateCustomizers} that should be
* applied to the {@link RestTemplate}. Customizers are applied in the order that they

View File

@ -1,5 +1,5 @@
/*
* Copyright 2012-2022 the original author or authors.
* Copyright 2012-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -20,6 +20,7 @@ import java.time.Duration;
import java.util.function.Function;
import java.util.function.Supplier;
import org.springframework.boot.ssl.SslBundle;
import org.springframework.boot.web.client.ClientHttpRequestFactories;
import org.springframework.boot.web.client.ClientHttpRequestFactorySettings;
import org.springframework.http.client.ClientHttpRequestFactory;
@ -40,6 +41,8 @@ public class HttpWebServiceMessageSenderBuilder {
private Duration readTimeout;
private SslBundle sslBundle;
private Function<ClientHttpRequestFactorySettings, ClientHttpRequestFactory> requestFactory;
/**
@ -62,6 +65,16 @@ public class HttpWebServiceMessageSenderBuilder {
return this;
}
/**
* Set an {@link SslBundle} that will be used to configure a secure connection.
* @param sslBundle the SSL bundle
* @return a new builder instance
*/
public HttpWebServiceMessageSenderBuilder sslBundle(SslBundle sslBundle) {
this.sslBundle = sslBundle;
return this;
}
/**
* Set the {@code Supplier} of {@link ClientHttpRequestFactory} that should be called
* to create the HTTP-based {@link WebServiceMessageSender}.
@ -100,7 +113,7 @@ public class HttpWebServiceMessageSenderBuilder {
private ClientHttpRequestFactory getRequestFactory() {
ClientHttpRequestFactorySettings settings = new ClientHttpRequestFactorySettings(this.connectTimeout,
this.readTimeout, null);
this.readTimeout, null, this.sslBundle);
return (this.requestFactory != null) ? this.requestFactory.apply(settings)
: ClientHttpRequestFactories.get(settings);
}

View File

@ -20,7 +20,10 @@ import java.time.Duration;
import org.junit.jupiter.api.Test;
import org.springframework.boot.ssl.SslBundle;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.mock;
/**
* Tests for {@link ClientHttpRequestFactorySettings}.
@ -37,6 +40,7 @@ class ClientHttpRequestFactorySettingsTests {
assertThat(settings.connectTimeout()).isNull();
assertThat(settings.readTimeout()).isNull();
assertThat(settings.bufferRequestBody()).isNull();
assertThat(settings.sslBundle()).isNull();
}
@Test
@ -46,6 +50,7 @@ class ClientHttpRequestFactorySettingsTests {
assertThat(settings.connectTimeout()).isEqualTo(ONE_SECOND);
assertThat(settings.readTimeout()).isNull();
assertThat(settings.bufferRequestBody()).isNull();
assertThat(settings.sslBundle()).isNull();
}
@Test
@ -55,6 +60,7 @@ class ClientHttpRequestFactorySettingsTests {
assertThat(settings.connectTimeout()).isNull();
assertThat(settings.readTimeout()).isEqualTo(ONE_SECOND);
assertThat(settings.bufferRequestBody()).isNull();
assertThat(settings.sslBundle()).isNull();
}
@Test
@ -64,6 +70,17 @@ class ClientHttpRequestFactorySettingsTests {
assertThat(settings.connectTimeout()).isNull();
assertThat(settings.readTimeout()).isNull();
assertThat(settings.bufferRequestBody()).isTrue();
assertThat(settings.sslBundle()).isNull();
}
@Test
void withSslBundleReturnsInstanceWithUpdatedSslBundle() {
SslBundle sslBundle = mock(SslBundle.class);
ClientHttpRequestFactorySettings settings = ClientHttpRequestFactorySettings.DEFAULTS.withSslBundle(sslBundle);
assertThat(settings.connectTimeout()).isNull();
assertThat(settings.readTimeout()).isNull();
assertThat(settings.bufferRequestBody()).isNull();
assertThat(settings.sslBundle()).isSameAs(sslBundle);
}
}