diff --git a/spring-aop/src/main/java/org/springframework/aop/interceptor/AsyncExecutionInterceptor.java b/spring-aop/src/main/java/org/springframework/aop/interceptor/AsyncExecutionInterceptor.java index 8b98bcce660..d5026cb4346 100644 --- a/spring-aop/src/main/java/org/springframework/aop/interceptor/AsyncExecutionInterceptor.java +++ b/spring-aop/src/main/java/org/springframework/aop/interceptor/AsyncExecutionInterceptor.java @@ -18,9 +18,12 @@ package org.springframework.aop.interceptor; import java.lang.reflect.Method; import java.util.concurrent.Callable; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; import java.util.concurrent.ExecutionException; import java.util.concurrent.Executor; import java.util.concurrent.Future; +import java.util.function.Supplier; import org.aopalliance.intercept.MethodInterceptor; import org.aopalliance.intercept.MethodInvocation; @@ -30,6 +33,7 @@ import org.springframework.core.BridgeMethodResolver; import org.springframework.core.Ordered; import org.springframework.core.task.AsyncListenableTaskExecutor; import org.springframework.core.task.AsyncTaskExecutor; +import org.springframework.lang.UsesJava8; import org.springframework.util.ClassUtils; import org.springframework.util.concurrent.ListenableFuture; @@ -68,6 +72,11 @@ import org.springframework.util.concurrent.ListenableFuture; public class AsyncExecutionInterceptor extends AsyncExecutionAspectSupport implements MethodInterceptor, Ordered { + // Java 8's CompletableFuture type present? + private static final boolean completableFuturePresent = ClassUtils.isPresent( + "java.util.concurrent.CompletableFuture", AsyncExecutionInterceptor.class.getClassLoader()); + + /** * Create a new {@code AsyncExecutionInterceptor}. * @param defaultExecutor the {@link Executor} (typically a Spring {@link AsyncTaskExecutor} @@ -124,6 +133,12 @@ public class AsyncExecutionInterceptor extends AsyncExecutionAspectSupport }; Class returnType = invocation.getMethod().getReturnType(); + if (completableFuturePresent) { + Future result = CompletableFutureDelegate.processCompletableFuture(returnType, task, executor); + if (result != null) { + return result; + } + } if (ListenableFuture.class.isAssignableFrom(returnType)) { return ((AsyncListenableTaskExecutor) executor).submitListenable(task); } @@ -154,4 +169,29 @@ public class AsyncExecutionInterceptor extends AsyncExecutionAspectSupport return Ordered.HIGHEST_PRECEDENCE; } + + /** + * Inner class to avoid a hard dependency on Java 8. + */ + @UsesJava8 + private static class CompletableFutureDelegate { + + public static Future processCompletableFuture(Class returnType, final Callable task, Executor executor) { + if (!CompletableFuture.class.isAssignableFrom(returnType)) { + return null; + } + return CompletableFuture.supplyAsync(new Supplier() { + @Override + public T get() { + try { + return task.call(); + } + catch (Throwable ex) { + throw new CompletionException(ex); + } + } + }, executor); + } + } + } diff --git a/spring-context/src/test/java/org/springframework/scheduling/annotation/AsyncExecutionTests.java b/spring-context/src/test/java/org/springframework/scheduling/annotation/AsyncExecutionTests.java index 9562fed9e4a..0d55ec7a20c 100644 --- a/spring-context/src/test/java/org/springframework/scheduling/annotation/AsyncExecutionTests.java +++ b/spring-context/src/test/java/org/springframework/scheduling/annotation/AsyncExecutionTests.java @@ -21,6 +21,7 @@ import java.io.Serializable; import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; import java.util.HashMap; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; import java.util.concurrent.Future; @@ -71,24 +72,48 @@ public class AsyncExecutionTests { assertEquals("20", future.get()); ListenableFuture listenableFuture = asyncTest.returnSomethingListenable(20); assertEquals("20", listenableFuture.get()); + CompletableFuture completableFuture = asyncTest.returnSomethingCompletable(20); + assertEquals("20", completableFuture.get()); - future = asyncTest.returnSomething(0); try { - future.get(); + asyncTest.returnSomething(0).get(); fail("Should have thrown ExecutionException"); } catch (ExecutionException ex) { assertTrue(ex.getCause() instanceof IllegalArgumentException); } - future = asyncTest.returnSomething(-1); try { - future.get(); + asyncTest.returnSomething(-1).get(); fail("Should have thrown ExecutionException"); } catch (ExecutionException ex) { assertTrue(ex.getCause() instanceof IOException); } + + try { + asyncTest.returnSomethingListenable(0).get(); + fail("Should have thrown ExecutionException"); + } + catch (ExecutionException ex) { + assertTrue(ex.getCause() instanceof IllegalArgumentException); + } + + try { + asyncTest.returnSomethingListenable(-1).get(); + fail("Should have thrown ExecutionException"); + } + catch (ExecutionException ex) { + assertTrue(ex.getCause() instanceof IOException); + } + + try { + asyncTest.returnSomethingCompletable(0).get(); + fail("Should have thrown ExecutionException"); + } + catch (ExecutionException ex) { + assertTrue(ex.getCause() instanceof IllegalArgumentException); + } } @Test @@ -163,6 +188,32 @@ public class AsyncExecutionTests { assertEquals("20", future.get()); ListenableFuture listenableFuture = asyncTest.returnSomethingListenable(20); assertEquals("20", listenableFuture.get()); + CompletableFuture completableFuture = asyncTest.returnSomethingCompletable(20); + assertEquals("20", completableFuture.get()); + + try { + asyncTest.returnSomething(0).get(); + fail("Should have thrown ExecutionException"); + } + catch (ExecutionException ex) { + assertTrue(ex.getCause() instanceof IllegalArgumentException); + } + + try { + asyncTest.returnSomethingListenable(0).get(); + fail("Should have thrown ExecutionException"); + } + catch (ExecutionException ex) { + assertTrue(ex.getCause() instanceof IllegalArgumentException); + } + + try { + asyncTest.returnSomethingCompletable(0).get(); + fail("Should have thrown ExecutionException"); + } + catch (ExecutionException ex) { + assertTrue(ex.getCause() instanceof IllegalArgumentException); + } } @Test @@ -397,8 +448,23 @@ public class AsyncExecutionTests { @Async public ListenableFuture returnSomethingListenable(int i) { assertTrue(!Thread.currentThread().getName().equals(originalThreadName)); + if (i == 0) { + throw new IllegalArgumentException(); + } + else if (i < 0) { + return AsyncResult.forExecutionException(new IOException()); + } return new AsyncResult(Integer.toString(i)); } + + @Async + public CompletableFuture returnSomethingCompletable(int i) { + assertTrue(!Thread.currentThread().getName().equals(originalThreadName)); + if (i == 0) { + throw new IllegalArgumentException(); + } + return CompletableFuture.completedFuture(Integer.toString(i)); + } } @@ -459,14 +525,29 @@ public class AsyncExecutionTests { public Future returnSomething(int i) { assertTrue(!Thread.currentThread().getName().equals(originalThreadName)); + if (i == 0) { + throw new IllegalArgumentException(); + } return new AsyncResult(Integer.toString(i)); } public ListenableFuture returnSomethingListenable(int i) { assertTrue(!Thread.currentThread().getName().equals(originalThreadName)); + if (i == 0) { + throw new IllegalArgumentException(); + } return new AsyncResult(Integer.toString(i)); } + @Async + public CompletableFuture returnSomethingCompletable(int i) { + assertTrue(!Thread.currentThread().getName().equals(originalThreadName)); + if (i == 0) { + throw new IllegalArgumentException(); + } + return CompletableFuture.completedFuture(Integer.toString(i)); + } + @Override public void destroy() { }