diff --git a/spring-test/src/main/java/org/springframework/test/web/client/MockMvcClientHttpRequestFactory.java b/spring-test/src/main/java/org/springframework/test/web/client/MockMvcClientHttpRequestFactory.java index d34d51ed194..a38ab07e4ba 100644 --- a/spring-test/src/main/java/org/springframework/test/web/client/MockMvcClientHttpRequestFactory.java +++ b/spring-test/src/main/java/org/springframework/test/web/client/MockMvcClientHttpRequestFactory.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2016 the original author or authors. + * Copyright 2002-2017 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -31,19 +31,26 @@ import org.springframework.mock.http.client.MockClientHttpRequest; import org.springframework.mock.http.client.MockClientHttpResponse; import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.test.web.servlet.MockMvc; -import org.springframework.test.web.servlet.MvcResult; -import org.springframework.test.web.servlet.request.MockHttpServletRequestBuilder; import org.springframework.util.Assert; -import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.*; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.request; + /** * A {@link ClientHttpRequestFactory} for requests executed via {@link MockMvc}. * + *

As of 5.0 this class also implements + * {@link org.springframework.http.client.AsyncClientHttpRequestFactory + * AsyncClientHttpRequestFactory}. However note that + * {@link org.springframework.web.client.AsyncRestTemplate} and related classes + * have been deprecated at the same time. + * * @author Rossen Stoyanchev * @since 3.2 */ -public class MockMvcClientHttpRequestFactory implements ClientHttpRequestFactory { +@SuppressWarnings("deprecation") +public class MockMvcClientHttpRequestFactory + implements ClientHttpRequestFactory, org.springframework.http.client.AsyncClientHttpRequestFactory { private final MockMvc mockMvc; @@ -55,31 +62,46 @@ public class MockMvcClientHttpRequestFactory implements ClientHttpRequestFactory @Override - public ClientHttpRequest createRequest(final URI uri, final HttpMethod httpMethod) throws IOException { + public ClientHttpRequest createRequest(final URI uri, final HttpMethod httpMethod) { return new MockClientHttpRequest(httpMethod, uri) { @Override public ClientHttpResponse executeInternal() throws IOException { - try { - MockHttpServletRequestBuilder requestBuilder = request(httpMethod, uri); - requestBuilder.content(getBodyAsBytes()); - requestBuilder.headers(getHeaders()); - MvcResult mvcResult = MockMvcClientHttpRequestFactory.this.mockMvc.perform(requestBuilder).andReturn(); - MockHttpServletResponse servletResponse = mvcResult.getResponse(); - HttpStatus status = HttpStatus.valueOf(servletResponse.getStatus()); - byte[] body = servletResponse.getContentAsByteArray(); - HttpHeaders headers = getResponseHeaders(servletResponse); - MockClientHttpResponse clientResponse = new MockClientHttpResponse(body, status); - clientResponse.getHeaders().putAll(headers); - return clientResponse; - } - catch (Exception ex) { - byte[] body = ex.toString().getBytes(StandardCharsets.UTF_8); - return new MockClientHttpResponse(body, HttpStatus.INTERNAL_SERVER_ERROR); - } + return getClientHttpResponse(httpMethod, uri, getHeaders(), getBodyAsBytes()); } }; } + @Override + public org.springframework.http.client.AsyncClientHttpRequest createAsyncRequest(URI uri, HttpMethod method) { + return new org.springframework.mock.http.client.MockAsyncClientHttpRequest(method, uri) { + @Override + protected ClientHttpResponse executeInternal() throws IOException { + return getClientHttpResponse(method, uri, getHeaders(), getBodyAsBytes()); + } + }; + } + + private ClientHttpResponse getClientHttpResponse( + HttpMethod httpMethod, URI uri, HttpHeaders requestHeaders, byte[] requestBody) { + + try { + MockHttpServletResponse servletResponse = mockMvc + .perform(request(httpMethod, uri).content(requestBody).headers(requestHeaders)) + .andReturn() + .getResponse(); + + HttpStatus status = HttpStatus.valueOf(servletResponse.getStatus()); + byte[] body = servletResponse.getContentAsByteArray(); + MockClientHttpResponse clientResponse = new MockClientHttpResponse(body, status); + clientResponse.getHeaders().putAll(getResponseHeaders(servletResponse)); + return clientResponse; + } + catch (Exception ex) { + byte[] body = ex.toString().getBytes(StandardCharsets.UTF_8); + return new MockClientHttpResponse(body, HttpStatus.INTERNAL_SERVER_ERROR); + } + } + private HttpHeaders getResponseHeaders(MockHttpServletResponse response) { HttpHeaders headers = new HttpHeaders(); for (String name : response.getHeaderNames()) { diff --git a/spring-test/src/test/java/org/springframework/test/web/client/samples/MockMvcClientHttpRequestFactoryTests.java b/spring-test/src/test/java/org/springframework/test/web/client/samples/MockMvcClientHttpRequestFactoryTests.java index d6e03de0892..0b72717cc45 100644 --- a/spring-test/src/test/java/org/springframework/test/web/client/samples/MockMvcClientHttpRequestFactoryTests.java +++ b/spring-test/src/test/java/org/springframework/test/web/client/samples/MockMvcClientHttpRequestFactoryTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2011 the original author or authors. + * Copyright 2002-2017 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,6 +23,7 @@ import org.junit.runner.RunWith; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.ComponentScan; import org.springframework.context.annotation.Configuration; +import org.springframework.http.ResponseEntity; import org.springframework.stereotype.Controller; import org.springframework.test.context.ContextConfiguration; import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; @@ -30,16 +31,19 @@ import org.springframework.test.context.web.WebAppConfiguration; import org.springframework.test.web.client.MockMvcClientHttpRequestFactory; import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.setup.MockMvcBuilders; +import org.springframework.util.concurrent.ListenableFuture; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RequestMethod; import org.springframework.web.bind.annotation.ResponseBody; +import org.springframework.web.client.AsyncRestTemplate; import org.springframework.web.client.RestTemplate; import org.springframework.web.context.WebApplicationContext; import org.springframework.web.servlet.config.annotation.EnableWebMvc; import org.springframework.web.servlet.config.annotation.WebMvcConfigurerAdapter; -import static org.junit.Assert.*; -import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.*; +import static org.junit.Assert.assertEquals; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; + /** * Tests that use a {@link RestTemplate} configured with a @@ -57,21 +61,29 @@ public class MockMvcClientHttpRequestFactoryTests { @Autowired private WebApplicationContext wac; - private RestTemplate restTemplate; + private MockMvc mockMvc; @Before public void setup() { - MockMvc mockMvc = MockMvcBuilders.webAppContextSetup(this.wac).alwaysExpect(status().isOk()).build(); - this.restTemplate = new RestTemplate(new MockMvcClientHttpRequestFactory(mockMvc)); + this.mockMvc = MockMvcBuilders.webAppContextSetup(this.wac).alwaysExpect(status().isOk()).build(); } @Test public void test() throws Exception { - String result = this.restTemplate.getForObject("/foo", String.class); + RestTemplate template = new RestTemplate(new MockMvcClientHttpRequestFactory(this.mockMvc)); + String result = template.getForObject("/foo", String.class); assertEquals("bar", result); } + @Test + @SuppressWarnings("deprecation") + public void testAsyncTemplate() throws Exception { + AsyncRestTemplate template = new AsyncRestTemplate(new MockMvcClientHttpRequestFactory(this.mockMvc)); + ListenableFuture> entity = template.getForEntity("/foo", String.class); + assertEquals("bar", entity.get().getBody()); + } + @EnableWebMvc @Configuration