Improve expired session check algorithm

1. Add session count threshold as am extra pre-condition.
2. Check pre-conditions for expiration checks on every request.

Effectively an upper bound on how many sessions can be created before
expiration checks are performed.

Issue: SPR-17020
This commit is contained in:
Rossen Stoyanchev 2018-07-11 15:59:18 -04:00
parent e9ed45ee3b
commit 32b75221b3
2 changed files with 114 additions and 72 deletions

View File

@ -20,6 +20,7 @@ import java.time.Clock;
import java.time.Duration; import java.time.Duration;
import java.time.Instant; import java.time.Instant;
import java.time.ZoneId; import java.time.ZoneId;
import java.time.temporal.ChronoUnit;
import java.util.Iterator; import java.util.Iterator;
import java.util.Map; import java.util.Map;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
@ -43,9 +44,6 @@ import org.springframework.web.server.WebSession;
*/ */
public class InMemoryWebSessionStore implements WebSessionStore { public class InMemoryWebSessionStore implements WebSessionStore {
/** Minimum period between expiration checks. */
private static final Duration EXPIRATION_CHECK_PERIOD = Duration.ofSeconds(60);
private static final IdGenerator idGenerator = new JdkIdGenerator(); private static final IdGenerator idGenerator = new JdkIdGenerator();
@ -53,9 +51,7 @@ public class InMemoryWebSessionStore implements WebSessionStore {
private final ConcurrentMap<String, InMemoryWebSession> sessions = new ConcurrentHashMap<>(); private final ConcurrentMap<String, InMemoryWebSession> sessions = new ConcurrentHashMap<>();
private volatile Instant nextExpirationCheckTime = Instant.now(this.clock).plus(EXPIRATION_CHECK_PERIOD); private final ExpiredSessionChecker expiredSessionChecker = new ExpiredSessionChecker();
private final ReentrantLock expirationCheckLock = new ReentrantLock();
/** /**
@ -70,8 +66,7 @@ public class InMemoryWebSessionStore implements WebSessionStore {
public void setClock(Clock clock) { public void setClock(Clock clock) {
Assert.notNull(clock, "Clock is required"); Assert.notNull(clock, "Clock is required");
this.clock = clock; this.clock = clock;
// Force a check when clock changes.. this.expiredSessionChecker.removeExpiredSessions(clock.instant());
this.nextExpirationCheckTime = Instant.now(this.clock);
} }
/** /**
@ -84,49 +79,29 @@ public class InMemoryWebSessionStore implements WebSessionStore {
@Override @Override
public Mono<WebSession> createWebSession() { public Mono<WebSession> createWebSession() {
return Mono.fromSupplier(InMemoryWebSession::new); Instant now = this.clock.instant();
this.expiredSessionChecker.checkIfNecessary(now);
return Mono.fromSupplier(() -> new InMemoryWebSession(now));
} }
@Override @Override
public Mono<WebSession> retrieveSession(String id) { public Mono<WebSession> retrieveSession(String id) {
Instant currentTime = Instant.now(this.clock); Instant now = this.clock.instant();
if (!this.sessions.isEmpty() && !currentTime.isBefore(this.nextExpirationCheckTime)) { this.expiredSessionChecker.checkIfNecessary(now);
checkExpiredSessions(currentTime);
}
InMemoryWebSession session = this.sessions.get(id); InMemoryWebSession session = this.sessions.get(id);
if (session == null) { if (session == null) {
return Mono.empty(); return Mono.empty();
} }
else if (session.isExpired(currentTime)) { else if (session.isExpired(now)) {
this.sessions.remove(id); this.sessions.remove(id);
return Mono.empty(); return Mono.empty();
} }
else { else {
session.updateLastAccessTime(currentTime); session.updateLastAccessTime(now);
return Mono.just(session); return Mono.just(session);
} }
} }
private void checkExpiredSessions(Instant currentTime) {
if (this.expirationCheckLock.tryLock()) {
try {
Iterator<InMemoryWebSession> iterator = this.sessions.values().iterator();
while (iterator.hasNext()) {
InMemoryWebSession session = iterator.next();
if (session.isExpired(currentTime)) {
iterator.remove();
session.invalidate();
}
}
}
finally {
this.nextExpirationCheckTime = currentTime.plus(EXPIRATION_CHECK_PERIOD);
this.expirationCheckLock.unlock();
}
}
}
@Override @Override
public Mono<Void> removeSession(String id) { public Mono<Void> removeSession(String id) {
this.sessions.remove(id); this.sessions.remove(id);
@ -137,7 +112,7 @@ public class InMemoryWebSessionStore implements WebSessionStore {
return Mono.fromSupplier(() -> { return Mono.fromSupplier(() -> {
Assert.isInstanceOf(InMemoryWebSession.class, webSession); Assert.isInstanceOf(InMemoryWebSession.class, webSession);
InMemoryWebSession session = (InMemoryWebSession) webSession; InMemoryWebSession session = (InMemoryWebSession) webSession;
session.updateLastAccessTime(Instant.now(getClock())); session.updateLastAccessTime(getClock().instant());
return session; return session;
}); });
} }
@ -157,8 +132,9 @@ public class InMemoryWebSessionStore implements WebSessionStore {
private final AtomicReference<State> state = new AtomicReference<>(State.NEW); private final AtomicReference<State> state = new AtomicReference<>(State.NEW);
public InMemoryWebSession() {
this.creationTime = Instant.now(getClock()); public InMemoryWebSession(Instant creationTime) {
this.creationTime = creationTime;
this.lastAccessTime = this.creationTime; this.lastAccessTime = this.creationTime;
} }
@ -256,6 +232,57 @@ public class InMemoryWebSessionStore implements WebSessionStore {
} }
private class ExpiredSessionChecker {
/** Max time before next expiration checks. */
private static final int CHECK_PERIOD = 60;
/** Max sessions that can be created before next expiration checks. */
private static final int SESSION_COUNT_THRESHOLD = 500;
private final ReentrantLock lock = new ReentrantLock();
private Instant nextCheckTime = Instant.now(clock).plus(CHECK_PERIOD, ChronoUnit.SECONDS);
private long lastSessionCount;
public void checkIfNecessary(Instant now) {
if (howManyCreated() > SESSION_COUNT_THRESHOLD || this.nextCheckTime.isBefore(now)) {
removeExpiredSessions(Instant.now(clock));
}
}
private long howManyCreated() {
return sessions.size() - this.lastSessionCount;
}
public void removeExpiredSessions(Instant now) {
if (sessions.isEmpty()) {
return;
}
if (this.lock.tryLock()) {
try {
Iterator<InMemoryWebSession> iterator = sessions.values().iterator();
while (iterator.hasNext()) {
InMemoryWebSession session = iterator.next();
if (session.isExpired(now)) {
iterator.remove();
session.invalidate();
}
}
}
finally {
this.nextCheckTime = clock.instant().plus(CHECK_PERIOD, ChronoUnit.SECONDS);
this.lastSessionCount = sessions.size();
this.lock.unlock();
}
}
}
}
private enum State { NEW, STARTED, EXPIRED } private enum State { NEW, STARTED, EXPIRED }
} }

View File

@ -18,15 +18,17 @@ package org.springframework.web.server.session;
import java.time.Clock; import java.time.Clock;
import java.time.Duration; import java.time.Duration;
import java.time.Instant; import java.time.Instant;
import java.util.HashMap;
import java.util.Map;
import java.util.stream.IntStream;
import org.junit.Test; import org.junit.Test;
import org.springframework.beans.DirectFieldAccessor;
import org.springframework.web.server.WebSession; import org.springframework.web.server.WebSession;
import static junit.framework.TestCase.assertSame; import static junit.framework.TestCase.assertSame;
import static org.junit.Assert.assertNotNull; import static org.junit.Assert.*;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
/** /**
* Unit tests for {@link InMemoryWebSessionStore}. * Unit tests for {@link InMemoryWebSessionStore}.
@ -91,44 +93,57 @@ public class InMemoryWebSessionStoreTests {
} }
@Test @Test
public void expirationChecks() { public void expirationCheckBasedOnTimeWindow() {
// Create 3 sessions
WebSession session1 = this.store.createWebSession().block();
assertNotNull(session1);
session1.start();
session1.save().block();
WebSession session2 = this.store.createWebSession().block(); DirectFieldAccessor accessor = new DirectFieldAccessor(this.store);
assertNotNull(session2); Map<?,?> sessions = (Map<?, ?>) accessor.getPropertyValue("sessions");
session2.start();
session2.save().block();
WebSession session3 = this.store.createWebSession().block(); // Create 100 sessions
assertNotNull(session3); IntStream.range(0, 100).forEach(i -> insertSession());
session3.start();
session3.save().block();
// Fast-forward 31 minutes // Force a new clock (31 min later) but don't use setter which would clean expired sessions
Clock newClock = Clock.offset(this.store.getClock(), Duration.ofMinutes(31));
accessor.setPropertyValue("clock", newClock);
assertEquals(100, sessions.size());
// Create 50 more which forces a time-based check (clock moved forward)
IntStream.range(0, 50).forEach(i -> insertSession());
assertEquals(50, sessions.size());
}
@Test
@SuppressWarnings("unchecked")
public void expirationCheckBasedOnSessionCount() {
DirectFieldAccessor accessor = new DirectFieldAccessor(this.store);
Map<String, WebSession> sessions = (Map<String, WebSession>) accessor.getPropertyValue("sessions");
// Create 100 sessions
IntStream.range(0, 100).forEach(i -> insertSession());
// Copy sessions (about to be expired)
Map<String, WebSession> expiredSessions = new HashMap<>(sessions);
// Set new clock which expires and removes above sessions
this.store.setClock(Clock.offset(this.store.getClock(), Duration.ofMinutes(31))); this.store.setClock(Clock.offset(this.store.getClock(), Duration.ofMinutes(31)));
assertEquals(0, sessions.size());
// Create 2 more sessions // Re-insert expired sessions
WebSession session4 = this.store.createWebSession().block(); sessions.putAll(expiredSessions);
assertNotNull(session4); assertEquals(100, sessions.size());
session4.start();
session4.save().block();
WebSession session5 = this.store.createWebSession().block(); // Create 600 more to go over the threshold
assertNotNull(session5); IntStream.range(0, 600).forEach(i -> insertSession());
session5.start(); assertEquals(600, sessions.size());
session5.save().block(); }
// Retrieve, forcing cleanup of all expired.. private WebSession insertSession() {
assertNull(this.store.retrieveSession(session1.getId()).block()); WebSession session = this.store.createWebSession().block();
assertNull(this.store.retrieveSession(session2.getId()).block()); assertNotNull(session);
assertNull(this.store.retrieveSession(session3.getId()).block()); session.start();
session.save().block();
assertNotNull(this.store.retrieveSession(session4.getId()).block()); return session;
assertNotNull(this.store.retrieveSession(session5.getId()).block());
} }
} }