Improve expired session check algorithm
1. Add session count threshold as am extra pre-condition. 2. Check pre-conditions for expiration checks on every request. Effectively an upper bound on how many sessions can be created before expiration checks are performed. Issue: SPR-17020
This commit is contained in:
parent
e9ed45ee3b
commit
32b75221b3
|
@ -20,6 +20,7 @@ import java.time.Clock;
|
||||||
import java.time.Duration;
|
import java.time.Duration;
|
||||||
import java.time.Instant;
|
import java.time.Instant;
|
||||||
import java.time.ZoneId;
|
import java.time.ZoneId;
|
||||||
|
import java.time.temporal.ChronoUnit;
|
||||||
import java.util.Iterator;
|
import java.util.Iterator;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.concurrent.ConcurrentHashMap;
|
import java.util.concurrent.ConcurrentHashMap;
|
||||||
|
@ -43,9 +44,6 @@ import org.springframework.web.server.WebSession;
|
||||||
*/
|
*/
|
||||||
public class InMemoryWebSessionStore implements WebSessionStore {
|
public class InMemoryWebSessionStore implements WebSessionStore {
|
||||||
|
|
||||||
/** Minimum period between expiration checks. */
|
|
||||||
private static final Duration EXPIRATION_CHECK_PERIOD = Duration.ofSeconds(60);
|
|
||||||
|
|
||||||
private static final IdGenerator idGenerator = new JdkIdGenerator();
|
private static final IdGenerator idGenerator = new JdkIdGenerator();
|
||||||
|
|
||||||
|
|
||||||
|
@ -53,9 +51,7 @@ public class InMemoryWebSessionStore implements WebSessionStore {
|
||||||
|
|
||||||
private final ConcurrentMap<String, InMemoryWebSession> sessions = new ConcurrentHashMap<>();
|
private final ConcurrentMap<String, InMemoryWebSession> sessions = new ConcurrentHashMap<>();
|
||||||
|
|
||||||
private volatile Instant nextExpirationCheckTime = Instant.now(this.clock).plus(EXPIRATION_CHECK_PERIOD);
|
private final ExpiredSessionChecker expiredSessionChecker = new ExpiredSessionChecker();
|
||||||
|
|
||||||
private final ReentrantLock expirationCheckLock = new ReentrantLock();
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -70,8 +66,7 @@ public class InMemoryWebSessionStore implements WebSessionStore {
|
||||||
public void setClock(Clock clock) {
|
public void setClock(Clock clock) {
|
||||||
Assert.notNull(clock, "Clock is required");
|
Assert.notNull(clock, "Clock is required");
|
||||||
this.clock = clock;
|
this.clock = clock;
|
||||||
// Force a check when clock changes..
|
this.expiredSessionChecker.removeExpiredSessions(clock.instant());
|
||||||
this.nextExpirationCheckTime = Instant.now(this.clock);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -84,49 +79,29 @@ public class InMemoryWebSessionStore implements WebSessionStore {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Mono<WebSession> createWebSession() {
|
public Mono<WebSession> createWebSession() {
|
||||||
return Mono.fromSupplier(InMemoryWebSession::new);
|
Instant now = this.clock.instant();
|
||||||
|
this.expiredSessionChecker.checkIfNecessary(now);
|
||||||
|
return Mono.fromSupplier(() -> new InMemoryWebSession(now));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Mono<WebSession> retrieveSession(String id) {
|
public Mono<WebSession> retrieveSession(String id) {
|
||||||
Instant currentTime = Instant.now(this.clock);
|
Instant now = this.clock.instant();
|
||||||
if (!this.sessions.isEmpty() && !currentTime.isBefore(this.nextExpirationCheckTime)) {
|
this.expiredSessionChecker.checkIfNecessary(now);
|
||||||
checkExpiredSessions(currentTime);
|
|
||||||
}
|
|
||||||
|
|
||||||
InMemoryWebSession session = this.sessions.get(id);
|
InMemoryWebSession session = this.sessions.get(id);
|
||||||
if (session == null) {
|
if (session == null) {
|
||||||
return Mono.empty();
|
return Mono.empty();
|
||||||
}
|
}
|
||||||
else if (session.isExpired(currentTime)) {
|
else if (session.isExpired(now)) {
|
||||||
this.sessions.remove(id);
|
this.sessions.remove(id);
|
||||||
return Mono.empty();
|
return Mono.empty();
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
session.updateLastAccessTime(currentTime);
|
session.updateLastAccessTime(now);
|
||||||
return Mono.just(session);
|
return Mono.just(session);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private void checkExpiredSessions(Instant currentTime) {
|
|
||||||
if (this.expirationCheckLock.tryLock()) {
|
|
||||||
try {
|
|
||||||
Iterator<InMemoryWebSession> iterator = this.sessions.values().iterator();
|
|
||||||
while (iterator.hasNext()) {
|
|
||||||
InMemoryWebSession session = iterator.next();
|
|
||||||
if (session.isExpired(currentTime)) {
|
|
||||||
iterator.remove();
|
|
||||||
session.invalidate();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
finally {
|
|
||||||
this.nextExpirationCheckTime = currentTime.plus(EXPIRATION_CHECK_PERIOD);
|
|
||||||
this.expirationCheckLock.unlock();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Mono<Void> removeSession(String id) {
|
public Mono<Void> removeSession(String id) {
|
||||||
this.sessions.remove(id);
|
this.sessions.remove(id);
|
||||||
|
@ -137,7 +112,7 @@ public class InMemoryWebSessionStore implements WebSessionStore {
|
||||||
return Mono.fromSupplier(() -> {
|
return Mono.fromSupplier(() -> {
|
||||||
Assert.isInstanceOf(InMemoryWebSession.class, webSession);
|
Assert.isInstanceOf(InMemoryWebSession.class, webSession);
|
||||||
InMemoryWebSession session = (InMemoryWebSession) webSession;
|
InMemoryWebSession session = (InMemoryWebSession) webSession;
|
||||||
session.updateLastAccessTime(Instant.now(getClock()));
|
session.updateLastAccessTime(getClock().instant());
|
||||||
return session;
|
return session;
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
@ -157,8 +132,9 @@ public class InMemoryWebSessionStore implements WebSessionStore {
|
||||||
|
|
||||||
private final AtomicReference<State> state = new AtomicReference<>(State.NEW);
|
private final AtomicReference<State> state = new AtomicReference<>(State.NEW);
|
||||||
|
|
||||||
public InMemoryWebSession() {
|
|
||||||
this.creationTime = Instant.now(getClock());
|
public InMemoryWebSession(Instant creationTime) {
|
||||||
|
this.creationTime = creationTime;
|
||||||
this.lastAccessTime = this.creationTime;
|
this.lastAccessTime = this.creationTime;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -256,6 +232,57 @@ public class InMemoryWebSessionStore implements WebSessionStore {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
private class ExpiredSessionChecker {
|
||||||
|
|
||||||
|
/** Max time before next expiration checks. */
|
||||||
|
private static final int CHECK_PERIOD = 60;
|
||||||
|
|
||||||
|
/** Max sessions that can be created before next expiration checks. */
|
||||||
|
private static final int SESSION_COUNT_THRESHOLD = 500;
|
||||||
|
|
||||||
|
|
||||||
|
private final ReentrantLock lock = new ReentrantLock();
|
||||||
|
|
||||||
|
private Instant nextCheckTime = Instant.now(clock).plus(CHECK_PERIOD, ChronoUnit.SECONDS);
|
||||||
|
|
||||||
|
private long lastSessionCount;
|
||||||
|
|
||||||
|
|
||||||
|
public void checkIfNecessary(Instant now) {
|
||||||
|
if (howManyCreated() > SESSION_COUNT_THRESHOLD || this.nextCheckTime.isBefore(now)) {
|
||||||
|
removeExpiredSessions(Instant.now(clock));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private long howManyCreated() {
|
||||||
|
return sessions.size() - this.lastSessionCount;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void removeExpiredSessions(Instant now) {
|
||||||
|
if (sessions.isEmpty()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (this.lock.tryLock()) {
|
||||||
|
try {
|
||||||
|
Iterator<InMemoryWebSession> iterator = sessions.values().iterator();
|
||||||
|
while (iterator.hasNext()) {
|
||||||
|
InMemoryWebSession session = iterator.next();
|
||||||
|
if (session.isExpired(now)) {
|
||||||
|
iterator.remove();
|
||||||
|
session.invalidate();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
finally {
|
||||||
|
this.nextCheckTime = clock.instant().plus(CHECK_PERIOD, ChronoUnit.SECONDS);
|
||||||
|
this.lastSessionCount = sessions.size();
|
||||||
|
this.lock.unlock();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
private enum State { NEW, STARTED, EXPIRED }
|
private enum State { NEW, STARTED, EXPIRED }
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,15 +18,17 @@ package org.springframework.web.server.session;
|
||||||
import java.time.Clock;
|
import java.time.Clock;
|
||||||
import java.time.Duration;
|
import java.time.Duration;
|
||||||
import java.time.Instant;
|
import java.time.Instant;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.stream.IntStream;
|
||||||
|
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
||||||
|
import org.springframework.beans.DirectFieldAccessor;
|
||||||
import org.springframework.web.server.WebSession;
|
import org.springframework.web.server.WebSession;
|
||||||
|
|
||||||
import static junit.framework.TestCase.assertSame;
|
import static junit.framework.TestCase.assertSame;
|
||||||
import static org.junit.Assert.assertNotNull;
|
import static org.junit.Assert.*;
|
||||||
import static org.junit.Assert.assertNull;
|
|
||||||
import static org.junit.Assert.assertTrue;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Unit tests for {@link InMemoryWebSessionStore}.
|
* Unit tests for {@link InMemoryWebSessionStore}.
|
||||||
|
@ -91,44 +93,57 @@ public class InMemoryWebSessionStoreTests {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void expirationChecks() {
|
public void expirationCheckBasedOnTimeWindow() {
|
||||||
// Create 3 sessions
|
|
||||||
WebSession session1 = this.store.createWebSession().block();
|
|
||||||
assertNotNull(session1);
|
|
||||||
session1.start();
|
|
||||||
session1.save().block();
|
|
||||||
|
|
||||||
WebSession session2 = this.store.createWebSession().block();
|
DirectFieldAccessor accessor = new DirectFieldAccessor(this.store);
|
||||||
assertNotNull(session2);
|
Map<?,?> sessions = (Map<?, ?>) accessor.getPropertyValue("sessions");
|
||||||
session2.start();
|
|
||||||
session2.save().block();
|
|
||||||
|
|
||||||
WebSession session3 = this.store.createWebSession().block();
|
// Create 100 sessions
|
||||||
assertNotNull(session3);
|
IntStream.range(0, 100).forEach(i -> insertSession());
|
||||||
session3.start();
|
|
||||||
session3.save().block();
|
|
||||||
|
|
||||||
// Fast-forward 31 minutes
|
// Force a new clock (31 min later) but don't use setter which would clean expired sessions
|
||||||
|
Clock newClock = Clock.offset(this.store.getClock(), Duration.ofMinutes(31));
|
||||||
|
accessor.setPropertyValue("clock", newClock);
|
||||||
|
|
||||||
|
assertEquals(100, sessions.size());
|
||||||
|
|
||||||
|
// Create 50 more which forces a time-based check (clock moved forward)
|
||||||
|
IntStream.range(0, 50).forEach(i -> insertSession());
|
||||||
|
assertEquals(50, sessions.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
public void expirationCheckBasedOnSessionCount() {
|
||||||
|
|
||||||
|
DirectFieldAccessor accessor = new DirectFieldAccessor(this.store);
|
||||||
|
Map<String, WebSession> sessions = (Map<String, WebSession>) accessor.getPropertyValue("sessions");
|
||||||
|
|
||||||
|
// Create 100 sessions
|
||||||
|
IntStream.range(0, 100).forEach(i -> insertSession());
|
||||||
|
|
||||||
|
// Copy sessions (about to be expired)
|
||||||
|
Map<String, WebSession> expiredSessions = new HashMap<>(sessions);
|
||||||
|
|
||||||
|
// Set new clock which expires and removes above sessions
|
||||||
this.store.setClock(Clock.offset(this.store.getClock(), Duration.ofMinutes(31)));
|
this.store.setClock(Clock.offset(this.store.getClock(), Duration.ofMinutes(31)));
|
||||||
|
assertEquals(0, sessions.size());
|
||||||
|
|
||||||
// Create 2 more sessions
|
// Re-insert expired sessions
|
||||||
WebSession session4 = this.store.createWebSession().block();
|
sessions.putAll(expiredSessions);
|
||||||
assertNotNull(session4);
|
assertEquals(100, sessions.size());
|
||||||
session4.start();
|
|
||||||
session4.save().block();
|
|
||||||
|
|
||||||
WebSession session5 = this.store.createWebSession().block();
|
// Create 600 more to go over the threshold
|
||||||
assertNotNull(session5);
|
IntStream.range(0, 600).forEach(i -> insertSession());
|
||||||
session5.start();
|
assertEquals(600, sessions.size());
|
||||||
session5.save().block();
|
}
|
||||||
|
|
||||||
// Retrieve, forcing cleanup of all expired..
|
private WebSession insertSession() {
|
||||||
assertNull(this.store.retrieveSession(session1.getId()).block());
|
WebSession session = this.store.createWebSession().block();
|
||||||
assertNull(this.store.retrieveSession(session2.getId()).block());
|
assertNotNull(session);
|
||||||
assertNull(this.store.retrieveSession(session3.getId()).block());
|
session.start();
|
||||||
|
session.save().block();
|
||||||
assertNotNull(this.store.retrieveSession(session4.getId()).block());
|
return session;
|
||||||
assertNotNull(this.store.retrieveSession(session5.getId()).block());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue