Earlier detection of token authentication
Use a callback to detect token authentication (via inteceptor) thus avoiding a potential race between that detection after the message is sent on the inbound channel (via Executor) and the processing of the CONNECTED frame returned from the broker on the outbound channel. Closes gh-23160
This commit is contained in:
parent
5af9a8edae
commit
4e6e47b726
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright 2002-2017 the original author or authors.
|
||||
* Copyright 2002-2019 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -19,6 +19,7 @@ package org.springframework.messaging.simp;
|
|||
import java.security.Principal;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.function.Consumer;
|
||||
|
||||
import org.springframework.lang.Nullable;
|
||||
import org.springframework.messaging.Message;
|
||||
|
@ -84,6 +85,10 @@ public class SimpMessageHeaderAccessor extends NativeMessageHeaderAccessor {
|
|||
public static final String IGNORE_ERROR = "simpIgnoreError";
|
||||
|
||||
|
||||
@Nullable
|
||||
private Consumer<Principal> userCallback;
|
||||
|
||||
|
||||
/**
|
||||
* A constructor for creating new message headers.
|
||||
* This constructor is protected. See factory methods in this and sub-classes.
|
||||
|
@ -171,6 +176,9 @@ public class SimpMessageHeaderAccessor extends NativeMessageHeaderAccessor {
|
|||
|
||||
public void setUser(@Nullable Principal principal) {
|
||||
setHeader(USER_HEADER, principal);
|
||||
if (this.userCallback != null) {
|
||||
this.userCallback.accept(principal);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -181,6 +189,18 @@ public class SimpMessageHeaderAccessor extends NativeMessageHeaderAccessor {
|
|||
return (Principal) getHeader(USER_HEADER);
|
||||
}
|
||||
|
||||
/**
|
||||
* Provide a callback to be invoked if and when {@link #setUser(Principal)}
|
||||
* is called. This is used internally on the inbound channel to detect
|
||||
* token-based authentications through an interceptor.
|
||||
* @param callback the callback to invoke
|
||||
* @since 5.1.9
|
||||
*/
|
||||
public void setUserChangeCallback(Consumer<Principal> callback) {
|
||||
Assert.notNull(callback, "'callback' is required");
|
||||
this.userCallback = this.userCallback != null ? this.userCallback.andThen(callback) : callback;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getShortLogMessage(Object payload) {
|
||||
if (getMessageType() == null) {
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright 2002-2014 the original author or authors.
|
||||
* Copyright 2002-2019 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -16,11 +16,14 @@
|
|||
|
||||
package org.springframework.messaging.simp;
|
||||
|
||||
import java.security.Principal;
|
||||
import java.util.Collections;
|
||||
import java.util.function.Consumer;
|
||||
|
||||
import org.junit.Test;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
import static org.mockito.Mockito.mock;
|
||||
|
||||
/**
|
||||
* Unit tests for SimpMessageHeaderAccessor.
|
||||
|
@ -63,4 +66,35 @@ public class SimpMessageHeaderAccessorTests {
|
|||
"{nativeKey=[nativeValue]} payload=p", accessor.getDetailedLogMessage("p"));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void userChangeCallback() {
|
||||
UserCallback userCallback = new UserCallback();
|
||||
SimpMessageHeaderAccessor accessor = SimpMessageHeaderAccessor.create();
|
||||
accessor.setUserChangeCallback(userCallback);
|
||||
|
||||
Principal user1 = mock(Principal.class);
|
||||
accessor.setUser(user1);
|
||||
assertEquals(user1, userCallback.getUser());
|
||||
|
||||
Principal user2 = mock(Principal.class);
|
||||
accessor.setUser(user2);
|
||||
assertEquals(user2, userCallback.getUser());
|
||||
}
|
||||
|
||||
|
||||
private static class UserCallback implements Consumer<Principal> {
|
||||
|
||||
private Principal user;
|
||||
|
||||
|
||||
public Principal getUser() {
|
||||
return this.user;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void accept(Principal principal) {
|
||||
this.user = principal;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -258,9 +258,19 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
|
|||
MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
|
||||
Assert.state(headerAccessor != null, "No StompHeaderAccessor");
|
||||
|
||||
StompCommand command = headerAccessor.getCommand();
|
||||
boolean isConnect = StompCommand.CONNECT.equals(command) || StompCommand.STOMP.equals(command);
|
||||
|
||||
headerAccessor.setSessionId(session.getId());
|
||||
headerAccessor.setSessionAttributes(session.getAttributes());
|
||||
headerAccessor.setUser(getUser(session));
|
||||
if (isConnect) {
|
||||
headerAccessor.setUserChangeCallback(user -> {
|
||||
if (user != null && user != session.getPrincipal()) {
|
||||
this.stompAuthentications.put(session.getId(), user);
|
||||
}
|
||||
});
|
||||
}
|
||||
headerAccessor.setHeader(SimpMessageHeaderAccessor.HEART_BEAT_HEADER, headerAccessor.getHeartbeat());
|
||||
if (!detectImmutableMessageInterceptor(outputChannel)) {
|
||||
headerAccessor.setImmutable();
|
||||
|
@ -270,8 +280,6 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
|
|||
logger.trace("From client: " + headerAccessor.getShortLogMessage(message.getPayload()));
|
||||
}
|
||||
|
||||
StompCommand command = headerAccessor.getCommand();
|
||||
boolean isConnect = StompCommand.CONNECT.equals(command) || StompCommand.STOMP.equals(command);
|
||||
if (isConnect) {
|
||||
this.stats.incrementConnectCount();
|
||||
}
|
||||
|
@ -284,12 +292,6 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
|
|||
boolean sent = outputChannel.send(message);
|
||||
|
||||
if (sent) {
|
||||
if (isConnect) {
|
||||
Principal user = headerAccessor.getUser();
|
||||
if (user != null && user != session.getPrincipal()) {
|
||||
this.stompAuthentications.put(session.getId(), user);
|
||||
}
|
||||
}
|
||||
if (this.eventPublisher != null) {
|
||||
Principal user = getUser(session);
|
||||
if (isConnect) {
|
||||
|
|
|
@ -378,6 +378,15 @@ public class StompSubProtocolHandlerTests {
|
|||
Principal user = SimpMessageHeaderAccessor.getUser(message.getHeaders());
|
||||
assertNotNull(user);
|
||||
assertEquals("__pete__@gmail.com", user.getName());
|
||||
|
||||
StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.CONNECTED);
|
||||
message = MessageBuilder.createMessage(EMPTY_PAYLOAD, accessor.getMessageHeaders());
|
||||
handler.handleMessageToClient(this.session, message);
|
||||
|
||||
assertEquals(1, this.session.getSentMessages().size());
|
||||
WebSocketMessage<?> textMessage = this.session.getSentMessages().get(0);
|
||||
assertEquals("CONNECTED\n" + "user-name:__pete__@gmail.com\n" + "\n" + "\u0000",
|
||||
textMessage.getPayload());
|
||||
}
|
||||
|
||||
@Test
|
||||
|
|
Loading…
Reference in New Issue