Add ServerResponse.async() in WebMvc.fn
This commit introduces a new async(Object) method in the WebMvc.fn, taking a asynchronous response as argument in the form of a CompletableFuture or Publisher. This allows for asynchronous setting of headers and status (and not just body, which was already possible). Closes gh-25828
This commit is contained in:
parent
c083b95ce1
commit
4e76a4780c
|
@ -0,0 +1,348 @@
|
|||
/*
|
||||
* Copyright 2002-2020 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.web.servlet.function;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
import java.util.function.Function;
|
||||
|
||||
import javax.servlet.AsyncContext;
|
||||
import javax.servlet.AsyncListener;
|
||||
import javax.servlet.ServletContext;
|
||||
import javax.servlet.ServletException;
|
||||
import javax.servlet.ServletRequest;
|
||||
import javax.servlet.ServletResponse;
|
||||
import javax.servlet.http.Cookie;
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import javax.servlet.http.HttpServletRequestWrapper;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
|
||||
import org.reactivestreams.Publisher;
|
||||
import org.reactivestreams.Subscriber;
|
||||
import org.reactivestreams.Subscription;
|
||||
|
||||
import org.springframework.core.ReactiveAdapter;
|
||||
import org.springframework.core.ReactiveAdapterRegistry;
|
||||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.http.HttpStatus;
|
||||
import org.springframework.lang.Nullable;
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.util.ClassUtils;
|
||||
import org.springframework.util.MultiValueMap;
|
||||
import org.springframework.web.servlet.ModelAndView;
|
||||
|
||||
/**
|
||||
* Implementation of {@link ServerResponse} based on a {@link CompletableFuture}.
|
||||
*
|
||||
* @author Arjen Poutsma
|
||||
* @since 5.3
|
||||
* @see ServerResponse#async(Object)
|
||||
*/
|
||||
final class AsyncServerResponse extends ErrorHandlingServerResponse {
|
||||
|
||||
static final boolean reactiveStreamsPresent = ClassUtils.isPresent(
|
||||
"org.reactivestreams.Publisher", AsyncServerResponse.class.getClassLoader());
|
||||
|
||||
|
||||
private final CompletableFuture<ServerResponse> futureResponse;
|
||||
|
||||
|
||||
private AsyncServerResponse(CompletableFuture<ServerResponse> futureResponse) {
|
||||
this.futureResponse = futureResponse;
|
||||
}
|
||||
|
||||
@Override
|
||||
public HttpStatus statusCode() {
|
||||
return delegate(ServerResponse::statusCode);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int rawStatusCode() {
|
||||
return delegate(ServerResponse::rawStatusCode);
|
||||
}
|
||||
|
||||
@Override
|
||||
public HttpHeaders headers() {
|
||||
return delegate(ServerResponse::headers);
|
||||
}
|
||||
|
||||
@Override
|
||||
public MultiValueMap<String, Cookie> cookies() {
|
||||
return delegate(ServerResponse::cookies);
|
||||
}
|
||||
|
||||
private <R> R delegate(Function<ServerResponse, R> function) {
|
||||
ServerResponse response = this.futureResponse.getNow(null);
|
||||
if (response != null) {
|
||||
return function.apply(response);
|
||||
}
|
||||
else {
|
||||
throw new IllegalStateException("Future ServerResponse has not yet completed");
|
||||
}
|
||||
}
|
||||
|
||||
@Nullable
|
||||
@Override
|
||||
public ModelAndView writeTo(HttpServletRequest request, HttpServletResponse response, Context context) {
|
||||
|
||||
SharedAsyncContextHttpServletRequest sharedRequest = new SharedAsyncContextHttpServletRequest(request);
|
||||
AsyncContext asyncContext = sharedRequest.startAsync(request, response);
|
||||
this.futureResponse.whenComplete((futureResponse, futureThrowable) -> {
|
||||
try {
|
||||
if (futureResponse != null) {
|
||||
ModelAndView mav = futureResponse.writeTo(sharedRequest, response, context);
|
||||
Assert.state(mav == null, "Asynchronous, rendering ServerResponse implementations are not " +
|
||||
"supported in WebMvc.fn. Please use WebFlux.fn instead.");
|
||||
}
|
||||
else if (futureThrowable != null) {
|
||||
handleError(futureThrowable, request, response, context);
|
||||
}
|
||||
}
|
||||
catch (Throwable throwable) {
|
||||
try {
|
||||
handleError(throwable, request, response, context);
|
||||
}
|
||||
catch (ServletException | IOException ex) {
|
||||
logger.warn("Asynchronous execution resulted in exception", ex);
|
||||
}
|
||||
}
|
||||
finally {
|
||||
asyncContext.complete();
|
||||
}
|
||||
});
|
||||
return null;
|
||||
}
|
||||
|
||||
|
||||
@SuppressWarnings({"unchecked"})
|
||||
public static ServerResponse create(Object o) {
|
||||
Assert.notNull(o, "Argument to async must not be null");
|
||||
|
||||
if (o instanceof CompletableFuture) {
|
||||
CompletableFuture<ServerResponse> futureResponse = (CompletableFuture<ServerResponse>) o;
|
||||
return new AsyncServerResponse(futureResponse);
|
||||
}
|
||||
else if (reactiveStreamsPresent) {
|
||||
ReactiveAdapter adapter = ReactiveAdapterRegistry.getSharedInstance().getAdapter(o.getClass());
|
||||
if (adapter != null) {
|
||||
Publisher<ServerResponse> publisher = adapter.toPublisher(o);
|
||||
CompletableFuture<ServerResponse> futureResponse = new CompletableFuture<>();
|
||||
publisher.subscribe(new ToFutureSubscriber(futureResponse));
|
||||
return new AsyncServerResponse(futureResponse);
|
||||
}
|
||||
}
|
||||
throw new IllegalArgumentException("Asynchronous type not supported: " + o.getClass());
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Subscriber that exposes the first result it receives via a CompletableFuture.
|
||||
*/
|
||||
private static final class ToFutureSubscriber implements Subscriber<ServerResponse> {
|
||||
|
||||
private final CompletableFuture<ServerResponse> future;
|
||||
|
||||
@Nullable
|
||||
private Subscription subscription;
|
||||
|
||||
|
||||
public ToFutureSubscriber(CompletableFuture<ServerResponse> future) {
|
||||
this.future = future;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onSubscribe(Subscription s) {
|
||||
if (this.subscription == null) {
|
||||
this.subscription = s;
|
||||
s.request(1);
|
||||
}
|
||||
else {
|
||||
s.cancel();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onNext(ServerResponse serverResponse) {
|
||||
if (!this.future.isDone()) {
|
||||
this.future.complete(serverResponse);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onError(Throwable t) {
|
||||
if (!this.future.isDone()) {
|
||||
this.future.completeExceptionally(t);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onComplete() {
|
||||
if (!this.future.isDone()) {
|
||||
this.future.completeExceptionally(new IllegalStateException("Did not receive ServerResponse"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* HttpServletRequestWrapper that shares its AsyncContext between this
|
||||
* AsyncServerResponse class and other, subsequent ServerResponse
|
||||
* implementations, keeping track of how many contexts where
|
||||
* started with startAsync(). This way, we make sure that
|
||||
* {@link AsyncContext#complete()} only completes for the response that
|
||||
* finishes last, and is not closed prematurely.
|
||||
*/
|
||||
private static final class SharedAsyncContextHttpServletRequest extends HttpServletRequestWrapper {
|
||||
|
||||
private final AsyncContext asyncContext;
|
||||
|
||||
private final AtomicInteger startedContexts;
|
||||
|
||||
public SharedAsyncContextHttpServletRequest(HttpServletRequest request) {
|
||||
super(request);
|
||||
this.asyncContext = request.startAsync();
|
||||
this.startedContexts = new AtomicInteger(0);
|
||||
}
|
||||
|
||||
private SharedAsyncContextHttpServletRequest(HttpServletRequest request, AsyncContext asyncContext,
|
||||
AtomicInteger startedContexts) {
|
||||
super(request);
|
||||
this.asyncContext = asyncContext;
|
||||
this.startedContexts = startedContexts;
|
||||
}
|
||||
|
||||
@Override
|
||||
public AsyncContext startAsync() throws IllegalStateException {
|
||||
this.startedContexts.incrementAndGet();
|
||||
return new SharedAsyncContext(this.asyncContext, this, this.asyncContext.getResponse(),
|
||||
this.startedContexts);
|
||||
}
|
||||
|
||||
@Override
|
||||
public AsyncContext startAsync(ServletRequest servletRequest, ServletResponse servletResponse)
|
||||
throws IllegalStateException {
|
||||
this.startedContexts.incrementAndGet();
|
||||
SharedAsyncContextHttpServletRequest sharedRequest;
|
||||
if (servletRequest instanceof SharedAsyncContextHttpServletRequest) {
|
||||
sharedRequest = (SharedAsyncContextHttpServletRequest) servletRequest;
|
||||
}
|
||||
else {
|
||||
sharedRequest = new SharedAsyncContextHttpServletRequest((HttpServletRequest) servletRequest,
|
||||
this.asyncContext, this.startedContexts);
|
||||
}
|
||||
return new SharedAsyncContext(this.asyncContext, sharedRequest, servletResponse, this.startedContexts);
|
||||
}
|
||||
|
||||
@Override
|
||||
public AsyncContext getAsyncContext() {
|
||||
return new SharedAsyncContext(this.asyncContext, this, this.asyncContext.getResponse(), this.startedContexts);
|
||||
}
|
||||
|
||||
|
||||
private static final class SharedAsyncContext implements AsyncContext {
|
||||
|
||||
private final AsyncContext delegate;
|
||||
|
||||
private final AtomicInteger openContexts;
|
||||
|
||||
private final ServletRequest request;
|
||||
|
||||
private final ServletResponse response;
|
||||
|
||||
|
||||
public SharedAsyncContext(AsyncContext delegate, SharedAsyncContextHttpServletRequest request,
|
||||
ServletResponse response, AtomicInteger usageCount) {
|
||||
|
||||
this.delegate = delegate;
|
||||
this.request = request;
|
||||
this.response = response;
|
||||
this.openContexts = usageCount;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void complete() {
|
||||
if (this.openContexts.decrementAndGet() == 0) {
|
||||
this.delegate.complete();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public ServletRequest getRequest() {
|
||||
return this.request;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ServletResponse getResponse() {
|
||||
return this.response;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean hasOriginalRequestAndResponse() {
|
||||
return this.delegate.hasOriginalRequestAndResponse();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void dispatch() {
|
||||
this.delegate.dispatch();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void dispatch(String path) {
|
||||
this.delegate.dispatch(path);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void dispatch(ServletContext context, String path) {
|
||||
this.delegate.dispatch(context, path);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void start(Runnable run) {
|
||||
this.delegate.start(run);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void addListener(AsyncListener listener) {
|
||||
this.delegate.addListener(listener);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void addListener(AsyncListener listener,
|
||||
ServletRequest servletRequest,
|
||||
ServletResponse servletResponse) {
|
||||
|
||||
this.delegate.addListener(listener, servletRequest, servletResponse);
|
||||
}
|
||||
|
||||
@Override
|
||||
public <T extends AsyncListener> T createListener(Class<T> clazz) throws ServletException {
|
||||
return this.delegate.createListener(clazz);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setTimeout(long timeout) {
|
||||
this.delegate.setTimeout(timeout);
|
||||
}
|
||||
|
||||
@Override
|
||||
public long getTimeout() {
|
||||
return this.delegate.getTimeout();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -41,6 +41,8 @@ import org.reactivestreams.Subscriber;
|
|||
import org.reactivestreams.Subscription;
|
||||
|
||||
import org.springframework.core.ParameterizedTypeReference;
|
||||
import org.springframework.core.ReactiveAdapter;
|
||||
import org.springframework.core.ReactiveAdapterRegistry;
|
||||
import org.springframework.core.io.InputStreamResource;
|
||||
import org.springframework.core.io.Resource;
|
||||
import org.springframework.core.io.support.ResourceRegion;
|
||||
|
@ -56,7 +58,6 @@ import org.springframework.http.converter.HttpMessageConverter;
|
|||
import org.springframework.http.server.ServletServerHttpResponse;
|
||||
import org.springframework.lang.Nullable;
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.util.ClassUtils;
|
||||
import org.springframework.util.LinkedMultiValueMap;
|
||||
import org.springframework.util.MultiValueMap;
|
||||
import org.springframework.web.HttpMediaTypeNotAcceptableException;
|
||||
|
@ -71,9 +72,6 @@ import org.springframework.web.servlet.ModelAndView;
|
|||
*/
|
||||
final class DefaultEntityResponseBuilder<T> implements EntityResponse.Builder<T> {
|
||||
|
||||
private static final boolean reactiveStreamsPresent = ClassUtils.isPresent(
|
||||
"org.reactivestreams.Publisher", DefaultEntityResponseBuilder.class.getClassLoader());
|
||||
|
||||
private static final Type RESOURCE_REGION_LIST_TYPE =
|
||||
new ParameterizedTypeReference<List<ResourceRegion>>() { }.getType();
|
||||
|
||||
|
@ -209,13 +207,14 @@ final class DefaultEntityResponseBuilder<T> implements EntityResponse.Builder<T>
|
|||
return new CompletionStageEntityResponse(this.status, this.headers, this.cookies,
|
||||
completionStage, this.entityType);
|
||||
}
|
||||
else if (reactiveStreamsPresent && PublisherEntityResponse.isPublisher(this.entity)) {
|
||||
Publisher publisher = (Publisher) this.entity;
|
||||
return new PublisherEntityResponse(this.status, this.headers, this.cookies, publisher, this.entityType);
|
||||
}
|
||||
else {
|
||||
return new DefaultEntityResponse<>(this.status, this.headers, this.cookies, this.entity, this.entityType);
|
||||
else if (AsyncServerResponse.reactiveStreamsPresent) {
|
||||
ReactiveAdapter adapter = ReactiveAdapterRegistry.getSharedInstance().getAdapter(this.entity.getClass());
|
||||
if (adapter != null) {
|
||||
Publisher<T> publisher = adapter.toPublisher(this.entity);
|
||||
return new PublisherEntityResponse(this.status, this.headers, this.cookies, publisher, this.entityType);
|
||||
}
|
||||
}
|
||||
return new DefaultEntityResponse<>(this.status, this.headers, this.cookies, this.entity, this.entityType);
|
||||
}
|
||||
|
||||
|
||||
|
@ -325,7 +324,7 @@ final class DefaultEntityResponseBuilder<T> implements EntityResponse.Builder<T>
|
|||
}
|
||||
|
||||
protected void tryWriteEntityWithMessageConverters(Object entity, HttpServletRequest request,
|
||||
HttpServletResponse response, ServerResponse.Context context) {
|
||||
HttpServletResponse response, ServerResponse.Context context) throws ServletException, IOException {
|
||||
try {
|
||||
writeEntityWithMessageConverters(entity, request, response, context);
|
||||
}
|
||||
|
@ -376,6 +375,9 @@ final class DefaultEntityResponseBuilder<T> implements EntityResponse.Builder<T>
|
|||
handleError(throwable, servletRequest, servletResponse, context);
|
||||
}
|
||||
}
|
||||
catch (ServletException | IOException ex) {
|
||||
logger.warn("Asynchronous execution resulted in exception", ex);
|
||||
}
|
||||
finally {
|
||||
asyncContext.complete();
|
||||
}
|
||||
|
@ -385,6 +387,9 @@ final class DefaultEntityResponseBuilder<T> implements EntityResponse.Builder<T>
|
|||
}
|
||||
|
||||
|
||||
/**
|
||||
* {@link EntityResponse} implementation for asynchronous {@link Publisher} bodies.
|
||||
*/
|
||||
private static class PublisherEntityResponse<T> extends DefaultEntityResponse<Publisher<T>> {
|
||||
|
||||
public PublisherEntityResponse(int statusCode, HttpHeaders headers,
|
||||
|
@ -403,11 +408,6 @@ final class DefaultEntityResponseBuilder<T> implements EntityResponse.Builder<T>
|
|||
return null;
|
||||
}
|
||||
|
||||
public static boolean isPublisher(Object entity) {
|
||||
return (entity instanceof Publisher);
|
||||
}
|
||||
|
||||
|
||||
@SuppressWarnings("SubscriberImplementation")
|
||||
private class ProducingSubscriber implements Subscriber<T> {
|
||||
|
||||
|
@ -438,13 +438,23 @@ final class DefaultEntityResponseBuilder<T> implements EntityResponse.Builder<T>
|
|||
public void onNext(T element) {
|
||||
HttpServletRequest servletRequest = (HttpServletRequest) this.asyncContext.getRequest();
|
||||
HttpServletResponse servletResponse = (HttpServletResponse) this.asyncContext.getResponse();
|
||||
tryWriteEntityWithMessageConverters(element, servletRequest, servletResponse, this.context);
|
||||
try {
|
||||
tryWriteEntityWithMessageConverters(element, servletRequest, servletResponse, this.context);
|
||||
}
|
||||
catch (ServletException | IOException ex) {
|
||||
onError(ex);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onError(Throwable t) {
|
||||
handleError(t, (HttpServletRequest) this.asyncContext.getRequest(),
|
||||
(HttpServletResponse) this.asyncContext.getResponse(), this.context);
|
||||
try {
|
||||
handleError(t, (HttpServletRequest) this.asyncContext.getRequest(),
|
||||
(HttpServletResponse) this.asyncContext.getResponse(), this.context);
|
||||
}
|
||||
catch (ServletException | IOException ex) {
|
||||
logger.warn("Asynchronous execution resulted in exception", ex);
|
||||
}
|
||||
this.asyncContext.complete();
|
||||
}
|
||||
|
||||
|
|
|
@ -17,21 +17,17 @@
|
|||
package org.springframework.web.servlet.function;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.UncheckedIOException;
|
||||
import java.net.URI;
|
||||
import java.time.Instant;
|
||||
import java.time.ZonedDateTime;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collection;
|
||||
import java.util.EnumSet;
|
||||
import java.util.LinkedHashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.function.BiFunction;
|
||||
import java.util.function.Consumer;
|
||||
import java.util.function.Predicate;
|
||||
|
||||
import javax.servlet.ServletException;
|
||||
import javax.servlet.http.Cookie;
|
||||
|
@ -231,17 +227,17 @@ class DefaultServerResponseBuilder implements ServerResponse.BodyBuilder {
|
|||
/**
|
||||
* Abstract base class for {@link ServerResponse} implementations.
|
||||
*/
|
||||
abstract static class AbstractServerResponse implements ServerResponse {
|
||||
abstract static class AbstractServerResponse extends ErrorHandlingServerResponse {
|
||||
|
||||
private static final Set<HttpMethod> SAFE_METHODS = EnumSet.of(HttpMethod.GET, HttpMethod.HEAD);
|
||||
|
||||
|
||||
final int statusCode;
|
||||
|
||||
private final HttpHeaders headers;
|
||||
|
||||
private final MultiValueMap<String, Cookie> cookies;
|
||||
|
||||
private final List<ErrorHandler<?>> errorHandlers = new ArrayList<>();
|
||||
|
||||
protected AbstractServerResponse(
|
||||
int statusCode, HttpHeaders headers, MultiValueMap<String, Cookie> cookies) {
|
||||
|
@ -252,14 +248,6 @@ class DefaultServerResponseBuilder implements ServerResponse.BodyBuilder {
|
|||
CollectionUtils.unmodifiableMultiValueMap(new LinkedMultiValueMap<>(cookies));
|
||||
}
|
||||
|
||||
protected <T extends ServerResponse> void addErrorHandler(Predicate<Throwable> predicate,
|
||||
BiFunction<Throwable, ServerRequest, T> errorHandler) {
|
||||
|
||||
Assert.notNull(predicate, "Predicate must not be null");
|
||||
Assert.notNull(errorHandler, "ErrorHandler must not be null");
|
||||
this.errorHandlers.add(new ErrorHandler<>(predicate, errorHandler));
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public final HttpStatus statusCode() {
|
||||
|
@ -338,55 +326,6 @@ class DefaultServerResponseBuilder implements ServerResponse.BodyBuilder {
|
|||
HttpServletRequest request, HttpServletResponse response, Context context)
|
||||
throws ServletException, IOException;
|
||||
|
||||
@Nullable
|
||||
protected ModelAndView handleError(Throwable t, HttpServletRequest servletRequest,
|
||||
HttpServletResponse servletResponse, Context context) {
|
||||
|
||||
return this.errorHandlers.stream()
|
||||
.filter(errorHandler -> errorHandler.test(t))
|
||||
.findFirst()
|
||||
.map(errorHandler -> {
|
||||
ServerRequest serverRequest = (ServerRequest)
|
||||
servletRequest.getAttribute(RouterFunctions.REQUEST_ATTRIBUTE);
|
||||
ServerResponse serverResponse = errorHandler.handle(t, serverRequest);
|
||||
try {
|
||||
return serverResponse.writeTo(servletRequest, servletResponse, context);
|
||||
}
|
||||
catch (ServletException ex) {
|
||||
throw new IllegalStateException(ex);
|
||||
}
|
||||
catch (IOException ex) {
|
||||
throw new UncheckedIOException(ex);
|
||||
}
|
||||
})
|
||||
.orElseThrow(() -> new IllegalStateException(t));
|
||||
}
|
||||
|
||||
|
||||
private static class ErrorHandler<T extends ServerResponse> {
|
||||
|
||||
private final Predicate<Throwable> predicate;
|
||||
|
||||
private final BiFunction<Throwable, ServerRequest, T>
|
||||
responseProvider;
|
||||
|
||||
public ErrorHandler(Predicate<Throwable> predicate,
|
||||
BiFunction<Throwable, ServerRequest, T> responseProvider) {
|
||||
|
||||
Assert.notNull(predicate, "Predicate must not be null");
|
||||
Assert.notNull(responseProvider, "ResponseProvider must not be null");
|
||||
this.predicate = predicate;
|
||||
this.responseProvider = responseProvider;
|
||||
}
|
||||
|
||||
public boolean test(Throwable t) {
|
||||
return this.predicate.test(t);
|
||||
}
|
||||
|
||||
public T handle(Throwable t, ServerRequest serverRequest) {
|
||||
return this.responseProvider.apply(t, serverRequest);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
@ -394,8 +333,7 @@ class DefaultServerResponseBuilder implements ServerResponse.BodyBuilder {
|
|||
|
||||
private final BiFunction<HttpServletRequest, HttpServletResponse, ModelAndView> writeFunction;
|
||||
|
||||
public WriterFunctionResponse(
|
||||
int statusCode, HttpHeaders headers, MultiValueMap<String, Cookie> cookies,
|
||||
public WriterFunctionResponse(int statusCode, HttpHeaders headers, MultiValueMap<String, Cookie> cookies,
|
||||
BiFunction<HttpServletRequest, HttpServletResponse, ModelAndView> writeFunction) {
|
||||
|
||||
super(statusCode, headers, cookies);
|
||||
|
|
|
@ -52,7 +52,7 @@ public interface EntityResponse<T> extends ServerResponse {
|
|||
/**
|
||||
* Create a builder with the given object.
|
||||
* @param t the object that represents the body of the response
|
||||
* @param <T> the type of element contained in the publisher
|
||||
* @param <T> the type of element contained in the entity
|
||||
* @return the created builder
|
||||
*/
|
||||
static <T> Builder<T> fromObject(T t) {
|
||||
|
@ -63,7 +63,7 @@ public interface EntityResponse<T> extends ServerResponse {
|
|||
* Create a builder with the given object and type reference.
|
||||
* @param t the object that represents the body of the response
|
||||
* @param entityType the type of the entity, used to capture the generic type
|
||||
* @param <T> the type of element contained in the publisher
|
||||
* @param <T> the type of element contained in the entity
|
||||
* @return the created builder
|
||||
*/
|
||||
static <T> Builder<T> fromObject(T t, ParameterizedTypeReference<T> entityType) {
|
||||
|
|
|
@ -0,0 +1,95 @@
|
|||
/*
|
||||
* Copyright 2002-2020 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.web.servlet.function;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.function.BiFunction;
|
||||
import java.util.function.Predicate;
|
||||
|
||||
import javax.servlet.ServletException;
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
|
||||
import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
|
||||
import org.springframework.lang.Nullable;
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.web.servlet.ModelAndView;
|
||||
|
||||
/**
|
||||
* Base class for {@link ServerResponse} implementations with error handling.
|
||||
*
|
||||
* @author Arjen Poutsma
|
||||
* @since 5.3
|
||||
*/
|
||||
abstract class ErrorHandlingServerResponse implements ServerResponse {
|
||||
|
||||
protected final Log logger = LogFactory.getLog(getClass());
|
||||
|
||||
private final List<ErrorHandler<?>> errorHandlers = new ArrayList<>();
|
||||
|
||||
|
||||
protected final <T extends ServerResponse> void addErrorHandler(Predicate<Throwable> predicate,
|
||||
BiFunction<Throwable, ServerRequest, T> errorHandler) {
|
||||
|
||||
Assert.notNull(predicate, "Predicate must not be null");
|
||||
Assert.notNull(errorHandler, "ErrorHandler must not be null");
|
||||
this.errorHandlers.add(new ErrorHandler<>(predicate, errorHandler));
|
||||
}
|
||||
|
||||
@Nullable
|
||||
protected ModelAndView handleError(Throwable t, HttpServletRequest servletRequest,
|
||||
HttpServletResponse servletResponse, Context context) throws ServletException, IOException {
|
||||
|
||||
for (ErrorHandler<?> errorHandler : this.errorHandlers) {
|
||||
if (errorHandler.test(t)) {
|
||||
ServerRequest serverRequest = (ServerRequest)
|
||||
servletRequest.getAttribute(RouterFunctions.REQUEST_ATTRIBUTE);
|
||||
ServerResponse serverResponse = errorHandler.handle(t, serverRequest);
|
||||
return serverResponse.writeTo(servletRequest, servletResponse, context);
|
||||
}
|
||||
}
|
||||
throw new ServletException(t);
|
||||
}
|
||||
|
||||
|
||||
private static class ErrorHandler<T extends ServerResponse> {
|
||||
|
||||
private final Predicate<Throwable> predicate;
|
||||
|
||||
private final BiFunction<Throwable, ServerRequest, T> responseProvider;
|
||||
|
||||
public ErrorHandler(Predicate<Throwable> predicate, BiFunction<Throwable, ServerRequest, T> responseProvider) {
|
||||
Assert.notNull(predicate, "Predicate must not be null");
|
||||
Assert.notNull(responseProvider, "ResponseProvider must not be null");
|
||||
this.predicate = predicate;
|
||||
this.responseProvider = responseProvider;
|
||||
}
|
||||
|
||||
public boolean test(Throwable t) {
|
||||
return this.predicate.test(t);
|
||||
}
|
||||
|
||||
public T handle(Throwable t, ServerRequest serverRequest) {
|
||||
return this.responseProvider.apply(t, serverRequest);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
|
@ -112,9 +112,8 @@ public interface HandlerFilterFunction<T extends ServerResponse, R extends Serve
|
|||
return (request, next) -> {
|
||||
try {
|
||||
T t = next.handle(request);
|
||||
if (t instanceof DefaultServerResponseBuilder.AbstractServerResponse) {
|
||||
((DefaultServerResponseBuilder.AbstractServerResponse) t)
|
||||
.addErrorHandler(predicate, errorHandler);
|
||||
if (t instanceof ErrorHandlingServerResponse) {
|
||||
((ErrorHandlingServerResponse) t).addErrorHandler(predicate, errorHandler);
|
||||
}
|
||||
return t;
|
||||
}
|
||||
|
|
|
@ -24,6 +24,7 @@ import java.util.Collection;
|
|||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.concurrent.CompletionStage;
|
||||
import java.util.function.BiFunction;
|
||||
import java.util.function.Consumer;
|
||||
|
@ -36,6 +37,7 @@ import javax.servlet.http.HttpServletResponse;
|
|||
import org.reactivestreams.Publisher;
|
||||
|
||||
import org.springframework.core.ParameterizedTypeReference;
|
||||
import org.springframework.core.ReactiveAdapterRegistry;
|
||||
import org.springframework.http.CacheControl;
|
||||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.http.HttpMethod;
|
||||
|
@ -216,6 +218,31 @@ public interface ServerResponse {
|
|||
return status(HttpStatus.UNPROCESSABLE_ENTITY);
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a (built) response with the given asynchronous response.
|
||||
* Parameter {@code asyncResponse} can be a
|
||||
* {@link CompletableFuture CompletableFuture<ServerResponse>} or
|
||||
* {@link Publisher Publisher<ServerResponse>} (or any
|
||||
* asynchronous producer of a single {@code ServerResponse} that can be
|
||||
* adapted via the {@link ReactiveAdapterRegistry}).
|
||||
*
|
||||
* <p>This method can be used to set the response status code, headers, and
|
||||
* body based on an asynchronous result. If only the body is asynchronous,
|
||||
* {@link BodyBuilder#body(Object)} can be used instead.
|
||||
*
|
||||
* <p><strong>Note</strong> that
|
||||
* {@linkplain RenderingResponse rendering responses}, as returned by
|
||||
* {@link BodyBuilder#render}, are <strong>not</strong> supported as value
|
||||
* for {@code asyncResponse}. Use WebFlux.fn for asynchronous rendering.
|
||||
* @param asyncResponse a {@code CompletableFuture<ServerResponse>} or
|
||||
* {@code Publisher<ServerResponse>}
|
||||
* @return the asynchronous response
|
||||
* @since 5.3
|
||||
*/
|
||||
static ServerResponse async(Object asyncResponse) {
|
||||
return AsyncServerResponse.create(asyncResponse);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Defines a builder that adds headers to the response.
|
||||
|
@ -374,10 +401,13 @@ public interface ServerResponse {
|
|||
BodyBuilder contentType(MediaType contentType);
|
||||
|
||||
/**
|
||||
* Set the body of the response to the given {@code Object} and return it.
|
||||
* Set the body of the response to the given {@code Object} and return
|
||||
* it.
|
||||
*
|
||||
* <p>Asynchronous response bodies are supported by providing a {@link CompletionStage} or
|
||||
* {@link Publisher} as body.
|
||||
* <p>Asynchronous response bodies are supported by providing a
|
||||
* {@link CompletionStage} or {@link Publisher} as body (or any
|
||||
* asynchronous producer of a single entity that can be adapted via the
|
||||
* {@link ReactiveAdapterRegistry}).
|
||||
* @param body the body of the response
|
||||
* @return the built response
|
||||
*/
|
||||
|
|
Loading…
Reference in New Issue