Earlier detection of token authentication

Use a callback to detect token authentication (via inteceptor) thus
avoiding a potential race between that detection after the message is
sent on the inbound channel (via Executor) and the processing of the
CONNECTED frame returned from the broker on the outbound channel.

Closes gh-23160
This commit is contained in:
Rossen Stoyanchev 2019-07-03 15:22:56 +01:00
parent 5af9a8edae
commit 4e6e47b726
4 changed files with 75 additions and 10 deletions

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2017 the original author or authors.
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -19,6 +19,7 @@ package org.springframework.messaging.simp;
import java.security.Principal;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;
import org.springframework.lang.Nullable;
import org.springframework.messaging.Message;
@ -84,6 +85,10 @@ public class SimpMessageHeaderAccessor extends NativeMessageHeaderAccessor {
public static final String IGNORE_ERROR = "simpIgnoreError";
@Nullable
private Consumer<Principal> userCallback;
/**
* A constructor for creating new message headers.
* This constructor is protected. See factory methods in this and sub-classes.
@ -171,6 +176,9 @@ public class SimpMessageHeaderAccessor extends NativeMessageHeaderAccessor {
public void setUser(@Nullable Principal principal) {
setHeader(USER_HEADER, principal);
if (this.userCallback != null) {
this.userCallback.accept(principal);
}
}
/**
@ -181,6 +189,18 @@ public class SimpMessageHeaderAccessor extends NativeMessageHeaderAccessor {
return (Principal) getHeader(USER_HEADER);
}
/**
* Provide a callback to be invoked if and when {@link #setUser(Principal)}
* is called. This is used internally on the inbound channel to detect
* token-based authentications through an interceptor.
* @param callback the callback to invoke
* @since 5.1.9
*/
public void setUserChangeCallback(Consumer<Principal> callback) {
Assert.notNull(callback, "'callback' is required");
this.userCallback = this.userCallback != null ? this.userCallback.andThen(callback) : callback;
}
@Override
public String getShortLogMessage(Object payload) {
if (getMessageType() == null) {

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2014 the original author or authors.
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,11 +16,14 @@
package org.springframework.messaging.simp;
import java.security.Principal;
import java.util.Collections;
import java.util.function.Consumer;
import org.junit.Test;
import static org.junit.Assert.*;
import static org.mockito.Mockito.mock;
/**
* Unit tests for SimpMessageHeaderAccessor.
@ -63,4 +66,35 @@ public class SimpMessageHeaderAccessorTests {
"{nativeKey=[nativeValue]} payload=p", accessor.getDetailedLogMessage("p"));
}
@Test
public void userChangeCallback() {
UserCallback userCallback = new UserCallback();
SimpMessageHeaderAccessor accessor = SimpMessageHeaderAccessor.create();
accessor.setUserChangeCallback(userCallback);
Principal user1 = mock(Principal.class);
accessor.setUser(user1);
assertEquals(user1, userCallback.getUser());
Principal user2 = mock(Principal.class);
accessor.setUser(user2);
assertEquals(user2, userCallback.getUser());
}
private static class UserCallback implements Consumer<Principal> {
private Principal user;
public Principal getUser() {
return this.user;
}
@Override
public void accept(Principal principal) {
this.user = principal;
}
}
}

View File

@ -258,9 +258,19 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
Assert.state(headerAccessor != null, "No StompHeaderAccessor");
StompCommand command = headerAccessor.getCommand();
boolean isConnect = StompCommand.CONNECT.equals(command) || StompCommand.STOMP.equals(command);
headerAccessor.setSessionId(session.getId());
headerAccessor.setSessionAttributes(session.getAttributes());
headerAccessor.setUser(getUser(session));
if (isConnect) {
headerAccessor.setUserChangeCallback(user -> {
if (user != null && user != session.getPrincipal()) {
this.stompAuthentications.put(session.getId(), user);
}
});
}
headerAccessor.setHeader(SimpMessageHeaderAccessor.HEART_BEAT_HEADER, headerAccessor.getHeartbeat());
if (!detectImmutableMessageInterceptor(outputChannel)) {
headerAccessor.setImmutable();
@ -270,8 +280,6 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
logger.trace("From client: " + headerAccessor.getShortLogMessage(message.getPayload()));
}
StompCommand command = headerAccessor.getCommand();
boolean isConnect = StompCommand.CONNECT.equals(command) || StompCommand.STOMP.equals(command);
if (isConnect) {
this.stats.incrementConnectCount();
}
@ -284,12 +292,6 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
boolean sent = outputChannel.send(message);
if (sent) {
if (isConnect) {
Principal user = headerAccessor.getUser();
if (user != null && user != session.getPrincipal()) {
this.stompAuthentications.put(session.getId(), user);
}
}
if (this.eventPublisher != null) {
Principal user = getUser(session);
if (isConnect) {

View File

@ -378,6 +378,15 @@ public class StompSubProtocolHandlerTests {
Principal user = SimpMessageHeaderAccessor.getUser(message.getHeaders());
assertNotNull(user);
assertEquals("__pete__@gmail.com", user.getName());
StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.CONNECTED);
message = MessageBuilder.createMessage(EMPTY_PAYLOAD, accessor.getMessageHeaders());
handler.handleMessageToClient(this.session, message);
assertEquals(1, this.session.getSentMessages().size());
WebSocketMessage<?> textMessage = this.session.getSentMessages().get(0);
assertEquals("CONNECTED\n" + "user-name:__pete__@gmail.com\n" + "\n" + "\u0000",
textMessage.getPayload());
}
@Test