WebMvc respects RouterFunction beans ordering

Closes gh-28595
This commit is contained in:
rstoyanchev 2022-06-13 16:32:14 +01:00
parent 97854d9fec
commit 52d0681ca1
4 changed files with 118 additions and 21 deletions

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2021 the original author or authors. * Copyright 2002-2022 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -119,12 +119,11 @@ public class RouterFunctionMapping extends AbstractHandlerMapping implements Ini
} }
private List<RouterFunction<?>> routerFunctions() { private List<RouterFunction<?>> routerFunctions() {
List<RouterFunction<?>> functions = obtainApplicationContext() return obtainApplicationContext()
.getBeanProvider(RouterFunction.class) .getBeanProvider(RouterFunction.class)
.orderedStream() .orderedStream()
.map(router -> (RouterFunction<?>) router) .map(router -> (RouterFunction<?>) router)
.collect(Collectors.toList()); .collect(Collectors.toList());
return (!CollectionUtils.isEmpty(functions) ? functions : Collections.emptyList());
} }
private void logRouterFunctions(List<RouterFunction<?>> routerFunctions) { private void logRouterFunctions(List<RouterFunction<?>> routerFunctions) {

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2021 the original author or authors. * Copyright 2002-2022 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -20,6 +20,7 @@ import org.junit.jupiter.api.Test;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
import reactor.test.StepVerifier; import reactor.test.StepVerifier;
import org.springframework.context.annotation.AnnotationConfigApplicationContext;
import org.springframework.http.codec.ServerCodecConfigurer; import org.springframework.http.codec.ServerCodecConfigurer;
import org.springframework.web.reactive.HandlerMapping; import org.springframework.web.reactive.HandlerMapping;
import org.springframework.web.reactive.function.server.HandlerFunction; import org.springframework.web.reactive.function.server.HandlerFunction;
@ -72,6 +73,23 @@ public class RouterFunctionMappingTests {
.verify(); .verify();
} }
@Test
void empty() throws Exception {
AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext();
context.refresh();
RouterFunctionMapping mapping = new RouterFunctionMapping();
mapping.setMessageReaders(this.codecConfigurer.getReaders());
mapping.setApplicationContext(context);
mapping.afterPropertiesSet();
Mono<Object> result = mapping.getHandler(createExchange("https://example.com/match"));
StepVerifier.create(result)
.expectComplete()
.verify();
}
@Test @Test
void changeParser() throws Exception { void changeParser() throws Exception {
HandlerFunction<ServerResponse> handlerFunction = request -> ServerResponse.ok().build(); HandlerFunction<ServerResponse> handlerFunction = request -> ServerResponse.ok().build();

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2021 the original author or authors. * Copyright 2002-2022 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -19,11 +19,10 @@ package org.springframework.web.servlet.function.support;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.stream.Collectors;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import org.springframework.beans.factory.BeanFactoryUtils;
import org.springframework.beans.factory.InitializingBean; import org.springframework.beans.factory.InitializingBean;
import org.springframework.context.ApplicationContext; import org.springframework.context.ApplicationContext;
import org.springframework.core.SpringProperties; import org.springframework.core.SpringProperties;
@ -135,7 +134,7 @@ public class RouterFunctionMapping extends AbstractHandlerMapping implements Ini
@Override @Override
public void afterPropertiesSet() throws Exception { public void afterPropertiesSet() throws Exception {
if (this.routerFunction == null) { if (this.routerFunction == null) {
initRouterFunction(); initRouterFunctions();
} }
if (CollectionUtils.isEmpty(this.messageConverters)) { if (CollectionUtils.isEmpty(this.messageConverters)) {
initMessageConverters(); initMessageConverters();
@ -154,20 +153,39 @@ public class RouterFunctionMapping extends AbstractHandlerMapping implements Ini
* Detect a all {@linkplain RouterFunction router functions} in the * Detect a all {@linkplain RouterFunction router functions} in the
* current application context. * current application context.
*/ */
@SuppressWarnings({"rawtypes", "unchecked"}) private void initRouterFunctions() {
private void initRouterFunction() { List<RouterFunction<?>> routerFunctions = routerFunctions();
ApplicationContext applicationContext = obtainApplicationContext();
Map<String, RouterFunction> beans =
(this.detectHandlerFunctionsInAncestorContexts ?
BeanFactoryUtils.beansOfTypeIncludingAncestors(applicationContext, RouterFunction.class) :
applicationContext.getBeansOfType(RouterFunction.class));
List<RouterFunction> routerFunctions = new ArrayList<>(beans.values());
this.routerFunction = routerFunctions.stream().reduce(RouterFunction::andOther).orElse(null); this.routerFunction = routerFunctions.stream().reduce(RouterFunction::andOther).orElse(null);
logRouterFunctions(routerFunctions); logRouterFunctions(routerFunctions);
} }
@SuppressWarnings("rawtypes") private List<RouterFunction<?>> routerFunctions() {
private void logRouterFunctions(List<RouterFunction> routerFunctions) { List<RouterFunction<?>> routerFunctions = new ArrayList<>();
if (this.detectHandlerFunctionsInAncestorContexts) {
detectRouterFunctionsInAncestorContexts(obtainApplicationContext(), routerFunctions);
}
obtainApplicationContext()
.getBeanProvider(RouterFunction.class)
.orderedStream()
.map(router -> (RouterFunction<?>) router)
.collect(Collectors.toCollection(() -> routerFunctions));
return routerFunctions;
}
private void detectRouterFunctionsInAncestorContexts(
ApplicationContext applicationContext, List<RouterFunction<?>> routerFunctions) {
ApplicationContext parentContext = applicationContext.getParent();
if (parentContext != null) {
detectRouterFunctionsInAncestorContexts(parentContext, routerFunctions);
parentContext.getBeanProvider(RouterFunction.class)
.orderedStream()
.map(router -> (RouterFunction<?>) router)
.collect(Collectors.toCollection(() -> routerFunctions));
}
}
private void logRouterFunctions(List<RouterFunction<?>> routerFunctions) {
if (mappingsLogger.isDebugEnabled()) { if (mappingsLogger.isDebugEnabled()) {
routerFunctions.forEach(function -> mappingsLogger.debug("Mapped " + function)); routerFunctions.forEach(function -> mappingsLogger.debug("Mapped " + function));
} }

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2021 the original author or authors. * Copyright 2002-2022 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -21,7 +21,10 @@ import java.util.List;
import java.util.Optional; import java.util.Optional;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.springframework.context.annotation.AnnotationConfigApplicationContext;
import org.springframework.http.converter.HttpMessageConverter; import org.springframework.http.converter.HttpMessageConverter;
import org.springframework.web.servlet.HandlerExecutionChain; import org.springframework.web.servlet.HandlerExecutionChain;
import org.springframework.web.servlet.HandlerMapping; import org.springframework.web.servlet.HandlerMapping;
@ -41,7 +44,7 @@ import static org.assertj.core.api.Assertions.assertThat;
*/ */
class RouterFunctionMappingTests { class RouterFunctionMappingTests {
private List<HttpMessageConverter<?>> messageConverters = Collections.emptyList(); private final List<HttpMessageConverter<?>> messageConverters = Collections.emptyList();
@Test @Test
void normal() throws Exception { void normal() throws Exception {
@ -71,6 +74,65 @@ class RouterFunctionMappingTests {
assertThat(result).isNull(); assertThat(result).isNull();
} }
@Test
void empty() throws Exception {
AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext();
context.refresh();
RouterFunctionMapping mapping = new RouterFunctionMapping();
mapping.setMessageConverters(this.messageConverters);
mapping.setApplicationContext(context);
mapping.afterPropertiesSet();
MockHttpServletRequest request = createTestRequest("/match");
HandlerExecutionChain result = mapping.getHandler(request);
assertThat(result).isNull();
}
@ParameterizedTest
@ValueSource(booleans = {true, false})
void detectHandlerFunctionsInAncestorContexts(boolean detect) throws Exception {
HandlerFunction<ServerResponse> function1 = request -> ServerResponse.ok().build();
HandlerFunction<ServerResponse> function2 = request -> ServerResponse.ok().build();
HandlerFunction<ServerResponse> function3 = request -> ServerResponse.ok().build();
AnnotationConfigApplicationContext context1 = new AnnotationConfigApplicationContext();
context1.registerBean(RouterFunction.class, () -> RouterFunctions.route().GET("/fn1", function1).build());
context1.refresh();
AnnotationConfigApplicationContext context2 = new AnnotationConfigApplicationContext();
context2.registerBean(RouterFunction.class, () -> RouterFunctions.route().GET("/fn2", function2).build());
context2.setParent(context1);
context2.refresh();
AnnotationConfigApplicationContext context3 = new AnnotationConfigApplicationContext();
context3.registerBean(RouterFunction.class, () -> RouterFunctions.route().GET("/fn3", function3).build());
context3.setParent(context2);
context3.refresh();
RouterFunctionMapping mapping = new RouterFunctionMapping();
mapping.setDetectHandlerFunctionsInAncestorContexts(detect);
mapping.setMessageConverters(this.messageConverters);
mapping.setApplicationContext(context3);
mapping.afterPropertiesSet();
HandlerExecutionChain chain1 = mapping.getHandler(createTestRequest("/fn1"));
HandlerExecutionChain chain2 = mapping.getHandler(createTestRequest("/fn2"));
if (detect) {
assertThat(chain1).isNotNull().extracting(HandlerExecutionChain::getHandler).isSameAs(function1);
assertThat(chain2).isNotNull().extracting(HandlerExecutionChain::getHandler).isSameAs(function2);
}
else {
assertThat(chain1).isNull();
assertThat(chain2).isNull();
}
HandlerExecutionChain chain3 = mapping.getHandler(createTestRequest("/fn3"));
assertThat(chain3).isNotNull().extracting(HandlerExecutionChain::getHandler).isSameAs(function3);
}
@Test @Test
void changeParser() throws Exception { void changeParser() throws Exception {
HandlerFunction<ServerResponse> handlerFunction = request -> ServerResponse.ok().build(); HandlerFunction<ServerResponse> handlerFunction = request -> ServerResponse.ok().build();