Restore compatibility with MockRestServiceServer

Closes gh-17885
This commit is contained in:
Phillip Webb 2019-09-09 14:19:30 -07:00 committed by Stephane Nicoll
parent 3d5530d15d
commit ad32603635
7 changed files with 141 additions and 181 deletions

View File

@ -99,10 +99,7 @@ class TestRestTemplateTests {
RestTemplateBuilder builder = new RestTemplateBuilder().requestFactory(() -> customFactory);
TestRestTemplate testRestTemplate = new TestRestTemplate(builder).withBasicAuth("test", "test");
RestTemplate restTemplate = testRestTemplate.getRestTemplate();
assertThat(restTemplate.getRequestFactory().getClass().getName())
.contains("RestTemplateBuilderClientHttpRequestFactoryWrapper");
Object requestFactory = ReflectionTestUtils.getField(restTemplate.getRequestFactory(), "requestFactory");
assertThat(requestFactory).isEqualTo(customFactory).hasSameClassAs(customFactory);
assertThat(restTemplate.getRequestFactory()).isEqualTo(customFactory).hasSameClassAs(customFactory);
}
@Test
@ -203,28 +200,21 @@ class TestRestTemplateTests {
}
@Test
void withBasicAuthAddsBasicAuthClientFactoryWhenNotAlreadyPresent() throws Exception {
void withBasicAuthAddsBasicAuthWhenNotAlreadyPresent() throws Exception {
TestRestTemplate original = new TestRestTemplate();
TestRestTemplate basicAuth = original.withBasicAuth("user", "password");
assertThat(getConverterClasses(original)).containsExactlyElementsOf(getConverterClasses(basicAuth));
assertThat(basicAuth.getRestTemplate().getRequestFactory().getClass().getName())
.contains("RestTemplateBuilderClientHttpRequestFactoryWrapper");
assertThat(ReflectionTestUtils.getField(basicAuth.getRestTemplate().getRequestFactory(), "requestFactory"))
.isInstanceOf(CustomHttpComponentsClientHttpRequestFactory.class);
assertThat(basicAuth.getRestTemplate().getInterceptors()).isEmpty();
assertBasicAuthorizationCredentials(original, null, null);
assertBasicAuthorizationCredentials(basicAuth, "user", "password");
}
@Test
void withBasicAuthReplacesBasicAuthClientFactoryWhenAlreadyPresent() throws Exception {
void withBasicAuthReplacesBasicAuthWhenAlreadyPresent() throws Exception {
TestRestTemplate original = new TestRestTemplate("foo", "bar").withBasicAuth("replace", "replace");
TestRestTemplate basicAuth = original.withBasicAuth("user", "password");
assertThat(getConverterClasses(basicAuth)).containsExactlyElementsOf(getConverterClasses(original));
assertThat(basicAuth.getRestTemplate().getRequestFactory().getClass().getName())
.contains("RestTemplateBuilderClientHttpRequestFactoryWrapper");
assertThat(ReflectionTestUtils.getField(basicAuth.getRestTemplate().getRequestFactory(), "requestFactory"))
.isInstanceOf(CustomHttpComponentsClientHttpRequestFactory.class);
assertThat(basicAuth.getRestTemplate().getInterceptors()).isEmpty();
assertBasicAuthorizationCredentials(original, "replace", "replace");
assertBasicAuthorizationCredentials(basicAuth, "user", "password");
}
@ -347,11 +337,16 @@ class TestRestTemplateTests {
private void assertBasicAuthorizationCredentials(TestRestTemplate testRestTemplate, String username,
String password) throws Exception {
ClientHttpRequestFactory requestFactory = testRestTemplate.getRestTemplate().getRequestFactory();
ClientHttpRequest request = requestFactory.createRequest(URI.create("http://localhost"), HttpMethod.POST);
assertThat(request.getHeaders()).containsKeys(HttpHeaders.AUTHORIZATION);
assertThat(request.getHeaders().get(HttpHeaders.AUTHORIZATION)).containsExactly(
"Basic " + Base64Utils.encodeToString(String.format("%s:%s", username, password).getBytes()));
ClientHttpRequest request = ReflectionTestUtils.invokeMethod(testRestTemplate.getRestTemplate(),
"createRequest", URI.create("http://localhost"), HttpMethod.POST);
if (username == null) {
assertThat(request.getHeaders()).doesNotContainKey(HttpHeaders.AUTHORIZATION);
}
else {
assertThat(request.getHeaders()).containsKeys(HttpHeaders.AUTHORIZATION);
assertThat(request.getHeaders().get(HttpHeaders.AUTHORIZATION)).containsExactly(
"Basic " + Base64Utils.encodeToString(String.format("%s:%s", username, password).getBytes()));
}
}

View File

@ -615,7 +615,7 @@ public class RestTemplateBuilder {
if (requestFactory != null) {
restTemplate.setRequestFactory(requestFactory);
}
addClientHttpRequestFactoryWrapper(restTemplate);
addClientHttpRequestInitializer(restTemplate);
if (!CollectionUtils.isEmpty(this.messageConverters)) {
restTemplate.setMessageConverters(new ArrayList<>(this.messageConverters));
}
@ -659,24 +659,12 @@ public class RestTemplateBuilder {
return requestFactory;
}
private void addClientHttpRequestFactoryWrapper(RestTemplate restTemplate) {
private void addClientHttpRequestInitializer(RestTemplate restTemplate) {
if (this.basicAuthentication == null && this.defaultHeaders.isEmpty() && this.requestCustomizers.isEmpty()) {
return;
}
List<ClientHttpRequestInterceptor> interceptors = null;
if (!restTemplate.getInterceptors().isEmpty()) {
// Stash and clear the interceptors so we can access the real factory
interceptors = new ArrayList<>(restTemplate.getInterceptors());
restTemplate.getInterceptors().clear();
}
ClientHttpRequestFactory requestFactory = restTemplate.getRequestFactory();
ClientHttpRequestFactory wrapper = new RestTemplateBuilderClientHttpRequestFactoryWrapper(requestFactory,
this.basicAuthentication, this.defaultHeaders, this.requestCustomizers);
restTemplate.setRequestFactory(wrapper);
// Restore the original interceptors
if (interceptors != null) {
restTemplate.getInterceptors().addAll(interceptors);
}
restTemplate.getClientHttpRequestInitializers().add(new RestTemplateBuilderClientHttpRequestInitializer(
this.basicAuthentication, this.defaultHeaders, this.requestCustomizers));
}
@SuppressWarnings("unchecked")

View File

@ -16,18 +16,15 @@
package org.springframework.boot.web.client;
import java.io.IOException;
import java.net.URI;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.springframework.boot.util.LambdaSafe;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.client.AbstractClientHttpRequestFactoryWrapper;
import org.springframework.http.client.ClientHttpRequest;
import org.springframework.http.client.ClientHttpRequestFactory;
import org.springframework.http.client.ClientHttpRequestInitializer;
/**
* {@link ClientHttpRequestFactory} to apply customizations from the
@ -36,7 +33,7 @@ import org.springframework.http.client.ClientHttpRequestFactory;
* @author Dmytro Nosan
* @author Ilya Lukyanovich
*/
class RestTemplateBuilderClientHttpRequestFactoryWrapper extends AbstractClientHttpRequestFactoryWrapper {
class RestTemplateBuilderClientHttpRequestInitializer implements ClientHttpRequestInitializer {
private final BasicAuthentication basicAuthentication;
@ -44,10 +41,8 @@ class RestTemplateBuilderClientHttpRequestFactoryWrapper extends AbstractClientH
private final Set<RestTemplateRequestCustomizer<?>> requestCustomizers;
RestTemplateBuilderClientHttpRequestFactoryWrapper(ClientHttpRequestFactory requestFactory,
BasicAuthentication basicAuthentication, Map<String, List<String>> defaultHeaders,
Set<RestTemplateRequestCustomizer<?>> requestCustomizers) {
super(requestFactory);
RestTemplateBuilderClientHttpRequestInitializer(BasicAuthentication basicAuthentication,
Map<String, List<String>> defaultHeaders, Set<RestTemplateRequestCustomizer<?>> requestCustomizers) {
this.basicAuthentication = basicAuthentication;
this.defaultHeaders = defaultHeaders;
this.requestCustomizers = requestCustomizers;
@ -55,9 +50,7 @@ class RestTemplateBuilderClientHttpRequestFactoryWrapper extends AbstractClientH
@Override
@SuppressWarnings("unchecked")
protected ClientHttpRequest createRequest(URI uri, HttpMethod httpMethod, ClientHttpRequestFactory requestFactory)
throws IOException {
ClientHttpRequest request = requestFactory.createRequest(uri, httpMethod);
public void initialize(ClientHttpRequest request) {
HttpHeaders headers = request.getHeaders();
if (this.basicAuthentication != null) {
this.basicAuthentication.applyTo(headers);
@ -65,7 +58,6 @@ class RestTemplateBuilderClientHttpRequestFactoryWrapper extends AbstractClientH
this.defaultHeaders.forEach(headers::putIfAbsent);
LambdaSafe.callbacks(RestTemplateRequestCustomizer.class, this.requestCustomizers, request)
.invoke((customizer) -> customizer.customize(request));
return request;
}
}

View File

@ -17,6 +17,7 @@
package org.springframework.boot.web.client;
import org.springframework.http.client.ClientHttpRequest;
import org.springframework.http.client.ClientHttpRequestInitializer;
import org.springframework.web.client.RestTemplate;
/**
@ -28,6 +29,7 @@ import org.springframework.web.client.RestTemplate;
* @author Phillip Webb
* @since 2.2.0
* @see RestTemplateBuilder
* @see ClientHttpRequestInitializer
*/
@FunctionalInterface
public interface RestTemplateRequestCustomizer<T extends ClientHttpRequest> {

View File

@ -1,117 +0,0 @@
/*
* Copyright 2012-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.boot.web.client;
import java.io.IOException;
import java.net.URI;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.InOrder;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.client.ClientHttpRequest;
import org.springframework.http.client.ClientHttpRequestFactory;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.inOrder;
import static org.mockito.Mockito.mock;
/**
* Tests for {@link RestTemplateBuilderClientHttpRequestFactoryWrapper}.
*
* @author Dmytro Nosan
* @author Ilya Lukyanovich
* @author Phillip Webb
*/
public class RestTemplateBuilderClientHttpRequestFactoryWrapperTests {
private ClientHttpRequestFactory requestFactory;
private final HttpHeaders headers = new HttpHeaders();
@BeforeEach
void setUp() throws IOException {
this.requestFactory = mock(ClientHttpRequestFactory.class);
ClientHttpRequest request = mock(ClientHttpRequest.class);
given(this.requestFactory.createRequest(any(), any())).willReturn(request);
given(request.getHeaders()).willReturn(this.headers);
}
@Test
void createRequestWhenHasBasicAuthAndNoAuthHeaderAddsHeader() throws IOException {
this.requestFactory = new RestTemplateBuilderClientHttpRequestFactoryWrapper(this.requestFactory,
new BasicAuthentication("spring", "boot", null), Collections.emptyMap(), Collections.emptySet());
ClientHttpRequest request = createRequest();
assertThat(request.getHeaders().get(HttpHeaders.AUTHORIZATION)).containsExactly("Basic c3ByaW5nOmJvb3Q=");
}
@Test
void createRequestWhenHasBasicAuthAndExistingAuthHeaderDoesNotAddHeader() throws IOException {
this.headers.setBasicAuth("boot", "spring");
this.requestFactory = new RestTemplateBuilderClientHttpRequestFactoryWrapper(this.requestFactory,
new BasicAuthentication("spring", "boot", null), Collections.emptyMap(), Collections.emptySet());
ClientHttpRequest request = createRequest();
assertThat(request.getHeaders().get(HttpHeaders.AUTHORIZATION)).doesNotContain("Basic c3ByaW5nOmJvb3Q=");
}
@Test
void createRequestWhenHasDefaultHeadersAddsMissing() throws IOException {
this.headers.add("one", "existing");
Map<String, List<String>> defaultHeaders = new LinkedHashMap<>();
defaultHeaders.put("one", Collections.singletonList("1"));
defaultHeaders.put("two", Arrays.asList("2", "3"));
defaultHeaders.put("three", Collections.singletonList("4"));
this.requestFactory = new RestTemplateBuilderClientHttpRequestFactoryWrapper(this.requestFactory, null,
defaultHeaders, Collections.emptySet());
ClientHttpRequest request = createRequest();
assertThat(request.getHeaders().get("one")).containsExactly("existing");
assertThat(request.getHeaders().get("two")).containsExactly("2", "3");
assertThat(request.getHeaders().get("three")).containsExactly("4");
}
@Test
@SuppressWarnings("unchecked")
void createRequestWhenHasRequestCustomizersAppliesThemInOrder() throws IOException {
Set<RestTemplateRequestCustomizer<?>> customizers = new LinkedHashSet<>();
customizers.add(mock(RestTemplateRequestCustomizer.class));
customizers.add(mock(RestTemplateRequestCustomizer.class));
customizers.add(mock(RestTemplateRequestCustomizer.class));
this.requestFactory = new RestTemplateBuilderClientHttpRequestFactoryWrapper(this.requestFactory, null,
Collections.emptyMap(), customizers);
ClientHttpRequest request = createRequest();
InOrder inOrder = inOrder(customizers.toArray());
for (RestTemplateRequestCustomizer<?> customizer : customizers) {
inOrder.verify((RestTemplateRequestCustomizer<ClientHttpRequest>) customizer).customize(request);
}
}
private ClientHttpRequest createRequest() throws IOException {
return this.requestFactory.createRequest(URI.create("https://localhost:8080"), HttpMethod.POST);
}
}

View File

@ -0,0 +1,94 @@
/*
* Copyright 2012-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.boot.web.client;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.junit.jupiter.api.Test;
import org.mockito.InOrder;
import org.springframework.http.HttpHeaders;
import org.springframework.http.client.ClientHttpRequest;
import org.springframework.mock.http.client.MockClientHttpRequest;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.inOrder;
import static org.mockito.Mockito.mock;
/**
* Tests for {@link RestTemplateBuilderClientHttpRequestInitializer}.
*
* @author Dmytro Nosan
* @author Ilya Lukyanovich
* @author Phillip Webb
*/
public class RestTemplateBuilderClientHttpRequestInitializerTests {
private final MockClientHttpRequest request = new MockClientHttpRequest();
@Test
void createRequestWhenHasBasicAuthAndNoAuthHeaderAddsHeader() throws IOException {
new RestTemplateBuilderClientHttpRequestInitializer(new BasicAuthentication("spring", "boot", null),
Collections.emptyMap(), Collections.emptySet()).initialize(this.request);
assertThat(this.request.getHeaders().get(HttpHeaders.AUTHORIZATION)).containsExactly("Basic c3ByaW5nOmJvb3Q=");
}
@Test
void createRequestWhenHasBasicAuthAndExistingAuthHeaderDoesNotAddHeader() throws IOException {
this.request.getHeaders().setBasicAuth("boot", "spring");
new RestTemplateBuilderClientHttpRequestInitializer(new BasicAuthentication("spring", "boot", null),
Collections.emptyMap(), Collections.emptySet()).initialize(this.request);
assertThat(this.request.getHeaders().get(HttpHeaders.AUTHORIZATION)).doesNotContain("Basic c3ByaW5nOmJvb3Q=");
}
@Test
void createRequestWhenHasDefaultHeadersAddsMissing() throws IOException {
this.request.getHeaders().add("one", "existing");
Map<String, List<String>> defaultHeaders = new LinkedHashMap<>();
defaultHeaders.put("one", Collections.singletonList("1"));
defaultHeaders.put("two", Arrays.asList("2", "3"));
defaultHeaders.put("three", Collections.singletonList("4"));
new RestTemplateBuilderClientHttpRequestInitializer(null, defaultHeaders, Collections.emptySet())
.initialize(this.request);
assertThat(this.request.getHeaders().get("one")).containsExactly("existing");
assertThat(this.request.getHeaders().get("two")).containsExactly("2", "3");
assertThat(this.request.getHeaders().get("three")).containsExactly("4");
}
@Test
@SuppressWarnings("unchecked")
void createRequestWhenHasRequestCustomizersAppliesThemInOrder() throws IOException {
Set<RestTemplateRequestCustomizer<?>> customizers = new LinkedHashSet<>();
customizers.add(mock(RestTemplateRequestCustomizer.class));
customizers.add(mock(RestTemplateRequestCustomizer.class));
customizers.add(mock(RestTemplateRequestCustomizer.class));
new RestTemplateBuilderClientHttpRequestInitializer(null, Collections.emptyMap(), customizers)
.initialize(this.request);
InOrder inOrder = inOrder(customizers.toArray());
for (RestTemplateRequestCustomizer<?> customizer : customizers) {
inOrder.verify((RestTemplateRequestCustomizer<ClientHttpRequest>) customizer).customize(this.request);
}
}
}

View File

@ -38,6 +38,7 @@ import org.springframework.http.MediaType;
import org.springframework.http.client.BufferingClientHttpRequestFactory;
import org.springframework.http.client.ClientHttpRequest;
import org.springframework.http.client.ClientHttpRequestFactory;
import org.springframework.http.client.ClientHttpRequestInitializer;
import org.springframework.http.client.ClientHttpRequestInterceptor;
import org.springframework.http.client.HttpComponentsClientHttpRequestFactory;
import org.springframework.http.client.InterceptingClientHttpRequestFactory;
@ -309,8 +310,7 @@ class RestTemplateBuilderTests {
@Test
void basicAuthenticationShouldApply() throws Exception {
RestTemplate template = this.builder.basicAuthentication("spring", "boot", StandardCharsets.UTF_8).build();
ClientHttpRequestFactory requestFactory = template.getRequestFactory();
ClientHttpRequest request = requestFactory.createRequest(URI.create("http://localhost"), HttpMethod.POST);
ClientHttpRequest request = createRequest(template);
assertThat(request.getHeaders()).containsOnlyKeys(HttpHeaders.AUTHORIZATION);
assertThat(request.getHeaders().get(HttpHeaders.AUTHORIZATION)).containsExactly("Basic c3ByaW5nOmJvb3Q=");
}
@ -318,8 +318,7 @@ class RestTemplateBuilderTests {
@Test
void defaultHeaderAddsHeader() throws IOException {
RestTemplate template = this.builder.defaultHeader("spring", "boot").build();
ClientHttpRequestFactory requestFactory = template.getRequestFactory();
ClientHttpRequest request = requestFactory.createRequest(URI.create("http://localhost"), HttpMethod.GET);
ClientHttpRequest request = createRequest(template);
assertThat(request.getHeaders()).contains(entry("spring", Collections.singletonList("boot")));
}
@ -328,17 +327,23 @@ class RestTemplateBuilderTests {
String name = HttpHeaders.ACCEPT;
String[] values = { MediaType.APPLICATION_JSON_VALUE, MediaType.APPLICATION_XML_VALUE };
RestTemplate template = this.builder.defaultHeader(name, values).build();
ClientHttpRequestFactory requestFactory = template.getRequestFactory();
ClientHttpRequest request = requestFactory.createRequest(URI.create("http://localhost"), HttpMethod.GET);
ClientHttpRequest request = createRequest(template);
assertThat(request.getHeaders()).contains(entry(name, Arrays.asList(values)));
}
@Test // gh-17885
void defaultHeaderWhenUsingMockRestServiceServerAddsHeader() throws IOException {
RestTemplate template = this.builder.defaultHeader("spring", "boot").build();
MockRestServiceServer.bindTo(template).build();
ClientHttpRequest request = createRequest(template);
assertThat(request.getHeaders()).contains(entry("spring", Collections.singletonList("boot")));
}
@Test
void requestCustomizersAddsCustomizers() throws IOException {
RestTemplate template = this.builder
.requestCustomizers((request) -> request.getHeaders().add("spring", "framework")).build();
ClientHttpRequestFactory requestFactory = template.getRequestFactory();
ClientHttpRequest request = requestFactory.createRequest(URI.create("http://localhost"), HttpMethod.GET);
ClientHttpRequest request = createRequest(template);
assertThat(request.getHeaders()).contains(entry("spring", Collections.singletonList("framework")));
}
@ -347,8 +352,7 @@ class RestTemplateBuilderTests {
RestTemplate template = this.builder
.requestCustomizers((request) -> request.getHeaders().add("spring", "framework"))
.additionalRequestCustomizers((request) -> request.getHeaders().add("for", "java")).build();
ClientHttpRequestFactory requestFactory = template.getRequestFactory();
ClientHttpRequest request = requestFactory.createRequest(URI.create("http://localhost"), HttpMethod.GET);
ClientHttpRequest request = createRequest(template);
assertThat(request.getHeaders()).contains(entry("spring", Collections.singletonList("framework")))
.contains(entry("for", Collections.singletonList("java")));
}
@ -428,11 +432,8 @@ class RestTemplateBuilderTests {
assertThat(restTemplate.getErrorHandler()).isEqualTo(errorHandler);
ClientHttpRequestFactory actualRequestFactory = restTemplate.getRequestFactory();
assertThat(actualRequestFactory).isInstanceOf(InterceptingClientHttpRequestFactory.class);
ClientHttpRequestFactory authRequestFactory = (ClientHttpRequestFactory) ReflectionTestUtils
.getField(actualRequestFactory, "requestFactory");
assertThat(authRequestFactory)
.isInstanceOf(RestTemplateBuilderClientHttpRequestFactoryWrapper.class);
assertThat(authRequestFactory).hasFieldOrPropertyWithValue("requestFactory", requestFactory);
ClientHttpRequestInitializer initializer = restTemplate.getClientHttpRequestInitializers().get(0);
assertThat(initializer).isInstanceOf(RestTemplateBuilderClientHttpRequestInitializer.class);
}).build();
}
@ -589,6 +590,11 @@ class RestTemplateBuilderTests {
assertThat(template.getRequestFactory()).isInstanceOf(BufferingClientHttpRequestFactory.class);
}
private ClientHttpRequest createRequest(RestTemplate template) {
return ReflectionTestUtils.invokeMethod(template, "createRequest", URI.create("http://localhost"),
HttpMethod.GET);
}
static class RestTemplateSubclass extends RestTemplate {
}