From f84f31b47d9a81ba187b46cdf9ef90edf68050f0 Mon Sep 17 00:00:00 2001 From: Andy Wilkinson Date: Wed, 13 Jul 2016 16:40:32 +0100 Subject: [PATCH] Add setReadTimeout and setConnectTimeout to RestTemplateBuilder Closes gh-6346 --- spring-boot-parent/pom.xml | 15 ++ spring-boot/pom.xml | 15 ++ .../boot/web/client/RestTemplateBuilder.java | 211 +++++++++++++++--- .../web/client/RestTemplateBuilderTests.java | 118 ++++++++++ 4 files changed, 322 insertions(+), 37 deletions(-) diff --git a/spring-boot-parent/pom.xml b/spring-boot-parent/pom.xml index 108f39305f7..b79c87ecaba 100644 --- a/spring-boot-parent/pom.xml +++ b/spring-boot-parent/pom.xml @@ -52,6 +52,21 @@ guava 18.0 + + com.squareup.okhttp + okhttp + 2.7.5 + + + com.squareup.okhttp3 + okhttp + 3.4.1 + + + io.netty + netty-all + 4.0.38.Final + io.spring.gradle dependency-management-plugin diff --git a/spring-boot/pom.xml b/spring-boot/pom.xml index e74d32cccd7..a6c8e35209d 100644 --- a/spring-boot/pom.xml +++ b/spring-boot/pom.xml @@ -271,6 +271,16 @@ h2 test + + com.squareup.okhttp + okhttp + test + + + com.squareup.okhttp3 + okhttp + test + mysql mysql-connector-java @@ -306,6 +316,11 @@ mariadb-java-client test + + io.netty + netty-all + test + org.postgresql postgresql diff --git a/spring-boot/src/main/java/org/springframework/boot/web/client/RestTemplateBuilder.java b/spring-boot/src/main/java/org/springframework/boot/web/client/RestTemplateBuilder.java index 566ec1b4b26..9b96e61db14 100644 --- a/spring-boot/src/main/java/org/springframework/boot/web/client/RestTemplateBuilder.java +++ b/spring-boot/src/main/java/org/springframework/boot/web/client/RestTemplateBuilder.java @@ -16,6 +16,8 @@ package org.springframework.boot.web.client; +import java.lang.reflect.Field; +import java.lang.reflect.Method; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; @@ -26,6 +28,7 @@ import java.util.Map; import java.util.Set; import org.springframework.beans.BeanUtils; +import org.springframework.http.client.AbstractClientHttpRequestFactoryWrapper; import org.springframework.http.client.ClientHttpRequestFactory; import org.springframework.http.client.SimpleClientHttpRequestFactory; import org.springframework.http.client.support.BasicAuthorizationInterceptor; @@ -33,6 +36,7 @@ import org.springframework.http.converter.HttpMessageConverter; import org.springframework.util.Assert; import org.springframework.util.ClassUtils; import org.springframework.util.CollectionUtils; +import org.springframework.util.ReflectionUtils; import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestTemplate; import org.springframework.web.util.UriTemplateHandler; @@ -51,6 +55,7 @@ import org.springframework.web.util.UriTemplateHandler; * * @author Stephane Nicoll * @author Phillip Webb + * @author Andy Wilkinson * @since 1.4.0 */ public class RestTemplateBuilder { @@ -84,7 +89,9 @@ public class RestTemplateBuilder { private final BasicAuthorizationInterceptor basicAuthorization; - private final Set customizers; + private final Set restTemplateCustomizers; + + private final Set requestFactoryCustomizers; /** * Create a new {@link RestTemplateBuilder} instance. @@ -100,8 +107,9 @@ public class RestTemplateBuilder { this.uriTemplateHandler = null; this.errorHandler = null; this.basicAuthorization = null; - this.customizers = Collections.unmodifiableSet( + this.restTemplateCustomizers = Collections.unmodifiableSet( new LinkedHashSet(Arrays.asList(customizers))); + this.requestFactoryCustomizers = Collections.emptySet(); } private RestTemplateBuilder(boolean detectRequestFactory, String rootUri, @@ -109,7 +117,8 @@ public class RestTemplateBuilder { ClientHttpRequestFactory requestFactory, UriTemplateHandler uriTemplateHandler, ResponseErrorHandler errorHandler, BasicAuthorizationInterceptor basicAuthorization, - Set customizers) { + Set restTemplateCustomizers, + Set requestFactoryCustomizers) { super(); this.detectRequestFactory = detectRequestFactory; this.rootUri = rootUri; @@ -118,7 +127,8 @@ public class RestTemplateBuilder { this.uriTemplateHandler = uriTemplateHandler; this.errorHandler = errorHandler; this.basicAuthorization = basicAuthorization; - this.customizers = customizers; + this.restTemplateCustomizers = restTemplateCustomizers; + this.requestFactoryCustomizers = requestFactoryCustomizers; } /** @@ -131,7 +141,8 @@ public class RestTemplateBuilder { public RestTemplateBuilder detectRequestFactory(boolean detectRequestFactory) { return new RestTemplateBuilder(detectRequestFactory, this.rootUri, this.messageConverters, this.requestFactory, this.uriTemplateHandler, - this.errorHandler, this.basicAuthorization, this.customizers); + this.errorHandler, this.basicAuthorization, this.restTemplateCustomizers, + this.requestFactoryCustomizers); } /** @@ -143,7 +154,8 @@ public class RestTemplateBuilder { public RestTemplateBuilder rootUri(String rootUri) { return new RestTemplateBuilder(this.detectRequestFactory, rootUri, this.messageConverters, this.requestFactory, this.uriTemplateHandler, - this.errorHandler, this.basicAuthorization, this.customizers); + this.errorHandler, this.basicAuthorization, this.restTemplateCustomizers, + this.requestFactoryCustomizers); } /** @@ -173,7 +185,8 @@ public class RestTemplateBuilder { Collections.unmodifiableSet( new LinkedHashSet>(messageConverters)), this.requestFactory, this.uriTemplateHandler, this.errorHandler, - this.basicAuthorization, this.customizers); + this.basicAuthorization, this.restTemplateCustomizers, + this.requestFactoryCustomizers); } /** @@ -202,7 +215,7 @@ public class RestTemplateBuilder { return new RestTemplateBuilder(this.detectRequestFactory, this.rootUri, append(this.messageConverters, messageConverters), this.requestFactory, this.uriTemplateHandler, this.errorHandler, this.basicAuthorization, - this.customizers); + this.restTemplateCustomizers, this.requestFactoryCustomizers); } /** @@ -217,7 +230,8 @@ public class RestTemplateBuilder { Collections.unmodifiableSet(new LinkedHashSet>( new RestTemplate().getMessageConverters())), this.requestFactory, this.uriTemplateHandler, this.errorHandler, - this.basicAuthorization, this.customizers); + this.basicAuthorization, this.restTemplateCustomizers, + this.requestFactoryCustomizers); } /** @@ -242,7 +256,8 @@ public class RestTemplateBuilder { Assert.notNull(requestFactory, "RequestFactory must not be null"); return new RestTemplateBuilder(this.detectRequestFactory, this.rootUri, this.messageConverters, requestFactory, this.uriTemplateHandler, - this.errorHandler, this.basicAuthorization, this.customizers); + this.errorHandler, this.basicAuthorization, this.restTemplateCustomizers, + this.requestFactoryCustomizers); } /** @@ -255,7 +270,8 @@ public class RestTemplateBuilder { Assert.notNull(uriTemplateHandler, "UriTemplateHandler must not be null"); return new RestTemplateBuilder(this.detectRequestFactory, this.rootUri, this.messageConverters, this.requestFactory, uriTemplateHandler, - this.errorHandler, this.basicAuthorization, this.customizers); + this.errorHandler, this.basicAuthorization, this.restTemplateCustomizers, + this.requestFactoryCustomizers); } /** @@ -268,7 +284,8 @@ public class RestTemplateBuilder { Assert.notNull(errorHandler, "ErrorHandler must not be null"); return new RestTemplateBuilder(this.detectRequestFactory, this.rootUri, this.messageConverters, this.requestFactory, this.uriTemplateHandler, - errorHandler, this.basicAuthorization, this.customizers); + errorHandler, this.basicAuthorization, this.restTemplateCustomizers, + this.requestFactoryCustomizers); } /** @@ -282,14 +299,14 @@ public class RestTemplateBuilder { return new RestTemplateBuilder(this.detectRequestFactory, this.rootUri, this.messageConverters, this.requestFactory, this.uriTemplateHandler, this.errorHandler, new BasicAuthorizationInterceptor(username, password), - this.customizers); + this.restTemplateCustomizers, this.requestFactoryCustomizers); } /** - * Set the {@link RestTemplateCustomizer RestTemplateCustomizers} that should be applied - * to the {@link RestTemplate}. Customizers are applied in the order that they were - * added after builder configuration has been applied. Setting this value will replace - * any previously configured customizers. + * Set the {@link RestTemplateCustomizer RestTemplateCustomizers} that should be + * applied to the {@link RestTemplate}. Customizers are applied in the order that they + * were added after builder configuration has been applied. Setting this value will + * replace any previously configured customizers. * @param restTemplateCustomizers the customizers to set * @return a new builder instance * @see #additionalCustomizers(RestTemplateCustomizer...) @@ -302,10 +319,10 @@ public class RestTemplateBuilder { } /** - * Set the {@link RestTemplateCustomizer RestTemplateCustomizers} that should be applied - * to the {@link RestTemplate}. Customizers are applied in the order that they were - * added after builder configuration has been applied. Setting this value will replace - * any previously configured customizers. + * Set the {@link RestTemplateCustomizer RestTemplateCustomizers} that should be + * applied to the {@link RestTemplate}. Customizers are applied in the order that they + * were added after builder configuration has been applied. Setting this value will + * replace any previously configured customizers. * @param restTemplateCustomizers the customizers to set * @return a new builder instance * @see #additionalCustomizers(RestTemplateCustomizer...) @@ -318,13 +335,14 @@ public class RestTemplateBuilder { this.messageConverters, this.requestFactory, this.uriTemplateHandler, this.errorHandler, this.basicAuthorization, Collections.unmodifiableSet(new LinkedHashSet( - restTemplateCustomizers))); + restTemplateCustomizers)), + this.requestFactoryCustomizers); } /** - * Add {@link RestTemplateCustomizer RestTemplateCustomizers} that should be applied to - * the {@link RestTemplate}. Customizers are applied in the order that they were added - * after builder configuration has been applied. + * Add {@link RestTemplateCustomizer RestTemplateCustomizers} that should be applied + * to the {@link RestTemplate}. Customizers are applied in the order that they were + * added after builder configuration has been applied. * @param restTemplateCustomizers the customizers to add * @return a new builder instance * @see #customizers(RestTemplateCustomizer...) @@ -337,9 +355,9 @@ public class RestTemplateBuilder { } /** - * Add {@link RestTemplateCustomizer RestTemplateCustomizers} that should be applied to - * the {@link RestTemplate}. Customizers are applied in the order that they were added - * after builder configuration has been applied. + * Add {@link RestTemplateCustomizer RestTemplateCustomizers} that should be applied + * to the {@link RestTemplate}. Customizers are applied in the order that they were + * added after builder configuration has been applied. * @param customizers the customizers to add * @return a new builder instance * @see #customizers(RestTemplateCustomizer...) @@ -350,7 +368,38 @@ public class RestTemplateBuilder { return new RestTemplateBuilder(this.detectRequestFactory, this.rootUri, this.messageConverters, this.requestFactory, this.uriTemplateHandler, this.errorHandler, this.basicAuthorization, - append(this.customizers, customizers)); + append(this.restTemplateCustomizers, customizers), + this.requestFactoryCustomizers); + } + + /** + * Sets the connect timeout in milliseconds on the underlying + * {@link ClientHttpRequestFactory}. + * + * @param connectTimeout the connect timeout in milliseconds + * @return a new builder instance. + */ + public RestTemplateBuilder setConnectTimeout(int connectTimeout) { + return new RestTemplateBuilder(this.detectRequestFactory, this.rootUri, + this.messageConverters, this.requestFactory, this.uriTemplateHandler, + this.errorHandler, this.basicAuthorization, this.restTemplateCustomizers, + append(this.requestFactoryCustomizers, + new ConnectTimeoutRequestFactoryCustomizer(connectTimeout))); + } + + /** + * Sets the read timeout in milliseconds on the underlying + * {@link ClientHttpRequestFactory}. + * + * @param readTimeout the read timeout in milliseconds + * @return a new builder instance. + */ + public RestTemplateBuilder setReadTimeout(int readTimeout) { + return new RestTemplateBuilder(this.detectRequestFactory, this.rootUri, + this.messageConverters, this.requestFactory, this.uriTemplateHandler, + this.errorHandler, this.basicAuthorization, this.restTemplateCustomizers, + append(this.requestFactoryCustomizers, + new ReadTimeoutRequestFactoryCustomizer(readTimeout))); } /** @@ -386,12 +435,7 @@ public class RestTemplateBuilder { * @see RestTemplateBuilder#build(Class) */ public T configure(T restTemplate) { - if (this.requestFactory != null) { - restTemplate.setRequestFactory(this.requestFactory); - } - else if (this.detectRequestFactory) { - restTemplate.setRequestFactory(detectRequestFactory()); - } + configureRequestFactory(restTemplate); if (!CollectionUtils.isEmpty(this.messageConverters)) { restTemplate.setMessageConverters( new ArrayList>(this.messageConverters)); @@ -408,14 +452,47 @@ public class RestTemplateBuilder { if (this.basicAuthorization != null) { restTemplate.getInterceptors().add(this.basicAuthorization); } - if (!CollectionUtils.isEmpty(this.customizers)) { - for (RestTemplateCustomizer customizer : this.customizers) { + if (!CollectionUtils.isEmpty(this.restTemplateCustomizers)) { + for (RestTemplateCustomizer customizer : this.restTemplateCustomizers) { customizer.customize(restTemplate); } } return restTemplate; } + private void configureRequestFactory(RestTemplate restTemplate) { + ClientHttpRequestFactory requestFactory = null; + if (this.requestFactory != null) { + requestFactory = unwrapRequestFactoryIfNecessary(this.requestFactory); + } + else if (this.detectRequestFactory) { + requestFactory = detectRequestFactory(); + } + if (requestFactory != null) { + for (RequestFactoryCustomizer customizer : this.requestFactoryCustomizers) { + customizer.customize(requestFactory); + } + restTemplate.setRequestFactory(requestFactory); + } + } + + private ClientHttpRequestFactory unwrapRequestFactoryIfNecessary( + ClientHttpRequestFactory requestFactory) { + if (!(requestFactory instanceof AbstractClientHttpRequestFactoryWrapper)) { + return requestFactory; + } + ClientHttpRequestFactory unwrappedRequestFactory = requestFactory; + Field field = ReflectionUtils.findField( + AbstractClientHttpRequestFactoryWrapper.class, "requestFactory"); + ReflectionUtils.makeAccessible(field); + do { + unwrappedRequestFactory = (ClientHttpRequestFactory) ReflectionUtils + .getField(field, unwrappedRequestFactory); + } + while (unwrappedRequestFactory instanceof AbstractClientHttpRequestFactoryWrapper); + return unwrappedRequestFactory; + } + private ClientHttpRequestFactory detectRequestFactory() { for (Map.Entry candidate : REQUEST_FACTORY_CANDIDATES .entrySet()) { @@ -429,6 +506,13 @@ public class RestTemplateBuilder { return new SimpleClientHttpRequestFactory(); } + private Set append(Set set, T addition) { + Set result = new LinkedHashSet( + set == null ? Collections.emptySet() : set); + result.add(addition); + return Collections.unmodifiableSet(result); + } + private Set append(Set set, Collection additions) { Set result = new LinkedHashSet( set == null ? Collections.emptySet() : set); @@ -436,4 +520,57 @@ public class RestTemplateBuilder { return Collections.unmodifiableSet(result); } + private interface RequestFactoryCustomizer { + + void customize(ClientHttpRequestFactory factory); + + } + + private static abstract class TimeoutConfiguringRequestFactoryCustomizer + implements RequestFactoryCustomizer { + + private final int timeout; + + private final String methodName; + + TimeoutConfiguringRequestFactoryCustomizer(int timeout, String methodName) { + this.timeout = timeout; + this.methodName = methodName; + } + + @Override + public void customize(ClientHttpRequestFactory factory) { + ReflectionUtils.invokeMethod(findMethod(factory), factory, this.timeout); + } + + private Method findMethod(ClientHttpRequestFactory factory) { + Method method = ReflectionUtils.findMethod(factory.getClass(), + this.methodName, int.class); + if (method != null) { + return method; + } + throw new IllegalStateException("Request factory " + factory.getClass() + + " does not have a " + this.methodName + "(int) method"); + } + + } + + private static class ReadTimeoutRequestFactoryCustomizer + extends TimeoutConfiguringRequestFactoryCustomizer { + + ReadTimeoutRequestFactoryCustomizer(int readTimeout) { + super(readTimeout, "setReadTimeout"); + } + + } + + private static class ConnectTimeoutRequestFactoryCustomizer + extends TimeoutConfiguringRequestFactoryCustomizer { + + ConnectTimeoutRequestFactoryCustomizer(int connectTimeout) { + super(connectTimeout, "setConnectTimeout"); + } + + } + } diff --git a/spring-boot/src/test/java/org/springframework/boot/web/client/RestTemplateBuilderTests.java b/spring-boot/src/test/java/org/springframework/boot/web/client/RestTemplateBuilderTests.java index a24989917e8..bdc81116b05 100644 --- a/spring-boot/src/test/java/org/springframework/boot/web/client/RestTemplateBuilderTests.java +++ b/spring-boot/src/test/java/org/springframework/boot/web/client/RestTemplateBuilderTests.java @@ -19,6 +19,8 @@ package org.springframework.boot.web.client; import java.util.Collections; import java.util.Set; +import com.squareup.okhttp.OkHttpClient; +import org.apache.http.client.config.RequestConfig; import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -26,14 +28,19 @@ import org.junit.rules.ExpectedException; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.springframework.http.client.BufferingClientHttpRequestFactory; import org.springframework.http.client.ClientHttpRequestFactory; import org.springframework.http.client.ClientHttpRequestInterceptor; import org.springframework.http.client.HttpComponentsClientHttpRequestFactory; +import org.springframework.http.client.Netty4ClientHttpRequestFactory; +import org.springframework.http.client.OkHttp3ClientHttpRequestFactory; +import org.springframework.http.client.OkHttpClientHttpRequestFactory; import org.springframework.http.client.SimpleClientHttpRequestFactory; import org.springframework.http.client.support.BasicAuthorizationInterceptor; import org.springframework.http.converter.HttpMessageConverter; import org.springframework.http.converter.ResourceHttpMessageConverter; import org.springframework.http.converter.StringHttpMessageConverter; +import org.springframework.test.util.ReflectionTestUtils; import org.springframework.test.web.client.MockRestServiceServer; import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestTemplate; @@ -53,6 +60,7 @@ import static org.springframework.test.web.client.response.MockRestResponseCreat * * @author Stephane Nicoll * @author Phillip Webb + * @author Andy Wilkinson */ public class RestTemplateBuilderTests { @@ -356,6 +364,116 @@ public class RestTemplateBuilderTests { .isInstanceOf(HttpComponentsClientHttpRequestFactory.class); } + @Test + public void connectTimeoutCanBeConfiguredOnHttpComponentsRequestFactory() { + ClientHttpRequestFactory requestFactory = this.builder + .requestFactory(HttpComponentsClientHttpRequestFactory.class) + .setConnectTimeout(1234).build().getRequestFactory(); + assertThat(((RequestConfig) ReflectionTestUtils.getField(requestFactory, + "requestConfig")).getConnectTimeout()).isEqualTo(1234); + } + + @Test + public void readTimeoutCanBeConfiguredOnHttpComponentsRequestFactory() { + ClientHttpRequestFactory requestFactory = this.builder + .requestFactory(HttpComponentsClientHttpRequestFactory.class) + .setReadTimeout(1234).build().getRequestFactory(); + assertThat(((RequestConfig) ReflectionTestUtils.getField(requestFactory, + "requestConfig")).getSocketTimeout()).isEqualTo(1234); + } + + @Test + public void connectTimeoutCanBeConfiguredOnSimpleRequestFactory() { + ClientHttpRequestFactory requestFactory = this.builder + .requestFactory(SimpleClientHttpRequestFactory.class) + .setConnectTimeout(1234).build().getRequestFactory(); + assertThat(ReflectionTestUtils.getField(requestFactory, "connectTimeout")) + .isEqualTo(1234); + } + + @Test + public void readTimeoutCanBeConfiguredOnSimpleRequestFactory() { + ClientHttpRequestFactory requestFactory = this.builder + .requestFactory(SimpleClientHttpRequestFactory.class).setReadTimeout(1234) + .build().getRequestFactory(); + assertThat(ReflectionTestUtils.getField(requestFactory, "readTimeout")) + .isEqualTo(1234); + } + + @Test + public void connectTimeoutCanBeConfiguredOnNetty4RequestFactory() { + ClientHttpRequestFactory requestFactory = this.builder + .requestFactory(Netty4ClientHttpRequestFactory.class) + .setConnectTimeout(1234).build().getRequestFactory(); + assertThat(ReflectionTestUtils.getField(requestFactory, "connectTimeout")) + .isEqualTo(1234); + } + + @Test + public void readTimeoutCanBeConfiguredOnNetty4RequestFactory() { + ClientHttpRequestFactory requestFactory = this.builder + .requestFactory(Netty4ClientHttpRequestFactory.class).setReadTimeout(1234) + .build().getRequestFactory(); + assertThat(ReflectionTestUtils.getField(requestFactory, "readTimeout")) + .isEqualTo(1234); + } + + @Test + public void connectTimeoutCanBeConfiguredOnOkHttp2RequestFactory() { + ClientHttpRequestFactory requestFactory = this.builder + .requestFactory(OkHttpClientHttpRequestFactory.class) + .setConnectTimeout(1234).build().getRequestFactory(); + assertThat(((OkHttpClient) ReflectionTestUtils.getField(requestFactory, "client")) + .getConnectTimeout()).isEqualTo(1234); + } + + @Test + public void readTimeoutCanBeConfiguredOnOkHttp2RequestFactory() { + ClientHttpRequestFactory requestFactory = this.builder + .requestFactory(OkHttpClientHttpRequestFactory.class).setReadTimeout(1234) + .build().getRequestFactory(); + assertThat(((OkHttpClient) ReflectionTestUtils.getField(requestFactory, "client")) + .getReadTimeout()).isEqualTo(1234); + } + + @Test + public void connectTimeoutCanBeConfiguredOnOkHttp3RequestFactory() { + ClientHttpRequestFactory requestFactory = this.builder + .requestFactory(OkHttp3ClientHttpRequestFactory.class) + .setConnectTimeout(1234).build().getRequestFactory(); + assertThat(ReflectionTestUtils.getField( + ReflectionTestUtils.getField(requestFactory, "client"), "connectTimeout")) + .isEqualTo(1234); + } + + @Test + public void readTimeoutCanBeConfiguredOnOkHttp3RequestFactory() { + ClientHttpRequestFactory requestFactory = this.builder + .requestFactory(OkHttp3ClientHttpRequestFactory.class) + .setReadTimeout(1234).build().getRequestFactory(); + assertThat(ReflectionTestUtils.getField( + ReflectionTestUtils.getField(requestFactory, "client"), "readTimeout")) + .isEqualTo(1234); + } + + @Test + public void connectTimeoutCanBeConfiguredOnAWrappedRequestFactory() { + SimpleClientHttpRequestFactory requestFactory = new SimpleClientHttpRequestFactory(); + this.builder.requestFactory(new BufferingClientHttpRequestFactory(requestFactory)) + .setConnectTimeout(1234).build(); + assertThat(ReflectionTestUtils.getField(requestFactory, "connectTimeout")) + .isEqualTo(1234); + } + + @Test + public void readTimeoutCanBeConfiguredOnAWrappedRequestFactory() { + SimpleClientHttpRequestFactory requestFactory = new SimpleClientHttpRequestFactory(); + this.builder.requestFactory(new BufferingClientHttpRequestFactory(requestFactory)) + .setReadTimeout(1234).build(); + assertThat(ReflectionTestUtils.getField(requestFactory, "readTimeout")) + .isEqualTo(1234); + } + public static class RestTemplateSubclass extends RestTemplate { }