diff --git a/spring-context/src/test/java/org/springframework/context/i18n/LocaleContextThreadLocalAccessorTests.java b/spring-context/src/test/java/org/springframework/context/i18n/LocaleContextThreadLocalAccessorTests.java index d4b2e6263d..748397e485 100644 --- a/spring-context/src/test/java/org/springframework/context/i18n/LocaleContextThreadLocalAccessorTests.java +++ b/spring-context/src/test/java/org/springframework/context/i18n/LocaleContextThreadLocalAccessorTests.java @@ -17,72 +17,71 @@ package org.springframework.context.i18n; import java.util.Locale; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Stream; import io.micrometer.context.ContextRegistry; import io.micrometer.context.ContextSnapshot; import io.micrometer.context.ContextSnapshotFactory; -import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; -import org.springframework.lang.Nullable; - import static org.assertj.core.api.Assertions.assertThat; /** * Tests for {@link LocaleContextThreadLocalAccessor}. * * @author Tadaya Tsuyukubo + * @author Rossen Stoyanchev */ class LocaleContextThreadLocalAccessorTests { - private final ContextRegistry registry = new ContextRegistry() - .registerThreadLocalAccessor(new LocaleContextThreadLocalAccessor()); + private final ContextRegistry registry = + new ContextRegistry().registerThreadLocalAccessor(new LocaleContextThreadLocalAccessor()); - @AfterEach - void cleanUp() { - LocaleContextHolder.resetLocaleContext(); + + private static Stream propagation() { + LocaleContext previousContext = new SimpleLocaleContext(Locale.ENGLISH); + LocaleContext currentContext = new SimpleLocaleContext(Locale.ENGLISH); + return Stream.of(Arguments.of(null, currentContext), Arguments.of(previousContext, currentContext)); } @ParameterizedTest @MethodSource - @SuppressWarnings("try") - void propagation(@Nullable LocaleContext previous, LocaleContext current) throws Exception { - LocaleContextHolder.setLocaleContext(current); - ContextSnapshot snapshot = ContextSnapshotFactory.builder() - .contextRegistry(this.registry) - .clearMissing(true) - .build() - .captureAll(); + @SuppressWarnings({ "try", "unused" }) + void propagation(LocaleContext previousContext, LocaleContext currentContext) throws Exception { - AtomicReference previousHolder = new AtomicReference<>(); - AtomicReference currentHolder = new AtomicReference<>(); - CountDownLatch latch = new CountDownLatch(1); - new Thread(() -> { - LocaleContextHolder.setLocaleContext(previous); + ContextSnapshot snapshot = createContextSnapshotFor(currentContext); + + AtomicReference contextInScope = new AtomicReference<>(); + AtomicReference contextAfterScope = new AtomicReference<>(); + + Thread thread = new Thread(() -> { + LocaleContextHolder.setLocaleContext(previousContext); try (ContextSnapshot.Scope scope = snapshot.setThreadLocals()) { - currentHolder.set(LocaleContextHolder.getLocaleContext()); + contextInScope.set(LocaleContextHolder.getLocaleContext()); } - previousHolder.set(LocaleContextHolder.getLocaleContext()); - latch.countDown(); - }).start(); + contextAfterScope.set(LocaleContextHolder.getLocaleContext()); + }); - latch.await(1, TimeUnit.SECONDS); - assertThat(previousHolder).hasValueSatisfying(value -> assertThat(value).isSameAs(previous)); - assertThat(currentHolder).hasValueSatisfying(value -> assertThat(value).isSameAs(current)); + thread.start(); + thread.join(1000); + + assertThat(contextAfterScope).hasValueSatisfying(value -> assertThat(value).isSameAs(previousContext)); + assertThat(contextInScope).hasValueSatisfying(value -> assertThat(value).isSameAs(currentContext)); } - private static Stream propagation() { - LocaleContext previous = new SimpleLocaleContext(Locale.ENGLISH); - LocaleContext current = new SimpleLocaleContext(Locale.ENGLISH); - return Stream.of( - Arguments.of(null, current), - Arguments.of(previous, current) - ); + private ContextSnapshot createContextSnapshotFor(LocaleContext context) { + LocaleContextHolder.setLocaleContext(context); + try { + return ContextSnapshotFactory.builder() + .contextRegistry(this.registry).clearMissing(true).build() + .captureAll(); + } + finally { + LocaleContextHolder.resetLocaleContext(); + } } + } diff --git a/spring-web/src/test/java/org/springframework/web/context/request/RequestAttributesThreadLocalAccessorTests.java b/spring-web/src/test/java/org/springframework/web/context/request/RequestAttributesThreadLocalAccessorTests.java index 4a4434ed05..dd51f5fcfb 100644 --- a/spring-web/src/test/java/org/springframework/web/context/request/RequestAttributesThreadLocalAccessorTests.java +++ b/spring-web/src/test/java/org/springframework/web/context/request/RequestAttributesThreadLocalAccessorTests.java @@ -16,8 +16,6 @@ package org.springframework.web.context.request; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Stream; @@ -25,13 +23,10 @@ import io.micrometer.context.ContextRegistry; import io.micrometer.context.ContextSnapshot; import io.micrometer.context.ContextSnapshot.Scope; import io.micrometer.context.ContextSnapshotFactory; -import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; -import org.springframework.lang.Nullable; - import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.Mockito.mock; @@ -39,52 +34,55 @@ import static org.mockito.Mockito.mock; * Tests for {@link RequestAttributesThreadLocalAccessor}. * * @author Tadaya Tsuyukubo + * @author Rossen Stoyanchev */ class RequestAttributesThreadLocalAccessorTests { - private final ContextRegistry registry = new ContextRegistry() - .registerThreadLocalAccessor(new RequestAttributesThreadLocalAccessor()); + private final ContextRegistry registry = + new ContextRegistry().registerThreadLocalAccessor(new RequestAttributesThreadLocalAccessor()); - @AfterEach - void cleanUp() { - RequestContextHolder.resetRequestAttributes(); + + private static Stream propagation() { + RequestAttributes previous = mock(RequestAttributes.class); + RequestAttributes current = mock(RequestAttributes.class); + return Stream.of(Arguments.of(null, current), Arguments.of(previous, current)); } @ParameterizedTest @MethodSource @SuppressWarnings({ "try", "unused" }) - void propagation(@Nullable RequestAttributes previous, RequestAttributes current) throws Exception { - RequestContextHolder.setRequestAttributes(current); - ContextSnapshot snapshot = ContextSnapshotFactory.builder() - .contextRegistry(this.registry) - .clearMissing(true) - .build() - .captureAll(); + void propagation(RequestAttributes previousRequest, RequestAttributes currentRequest) throws Exception { - AtomicReference previousHolder = new AtomicReference<>(); - AtomicReference currentHolder = new AtomicReference<>(); - CountDownLatch latch = new CountDownLatch(1); - new Thread(() -> { - RequestContextHolder.setRequestAttributes(previous); + ContextSnapshot snapshot = getSnapshotFor(currentRequest); + + AtomicReference requestInScope = new AtomicReference<>(); + AtomicReference requestAfterScope = new AtomicReference<>(); + + Thread thread = new Thread(() -> { + RequestContextHolder.setRequestAttributes(previousRequest); try (Scope scope = snapshot.setThreadLocals()) { - currentHolder.set(RequestContextHolder.getRequestAttributes()); + requestInScope.set(RequestContextHolder.getRequestAttributes()); } - previousHolder.set(RequestContextHolder.getRequestAttributes()); - latch.countDown(); - }).start(); + requestAfterScope.set(RequestContextHolder.getRequestAttributes()); + }); - latch.await(1, TimeUnit.SECONDS); - assertThat(previousHolder).hasValueSatisfying(value -> assertThat(value).isSameAs(previous)); - assertThat(currentHolder).hasValueSatisfying(value -> assertThat(value).isSameAs(current)); + thread.start(); + thread.join(1000); + + assertThat(requestInScope).hasValueSatisfying(value -> assertThat(value).isSameAs(currentRequest)); + assertThat(requestAfterScope).hasValueSatisfying(value -> assertThat(value).isSameAs(previousRequest)); } - private static Stream propagation() { - RequestAttributes previous = mock(RequestAttributes.class); - RequestAttributes current = mock(RequestAttributes.class); - return Stream.of( - Arguments.of(null, current), - Arguments.of(previous, current) - ); + private ContextSnapshot getSnapshotFor(RequestAttributes request) { + RequestContextHolder.setRequestAttributes(request); + try { + return ContextSnapshotFactory.builder() + .contextRegistry(this.registry).clearMissing(true).build() + .captureAll(); + } + finally { + RequestContextHolder.resetRequestAttributes(); + } } }