WebSessionStore performs expiration check on retrieve

Issue: SPR-15963
This commit is contained in:
Rossen Stoyanchev 2017-09-26 22:11:35 -04:00
parent fbb428f032
commit cb2deccb2d
6 changed files with 85 additions and 42 deletions

View File

@ -80,7 +80,6 @@ public class DefaultWebSessionManager implements WebSessionManager {
public Mono<WebSession> getSession(ServerWebExchange exchange) { public Mono<WebSession> getSession(ServerWebExchange exchange) {
return Mono.defer(() -> return Mono.defer(() ->
retrieveSession(exchange) retrieveSession(exchange)
.flatMap(session -> removeSessionIfExpired(exchange, session))
.flatMap(this.getSessionStore()::updateLastAccessTime) .flatMap(this.getSessionStore()::updateLastAccessTime)
.switchIfEmpty(this.sessionStore.createWebSession()) .switchIfEmpty(this.sessionStore.createWebSession())
.doOnNext(session -> exchange.getResponse().beforeCommit(() -> save(exchange, session)))); .doOnNext(session -> exchange.getResponse().beforeCommit(() -> save(exchange, session))));
@ -92,14 +91,6 @@ public class DefaultWebSessionManager implements WebSessionManager {
.next(); .next();
} }
private Mono<WebSession> removeSessionIfExpired(ServerWebExchange exchange, WebSession session) {
if (session.isExpired()) {
this.sessionIdResolver.expireSession(exchange);
return this.sessionStore.removeSession(session.getId()).then(Mono.empty());
}
return Mono.just(session);
}
private Mono<Void> save(ServerWebExchange exchange, WebSession session) { private Mono<Void> save(ServerWebExchange exchange, WebSession session) {
if (session.isExpired()) { if (session.isExpired()) {
return Mono.error(new IllegalStateException( return Mono.error(new IllegalStateException(
@ -110,11 +101,14 @@ public class DefaultWebSessionManager implements WebSessionManager {
} }
if (!session.isStarted()) { if (!session.isStarted()) {
if (hasNewSessionId(exchange, session)) {
this.sessionIdResolver.expireSession(exchange);
}
return Mono.empty(); return Mono.empty();
} }
if (hasNewSessionId(exchange, session)) { if (hasNewSessionId(exchange, session)) {
DefaultWebSessionManager.this.sessionIdResolver.setSessionId(exchange, session.getId()); this.sessionIdResolver.setSessionId(exchange, session.getId());
} }
return session.save(); return session.save();

View File

@ -77,7 +77,17 @@ public class InMemoryWebSessionStore implements WebSessionStore {
@Override @Override
public Mono<WebSession> retrieveSession(String id) { public Mono<WebSession> retrieveSession(String id) {
return (this.sessions.containsKey(id) ? Mono.just(this.sessions.get(id)) : Mono.empty()); WebSession session = this.sessions.get(id);
if (session == null) {
return Mono.empty();
}
else if (session.isExpired()) {
this.sessions.remove(id);
return Mono.empty();
}
else {
return Mono.just(session);
}
} }
@Override @Override

View File

@ -40,8 +40,10 @@ public interface WebSessionStore {
/** /**
* Return the WebSession for the given id. * Return the WebSession for the given id.
* <p><strong>Note:</strong> This method should perform an expiration check,
* remove the session if it has expired and return empty.
* @param sessionId the session to load * @param sessionId the session to load
* @return the session, or an empty {@code Mono}. * @return the session, or an empty {@code Mono} .
*/ */
Mono<WebSession> retrieveSession(String sessionId); Mono<WebSession> retrieveSession(String sessionId);

View File

@ -40,7 +40,6 @@ import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.never; import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyZeroInteractions;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
/** /**
@ -64,20 +63,16 @@ public class DefaultWebSessionManagerTests {
@Mock @Mock
private WebSession createSession; private WebSession createSession;
@Mock
private WebSession retrieveSession;
@Mock @Mock
private WebSession updateSession; private WebSession updateSession;
@Before @Before
public void setUp() throws Exception { public void setUp() throws Exception {
when(this.store.createWebSession()).thenReturn(Mono.just(this.createSession)); when(this.store.createWebSession()).thenReturn(Mono.just(this.createSession));
when(this.store.updateLastAccessTime(any())).thenReturn(Mono.just(this.updateSession)); when(this.store.updateLastAccessTime(any())).thenReturn(Mono.just(this.updateSession));
when(this.store.retrieveSession(any())).thenReturn(Mono.just(this.retrieveSession));
when(this.createSession.save()).thenReturn(Mono.empty()); when(this.createSession.save()).thenReturn(Mono.empty());
when(this.updateSession.getId()).thenReturn("update-session-id"); when(this.updateSession.getId()).thenReturn("update-session-id");
when(this.retrieveSession.getId()).thenReturn("retrieve-session-id");
this.manager = new DefaultWebSessionManager(); this.manager = new DefaultWebSessionManager();
this.manager.setSessionIdResolver(this.idResolver); this.manager.setSessionIdResolver(this.idResolver);
@ -97,7 +92,6 @@ public class DefaultWebSessionManagerTests {
assertFalse(session.isStarted()); assertFalse(session.isStarted());
assertFalse(session.isExpired()); assertFalse(session.isExpired());
verifyZeroInteractions(this.retrieveSession, this.updateSession);
verify(this.createSession, never()).save(); verify(this.createSession, never()).save();
verify(this.idResolver, never()).setSessionId(any(), any()); verify(this.idResolver, never()).setSessionId(any(), any());
} }
@ -138,19 +132,6 @@ public class DefaultWebSessionManagerTests {
assertEquals(id, actual.getId()); assertEquals(id, actual.getId());
} }
@Test
public void existingSessionIsExpired() throws Exception {
String id = this.retrieveSession.getId();
when(this.retrieveSession.isExpired()).thenReturn(true);
when(this.idResolver.resolveSessionIds(this.exchange)).thenReturn(Collections.singletonList(id));
when(this.store.removeSession(any())).thenReturn(Mono.empty());
WebSession actual = this.manager.getSession(this.exchange).block();
assertEquals(this.createSession.getId(), actual.getId());
verify(this.store).removeSession(id);
verify(this.idResolver).expireSession(any());
}
@Test @Test
public void multipleSessionIds() throws Exception { public void multipleSessionIds() throws Exception {
WebSession existing = this.updateSession; WebSession existing = this.updateSession;

View File

@ -15,11 +15,16 @@
*/ */
package org.springframework.web.server.session; package org.springframework.web.server.session;
import java.time.Clock;
import java.time.Duration;
import org.junit.Test; import org.junit.Test;
import org.springframework.web.server.WebSession; import org.springframework.web.server.WebSession;
import static junit.framework.TestCase.assertSame;
import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
/** /**
@ -28,34 +33,34 @@ import static org.junit.Assert.assertTrue;
*/ */
public class InMemoryWebSessionStoreTests { public class InMemoryWebSessionStoreTests {
private InMemoryWebSessionStore sessionStore = new InMemoryWebSessionStore(); private InMemoryWebSessionStore store = new InMemoryWebSessionStore();
@Test @Test
public void constructorWhenImplicitStartCopiedThenCopyIsStarted() { public void constructorWhenImplicitStartCopiedThenCopyIsStarted() {
WebSession original = this.sessionStore.createWebSession().block(); WebSession original = this.store.createWebSession().block();
assertNotNull(original); assertNotNull(original);
original.getAttributes().put("foo", "bar"); original.getAttributes().put("foo", "bar");
WebSession copy = this.sessionStore.updateLastAccessTime(original).block(); WebSession copy = this.store.updateLastAccessTime(original).block();
assertNotNull(copy); assertNotNull(copy);
assertTrue(copy.isStarted()); assertTrue(copy.isStarted());
} }
@Test @Test
public void constructorWhenExplicitStartCopiedThenCopyIsStarted() { public void constructorWhenExplicitStartCopiedThenCopyIsStarted() {
WebSession original = this.sessionStore.createWebSession().block(); WebSession original = this.store.createWebSession().block();
assertNotNull(original); assertNotNull(original);
original.start(); original.start();
WebSession copy = this.sessionStore.updateLastAccessTime(original).block(); WebSession copy = this.store.updateLastAccessTime(original).block();
assertNotNull(copy); assertNotNull(copy);
assertTrue(copy.isStarted()); assertTrue(copy.isStarted());
} }
@Test @Test
public void startsSessionExplicitly() { public void startsSessionExplicitly() {
WebSession session = this.sessionStore.createWebSession().block(); WebSession session = this.store.createWebSession().block();
assertNotNull(session); assertNotNull(session);
session.start(); session.start();
assertTrue(session.isStarted()); assertTrue(session.isStarted());
@ -63,11 +68,27 @@ public class InMemoryWebSessionStoreTests {
@Test @Test
public void startsSessionImplicitly() { public void startsSessionImplicitly() {
WebSession session = this.sessionStore.createWebSession().block(); WebSession session = this.store.createWebSession().block();
assertNotNull(session); assertNotNull(session);
session.start(); session.start();
session.getAttributes().put("foo", "bar"); session.getAttributes().put("foo", "bar");
assertTrue(session.isStarted()); assertTrue(session.isStarted());
} }
@Test
public void retrieveExpiredSession() throws Exception {
WebSession session = this.store.createWebSession().block();
assertNotNull(session);
session.getAttributes().put("foo", "bar");
session.save();
String id = session.getId();
WebSession retrieved = this.store.retrieveSession(id).block();
assertNotNull(retrieved);
assertSame(session, retrieved);
this.store.setClock(Clock.offset(this.store.getClock(), Duration.ofMinutes(31)));
WebSession retrievedAgain = this.store.retrieveSession(id).block();
assertNull(retrievedAgain);
}
} }

View File

@ -43,6 +43,7 @@ import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotEquals; import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull; import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
/** /**
* Integration tests for with a server-side session. * Integration tests for with a server-side session.
@ -109,7 +110,7 @@ public class WebSessionIntegrationTests extends AbstractHttpHandlerIntegrationTe
assertNull(response.getHeaders().get("Set-Cookie")); assertNull(response.getHeaders().get("Set-Cookie"));
assertEquals(2, this.handler.getSessionRequestCount()); assertEquals(2, this.handler.getSessionRequestCount());
// Now set the clock of the session back by 31 minutes // Now fast-forward by 31 minutes
InMemoryWebSessionStore store = (InMemoryWebSessionStore) this.sessionManager.getSessionStore(); InMemoryWebSessionStore store = (InMemoryWebSessionStore) this.sessionManager.getSessionStore();
WebSession session = store.retrieveSession(id).block(); WebSession session = store.retrieveSession(id).block();
assertNotNull(session); assertNotNull(session);
@ -125,6 +126,33 @@ public class WebSessionIntegrationTests extends AbstractHttpHandlerIntegrationTe
assertEquals(1, this.handler.getSessionRequestCount()); assertEquals(1, this.handler.getSessionRequestCount());
} }
@Test
public void expiredSessionEnds() throws Exception {
// First request: no session yet, new session created
RequestEntity<Void> request = RequestEntity.get(createUri()).build();
ResponseEntity<Void> response = this.restTemplate.exchange(request, Void.class);
assertEquals(HttpStatus.OK, response.getStatusCode());
String id = extractSessionId(response.getHeaders());
assertNotNull(id);
assertEquals(1, this.handler.getSessionRequestCount());
// Now fast-forward by 31 minutes
InMemoryWebSessionStore store = (InMemoryWebSessionStore) this.sessionManager.getSessionStore();
store.setClock(Clock.offset(store.getClock(), Duration.ofMinutes(31)));
// Second request: session expires
URI uri = new URI("http://localhost:" + this.port + "/?expiredSession");
request = RequestEntity.get(uri).header("Cookie", "SESSION=" + id).build();
response = this.restTemplate.exchange(request, Void.class);
assertEquals(HttpStatus.OK, response.getStatusCode());
String value = response.getHeaders().getFirst("Set-Cookie");
assertNotNull(value);
assertTrue("Actual value: " + value, value.contains("Max-Age=0"));
}
@Test @Test
public void changeSessionId() throws Exception { public void changeSessionId() throws Exception {
@ -178,11 +206,18 @@ public class WebSessionIntegrationTests extends AbstractHttpHandlerIntegrationTe
@Override @Override
public Mono<Void> handle(ServerWebExchange exchange) { public Mono<Void> handle(ServerWebExchange exchange) {
if (exchange.getRequest().getQueryParams().containsKey("changeId")) { if (exchange.getRequest().getQueryParams().containsKey("expiredSession")) {
return exchange.getSession().doOnNext(session -> {
// Don't do anything, leave it expired...
}).then();
}
else if (exchange.getRequest().getQueryParams().containsKey("changeId")) {
return exchange.getSession().flatMap(session -> return exchange.getSession().flatMap(session ->
session.changeSessionId().doOnSuccess(aVoid -> updateSessionAttribute(session))); session.changeSessionId().doOnSuccess(aVoid -> updateSessionAttribute(session)));
} }
return exchange.getSession().doOnSuccess(this::updateSessionAttribute).then(); else {
return exchange.getSession().doOnSuccess(this::updateSessionAttribute).then();
}
} }
private void updateSessionAttribute(WebSession session) { private void updateSessionAttribute(WebSession session) {