From 03a6f97e76e7b82deb6275cb6185a5b6ee1ad79d Mon Sep 17 00:00:00 2001 From: Brian Clozel Date: Fri, 16 Feb 2018 12:21:28 +0100 Subject: [PATCH] TestRestTemplate should not override request factory Previously `TestRestTemplate` would override the configured `ClientHttpRequestFactory` if the Apache HTTP client library was on classpath. This commit fixes two issues: 1. The existing `ClientHttpRequestFactory` is overridden *only* if it is using the Apache HTTP client variant, in order to wrap it with the `TestRestTemplate` custom support 2. Calling `withBasicAuth` will no longer directly use the request factory returned by the internal `RestTemplate`; if client interceptors are configured, the request factory is wrapped with an `InterceptingClientHttpRequestFactory`. If we don't unwrap it, interceptors are copied/applied twice in the newly created `TestRestTemplate` instance. For that, we need to use reflection as the underlying request factory is not accessible directly. Closes gh-8697 --- .../test/web/client/TestRestTemplate.java | 22 ++++++++++++++++--- .../web/client/TestRestTemplateTests.java | 12 ++++++++++ 2 files changed, 31 insertions(+), 3 deletions(-) diff --git a/spring-boot-project/spring-boot-test/src/main/java/org/springframework/boot/test/web/client/TestRestTemplate.java b/spring-boot-project/spring-boot-test/src/main/java/org/springframework/boot/test/web/client/TestRestTemplate.java index 656740cdb89..1a13885ab7f 100644 --- a/spring-boot-project/spring-boot-test/src/main/java/org/springframework/boot/test/web/client/TestRestTemplate.java +++ b/spring-boot-project/spring-boot-test/src/main/java/org/springframework/boot/test/web/client/TestRestTemplate.java @@ -17,6 +17,7 @@ package org.springframework.boot.test.web.client; import java.io.IOException; +import java.lang.reflect.Field; import java.net.URI; import java.util.ArrayList; import java.util.Arrays; @@ -45,12 +46,14 @@ import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; import org.springframework.http.RequestEntity; import org.springframework.http.ResponseEntity; +import org.springframework.http.client.ClientHttpRequestFactory; import org.springframework.http.client.ClientHttpRequestInterceptor; import org.springframework.http.client.ClientHttpResponse; import org.springframework.http.client.HttpComponentsClientHttpRequestFactory; +import org.springframework.http.client.InterceptingClientHttpRequestFactory; import org.springframework.http.client.support.BasicAuthorizationInterceptor; import org.springframework.util.Assert; -import org.springframework.util.ClassUtils; +import org.springframework.util.ReflectionUtils; import org.springframework.web.client.DefaultResponseErrorHandler; import org.springframework.web.client.RequestCallback; import org.springframework.web.client.ResponseExtractor; @@ -135,7 +138,8 @@ public class TestRestTemplate { HttpClientOption... httpClientOptions) { Assert.notNull(restTemplate, "RestTemplate must not be null"); this.httpClientOptions = httpClientOptions; - if (ClassUtils.isPresent("org.apache.http.client.config.RequestConfig", null)) { + if (restTemplate.getRequestFactory().getClass().getName() + .equals("org.springframework.http.client.HttpComponentsClientHttpRequestFactory")) { restTemplate.setRequestFactory( new CustomHttpComponentsClientHttpRequestFactory(httpClientOptions)); } @@ -1021,7 +1025,7 @@ public class TestRestTemplate { RestTemplate restTemplate = new RestTemplate(); restTemplate.setMessageConverters(getRestTemplate().getMessageConverters()); restTemplate.setInterceptors(getRestTemplate().getInterceptors()); - restTemplate.setRequestFactory(getRestTemplate().getRequestFactory()); + restTemplate.setRequestFactory(getRequestFactory(getRestTemplate())); restTemplate.setUriTemplateHandler(getRestTemplate().getUriTemplateHandler()); TestRestTemplate testRestTemplate = new TestRestTemplate(restTemplate, username, password, this.httpClientOptions); @@ -1030,6 +1034,18 @@ public class TestRestTemplate { return testRestTemplate; } + private ClientHttpRequestFactory getRequestFactory(RestTemplate restTemplate) { + ClientHttpRequestFactory requestFactory = restTemplate.getRequestFactory(); + if (InterceptingClientHttpRequestFactory.class.isAssignableFrom(requestFactory.getClass())) { + Field requestFactoryField = ReflectionUtils + .findField(RestTemplate.class, "requestFactory"); + ReflectionUtils.makeAccessible(requestFactoryField); + requestFactory = (ClientHttpRequestFactory) + ReflectionUtils.getField(requestFactoryField, getRestTemplate()); + } + return requestFactory; + } + @SuppressWarnings({ "rawtypes", "unchecked" }) private RequestEntity createRequestEntityWithRootAppliedUri( RequestEntity requestEntity) { diff --git a/spring-boot-project/spring-boot-test/src/test/java/org/springframework/boot/test/web/client/TestRestTemplateTests.java b/spring-boot-project/spring-boot-test/src/test/java/org/springframework/boot/test/web/client/TestRestTemplateTests.java index 43094e656ba..cd5976cd46a 100644 --- a/spring-boot-project/spring-boot-test/src/test/java/org/springframework/boot/test/web/client/TestRestTemplateTests.java +++ b/spring-boot-project/spring-boot-test/src/test/java/org/springframework/boot/test/web/client/TestRestTemplateTests.java @@ -37,6 +37,8 @@ import org.springframework.http.client.ClientHttpRequestFactory; import org.springframework.http.client.ClientHttpRequestInterceptor; import org.springframework.http.client.HttpComponentsClientHttpRequestFactory; import org.springframework.http.client.InterceptingClientHttpRequestFactory; +import org.springframework.http.client.OkHttp3ClientHttpRequestFactory; +import org.springframework.http.client.SimpleClientHttpRequestFactory; import org.springframework.http.client.support.BasicAuthorizationInterceptor; import org.springframework.mock.env.MockEnvironment; import org.springframework.mock.http.client.MockClientHttpRequest; @@ -82,6 +84,15 @@ public class TestRestTemplateTests { .isInstanceOf(HttpComponentsClientHttpRequestFactory.class); } + @Test + public void doNotReplaceCustomRequestFactory() { + RestTemplateBuilder builder = new RestTemplateBuilder() + .requestFactory(OkHttp3ClientHttpRequestFactory.class); + TestRestTemplate testRestTemplate = new TestRestTemplate(builder); + assertThat(testRestTemplate.getRestTemplate().getRequestFactory()) + .isInstanceOf(OkHttp3ClientHttpRequestFactory.class); + } + @Test public void getRootUriRootUriSetViaRestTemplateBuilder() { String rootUri = "http://example.com"; @@ -125,6 +136,7 @@ public class TestRestTemplateTests { @Test public void restOperationsAreAvailable() { RestTemplate delegate = mock(RestTemplate.class); + given(delegate.getRequestFactory()).willReturn(new SimpleClientHttpRequestFactory()); given(delegate.getUriTemplateHandler()) .willReturn(new DefaultUriBuilderFactory()); RestTemplateBuilder builder = mock(RestTemplateBuilder.class);