WebFlux support for @SessionAttributes

Issue: SPR-15887
This commit is contained in:
Rossen Stoyanchev 2017-09-09 17:35:49 -04:00
parent bc470fca30
commit f76ac5bb32
10 changed files with 609 additions and 70 deletions

View File

@ -97,6 +97,8 @@ class ControllerMethodResolver {
private final Map<ControllerAdviceBean, ExceptionHandlerMethodResolver> exceptionHandlerAdviceCache = private final Map<ControllerAdviceBean, ExceptionHandlerMethodResolver> exceptionHandlerAdviceCache =
new LinkedHashMap<>(64); new LinkedHashMap<>(64);
private final Map<Class<?>, SessionAttributesHandler> sessionAttributesHandlerCache = new ConcurrentHashMap<>(64);
ControllerMethodResolver(ArgumentResolverConfigurer argumentResolvers, ControllerMethodResolver(ArgumentResolverConfigurer argumentResolvers,
List<HttpMessageReader<?>> messageReaders, ReactiveAdapterRegistry reactiveRegistry, List<HttpMessageReader<?>> messageReaders, ReactiveAdapterRegistry reactiveRegistry,
@ -154,6 +156,7 @@ class ControllerMethodResolver {
registrar.addIfModelAttribute(() -> new ErrorsMethodArgumentResolver(reactiveRegistry)); registrar.addIfModelAttribute(() -> new ErrorsMethodArgumentResolver(reactiveRegistry));
registrar.add(new ServerWebExchangeArgumentResolver(reactiveRegistry)); registrar.add(new ServerWebExchangeArgumentResolver(reactiveRegistry));
registrar.add(new PrincipalArgumentResolver(reactiveRegistry)); registrar.add(new PrincipalArgumentResolver(reactiveRegistry));
registrar.addIfRequestBody(readers -> new SessionStatusMethodArgumentResolver());
registrar.add(new WebSessionArgumentResolver(reactiveRegistry)); registrar.add(new WebSessionArgumentResolver(reactiveRegistry));
// Custom... // Custom...
@ -315,6 +318,25 @@ class ControllerMethodResolver {
return invocable; return invocable;
} }
/**
* Return the handler for the type-level {@code @SessionAttributes} annotation
* based on the given controller method.
*/
public SessionAttributesHandler getSessionAttributesHandler(HandlerMethod handlerMethod) {
Class<?> handlerType = handlerMethod.getBeanType();
SessionAttributesHandler result = this.sessionAttributesHandlerCache.get(handlerType);
if (result == null) {
synchronized (this.sessionAttributesHandlerCache) {
result = this.sessionAttributesHandlerCache.get(handlerType);
if (result == null) {
result = new SessionAttributesHandler(handlerType);
this.sessionAttributesHandlerCache.put(handlerType, result);
}
}
}
return result;
}
/** Filter for {@link InitBinder @InitBinder} methods. */ /** Filter for {@link InitBinder @InitBinder} methods. */
private static final ReflectionUtils.MethodFilter BINDER_METHODS = method -> private static final ReflectionUtils.MethodFilter BINDER_METHODS = method ->
@ -336,6 +358,7 @@ class ControllerMethodResolver {
private final List<HandlerMethodArgumentResolver> result = new ArrayList<>(); private final List<HandlerMethodArgumentResolver> result = new ArrayList<>();
private ArgumentResolverRegistrar(ArgumentResolverConfigurer resolvers, private ArgumentResolverRegistrar(ArgumentResolverConfigurer resolvers,
List<HttpMessageReader<?>> messageReaders, boolean modelAttribute) { List<HttpMessageReader<?>> messageReaders, boolean modelAttribute) {

View File

@ -23,12 +23,15 @@ import java.util.List;
import org.springframework.lang.Nullable; import org.springframework.lang.Nullable;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.web.bind.annotation.InitBinder; import org.springframework.web.bind.annotation.InitBinder;
import org.springframework.web.bind.support.SessionStatus;
import org.springframework.web.bind.support.SimpleSessionStatus;
import org.springframework.web.bind.support.WebBindingInitializer; import org.springframework.web.bind.support.WebBindingInitializer;
import org.springframework.web.bind.support.WebExchangeDataBinder; import org.springframework.web.bind.support.WebExchangeDataBinder;
import org.springframework.web.reactive.BindingContext; import org.springframework.web.reactive.BindingContext;
import org.springframework.web.reactive.HandlerResult; import org.springframework.web.reactive.HandlerResult;
import org.springframework.web.reactive.result.method.SyncInvocableHandlerMethod; import org.springframework.web.reactive.result.method.SyncInvocableHandlerMethod;
import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebSession;
/** /**
* Extends {@link BindingContext} with {@code @InitBinder} method initialization. * Extends {@link BindingContext} with {@code @InitBinder} method initialization.
@ -43,6 +46,11 @@ class InitBinderBindingContext extends BindingContext {
/* Simple BindingContext to help with the invoking @InitBinder methods */ /* Simple BindingContext to help with the invoking @InitBinder methods */
private final BindingContext binderMethodContext; private final BindingContext binderMethodContext;
private final SessionStatus sessionStatus = new SimpleSessionStatus();
@Nullable
private Runnable saveModelOperation;
InitBinderBindingContext(@Nullable WebBindingInitializer initializer, InitBinderBindingContext(@Nullable WebBindingInitializer initializer,
List<SyncInvocableHandlerMethod> binderMethods) { List<SyncInvocableHandlerMethod> binderMethods) {
@ -53,6 +61,15 @@ class InitBinderBindingContext extends BindingContext {
} }
/**
* Return the {@link SessionStatus} instance to use that can be used to
* signal that session processing is complete.
*/
public SessionStatus getSessionStatus() {
return this.sessionStatus;
}
@Override @Override
protected WebExchangeDataBinder initDataBinder(WebExchangeDataBinder dataBinder, protected WebExchangeDataBinder initDataBinder(WebExchangeDataBinder dataBinder,
ServerWebExchange exchange) { ServerWebExchange exchange) {
@ -87,4 +104,29 @@ class InitBinderBindingContext extends BindingContext {
} }
} }
/**
* Provide the context required to apply {@link #saveModel()} after the
* controller method has been invoked.
*/
public void setSessionContext(SessionAttributesHandler attributesHandler, WebSession session) {
this.saveModelOperation = () -> {
if (getSessionStatus().isComplete()) {
attributesHandler.cleanupAttributes(session);
}
else {
attributesHandler.storeAttributes(session, getModel().asMap());
}
};
}
/**
* Save model attributes in the session based on a type-level declarations
* in an {@code @SessionAttributes} annotation.
*/
public void saveModel() {
if (this.saveModelOperation != null) {
this.saveModelOperation.run();
}
}
} }

View File

@ -21,13 +21,11 @@ import java.lang.annotation.Annotation;
import java.lang.reflect.Constructor; import java.lang.reflect.Constructor;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Optional;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
import reactor.core.publisher.MonoProcessor; import reactor.core.publisher.MonoProcessor;
import org.springframework.beans.BeanUtils; import org.springframework.beans.BeanUtils;
import org.springframework.core.Conventions;
import org.springframework.core.DefaultParameterNameDiscoverer; import org.springframework.core.DefaultParameterNameDiscoverer;
import org.springframework.core.MethodParameter; import org.springframework.core.MethodParameter;
import org.springframework.core.ParameterNameDiscoverer; import org.springframework.core.ParameterNameDiscoverer;
@ -39,7 +37,6 @@ import org.springframework.lang.Nullable;
import org.springframework.ui.Model; import org.springframework.ui.Model;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.util.ClassUtils; import org.springframework.util.ClassUtils;
import org.springframework.util.StringUtils;
import org.springframework.validation.BindingResult; import org.springframework.validation.BindingResult;
import org.springframework.validation.Errors; import org.springframework.validation.Errors;
import org.springframework.validation.annotation.Validated; import org.springframework.validation.annotation.Validated;
@ -115,7 +112,7 @@ public class ModelAttributeMethodArgumentResolver extends HandlerMethodArgumentR
() -> getClass().getSimpleName() + " does not support multi-value reactive type wrapper: " + () -> getClass().getSimpleName() + " does not support multi-value reactive type wrapper: " +
parameter.getGenericParameterType()); parameter.getGenericParameterType());
String name = getAttributeName(parameter); String name = ModelInitializer.getNameForParameter(parameter);
Mono<?> valueMono = prepareAttributeMono(name, valueType, context, exchange); Mono<?> valueMono = prepareAttributeMono(name, valueType, context, exchange);
Map<String, Object> model = context.getModel().asMap(); Map<String, Object> model = context.getModel().asMap();
@ -150,13 +147,6 @@ public class ModelAttributeMethodArgumentResolver extends HandlerMethodArgumentR
}); });
} }
private String getAttributeName(MethodParameter parameter) {
return Optional.ofNullable(parameter.getParameterAnnotation(ModelAttribute.class))
.filter(ann -> StringUtils.hasText(ann.value()))
.map(ModelAttribute::value)
.orElse(Conventions.getVariableNameForParameter(parameter));
}
private Mono<?> prepareAttributeMono(String attributeName, ResolvableType attributeType, private Mono<?> prepareAttributeMono(String attributeName, ResolvableType attributeType,
BindingContext context, ServerWebExchange exchange) { BindingContext context, ServerWebExchange exchange) {

View File

@ -19,9 +19,11 @@ package org.springframework.web.reactive.result.method.annotation;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.Optional; import java.util.Optional;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import org.jetbrains.annotations.NotNull;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
import org.springframework.core.Conventions; import org.springframework.core.Conventions;
@ -31,8 +33,10 @@ import org.springframework.core.ReactiveAdapterRegistry;
import org.springframework.core.ResolvableType; import org.springframework.core.ResolvableType;
import org.springframework.core.annotation.AnnotatedElementUtils; import org.springframework.core.annotation.AnnotatedElementUtils;
import org.springframework.lang.Nullable; import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils; import org.springframework.util.StringUtils;
import org.springframework.web.bind.annotation.ModelAttribute; import org.springframework.web.bind.annotation.ModelAttribute;
import org.springframework.web.method.HandlerMethod;
import org.springframework.web.reactive.BindingContext; import org.springframework.web.reactive.BindingContext;
import org.springframework.web.reactive.HandlerResult; import org.springframework.web.reactive.HandlerResult;
import org.springframework.web.reactive.result.method.InvocableHandlerMethod; import org.springframework.web.reactive.result.method.InvocableHandlerMethod;
@ -47,40 +51,72 @@ import org.springframework.web.server.ServerWebExchange;
*/ */
class ModelInitializer { class ModelInitializer {
private final ControllerMethodResolver methodResolver;
private final ReactiveAdapterRegistry adapterRegistry; private final ReactiveAdapterRegistry adapterRegistry;
public ModelInitializer(ReactiveAdapterRegistry adapterRegistry) { public ModelInitializer(ControllerMethodResolver methodResolver, ReactiveAdapterRegistry adapterRegistry) {
Assert.notNull(methodResolver, "ControllerMethodResolver is required");
Assert.notNull(adapterRegistry, "ReactiveAdapterRegistry is required");
this.methodResolver = methodResolver;
this.adapterRegistry = adapterRegistry; this.adapterRegistry = adapterRegistry;
} }
/** /**
* Initialize the default model in the given {@code BindingContext} through * Initialize the {@link org.springframework.ui.Model Model} based on a
* the {@code @ModelAttribute} methods and indicate when complete. * (type-level) {@code @SessionAttributes} annotation and
* <p>This will wait for {@code @ModelAttribute} methods that return * {@code @ModelAttribute} methods.
* {@code Mono<Void>} since those may be adding attributes asynchronously. * @param handlerMethod the target controller method
* However if methods return async attributes, those will be added to the * @param bindingContext the context containing the model
* model as-is and without waiting for them to be resolved.
* @param bindingContext the BindingContext with the default model
* @param attributeMethods the {@code @ModelAttribute} methods
* @param exchange the current exchange * @param exchange the current exchange
* @return a {@code Mono} for when the model is populated. * @return a {@code Mono} for when the model is populated.
*/ */
@SuppressWarnings("Convert2MethodRef") @SuppressWarnings("Convert2MethodRef")
public Mono<Void> initModel(BindingContext bindingContext, public Mono<Void> initModel(HandlerMethod handlerMethod, InitBinderBindingContext bindingContext,
List<InvocableHandlerMethod> attributeMethods, ServerWebExchange exchange) { ServerWebExchange exchange) {
List<InvocableHandlerMethod> modelMethods =
this.methodResolver.getModelAttributeMethods(handlerMethod);
SessionAttributesHandler sessionAttributesHandler =
this.methodResolver.getSessionAttributesHandler(handlerMethod);
if (!sessionAttributesHandler.hasSessionAttributes()) {
return invokeModelAttributeMethods(bindingContext, modelMethods, exchange);
}
return exchange.getSession()
.flatMap(session -> {
Map<String, Object> attributes = sessionAttributesHandler.retrieveAttributes(session);
bindingContext.getModel().mergeAttributes(attributes);
bindingContext.setSessionContext(sessionAttributesHandler, session);
return invokeModelAttributeMethods(bindingContext, modelMethods, exchange)
.doOnSuccess(aVoid -> {
findModelAttributes(handlerMethod, sessionAttributesHandler).forEach(name -> {
if (!bindingContext.getModel().containsAttribute(name)) {
Object value = session.getRequiredAttribute(name);
bindingContext.getModel().addAttribute(name, value);
}
});
});
});
}
@NotNull
private Mono<Void> invokeModelAttributeMethods(BindingContext bindingContext,
List<InvocableHandlerMethod> modelMethods, ServerWebExchange exchange) {
List<Mono<HandlerResult>> resultList = new ArrayList<>(); List<Mono<HandlerResult>> resultList = new ArrayList<>();
attributeMethods.forEach(invocable -> resultList.add(invocable.invoke(exchange, bindingContext))); modelMethods.forEach(invocable -> resultList.add(invocable.invoke(exchange, bindingContext)));
return Mono return Mono
.zip(resultList, objectArray -> { .zip(resultList, objectArray ->
return Arrays.stream(objectArray) Arrays.stream(objectArray)
.map(object -> handleResult(((HandlerResult) object), bindingContext)) .map(object -> handleResult(((HandlerResult) object), bindingContext))
.collect(Collectors.toList()); .collect(Collectors.toList()))
}) .flatMap(Mono::when);
.flatMap(completionList -> Mono.when(completionList));
} }
private Mono<Void> handleResult(HandlerResult handlerResult, BindingContext bindingContext) { private Mono<Void> handleResult(HandlerResult handlerResult, BindingContext bindingContext) {
@ -109,4 +145,35 @@ class ModelInitializer {
.orElse(Conventions.getVariableNameForParameter(param)); .orElse(Conventions.getVariableNameForParameter(param));
} }
/** Find {@code @ModelAttribute} arguments also listed as {@code @SessionAttributes}. */
private List<String> findModelAttributes(HandlerMethod handlerMethod,
SessionAttributesHandler sessionAttributesHandler) {
List<String> result = new ArrayList<>();
for (MethodParameter parameter : handlerMethod.getMethodParameters()) {
if (parameter.hasParameterAnnotation(ModelAttribute.class)) {
String name = getNameForParameter(parameter);
Class<?> paramType = parameter.getParameterType();
if (sessionAttributesHandler.isHandlerSessionAttribute(name, paramType)) {
result.add(name);
}
}
}
return result;
}
/**
* Derive the model attribute name for the given method parameter based on
* a {@code @ModelAttribute} parameter annotation (if present) or falling
* back on parameter type based conventions.
* @param parameter a descriptor for the method parameter
* @return the derived name
* @see Conventions#getVariableNameForParameter(MethodParameter)
*/
public static String getNameForParameter(MethodParameter parameter) {
ModelAttribute ann = parameter.getParameterAnnotation(ModelAttribute.class);
String name = (ann != null ? ann.value() : null);
return (StringUtils.hasText(name) ? name : Conventions.getVariableNameForParameter(parameter));
}
} }

View File

@ -169,7 +169,7 @@ public class RequestMappingHandlerAdapter implements HandlerAdapter, Application
this.methodResolver = new ControllerMethodResolver(this.argumentResolverConfigurer, this.methodResolver = new ControllerMethodResolver(this.argumentResolverConfigurer,
this.messageReaders, this.reactiveAdapterRegistry, this.applicationContext); this.messageReaders, this.reactiveAdapterRegistry, this.applicationContext);
this.modelInitializer = new ModelInitializer(this.reactiveAdapterRegistry); this.modelInitializer = new ModelInitializer(this.methodResolver, this.reactiveAdapterRegistry);
} }
@ -183,21 +183,20 @@ public class RequestMappingHandlerAdapter implements HandlerAdapter, Application
HandlerMethod handlerMethod = (HandlerMethod) handler; HandlerMethod handlerMethod = (HandlerMethod) handler;
Assert.state(this.methodResolver != null && this.modelInitializer != null, "Not initialized"); Assert.state(this.methodResolver != null && this.modelInitializer != null, "Not initialized");
BindingContext bindingContext = new InitBinderBindingContext( InitBinderBindingContext bindingContext = new InitBinderBindingContext(
getWebBindingInitializer(), this.methodResolver.getInitBinderMethods(handlerMethod)); getWebBindingInitializer(), this.methodResolver.getInitBinderMethods(handlerMethod));
List<InvocableHandlerMethod> modelAttributeMethods = InvocableHandlerMethod invocableMethod = this.methodResolver.getRequestMappingMethod(handlerMethod);
this.methodResolver.getModelAttributeMethods(handlerMethod);
Function<Throwable, Mono<HandlerResult>> exceptionHandler = Function<Throwable, Mono<HandlerResult>> exceptionHandler =
ex -> handleException(ex, handlerMethod, bindingContext, exchange); ex -> handleException(ex, handlerMethod, bindingContext, exchange);
return this.modelInitializer return this.modelInitializer
.initModel(bindingContext, modelAttributeMethods, exchange) .initModel(handlerMethod, bindingContext, exchange)
.then(Mono.defer(() -> this.methodResolver.getRequestMappingMethod(handlerMethod) .then(Mono.defer(() -> invocableMethod.invoke(exchange, bindingContext)))
.invoke(exchange, bindingContext)
.doOnNext(result -> result.setExceptionHandler(exceptionHandler)) .doOnNext(result -> result.setExceptionHandler(exceptionHandler))
.onErrorResume(exceptionHandler))); .doOnNext(result -> bindingContext.saveModel())
.onErrorResume(exceptionHandler);
} }
private Mono<HandlerResult> handleException(Throwable exception, HandlerMethod handlerMethod, private Mono<HandlerResult> handleException(Throwable exception, HandlerMethod handlerMethod,

View File

@ -0,0 +1,136 @@
/*
* Copyright 2002-2017 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.reactive.result.method.annotation;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import org.springframework.core.annotation.AnnotatedElementUtils;
import org.springframework.util.Assert;
import org.springframework.web.bind.annotation.SessionAttributes;
import org.springframework.web.server.WebSession;
/**
* Package-private class to assist {@link ModelInitializer} with managing model
* attributes in the {@link WebSession} based on model attribute names and types
* declared va {@link SessionAttributes @SessionAttributes}.
*
* @author Rossen Stoyanchev
* @since 5.0
*/
class SessionAttributesHandler {
private final Set<String> attributeNames = new HashSet<>();
private final Set<Class<?>> attributeTypes = new HashSet<>();
private final Set<String> knownAttributeNames = Collections.newSetFromMap(new ConcurrentHashMap<>(4));
/**
* Create a new instance for a controller type. Session attribute names and
* types are extracted from the {@code @SessionAttributes} annotation, if
* present, on the given type.
* @param handlerType the controller type
*/
public SessionAttributesHandler(Class<?> handlerType) {
SessionAttributes annotation =
AnnotatedElementUtils.findMergedAnnotation(handlerType, SessionAttributes.class);
if (annotation != null) {
this.attributeNames.addAll(Arrays.asList(annotation.names()));
this.attributeTypes.addAll(Arrays.asList(annotation.types()));
}
this.knownAttributeNames.addAll(this.attributeNames);
}
/**
* Whether the controller represented by this instance has declared any
* session attributes through an {@link SessionAttributes} annotation.
*/
public boolean hasSessionAttributes() {
return (!this.attributeNames.isEmpty() || !this.attributeTypes.isEmpty());
}
/**
* Whether the attribute name or type match the names and types specified
* via {@code @SessionAttributes} on the underlying controller.
* <p>Attributes successfully resolved through this method are "remembered"
* and subsequently used in {@link #retrieveAttributes(WebSession)}
* and also {@link #cleanupAttributes(WebSession)}.
* @param attributeName the attribute name to check
* @param attributeType the type for the attribute
*/
public boolean isHandlerSessionAttribute(String attributeName, Class<?> attributeType) {
Assert.notNull(attributeName, "Attribute name must not be null");
if (this.attributeNames.contains(attributeName) || this.attributeTypes.contains(attributeType)) {
this.knownAttributeNames.add(attributeName);
return true;
}
else {
return false;
}
}
/**
* Retrieve "known" attributes from the session, i.e. attributes listed
* by name in {@code @SessionAttributes} or attributes previously stored
* in the model that matched by type.
* @param session the current session
* @return a map with handler session attributes, possibly empty
*/
public Map<String, Object> retrieveAttributes(WebSession session) {
Map<String, Object> attributes = new HashMap<>();
this.knownAttributeNames.forEach(name -> {
Object value = session.getAttribute(name);
if (value != null) {
attributes.put(name, value);
}
});
return attributes;
}
/**
* Store a subset of the given attributes in the session. Attributes not
* declared as session attributes via {@code @SessionAttributes} are ignored.
* @param session the current session
* @param attributes candidate attributes for session storage
*/
public void storeAttributes(WebSession session, Map<String, ?> attributes) {
attributes.keySet().forEach(name -> {
Object value = attributes.get(name);
if (value != null && isHandlerSessionAttribute(name, value.getClass())) {
session.getAttributes().put(name, value);
}
});
}
/**
* Remove "known" attributes from the session, i.e. attributes listed
* by name in {@code @SessionAttributes} or attributes previously stored
* in the model that matched by type.
* @param session the current session
*/
public void cleanupAttributes(WebSession session) {
this.knownAttributeNames.forEach(name -> session.getAttributes().remove(name));
}
}

View File

@ -0,0 +1,50 @@
/*
* Copyright 2002-2017 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.reactive.result.method.annotation;
import org.springframework.core.MethodParameter;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.web.bind.support.SessionStatus;
import org.springframework.web.reactive.BindingContext;
import org.springframework.web.reactive.result.method.SyncHandlerMethodArgumentResolver;
import org.springframework.web.server.ServerWebExchange;
/**
* Resolver for a {@link SessionStatus} argument obtaining it from the
* {@link BindingContext}.
*
* @author Rossen Stoyanchev
* @since 5.0
*/
public class SessionStatusMethodArgumentResolver implements SyncHandlerMethodArgumentResolver {
@Override
public boolean supportsParameter(MethodParameter parameter) {
return SessionStatus.class == parameter.getParameterType();
}
@Nullable
@Override
public Object resolveArgumentValue(MethodParameter parameter, BindingContext bindingContext,
ServerWebExchange exchange) {
Assert.isInstanceOf(InitBinderBindingContext.class, bindingContext);
return ((InitBinderBindingContext) bindingContext).getSessionStatus();
}
}

View File

@ -46,7 +46,8 @@ import org.springframework.web.reactive.result.method.SyncInvocableHandlerMethod
import org.springframework.web.server.ResponseStatusException; import org.springframework.web.server.ResponseStatusException;
import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.ServerWebExchange;
import static org.junit.Assert.*; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
/** /**
* Unit tests for {@link ControllerMethodResolver}. * Unit tests for {@link ControllerMethodResolver}.
@ -108,6 +109,7 @@ public class ControllerMethodResolverTests {
assertEquals(ErrorsMethodArgumentResolver.class, next(resolvers, index).getClass()); assertEquals(ErrorsMethodArgumentResolver.class, next(resolvers, index).getClass());
assertEquals(ServerWebExchangeArgumentResolver.class, next(resolvers, index).getClass()); assertEquals(ServerWebExchangeArgumentResolver.class, next(resolvers, index).getClass());
assertEquals(PrincipalArgumentResolver.class, next(resolvers, index).getClass()); assertEquals(PrincipalArgumentResolver.class, next(resolvers, index).getClass());
assertEquals(SessionStatusMethodArgumentResolver.class, next(resolvers, index).getClass());
assertEquals(WebSessionArgumentResolver.class, next(resolvers, index).getClass()); assertEquals(WebSessionArgumentResolver.class, next(resolvers, index).getClass());
assertEquals(CustomArgumentResolver.class, next(resolvers, index).getClass()); assertEquals(CustomArgumentResolver.class, next(resolvers, index).getClass());

View File

@ -23,30 +23,40 @@ import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import org.jetbrains.annotations.NotNull;
import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
import rx.Single; import rx.Single;
import org.springframework.context.support.StaticApplicationContext;
import org.springframework.core.MethodIntrospector; import org.springframework.core.MethodIntrospector;
import org.springframework.core.ReactiveAdapterRegistry; import org.springframework.core.ReactiveAdapterRegistry;
import org.springframework.core.annotation.AnnotationUtils; import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.lang.Nullable;
import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest; import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest;
import org.springframework.ui.Model; import org.springframework.ui.Model;
import org.springframework.util.ReflectionUtils; import org.springframework.util.ReflectionUtils;
import org.springframework.validation.Validator; import org.springframework.validation.Validator;
import org.springframework.web.bind.WebDataBinder; import org.springframework.web.bind.WebDataBinder;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.InitBinder; import org.springframework.web.bind.annotation.InitBinder;
import org.springframework.web.bind.annotation.ModelAttribute; import org.springframework.web.bind.annotation.ModelAttribute;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.SessionAttributes;
import org.springframework.web.bind.support.ConfigurableWebBindingInitializer; import org.springframework.web.bind.support.ConfigurableWebBindingInitializer;
import org.springframework.web.bind.support.WebBindingInitializer; import org.springframework.web.bind.support.WebBindingInitializer;
import org.springframework.web.bind.support.WebExchangeDataBinder; import org.springframework.web.bind.support.WebExchangeDataBinder;
import org.springframework.web.reactive.BindingContext; import org.springframework.web.method.HandlerMethod;
import org.springframework.web.reactive.result.method.InvocableHandlerMethod; import org.springframework.web.method.ResolvableMethod;
import org.springframework.web.reactive.result.method.SyncInvocableHandlerMethod; import org.springframework.web.reactive.result.method.SyncInvocableHandlerMethod;
import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebSession;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.fail;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
/** /**
@ -55,31 +65,55 @@ import static org.mockito.Mockito.mock;
*/ */
public class ModelInitializerTests { public class ModelInitializerTests {
private final ModelInitializer modelInitializer = new ModelInitializer(new ReactiveAdapterRegistry()); private ModelInitializer modelInitializer;
private final ServerWebExchange exchange = MockServerHttpRequest.get("/path").toExchange(); private final ServerWebExchange exchange = MockServerHttpRequest.get("/path").toExchange();
@Before
public void setUp() throws Exception {
ReactiveAdapterRegistry adapterRegistry = new ReactiveAdapterRegistry();
ArgumentResolverConfigurer resolverConfigurer = new ArgumentResolverConfigurer();
resolverConfigurer.addCustomResolver(new ModelArgumentResolver(adapterRegistry));
ControllerMethodResolver methodResolver = new ControllerMethodResolver(
resolverConfigurer, Collections.emptyList(), adapterRegistry, new StaticApplicationContext());
this.modelInitializer = new ModelInitializer(methodResolver, adapterRegistry);
}
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
@Test @Test
public void basic() throws Exception { public void initBinderMethod() throws Exception {
TestController controller = new TestController();
Validator validator = mock(Validator.class); Validator validator = mock(Validator.class);
TestController controller = new TestController();
controller.setValidator(validator); controller.setValidator(validator);
InitBinderBindingContext context = getBindingContext(controller);
List<SyncInvocableHandlerMethod> binderMethods = getBinderMethods(controller); Method method = ResolvableMethod.on(TestController.class).annotPresent(GetMapping.class).resolveMethod();
List<InvocableHandlerMethod> attributeMethods = getAttributeMethods(controller); HandlerMethod handlerMethod = new HandlerMethod(controller, method);
this.modelInitializer.initModel(handlerMethod, context, this.exchange).block(Duration.ofMillis(5000));
WebBindingInitializer bindingInitializer = new ConfigurableWebBindingInitializer(); WebExchangeDataBinder binder = context.createDataBinder(this.exchange, "name");
BindingContext bindingContext = new InitBinderBindingContext(bindingInitializer, binderMethods);
this.modelInitializer.initModel(bindingContext, attributeMethods, this.exchange).block(Duration.ofMillis(5000));
WebExchangeDataBinder binder = bindingContext.createDataBinder(this.exchange, "name");
assertEquals(Collections.singletonList(validator), binder.getValidators()); assertEquals(Collections.singletonList(validator), binder.getValidators());
}
Map<String, Object> model = bindingContext.getModel().asMap(); @SuppressWarnings("unchecked")
@Test
public void modelAttributeMethods() throws Exception {
TestController controller = new TestController();
InitBinderBindingContext context = getBindingContext(controller);
Method method = ResolvableMethod.on(TestController.class).annotPresent(GetMapping.class).resolveMethod();
HandlerMethod handlerMethod = new HandlerMethod(controller, method);
this.modelInitializer.initModel(handlerMethod, context, this.exchange).block(Duration.ofMillis(5000));
Map<String, Object> model = context.getModel().asMap();
assertEquals(5, model.size()); assertEquals(5, model.size());
Object value = model.get("bean"); Object value = model.get("bean");
@ -98,31 +132,101 @@ public class ModelInitializerTests {
assertEquals("Void Mono Method Bean", ((TestBean) value).getName()); assertEquals("Void Mono Method Bean", ((TestBean) value).getName());
} }
private List<SyncInvocableHandlerMethod> getBinderMethods(Object controller) { @Test
return MethodIntrospector public void saveModelAttributeToSession() throws Exception {
.selectMethods(controller.getClass(), BINDER_METHODS).stream() TestController controller = new TestController();
InitBinderBindingContext context = getBindingContext(controller);
Method method = ResolvableMethod.on(TestController.class).annotPresent(GetMapping.class).resolveMethod();
HandlerMethod handlerMethod = new HandlerMethod(controller, method);
this.modelInitializer.initModel(handlerMethod, context, this.exchange).block(Duration.ofMillis(5000));
WebSession session = this.exchange.getSession().block(Duration.ZERO);
assertNotNull(session);
assertEquals(0, session.getAttributes().size());
context.saveModel();
assertEquals(1, session.getAttributes().size());
assertEquals("Bean", ((TestBean) session.getRequiredAttribute("bean")).getName());
}
@Test
public void retrieveModelAttributeFromSession() throws Exception {
WebSession session = this.exchange.getSession().block(Duration.ZERO);
assertNotNull(session);
TestBean testBean = new TestBean("Session Bean");
session.getAttributes().put("bean", testBean);
TestController controller = new TestController();
InitBinderBindingContext context = getBindingContext(controller);
Method method = ResolvableMethod.on(TestController.class).annotPresent(GetMapping.class).resolveMethod();
HandlerMethod handlerMethod = new HandlerMethod(controller, method);
this.modelInitializer.initModel(handlerMethod, context, this.exchange).block(Duration.ofMillis(5000));
context.saveModel();
assertEquals(1, session.getAttributes().size());
assertEquals("Session Bean", ((TestBean) session.getRequiredAttribute("bean")).getName());
}
@Test
public void requiredSessionAttributeMissing() throws Exception {
TestController controller = new TestController();
InitBinderBindingContext context = getBindingContext(controller);
Method method = ResolvableMethod.on(TestController.class).annotPresent(PostMapping.class).resolveMethod();
HandlerMethod handlerMethod = new HandlerMethod(controller, method);
try {
this.modelInitializer.initModel(handlerMethod, context, this.exchange).block(Duration.ofMillis(5000));
fail();
}
catch (IllegalArgumentException ex) {
assertEquals("Required attribute 'missing-bean' is missing.", ex.getMessage());
}
}
@Test
public void clearModelAttributeFromSession() throws Exception {
WebSession session = this.exchange.getSession().block(Duration.ZERO);
assertNotNull(session);
TestBean testBean = new TestBean("Session Bean");
session.getAttributes().put("bean", testBean);
TestController controller = new TestController();
InitBinderBindingContext context = getBindingContext(controller);
Method method = ResolvableMethod.on(TestController.class).annotPresent(GetMapping.class).resolveMethod();
HandlerMethod handlerMethod = new HandlerMethod(controller, method);
this.modelInitializer.initModel(handlerMethod, context, this.exchange).block(Duration.ofMillis(5000));
context.getSessionStatus().setComplete();
context.saveModel();
assertEquals(0, session.getAttributes().size());
}
@NotNull
private InitBinderBindingContext getBindingContext(Object controller) {
List<SyncInvocableHandlerMethod> binderMethods =
MethodIntrospector.selectMethods(controller.getClass(), BINDER_METHODS)
.stream()
.map(method -> new SyncInvocableHandlerMethod(controller, method)) .map(method -> new SyncInvocableHandlerMethod(controller, method))
.collect(Collectors.toList()); .collect(Collectors.toList());;
}
private List<InvocableHandlerMethod> getAttributeMethods(Object controller) { WebBindingInitializer bindingInitializer = new ConfigurableWebBindingInitializer();
return MethodIntrospector return new InitBinderBindingContext(bindingInitializer, binderMethods);
.selectMethods(controller.getClass(), ATTRIBUTE_METHODS).stream()
.map(method -> toInvocable(controller, method))
.collect(Collectors.toList());
}
private InvocableHandlerMethod toInvocable(Object controller, Method method) {
ModelArgumentResolver resolver = new ModelArgumentResolver(new ReactiveAdapterRegistry());
InvocableHandlerMethod handlerMethod = new InvocableHandlerMethod(controller, method);
handlerMethod.setArgumentResolvers(Collections.singletonList(resolver));
return handlerMethod;
} }
@SuppressWarnings("unused") @SuppressWarnings("unused")
@SessionAttributes({"bean", "missing-bean"})
private static class TestController { private static class TestController {
@Nullable
private Validator validator; private Validator validator;
@ -165,8 +269,12 @@ public class ModelInitializerTests {
.then(); .then();
} }
@RequestMapping @GetMapping
public void handle() {} public void handleGet() {}
@PostMapping
public void handlePost(@ModelAttribute("missing-bean") TestBean testBean) {}
} }

View File

@ -0,0 +1,122 @@
/*
* Copyright 2002-2016 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.reactive.result.method.annotation;
import java.time.Duration;
import java.util.HashSet;
import org.junit.Test;
import org.springframework.tests.sample.beans.TestBean;
import org.springframework.ui.ModelMap;
import org.springframework.web.bind.annotation.SessionAttributes;
import org.springframework.web.server.WebSession;
import org.springframework.web.server.session.InMemoryWebSessionStore;
import static java.util.Arrays.asList;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
/**
* Test fixture with {@link SessionAttributesHandler}.
* @author Rossen Stoyanchev
*/
public class SessionAttributesHandlerTests {
private final SessionAttributesHandler sessionAttributesHandler =
new SessionAttributesHandler(TestController.class);
@Test
public void isSessionAttribute() throws Exception {
assertTrue(this.sessionAttributesHandler.isHandlerSessionAttribute("attr1", String.class));
assertTrue(this.sessionAttributesHandler.isHandlerSessionAttribute("attr2", String.class));
assertTrue(this.sessionAttributesHandler.isHandlerSessionAttribute("simple", TestBean.class));
assertFalse(this.sessionAttributesHandler.isHandlerSessionAttribute("simple", String.class));
}
@Test
public void retrieveAttributes() throws Exception {
WebSession session = new InMemoryWebSessionStore().createWebSession().block(Duration.ZERO);
assertNotNull(session);
session.getAttributes().put("attr1", "value1");
session.getAttributes().put("attr2", "value2");
session.getAttributes().put("attr3", new TestBean());
session.getAttributes().put("attr4", new TestBean());
assertEquals("Named attributes (attr1, attr2) should be 'known' right away",
new HashSet<>(asList("attr1", "attr2")),
sessionAttributesHandler.retrieveAttributes(session).keySet());
// Resolve 'attr3' by type
sessionAttributesHandler.isHandlerSessionAttribute("attr3", TestBean.class);
assertEquals("Named attributes (attr1, attr2) and resolved attribute (att3) should be 'known'",
new HashSet<>(asList("attr1", "attr2", "attr3")),
sessionAttributesHandler.retrieveAttributes(session).keySet());
}
@Test
public void cleanupAttributes() throws Exception {
WebSession session = new InMemoryWebSessionStore().createWebSession().block(Duration.ZERO);
assertNotNull(session);
session.getAttributes().put("attr1", "value1");
session.getAttributes().put("attr2", "value2");
session.getAttributes().put("attr3", new TestBean());
this.sessionAttributesHandler.cleanupAttributes(session);
assertNull(session.getAttributes().get("attr1"));
assertNull(session.getAttributes().get("attr2"));
assertNotNull(session.getAttributes().get("attr3"));
// Resolve 'attr3' by type
this.sessionAttributesHandler.isHandlerSessionAttribute("attr3", TestBean.class);
this.sessionAttributesHandler.cleanupAttributes(session);
assertNull(session.getAttributes().get("attr3"));
}
@Test
public void storeAttributes() throws Exception {
WebSession session = new InMemoryWebSessionStore().createWebSession().block(Duration.ZERO);
assertNotNull(session);
ModelMap model = new ModelMap();
model.put("attr1", "value1");
model.put("attr2", "value2");
model.put("attr3", new TestBean());
sessionAttributesHandler.storeAttributes(session, model);
assertEquals("value1", session.getAttributes().get("attr1"));
assertEquals("value2", session.getAttributes().get("attr2"));
assertTrue(session.getAttributes().get("attr3") instanceof TestBean);
}
@SessionAttributes(names = { "attr1", "attr2" }, types = { TestBean.class })
private static class TestController {
}
}