diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/DefaultWebClient.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/DefaultWebClient.java index f5c58d30389..5f1e0fb66d3 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/DefaultWebClient.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/DefaultWebClient.java @@ -21,6 +21,7 @@ import java.nio.charset.Charset; import java.time.ZonedDateTime; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @@ -29,6 +30,7 @@ import java.util.function.Function; import java.util.function.IntPredicate; import java.util.function.Predicate; import java.util.function.Supplier; +import java.util.stream.Collectors; import org.reactivestreams.Publisher; import reactor.core.publisher.Flux; @@ -84,21 +86,35 @@ class DefaultWebClient implements WebClient { @Nullable private final Consumer> defaultRequest; + private final List defaultStatusHandlers; + private final DefaultWebClientBuilder builder; DefaultWebClient(ExchangeFunction exchangeFunction, UriBuilderFactory uriBuilderFactory, @Nullable HttpHeaders defaultHeaders, @Nullable MultiValueMap defaultCookies, - @Nullable Consumer> defaultRequest, DefaultWebClientBuilder builder) { + @Nullable Consumer> defaultRequest, + @Nullable Map, Function>> statusHandlerMap, + DefaultWebClientBuilder builder) { this.exchangeFunction = exchangeFunction; this.uriBuilderFactory = uriBuilderFactory; this.defaultHeaders = defaultHeaders; this.defaultCookies = defaultCookies; this.defaultRequest = defaultRequest; + this.defaultStatusHandlers = initStatusHandlers(statusHandlerMap); this.builder = builder; } + private static List initStatusHandlers( + @Nullable Map, Function>> handlerMap) { + + return (CollectionUtils.isEmpty(handlerMap) ? Collections.emptyList() : + handlerMap.entrySet().stream() + .map(entry -> new DefaultResponseSpec.StatusHandler(entry.getKey(), entry.getValue())) + .collect(Collectors.toList())); + }; + @Override public RequestHeadersUriSpec get() { @@ -365,7 +381,8 @@ class DefaultWebClient implements WebClient { @Override public ResponseSpec retrieve() { - return new DefaultResponseSpec(exchange(), this::createRequest); + return new DefaultResponseSpec( + exchange(), this::createRequest, DefaultWebClient.this.defaultStatusHandlers); } private HttpRequest createRequest() { @@ -502,11 +519,18 @@ class DefaultWebClient implements WebClient { private final List statusHandlers = new ArrayList<>(1); + private final int defaultStatusHandlerCount; + + + DefaultResponseSpec( + Mono responseMono, Supplier requestSupplier, + List defaultStatusHandlers) { - DefaultResponseSpec(Mono responseMono, Supplier requestSupplier) { this.responseMono = responseMono; this.requestSupplier = requestSupplier; + this.statusHandlers.addAll(defaultStatusHandlers); this.statusHandlers.add(DEFAULT_STATUS_HANDLER); + this.defaultStatusHandlerCount = this.statusHandlers.size(); } @@ -516,10 +540,9 @@ class DefaultWebClient implements WebClient { Assert.notNull(statusCodePredicate, "StatusCodePredicate must not be null"); Assert.notNull(exceptionFunction, "Function must not be null"); - int index = this.statusHandlers.size() - 1; // Default handler always last + int index = this.statusHandlers.size() - this.defaultStatusHandlerCount; // Default handlers always last this.statusHandlers.add(index, new StatusHandler(statusCodePredicate, exceptionFunction)); return this; - } @Override diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/DefaultWebClientBuilder.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/DefaultWebClientBuilder.java index b224801cba9..cd673871f6f 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/DefaultWebClientBuilder.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/DefaultWebClientBuilder.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2021 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,8 +22,13 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.function.Consumer; +import java.util.function.Function; +import java.util.function.Predicate; + +import reactor.core.publisher.Mono; import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatusCode; import org.springframework.http.client.reactive.ClientHttpConnector; import org.springframework.http.client.reactive.HttpComponentsClientHttpConnector; import org.springframework.http.client.reactive.JdkClientHttpConnector; @@ -82,6 +87,9 @@ final class DefaultWebClientBuilder implements WebClient.Builder { @Nullable private Consumer> defaultRequest; + @Nullable + private Map, Function>> statusHandlers; + @Nullable private List filters; @@ -120,6 +128,7 @@ final class DefaultWebClientBuilder implements WebClient.Builder { this.defaultCookies = (other.defaultCookies != null ? new LinkedMultiValueMap<>(other.defaultCookies) : null); this.defaultRequest = other.defaultRequest; + this.statusHandlers = (other.statusHandlers != null ? new LinkedHashMap<>(other.statusHandlers) : null); this.filters = (other.filters != null ? new ArrayList<>(other.filters) : null); this.connector = other.connector; @@ -193,6 +202,15 @@ final class DefaultWebClientBuilder implements WebClient.Builder { return this; } + @Override + public WebClient.Builder defaultStatusHandler(Predicate statusPredicate, + Function> exceptionFunction) { + + this.statusHandlers = (this.statusHandlers != null ? this.statusHandlers : new LinkedHashMap<>()); + this.statusHandlers.put(statusPredicate, exceptionFunction); + return this; + } + @Override public WebClient.Builder filter(ExchangeFilterFunction filter) { Assert.notNull(filter, "ExchangeFilterFunction must not be null"); @@ -282,7 +300,9 @@ final class DefaultWebClientBuilder implements WebClient.Builder { return new DefaultWebClient(filteredExchange, initUriBuilderFactory(), defaultHeaders, defaultCookies, - this.defaultRequest, new DefaultWebClientBuilder(this)); + this.defaultRequest, + this.statusHandlers, + new DefaultWebClientBuilder(this)); } private ClientHttpConnector initConnector() { diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/WebClient.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/WebClient.java index 6361c9b3740..1d71505963e 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/WebClient.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/WebClient.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2021 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -255,6 +255,20 @@ public interface WebClient { */ Builder defaultRequest(Consumer> defaultRequest); + /** + * Register a default + * {@link ResponseSpec#onStatus(Predicate, Function) status handler} to + * apply to every response. Such default handlers are applied in the + * order in which they are registered, and after any others that are + * registered for a specific response. + * @param statusPredicate to match responses with + * @param exceptionFunction to map the response to an error signal + * @return this builder + * @since 6.0 + */ + Builder defaultStatusHandler(Predicate statusPredicate, + Function> exceptionFunction); + /** * Add the given filter to the end of the filter chain. * @param filter the filter to be added to the chain diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/function/client/DefaultWebClientTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/function/client/DefaultWebClientTests.java index 89bc58ae694..f1a585d1999 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/function/client/DefaultWebClientTests.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/function/client/DefaultWebClientTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2021 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -423,6 +423,40 @@ public class DefaultWebClientTests { StepVerifier.create(result).expectErrorMessage("1").verify(); } + @Test + public void onStatusHandlerRegisteredGlobally() { + + ClientResponse response = ClientResponse.create(HttpStatus.BAD_REQUEST).build(); + given(exchangeFunction.exchange(any())).willReturn(Mono.just(response)); + + Mono result = this.builder + .defaultStatusHandler(HttpStatusCode::is4xxClientError, resp -> Mono.error(new IllegalStateException("1"))) + .defaultStatusHandler(HttpStatusCode::is4xxClientError, resp -> Mono.error(new IllegalStateException("2"))) + .build().get() + .uri("/path") + .retrieve() + .bodyToMono(Void.class); + + StepVerifier.create(result).expectErrorMessage("1").verify(); + } + + @Test + public void onStatusHandlerRegisteredGloballyHaveLowerPrecedence() { + + ClientResponse response = ClientResponse.create(HttpStatus.BAD_REQUEST).build(); + given(exchangeFunction.exchange(any())).willReturn(Mono.just(response)); + + Mono result = this.builder + .defaultStatusHandler(HttpStatusCode::is4xxClientError, resp -> Mono.error(new IllegalStateException("1"))) + .build().get() + .uri("/path") + .retrieve() + .onStatus(HttpStatusCode::is4xxClientError, resp -> Mono.error(new IllegalStateException("2"))) + .bodyToMono(Void.class); + + StepVerifier.create(result).expectErrorMessage("2").verify(); + } + @Test // gh-23880 @SuppressWarnings("unchecked") public void onStatusHandlersDefaultHandlerIsLast() {