diff --git a/spring-web/src/main/java/org/springframework/http/HttpHeaders.java b/spring-web/src/main/java/org/springframework/http/HttpHeaders.java index 3e0a81dc65..4756017d36 100644 --- a/spring-web/src/main/java/org/springframework/http/HttpHeaders.java +++ b/spring-web/src/main/java/org/springframework/http/HttpHeaders.java @@ -35,8 +35,6 @@ import java.util.Collection; import java.util.Collections; import java.util.EnumSet; import java.util.Iterator; -import java.util.LinkedHashMap; -import java.util.LinkedList; import java.util.List; import java.util.Locale; import java.util.Map; @@ -47,7 +45,9 @@ import java.util.stream.Collectors; import org.springframework.lang.Nullable; import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; import org.springframework.util.LinkedCaseInsensitiveMap; +import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; import org.springframework.util.StringUtils; @@ -78,7 +78,8 @@ public class HttpHeaders implements MultiValueMap, Serializable /** * The empty {@code HttpHeaders} instance (immutable). */ - public static final HttpHeaders EMPTY = new HttpHeaders(new LinkedHashMap<>(), true); + public static final HttpHeaders EMPTY = + new ReadOnlyHttpHeaders(new HttpHeaders(new LinkedMultiValueMap<>(0))); /** * The HTTP {@code Accept} header field name. * @see Section 5.3.2 of RFC 7231 @@ -397,35 +398,27 @@ public class HttpHeaders implements MultiValueMap, Serializable private static final DateTimeFormatter[] DATE_FORMATTERS = new DateTimeFormatter[] { DateTimeFormatter.RFC_1123_DATE_TIME, DateTimeFormatter.ofPattern("EEEE, dd-MMM-yy HH:mm:ss zz", Locale.US), - DateTimeFormatter.ofPattern("EEE MMM dd HH:mm:ss yyyy",Locale.US).withZone(GMT) + DateTimeFormatter.ofPattern("EEE MMM dd HH:mm:ss yyyy", Locale.US).withZone(GMT) }; - private final Map> headers; - - private final boolean readOnly; + final MultiValueMap headers; /** - * Constructs a new, empty instance of the {@code HttpHeaders} object. + * Construct a new, empty instance of the {@code HttpHeaders} object. */ public HttpHeaders() { - this(new LinkedCaseInsensitiveMap<>(8, Locale.ENGLISH), false); + this(CollectionUtils.toMultiValueMap( + new LinkedCaseInsensitiveMap<>(8, Locale.ENGLISH))); } /** - * Private constructor that can create read-only {@code HttpHeader} instances. + * Construct a new {@code HttpHeaders} instance backed by an existing map. */ - private HttpHeaders(Map> headers, boolean readOnly) { - if (readOnly) { - Map> map = new LinkedCaseInsensitiveMap<>(headers.size(), Locale.ENGLISH); - headers.forEach((key, valueList) -> map.put(key, Collections.unmodifiableList(valueList))); - this.headers = Collections.unmodifiableMap(map); - } - else { - this.headers = headers; - } - this.readOnly = readOnly; + public HttpHeaders(MultiValueMap headers) { + Assert.notNull(headers, "headers must not be null"); + this.headers = headers; } @@ -1474,8 +1467,7 @@ public class HttpHeaders implements MultiValueMap, Serializable @Override @Nullable public String getFirst(String headerName) { - List headerValues = this.headers.get(headerName); - return (headerValues != null ? headerValues.get(0) : null); + return this.headers.getFirst(headerName); } /** @@ -1488,19 +1480,17 @@ public class HttpHeaders implements MultiValueMap, Serializable */ @Override public void add(String headerName, @Nullable String headerValue) { - List headerValues = this.headers.computeIfAbsent(headerName, k -> new LinkedList<>()); - headerValues.add(headerValue); + this.headers.add(headerName, headerValue); } @Override public void addAll(String key, List values) { - List currentValues = this.headers.computeIfAbsent(key, k -> new LinkedList<>()); - currentValues.addAll(values); + this.headers.addAll(key, values); } @Override public void addAll(MultiValueMap values) { - values.forEach(this::addAll); + this.headers.addAll(values); } /** @@ -1513,21 +1503,17 @@ public class HttpHeaders implements MultiValueMap, Serializable */ @Override public void set(String headerName, @Nullable String headerValue) { - List headerValues = new LinkedList<>(); - headerValues.add(headerValue); - this.headers.put(headerName, headerValues); + this.headers.set(headerName, headerValue); } @Override public void setAll(Map values) { - values.forEach(this::set); + this.headers.setAll(values); } @Override public Map toSingleValueMap() { - LinkedHashMap singleValueMap = new LinkedHashMap<>(this.headers.size()); - this.headers.forEach((key, valueList) -> singleValueMap.put(key, valueList.get(0))); - return singleValueMap; + return this.headers.toSingleValueMap(); } @@ -1623,7 +1609,12 @@ public class HttpHeaders implements MultiValueMap, Serializable */ public static HttpHeaders readOnlyHttpHeaders(HttpHeaders headers) { Assert.notNull(headers, "HttpHeaders must not be null"); - return (headers.readOnly ? headers : new HttpHeaders(headers, true)); + if (headers instanceof ReadOnlyHttpHeaders) { + return headers; + } + else { + return new ReadOnlyHttpHeaders(headers); + } } } diff --git a/spring-web/src/main/java/org/springframework/http/ReadOnlyHttpHeaders.java b/spring-web/src/main/java/org/springframework/http/ReadOnlyHttpHeaders.java new file mode 100644 index 0000000000..39f64d4b6a --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/ReadOnlyHttpHeaders.java @@ -0,0 +1,135 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http; + +import java.util.AbstractMap; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import org.springframework.lang.Nullable; +import org.springframework.util.MultiValueMap; + +/** + * {@code HttpHeaders} object that can only be read, not written to. + * + * @author Brian Clozel + * @since 5.1 + */ +class ReadOnlyHttpHeaders extends HttpHeaders { + + private static final long serialVersionUID = -8578554704772377436L; + + @Nullable + private MediaType cachedContentType; + + ReadOnlyHttpHeaders(HttpHeaders headers) { + super(headers.headers); + } + + @Override + public MediaType getContentType() { + if (this.cachedContentType != null) { + return this.cachedContentType; + } + else { + MediaType contentType = super.getContentType(); + this.cachedContentType = contentType; + return contentType; + } + } + + @Override + public List get(Object key) { + List values = this.headers.get(key); + if (values != null) { + return Collections.unmodifiableList(values); + } + return values; + } + + @Override + public void add(String headerName, @Nullable String headerValue) { + throw new UnsupportedOperationException(); + } + + @Override + public void addAll(String key, List values) { + throw new UnsupportedOperationException(); + } + + @Override + public void addAll(MultiValueMap values) { + throw new UnsupportedOperationException(); + } + + @Override + public void set(String headerName, @Nullable String headerValue) { + throw new UnsupportedOperationException(); + } + + @Override + public void setAll(Map values) { + throw new UnsupportedOperationException(); + } + + @Override + public Map toSingleValueMap() { + return Collections.unmodifiableMap(this.headers.toSingleValueMap()); + } + + @Override + public Set keySet() { + return Collections.unmodifiableSet(this.headers.keySet()); + } + + @Override + public List put(String key, List value) { + throw new UnsupportedOperationException(); + } + + @Override + public List remove(Object key) { + throw new UnsupportedOperationException(); + } + + @Override + public void putAll(Map> map) { + throw new UnsupportedOperationException(); + } + + @Override + public void clear() { + throw new UnsupportedOperationException(); + } + + @Override + public Collection> values() { + return Collections.unmodifiableCollection(this.headers.values()); + } + + @Override + public Set>> entrySet() { + return Collections.unmodifiableSet(this.headers.entrySet().stream() + .map(AbstractMap.SimpleImmutableEntry::new) + .collect(Collectors.toSet())); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/server/ServletServerHttpResponse.java b/spring-web/src/main/java/org/springframework/http/server/ServletServerHttpResponse.java index a88da3fb2b..29cd6228d3 100644 --- a/spring-web/src/main/java/org/springframework/http/server/ServletServerHttpResponse.java +++ b/spring-web/src/main/java/org/springframework/http/server/ServletServerHttpResponse.java @@ -153,6 +153,9 @@ public class ServletServerHttpResponse implements ServerHttpResponse { Assert.isInstanceOf(String.class, key, "Key must be a String-based header name"); Collection values1 = servletResponse.getHeaders((String) key); + if (headersWritten) { + return new ArrayList<>(values1); + } boolean isEmpty1 = CollectionUtils.isEmpty(values1); List values2 = super.get(key); diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractListenerServerHttpResponse.java b/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractListenerServerHttpResponse.java index 78c537e152..39dd29a9e5 100644 --- a/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractListenerServerHttpResponse.java +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractListenerServerHttpResponse.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2016 the original author or authors. + * Copyright 2002-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,6 +24,7 @@ import reactor.core.publisher.Mono; import org.springframework.core.io.buffer.DataBuffer; import org.springframework.core.io.buffer.DataBufferFactory; +import org.springframework.http.HttpHeaders; /** * Abstract base class for listener-based server responses, e.g. Servlet 3.1 @@ -41,6 +42,10 @@ public abstract class AbstractListenerServerHttpResponse extends AbstractServerH super(dataBufferFactory); } + public AbstractListenerServerHttpResponse(DataBufferFactory dataBufferFactory, HttpHeaders headers) { + super(dataBufferFactory, headers); + } + @Override protected final Mono writeWithInternal(Publisher body) { diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractServerHttpResponse.java b/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractServerHttpResponse.java index 2b21a2ef4d..b8356dbd8c 100644 --- a/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractServerHttpResponse.java +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractServerHttpResponse.java @@ -75,9 +75,14 @@ public abstract class AbstractServerHttpResponse implements ServerHttpResponse { public AbstractServerHttpResponse(DataBufferFactory dataBufferFactory) { + this(dataBufferFactory, new HttpHeaders()); + } + + public AbstractServerHttpResponse(DataBufferFactory dataBufferFactory, HttpHeaders headers) { Assert.notNull(dataBufferFactory, "DataBufferFactory must not be null"); + Assert.notNull(headers, "HttpHeaders must not be null"); this.dataBufferFactory = dataBufferFactory; - this.headers = new HttpHeaders(); + this.headers = headers; this.cookies = new LinkedMultiValueMap<>(); } diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/JettyHeadersAdapter.java b/spring-web/src/main/java/org/springframework/http/server/reactive/JettyHeadersAdapter.java new file mode 100644 index 0000000000..5ffc0779ac --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/JettyHeadersAdapter.java @@ -0,0 +1,222 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server.reactive; + +import java.util.AbstractSet; +import java.util.Collection; +import java.util.Enumeration; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import org.eclipse.jetty.http.HttpField; +import org.eclipse.jetty.http.HttpFields; + +import org.springframework.lang.Nullable; +import org.springframework.util.MultiValueMap; + +/** + * {@code MultiValueMap} implementation for wrapping Jetty HTTP headers. + * + * @author Brian Clozel + * @since 5.1 + */ +class JettyHeadersAdapter implements MultiValueMap { + + private final HttpFields headers; + + JettyHeadersAdapter(HttpFields headers) { + this.headers = headers; + } + + @Override + public String getFirst(String key) { + return this.headers.get(key); + } + + @Override + public void add(String key, @Nullable String value) { + this.headers.add(key, value); + } + + @Override + public void addAll(String key, List values) { + values.forEach(value -> add(key, value)); + } + + @Override + public void addAll(MultiValueMap values) { + values.forEach(this::addAll); + } + + @Override + public void set(String key, @Nullable String value) { + this.headers.put(key, value); + } + + @Override + public void setAll(Map values) { + values.forEach(this::set); + } + + @Override + public Map toSingleValueMap() { + Map singleValueMap = new LinkedHashMap<>(this.headers.size()); + Iterator iterator = this.headers.iterator(); + iterator.forEachRemaining(field -> { + if (!singleValueMap.containsKey(field.getName())) { + singleValueMap.put(field.getName(), field.getValue()); + } + }); + return singleValueMap; + } + + @Override + public int size() { + return this.headers.getFieldNamesCollection().size(); + } + + @Override + public boolean isEmpty() { + return this.headers.size() == 0; + } + + @Override + public boolean containsKey(Object key) { + if (key instanceof String) { + return this.headers.containsKey((String) key); + } + return false; + } + + @Override + public boolean containsValue(Object value) { + if (value instanceof String) { + return this.headers.stream() + .anyMatch(field -> field.contains((String) value)); + } + return false; + } + + @Nullable + @Override + public List get(Object key) { + if (key instanceof String) { + return this.headers.getValuesList((String) key); + } + return null; + } + + @Nullable + @Override + public List put(String key, List value) { + List oldValues = get(key); + this.headers.put(key, value); + return oldValues; + } + + @Nullable + @Override + public List remove(Object key) { + if (key instanceof String) { + List oldValues = get(key); + this.headers.remove((String) key); + return oldValues; + } + return null; + } + + @Override + public void putAll(Map> m) { + m.forEach(this::put); + } + + @Override + public void clear() { + this.headers.clear(); + } + + @Override + public Set keySet() { + return this.headers.getFieldNamesCollection(); + } + + @Override + public Collection> values() { + return this.headers.getFieldNamesCollection().stream() + .map(this.headers::getValuesList).collect(Collectors.toList()); + } + + @Override + public Set>> entrySet() { + return new AbstractSet>>() { + @Override + public Iterator>> iterator() { + return new EntryIterator(); + } + + @Override + public int size() { + return headers.size(); + } + }; + } + + private class EntryIterator implements Iterator>> { + + private Enumeration names = headers.getFieldNames(); + + @Override + public boolean hasNext() { + return this.names.hasMoreElements(); + } + + @Override + public Entry> next() { + return new HeaderEntry(this.names.nextElement()); + } + } + + private class HeaderEntry implements Entry> { + + private final String key; + + HeaderEntry(String key) { + this.key = key; + } + + @Override + public String getKey() { + return this.key.toString(); + } + + @Override + public List getValue() { + return headers.getValuesList(this.key); + } + + @Override + public List setValue(List value) { + List previousValues = headers.getValuesList(this.key); + headers.put(this.key, value); + return previousValues; + } + } +} diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/JettyHttpHandlerAdapter.java b/spring-web/src/main/java/org/springframework/http/server/reactive/JettyHttpHandlerAdapter.java index 76bcef8549..bd97b1ef6f 100644 --- a/spring-web/src/main/java/org/springframework/http/server/reactive/JettyHttpHandlerAdapter.java +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/JettyHttpHandlerAdapter.java @@ -17,15 +17,21 @@ package org.springframework.http.server.reactive; import java.io.IOException; +import java.net.URISyntaxException; import java.nio.ByteBuffer; import javax.servlet.AsyncContext; import javax.servlet.ServletResponse; +import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; +import org.eclipse.jetty.http.HttpFields; import org.eclipse.jetty.server.HttpOutput; +import org.eclipse.jetty.server.Request; +import org.eclipse.jetty.server.Response; import org.springframework.core.io.buffer.DataBuffer; import org.springframework.core.io.buffer.DataBufferFactory; +import org.springframework.http.HttpHeaders; /** * {@link ServletHttpHandlerAdapter} extension that uses Jetty APIs for writing @@ -42,6 +48,12 @@ public class JettyHttpHandlerAdapter extends ServletHttpHandlerAdapter { } + @Override + protected ServletServerHttpRequest createRequest(HttpServletRequest request, AsyncContext context) + throws IOException, URISyntaxException { + return new JettyServerHttpRequest(request, context, getServletPath(), getDataBufferFactory(), getBufferSize()); + } + @Override protected ServletServerHttpResponse createResponse(HttpServletResponse response, AsyncContext context, ServletServerHttpRequest request) throws IOException { @@ -50,14 +62,38 @@ public class JettyHttpHandlerAdapter extends ServletHttpHandlerAdapter { response, context, getDataBufferFactory(), getBufferSize(), request); } + private static final class JettyServerHttpRequest extends ServletServerHttpRequest { + + JettyServerHttpRequest(HttpServletRequest request, AsyncContext asyncContext, + String servletPath, DataBufferFactory bufferFactory, int bufferSize) + throws IOException, URISyntaxException { + + super(createHeaders(request), request, asyncContext, servletPath, bufferFactory, bufferSize); + } + + private static HttpHeaders createHeaders(HttpServletRequest request) { + HttpFields fields = ((Request) request).getMetaData().getFields(); + return new HttpHeaders(new JettyHeadersAdapter(fields)); + } + } + private static final class JettyServerHttpResponse extends ServletServerHttpResponse { - public JettyServerHttpResponse(HttpServletResponse response, AsyncContext asyncContext, + JettyServerHttpResponse(HttpServletResponse response, AsyncContext asyncContext, DataBufferFactory bufferFactory, int bufferSize, ServletServerHttpRequest request) throws IOException { - super(response, asyncContext, bufferFactory, bufferSize, request); + super(createHeaders(response), response, asyncContext, bufferFactory, bufferSize, request); + } + + private static HttpHeaders createHeaders(HttpServletResponse response) { + HttpFields fields = ((Response) response).getHttpFields(); + return new HttpHeaders(new JettyHeadersAdapter(fields)); + } + + @Override + protected void applyHeaders() { } @Override diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/NettyHeadersAdapter.java b/spring-web/src/main/java/org/springframework/http/server/reactive/NettyHeadersAdapter.java new file mode 100644 index 0000000000..6d68ceb1d8 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/NettyHeadersAdapter.java @@ -0,0 +1,217 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server.reactive; + +import java.util.AbstractSet; +import java.util.Collection; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import io.netty.handler.codec.http.HttpHeaders; + +import org.springframework.lang.Nullable; +import org.springframework.util.MultiValueMap; + +/** + * {@code MultiValueMap} implementation for wrapping Netty HTTP headers. + * + * @author Brian Clozel + * @since 5.1 + */ +class NettyHeadersAdapter implements MultiValueMap { + + private final HttpHeaders headers; + + NettyHeadersAdapter(HttpHeaders headers) { + this.headers = headers; + } + + @Override + @Nullable + public String getFirst(String key) { + return this.headers.get(key); + } + + @Override + public void add(String key, @Nullable String value) { + this.headers.add(key, value); + } + + @Override + public void addAll(String key, List values) { + this.headers.add(key, values); + } + + @Override + public void addAll(MultiValueMap values) { + values.forEach(this.headers::add); + } + + @Override + public void set(String key, @Nullable String value) { + this.headers.set(key, value); + } + + @Override + public void setAll(Map values) { + values.forEach(this.headers::set); + } + + @Override + public Map toSingleValueMap() { + Map singleValueMap = new LinkedHashMap<>(this.headers.size()); + this.headers.entries() + .forEach(entry -> { + if (!singleValueMap.containsKey(entry.getKey())) { + singleValueMap.put(entry.getKey(), entry.getValue()); + } + }); + return singleValueMap; + } + + @Override + public int size() { + return this.headers.size(); + } + + @Override + public boolean isEmpty() { + return this.headers.isEmpty(); + } + + @Override + public boolean containsKey(Object key) { + return (key instanceof String) && this.headers.contains((String) key); + } + + @Override + public boolean containsValue(Object value) { + return (value instanceof String) && + this.headers.entries().stream() + .anyMatch(entry -> value != null && value.equals(entry.getValue())); + } + + @Override + @Nullable + public List get(Object key) { + if (key instanceof String) { + return this.headers.getAll((String) key); + } + return null; + } + + @Nullable + @Override + public List put(String key, @Nullable List value) { + List previousValues = this.headers.getAll(key); + this.headers.add(key, value); + return previousValues; + } + + @Nullable + @Override + public List remove(Object key) { + if (key instanceof String) { + List previousValues = this.headers.getAll((String) key); + this.headers.remove((String) key); + return previousValues; + } + return null; + } + + @Override + public void putAll(Map> m) { + m.forEach(this.headers::add); + } + + @Override + public void clear() { + this.headers.clear(); + } + + @Override + public Set keySet() { + return this.headers.names(); + } + + @Override + public Collection> values() { + return this.headers.names().stream() + .map(this.headers::getAll).collect(Collectors.toList()); + } + + @Override + public Set>> entrySet() { + return new AbstractSet>>() { + @Override + public Iterator>> iterator() { + return new EntryIterator(); + } + + @Override + public int size() { + return headers.size(); + } + }; + } + + private class EntryIterator implements Iterator>> { + + private Iterator names = headers.names().iterator(); + + @Override + public boolean hasNext() { + return this.names.hasNext(); + } + + @Override + public Entry> next() { + return new HeaderEntry(this.names.next()); + } + } + + private class HeaderEntry implements Entry> { + + private final String key; + + HeaderEntry(String key) { + this.key = key; + } + + @Override + public String getKey() { + return this.key; + } + + @Override + public List getValue() { + return headers.getAll(this.key); + } + + @Override + public List setValue(List value) { + List previousValues = headers.getAll(this.key); + headers.set(this.key, value); + return previousValues; + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/ReactorServerHttpRequest.java b/spring-web/src/main/java/org/springframework/http/server/reactive/ReactorServerHttpRequest.java index fb3a88e3bb..884f2ccd51 100644 --- a/spring-web/src/main/java/org/springframework/http/server/reactive/ReactorServerHttpRequest.java +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/ReactorServerHttpRequest.java @@ -125,11 +125,8 @@ class ReactorServerHttpRequest extends AbstractServerHttpRequest { } private static HttpHeaders initHeaders(HttpServerRequest channel) { - HttpHeaders headers = new HttpHeaders(); - for (String name : channel.requestHeaders().names()) { - headers.put(name, channel.requestHeaders().getAll(name)); - } - return headers; + NettyHeadersAdapter headersMap = new NettyHeadersAdapter(channel.requestHeaders()); + return new HttpHeaders(headersMap); } diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/ReactorServerHttpResponse.java b/spring-web/src/main/java/org/springframework/http/server/reactive/ReactorServerHttpResponse.java index e2dd16ed7f..b536a6d960 100644 --- a/spring-web/src/main/java/org/springframework/http/server/reactive/ReactorServerHttpResponse.java +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/ReactorServerHttpResponse.java @@ -19,6 +19,7 @@ package org.springframework.http.server.reactive; import java.nio.file.Path; import io.netty.buffer.ByteBuf; +import io.netty.handler.codec.http.HttpHeaderNames; import io.netty.handler.codec.http.HttpResponseStatus; import io.netty.handler.codec.http.cookie.Cookie; import io.netty.handler.codec.http.cookie.DefaultCookie; @@ -30,6 +31,7 @@ import reactor.netty.http.server.HttpServerResponse; import org.springframework.core.io.buffer.DataBuffer; import org.springframework.core.io.buffer.DataBufferFactory; import org.springframework.core.io.buffer.NettyDataBufferFactory; +import org.springframework.http.HttpHeaders; import org.springframework.http.ResponseCookie; import org.springframework.http.ZeroCopyHttpOutputMessage; import org.springframework.util.Assert; @@ -47,11 +49,16 @@ class ReactorServerHttpResponse extends AbstractServerHttpResponse implements Ze public ReactorServerHttpResponse(HttpServerResponse response, DataBufferFactory bufferFactory) { - super(bufferFactory); + super(bufferFactory, initHeaders(response)); Assert.notNull(response, "HttpServerResponse must not be null"); this.response = response; } + private static HttpHeaders initHeaders(HttpServerResponse channel) { + channel.responseHeaders().remove(HttpHeaderNames.TRANSFER_ENCODING); + NettyHeadersAdapter headersMap = new NettyHeadersAdapter(channel.responseHeaders()); + return new HttpHeaders(headersMap); + } @SuppressWarnings("unchecked") @Override @@ -80,11 +87,9 @@ class ReactorServerHttpResponse extends AbstractServerHttpResponse implements Ze @Override protected void applyHeaders() { - getHeaders().forEach((headerName, headerValues) -> { - for (String value : headerValues) { - this.response.responseHeaders().add(headerName, value); - } - }); + if (getHeaders().getContentLength() == -1) { + this.response.chunkedTransfer(true); + } } @Override diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/ServletServerHttpRequest.java b/spring-web/src/main/java/org/springframework/http/server/reactive/ServletServerHttpRequest.java index b2437e268c..e2d6ec2634 100644 --- a/spring-web/src/main/java/org/springframework/http/server/reactive/ServletServerHttpRequest.java +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/ServletServerHttpRequest.java @@ -69,12 +69,18 @@ class ServletServerHttpRequest extends AbstractServerHttpRequest { private final byte[] buffer; - public ServletServerHttpRequest(HttpServletRequest request, AsyncContext asyncContext, String servletPath, DataBufferFactory bufferFactory, int bufferSize) throws IOException, URISyntaxException { - super(initUri(request), request.getContextPath() + servletPath, initHeaders(request)); + this(createDefaultHttpHeaders(request), request, asyncContext, servletPath, bufferFactory, bufferSize); + } + + public ServletServerHttpRequest(HttpHeaders headers, HttpServletRequest request, AsyncContext asyncContext, + String servletPath, DataBufferFactory bufferFactory, int bufferSize) + throws IOException, URISyntaxException { + + super(initUri(request), request.getContextPath() + servletPath, initHeaders(headers, request)); Assert.notNull(bufferFactory, "'bufferFactory' must not be null"); Assert.isTrue(bufferSize > 0, "'bufferSize' must be higher than 0"); @@ -91,6 +97,18 @@ class ServletServerHttpRequest extends AbstractServerHttpRequest { this.bodyPublisher.registerReadListener(); } + + private static HttpHeaders createDefaultHttpHeaders(HttpServletRequest request) { + HttpHeaders headers = new HttpHeaders(); + for (Enumeration names = request.getHeaderNames(); names.hasMoreElements(); ) { + String name = (String) names.nextElement(); + for (Enumeration values = request.getHeaders(name); values.hasMoreElements(); ) { + headers.add(name, (String) values.nextElement()); + } + } + return headers; + } + private static URI initUri(HttpServletRequest request) throws URISyntaxException { Assert.notNull(request, "'request' must not be null"); StringBuffer url = request.getRequestURL(); @@ -101,16 +119,7 @@ class ServletServerHttpRequest extends AbstractServerHttpRequest { return new URI(url.toString()); } - private static HttpHeaders initHeaders(HttpServletRequest request) { - HttpHeaders headers = new HttpHeaders(); - for (Enumeration names = request.getHeaderNames(); - names.hasMoreElements(); ) { - String name = (String) names.nextElement(); - for (Enumeration values = request.getHeaders(name); - values.hasMoreElements(); ) { - headers.add(name, (String) values.nextElement()); - } - } + private static HttpHeaders initHeaders(HttpHeaders headers, HttpServletRequest request) { MediaType contentType = headers.getContentType(); if (contentType == null) { String requestContentType = request.getContentType(); @@ -231,7 +240,8 @@ class ServletServerHttpRequest extends AbstractServerHttpRequest { private final class RequestAsyncListener implements AsyncListener { @Override - public void onStartAsync(AsyncEvent event) {} + public void onStartAsync(AsyncEvent event) { + } @Override public void onTimeout(AsyncEvent event) { diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/ServletServerHttpResponse.java b/spring-web/src/main/java/org/springframework/http/server/reactive/ServletServerHttpResponse.java index 81d9c0c549..9d11809031 100644 --- a/spring-web/src/main/java/org/springframework/http/server/reactive/ServletServerHttpResponse.java +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/ServletServerHttpResponse.java @@ -33,6 +33,7 @@ import org.reactivestreams.Publisher; import org.springframework.core.io.buffer.DataBuffer; import org.springframework.core.io.buffer.DataBufferFactory; import org.springframework.core.io.buffer.DataBufferUtils; +import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; import org.springframework.http.ResponseCookie; import org.springframework.lang.Nullable; @@ -62,11 +63,16 @@ class ServletServerHttpResponse extends AbstractListenerServerHttpResponse { private final ServletServerHttpRequest request; - public ServletServerHttpResponse(HttpServletResponse response, AsyncContext asyncContext, DataBufferFactory bufferFactory, int bufferSize, ServletServerHttpRequest request) throws IOException { - super(bufferFactory); + this(new HttpHeaders(), response, asyncContext, bufferFactory, bufferSize, request); + } + + public ServletServerHttpResponse(HttpHeaders headers, HttpServletResponse response, AsyncContext asyncContext, + DataBufferFactory bufferFactory, int bufferSize, ServletServerHttpRequest request) throws IOException { + + super(bufferFactory, headers); Assert.notNull(response, "HttpServletResponse must not be null"); Assert.notNull(bufferFactory, "DataBufferFactory must not be null"); diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/TomcatHeadersAdapter.java b/spring-web/src/main/java/org/springframework/http/server/reactive/TomcatHeadersAdapter.java new file mode 100644 index 0000000000..e667a3e86c --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/TomcatHeadersAdapter.java @@ -0,0 +1,237 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server.reactive; + +import java.util.AbstractSet; +import java.util.Collection; +import java.util.Collections; +import java.util.Enumeration; +import java.util.HashSet; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import org.apache.tomcat.util.buf.MessageBytes; +import org.apache.tomcat.util.http.MimeHeaders; + +import org.springframework.lang.Nullable; +import org.springframework.util.MultiValueMap; + +/** + * {@code MultiValueMap} implementation for wrapping Tomcat HTTP headers. + * + * @author Brian Clozel + * @since 5.1 + */ +class TomcatHeadersAdapter implements MultiValueMap { + + private final MimeHeaders headers; + + TomcatHeadersAdapter(MimeHeaders headers) { + this.headers = headers; + } + + @Override + public String getFirst(String key) { + return this.headers.getHeader(key); + } + + @Override + public void add(String key, String value) { + this.headers.addValue(key).setString(value); + } + + @Override + public void addAll(String key, List values) { + values.forEach(value -> add(key, value)); + } + + @Override + public void addAll(MultiValueMap values) { + values.forEach(this::addAll); + } + + @Override + public void set(String key, String value) { + this.headers.setValue(key).setString(value); + } + + @Override + public void setAll(Map values) { + values.forEach(this::set); + } + + @Override + public Map toSingleValueMap() { + Map singleValueMap = new LinkedHashMap<>(this.headers.size()); + this.keySet().forEach(key -> singleValueMap.put(key, getFirst(key))); + return singleValueMap; + } + + @Override + public int size() { + Enumeration names = this.headers.names(); + int size = 0; + while (names.hasMoreElements()) { + size++; + names.nextElement(); + } + return size; + } + + @Override + public boolean isEmpty() { + return this.headers.size() == 0; + } + + @Override + public boolean containsKey(Object key) { + if (key instanceof String) { + return this.headers.findHeader((String) key, 0) != -1; + } + return false; + } + + @Override + public boolean containsValue(Object value) { + if (value instanceof String) { + MessageBytes needle = MessageBytes.newInstance(); + needle.setString((String) value); + for (int i = 0; i < this.headers.size(); i++) { + if (this.headers.getValue(i).equals(needle)) { + return true; + } + } + } + return false; + } + + @Override + @Nullable + public List get(Object key) { + if (key instanceof String) { + return Collections.list(this.headers.values((String) key)); + } + return null; + } + + @Override + @Nullable + public List put(String key, List value) { + List previousValues = get(key); + value.forEach(v -> this.headers.addValue(key).setString(v)); + return previousValues; + } + + @Override + @Nullable + public List remove(Object key) { + if (key instanceof String) { + List previousValues = get(key); + this.headers.removeHeader((String) key); + return previousValues; + } + return null; + } + + @Override + public void putAll(Map> m) { + m.forEach(this::put); + } + + @Override + public void clear() { + this.headers.clear(); + } + + @Override + public Set keySet() { + Set result = new HashSet<>(8); + Enumeration names = this.headers.names(); + while (names.hasMoreElements()) { + result.add(names.nextElement()); + } + return result; + } + + @Override + public Collection> values() { + return keySet().stream().map(this::get).collect(Collectors.toList()); + } + + @Override + public Set>> entrySet() { + return new AbstractSet>>() { + @Override + public Iterator>> iterator() { + return new EntryIterator(); + } + + @Override + public int size() { + return headers.size(); + } + }; + } + + private class EntryIterator implements Iterator>> { + + private Enumeration names = headers.names(); + + @Override + public boolean hasNext() { + return this.names.hasMoreElements(); + } + + @Override + public Entry> next() { + return new HeaderEntry(this.names.nextElement()); + } + } + + private final class HeaderEntry implements Entry> { + + private final String key; + + private HeaderEntry(String key) { + this.key = key; + } + + @Override + public String getKey() { + return this.key; + } + + @Nullable + @Override + public List getValue() { + return get(this.key); + } + + @Nullable + @Override + public List setValue(List value) { + List previous = getValue(); + headers.removeHeader(this.key); + addAll(this.key, value); + return previous; + } + } +} diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/TomcatHttpHandlerAdapter.java b/spring-web/src/main/java/org/springframework/http/server/reactive/TomcatHttpHandlerAdapter.java index 2041234428..89851e938a 100644 --- a/spring-web/src/main/java/org/springframework/http/server/reactive/TomcatHttpHandlerAdapter.java +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/TomcatHttpHandlerAdapter.java @@ -17,6 +17,7 @@ package org.springframework.http.server.reactive; import java.io.IOException; +import java.lang.reflect.Field; import java.net.URISyntaxException; import java.nio.ByteBuffer; import javax.servlet.AsyncContext; @@ -27,17 +28,25 @@ import javax.servlet.http.HttpServletResponse; import org.apache.catalina.connector.CoyoteInputStream; import org.apache.catalina.connector.CoyoteOutputStream; +import org.apache.catalina.connector.RequestFacade; +import org.apache.catalina.connector.ResponseFacade; +import org.apache.coyote.Request; +import org.apache.coyote.Response; import org.springframework.core.io.buffer.DataBuffer; import org.springframework.core.io.buffer.DataBufferFactory; import org.springframework.core.io.buffer.DataBufferUtils; +import org.springframework.http.HttpHeaders; import org.springframework.util.Assert; +import org.springframework.util.ReflectionUtils; /** * {@link ServletHttpHandlerAdapter} extension that uses Tomcat APIs for reading * from the request and writing to the response with {@link ByteBuffer}. * * @author Violeta Georgieva + * @author Brian Clozel + * @author Brian Clozel * @since 5.0 * @see org.springframework.web.server.adapter.AbstractReactiveWebInitializer */ @@ -66,21 +75,39 @@ public class TomcatHttpHandlerAdapter extends ServletHttpHandlerAdapter { response, asyncContext, getDataBufferFactory(), getBufferSize(), request); } + private static final class TomcatServerHttpRequest extends ServletServerHttpRequest { - private final class TomcatServerHttpRequest extends ServletServerHttpRequest { + private static final Field COYOTE_REQUEST_FIELD = ReflectionUtils.findField(RequestFacade.class, "request"); - public TomcatServerHttpRequest(HttpServletRequest request, AsyncContext context, + private final int bufferSize; + + private final DataBufferFactory factory; + + static { + ReflectionUtils.makeAccessible(COYOTE_REQUEST_FIELD); + } + + TomcatServerHttpRequest(HttpServletRequest request, AsyncContext context, String servletPath, DataBufferFactory factory, int bufferSize) throws IOException, URISyntaxException { - super(request, context, servletPath, factory, bufferSize); + super(createTomcatHttpHeaders(request), request, context, servletPath, factory, bufferSize); + this.factory = factory; + this.bufferSize = bufferSize; + } + + private static HttpHeaders createTomcatHttpHeaders(HttpServletRequest request) { + Request tomcatRequest = ((org.apache.catalina.connector.Request) ReflectionUtils + .getField(COYOTE_REQUEST_FIELD, request)).getCoyoteRequest(); + TomcatHeadersAdapter headers = new TomcatHeadersAdapter(tomcatRequest.getMimeHeaders()); + return new HttpHeaders(headers); } @Override protected DataBuffer readFromInputStream() throws IOException { boolean release = true; - int capacity = getBufferSize(); - DataBuffer dataBuffer = getDataBufferFactory().allocateBuffer(capacity); + int capacity = this.bufferSize; + DataBuffer dataBuffer = this.factory.allocateBuffer(capacity); try { ByteBuffer byteBuffer = dataBuffer.asByteBuffer(0, capacity); @@ -111,10 +138,27 @@ public class TomcatHttpHandlerAdapter extends ServletHttpHandlerAdapter { private static final class TomcatServerHttpResponse extends ServletServerHttpResponse { - public TomcatServerHttpResponse(HttpServletResponse response, AsyncContext context, + private static final Field COYOTE_RESPONSE_FIELD = ReflectionUtils.findField(ResponseFacade.class, "response"); + + static { + ReflectionUtils.makeAccessible(COYOTE_RESPONSE_FIELD); + } + + TomcatServerHttpResponse(HttpServletResponse response, AsyncContext context, DataBufferFactory factory, int bufferSize, ServletServerHttpRequest request) throws IOException { - super(response, context, factory, bufferSize, request); + super(createTomcatHttpHeaders(response), response, context, factory, bufferSize, request); + } + + private static HttpHeaders createTomcatHttpHeaders(HttpServletResponse response) { + Response tomcatResponse = ((org.apache.catalina.connector.Response) ReflectionUtils + .getField(COYOTE_RESPONSE_FIELD, response)).getCoyoteResponse(); + TomcatHeadersAdapter headers = new TomcatHeadersAdapter(tomcatResponse.getMimeHeaders()); + return new HttpHeaders(headers); + } + + @Override + protected void applyHeaders() { } @Override diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowHeadersAdapter.java b/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowHeadersAdapter.java new file mode 100644 index 0000000000..3e817c906f --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowHeadersAdapter.java @@ -0,0 +1,222 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server.reactive; + +import java.util.AbstractSet; +import java.util.Collection; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import io.undertow.util.HeaderMap; +import io.undertow.util.HeaderValues; +import io.undertow.util.HttpString; + +import org.springframework.lang.Nullable; +import org.springframework.util.MultiValueMap; + +/** + * {@code MultiValueMap} implementation for wrapping Undertow HTTP headers. + * + * @author Brian Clozel + * @since 5.1 + */ +class UndertowHeadersAdapter implements MultiValueMap { + + private final HeaderMap headers; + + UndertowHeadersAdapter(HeaderMap headers) { + this.headers = headers; + } + + @Override + public String getFirst(String key) { + return this.headers.getFirst(key); + } + + @Override + public void add(String key, @Nullable String value) { + this.headers.add(HttpString.tryFromString(key), value); + } + + @Override + @SuppressWarnings("unchecked") + public void addAll(String key, List values) { + this.headers.addAll(HttpString.tryFromString(key), (List) values); + } + + @Override + public void addAll(MultiValueMap values) { + values.forEach((key, list) -> this.headers.addAll(HttpString.tryFromString(key), list)); + } + + @Override + public void set(String key, @Nullable String value) { + this.headers.put(HttpString.tryFromString(key), value); + } + + @Override + public void setAll(Map values) { + values.forEach((key, list) -> this.headers.put(HttpString.tryFromString(key), list)); + } + + @Override + public Map toSingleValueMap() { + Map singleValueMap = new LinkedHashMap<>(this.headers.size()); + this.headers.forEach(values -> + singleValueMap.put(values.getHeaderName().toString(), values.getFirst())); + return singleValueMap; + } + + @Override + public int size() { + return this.headers.size(); + } + + @Override + public boolean isEmpty() { + return this.headers.size() == 0; + } + + @Override + public boolean containsKey(Object key) { + if (key instanceof String) { + return this.headers.contains((String) key); + } + return false; + } + + @Override + public boolean containsValue(Object value) { + if (value instanceof String) { + return this.headers.getHeaderNames().stream() + .map(this.headers::get) + .anyMatch(values -> values.contains(value)); + } + return false; + } + + @Override + @Nullable + public List get(Object key) { + if (key instanceof String) { + return this.headers.get((String) key); + } + return null; + } + + @Override + @Nullable + public List put(String key, List value) { + HeaderValues previousValues = this.headers.get(key); + this.headers.putAll(HttpString.tryFromString(key), value); + return previousValues; + } + + @Override + @Nullable + public List remove(Object key) { + if (key instanceof String) { + this.headers.remove((String) key); + } + return null; + } + + @Override + public void putAll(Map> m) { + m.forEach((key, values) -> + this.headers.putAll(HttpString.tryFromString(key), values)); + } + + @Override + public void clear() { + this.headers.clear(); + } + + @Override + public Set keySet() { + return this.headers.getHeaderNames().stream() + .map(HttpString::toString) + .collect(Collectors.toSet()); + } + + @Override + public Collection> values() { + return this.headers.getHeaderNames().stream() + .map(this.headers::get) + .collect(Collectors.toList()); + } + + @Override + public Set>> entrySet() { + return new AbstractSet>>() { + @Override + public Iterator>> iterator() { + return new EntryIterator(); + } + + @Override + public int size() { + return headers.size(); + } + }; + } + + private class EntryIterator implements Iterator>> { + + private Iterator names = headers.getHeaderNames().iterator(); + + @Override + public boolean hasNext() { + return this.names.hasNext(); + } + + @Override + public Entry> next() { + return new HeaderEntry(this.names.next()); + } + } + + private class HeaderEntry implements Entry> { + + private final HttpString key; + + HeaderEntry(HttpString key) { + this.key = key; + } + + @Override + public String getKey() { + return this.key.toString(); + } + + @Override + public List getValue() { + return headers.get(this.key); + } + + @Override + public List setValue(List value) { + List previousValues = headers.get(this.key); + headers.putAll(this.key, value); + return previousValues; + } + } +} diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpRequest.java b/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpRequest.java index da6c061391..6c68f5a528 100644 --- a/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpRequest.java +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpRequest.java @@ -30,7 +30,6 @@ import io.undertow.connector.ByteBufferPool; import io.undertow.connector.PooledByteBuffer; import io.undertow.server.HttpServerExchange; import io.undertow.server.handlers.Cookie; -import io.undertow.util.HeaderValues; import org.xnio.channels.StreamSourceChannel; import reactor.core.publisher.Flux; @@ -79,11 +78,9 @@ class UndertowServerHttpRequest extends AbstractServerHttpRequest { } private static HttpHeaders initHeaders(HttpServerExchange exchange) { - HttpHeaders headers = new HttpHeaders(); - for (HeaderValues values : exchange.getRequestHeaders()) { - headers.put(values.getHeaderName().toString(), values); - } - return headers; + UndertowHeadersAdapter headersMap = + new UndertowHeadersAdapter(exchange.getRequestHeaders()); + return new HttpHeaders(headersMap); } @Override diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpResponse.java b/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpResponse.java index 1ad997896c..a9533379cd 100644 --- a/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpResponse.java +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpResponse.java @@ -25,7 +25,6 @@ import java.nio.file.StandardOpenOption; import io.undertow.server.HttpServerExchange; import io.undertow.server.handlers.Cookie; import io.undertow.server.handlers.CookieImpl; -import io.undertow.util.HttpString; import org.reactivestreams.Processor; import org.reactivestreams.Publisher; import org.xnio.channels.Channels; @@ -35,6 +34,7 @@ import reactor.core.publisher.Mono; import org.springframework.core.io.buffer.DataBuffer; import org.springframework.core.io.buffer.DataBufferFactory; import org.springframework.core.io.buffer.DataBufferUtils; +import org.springframework.http.HttpHeaders; import org.springframework.http.ResponseCookie; import org.springframework.http.ZeroCopyHttpOutputMessage; import org.springframework.lang.Nullable; @@ -58,15 +58,21 @@ class UndertowServerHttpResponse extends AbstractListenerServerHttpResponse impl private StreamSinkChannel responseChannel; - public UndertowServerHttpResponse( + UndertowServerHttpResponse( HttpServerExchange exchange, DataBufferFactory bufferFactory, UndertowServerHttpRequest request) { - super(bufferFactory); + super(bufferFactory, createHeaders(exchange)); Assert.notNull(exchange, "HttpServerExchange must not be null"); this.exchange = exchange; this.request = request; } + private static HttpHeaders createHeaders(HttpServerExchange exchange) { + UndertowHeadersAdapter headersMap = + new UndertowHeadersAdapter(exchange.getResponseHeaders()); + return new HttpHeaders(headersMap); + } + @SuppressWarnings("unchecked") @Override @@ -85,8 +91,6 @@ class UndertowServerHttpResponse extends AbstractListenerServerHttpResponse impl @Override protected void applyHeaders() { - getHeaders().forEach((headerName, headerValues) -> - this.exchange.getResponseHeaders().addAll(HttpString.tryFromString(headerName), headerValues)); } @Override diff --git a/spring-web/src/main/java/org/springframework/web/cors/reactive/DefaultCorsProcessor.java b/spring-web/src/main/java/org/springframework/web/cors/reactive/DefaultCorsProcessor.java index b673e1e463..f5027686de 100644 --- a/spring-web/src/main/java/org/springframework/web/cors/reactive/DefaultCorsProcessor.java +++ b/spring-web/src/main/java/org/springframework/web/cors/reactive/DefaultCorsProcessor.java @@ -87,7 +87,7 @@ public class DefaultCorsProcessor implements CorsProcessor { } private boolean responseHasCors(ServerHttpResponse response) { - return (response.getHeaders().getAccessControlAllowOrigin() != null); + return response.getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN) != null; } /** diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/socket/server/support/HandshakeWebSocketService.java b/spring-webflux/src/main/java/org/springframework/web/reactive/socket/server/support/HandshakeWebSocketService.java index f5061f960e..d3e353bf90 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/socket/server/support/HandshakeWebSocketService.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/socket/server/support/HandshakeWebSocketService.java @@ -272,7 +272,10 @@ public class HandshakeWebSocketService implements WebSocketService, Lifecycle { @Nullable String protocol, Map attributes) { URI uri = request.getURI(); - HttpHeaders headers = request.getHeaders(); + // Copy request headers, as they might be pooled and recycled by + // the server implementation once the handshake HTTP exchange is done. + HttpHeaders headers = new HttpHeaders(); + headers.addAll(request.getHeaders()); Mono principal = exchange.getPrincipal(); String logPrefix = exchange.getLogPrefix(); InetSocketAddress remoteAddress = request.getRemoteAddress();