Execute preflight checks before interceptor chain

See gh-29509
This commit is contained in:
tianshuang 2022-11-17 18:14:18 +08:00 committed by rstoyanchev
parent da7ad71b7f
commit a1ce5dac0b
3 changed files with 48 additions and 15 deletions

View File

@ -110,7 +110,7 @@ public class HandlerExecutionChain {
/**
* Add the given interceptors to the end of this chain.
*/
public void addInterceptors(HandlerInterceptor... interceptors) {
public void addInterceptors(@Nullable HandlerInterceptor... interceptors) {
CollectionUtils.mergeArrayIntoCollection(interceptors, this.interceptorList);
}

View File

@ -663,9 +663,9 @@ public abstract class AbstractHandlerMapping extends WebApplicationObjectSupport
/**
* Update the HandlerExecutionChain for CORS-related handling.
* <p>For pre-flight requests, the default implementation replaces the selected
* handler with a simple HttpRequestHandler that invokes the configured
* {@link #setCorsProcessor}.
* <p>For pre-flight requests, the default implementation inserts a
* HandlerInterceptor that makes CORS-related checks and adds CORS headers.
* But does not abort the execution chain.
* <p>For actual requests, the default implementation inserts a
* HandlerInterceptor that makes CORS-related checks and adds CORS headers.
* @param request the current request
@ -675,15 +675,12 @@ public abstract class AbstractHandlerMapping extends WebApplicationObjectSupport
*/
protected HandlerExecutionChain getCorsHandlerExecutionChain(HttpServletRequest request,
HandlerExecutionChain chain, @Nullable CorsConfiguration config) {
if (CorsUtils.isPreFlightRequest(request)) {
HandlerInterceptor[] interceptors = chain.getInterceptors();
return new HandlerExecutionChain(new PreFlightHandler(config), interceptors);
}
else {
chain.addInterceptor(0, new CorsInterceptor(config));
return chain;
boolean isPreFlightRequest = CorsUtils.isPreFlightRequest(request);
if (isPreFlightRequest) {
chain = new HandlerExecutionChain(new PreFlightHandler(config), chain.getInterceptors());
}
chain.addInterceptor(0, new CorsInterceptor(config, isPreFlightRequest));
return chain;
}
@ -698,7 +695,7 @@ public abstract class AbstractHandlerMapping extends WebApplicationObjectSupport
@Override
public void handleRequest(HttpServletRequest request, HttpServletResponse response) throws IOException {
corsProcessor.processRequest(this.config, request, response);
// no-op
}
@Override
@ -713,9 +710,11 @@ public abstract class AbstractHandlerMapping extends WebApplicationObjectSupport
@Nullable
private final CorsConfiguration config;
private final boolean alwaysProceed;
public CorsInterceptor(@Nullable CorsConfiguration config) {
public CorsInterceptor(@Nullable CorsConfiguration config, boolean alwaysProceed) {
this.config = config;
this.alwaysProceed = alwaysProceed;
}
@Override
@ -728,7 +727,8 @@ public abstract class AbstractHandlerMapping extends WebApplicationObjectSupport
return true;
}
return corsProcessor.processRequest(this.config, request, response);
boolean proceed = corsProcessor.processRequest(this.config, request, response);
return this.alwaysProceed || proceed;
}
@Override

View File

@ -39,7 +39,9 @@ import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.context.support.StaticWebApplicationContext;
import org.springframework.web.cors.CorsConfiguration;
import org.springframework.web.method.HandlerMethod;
import org.springframework.web.servlet.ComplexWebApplicationContext;
import org.springframework.web.servlet.HandlerExecutionChain;
import org.springframework.web.servlet.HandlerInterceptor;
import org.springframework.web.servlet.HandlerMapping;
import org.springframework.web.servlet.mvc.HttpRequestHandlerAdapter;
import org.springframework.web.testfixture.servlet.MockHttpServletRequest;
@ -129,7 +131,9 @@ public class HandlerMethodMappingTests {
HandlerExecutionChain chain = this.mapping.getHandler(request);
assertThat(chain).isNotNull();
assertThat(chain.getInterceptorList()).isNotEmpty();
assertThat(chain.getHandler()).isInstanceOf(HttpRequestHandler.class);
chain.getInterceptorList().get(0).preHandle(request, response, chain.getHandler());
new HttpRequestHandlerAdapter().handle(request, response, chain.getHandler());
assertThat(response.getStatus()).isEqualTo(403);
@ -148,7 +152,9 @@ public class HandlerMethodMappingTests {
HandlerExecutionChain chain = this.mapping.getHandler(request);
assertThat(chain).isNotNull();
assertThat(chain.getInterceptorList()).isNotEmpty();
assertThat(chain.getHandler()).isInstanceOf(HttpRequestHandler.class);
chain.getInterceptorList().get(0).preHandle(request, response, chain.getHandler());
new HttpRequestHandlerAdapter().handle(request, response, chain.getHandler());
assertThat(response.getStatus()).isEqualTo(200);
@ -156,6 +162,33 @@ public class HandlerMethodMappingTests {
assertThat(response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_METHODS)).isEqualTo("GET");
}
@Test
public void abortInterceptorInPreFlightRequestWithCorsConfig() throws Exception {
this.mapping.registerMapping("/foo", this.handler, this.handler.getClass().getMethod("corsHandlerMethod"));
MockHttpServletRequest request = new MockHttpServletRequest("OPTIONS", "/foo");
request.addParameter("abort", "true");
request.addHeader(HttpHeaders.ORIGIN, "https://domain.com");
request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET");
MockHttpServletResponse response = new MockHttpServletResponse();
HandlerExecutionChain chain = this.mapping.getHandler(request);
assertThat(chain).isNotNull();
chain.addInterceptor(new ComplexWebApplicationContext.MyHandlerInterceptor1());
chain.addInterceptor(new ComplexWebApplicationContext.MyHandlerInterceptor2());
assertThat(chain.getInterceptorList().size()).isEqualTo(3);
assertThat(chain.getHandler()).isInstanceOf(HttpRequestHandler.class);
for (HandlerInterceptor interceptor : chain.getInterceptorList()) {
interceptor.preHandle(request, response, chain.getHandler());
}
new HttpRequestHandlerAdapter().handle(request, response, chain.getHandler());
assertThat(response.getStatus()).isEqualTo(200);
assertThat(response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)).isEqualTo("https://domain.com");
assertThat(response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_METHODS)).isEqualTo("GET,HEAD");
}
@Test
public void detectHandlerMethodsInAncestorContexts() {
StaticApplicationContext cxt = new StaticApplicationContext();