diff --git a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/security/servlet/ApplicationContextRequestMatcher.java b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/security/servlet/ApplicationContextRequestMatcher.java index 1e2b661b968..a56626e873e 100644 --- a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/security/servlet/ApplicationContextRequestMatcher.java +++ b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/security/servlet/ApplicationContextRequestMatcher.java @@ -16,7 +16,6 @@ package org.springframework.boot.security.servlet; -import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Supplier; import javax.servlet.http.HttpServletRequest; @@ -44,7 +43,9 @@ public abstract class ApplicationContextRequestMatcher implements RequestMatc private final Class contextClass; - private final AtomicBoolean initialized = new AtomicBoolean(false); + private volatile boolean initialized; + + private final Object initializeLock = new Object(); public ApplicationContextRequestMatcher(Class contextClass) { Assert.notNull(contextClass, "Context class must not be null"); @@ -59,8 +60,13 @@ public abstract class ApplicationContextRequestMatcher implements RequestMatc return false; } Supplier context = () -> getContext(webApplicationContext); - if (this.initialized.compareAndSet(false, true)) { - initialized(context); + if (!this.initialized) { + synchronized (this.initializeLock) { + if (!this.initialized) { + initialized(context); + this.initialized = true; + } + } } return matches(request, context); } @@ -89,7 +95,7 @@ public abstract class ApplicationContextRequestMatcher implements RequestMatc * Method that can be implemented by subclasses that wish to initialize items the * first time that the matcher is called. This method will be called only once and * only if {@link #ignoreApplicationContext(WebApplicationContext)} returns - * {@code true}. Note that the supplied context will be based on the + * {@code false}. Note that the supplied context will be based on the * first request sent to the matcher. * @param context a supplier for the initialized context (may throw an exception) * @see #ignoreApplicationContext(WebApplicationContext) diff --git a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/security/servlet/ApplicationContextRequestMatcherTests.java b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/security/servlet/ApplicationContextRequestMatcherTests.java index de180d3de23..12448b3d127 100644 --- a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/security/servlet/ApplicationContextRequestMatcherTests.java +++ b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/security/servlet/ApplicationContextRequestMatcherTests.java @@ -16,6 +16,10 @@ package org.springframework.boot.security.servlet; +import java.lang.Thread.UncaughtExceptionHandler; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Supplier; import javax.servlet.http.HttpServletRequest; @@ -26,6 +30,7 @@ import org.springframework.beans.factory.NoSuchBeanDefinitionException; import org.springframework.context.ApplicationContext; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockServletContext; +import org.springframework.util.ReflectionUtils; import org.springframework.web.context.WebApplicationContext; import org.springframework.web.context.support.StaticWebApplicationContext; @@ -105,6 +110,31 @@ class ApplicationContextRequestMatcherTests { assertThat(matcher.matches(request)).isFalse(); } + @Test // gh-18211 + void matchesWhenConcurrentlyCalledWaitsForInitialize() { + ConcurrentApplicationContextRequestMatcher matcher = new ConcurrentApplicationContextRequestMatcher(); + StaticWebApplicationContext context = createWebApplicationContext(); + Runnable target = () -> matcher.matches(new MockHttpServletRequest(context.getServletContext())); + List threads = new ArrayList<>(); + AssertingUncaughtExceptionHandler exceptionHandler = new AssertingUncaughtExceptionHandler(); + for (int i = 0; i < 2; i++) { + Thread thread = new Thread(target); + thread.setUncaughtExceptionHandler(exceptionHandler); + threads.add(thread); + } + threads.forEach(Thread::start); + threads.forEach(this::join); + exceptionHandler.assertNoExceptions(); + } + + private void join(Thread thread) { + try { + thread.join(1000); + } + catch (InterruptedException ex) { + } + } + private StaticWebApplicationContext createWebApplicationContext() { StaticWebApplicationContext context = new StaticWebApplicationContext(); MockServletContext servletContext = new MockServletContext(); @@ -160,4 +190,47 @@ class ApplicationContextRequestMatcherTests { } + static class ConcurrentApplicationContextRequestMatcher extends ApplicationContextRequestMatcher { + + ConcurrentApplicationContextRequestMatcher() { + super(Object.class); + } + + private AtomicBoolean initialized = new AtomicBoolean(); + + @Override + protected void initialized(Supplier context) { + try { + Thread.sleep(200); + } + catch (InterruptedException ex) { + } + this.initialized.set(true); + } + + @Override + protected boolean matches(HttpServletRequest request, Supplier context) { + assertThat(this.initialized.get()).isTrue(); + return true; + } + + } + + private static class AssertingUncaughtExceptionHandler implements UncaughtExceptionHandler { + + private volatile Throwable ex; + + @Override + public void uncaughtException(Thread thead, Throwable ex) { + this.ex = ex; + } + + void assertNoExceptions() { + if (this.ex != null) { + ReflectionUtils.rethrowRuntimeException(this.ex); + } + } + + } + }