diff --git a/spring-web-reactive/src/main/java/org/springframework/web/server/session/CookieWebSessionIdResolver.java b/spring-web-reactive/src/main/java/org/springframework/web/server/session/CookieWebSessionIdResolver.java index 51aeaae895b..8530117ccc4 100644 --- a/spring-web-reactive/src/main/java/org/springframework/web/server/session/CookieWebSessionIdResolver.java +++ b/spring-web-reactive/src/main/java/org/springframework/web/server/session/CookieWebSessionIdResolver.java @@ -18,13 +18,11 @@ package org.springframework.web.server.session; import java.time.Duration; import java.util.Collections; import java.util.List; -import java.util.Map; -import java.util.Optional; +import java.util.stream.Collectors; import org.springframework.http.HttpCookie; import org.springframework.http.ServerHttpCookie; import org.springframework.util.Assert; -import org.springframework.util.CollectionUtils; import org.springframework.util.MultiValueMap; import org.springframework.util.StringUtils; import org.springframework.web.server.ServerWebExchange; @@ -77,10 +75,13 @@ public class CookieWebSessionIdResolver implements WebSessionIdResolver { @Override - public Optional resolveSessionId(ServerWebExchange exchange) { + public List resolveSessionId(ServerWebExchange exchange) { MultiValueMap cookieMap = exchange.getRequest().getCookies(); - HttpCookie cookie = cookieMap.getFirst(getCookieName()); - return (cookie != null ? Optional.of(cookie.getValue()) : Optional.empty()); + List cookies = cookieMap.get(getCookieName()); + if (cookies == null) { + return Collections.emptyList(); + } + return cookies.stream().map(HttpCookie::getValue).collect(Collectors.toList()); } @Override diff --git a/spring-web-reactive/src/main/java/org/springframework/web/server/session/DefaultWebSessionManager.java b/spring-web-reactive/src/main/java/org/springframework/web/server/session/DefaultWebSessionManager.java index 8c1756c313e..5bcb233b958 100644 --- a/spring-web-reactive/src/main/java/org/springframework/web/server/session/DefaultWebSessionManager.java +++ b/spring-web-reactive/src/main/java/org/springframework/web/server/session/DefaultWebSessionManager.java @@ -17,9 +17,10 @@ package org.springframework.web.server.session; import java.time.Clock; import java.time.Instant; -import java.util.Optional; +import java.util.List; import java.util.UUID; +import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import org.springframework.util.Assert; @@ -98,9 +99,8 @@ public class DefaultWebSessionManager implements WebSessionManager { @Override public Mono getSession(ServerWebExchange exchange) { - return Mono.fromCallable(() -> getSessionIdResolver().resolveSessionId(exchange)) - .where(Optional::isPresent) - .map(Optional::get) + return Flux.fromIterable(getSessionIdResolver().resolveSessionId(exchange)) + .next() .then(this.sessionStore::retrieveSession) .then(session -> validateSession(exchange, session)) .otherwiseIfEmpty(createSession(exchange)) @@ -147,8 +147,8 @@ public class DefaultWebSessionManager implements WebSessionManager { // Force explicit start session.start(); - Optional requestedId = getSessionIdResolver().resolveSessionId(exchange); - if (!requestedId.isPresent() || !session.getId().equals(requestedId.get())) { + List requestedIds = getSessionIdResolver().resolveSessionId(exchange); + if (requestedIds.isEmpty() || !session.getId().equals(requestedIds.get(0))) { this.sessionIdResolver.setSessionId(exchange, session.getId()); } return this.sessionStore.storeSession(session); diff --git a/spring-web-reactive/src/main/java/org/springframework/web/server/session/WebSessionIdResolver.java b/spring-web-reactive/src/main/java/org/springframework/web/server/session/WebSessionIdResolver.java index 151952716b0..a57a1cfa7ec 100644 --- a/spring-web-reactive/src/main/java/org/springframework/web/server/session/WebSessionIdResolver.java +++ b/spring-web-reactive/src/main/java/org/springframework/web/server/session/WebSessionIdResolver.java @@ -15,7 +15,7 @@ */ package org.springframework.web.server.session; -import java.util.Optional; +import java.util.List; import org.springframework.web.server.ServerWebExchange; @@ -31,11 +31,11 @@ import org.springframework.web.server.ServerWebExchange; public interface WebSessionIdResolver { /** - * Resolve the session id associated with the request. + * Resolve the session id's associated with the request. * @param exchange the current exchange - * @return the session id if present + * @return the session id's or an empty list */ - Optional resolveSessionId(ServerWebExchange exchange); + List resolveSessionId(ServerWebExchange exchange); /** * Send the given session id to the client or if the session id is "null" diff --git a/spring-web-reactive/src/test/java/org/springframework/web/server/session/DefaultWebSessionManagerTests.java b/spring-web-reactive/src/test/java/org/springframework/web/server/session/DefaultWebSessionManagerTests.java index 5342f97f1a7..c7c17b8a6ae 100644 --- a/spring-web-reactive/src/test/java/org/springframework/web/server/session/DefaultWebSessionManagerTests.java +++ b/spring-web-reactive/src/test/java/org/springframework/web/server/session/DefaultWebSessionManagerTests.java @@ -19,7 +19,9 @@ import java.net.URI; import java.time.Clock; import java.time.Duration; import java.time.Instant; -import java.util.Optional; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; import org.junit.Before; import org.junit.Test; @@ -27,9 +29,9 @@ import org.junit.Test; import org.springframework.http.HttpMethod; import org.springframework.http.server.reactive.MockServerHttpRequest; import org.springframework.http.server.reactive.MockServerHttpResponse; -import org.springframework.web.server.adapter.DefaultServerWebExchange; import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.WebSession; +import org.springframework.web.server.adapter.DefaultServerWebExchange; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; @@ -64,7 +66,7 @@ public class DefaultWebSessionManagerTests { @Test public void getSessionPassive() throws Exception { - this.idResolver.setIdToResolve(Optional.empty()); + this.idResolver.setIdsToResolve(Collections.emptyList()); WebSession session = this.manager.getSession(this.exchange).get(); assertNotNull(session); @@ -79,7 +81,7 @@ public class DefaultWebSessionManagerTests { @Test public void getSessionForceCreate() throws Exception { - this.idResolver.setIdToResolve(Optional.empty()); + this.idResolver.setIdsToResolve(Collections.emptyList()); WebSession session = this.manager.getSession(this.exchange).get(); session.start(); session.save(); @@ -92,7 +94,7 @@ public class DefaultWebSessionManagerTests { @Test public void getSessionAddAttribute() throws Exception { - this.idResolver.setIdToResolve(Optional.empty()); + this.idResolver.setIdsToResolve(Collections.emptyList()); WebSession session = this.manager.getSession(this.exchange).get(); session.getAttributes().put("foo", "bar"); session.save(); @@ -105,7 +107,7 @@ public class DefaultWebSessionManagerTests { DefaultWebSession existing = new DefaultWebSession("1", Clock.systemDefaultZone()); this.manager.getSessionStore().storeSession(existing); - this.idResolver.setIdToResolve(Optional.of("1")); + this.idResolver.setIdsToResolve(Collections.singletonList("1")); WebSession actual = this.manager.getSession(this.exchange).get(); assertSame(existing, actual); } @@ -118,7 +120,7 @@ public class DefaultWebSessionManagerTests { existing.setLastAccessTime(Instant.now(clock).minus(Duration.ofMinutes(31))); this.manager.getSessionStore().storeSession(existing); - this.idResolver.setIdToResolve(Optional.of("1")); + this.idResolver.setIdsToResolve(Collections.singletonList("1")); WebSession actual = this.manager.getSession(this.exchange).get(); assertNotSame(existing, actual); } @@ -126,13 +128,13 @@ public class DefaultWebSessionManagerTests { private static class TestWebSessionIdResolver implements WebSessionIdResolver { - private Optional idToResolve = Optional.empty(); + private List idsToResolve = new ArrayList<>(); private String id = null; - public void setIdToResolve(Optional idToResolve) { - this.idToResolve = idToResolve; + public void setIdsToResolve(List idsToResolve) { + this.idsToResolve = idsToResolve; } public String getId() { @@ -140,8 +142,8 @@ public class DefaultWebSessionManagerTests { } @Override - public Optional resolveSessionId(ServerWebExchange exchange) { - return this.idToResolve; + public List resolveSessionId(ServerWebExchange exchange) { + return this.idsToResolve; } @Override