diff --git a/spring-web/src/main/java/org/springframework/web/server/session/InMemoryWebSessionStore.java b/spring-web/src/main/java/org/springframework/web/server/session/InMemoryWebSessionStore.java index 721c3259be..0ee5b88780 100644 --- a/spring-web/src/main/java/org/springframework/web/server/session/InMemoryWebSessionStore.java +++ b/spring-web/src/main/java/org/springframework/web/server/session/InMemoryWebSessionStore.java @@ -20,6 +20,7 @@ import java.time.Clock; import java.time.Duration; import java.time.Instant; import java.time.ZoneId; +import java.time.temporal.ChronoUnit; import java.util.Iterator; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; @@ -43,9 +44,6 @@ import org.springframework.web.server.WebSession; */ public class InMemoryWebSessionStore implements WebSessionStore { - /** Minimum period between expiration checks. */ - private static final Duration EXPIRATION_CHECK_PERIOD = Duration.ofSeconds(60); - private static final IdGenerator idGenerator = new JdkIdGenerator(); @@ -53,9 +51,7 @@ public class InMemoryWebSessionStore implements WebSessionStore { private final ConcurrentMap sessions = new ConcurrentHashMap<>(); - private volatile Instant nextExpirationCheckTime = Instant.now(this.clock).plus(EXPIRATION_CHECK_PERIOD); - - private final ReentrantLock expirationCheckLock = new ReentrantLock(); + private final ExpiredSessionChecker expiredSessionChecker = new ExpiredSessionChecker(); /** @@ -70,8 +66,7 @@ public class InMemoryWebSessionStore implements WebSessionStore { public void setClock(Clock clock) { Assert.notNull(clock, "Clock is required"); this.clock = clock; - // Force a check when clock changes.. - this.nextExpirationCheckTime = Instant.now(this.clock); + this.expiredSessionChecker.removeExpiredSessions(clock.instant()); } /** @@ -84,49 +79,29 @@ public class InMemoryWebSessionStore implements WebSessionStore { @Override public Mono createWebSession() { - return Mono.fromSupplier(InMemoryWebSession::new); + Instant now = this.clock.instant(); + this.expiredSessionChecker.checkIfNecessary(now); + return Mono.fromSupplier(() -> new InMemoryWebSession(now)); } @Override public Mono retrieveSession(String id) { - Instant currentTime = Instant.now(this.clock); - if (!this.sessions.isEmpty() && !currentTime.isBefore(this.nextExpirationCheckTime)) { - checkExpiredSessions(currentTime); - } - + Instant now = this.clock.instant(); + this.expiredSessionChecker.checkIfNecessary(now); InMemoryWebSession session = this.sessions.get(id); if (session == null) { return Mono.empty(); } - else if (session.isExpired(currentTime)) { + else if (session.isExpired(now)) { this.sessions.remove(id); return Mono.empty(); } else { - session.updateLastAccessTime(currentTime); + session.updateLastAccessTime(now); return Mono.just(session); } } - private void checkExpiredSessions(Instant currentTime) { - if (this.expirationCheckLock.tryLock()) { - try { - Iterator iterator = this.sessions.values().iterator(); - while (iterator.hasNext()) { - InMemoryWebSession session = iterator.next(); - if (session.isExpired(currentTime)) { - iterator.remove(); - session.invalidate(); - } - } - } - finally { - this.nextExpirationCheckTime = currentTime.plus(EXPIRATION_CHECK_PERIOD); - this.expirationCheckLock.unlock(); - } - } - } - @Override public Mono removeSession(String id) { this.sessions.remove(id); @@ -137,7 +112,7 @@ public class InMemoryWebSessionStore implements WebSessionStore { return Mono.fromSupplier(() -> { Assert.isInstanceOf(InMemoryWebSession.class, webSession); InMemoryWebSession session = (InMemoryWebSession) webSession; - session.updateLastAccessTime(Instant.now(getClock())); + session.updateLastAccessTime(getClock().instant()); return session; }); } @@ -157,8 +132,9 @@ public class InMemoryWebSessionStore implements WebSessionStore { private final AtomicReference state = new AtomicReference<>(State.NEW); - public InMemoryWebSession() { - this.creationTime = Instant.now(getClock()); + + public InMemoryWebSession(Instant creationTime) { + this.creationTime = creationTime; this.lastAccessTime = this.creationTime; } @@ -256,6 +232,57 @@ public class InMemoryWebSessionStore implements WebSessionStore { } + private class ExpiredSessionChecker { + + /** Max time before next expiration checks. */ + private static final int CHECK_PERIOD = 60; + + /** Max sessions that can be created before next expiration checks. */ + private static final int SESSION_COUNT_THRESHOLD = 500; + + + private final ReentrantLock lock = new ReentrantLock(); + + private Instant nextCheckTime = Instant.now(clock).plus(CHECK_PERIOD, ChronoUnit.SECONDS); + + private long lastSessionCount; + + + public void checkIfNecessary(Instant now) { + if (howManyCreated() > SESSION_COUNT_THRESHOLD || this.nextCheckTime.isBefore(now)) { + removeExpiredSessions(Instant.now(clock)); + } + } + + private long howManyCreated() { + return sessions.size() - this.lastSessionCount; + } + + public void removeExpiredSessions(Instant now) { + if (sessions.isEmpty()) { + return; + } + if (this.lock.tryLock()) { + try { + Iterator iterator = sessions.values().iterator(); + while (iterator.hasNext()) { + InMemoryWebSession session = iterator.next(); + if (session.isExpired(now)) { + iterator.remove(); + session.invalidate(); + } + } + } + finally { + this.nextCheckTime = clock.instant().plus(CHECK_PERIOD, ChronoUnit.SECONDS); + this.lastSessionCount = sessions.size(); + this.lock.unlock(); + } + } + } + } + + private enum State { NEW, STARTED, EXPIRED } } diff --git a/spring-web/src/test/java/org/springframework/web/server/session/InMemoryWebSessionStoreTests.java b/spring-web/src/test/java/org/springframework/web/server/session/InMemoryWebSessionStoreTests.java index aac388493c..b469fc8ac2 100644 --- a/spring-web/src/test/java/org/springframework/web/server/session/InMemoryWebSessionStoreTests.java +++ b/spring-web/src/test/java/org/springframework/web/server/session/InMemoryWebSessionStoreTests.java @@ -18,15 +18,17 @@ package org.springframework.web.server.session; import java.time.Clock; import java.time.Duration; import java.time.Instant; +import java.util.HashMap; +import java.util.Map; +import java.util.stream.IntStream; import org.junit.Test; +import org.springframework.beans.DirectFieldAccessor; import org.springframework.web.server.WebSession; import static junit.framework.TestCase.assertSame; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertNull; -import static org.junit.Assert.assertTrue; +import static org.junit.Assert.*; /** * Unit tests for {@link InMemoryWebSessionStore}. @@ -91,44 +93,57 @@ public class InMemoryWebSessionStoreTests { } @Test - public void expirationChecks() { - // Create 3 sessions - WebSession session1 = this.store.createWebSession().block(); - assertNotNull(session1); - session1.start(); - session1.save().block(); + public void expirationCheckBasedOnTimeWindow() { - WebSession session2 = this.store.createWebSession().block(); - assertNotNull(session2); - session2.start(); - session2.save().block(); + DirectFieldAccessor accessor = new DirectFieldAccessor(this.store); + Map sessions = (Map) accessor.getPropertyValue("sessions"); - WebSession session3 = this.store.createWebSession().block(); - assertNotNull(session3); - session3.start(); - session3.save().block(); + // Create 100 sessions + IntStream.range(0, 100).forEach(i -> insertSession()); - // Fast-forward 31 minutes + // Force a new clock (31 min later) but don't use setter which would clean expired sessions + Clock newClock = Clock.offset(this.store.getClock(), Duration.ofMinutes(31)); + accessor.setPropertyValue("clock", newClock); + + assertEquals(100, sessions.size()); + + // Create 50 more which forces a time-based check (clock moved forward) + IntStream.range(0, 50).forEach(i -> insertSession()); + assertEquals(50, sessions.size()); + } + + @Test + @SuppressWarnings("unchecked") + public void expirationCheckBasedOnSessionCount() { + + DirectFieldAccessor accessor = new DirectFieldAccessor(this.store); + Map sessions = (Map) accessor.getPropertyValue("sessions"); + + // Create 100 sessions + IntStream.range(0, 100).forEach(i -> insertSession()); + + // Copy sessions (about to be expired) + Map expiredSessions = new HashMap<>(sessions); + + // Set new clock which expires and removes above sessions this.store.setClock(Clock.offset(this.store.getClock(), Duration.ofMinutes(31))); + assertEquals(0, sessions.size()); - // Create 2 more sessions - WebSession session4 = this.store.createWebSession().block(); - assertNotNull(session4); - session4.start(); - session4.save().block(); + // Re-insert expired sessions + sessions.putAll(expiredSessions); + assertEquals(100, sessions.size()); - WebSession session5 = this.store.createWebSession().block(); - assertNotNull(session5); - session5.start(); - session5.save().block(); + // Create 600 more to go over the threshold + IntStream.range(0, 600).forEach(i -> insertSession()); + assertEquals(600, sessions.size()); + } - // Retrieve, forcing cleanup of all expired.. - assertNull(this.store.retrieveSession(session1.getId()).block()); - assertNull(this.store.retrieveSession(session2.getId()).block()); - assertNull(this.store.retrieveSession(session3.getId()).block()); - - assertNotNull(this.store.retrieveSession(session4.getId()).block()); - assertNotNull(this.store.retrieveSession(session5.getId()).block()); + private WebSession insertSession() { + WebSession session = this.store.createWebSession().block(); + assertNotNull(session); + session.start(); + session.save().block(); + return session; } }