diff --git a/messaging/src/main/java/org/springframework/security/messaging/access/intercept/MessageMatcherDelegatingAuthorizationManager.java b/messaging/src/main/java/org/springframework/security/messaging/access/intercept/MessageMatcherDelegatingAuthorizationManager.java index 32d2a09016..68c1be5d7b 100644 --- a/messaging/src/main/java/org/springframework/security/messaging/access/intercept/MessageMatcherDelegatingAuthorizationManager.java +++ b/messaging/src/main/java/org/springframework/security/messaging/access/intercept/MessageMatcherDelegatingAuthorizationManager.java @@ -18,6 +18,7 @@ package org.springframework.security.messaging.access.intercept; import java.util.ArrayList; import java.util.List; +import java.util.Map; import java.util.function.Supplier; import org.apache.commons.logging.Log; @@ -37,6 +38,7 @@ import org.springframework.security.messaging.util.matcher.SimpMessageTypeMatche import org.springframework.util.AntPathMatcher; import org.springframework.util.Assert; import org.springframework.util.PathMatcher; +import org.springframework.util.function.SingletonSupplier; public final class MessageMatcherDelegatingAuthorizationManager implements AuthorizationManager> { @@ -87,6 +89,10 @@ public final class MessageMatcherDelegatingAuthorizationManager implements Autho SimpDestinationMessageMatcher simp = (SimpDestinationMessageMatcher) matcher; return new MessageAuthorizationContext<>(message, simp.extractPathVariables(message)); } + if (matcher instanceof Builder.LazySimpDestinationMessageMatcher) { + Builder.LazySimpDestinationMessageMatcher path = (Builder.LazySimpDestinationMessageMatcher) matcher; + return new MessageAuthorizationContext<>(message, path.extractPathVariables(message)); + } return new MessageAuthorizationContext<>(message); } @@ -192,8 +198,7 @@ public final class MessageMatcherDelegatingAuthorizationManager implements Autho private Builder.Constraint simpDestMatchers(SimpMessageType type, String... patterns) { List> matchers = new ArrayList<>(patterns.length); for (String pattern : patterns) { - Supplier> supplier = new Builder.PathMatcherMessageMatcherBuilder(pattern, type); - MessageMatcher matcher = new Builder.SupplierMessageMatcher(supplier); + MessageMatcher matcher = new LazySimpDestinationMessageMatcher(pattern, type); matchers.add(matcher); } return new Builder.Constraint(matchers); @@ -375,58 +380,33 @@ public final class MessageMatcherDelegatingAuthorizationManager implements Autho } - private static final class SupplierMessageMatcher implements MessageMatcher { + private final class LazySimpDestinationMessageMatcher implements MessageMatcher { - private final Supplier> supplier; + private final Supplier delegate; - private volatile MessageMatcher delegate; - - SupplierMessageMatcher(Supplier> supplier) { - this.supplier = supplier; + private LazySimpDestinationMessageMatcher(String pattern, SimpMessageType type) { + this.delegate = SingletonSupplier.of(() -> { + PathMatcher pathMatcher = Builder.this.pathMatcher.get(); + if (type == null) { + return new SimpDestinationMessageMatcher(pattern, pathMatcher); + } + if (SimpMessageType.MESSAGE == type) { + return SimpDestinationMessageMatcher.createMessageMatcher(pattern, pathMatcher); + } + if (SimpMessageType.SUBSCRIBE == type) { + return SimpDestinationMessageMatcher.createSubscribeMatcher(pattern, pathMatcher); + } + throw new IllegalStateException(type + " is not supported since it does not have a destination"); + }); } @Override public boolean matches(Message message) { - if (this.delegate == null) { - synchronized (this.supplier) { - if (this.delegate == null) { - this.delegate = this.supplier.get(); - } - } - } - return this.delegate.matches(message); + return this.delegate.get().matches(message); } - } - - private final class PathMatcherMessageMatcherBuilder implements Supplier> { - - private final String pattern; - - private final SimpMessageType type; - - private PathMatcherMessageMatcherBuilder(String pattern, SimpMessageType type) { - this.pattern = pattern; - this.type = type; - } - - private PathMatcher resolvePathMatcher() { - return Builder.this.pathMatcher.get(); - } - - @Override - public MessageMatcher get() { - PathMatcher pathMatcher = resolvePathMatcher(); - if (this.type == null) { - return new SimpDestinationMessageMatcher(this.pattern, pathMatcher); - } - if (SimpMessageType.MESSAGE == this.type) { - return SimpDestinationMessageMatcher.createMessageMatcher(this.pattern, pathMatcher); - } - if (SimpMessageType.SUBSCRIBE == this.type) { - return SimpDestinationMessageMatcher.createSubscribeMatcher(this.pattern, pathMatcher); - } - throw new IllegalStateException(this.type + " is not supported since it does not have a destination"); + Map extractPathVariables(Message message) { + return this.delegate.get().extractPathVariables(message); } } diff --git a/messaging/src/test/java/org/springframework/security/messaging/access/intercept/MessageMatcherDelegatingAuthorizationManagerTests.java b/messaging/src/test/java/org/springframework/security/messaging/access/intercept/MessageMatcherDelegatingAuthorizationManagerTests.java new file mode 100644 index 0000000000..07089d87f6 --- /dev/null +++ b/messaging/src/test/java/org/springframework/security/messaging/access/intercept/MessageMatcherDelegatingAuthorizationManagerTests.java @@ -0,0 +1,128 @@ +/* + * Copyright 2002-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.messaging.access.intercept; + +import java.util.Map; +import java.util.function.Supplier; + +import org.junit.jupiter.api.Test; + +import org.springframework.messaging.Message; +import org.springframework.messaging.MessageHeaders; +import org.springframework.messaging.simp.SimpMessageHeaderAccessor; +import org.springframework.messaging.simp.SimpMessageType; +import org.springframework.messaging.support.GenericMessage; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.authorization.AuthorizationDecision; +import org.springframework.security.authorization.AuthorizationManager; +import org.springframework.security.core.Authentication; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.mock; + +/** + * Tests for {@link MessageMatcherDelegatingAuthorizationManager} + */ +public final class MessageMatcherDelegatingAuthorizationManagerTests { + + @Test + void checkWhenPermitAllThenPermits() { + AuthorizationManager> authorizationManager = builder().anyMessage().permitAll().build(); + Message message = new GenericMessage<>(new Object()); + assertThat(authorizationManager.check(mock(Supplier.class), message).isGranted()).isTrue(); + } + + @Test + void checkWhenAnyMessageHasRoleThenRequires() { + AuthorizationManager> authorizationManager = builder().anyMessage().hasRole("USER").build(); + Message message = new GenericMessage<>(new Object()); + Authentication user = new TestingAuthenticationToken("user", "password", "ROLE_USER"); + assertThat(authorizationManager.check(() -> user, message).isGranted()).isTrue(); + Authentication admin = new TestingAuthenticationToken("user", "password", "ROLE_ADMIN"); + assertThat(authorizationManager.check(() -> admin, message).isGranted()).isFalse(); + } + + @Test + void checkWhenSimpDestinationMatchesThenUses() { + AuthorizationManager> authorizationManager = builder().simpDestMatchers("destination").permitAll() + .anyMessage().denyAll().build(); + MessageHeaders headers = new MessageHeaders( + Map.of(SimpMessageHeaderAccessor.DESTINATION_HEADER, "destination")); + Message message = new GenericMessage<>(new Object(), headers); + assertThat(authorizationManager.check(mock(Supplier.class), message).isGranted()).isTrue(); + } + + @Test + void checkWhenNullDestinationHeaderMatchesThenUses() { + AuthorizationManager> authorizationManager = builder().nullDestMatcher().permitAll().anyMessage() + .denyAll().build(); + Message message = new GenericMessage<>(new Object()); + assertThat(authorizationManager.check(mock(Supplier.class), message).isGranted()).isTrue(); + MessageHeaders headers = new MessageHeaders( + Map.of(SimpMessageHeaderAccessor.DESTINATION_HEADER, "destination")); + message = new GenericMessage<>(new Object(), headers); + assertThat(authorizationManager.check(mock(Supplier.class), message).isGranted()).isFalse(); + } + + @Test + void checkWhenSimpTypeMatchesThenUses() { + AuthorizationManager> authorizationManager = builder().simpTypeMatchers(SimpMessageType.CONNECT) + .permitAll().anyMessage().denyAll().build(); + MessageHeaders headers = new MessageHeaders( + Map.of(SimpMessageHeaderAccessor.MESSAGE_TYPE_HEADER, SimpMessageType.CONNECT)); + Message message = new GenericMessage<>(new Object(), headers); + assertThat(authorizationManager.check(mock(Supplier.class), message).isGranted()).isTrue(); + } + + // gh-12540 + @Test + void checkWhenSimpDestinationMatchesThenVariablesExtracted() { + AuthorizationManager> authorizationManager = builder().simpDestMatchers("destination/{id}") + .access(variable("id").isEqualTo("3")).anyMessage().denyAll().build(); + MessageHeaders headers = new MessageHeaders( + Map.of(SimpMessageHeaderAccessor.DESTINATION_HEADER, "destination/3")); + Message message = new GenericMessage<>(new Object(), headers); + assertThat(authorizationManager.check(mock(Supplier.class), message).isGranted()).isTrue(); + } + + private MessageMatcherDelegatingAuthorizationManager.Builder builder() { + return MessageMatcherDelegatingAuthorizationManager.builder(); + } + + private Builder variable(String name) { + return new Builder(name); + + } + + private static final class Builder { + + private final String name; + + private Builder(String name) { + this.name = name; + } + + AuthorizationManager> isEqualTo(String value) { + return (authentication, object) -> { + String extracted = object.getVariables().get(this.name); + return new AuthorizationDecision(value.equals(extracted)); + }; + } + + } + +}