This commit is contained in:
Rossen Stoyanchev 2017-09-05 17:47:02 -04:00
parent c98e01ad1f
commit 320bfdf413
2 changed files with 30 additions and 20 deletions

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2017 the original author or authors. * Copyright 2002-2017 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -23,18 +23,21 @@ import org.springframework.util.Assert;
import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.ServerWebExchange;
/** /**
* Header-based {@link WebSessionIdResolver}. * Request and response header-based {@link WebSessionIdResolver}.
* *
* @author Greg Turnquist * @author Greg Turnquist
* @since 5.0 * @since 5.0
*/ */
public class HeaderSessionIdResolver implements WebSessionIdResolver { public class HeaderWebSessionIdResolver implements WebSessionIdResolver {
private String headerName = "SESSION"; private String headerName = "SESSION";
/** /**
* Set the name of the session header to use for the session id. * Set the name of the session header to use for the session id.
* <p>By default set to "SESSION". * The name is used to extract the session id from the request headers as
* well to set the session id on the response headers.
* <p>By default set to {@literal "SESSION"}.
* @param headerName the header name * @param headerName the header name
*/ */
public void setHeaderName(String headerName) { public void setHeaderName(String headerName) {
@ -49,24 +52,22 @@ public class HeaderSessionIdResolver implements WebSessionIdResolver {
return this.headerName; return this.headerName;
} }
@Override @Override
public List<String> resolveSessionIds(ServerWebExchange exchange) { public List<String> resolveSessionIds(ServerWebExchange exchange) {
HttpHeaders headers = exchange.getRequest().getHeaders(); HttpHeaders headers = exchange.getRequest().getHeaders();
List<String> sessionHeaders = headers.get(this.getHeaderName()); return headers.getOrDefault(getHeaderName(), Collections.emptyList());
if (sessionHeaders == null) {
return Collections.emptyList();
}
return sessionHeaders;
} }
@Override @Override
public void setSessionId(ServerWebExchange exchange, String id) { public void setSessionId(ServerWebExchange exchange, String id) {
Assert.notNull(id, "'id' is required."); Assert.notNull(id, "'id' is required.");
exchange.getResponse().getHeaders().set(this.headerName, id); exchange.getResponse().getHeaders().set(getHeaderName(), id);
} }
@Override @Override
public void expireSession(ServerWebExchange exchange) { public void expireSession(ServerWebExchange exchange) {
this.setSessionId(exchange, ""); this.setSessionId(exchange, "");
} }
} }

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2017 the original author or authors. * Copyright 2002-2017 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -15,11 +15,6 @@
*/ */
package org.springframework.web.server.session; package org.springframework.web.server.session;
import static org.hamcrest.collection.IsCollectionWithSize.*;
import static org.hamcrest.core.Is.*;
import static org.hamcrest.core.IsCollectionContaining.*;
import static org.junit.Assert.*;
import java.time.Clock; import java.time.Clock;
import java.time.Duration; import java.time.Duration;
import java.time.Instant; import java.time.Instant;
@ -29,6 +24,7 @@ import java.util.UUID;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
import org.springframework.http.codec.ServerCodecConfigurer; import org.springframework.http.codec.ServerCodecConfigurer;
import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest; import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest;
import org.springframework.mock.http.server.reactive.test.MockServerHttpResponse; import org.springframework.mock.http.server.reactive.test.MockServerHttpResponse;
@ -37,24 +33,36 @@ import org.springframework.web.server.WebSession;
import org.springframework.web.server.adapter.DefaultServerWebExchange; import org.springframework.web.server.adapter.DefaultServerWebExchange;
import org.springframework.web.server.i18n.AcceptHeaderLocaleContextResolver; import org.springframework.web.server.i18n.AcceptHeaderLocaleContextResolver;
import static org.hamcrest.collection.IsCollectionWithSize.hasSize;
import static org.hamcrest.core.Is.is;
import static org.hamcrest.core.IsCollectionContaining.hasItem;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNotSame;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertThat;
/** /**
* Tests using {@link HeaderSessionIdResolver}. * Tests using {@link HeaderWebSessionIdResolver}.
* *
* @author Greg Turnquist * @author Greg Turnquist
*/ */
public class HeaderSessionIdResolverTests { public class HeaderWebSessionIdResolverTests {
private static final Clock CLOCK = Clock.system(ZoneId.of("GMT")); private static final Clock CLOCK = Clock.system(ZoneId.of("GMT"));
private HeaderSessionIdResolver idResolver;
private HeaderWebSessionIdResolver idResolver;
private DefaultWebSessionManager manager; private DefaultWebSessionManager manager;
private ServerWebExchange exchange; private ServerWebExchange exchange;
@Before @Before
public void setUp() { public void setUp() {
this.idResolver = new HeaderSessionIdResolver(); this.idResolver = new HeaderWebSessionIdResolver();
this.manager = new DefaultWebSessionManager(); this.manager = new DefaultWebSessionManager();
this.manager.setSessionIdResolver(this.idResolver); this.manager.setSessionIdResolver(this.idResolver);
@ -172,4 +180,5 @@ public class HeaderSessionIdResolverTests {
private DefaultWebSession createDefaultWebSession(UUID sessionId) { private DefaultWebSession createDefaultWebSession(UUID sessionId) {
return new DefaultWebSession(() -> sessionId, CLOCK, (s, session) -> Mono.empty(), s -> Mono.empty()); return new DefaultWebSession(() -> sessionId, CLOCK, (s, session) -> Mono.empty(), s -> Mono.empty());
} }
} }