diff --git a/spring-core/src/main/java/org/springframework/core/task/SimpleAsyncTaskExecutor.java b/spring-core/src/main/java/org/springframework/core/task/SimpleAsyncTaskExecutor.java index e2d2363373f..e575c2d54e9 100644 --- a/spring-core/src/main/java/org/springframework/core/task/SimpleAsyncTaskExecutor.java +++ b/spring-core/src/main/java/org/springframework/core/task/SimpleAsyncTaskExecutor.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2024 the original author or authors. + * Copyright 2002-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -92,6 +92,8 @@ public class SimpleAsyncTaskExecutor extends CustomizableThreadCreator @Nullable private Set activeThreads; + private boolean rejectTasksWhenLimitReached = false; + private volatile boolean active = true; @@ -190,6 +192,17 @@ public class SimpleAsyncTaskExecutor extends CustomizableThreadCreator this.activeThreads = (timeout > 0 ? ConcurrentHashMap.newKeySet() : null); } + /** + * Specify whether to reject tasks when the concurrency limit has been reached, + * throwing {@link TaskRejectedException} on any further submission attempts. + *

The default is {@code false}, blocking the caller until the submission can + * be accepted. Switch this to {@code true} for immediate rejection instead. + * @since 6.2.6 + */ + public void setRejectTasksWhenLimitReached(boolean rejectTasksWhenLimitReached) { + this.rejectTasksWhenLimitReached = rejectTasksWhenLimitReached; + } + /** * Set the maximum number of parallel task executions allowed. * The default of -1 indicates no concurrency limit at all. @@ -372,13 +385,21 @@ public class SimpleAsyncTaskExecutor extends CustomizableThreadCreator * making {@code beforeAccess()} and {@code afterAccess()} * visible to the surrounding class. */ - private static class ConcurrencyThrottleAdapter extends ConcurrencyThrottleSupport { + private class ConcurrencyThrottleAdapter extends ConcurrencyThrottleSupport { @Override protected void beforeAccess() { super.beforeAccess(); } + @Override + protected void onLimitReached() { + if (rejectTasksWhenLimitReached) { + throw new TaskRejectedException("Concurrency limit reached: " + getConcurrencyLimit()); + } + super.onLimitReached(); + } + @Override protected void afterAccess() { super.afterAccess(); diff --git a/spring-core/src/main/java/org/springframework/util/ConcurrencyThrottleSupport.java b/spring-core/src/main/java/org/springframework/util/ConcurrencyThrottleSupport.java index 46da8e430ca..cf54df78e9c 100644 --- a/spring-core/src/main/java/org/springframework/util/ConcurrencyThrottleSupport.java +++ b/spring-core/src/main/java/org/springframework/util/ConcurrencyThrottleSupport.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2024 the original author or authors. + * Copyright 2002-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -105,6 +105,7 @@ public abstract class ConcurrencyThrottleSupport implements Serializable { /** * To be invoked before the main execution logic of concrete subclasses. *

This implementation applies the concurrency throttle. + * @see #onLimitReached() * @see #afterAccess() */ protected void beforeAccess() { @@ -113,29 +114,12 @@ public abstract class ConcurrencyThrottleSupport implements Serializable { "Currently no invocations allowed - concurrency limit set to NO_CONCURRENCY"); } if (this.concurrencyLimit > 0) { - boolean debug = logger.isDebugEnabled(); this.concurrencyLock.lock(); try { - boolean interrupted = false; - while (this.concurrencyCount >= this.concurrencyLimit) { - if (interrupted) { - throw new IllegalStateException("Thread was interrupted while waiting for invocation access, " + - "but concurrency limit still does not allow for entering"); - } - if (debug) { - logger.debug("Concurrency count " + this.concurrencyCount + - " has reached limit " + this.concurrencyLimit + " - blocking"); - } - try { - this.concurrencyCondition.await(); - } - catch (InterruptedException ex) { - // Re-interrupt current thread, to allow other threads to react. - Thread.currentThread().interrupt(); - interrupted = true; - } + if (this.concurrencyCount >= this.concurrencyLimit) { + onLimitReached(); } - if (debug) { + if (logger.isDebugEnabled()) { logger.debug("Entering throttle at concurrency count " + this.concurrencyCount); } this.concurrencyCount++; @@ -146,6 +130,33 @@ public abstract class ConcurrencyThrottleSupport implements Serializable { } } + /** + * Triggered by {@link #beforeAccess()} when the concurrency limit has been reached. + * The default implementation blocks until the concurrency count allows for entering. + * @since 6.2.6 + */ + protected void onLimitReached() { + boolean interrupted = false; + while (this.concurrencyCount >= this.concurrencyLimit) { + if (interrupted) { + throw new IllegalStateException("Thread was interrupted while waiting for invocation access, " + + "but concurrency limit still does not allow for entering"); + } + if (logger.isDebugEnabled()) { + logger.debug("Concurrency count " + this.concurrencyCount + + " has reached limit " + this.concurrencyLimit + " - blocking"); + } + try { + this.concurrencyCondition.await(); + } + catch (InterruptedException ex) { + // Re-interrupt current thread, to allow other threads to react. + Thread.currentThread().interrupt(); + interrupted = true; + } + } + } + /** * To be invoked after the main execution logic of concrete subclasses. * @see #beforeAccess() diff --git a/spring-core/src/test/java/org/springframework/core/task/SimpleAsyncTaskExecutorTests.java b/spring-core/src/test/java/org/springframework/core/task/SimpleAsyncTaskExecutorTests.java index c7f4bd9d3b4..27ea6a7053f 100644 --- a/spring-core/src/test/java/org/springframework/core/task/SimpleAsyncTaskExecutorTests.java +++ b/spring-core/src/test/java/org/springframework/core/task/SimpleAsyncTaskExecutorTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,6 +21,7 @@ import org.junit.jupiter.api.Test; import org.springframework.util.ConcurrencyThrottleSupport; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.assertj.core.api.Assertions.assertThatIllegalStateException; @@ -31,6 +32,23 @@ import static org.assertj.core.api.Assertions.assertThatIllegalStateException; */ class SimpleAsyncTaskExecutorTests { + @Test + void isActiveUntilClose() { + SimpleAsyncTaskExecutor executor = new SimpleAsyncTaskExecutor(); + assertThat(executor.isActive()).isTrue(); + assertThat(executor.isThrottleActive()).isFalse(); + executor.close(); + assertThat(executor.isActive()).isFalse(); + assertThat(executor.isThrottleActive()).isFalse(); + } + + @Test + void throwsExceptionWhenSuppliedWithNullRunnable() { + try (SimpleAsyncTaskExecutor executor = new SimpleAsyncTaskExecutor()) { + assertThatIllegalArgumentException().isThrownBy(() -> executor.execute(null)); + } + } + @Test void cannotExecuteWhenConcurrencyIsSwitchedOff() { try (SimpleAsyncTaskExecutor executor = new SimpleAsyncTaskExecutor()) { @@ -41,35 +59,34 @@ class SimpleAsyncTaskExecutorTests { } @Test - void throttleIsNotActiveByDefault() { + void taskRejectedWhenConcurrencyLimitReached() { try (SimpleAsyncTaskExecutor executor = new SimpleAsyncTaskExecutor()) { - assertThat(executor.isThrottleActive()).as("Concurrency throttle must not default to being active (on)").isFalse(); + executor.setConcurrencyLimit(1); + executor.setRejectTasksWhenLimitReached(true); + assertThat(executor.isThrottleActive()).isTrue(); + executor.execute(new NoOpRunnable()); + assertThatExceptionOfType(TaskRejectedException.class).isThrownBy(() -> executor.execute(new NoOpRunnable())); } } @Test void threadNameGetsSetCorrectly() { - final String customPrefix = "chankPop#"; - final Object monitor = new Object(); - SimpleAsyncTaskExecutor executor = new SimpleAsyncTaskExecutor(customPrefix); - ThreadNameHarvester task = new ThreadNameHarvester(monitor); - executeAndWait(executor, task, monitor); - assertThat(task.getThreadName()).startsWith(customPrefix); + String customPrefix = "chankPop#"; + Object monitor = new Object(); + try (SimpleAsyncTaskExecutor executor = new SimpleAsyncTaskExecutor(customPrefix)) { + ThreadNameHarvester task = new ThreadNameHarvester(monitor); + executeAndWait(executor, task, monitor); + assertThat(task.getThreadName()).startsWith(customPrefix); + } } @Test void threadFactoryOverridesDefaults() { - final Object monitor = new Object(); - SimpleAsyncTaskExecutor executor = new SimpleAsyncTaskExecutor(runnable -> new Thread(runnable, "test")); - ThreadNameHarvester task = new ThreadNameHarvester(monitor); - executeAndWait(executor, task, monitor); - assertThat(task.getThreadName()).isEqualTo("test"); - } - - @Test - void throwsExceptionWhenSuppliedWithNullRunnable() { - try (SimpleAsyncTaskExecutor executor = new SimpleAsyncTaskExecutor()) { - assertThatIllegalArgumentException().isThrownBy(() -> executor.execute(null)); + Object monitor = new Object(); + try (SimpleAsyncTaskExecutor executor = new SimpleAsyncTaskExecutor(runnable -> new Thread(runnable, "test"))) { + ThreadNameHarvester task = new ThreadNameHarvester(monitor); + executeAndWait(executor, task, monitor); + assertThat(task.getThreadName()).isEqualTo("test"); } } @@ -89,7 +106,12 @@ class SimpleAsyncTaskExecutorTests { @Override public void run() { - // no-op + try { + Thread.sleep(1000); + } + catch (InterruptedException ex) { + Thread.currentThread().interrupt(); + } } }