TestDispatcherServlet unwraps to find mock request
Issue: SPR-16695
This commit is contained in:
parent
2dde000475
commit
313308208e
|
|
@ -26,6 +26,8 @@ import javax.servlet.http.HttpServletResponse;
|
|||
|
||||
import org.springframework.lang.Nullable;
|
||||
import org.springframework.mock.web.MockAsyncContext;
|
||||
import org.springframework.mock.web.MockHttpServletRequest;
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.web.context.WebApplicationContext;
|
||||
import org.springframework.web.context.request.NativeWebRequest;
|
||||
import org.springframework.web.context.request.async.CallableProcessingInterceptor;
|
||||
|
|
@ -35,6 +37,7 @@ import org.springframework.web.context.request.async.WebAsyncUtils;
|
|||
import org.springframework.web.servlet.DispatcherServlet;
|
||||
import org.springframework.web.servlet.HandlerExecutionChain;
|
||||
import org.springframework.web.servlet.ModelAndView;
|
||||
import org.springframework.web.util.WebUtils;
|
||||
|
||||
/**
|
||||
* A sub-class of {@code DispatcherServlet} that saves the result in an
|
||||
|
|
@ -68,8 +71,13 @@ final class TestDispatcherServlet extends DispatcherServlet {
|
|||
super.service(request, response);
|
||||
|
||||
if (request.getAsyncContext() != null) {
|
||||
MockHttpServletRequest mockRequest = WebUtils.getNativeRequest(request, MockHttpServletRequest.class);
|
||||
Assert.notNull(mockRequest, "Expected MockHttpServletRequest");
|
||||
MockAsyncContext mockAsyncContext = ((MockAsyncContext) mockRequest.getAsyncContext());
|
||||
Assert.notNull(mockAsyncContext, "MockAsyncContext not found. Did request wrapper not delegate startAsync?");
|
||||
|
||||
CountDownLatch dispatchLatch = new CountDownLatch(1);
|
||||
((MockAsyncContext) request.getAsyncContext()).addDispatchHandler(dispatchLatch::countDown);
|
||||
mockAsyncContext.addDispatchHandler(dispatchLatch::countDown);
|
||||
getMvcResult(request).setAsyncDispatchLatch(dispatchLatch);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright 2002-2016 the original author or authors.
|
||||
* Copyright 2002-2018 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
@ -19,9 +19,14 @@ package org.springframework.test.web.servlet.samples.standalone;
|
|||
import java.io.IOException;
|
||||
import java.security.Principal;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import javax.servlet.AsyncContext;
|
||||
import javax.servlet.AsyncListener;
|
||||
import javax.servlet.Filter;
|
||||
import javax.servlet.FilterChain;
|
||||
import javax.servlet.ServletContext;
|
||||
import javax.servlet.ServletException;
|
||||
import javax.servlet.ServletRequest;
|
||||
import javax.servlet.ServletResponse;
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import javax.servlet.http.HttpServletRequestWrapper;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
|
|
@ -120,10 +125,10 @@ public class FilterTests {
|
|||
.andExpect(model().attribute("principal", WrappingRequestResponseFilter.PRINCIPAL_NAME));
|
||||
}
|
||||
|
||||
@Test // SPR-16067
|
||||
public void filterWrapsRequestResponseWithAsyncDispatch() throws Exception {
|
||||
@Test // SPR-16067, SPR-16695
|
||||
public void filterWrapsRequestResponseAndPerformsAsyncDispatch() throws Exception {
|
||||
MockMvc mockMvc = standaloneSetup(new PersonController())
|
||||
.addFilters(new ShallowEtagHeaderFilter())
|
||||
.addFilters(new WrappingRequestResponseFilter(), new ShallowEtagHeaderFilter())
|
||||
.build();
|
||||
|
||||
MvcResult mvcResult = mockMvc.perform(get("/persons/1").accept(MediaType.APPLICATION_JSON))
|
||||
|
|
@ -189,10 +194,20 @@ public class FilterTests {
|
|||
FilterChain filterChain) throws ServletException, IOException {
|
||||
|
||||
filterChain.doFilter(new HttpServletRequestWrapper(request) {
|
||||
|
||||
@Override
|
||||
public Principal getUserPrincipal() {
|
||||
return () -> PRINCIPAL_NAME;
|
||||
}
|
||||
|
||||
// Like Spring Security does in HttpServlet3RequestFactory..
|
||||
|
||||
@Override
|
||||
public AsyncContext getAsyncContext() {
|
||||
return super.getAsyncContext() != null ?
|
||||
new AsyncContextWrapper(super.getAsyncContext()) : null;
|
||||
}
|
||||
|
||||
}, new HttpServletResponseWrapper(response));
|
||||
}
|
||||
}
|
||||
|
|
@ -206,4 +221,79 @@ public class FilterTests {
|
|||
response.sendRedirect("/login");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
private static class AsyncContextWrapper implements AsyncContext {
|
||||
|
||||
private final AsyncContext delegate;
|
||||
|
||||
public AsyncContextWrapper(AsyncContext delegate) {
|
||||
this.delegate = delegate;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ServletRequest getRequest() {
|
||||
return this.delegate.getRequest();
|
||||
}
|
||||
|
||||
@Override
|
||||
public ServletResponse getResponse() {
|
||||
return this.delegate.getResponse();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean hasOriginalRequestAndResponse() {
|
||||
return this.delegate.hasOriginalRequestAndResponse();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void dispatch() {
|
||||
this.delegate.dispatch();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void dispatch(String path) {
|
||||
this.delegate.dispatch(path);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void dispatch(ServletContext context, String path) {
|
||||
this.delegate.dispatch(context, path);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void complete() {
|
||||
this.delegate.complete();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void start(Runnable run) {
|
||||
this.delegate.start(run);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void addListener(AsyncListener listener) {
|
||||
this.delegate.addListener(listener);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void addListener(AsyncListener listener, ServletRequest req, ServletResponse res) {
|
||||
this.delegate.addListener(listener, req, res);
|
||||
}
|
||||
|
||||
@Override
|
||||
public <T extends AsyncListener> T createListener(Class<T> clazz) throws ServletException {
|
||||
return this.delegate.createListener(clazz);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setTimeout(long timeout) {
|
||||
this.delegate.setTimeout(timeout);
|
||||
}
|
||||
|
||||
@Override
|
||||
public long getTimeout() {
|
||||
return this.delegate.getTimeout();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue