Refactor DefaultWebSession

Use copy constructor to refresh a session with lastAccessTime and a
save function referencing the current exchange. As a result both fields
are now final and ConfigurableWebSession is no longer needed.
This commit is contained in:
Rossen Stoyanchev 2017-07-14 22:50:32 +02:00
parent bf712957f6
commit 47b63150d1
6 changed files with 78 additions and 123 deletions

View File

@ -1,46 +0,0 @@
/*
* Copyright 2002-2016 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.server.session;
import java.time.Instant;
import java.util.function.Supplier;
import reactor.core.publisher.Mono;
import org.springframework.web.server.WebSession;
/**
* Extend {@link WebSession} with management operations meant for internal use
* for example by implementations of {@link WebSessionManager}.
*
* @author Rossen Stoyanchev
* @since 5.0
*/
public interface ConfigurableWebSession extends WebSession {
/**
* Update the last access time for user-related session activity.
* @param time the time of access
*/
void setLastAccessTime(Instant time);
/**
* Set the operation to invoke when {@link WebSession#save()} is invoked.
* @param saveOperation the save operation
*/
void setSaveOperation(Supplier<Mono<Void>> saveOperation);
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2016 the original author or authors.
* Copyright 2002-2017 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -15,18 +15,18 @@
*/
package org.springframework.web.server.session;
import java.io.Serializable;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier;
import java.util.function.Function;
import reactor.core.publisher.Mono;
import org.springframework.util.Assert;
import org.springframework.web.server.WebSession;
/**
* Default implementation of {@link org.springframework.web.server.WebSession}.
@ -34,10 +34,7 @@ import org.springframework.util.Assert;
* @author Rossen Stoyanchev
* @since 5.0
*/
public class DefaultWebSession implements ConfigurableWebSession, Serializable {
private static final long serialVersionUID = -3567697426432961630L;
class DefaultWebSession implements WebSession {
private final String id;
@ -45,55 +42,66 @@ public class DefaultWebSession implements ConfigurableWebSession, Serializable {
private final Clock clock;
private final Function<WebSession, Mono<Void>> saveOperation;
private final Instant creationTime;
private volatile Instant lastAccessTime;
private final Instant lastAccessTime;
private volatile Duration maxIdleTime;
private AtomicReference<State> state = new AtomicReference<>();
private volatile transient Supplier<Mono<Void>> saveOperation;
private final AtomicReference<State> state;
/**
* Constructor to create a new session.
* Constructor for creating a brand, new session.
* @param id the session id
* @param clock for access to current time
*/
public DefaultWebSession(String id, Clock clock) {
DefaultWebSession(String id, Clock clock, Function<WebSession, Mono<Void>> saveOperation) {
Assert.notNull(id, "'id' is required.");
Assert.notNull(clock, "'clock' is required.");
this.id = id;
this.clock = clock;
this.saveOperation = saveOperation;
this.attributes = new ConcurrentHashMap<>();
this.creationTime = Instant.now(clock);
this.lastAccessTime = this.creationTime;
this.maxIdleTime = Duration.ofMinutes(30);
this.state.set(State.NEW);
this.state = new AtomicReference<>(State.NEW);
}
/**
* Constructor to load existing session.
* @param id the session id
* @param attributes the attributes of the session
* @param clock for access to current time
* @param creationTime the creation time
* Constructor to refresh an existing session for a new request.
* @param existingSession the session to recreate
* @param lastAccessTime the last access time
* @param maxIdleTime the configured maximum session idle time
* @param saveOperation save operation for the current request
*/
public DefaultWebSession(String id, Map<String, Object> attributes, Clock clock,
Instant creationTime, Instant lastAccessTime, Duration maxIdleTime) {
DefaultWebSession(DefaultWebSession existingSession, Instant lastAccessTime,
Function<WebSession, Mono<Void>> saveOperation) {
Assert.notNull(id, "'id' is required.");
Assert.notNull(clock, "'clock' is required.");
this.id = id;
this.attributes = new ConcurrentHashMap<>(attributes);
this.clock = clock;
this.creationTime = creationTime;
this.id = existingSession.id;
this.attributes = existingSession.attributes;
this.clock = existingSession.clock;
this.creationTime = existingSession.creationTime;
this.lastAccessTime = lastAccessTime;
this.maxIdleTime = maxIdleTime;
this.state.set(State.STARTED);
this.maxIdleTime = existingSession.maxIdleTime;
this.saveOperation = saveOperation;
this.state = existingSession.state;
}
/**
* For testing purposes.
*/
DefaultWebSession(DefaultWebSession existingSession, Instant lastAccessTime) {
this.id = existingSession.id;
this.attributes = existingSession.attributes;
this.clock = existingSession.clock;
this.creationTime = existingSession.creationTime;
this.lastAccessTime = lastAccessTime;
this.maxIdleTime = existingSession.maxIdleTime;
this.saveOperation = existingSession.saveOperation;
this.state = existingSession.state;
}
@ -112,11 +120,6 @@ public class DefaultWebSession implements ConfigurableWebSession, Serializable {
return this.creationTime;
}
@Override
public void setLastAccessTime(Instant lastAccessTime) {
this.lastAccessTime = lastAccessTime;
}
@Override
public Instant getLastAccessTime() {
return this.lastAccessTime;
@ -136,16 +139,6 @@ public class DefaultWebSession implements ConfigurableWebSession, Serializable {
return this.maxIdleTime;
}
@Override
public void setSaveOperation(Supplier<Mono<Void>> saveOperation) {
Assert.notNull(saveOperation, "'saveOperation' is required.");
this.saveOperation = saveOperation;
}
protected Supplier<Mono<Void>> getSaveOperation() {
return this.saveOperation;
}
@Override
public void start() {
@ -160,7 +153,7 @@ public class DefaultWebSession implements ConfigurableWebSession, Serializable {
@Override
public Mono<Void> save() {
return this.saveOperation.get();
return this.saveOperation.apply(this);
}
@Override

View File

@ -107,36 +107,29 @@ public class DefaultWebSessionManager implements WebSessionManager {
return Mono.defer(() ->
retrieveSession(exchange)
.flatMap(session -> removeSessionIfExpired(exchange, session))
.switchIfEmpty(createSession())
.doOnNext(session -> {
if (session instanceof ConfigurableWebSession) {
ConfigurableWebSession configurable = (ConfigurableWebSession) session;
configurable.setSaveOperation(() -> saveSession(exchange, session));
configurable.setLastAccessTime(Instant.now(getClock()));
}
exchange.getResponse().beforeCommit(session::save);
}));
.map(session -> {
Instant lastAccessTime = Instant.now(getClock());
return new DefaultWebSession(session, lastAccessTime, s -> saveSession(exchange, s));
})
.switchIfEmpty(createSession(exchange))
.doOnNext(session -> exchange.getResponse().beforeCommit(session::save)));
}
private Mono<WebSession> retrieveSession(ServerWebExchange exchange) {
private Mono<DefaultWebSession> retrieveSession(ServerWebExchange exchange) {
return Flux.fromIterable(getSessionIdResolver().resolveSessionIds(exchange))
.concatMap(this.sessionStore::retrieveSession)
.cast(DefaultWebSession.class)
.next();
}
private Mono<WebSession> removeSessionIfExpired(ServerWebExchange exchange, WebSession session) {
private Mono<DefaultWebSession> removeSessionIfExpired(ServerWebExchange exchange, DefaultWebSession session) {
if (session.isExpired()) {
this.sessionIdResolver.setSessionId(exchange, "");
this.sessionIdResolver.expireSession(exchange);
return this.sessionStore.removeSession(session.getId()).then(Mono.empty());
}
return Mono.just(session);
}
private Mono<DefaultWebSession> createSession() {
return Mono.fromSupplier(() ->
new DefaultWebSession(UUID.randomUUID().toString(), getClock()));
}
private Mono<Void> saveSession(ServerWebExchange exchange, WebSession session) {
if (session.isExpired()) {
return Mono.error(new IllegalStateException(
@ -165,4 +158,11 @@ public class DefaultWebSessionManager implements WebSessionManager {
return ids.isEmpty() || !session.getId().equals(ids.get(0));
}
private Mono<DefaultWebSession> createSession(ServerWebExchange exchange) {
return Mono.fromSupplier(() -> {
String id = UUID.randomUUID().toString();
return new DefaultWebSession(id, getClock(), sess -> saveSession(exchange, sess));
});
}
}

View File

@ -25,6 +25,7 @@ import java.util.List;
import org.junit.Before;
import org.junit.Test;
import reactor.core.publisher.Mono;
import org.springframework.http.codec.ServerCodecConfigurer;
import org.springframework.lang.Nullable;
@ -104,20 +105,22 @@ public class DefaultWebSessionManagerTests {
@Test
public void existingSession() throws Exception {
DefaultWebSession existing = new DefaultWebSession("1", Clock.systemDefaultZone());
DefaultWebSession existing = new DefaultWebSession("1", Clock.systemDefaultZone(), s -> Mono.empty());
this.manager.getSessionStore().storeSession(existing);
this.idResolver.setIdsToResolve(Collections.singletonList("1"));
WebSession actual = this.manager.getSession(this.exchange).block();
assertSame(existing, actual);
assertNotNull(actual);
assertEquals(existing.getId(), actual.getId());
}
@Test
public void existingSessionIsExpired() throws Exception {
Clock clock = Clock.systemDefaultZone();
DefaultWebSession existing = new DefaultWebSession("1", clock);
DefaultWebSession existing = new DefaultWebSession("1", clock, s -> Mono.empty());
existing.start();
existing.setLastAccessTime(Instant.now(clock).minus(Duration.ofMinutes(31)));
Instant lastAccessTime = Instant.now(clock).minus(Duration.ofMinutes(31));
existing = new DefaultWebSession(existing, lastAccessTime, s -> Mono.empty());
this.manager.getSessionStore().storeSession(existing);
this.idResolver.setIdsToResolve(Collections.singletonList("1"));
@ -127,12 +130,13 @@ public class DefaultWebSessionManagerTests {
@Test
public void multipleSessions() throws Exception {
DefaultWebSession existing = new DefaultWebSession("3", Clock.systemDefaultZone());
DefaultWebSession existing = new DefaultWebSession("3", Clock.systemDefaultZone(), s -> Mono.empty());
this.manager.getSessionStore().storeSession(existing);
this.idResolver.setIdsToResolve(Arrays.asList("1", "2", "3"));
WebSession actual = this.manager.getSession(this.exchange).block();
assertSame(existing, actual);
assertNotNull(actual);
assertEquals(existing.getId(), actual.getId());
}

View File

@ -20,6 +20,7 @@ import java.net.URI;
import java.net.URISyntaxException;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
@ -37,7 +38,6 @@ import org.springframework.util.StringUtils;
import org.springframework.web.client.RestTemplate;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebHandler;
import org.springframework.web.server.WebSession;
import org.springframework.web.server.adapter.WebHttpHandlerBuilder;
import static org.junit.Assert.assertEquals;
@ -116,9 +116,12 @@ public class WebSessionIntegrationTests extends AbstractHttpHandlerIntegrationTe
assertEquals(2, this.handler.getCount());
// Update lastAccessTime of the created session to -31 min
WebSession session = this.sessionManager.getSessionStore().retrieveSession(id).block();
((DefaultWebSession) session).setLastAccessTime(
Clock.offset(this.sessionManager.getClock(), Duration.ofMinutes(-31)).instant());
WebSessionStore store = this.sessionManager.getSessionStore();
DefaultWebSession session = (DefaultWebSession) store.retrieveSession(id).block();
assertNotNull(session);
Instant lastAccessTime = Clock.offset(this.sessionManager.getClock(), Duration.ofMinutes(-31)).instant();
session = new DefaultWebSession(session, lastAccessTime);
store.storeSession(session);
// Third request: expired session, new session created
request = RequestEntity.get(createUri("/")).header("Cookie", "SESSION=" + id).build();

View File

@ -15,8 +15,6 @@
*/
package org.springframework.web.reactive.result.method.annotation;
import java.time.Clock;
import io.reactivex.Single;
import org.junit.Test;
import reactor.core.publisher.Mono;
@ -32,11 +30,12 @@ import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebSession;
import org.springframework.web.server.adapter.DefaultServerWebExchange;
import org.springframework.web.server.i18n.AcceptHeaderLocaleContextResolver;
import org.springframework.web.server.session.DefaultWebSession;
import org.springframework.web.server.session.WebSessionManager;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.mock;
/**
* Unit tests for {@link WebSessionArgumentResolver}.
@ -62,7 +61,7 @@ public class WebSessionArgumentResolverTests {
public void resolverArgument() throws Exception {
BindingContext context = new BindingContext();
WebSession session = new DefaultWebSession("id", Clock.systemDefaultZone());
WebSession session = mock(WebSession.class);
WebSessionManager manager = exchange -> Mono.just(session);
MockServerHttpRequest request = MockServerHttpRequest.get("/").build();
ServerWebExchange exchange = new DefaultServerWebExchange(request, new MockServerHttpResponse(),
@ -74,11 +73,13 @@ public class WebSessionArgumentResolverTests {
param = this.testMethod.arg(Mono.class, WebSession.class);
actual = this.resolver.resolveArgument(param, context, exchange).block();
assertNotNull(actual);
assertTrue(Mono.class.isAssignableFrom(actual.getClass()));
assertSame(session, ((Mono<?>) actual).block());
param = this.testMethod.arg(Single.class, WebSession.class);
actual = this.resolver.resolveArgument(param, context, exchange).block();
assertNotNull(actual);
assertTrue(Single.class.isAssignableFrom(actual.getClass()));
assertSame(session, ((Single<?>) actual).blockingGet());
}