From fab889009a617a63634be0754844297da4437cdb Mon Sep 17 00:00:00 2001 From: Brian Clozel Date: Fri, 30 Aug 2024 20:08:36 +0200 Subject: [PATCH 1/2] Support multiple async starts in MockHttpServletRequest Closes gh-33457 --- .../mock/web/MockHttpServletRequest.java | 16 ++++- .../mock/web/MockHttpServletRequestTests.java | 68 +++++++++++++++++++ .../servlet/MockHttpServletRequest.java | 16 ++++- 3 files changed, 98 insertions(+), 2 deletions(-) diff --git a/spring-test/src/main/java/org/springframework/mock/web/MockHttpServletRequest.java b/spring-test/src/main/java/org/springframework/mock/web/MockHttpServletRequest.java index 0880440d02..37256c15ef 100644 --- a/spring-test/src/main/java/org/springframework/mock/web/MockHttpServletRequest.java +++ b/spring-test/src/main/java/org/springframework/mock/web/MockHttpServletRequest.java @@ -44,6 +44,8 @@ import java.util.TimeZone; import java.util.stream.Collectors; import jakarta.servlet.AsyncContext; +import jakarta.servlet.AsyncEvent; +import jakarta.servlet.AsyncListener; import jakarta.servlet.DispatcherType; import jakarta.servlet.RequestDispatcher; import jakarta.servlet.ServletConnection; @@ -920,7 +922,19 @@ public class MockHttpServletRequest implements HttpServletRequest { public AsyncContext startAsync(ServletRequest request, @Nullable ServletResponse response) { Assert.state(this.asyncSupported, "Async not supported"); this.asyncStarted = true; - this.asyncContext = new MockAsyncContext(request, response); + MockAsyncContext newAsyncContext = new MockAsyncContext(request, response); + if (this.asyncContext != null) { + try { + AsyncEvent startEvent = new AsyncEvent(newAsyncContext); + for (AsyncListener asyncListener : this.asyncContext.getListeners()) { + asyncListener.onStartAsync(startEvent); + } + } + catch (IOException ex) { + // ignore failures + } + } + this.asyncContext = newAsyncContext; return this.asyncContext; } diff --git a/spring-test/src/test/java/org/springframework/mock/web/MockHttpServletRequestTests.java b/spring-test/src/test/java/org/springframework/mock/web/MockHttpServletRequestTests.java index 963846482b..2dd438f31b 100644 --- a/spring-test/src/test/java/org/springframework/mock/web/MockHttpServletRequestTests.java +++ b/spring-test/src/test/java/org/springframework/mock/web/MockHttpServletRequestTests.java @@ -30,6 +30,9 @@ import java.util.List; import java.util.Locale; import java.util.Map; +import jakarta.servlet.AsyncContext; +import jakarta.servlet.AsyncEvent; +import jakarta.servlet.AsyncListener; import jakarta.servlet.http.Cookie; import org.junit.jupiter.api.Test; @@ -663,6 +666,44 @@ class MockHttpServletRequestTests { request.getDateHeader(HttpHeaders.IF_MODIFIED_SINCE)); } + @Test + void shouldRejectAsyncStartsIfUnsupported() { + assertThat(request.isAsyncStarted()).isFalse(); + assertThatIllegalStateException().isThrownBy(request::startAsync); + } + + @Test + void startAsyncShouldUpdateRequestState() { + assertThat(request.isAsyncStarted()).isFalse(); + request.setAsyncSupported(true); + AsyncContext asyncContext = request.startAsync(); + assertThat(request.isAsyncStarted()).isTrue(); + } + + @Test + void shouldNotifyAsyncListeners() { + request.setAsyncSupported(true); + AsyncContext asyncContext = request.startAsync(); + TestAsyncListener testAsyncListener = new TestAsyncListener(); + asyncContext.addListener(testAsyncListener); + asyncContext.complete(); + assertThat(testAsyncListener.events).hasSize(1); + assertThat(testAsyncListener.events.get(0)).extracting("name").isEqualTo("onComplete"); + } + + @Test + void shouldNotifyAsyncListenersWhenNewAsyncStarted() { + request.setAsyncSupported(true); + AsyncContext asyncContext = request.startAsync(); + TestAsyncListener testAsyncListener = new TestAsyncListener(); + asyncContext.addListener(testAsyncListener); + AsyncContext newAsyncContext = request.startAsync(); + assertThat(testAsyncListener.events).hasSize(1); + ListenerEvent listenerEvent = testAsyncListener.events.get(0); + assertThat(listenerEvent).extracting("name").isEqualTo("onStartAsync"); + assertThat(listenerEvent.event.getAsyncContext()).isEqualTo(newAsyncContext); + } + private void assertEqualEnumerations(Enumeration enum1, Enumeration enum2) { int count = 0; while (enum1.hasMoreElements()) { @@ -672,4 +713,31 @@ class MockHttpServletRequestTests { } } + static class TestAsyncListener implements AsyncListener { + + List events = new ArrayList<>(); + + @Override + public void onComplete(AsyncEvent asyncEvent) throws IOException { + this.events.add(new ListenerEvent("onComplete", asyncEvent)); + } + + @Override + public void onTimeout(AsyncEvent asyncEvent) throws IOException { + this.events.add(new ListenerEvent("onTimeout", asyncEvent)); + } + + @Override + public void onError(AsyncEvent asyncEvent) throws IOException { + this.events.add(new ListenerEvent("onError", asyncEvent)); + } + + @Override + public void onStartAsync(AsyncEvent asyncEvent) throws IOException { + this.events.add(new ListenerEvent("onStartAsync", asyncEvent)); + } + } + + record ListenerEvent(String name, AsyncEvent event) {} + } diff --git a/spring-web/src/testFixtures/java/org/springframework/web/testfixture/servlet/MockHttpServletRequest.java b/spring-web/src/testFixtures/java/org/springframework/web/testfixture/servlet/MockHttpServletRequest.java index 59eb0b977c..31ac9d1482 100644 --- a/spring-web/src/testFixtures/java/org/springframework/web/testfixture/servlet/MockHttpServletRequest.java +++ b/spring-web/src/testFixtures/java/org/springframework/web/testfixture/servlet/MockHttpServletRequest.java @@ -44,6 +44,8 @@ import java.util.TimeZone; import java.util.stream.Collectors; import jakarta.servlet.AsyncContext; +import jakarta.servlet.AsyncEvent; +import jakarta.servlet.AsyncListener; import jakarta.servlet.DispatcherType; import jakarta.servlet.RequestDispatcher; import jakarta.servlet.ServletConnection; @@ -921,7 +923,19 @@ public class MockHttpServletRequest implements HttpServletRequest { public AsyncContext startAsync(ServletRequest request, @Nullable ServletResponse response) { Assert.state(this.asyncSupported, "Async not supported"); this.asyncStarted = true; - this.asyncContext = new MockAsyncContext(request, response); + MockAsyncContext newAsyncContext = new MockAsyncContext(request, response); + if (this.asyncContext != null) { + try { + AsyncEvent startEvent = new AsyncEvent(newAsyncContext); + for (AsyncListener asyncListener : this.asyncContext.getListeners()) { + asyncListener.onStartAsync(startEvent); + } + } + catch (IOException ex) { + // ignore failures + } + } + this.asyncContext = newAsyncContext; return this.asyncContext; } From debba6545ba5b39b0a1f485b67091b6be76ccfca Mon Sep 17 00:00:00 2001 From: Brian Clozel Date: Fri, 30 Aug 2024 20:09:21 +0200 Subject: [PATCH 2/2] Listen to multiple async operations in ServerHttpObservationFilter Prior to this commit, the `ServerHttpObservationFilter` was fixed to re-enable instrumentation for async dispatches. This fix involves using an AsyncListener to be notified of exchange completion. This change was incomplete, as this would not work in some cases. If another filter starts the async mode and initiates an ASYNC dispatch, before async handling at the controller level, the async listener is not registered against subsequent async starts. This commit not only ensures that the async listener registers against new async starts, but also ensure that the initial creation and registration only happens during the initial REQUEST dispatch. Fixes gh-33451 --- .../filter/ServerHttpObservationFilter.java | 7 ++-- .../ServerHttpObservationFilterTests.java | 42 +++++++++++++++++++ 2 files changed, 46 insertions(+), 3 deletions(-) diff --git a/spring-web/src/main/java/org/springframework/web/filter/ServerHttpObservationFilter.java b/spring-web/src/main/java/org/springframework/web/filter/ServerHttpObservationFilter.java index e56b297349..3c4222e986 100644 --- a/spring-web/src/main/java/org/springframework/web/filter/ServerHttpObservationFilter.java +++ b/spring-web/src/main/java/org/springframework/web/filter/ServerHttpObservationFilter.java @@ -118,13 +118,13 @@ public class ServerHttpObservationFilter extends OncePerRequestFilter { throw ex; } finally { - // If async is started, register a listener for completion notification. - if (request.isAsyncStarted()) { + // If async is started during the first dispatch, register a listener for completion notification. + if (request.isAsyncStarted() && request.getDispatcherType() == DispatcherType.REQUEST) { request.getAsyncContext().addListener(new ObservationAsyncListener(observation)); } // scope is opened for ASYNC dispatches, but the observation will be closed // by the async listener. - else if (request.getDispatcherType() != DispatcherType.ASYNC){ + else if (!isAsyncDispatch(request)) { Throwable error = fetchException(request); if (error != null) { observation.error(error); @@ -168,6 +168,7 @@ public class ServerHttpObservationFilter extends OncePerRequestFilter { @Override public void onStartAsync(AsyncEvent event) { + event.getAsyncContext().addListener(this); } @Override diff --git a/spring-web/src/test/java/org/springframework/web/filter/ServerHttpObservationFilterTests.java b/spring-web/src/test/java/org/springframework/web/filter/ServerHttpObservationFilterTests.java index 1edc589b13..ddbbad37cf 100644 --- a/spring-web/src/test/java/org/springframework/web/filter/ServerHttpObservationFilterTests.java +++ b/spring-web/src/test/java/org/springframework/web/filter/ServerHttpObservationFilterTests.java @@ -21,11 +21,16 @@ import java.io.IOException; import io.micrometer.observation.ObservationRegistry; import io.micrometer.observation.tck.TestObservationRegistry; import io.micrometer.observation.tck.TestObservationRegistryAssert; +import jakarta.servlet.AsyncContext; import jakarta.servlet.AsyncEvent; import jakarta.servlet.AsyncListener; import jakarta.servlet.DispatcherType; +import jakarta.servlet.Filter; +import jakarta.servlet.FilterChain; import jakarta.servlet.RequestDispatcher; import jakarta.servlet.ServletException; +import jakarta.servlet.ServletRequest; +import jakarta.servlet.ServletResponse; import jakarta.servlet.http.HttpServlet; import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletResponse; @@ -139,6 +144,21 @@ class ServerHttpObservationFilterTests { assertThatHttpObservation().hasLowCardinalityKeyValue("outcome", "SUCCESS").hasBeenStopped(); } + @Test + void shouldRegisterListenerForAsyncStarts() throws Exception { + CustomAsyncFilter customAsyncFilter = new CustomAsyncFilter(); + this.mockFilterChain = new MockFilterChain(new NoOpServlet(), customAsyncFilter); + this.request.setAsyncSupported(true); + this.request.setDispatcherType(DispatcherType.REQUEST); + this.filter.doFilter(this.request, this.response, this.mockFilterChain); + customAsyncFilter.asyncContext.dispatch(); + this.request.setDispatcherType(DispatcherType.ASYNC); + AsyncContext newAsyncContext = this.request.startAsync(); + newAsyncContext.complete(); + + assertThatHttpObservation().hasLowCardinalityKeyValue("outcome", "SUCCESS").hasBeenStopped(); + } + @Test void shouldCloseObservationAfterAsyncError() throws Exception { this.request.setAsyncSupported(true); @@ -187,4 +207,26 @@ class ServerHttpObservationFilterTests { } } + @SuppressWarnings("serial") + static class NoOpServlet extends HttpServlet { + + @Override + protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { + + } + + } + + static class CustomAsyncFilter implements Filter { + + AsyncContext asyncContext; + + @Override + public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException { + this.asyncContext = servletRequest.startAsync(); + filterChain.doFilter(servletRequest, servletResponse); + } + + } + }