diff --git a/messaging/src/main/java/org/springframework/security/messaging/context/SecurityContextChannelInterceptor.java b/messaging/src/main/java/org/springframework/security/messaging/context/SecurityContextChannelInterceptor.java index cb5e913a74..6387a2ce47 100644 --- a/messaging/src/main/java/org/springframework/security/messaging/context/SecurityContextChannelInterceptor.java +++ b/messaging/src/main/java/org/springframework/security/messaging/context/SecurityContextChannelInterceptor.java @@ -36,6 +36,9 @@ import org.springframework.util.Assert; * @author Rob Winch */ public final class SecurityContextChannelInterceptor extends ChannelInterceptorAdapter implements ExecutorChannelInterceptor { + private final SecurityContext EMPTY_CONTEXT = SecurityContextHolder.createEmptyContext(); + private static final ThreadLocal ORIGINAL_CONTEXT = new ThreadLocal(); + private final String authenticationHeaderName; /** @@ -75,6 +78,9 @@ public final class SecurityContextChannelInterceptor extends ChannelInterceptorA } private void setup(Message message) { + SecurityContext currentContext = SecurityContextHolder.getContext(); + ORIGINAL_CONTEXT.set(currentContext); + Object user = message.getHeaders().get(authenticationHeaderName); if(!(user instanceof Authentication)) { return; @@ -86,6 +92,17 @@ public final class SecurityContextChannelInterceptor extends ChannelInterceptorA } private void cleanup() { - SecurityContextHolder.clearContext(); + SecurityContext originalContext = ORIGINAL_CONTEXT.get(); + ORIGINAL_CONTEXT.remove(); + + try { + if(EMPTY_CONTEXT.equals(originalContext)) { + SecurityContextHolder.clearContext(); + } else { + SecurityContextHolder.setContext(originalContext); + } + } catch(Throwable t) { + SecurityContextHolder.clearContext(); + } } } diff --git a/messaging/src/test/java/org/springframework/security/messaging/context/SecurityContextChannelInterceptorTests.java b/messaging/src/test/java/org/springframework/security/messaging/context/SecurityContextChannelInterceptorTests.java index 2a1c5d3f7a..8459534374 100644 --- a/messaging/src/test/java/org/springframework/security/messaging/context/SecurityContextChannelInterceptorTests.java +++ b/messaging/src/test/java/org/springframework/security/messaging/context/SecurityContextChannelInterceptorTests.java @@ -146,4 +146,19 @@ public class SecurityContextChannelInterceptorTests { assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull(); } + + @Test + public void restoresOriginalContext() throws Exception { + TestingAuthenticationToken original = new TestingAuthenticationToken("original", "original", "ROLE_USER"); + SecurityContextHolder.getContext().setAuthentication(original); + + messageBuilder.setHeader(SimpMessageHeaderAccessor.USER_HEADER, authentication); + interceptor.beforeHandle(messageBuilder.build(), channel, handler); + + assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs(authentication); + + interceptor.afterMessageHandled(messageBuilder.build(), channel, handler, null); + + assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs(original); + } } \ No newline at end of file