diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/RequestMappingHandlerAdapter.java b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/RequestMappingHandlerAdapter.java index 7ffd2f12836..b6b2e13be13 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/RequestMappingHandlerAdapter.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/RequestMappingHandlerAdapter.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2017 the original author or authors. + * Copyright 2002-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -38,6 +38,7 @@ import org.springframework.web.bind.support.WebBindingInitializer; import org.springframework.web.method.HandlerMethod; import org.springframework.web.reactive.BindingContext; import org.springframework.web.reactive.HandlerAdapter; +import org.springframework.web.reactive.HandlerMapping; import org.springframework.web.reactive.HandlerResult; import org.springframework.web.reactive.result.method.InvocableHandlerMethod; import org.springframework.web.server.ServerWebExchange; @@ -206,6 +207,9 @@ public class RequestMappingHandlerAdapter implements HandlerAdapter, Application Assert.state(this.methodResolver != null, "Not initialized"); + // Success and error responses may use different content types + exchange.getAttributes().remove(HandlerMapping.PRODUCIBLE_MEDIA_TYPES_ATTRIBUTE); + InvocableHandlerMethod invocable = this.methodResolver.getExceptionHandlerMethod(exception, handlerMethod); if (invocable != null) { try { diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/RequestMappingExceptionHandlingIntegrationTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/RequestMappingExceptionHandlingIntegrationTests.java index 04c081e0588..77b750331a5 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/RequestMappingExceptionHandlingIntegrationTests.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/RequestMappingExceptionHandlingIntegrationTests.java @@ -17,6 +17,8 @@ package org.springframework.web.reactive.result.method.annotation; import java.io.IOException; +import java.util.Collections; +import java.util.Map; import org.junit.Test; import org.reactivestreams.Publisher; @@ -32,6 +34,7 @@ import org.springframework.http.ResponseEntity; import org.springframework.web.bind.annotation.ExceptionHandler; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.RestController; +import org.springframework.web.client.HttpStatusCodeException; import org.springframework.web.reactive.config.EnableWebFlux; import static org.junit.Assert.*; @@ -74,7 +77,7 @@ public class RequestMappingExceptionHandlingIntegrationTests extends AbstractReq } @Test // SPR-16051 - public void exceptionAfterSeveralItems() throws Exception { + public void exceptionAfterSeveralItems() { try { performGet("/SPR-16051", new HttpHeaders(), String.class).getBody(); fail(); @@ -86,6 +89,21 @@ public class RequestMappingExceptionHandlingIntegrationTests extends AbstractReq } } + @Test // SPR-16318 + public void exceptionFromMethodWithProducesCondition() throws Exception { + try { + HttpHeaders headers = new HttpHeaders(); + headers.add("Accept", "text/csv, application/problem+json"); + performGet("/SPR-16318", headers, String.class).getBody(); + fail(); + } + catch (HttpStatusCodeException ex) { + assertEquals(500, ex.getRawStatusCode()); + assertEquals("application/problem+json;charset=UTF-8", ex.getResponseHeaders().getContentType().toString()); + assertEquals("{\"reason\":\"error\"}", ex.getResponseBodyAsString()); + } + } + private void doTest(String url, String expected) throws Exception { assertEquals(expected, performGet(url, new HttpHeaders(), String.class).getBody()); } @@ -118,7 +136,7 @@ public class RequestMappingExceptionHandlingIntegrationTests extends AbstractReq throw new RuntimeException("State", new IOException("IO")); } - @GetMapping("/mono-error") + @GetMapping(path = "/mono-error") public Publisher handleWithError() { return Mono.error(new IllegalArgumentException("Argument")); } @@ -134,6 +152,10 @@ public class RequestMappingExceptionHandlingIntegrationTests extends AbstractReq }); } + @GetMapping(path = "/SPR-16318", produces = "text/csv") + public String handleCsv() throws Exception { + throw new Spr16318Exception(); + } @ExceptionHandler public Publisher handleArgumentException(IOException ex) { @@ -149,6 +171,14 @@ public class RequestMappingExceptionHandlingIntegrationTests extends AbstractReq public ResponseEntity> handleStateException(IllegalStateException ex) { return ResponseEntity.ok(Mono.just("Recovered from error: " + ex.getMessage())); } + + @ExceptionHandler + public ResponseEntity> handle(Spr16318Exception ex) { + return ResponseEntity.status(500).body(Collections.singletonMap("reason", "error")); + } } + @SuppressWarnings("serial") + private static class Spr16318Exception extends Exception {} + } diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/DispatcherServlet.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/DispatcherServlet.java index fa2c907f22e..9df3010930f 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/DispatcherServlet.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/DispatcherServlet.java @@ -1248,6 +1248,9 @@ public class DispatcherServlet extends FrameworkServlet { protected ModelAndView processHandlerException(HttpServletRequest request, HttpServletResponse response, @Nullable Object handler, Exception ex) throws Exception { + // Success and error responses may use different content types + request.removeAttribute(HandlerMapping.PRODUCIBLE_MEDIA_TYPES_ATTRIBUTE); + // Check registered HandlerExceptionResolvers... ModelAndView exMv = null; if (this.handlerExceptionResolvers != null) { diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/ServletAnnotationControllerHandlerMethodTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/ServletAnnotationControllerHandlerMethodTests.java index b0207a9b6c7..8317b8b2bb0 100644 --- a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/ServletAnnotationControllerHandlerMethodTests.java +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/ServletAnnotationControllerHandlerMethodTests.java @@ -1121,7 +1121,22 @@ public class ServletAnnotationControllerHandlerMethodTests extends AbstractServl @Test public void produces() throws Exception { - initServletWithControllers(ProducesController.class); + initServlet(wac -> { + List> converters = new ArrayList<>(); + converters.add(new MappingJackson2HttpMessageConverter()); + converters.add(new Jaxb2RootElementHttpMessageConverter()); + + RootBeanDefinition beanDef; + + beanDef = new RootBeanDefinition(RequestMappingHandlerAdapter.class); + beanDef.getPropertyValues().add("messageConverters", converters); + wac.registerBeanDefinition("handlerAdapter", beanDef); + + beanDef = new RootBeanDefinition(ExceptionHandlerExceptionResolver.class); + beanDef.getPropertyValues().add("messageConverters", converters); + wac.registerBeanDefinition("requestMappingResolver", beanDef); + + }, ProducesController.class); MockHttpServletRequest request = new MockHttpServletRequest("GET", "/something"); request.addHeader("Accept", "text/html"); @@ -1152,6 +1167,15 @@ public class ServletAnnotationControllerHandlerMethodTests extends AbstractServl response = new MockHttpServletResponse(); getServlet().service(request, response); assertEquals(406, response.getStatus()); + + // SPR-16318 + request = new MockHttpServletRequest("GET", "/something"); + request.addHeader("Accept", "text/csv,application/problem+json"); + response = new MockHttpServletResponse(); + getServlet().service(request, response); + assertEquals(500, response.getStatus()); + assertEquals("application/problem+json;charset=UTF-8", response.getContentType()); + assertEquals("{\"reason\":\"error\"}", response.getContentAsString()); } @Test @@ -3000,15 +3024,25 @@ public class ServletAnnotationControllerHandlerMethodTests extends AbstractServl @Controller public static class ProducesController { - @RequestMapping(value = "/something", produces = "text/html") + @GetMapping(path = "/something", produces = "text/html") public void handleHtml(Writer writer) throws IOException { writer.write("html"); } - @RequestMapping(value = "/something", produces = "application/xml") + @GetMapping(path = "/something", produces = "application/xml") public void handleXml(Writer writer) throws IOException { writer.write("xml"); } + + @GetMapping(path = "/something", produces = "text/csv") + public String handleCsv() { + throw new IllegalArgumentException(); + } + + @ExceptionHandler + public ResponseEntity> handle(IllegalArgumentException ex) { + return ResponseEntity.status(500).body(Collections.singletonMap("reason", "error")); + } } @Controller