diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/function/DefaultEntityResponseBuilder.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/function/DefaultEntityResponseBuilder.java index d3550fda539..924017ec489 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/function/DefaultEntityResponseBuilder.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/function/DefaultEntityResponseBuilder.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -41,9 +41,13 @@ import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; import org.springframework.core.ParameterizedTypeReference; +import org.springframework.core.io.InputStreamResource; +import org.springframework.core.io.Resource; +import org.springframework.core.io.support.ResourceRegion; import org.springframework.http.CacheControl; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; +import org.springframework.http.HttpRange; import org.springframework.http.HttpStatus; import org.springframework.http.InvalidMediaTypeException; import org.springframework.http.MediaType; @@ -70,6 +74,9 @@ final class DefaultEntityResponseBuilder implements EntityResponse.Builder private static final boolean reactiveStreamsPresent = ClassUtils.isPresent( "org.reactivestreams.Publisher", DefaultEntityResponseBuilder.class.getClassLoader()); + private static final Type RESOURCE_REGION_LIST_TYPE = + new ParameterizedTypeReference>() { }.getType(); + private final T entity; @@ -245,6 +252,11 @@ final class DefaultEntityResponseBuilder implements EntityResponse.Builder this.entityType = entityType; } + private static boolean isResource(T entity) { + return !(entity instanceof InputStreamResource) && + (entity instanceof Resource); + } + @Override public T entity() { return this.entity; @@ -267,13 +279,33 @@ final class DefaultEntityResponseBuilder implements EntityResponse.Builder ServletServerHttpResponse serverResponse = new ServletServerHttpResponse(response); MediaType contentType = getContentType(response); Class entityClass = entity.getClass(); + Type entityType = this.entityType; + + if (entityClass != InputStreamResource.class && Resource.class.isAssignableFrom(entityClass)) { + serverResponse.getHeaders().set(HttpHeaders.ACCEPT_RANGES, "bytes"); + String rangeHeader = request.getHeader(HttpHeaders.RANGE); + if (rangeHeader != null) { + Resource resource = (Resource) entity; + try { + List httpRanges = HttpRange.parseRanges(rangeHeader); + serverResponse.getServletResponse().setStatus(HttpStatus.PARTIAL_CONTENT.value()); + entity = HttpRange.toResourceRegions(httpRanges, resource); + entityClass = entity.getClass(); + entityType = RESOURCE_REGION_LIST_TYPE; + } + catch (IllegalArgumentException ex) { + serverResponse.getHeaders().set(HttpHeaders.CONTENT_RANGE, "bytes */" + resource.contentLength()); + serverResponse.getServletResponse().setStatus(HttpStatus.REQUESTED_RANGE_NOT_SATISFIABLE.value()); + } + } + } for (HttpMessageConverter messageConverter : context.messageConverters()) { if (messageConverter instanceof GenericHttpMessageConverter) { GenericHttpMessageConverter genericMessageConverter = (GenericHttpMessageConverter) messageConverter; - if (genericMessageConverter.canWrite(this.entityType, entityClass, contentType)) { - genericMessageConverter.write(entity, this.entityType, contentType, serverResponse); + if (genericMessageConverter.canWrite(entityType, entityClass, contentType)) { + genericMessageConverter.write(entity, entityType, contentType, serverResponse); return; } } diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/function/ResourceHandlerFunctionTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/function/ResourceHandlerFunctionTests.java index 71616ee32f3..ed24be9c27d 100644 --- a/spring-webmvc/src/test/java/org/springframework/web/servlet/function/ResourceHandlerFunctionTests.java +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/function/ResourceHandlerFunctionTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,7 +17,9 @@ package org.springframework.web.servlet.function; import java.io.IOException; +import java.io.InputStream; import java.nio.file.Files; +import java.util.Arrays; import java.util.Collections; import java.util.EnumSet; import java.util.List; @@ -29,11 +31,13 @@ import org.junit.jupiter.api.Test; import org.springframework.core.io.ClassPathResource; import org.springframework.core.io.Resource; +import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; import org.springframework.http.converter.HttpMessageConverter; import org.springframework.http.converter.ResourceHttpMessageConverter; +import org.springframework.http.converter.ResourceRegionHttpMessageConverter; import org.springframework.web.servlet.ModelAndView; import org.springframework.web.testfixture.servlet.MockHttpServletRequest; import org.springframework.web.testfixture.servlet.MockHttpServletResponse; @@ -56,10 +60,11 @@ public class ResourceHandlerFunctionTests { @BeforeEach public void createContext() { this.messageConverter = new ResourceHttpMessageConverter(); + ResourceRegionHttpMessageConverter regionConverter = new ResourceRegionHttpMessageConverter(); this.context = new ServerResponse.Context() { @Override public List> messageConverters() { - return Collections.singletonList(messageConverter); + return Arrays.asList(messageConverter, regionConverter); } }; @@ -73,8 +78,7 @@ public class ResourceHandlerFunctionTests { ServerResponse response = this.handlerFunction.handle(request); assertThat(response.statusCode()).isEqualTo(HttpStatus.OK); - boolean condition = response instanceof EntityResponse; - assertThat(condition).isTrue(); + assertThat(response).isInstanceOf(EntityResponse.class); @SuppressWarnings("unchecked") EntityResponse entityResponse = (EntityResponse) response; assertThat(entityResponse.entity()).isEqualTo(this.resource); @@ -91,6 +95,61 @@ public class ResourceHandlerFunctionTests { assertThat(servletResponse.getContentLength()).isEqualTo(this.resource.contentLength()); } + @Test + public void getRange() throws IOException, ServletException { + MockHttpServletRequest servletRequest = new MockHttpServletRequest("GET", "/"); + servletRequest.addHeader("Range", "bytes=0-5"); + ServerRequest request = new DefaultServerRequest(servletRequest, Collections.singletonList(messageConverter)); + + ServerResponse response = this.handlerFunction.handle(request); + assertThat(response.statusCode()).isEqualTo(HttpStatus.OK); + assertThat(response).isInstanceOf(EntityResponse.class); + @SuppressWarnings("unchecked") + EntityResponse entityResponse = (EntityResponse) response; + assertThat(entityResponse.entity()).isEqualTo(this.resource); + + MockHttpServletResponse servletResponse = new MockHttpServletResponse(); + ModelAndView mav = response.writeTo(servletRequest, servletResponse, this.context); + assertThat(mav).isNull(); + + assertThat(servletResponse.getStatus()).isEqualTo(206); + byte[] expectedBytes = new byte[6]; + try (InputStream is = this.resource.getInputStream()) { + is.read(expectedBytes); + } + byte[] actualBytes = servletResponse.getContentAsByteArray(); + assertThat(actualBytes).isEqualTo(expectedBytes); + assertThat(servletResponse.getContentType()).isEqualTo(MediaType.TEXT_PLAIN_VALUE); + assertThat(servletResponse.getContentLength()).isEqualTo(6); + assertThat(servletResponse.getHeader(HttpHeaders.ACCEPT_RANGES)).isEqualTo("bytes"); + } + + @Test + public void getInvalidRange() throws IOException, ServletException { + MockHttpServletRequest servletRequest = new MockHttpServletRequest("GET", "/"); + servletRequest.addHeader("Range", "bytes=0-10, 0-10, 0-10, 0-10, 0-10, 0-10"); + ServerRequest request = new DefaultServerRequest(servletRequest, Collections.singletonList(messageConverter)); + + ServerResponse response = this.handlerFunction.handle(request); + assertThat(response.statusCode()).isEqualTo(HttpStatus.OK); + assertThat(response).isInstanceOf(EntityResponse.class); + @SuppressWarnings("unchecked") + EntityResponse entityResponse = (EntityResponse) response; + assertThat(entityResponse.entity()).isEqualTo(this.resource); + + MockHttpServletResponse servletResponse = new MockHttpServletResponse(); + ModelAndView mav = response.writeTo(servletRequest, servletResponse, this.context); + assertThat(mav).isNull(); + + assertThat(servletResponse.getStatus()).isEqualTo(416); + byte[] expectedBytes = Files.readAllBytes(this.resource.getFile().toPath()); + byte[] actualBytes = servletResponse.getContentAsByteArray(); + assertThat(actualBytes).isEqualTo(expectedBytes); + assertThat(servletResponse.getContentType()).isEqualTo(MediaType.TEXT_PLAIN_VALUE); + assertThat(servletResponse.getContentLength()).isEqualTo(this.resource.contentLength()); + assertThat(servletResponse.getHeader(HttpHeaders.ACCEPT_RANGES)).isEqualTo("bytes"); + } + @Test public void head() throws IOException, ServletException { MockHttpServletRequest servletRequest = new MockHttpServletRequest("HEAD", "/"); @@ -98,8 +157,7 @@ public class ResourceHandlerFunctionTests { ServerResponse response = this.handlerFunction.handle(request); assertThat(response.statusCode()).isEqualTo(HttpStatus.OK); - boolean condition = response instanceof EntityResponse; - assertThat(condition).isTrue(); + assertThat(response).isInstanceOf(EntityResponse.class); @SuppressWarnings("unchecked") EntityResponse entityResponse = (EntityResponse) response; assertThat(entityResponse.entity().getFilename()).isEqualTo(this.resource.getFilename()); @@ -136,4 +194,5 @@ public class ResourceHandlerFunctionTests { assertThat(actualBytes.length).isEqualTo(0); } + }