diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/config/annotation/AsyncSupportConfigurer.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/config/annotation/AsyncSupportConfigurer.java index a45b99bfcf..34cb26eb25 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/config/annotation/AsyncSupportConfigurer.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/config/annotation/AsyncSupportConfigurer.java @@ -16,6 +16,7 @@ package org.springframework.web.servlet.config.annotation; +import java.time.Duration; import java.util.ArrayList; import java.util.Arrays; import java.util.List; @@ -32,6 +33,7 @@ import org.springframework.web.context.request.async.DeferredResultProcessingInt * Helps with configuring options for asynchronous request processing. * * @author Rossen Stoyanchev + * @author Réda Housni Alaoui * @since 3.2 */ public class AsyncSupportConfigurer { @@ -44,6 +46,8 @@ public class AsyncSupportConfigurer { private final List deferredResultInterceptors = new ArrayList<>(); + private @Nullable Duration sseHeartbeatPeriod; + /** * The provided task executor is used for the following: @@ -99,6 +103,14 @@ public class AsyncSupportConfigurer { return this; } + /** + * Configure the SSE heartbeat period. + * @param sseHeartbeatPeriod The SSE heartbeat period + */ + public AsyncSupportConfigurer setSseHeartbeatPeriod(Duration sseHeartbeatPeriod) { + this.sseHeartbeatPeriod = sseHeartbeatPeriod; + return this; + } protected @Nullable AsyncTaskExecutor getTaskExecutor() { return this.taskExecutor; @@ -116,4 +128,8 @@ public class AsyncSupportConfigurer { return this.deferredResultInterceptors; } + protected @Nullable Duration getSseHeartbeatPeriod() { + return this.sseHeartbeatPeriod; + } + } diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/config/annotation/WebMvcConfigurationSupport.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/config/annotation/WebMvcConfigurationSupport.java index 8d031b7990..207a9ddd34 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/config/annotation/WebMvcConfigurationSupport.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/config/annotation/WebMvcConfigurationSupport.java @@ -22,6 +22,7 @@ import java.util.HashMap; import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.Optional; import jakarta.servlet.ServletContext; import org.jspecify.annotations.Nullable; @@ -693,6 +694,7 @@ public class WebMvcConfigurationSupport implements ApplicationContextAware, Serv } adapter.setCallableInterceptors(configurer.getCallableInterceptors()); adapter.setDeferredResultInterceptors(configurer.getDeferredResultInterceptors()); + Optional.ofNullable(configurer.getSseHeartbeatPeriod()).ifPresent(adapter::setSseHeartbeatPeriod); return adapter; } diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/DefaultSseEmitterHeartbeatExecutor.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/DefaultSseEmitterHeartbeatExecutor.java deleted file mode 100644 index 438ecab647..0000000000 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/DefaultSseEmitterHeartbeatExecutor.java +++ /dev/null @@ -1,125 +0,0 @@ -/* - * Copyright 2002-2025 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.web.servlet.mvc.method.annotation; - - -import java.io.IOException; -import java.time.Duration; -import java.util.Set; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ScheduledFuture; - -import org.jspecify.annotations.Nullable; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.springframework.context.SmartLifecycle; -import org.springframework.http.MediaType; -import org.springframework.scheduling.TaskScheduler; - -/** - * @author Réda Housni Alaoui - */ -public class DefaultSseEmitterHeartbeatExecutor implements SmartLifecycle, SseEmitterHeartbeatExecutor { - - private static final Logger LOGGER = LoggerFactory.getLogger(DefaultSseEmitterHeartbeatExecutor.class); - - private final TaskScheduler taskScheduler; - private final Set emitters = ConcurrentHashMap.newKeySet(); - - private final Object lifecycleMonitor = new Object(); - - private Duration period = Duration.ofSeconds(5); - private String eventName = "ping"; - private String eventObject = "ping"; - - private volatile boolean running; - @Nullable - private volatile ScheduledFuture taskFuture; - - public DefaultSseEmitterHeartbeatExecutor(TaskScheduler taskScheduler) { - this.taskScheduler = taskScheduler; - } - - public void setPeriod(Duration period) { - this.period = period; - } - - public void setEventName(String eventName) { - this.eventName = eventName; - } - - public void setEventObject(String eventObject) { - this.eventObject = eventObject; - } - - @Override - public void start() { - synchronized (lifecycleMonitor) { - taskFuture = taskScheduler.scheduleAtFixedRate(this::ping, period); - running = true; - } - } - - @Override - public void register(SseEmitter emitter) { - Runnable closeCallback = () -> emitters.remove(emitter); - emitter.onCompletion(closeCallback); - emitter.onError(t -> closeCallback.run()); - emitter.onTimeout(closeCallback); - - emitters.add(emitter); - } - - @Override - public void stop() { - synchronized (lifecycleMonitor) { - ScheduledFuture future = taskFuture; - if (future != null) { - future.cancel(true); - } - emitters.clear(); - running = false; - } - } - - @Override - public boolean isRunning() { - return running; - } - - boolean isRegistered(SseEmitter emitter) { - return emitters.contains(emitter); - } - - private void ping() { - LOGGER.atDebug().log(() -> "Pinging %s emitter(s)".formatted(emitters.size())); - - for (SseEmitter emitter : emitters) { - if (Thread.currentThread().isInterrupted()) { - return; - } - LOGGER.trace("Pinging {}", emitter); - SseEmitter.SseEventBuilder eventBuilder = SseEmitter.event().name(eventName).data(eventObject, MediaType.TEXT_PLAIN); - try { - emitter.send(eventBuilder); - } catch (IOException | RuntimeException e) { - // According to SseEmitter's Javadoc, the container itself will call SseEmitter#completeWithError - LOGGER.debug(e.getMessage()); - } - } - } -} diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/RequestMappingHandlerAdapter.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/RequestMappingHandlerAdapter.java index 577d86b215..8a3aa2ea13 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/RequestMappingHandlerAdapter.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/RequestMappingHandlerAdapter.java @@ -17,11 +17,13 @@ package org.springframework.web.servlet.mvc.method.annotation; import java.lang.reflect.Method; +import java.time.Duration; import java.util.ArrayList; import java.util.Collections; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; import java.util.concurrent.Callable; import java.util.concurrent.ConcurrentHashMap; @@ -54,6 +56,8 @@ import org.springframework.http.converter.ByteArrayHttpMessageConverter; import org.springframework.http.converter.HttpMessageConverter; import org.springframework.http.converter.StringHttpMessageConverter; import org.springframework.http.converter.support.AllEncompassingFormHttpMessageConverter; +import org.springframework.scheduling.TaskScheduler; +import org.springframework.scheduling.concurrent.SimpleAsyncTaskScheduler; import org.springframework.ui.ModelMap; import org.springframework.util.Assert; import org.springframework.util.ClassUtils; @@ -202,8 +206,9 @@ public class RequestMappingHandlerAdapter extends AbstractHandlerMethodAdapter private final Map> modelAttributeAdviceCache = new LinkedHashMap<>(); - @Nullable - private SseEmitterHeartbeatExecutor sseEmitterHeartbeatExecutor; + private TaskScheduler taskScheduler = new SimpleAsyncTaskScheduler(); + + private @Nullable Duration sseHeartbeatPeriod; /** * Provide resolvers for custom argument types. Custom resolvers are ordered @@ -530,10 +535,17 @@ public class RequestMappingHandlerAdapter extends AbstractHandlerMethodAdapter } /** - * Set the {@link SseEmitterHeartbeatExecutor} that will be used to periodically prob the SSE connection health + * Set the {@link TaskScheduler} */ - public void setSseEmitterHeartbeatExecutor(@Nullable SseEmitterHeartbeatExecutor sseEmitterHeartbeatExecutor) { - this.sseEmitterHeartbeatExecutor = sseEmitterHeartbeatExecutor; + public void setTaskScheduler(TaskScheduler taskScheduler) { + this.taskScheduler = taskScheduler; + } + + /** + * Sets the heartbeat period that will be used to periodically prob the SSE connection health + */ + public void setSseHeartbeatPeriod(@Nullable Duration sseHeartbeatPeriod) { + this.sseHeartbeatPeriod = sseHeartbeatPeriod; } /** @@ -743,9 +755,12 @@ public class RequestMappingHandlerAdapter extends AbstractHandlerMethodAdapter handlers.add(new ModelAndViewMethodReturnValueHandler()); handlers.add(new ModelMethodProcessor()); handlers.add(new ViewMethodReturnValueHandler()); + + SseEmitterHeartbeatExecutor sseEmitterHeartbeatExecutor = Optional.ofNullable(sseHeartbeatPeriod) + .map(period -> new SseEmitterHeartbeatExecutor(taskScheduler, period)).orElse(null); handlers.add(new ResponseBodyEmitterReturnValueHandler(getMessageConverters(), this.reactiveAdapterRegistry, this.taskExecutor, this.contentNegotiationManager, - initViewResolvers(), initLocaleResolver(), this.sseEmitterHeartbeatExecutor)); + initViewResolvers(), initLocaleResolver(), sseEmitterHeartbeatExecutor)); handlers.add(new StreamingResponseBodyReturnValueHandler()); handlers.add(new HttpEntityMethodProcessor(getMessageConverters(), this.contentNegotiationManager, this.requestResponseBodyAdvice, this.errorResponseInterceptors)); diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/SseEmitter.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/SseEmitter.java index 6cfda938f3..23b06ee323 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/SseEmitter.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/SseEmitter.java @@ -18,14 +18,18 @@ package org.springframework.web.servlet.mvc.method.annotation; import java.io.IOException; import java.nio.charset.StandardCharsets; +import java.time.Duration; import java.util.Collections; import java.util.LinkedHashSet; +import java.util.Optional; import java.util.Set; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; import org.jspecify.annotations.Nullable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; import org.springframework.http.server.ServerHttpResponse; @@ -41,10 +45,13 @@ import org.springframework.web.servlet.ModelAndView; * @author Juergen Hoeller * @author Sam Brannen * @author Brian Clozel + * @author Réda Housni Alaoui * @since 4.2 */ public class SseEmitter extends ResponseBodyEmitter { + private static final Logger LOGGER = LoggerFactory.getLogger(SseEmitter.class); + private static final MediaType TEXT_PLAIN = new MediaType("text", "plain", StandardCharsets.UTF_8); /** @@ -52,6 +59,8 @@ public class SseEmitter extends ResponseBodyEmitter { */ private final Lock writeLock = new ReentrantLock(); + private volatile @Nullable Long lastEmissionNanoTime; + /** * Create a new SseEmitter instance. */ @@ -134,12 +143,31 @@ public class SseEmitter extends ResponseBodyEmitter { this.writeLock.lock(); try { super.send(dataToSend); + this.lastEmissionNanoTime = System.nanoTime(); } finally { this.writeLock.unlock(); } } + void notifyOfHeartbeatTick(Duration heartbeatPeriod) { + boolean skip = Optional.ofNullable(lastEmissionNanoTime) + .map(lastEmissionNanoTime -> System.nanoTime() - lastEmissionNanoTime) + .map(nanoTimeElapsedSinceLastEmission -> nanoTimeElapsedSinceLastEmission < heartbeatPeriod.toNanos()) + .orElse(false); + if (skip) { + return; + } + LOGGER.trace("Sending heartbeat to {}", this); + SseEmitter.SseEventBuilder eventBuilder = SseEmitter.event().name("ping").data("ping", MediaType.TEXT_PLAIN); + try { + send(eventBuilder); + } catch (IOException | RuntimeException e) { + // According to SseEmitter's Javadoc, the container itself will call SseEmitter#completeWithError + LOGGER.debug(e.getMessage()); + } + } + @Override public String toString() { return "SseEmitter@" + ObjectUtils.getIdentityHexString(this); diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/SseEmitterHeartbeatExecutor.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/SseEmitterHeartbeatExecutor.java index 0b19305daa..084ec89c3a 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/SseEmitterHeartbeatExecutor.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/SseEmitterHeartbeatExecutor.java @@ -16,10 +16,74 @@ package org.springframework.web.servlet.mvc.method.annotation; + +import java.time.Duration; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ScheduledFuture; + +import org.jspecify.annotations.Nullable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.scheduling.TaskScheduler; + /** * @author Réda Housni Alaoui */ -public interface SseEmitterHeartbeatExecutor { +class SseEmitterHeartbeatExecutor { - void register(SseEmitter emitter); + private static final Logger LOGGER = LoggerFactory.getLogger(SseEmitterHeartbeatExecutor.class); + + private final TaskScheduler taskScheduler; + private final Set emitters = ConcurrentHashMap.newKeySet(); + + private final Object lifecycleMonitor = new Object(); + + private final Duration period; + + @Nullable + private volatile ScheduledFuture taskFuture; + + public SseEmitterHeartbeatExecutor(TaskScheduler taskScheduler, Duration period) { + this.taskScheduler = taskScheduler; + this.period = period; + } + + public void register(SseEmitter emitter) { + startIfNeeded(); + + Runnable closeCallback = () -> emitters.remove(emitter); + emitter.onCompletion(closeCallback); + emitter.onError(t -> closeCallback.run()); + emitter.onTimeout(closeCallback); + + emitters.add(emitter); + } + + boolean isRegistered(SseEmitter emitter) { + return emitters.contains(emitter); + } + + private void startIfNeeded() { + if (taskFuture != null) { + return; + } + synchronized (lifecycleMonitor) { + if (taskFuture != null) { + return; + } + taskFuture = taskScheduler.scheduleAtFixedRate(this::notifyEmitters, period); + } + } + + private void notifyEmitters() { + LOGGER.atDebug().log(() -> "Notifying %s emitter(s)".formatted(emitters.size())); + + for (SseEmitter emitter : emitters) { + if (Thread.currentThread().isInterrupted()) { + return; + } + emitter.notifyOfHeartbeatTick(period); + } + } } diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/DefaultSseEmitterHeartbeatExecutorTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/SseEmitterHeartbeatExecutorTests.java similarity index 83% rename from spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/DefaultSseEmitterHeartbeatExecutorTests.java rename to spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/SseEmitterHeartbeatExecutorTests.java index 325748bbac..d3b39c30a1 100644 --- a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/DefaultSseEmitterHeartbeatExecutorTests.java +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/SseEmitterHeartbeatExecutorTests.java @@ -43,28 +43,26 @@ import org.springframework.scheduling.Trigger; /** * @author Réda Housni Alaoui */ -class DefaultSseEmitterHeartbeatExecutorTests { +class SseEmitterHeartbeatExecutorTests { private static final MediaType TEXT_PLAIN_UTF8 = new MediaType("text", "plain", StandardCharsets.UTF_8); private TestTaskScheduler taskScheduler; - private DefaultSseEmitterHeartbeatExecutor executor; @BeforeEach void beforeEach() { this.taskScheduler = new TestTaskScheduler(); - executor = new DefaultSseEmitterHeartbeatExecutor(taskScheduler); } @Test @DisplayName("It sends heartbeat at a fixed rate") void test1() { - executor.start(); - assertThat(taskScheduler.fixedRateTask).isNotNull(); - assertThat(taskScheduler.fixedRatePeriod).isEqualTo(Duration.ofSeconds(5)); + SseEmitterHeartbeatExecutor executor = new SseEmitterHeartbeatExecutor(taskScheduler, Duration.ofSeconds(5)); TestEmitter emitter = createEmitter(); executor.register(emitter.emitter()); + assertThat(taskScheduler.fixedRateTask).isNotNull(); + assertThat(taskScheduler.fixedRatePeriod).isEqualTo(Duration.ofSeconds(5)); taskScheduler.fixedRateTask.run(); emitter.handler.assertSentObjectCount(3); @@ -77,7 +75,7 @@ class DefaultSseEmitterHeartbeatExecutorTests { @Test @DisplayName("Emitter is unregistered on completion") void test2() { - executor.start(); + SseEmitterHeartbeatExecutor executor = new SseEmitterHeartbeatExecutor(taskScheduler, Duration.ofSeconds(5)); TestEmitter emitter = createEmitter(); executor.register(emitter.emitter()); @@ -90,7 +88,7 @@ class DefaultSseEmitterHeartbeatExecutorTests { @Test @DisplayName("Emitter is unregistered on error") void test3() { - executor.start(); + SseEmitterHeartbeatExecutor executor = new SseEmitterHeartbeatExecutor(taskScheduler, Duration.ofSeconds(5)); TestEmitter emitter = createEmitter(); executor.register(emitter.emitter()); @@ -103,7 +101,7 @@ class DefaultSseEmitterHeartbeatExecutorTests { @Test @DisplayName("Emitter is unregistered on timeout") void test4() { - executor.start(); + SseEmitterHeartbeatExecutor executor = new SseEmitterHeartbeatExecutor(taskScheduler, Duration.ofSeconds(5)); TestEmitter emitter = createEmitter(); executor.register(emitter.emitter()); @@ -116,33 +114,22 @@ class DefaultSseEmitterHeartbeatExecutorTests { @Test @DisplayName("Emitters are unregistered on executor shutdown") void test5() { - executor.start(); + SseEmitterHeartbeatExecutor executor = new SseEmitterHeartbeatExecutor(taskScheduler, Duration.ofSeconds(5)); TestEmitter emitter = createEmitter(); executor.register(emitter.emitter()); assertThat(executor.isRegistered(emitter.emitter)).isTrue(); - executor.stop(); - assertThat(executor.isRegistered(emitter.emitter)).isFalse(); - } - - @Test - @DisplayName("The task schedule is canceled on executor shutdown") - void test6() { - executor.start(); - executor.stop(); - assertThat(taskScheduler.fixedRateFuture.canceled).isTrue(); - assertThat(taskScheduler.fixedRateFuture.interrupted).isTrue(); } @Test @DisplayName("The task never throws") - void test7() { - executor.start(); - assertThat(taskScheduler.fixedRateTask).isNotNull(); + void test6() { + SseEmitterHeartbeatExecutor executor = new SseEmitterHeartbeatExecutor(taskScheduler, Duration.ofSeconds(5)); TestEmitter emitter = createEmitter(); executor.register(emitter.emitter()); + assertThat(taskScheduler.fixedRateTask).isNotNull(); emitter.handler.exceptionToThrowOnSend = new RuntimeException(); assertThatCode(() -> taskScheduler.fixedRateTask.run()).doesNotThrowAnyException(); @@ -150,49 +137,14 @@ class DefaultSseEmitterHeartbeatExecutorTests { @Test @DisplayName("The heartbeat rate can be customized") - void test8() { - executor.setPeriod(Duration.ofSeconds(30)); - executor.start(); + void test7() { + SseEmitterHeartbeatExecutor executor = new SseEmitterHeartbeatExecutor(taskScheduler, Duration.ofSeconds(30)); + TestEmitter emitter = createEmitter(); + executor.register(emitter.emitter()); assertThat(taskScheduler.fixedRateTask).isNotNull(); assertThat(taskScheduler.fixedRatePeriod).isEqualTo(Duration.ofSeconds(30)); } - @Test - @DisplayName("The heartbeat event name can be customized") - void test9() { - executor.setEventName("foo"); - executor.start(); - assertThat(taskScheduler.fixedRateTask).isNotNull(); - - TestEmitter emitter = createEmitter(); - executor.register(emitter.emitter()); - taskScheduler.fixedRateTask.run(); - - emitter.handler.assertSentObjectCount(3); - emitter.handler.assertObject(0, "event:foo\ndata:", TEXT_PLAIN_UTF8); - emitter.handler.assertObject(1, "ping", MediaType.TEXT_PLAIN); - emitter.handler.assertObject(2, "\n\n", TEXT_PLAIN_UTF8); - emitter.handler.assertWriteCount(1); - } - - @Test - @DisplayName("The heartbeat event object can be customized") - void test10() { - executor.setEventObject("foo"); - executor.start(); - assertThat(taskScheduler.fixedRateTask).isNotNull(); - - TestEmitter emitter = createEmitter(); - executor.register(emitter.emitter()); - taskScheduler.fixedRateTask.run(); - - emitter.handler.assertSentObjectCount(3); - emitter.handler.assertObject(0, "event:ping\ndata:", TEXT_PLAIN_UTF8); - emitter.handler.assertObject(1, "foo", MediaType.TEXT_PLAIN); - emitter.handler.assertObject(2, "\n\n", TEXT_PLAIN_UTF8); - emitter.handler.assertWriteCount(1); - } - private TestEmitter createEmitter() { SseEmitter sseEmitter = new SseEmitter(); TestEmitterHandler handler = new TestEmitterHandler();