diff --git a/spring-context/src/main/java/org/springframework/scheduling/concurrent/ConcurrentTaskExecutor.java b/spring-context/src/main/java/org/springframework/scheduling/concurrent/ConcurrentTaskExecutor.java index 1edebb80de7..4fb6541d69a 100644 --- a/spring-context/src/main/java/org/springframework/scheduling/concurrent/ConcurrentTaskExecutor.java +++ b/spring-context/src/main/java/org/springframework/scheduling/concurrent/ConcurrentTaskExecutor.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -199,6 +199,10 @@ public class ConcurrentTaskExecutor implements AsyncListenableTaskExecutor, Sche return adapter; } + Runnable decorateTaskIfNecessary(Runnable task) { + return (this.taskDecorator != null ? this.taskDecorator.decorate(task) : task); + } + /** * TaskExecutorAdapter subclass that wraps all provided Runnables and Callables diff --git a/spring-context/src/main/java/org/springframework/scheduling/concurrent/ConcurrentTaskScheduler.java b/spring-context/src/main/java/org/springframework/scheduling/concurrent/ConcurrentTaskScheduler.java index 01770c2c197..4d560429e7e 100644 --- a/spring-context/src/main/java/org/springframework/scheduling/concurrent/ConcurrentTaskScheduler.java +++ b/spring-context/src/main/java/org/springframework/scheduling/concurrent/ConcurrentTaskScheduler.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,8 +20,10 @@ import java.time.Clock; import java.time.Duration; import java.time.Instant; import java.util.Date; +import java.util.concurrent.Callable; import java.util.concurrent.Executor; import java.util.concurrent.Executors; +import java.util.concurrent.Future; import java.util.concurrent.RejectedExecutionException; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledFuture; @@ -39,6 +41,7 @@ import org.springframework.scheduling.support.TaskUtils; import org.springframework.util.Assert; import org.springframework.util.ClassUtils; import org.springframework.util.ErrorHandler; +import org.springframework.util.concurrent.ListenableFuture; /** * Adapter that takes a {@code java.util.concurrent.ScheduledExecutorService} and @@ -191,6 +194,7 @@ public class ConcurrentTaskScheduler extends ConcurrentTaskExecutor implements T * @see Clock#systemDefaultZone() */ public void setClock(Clock clock) { + Assert.notNull(clock, "Clock must not be null"); this.clock = clock; } @@ -200,6 +204,33 @@ public class ConcurrentTaskScheduler extends ConcurrentTaskExecutor implements T } + @Override + public void execute(Runnable task) { + super.execute(TaskUtils.decorateTaskWithErrorHandler(task, this.errorHandler, false)); + } + + @Override + public Future submit(Runnable task) { + return super.submit(TaskUtils.decorateTaskWithErrorHandler(task, this.errorHandler, false)); + } + + @Override + public Future submit(Callable task) { + return super.submit(new DelegatingErrorHandlingCallable<>(task, this.errorHandler)); + } + + @SuppressWarnings("deprecation") + @Override + public ListenableFuture submitListenable(Runnable task) { + return super.submitListenable(TaskUtils.decorateTaskWithErrorHandler(task, this.errorHandler, false)); + } + + @SuppressWarnings("deprecation") + @Override + public ListenableFuture submitListenable(Callable task) { + return super.submitListenable(new DelegatingErrorHandlingCallable<>(task, this.errorHandler)); + } + @Override @Nullable public ScheduledFuture schedule(Runnable task, Trigger trigger) { @@ -211,7 +242,9 @@ public class ConcurrentTaskScheduler extends ConcurrentTaskExecutor implements T else { ErrorHandler errorHandler = (this.errorHandler != null ? this.errorHandler : TaskUtils.getDefaultErrorHandler(true)); - return new ReschedulingRunnable(task, trigger, this.clock, scheduleExecutorToUse, errorHandler).schedule(); + return new ReschedulingRunnable( + decorateTaskIfNecessary(task), trigger, this.clock, scheduleExecutorToUse, errorHandler) + .schedule(); } } catch (RejectedExecutionException ex) { @@ -283,6 +316,7 @@ public class ConcurrentTaskScheduler extends ConcurrentTaskExecutor implements T private Runnable decorateTask(Runnable task, boolean isRepeatingTask) { Runnable result = TaskUtils.decorateTaskWithErrorHandler(task, this.errorHandler, isRepeatingTask); + result = decorateTaskIfNecessary(result); if (this.enterpriseConcurrentScheduler) { result = ManagedTaskBuilder.buildManagedTask(result, task.toString()); } diff --git a/spring-context/src/main/java/org/springframework/scheduling/concurrent/DelegatingErrorHandlingCallable.java b/spring-context/src/main/java/org/springframework/scheduling/concurrent/DelegatingErrorHandlingCallable.java new file mode 100644 index 00000000000..c30f209c53f --- /dev/null +++ b/spring-context/src/main/java/org/springframework/scheduling/concurrent/DelegatingErrorHandlingCallable.java @@ -0,0 +1,65 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.scheduling.concurrent; + +import java.lang.reflect.UndeclaredThrowableException; +import java.util.concurrent.Callable; + +import org.springframework.lang.Nullable; +import org.springframework.scheduling.support.TaskUtils; +import org.springframework.util.ErrorHandler; +import org.springframework.util.ReflectionUtils; + +/** + * {@link Callable} adapter for an {@link ErrorHandler}. + * + * @author Juergen Hoeller + * @since 6.2 + * @param the value type + */ +class DelegatingErrorHandlingCallable implements Callable { + + private final Callable delegate; + + private final ErrorHandler errorHandler; + + + public DelegatingErrorHandlingCallable(Callable delegate, @Nullable ErrorHandler errorHandler) { + this.delegate = delegate; + this.errorHandler = (errorHandler != null ? errorHandler : + TaskUtils.getDefaultErrorHandler(false)); + } + + + @Override + @Nullable + public V call() throws Exception { + try { + return this.delegate.call(); + } + catch (Throwable ex) { + try { + this.errorHandler.handleError(ex); + } + catch (UndeclaredThrowableException exToPropagate) { + ReflectionUtils.rethrowException(exToPropagate.getUndeclaredThrowable()); + } + return null; + } + } + +} diff --git a/spring-context/src/main/java/org/springframework/scheduling/concurrent/SimpleAsyncTaskScheduler.java b/spring-context/src/main/java/org/springframework/scheduling/concurrent/SimpleAsyncTaskScheduler.java index 4bd8f8e18a3..5e38b6218f8 100644 --- a/spring-context/src/main/java/org/springframework/scheduling/concurrent/SimpleAsyncTaskScheduler.java +++ b/spring-context/src/main/java/org/springframework/scheduling/concurrent/SimpleAsyncTaskScheduler.java @@ -19,6 +19,7 @@ package org.springframework.scheduling.concurrent; import java.time.Clock; import java.time.Duration; import java.time.Instant; +import java.util.concurrent.Callable; import java.util.concurrent.Executor; import java.util.concurrent.Future; import java.util.concurrent.RejectedExecutionException; @@ -41,7 +42,9 @@ import org.springframework.scheduling.TaskScheduler; import org.springframework.scheduling.Trigger; import org.springframework.scheduling.support.DelegatingErrorHandlingRunnable; import org.springframework.scheduling.support.TaskUtils; +import org.springframework.util.Assert; import org.springframework.util.ErrorHandler; +import org.springframework.util.concurrent.ListenableFuture; /** * A simple implementation of Spring's {@link TaskScheduler} interface, using @@ -108,6 +111,9 @@ public class SimpleAsyncTaskScheduler extends SimpleAsyncTaskExecutor implements private final ExecutorLifecycleDelegate lifecycleDelegate = new ExecutorLifecycleDelegate(this.scheduledExecutor); + @Nullable + private ErrorHandler errorHandler; + private Clock clock = Clock.systemDefaultZone(); private int phase = DEFAULT_PHASE; @@ -119,13 +125,22 @@ public class SimpleAsyncTaskScheduler extends SimpleAsyncTaskExecutor implements private ApplicationContext applicationContext; + /** + * Provide an {@link ErrorHandler} strategy. + * @since 6.2 + */ + public void setErrorHandler(ErrorHandler errorHandler) { + Assert.notNull(errorHandler, "ErrorHandler must not be null"); + this.errorHandler = errorHandler; + } + /** * Set the clock to use for scheduling purposes. *

The default clock is the system clock for the default time zone. - * @since 5.3 * @see Clock#systemDefaultZone() */ public void setClock(Clock clock) { + Assert.notNull(clock, "Clock must not be null"); this.clock = clock; } @@ -194,7 +209,8 @@ public class SimpleAsyncTaskScheduler extends SimpleAsyncTaskExecutor implements } private Runnable taskOnSchedulerThread(Runnable task) { - return new DelegatingErrorHandlingRunnable(task, TaskUtils.getDefaultErrorHandler(true)); + return new DelegatingErrorHandlingRunnable(task, + (this.errorHandler != null ? this.errorHandler : TaskUtils.getDefaultErrorHandler(true))); } private Runnable scheduledTask(Runnable task) { @@ -202,7 +218,10 @@ public class SimpleAsyncTaskScheduler extends SimpleAsyncTaskExecutor implements } private void shutdownAwareErrorHandler(Throwable ex) { - if (this.scheduledExecutor.isTerminated()) { + if (this.errorHandler != null) { + this.errorHandler.handleError(ex); + } + else if (this.scheduledExecutor.isTerminated()) { LogFactory.getLog(getClass()).debug("Ignoring scheduled task exception after shutdown", ex); } else { @@ -211,12 +230,40 @@ public class SimpleAsyncTaskScheduler extends SimpleAsyncTaskExecutor implements } + @Override + public void execute(Runnable task) { + super.execute(TaskUtils.decorateTaskWithErrorHandler(task, this.errorHandler, false)); + } + + @Override + public Future submit(Runnable task) { + return super.submit(TaskUtils.decorateTaskWithErrorHandler(task, this.errorHandler, false)); + } + + @Override + public Future submit(Callable task) { + return super.submit(new DelegatingErrorHandlingCallable<>(task, this.errorHandler)); + } + + @SuppressWarnings("deprecation") + @Override + public ListenableFuture submitListenable(Runnable task) { + return super.submitListenable(TaskUtils.decorateTaskWithErrorHandler(task, this.errorHandler, false)); + } + + @SuppressWarnings("deprecation") + @Override + public ListenableFuture submitListenable(Callable task) { + return super.submitListenable(new DelegatingErrorHandlingCallable<>(task, this.errorHandler)); + } + @Override @Nullable public ScheduledFuture schedule(Runnable task, Trigger trigger) { try { Runnable delegate = scheduledTask(task); - ErrorHandler errorHandler = TaskUtils.getDefaultErrorHandler(true); + ErrorHandler errorHandler = + (this.errorHandler != null ? this.errorHandler : TaskUtils.getDefaultErrorHandler(true)); return new ReschedulingRunnable( delegate, trigger, this.clock, this.scheduledExecutor, errorHandler).schedule(); } diff --git a/spring-context/src/main/java/org/springframework/scheduling/concurrent/ThreadPoolTaskScheduler.java b/spring-context/src/main/java/org/springframework/scheduling/concurrent/ThreadPoolTaskScheduler.java index 9ab42bae99e..e9e512ad7fd 100644 --- a/spring-context/src/main/java/org/springframework/scheduling/concurrent/ThreadPoolTaskScheduler.java +++ b/spring-context/src/main/java/org/springframework/scheduling/concurrent/ThreadPoolTaskScheduler.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,18 +21,23 @@ import java.time.Duration; import java.time.Instant; import java.util.Map; import java.util.concurrent.Callable; +import java.util.concurrent.Delayed; +import java.util.concurrent.ExecutionException; import java.util.concurrent.Executor; import java.util.concurrent.ExecutorService; import java.util.concurrent.Future; import java.util.concurrent.RejectedExecutionException; import java.util.concurrent.RejectedExecutionHandler; +import java.util.concurrent.RunnableScheduledFuture; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledFuture; import java.util.concurrent.ScheduledThreadPoolExecutor; import java.util.concurrent.ThreadFactory; import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; import org.springframework.core.task.AsyncListenableTaskExecutor; +import org.springframework.core.task.TaskDecorator; import org.springframework.core.task.TaskRejectedException; import org.springframework.lang.Nullable; import org.springframework.scheduling.SchedulingTaskExecutor; @@ -75,6 +80,9 @@ public class ThreadPoolTaskScheduler extends ExecutorConfigurationSupport private volatile boolean executeExistingDelayedTasksAfterShutdownPolicy = true; + @Nullable + private TaskDecorator taskDecorator; + @Nullable private volatile ErrorHandler errorHandler; @@ -145,6 +153,20 @@ public class ThreadPoolTaskScheduler extends ExecutorConfigurationSupport this.executeExistingDelayedTasksAfterShutdownPolicy = flag; } + /** + * Specify a custom {@link TaskDecorator} to be applied to any {@link Runnable} + * about to be executed. + *

Note that such a decorator is not being applied to the user-supplied + * {@code Runnable}/{@code Callable} but rather to the scheduled execution + * callback (a wrapper around the user-supplied task). + *

The primary use case is to set some execution context around the task's + * invocation, or to provide some monitoring/statistics for task execution. + * @since 6.2 + */ + public void setTaskDecorator(TaskDecorator taskDecorator) { + this.taskDecorator = taskDecorator; + } + /** * Set a custom {@link ErrorHandler} strategy. */ @@ -159,6 +181,7 @@ public class ThreadPoolTaskScheduler extends ExecutorConfigurationSupport * @see Clock#systemDefaultZone() */ public void setClock(Clock clock) { + Assert.notNull(clock, "Clock must not be null"); this.clock = clock; } @@ -212,6 +235,14 @@ public class ThreadPoolTaskScheduler extends ExecutorConfigurationSupport protected void afterExecute(Runnable task, Throwable ex) { ThreadPoolTaskScheduler.this.afterExecute(task, ex); } + @Override + protected RunnableScheduledFuture decorateTask(Runnable runnable, RunnableScheduledFuture task) { + return decorateTaskIfNecessary(task); + } + @Override + protected RunnableScheduledFuture decorateTask(Callable callable, RunnableScheduledFuture task) { + return decorateTaskIfNecessary(task); + } }; } @@ -310,12 +341,7 @@ public class ThreadPoolTaskScheduler extends ExecutorConfigurationSupport public Future submit(Callable task) { ExecutorService executor = getScheduledExecutor(); try { - Callable taskToUse = task; - ErrorHandler errorHandler = this.errorHandler; - if (errorHandler != null) { - taskToUse = new DelegatingErrorHandlingCallable<>(task, errorHandler); - } - return executor.submit(taskToUse); + return executor.submit(new DelegatingErrorHandlingCallable<>(task, this.errorHandler)); } catch (RejectedExecutionException ex) { throw new TaskRejectedException(executor, task, ex); @@ -447,32 +473,70 @@ public class ThreadPoolTaskScheduler extends ExecutorConfigurationSupport } + private RunnableScheduledFuture decorateTaskIfNecessary(RunnableScheduledFuture future) { + return (this.taskDecorator != null ? new DelegatingRunnableScheduledFuture<>(future, this.taskDecorator) : + future); + } + private Runnable errorHandlingTask(Runnable task, boolean isRepeatingTask) { return TaskUtils.decorateTaskWithErrorHandler(task, this.errorHandler, isRepeatingTask); } - private static class DelegatingErrorHandlingCallable implements Callable { + private static class DelegatingRunnableScheduledFuture implements RunnableScheduledFuture { - private final Callable delegate; + private final RunnableScheduledFuture future; - private final ErrorHandler errorHandler; + private final Runnable decoratedRunnable; - public DelegatingErrorHandlingCallable(Callable delegate, ErrorHandler errorHandler) { - this.delegate = delegate; - this.errorHandler = errorHandler; + public DelegatingRunnableScheduledFuture(RunnableScheduledFuture future, TaskDecorator taskDecorator) { + this.future = future; + this.decoratedRunnable = taskDecorator.decorate(this.future); } @Override - @Nullable - public V call() throws Exception { - try { - return this.delegate.call(); - } - catch (Throwable ex) { - this.errorHandler.handleError(ex); - return null; - } + public void run() { + this.decoratedRunnable.run(); + } + + @Override + public boolean cancel(boolean mayInterruptIfRunning) { + return this.future.cancel(mayInterruptIfRunning); + } + + @Override + public boolean isCancelled() { + return this.future.isCancelled(); + } + + @Override + public boolean isDone() { + return this.future.isDone(); + } + + @Override + public V get() throws InterruptedException, ExecutionException { + return this.future.get(); + } + + @Override + public V get(long timeout, TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException { + return this.future.get(timeout, unit); + } + + @Override + public boolean isPeriodic() { + return this.future.isPeriodic(); + } + + @Override + public long getDelay(TimeUnit unit) { + return this.future.getDelay(unit); + } + + @Override + public int compareTo(Delayed o) { + return this.future.compareTo(o); } } diff --git a/spring-context/src/test/java/org/springframework/scheduling/concurrent/ConcurrentTaskSchedulerTests.java b/spring-context/src/test/java/org/springframework/scheduling/concurrent/ConcurrentTaskSchedulerTests.java new file mode 100644 index 00000000000..1ef701fff6b --- /dev/null +++ b/spring-context/src/test/java/org/springframework/scheduling/concurrent/ConcurrentTaskSchedulerTests.java @@ -0,0 +1,244 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.scheduling.concurrent; + +import java.time.Instant; +import java.util.Date; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.RepeatedTest; +import org.junit.jupiter.api.Test; + +import org.springframework.scheduling.Trigger; +import org.springframework.scheduling.TriggerContext; +import org.springframework.util.ErrorHandler; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; + +/** + * @author Juergen Hoeller + */ +class ConcurrentTaskSchedulerTests extends AbstractSchedulingTaskExecutorTests { + + private final CustomizableThreadFactory threadFactory = new CustomizableThreadFactory(); + + private final ConcurrentTaskScheduler scheduler = new ConcurrentTaskScheduler( + Executors.newScheduledThreadPool(1, threadFactory)); + + private final AtomicBoolean taskRun = new AtomicBoolean(); + + + @SuppressWarnings("deprecation") + @Override + protected org.springframework.core.task.AsyncListenableTaskExecutor buildExecutor() { + threadFactory.setThreadNamePrefix(this.threadNamePrefix); + scheduler.setTaskDecorator(runnable -> () -> { + taskRun.set(true); + runnable.run(); + }); + return scheduler; + } + + @Override + @AfterEach + void shutdownExecutor() { + for (Runnable task : ((ExecutorService) scheduler.getConcurrentExecutor()).shutdownNow()) { + if (task instanceof Future) { + ((Future) task).cancel(true); + } + } + } + + + @Test + @Override + void submitRunnableWithGetAfterShutdown() { + // decorated Future cannot be cancelled on shutdown with ConcurrentTaskScheduler (see above) + } + + @Test + @SuppressWarnings("deprecation") + @Override + void submitListenableRunnableWithGetAfterShutdown() { + // decorated Future cannot be cancelled on shutdown with ConcurrentTaskScheduler (see above) + } + + @Test + @Override + void submitCallableWithGetAfterShutdown() { + // decorated Future cannot be cancelled on shutdown with ConcurrentTaskScheduler (see above) + } + + @Test + @SuppressWarnings("deprecation") + @Override + void submitListenableCallableWithGetAfterShutdown() { + // decorated Future cannot be cancelled on shutdown with ConcurrentTaskScheduler (see above) + } + + + @Test + void executeFailingRunnableWithErrorHandler() { + TestTask task = new TestTask(this.testName, 0); + TestErrorHandler errorHandler = new TestErrorHandler(1); + scheduler.setErrorHandler(errorHandler); + scheduler.execute(task); + await(errorHandler); + assertThat(errorHandler.lastError).isNotNull(); + assertThat(taskRun.get()).isTrue(); + } + + @Test + void submitFailingRunnableWithErrorHandler() throws Exception { + TestTask task = new TestTask(this.testName, 0); + TestErrorHandler errorHandler = new TestErrorHandler(1); + scheduler.setErrorHandler(errorHandler); + Future future = scheduler.submit(task); + Object result = future.get(1000, TimeUnit.MILLISECONDS); + assertThat(future.isDone()).isTrue(); + assertThat(result).isNull(); + assertThat(errorHandler.lastError).isNotNull(); + assertThat(taskRun.get()).isTrue(); + } + + @Test + void submitFailingCallableWithErrorHandler() throws Exception { + TestCallable task = new TestCallable(this.testName, 0); + TestErrorHandler errorHandler = new TestErrorHandler(1); + scheduler.setErrorHandler(errorHandler); + Future future = scheduler.submit(task); + Object result = future.get(1000, TimeUnit.MILLISECONDS); + assertThat(future.isDone()).isTrue(); + assertThat(result).isNull(); + assertThat(errorHandler.lastError).isNotNull(); + assertThat(taskRun.get()).isTrue(); + } + + @Test + @SuppressWarnings("deprecation") + void scheduleOneTimeTask() throws Exception { + TestTask task = new TestTask(this.testName, 1); + Future future = scheduler.schedule(task, new Date()); + Object result = future.get(1000, TimeUnit.MILLISECONDS); + assertThat(result).isNull(); + assertThat(future.isDone()).isTrue(); + assertThat(taskRun.get()).isTrue(); + assertThreadNamePrefix(task); + } + + @Test + @SuppressWarnings("deprecation") + void scheduleOneTimeFailingTaskWithoutErrorHandler() { + TestTask task = new TestTask(this.testName, 0); + Future future = scheduler.schedule(task, new Date()); + assertThatExceptionOfType(ExecutionException.class).isThrownBy(() -> future.get(1000, TimeUnit.MILLISECONDS)); + assertThat(future.isDone()).isTrue(); + assertThat(taskRun.get()).isTrue(); + } + + @Test + @SuppressWarnings("deprecation") + void scheduleOneTimeFailingTaskWithErrorHandler() throws Exception { + TestTask task = new TestTask(this.testName, 0); + TestErrorHandler errorHandler = new TestErrorHandler(1); + scheduler.setErrorHandler(errorHandler); + Future future = scheduler.schedule(task, new Date()); + Object result = future.get(1000, TimeUnit.MILLISECONDS); + assertThat(future.isDone()).isTrue(); + assertThat(result).isNull(); + assertThat(errorHandler.lastError).isNotNull(); + assertThat(taskRun.get()).isTrue(); + } + + @RepeatedTest(20) + void scheduleMultipleTriggerTasks() throws Exception { + TestTask task = new TestTask(this.testName, 3); + Future future = scheduler.schedule(task, new TestTrigger(3)); + Object result = future.get(1000, TimeUnit.MILLISECONDS); + assertThat(result).isNull(); + await(task); + assertThat(taskRun.get()).isTrue(); + assertThreadNamePrefix(task); + } + + + private void await(TestTask task) { + await(task.latch); + } + + private void await(TestErrorHandler errorHandler) { + await(errorHandler.latch); + } + + private void await(CountDownLatch latch) { + try { + latch.await(1000, TimeUnit.MILLISECONDS); + } + catch (InterruptedException ex) { + throw new IllegalStateException(ex); + } + assertThat(latch.getCount()).as("latch did not count down").isEqualTo(0); + } + + + private static class TestErrorHandler implements ErrorHandler { + + private final CountDownLatch latch; + + private volatile Throwable lastError; + + TestErrorHandler(int expectedErrorCount) { + this.latch = new CountDownLatch(expectedErrorCount); + } + + @Override + public void handleError(Throwable t) { + this.lastError = t; + this.latch.countDown(); + } + } + + + private static class TestTrigger implements Trigger { + + private final int maxRunCount; + + private final AtomicInteger actualRunCount = new AtomicInteger(); + + TestTrigger(int maxRunCount) { + this.maxRunCount = maxRunCount; + } + + @Override + public Instant nextExecution(TriggerContext triggerContext) { + if (this.actualRunCount.incrementAndGet() > this.maxRunCount) { + return null; + } + return Instant.now(); + } + } + +} diff --git a/spring-context/src/test/java/org/springframework/scheduling/concurrent/SimpleAsyncTaskSchedulerTests.java b/spring-context/src/test/java/org/springframework/scheduling/concurrent/SimpleAsyncTaskSchedulerTests.java new file mode 100644 index 00000000000..9a7a751f749 --- /dev/null +++ b/spring-context/src/test/java/org/springframework/scheduling/concurrent/SimpleAsyncTaskSchedulerTests.java @@ -0,0 +1,229 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.scheduling.concurrent; + +import java.time.Instant; +import java.util.Date; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; + +import org.junit.jupiter.api.RepeatedTest; +import org.junit.jupiter.api.Test; + +import org.springframework.scheduling.Trigger; +import org.springframework.scheduling.TriggerContext; +import org.springframework.util.ErrorHandler; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Juergen Hoeller + * @since 6.2 + */ +class SimpleAsyncTaskSchedulerTests extends AbstractSchedulingTaskExecutorTests { + + private final SimpleAsyncTaskScheduler scheduler = new SimpleAsyncTaskScheduler(); + + private final AtomicBoolean taskRun = new AtomicBoolean(); + + + @SuppressWarnings("deprecation") + @Override + protected org.springframework.core.task.AsyncListenableTaskExecutor buildExecutor() { + scheduler.setTaskDecorator(runnable -> () -> { + taskRun.set(true); + runnable.run(); + }); + scheduler.setThreadNamePrefix(this.threadNamePrefix); + return scheduler; + } + + + @Test + @Override + void submitRunnableWithGetAfterShutdown() { + // decorated Future cannot be cancelled on shutdown with SimpleAsyncTaskScheduler + } + + @Test + @SuppressWarnings("deprecation") + @Override + void submitListenableRunnableWithGetAfterShutdown() { + // decorated Future cannot be cancelled on shutdown with SimpleAsyncTaskScheduler + } + + @Test + @Override + void submitCompletableRunnableWithGetAfterShutdown() { + // decorated Future cannot be cancelled on shutdown with SimpleAsyncTaskScheduler + } + + @Test + @Override + void submitCallableWithGetAfterShutdown() { + // decorated Future cannot be cancelled on shutdown with SimpleAsyncTaskScheduler + } + + @Test + @SuppressWarnings("deprecation") + @Override + void submitListenableCallableWithGetAfterShutdown() { + // decorated Future cannot be cancelled on shutdown with SimpleAsyncTaskScheduler + } + + @Test + @Override + void submitCompletableCallableWithGetAfterShutdown() { + // decorated Future cannot be cancelled on shutdown with SimpleAsyncTaskScheduler + } + + + @Test + void executeFailingRunnableWithErrorHandler() { + TestTask task = new TestTask(this.testName, 0); + TestErrorHandler errorHandler = new TestErrorHandler(1); + scheduler.setErrorHandler(errorHandler); + scheduler.execute(task); + await(errorHandler); + assertThat(errorHandler.lastError).isNotNull(); + assertThat(taskRun.get()).isTrue(); + } + + @Test + void submitFailingRunnableWithErrorHandler() throws Exception { + TestTask task = new TestTask(this.testName, 0); + TestErrorHandler errorHandler = new TestErrorHandler(1); + scheduler.setErrorHandler(errorHandler); + Future future = scheduler.submit(task); + Object result = future.get(1000, TimeUnit.MILLISECONDS); + assertThat(future.isDone()).isTrue(); + assertThat(result).isNull(); + assertThat(errorHandler.lastError).isNotNull(); + assertThat(taskRun.get()).isTrue(); + } + + @Test + void submitFailingCallableWithErrorHandler() throws Exception { + TestCallable task = new TestCallable(this.testName, 0); + TestErrorHandler errorHandler = new TestErrorHandler(1); + scheduler.setErrorHandler(errorHandler); + Future future = scheduler.submit(task); + Object result = future.get(1000, TimeUnit.MILLISECONDS); + assertThat(future.isDone()).isTrue(); + assertThat(result).isNull(); + assertThat(errorHandler.lastError).isNotNull(); + assertThat(taskRun.get()).isTrue(); + } + + @Test + @SuppressWarnings("deprecation") + void scheduleOneTimeTask() throws Exception { + TestTask task = new TestTask(this.testName, 1); + Future future = scheduler.schedule(task, new Date()); + Object result = future.get(1000, TimeUnit.MILLISECONDS); + assertThat(result).isNull(); + await(task); + assertThat(taskRun.get()).isTrue(); + assertThreadNamePrefix(task); + } + + @Test + @SuppressWarnings("deprecation") + void scheduleOneTimeFailingTaskWithErrorHandler() throws Exception { + TestTask task = new TestTask(this.testName, 0); + TestErrorHandler errorHandler = new TestErrorHandler(1); + scheduler.setErrorHandler(errorHandler); + Future future = scheduler.schedule(task, new Date()); + Object result = future.get(1000, TimeUnit.MILLISECONDS); + await(errorHandler); + assertThat(result).isNull(); + assertThat(errorHandler.lastError).isNotNull(); + assertThat(taskRun.get()).isTrue(); + } + + @RepeatedTest(20) + void scheduleMultipleTriggerTasks() throws Exception { + TestTask task = new TestTask(this.testName, 3); + Future future = scheduler.schedule(task, new TestTrigger(3)); + Object result = future.get(1000, TimeUnit.MILLISECONDS); + assertThat(result).isNull(); + await(task); + assertThat(taskRun.get()).isTrue(); + assertThreadNamePrefix(task); + } + + + private void await(TestTask task) { + await(task.latch); + } + + private void await(TestErrorHandler errorHandler) { + await(errorHandler.latch); + } + + private void await(CountDownLatch latch) { + try { + latch.await(1000, TimeUnit.MILLISECONDS); + } + catch (InterruptedException ex) { + throw new IllegalStateException(ex); + } + assertThat(latch.getCount()).as("latch did not count down").isEqualTo(0); + } + + + private static class TestErrorHandler implements ErrorHandler { + + private final CountDownLatch latch; + + private volatile Throwable lastError; + + TestErrorHandler(int expectedErrorCount) { + this.latch = new CountDownLatch(expectedErrorCount); + } + + @Override + public void handleError(Throwable t) { + this.lastError = t; + this.latch.countDown(); + } + } + + + private static class TestTrigger implements Trigger { + + private final int maxRunCount; + + private final AtomicInteger actualRunCount = new AtomicInteger(); + + TestTrigger(int maxRunCount) { + this.maxRunCount = maxRunCount; + } + + @Override + public Instant nextExecution(TriggerContext triggerContext) { + if (this.actualRunCount.incrementAndGet() > this.maxRunCount) { + return null; + } + return Instant.now(); + } + } + +} diff --git a/spring-context/src/test/java/org/springframework/scheduling/concurrent/ThreadPoolTaskSchedulerTests.java b/spring-context/src/test/java/org/springframework/scheduling/concurrent/ThreadPoolTaskSchedulerTests.java index 7226c0abd67..48c0192c138 100644 --- a/spring-context/src/test/java/org/springframework/scheduling/concurrent/ThreadPoolTaskSchedulerTests.java +++ b/spring-context/src/test/java/org/springframework/scheduling/concurrent/ThreadPoolTaskSchedulerTests.java @@ -22,6 +22,7 @@ import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import org.junit.jupiter.api.RepeatedTest; @@ -44,10 +45,16 @@ class ThreadPoolTaskSchedulerTests extends AbstractSchedulingTaskExecutorTests { private final ThreadPoolTaskScheduler scheduler = new ThreadPoolTaskScheduler(); + private final AtomicBoolean taskRun = new AtomicBoolean(); + @SuppressWarnings("deprecation") @Override protected org.springframework.core.task.AsyncListenableTaskExecutor buildExecutor() { + scheduler.setTaskDecorator(runnable -> () -> { + taskRun.set(true); + runnable.run(); + }); scheduler.setThreadNamePrefix(this.threadNamePrefix); scheduler.afterPropertiesSet(); return scheduler; @@ -62,6 +69,7 @@ class ThreadPoolTaskSchedulerTests extends AbstractSchedulingTaskExecutorTests { scheduler.execute(task); await(errorHandler); assertThat(errorHandler.lastError).isNotNull(); + assertThat(taskRun.get()).isTrue(); } @Test @@ -74,6 +82,7 @@ class ThreadPoolTaskSchedulerTests extends AbstractSchedulingTaskExecutorTests { assertThat(future.isDone()).isTrue(); assertThat(result).isNull(); assertThat(errorHandler.lastError).isNotNull(); + assertThat(taskRun.get()).isTrue(); } @Test @@ -86,6 +95,7 @@ class ThreadPoolTaskSchedulerTests extends AbstractSchedulingTaskExecutorTests { assertThat(future.isDone()).isTrue(); assertThat(result).isNull(); assertThat(errorHandler.lastError).isNotNull(); + assertThat(taskRun.get()).isTrue(); } @Test @@ -96,6 +106,7 @@ class ThreadPoolTaskSchedulerTests extends AbstractSchedulingTaskExecutorTests { Object result = future.get(1000, TimeUnit.MILLISECONDS); assertThat(result).isNull(); assertThat(future.isDone()).isTrue(); + assertThat(taskRun.get()).isTrue(); assertThreadNamePrefix(task); } @@ -106,6 +117,7 @@ class ThreadPoolTaskSchedulerTests extends AbstractSchedulingTaskExecutorTests { Future future = scheduler.schedule(task, new Date()); assertThatExceptionOfType(ExecutionException.class).isThrownBy(() -> future.get(1000, TimeUnit.MILLISECONDS)); assertThat(future.isDone()).isTrue(); + assertThat(taskRun.get()).isTrue(); } @Test @@ -119,6 +131,7 @@ class ThreadPoolTaskSchedulerTests extends AbstractSchedulingTaskExecutorTests { assertThat(future.isDone()).isTrue(); assertThat(result).isNull(); assertThat(errorHandler.lastError).isNotNull(); + assertThat(taskRun.get()).isTrue(); } @RepeatedTest(20) @@ -128,6 +141,7 @@ class ThreadPoolTaskSchedulerTests extends AbstractSchedulingTaskExecutorTests { Object result = future.get(1000, TimeUnit.MILLISECONDS); assertThat(result).isNull(); await(task); + assertThat(taskRun.get()).isTrue(); assertThreadNamePrefix(task); } @@ -147,7 +161,7 @@ class ThreadPoolTaskSchedulerTests extends AbstractSchedulingTaskExecutorTests { catch (InterruptedException ex) { throw new IllegalStateException(ex); } - assertThat(latch.getCount()).as("latch did not count down,").isEqualTo(0); + assertThat(latch.getCount()).as("latch did not count down").isEqualTo(0); }