Invoke defaultRequest earlier in RestClient and WebClient

Closes gh-32053
This commit is contained in:
Arjen Poutsma 2024-01-25 11:09:38 +01:00
parent 218957f0e8
commit bc2257aaff
4 changed files with 50 additions and 8 deletions

View File

@ -181,7 +181,11 @@ final class DefaultRestClient implements RestClient {
} }
private RequestBodyUriSpec methodInternal(HttpMethod httpMethod) { private RequestBodyUriSpec methodInternal(HttpMethod httpMethod) {
return new DefaultRequestBodyUriSpec(httpMethod); DefaultRequestBodyUriSpec spec = new DefaultRequestBodyUriSpec(httpMethod);
if (this.defaultRequest != null) {
this.defaultRequest.accept(spec);
}
return spec;
} }
@Override @Override
@ -456,9 +460,6 @@ final class DefaultRestClient implements RestClient {
Observation observation = null; Observation observation = null;
URI uri = null; URI uri = null;
try { try {
if (DefaultRestClient.this.defaultRequest != null) {
DefaultRestClient.this.defaultRequest.accept(this);
}
uri = initUri(); uri = initUri();
HttpHeaders headers = initHeaders(); HttpHeaders headers = initHeaders();
ClientHttpRequest clientRequest = createRequest(uri); ClientHttpRequest clientRequest = createRequest(uri);

View File

@ -900,6 +900,29 @@ class RestClientIntegrationTests {
expectRequest(request -> assertThat(request.getHeader("foo")).isEqualTo("bar")); expectRequest(request -> assertThat(request.getHeader("foo")).isEqualTo("bar"));
} }
@ParameterizedRestClientTest
void defaultRequestOverride(ClientHttpRequestFactory requestFactory) {
startServer(requestFactory);
prepareResponse(response -> response.setHeader("Content-Type", "text/plain")
.setBody("Hello Spring!"));
RestClient headersClient = this.restClient.mutate()
.defaultRequest(request -> request.accept(MediaType.APPLICATION_JSON))
.build();
String result = headersClient.get()
.uri("/greeting")
.accept(MediaType.TEXT_PLAIN)
.retrieve()
.body(String.class);
assertThat(result).isEqualTo("Hello Spring!");
expectRequestCount(1);
expectRequest(request -> assertThat(request.getHeader("Accept")).isEqualTo(MediaType.TEXT_PLAIN_VALUE));
}
private void prepareResponse(Consumer<MockResponse> consumer) { private void prepareResponse(Consumer<MockResponse> consumer) {
MockResponse response = new MockResponse(); MockResponse response = new MockResponse();

View File

@ -177,7 +177,11 @@ final class DefaultWebClient implements WebClient {
} }
private RequestBodyUriSpec methodInternal(HttpMethod httpMethod) { private RequestBodyUriSpec methodInternal(HttpMethod httpMethod) {
return new DefaultRequestBodyUriSpec(httpMethod); DefaultRequestBodyUriSpec spec = new DefaultRequestBodyUriSpec(httpMethod);
if (this.defaultRequest != null) {
this.defaultRequest.accept(spec);
}
return spec;
} }
@Override @Override
@ -479,9 +483,6 @@ final class DefaultWebClient implements WebClient {
} }
private ClientRequest.Builder initRequestBuilder() { private ClientRequest.Builder initRequestBuilder() {
if (defaultRequest != null) {
defaultRequest.accept(this);
}
ClientRequest.Builder builder = ClientRequest.create(this.httpMethod, initUri()) ClientRequest.Builder builder = ClientRequest.create(this.httpMethod, initUri())
.headers(this::initHeaders) .headers(this::initHeaders)
.cookies(this::initCookies) .cookies(this::initCookies)

View File

@ -528,6 +528,23 @@ public class DefaultWebClientTests {
StepVerifier.create(responsePublisher).expectError(WebClientResponseException.class).verify(); StepVerifier.create(responsePublisher).expectError(WebClientResponseException.class).verify();
} }
@Test // gh-32053
void defaultRequestOverride() {
WebClient client = this.builder
.defaultRequest(spec -> spec.accept(MediaType.APPLICATION_JSON))
.build();
client.get().uri("/path")
.accept(MediaType.IMAGE_PNG)
.retrieve()
.bodyToMono(Void.class)
.block(Duration.ofSeconds(3));
ClientRequest request = verifyAndGetRequest();
assertThat(request.headers().getAccept()).containsExactly(MediaType.IMAGE_PNG);
}
private ClientRequest verifyAndGetRequest() { private ClientRequest verifyAndGetRequest() {
ClientRequest request = this.captor.getValue(); ClientRequest request = this.captor.getValue();