From ce7278aaf4f20348862267c2081c20dc5bd77128 Mon Sep 17 00:00:00 2001 From: Brian Clozel Date: Wed, 3 Oct 2018 21:18:24 +0200 Subject: [PATCH] Optimize HTTP headers management Several benchmarks underlined a few hotspots for CPU and GC pressure in the Spring Framework codebase: 1. `org.springframework.util.MimeType.(String, String, Map)` 2. `org.springframework.util.LinkedCaseInsensitiveMap.convertKey(String)` Both are linked with HTTP request headers parsing and response headers writin during the exchange processing phase. 1) is linked to repeated calls to `HttpHeaders.getContentType` within a single request handling. The media type parsing operation is expensive and the result doesn't change between calls, since the request headers are immutable at that point. This commit improves this by caching the parsed `MediaType` for the `"Content-Type"` request header in the `ReadOnlyHttpHeaders` class. This change is available for both Spring MVC and Spring WebFlux. 2) is linked to insertions/lookups in the `LinkedCaseInsensitiveMap`, which is the data structure behind `HttpHeaders`. Those operations are creating a lot of garbage (including a lot of `String` created by `toLowerCase`). We could choose a more efficient data structure for storing HTTP headers data. As a first step, this commit is focusing on Spring WebFlux and introduces `MultiValueMap` implementations mapped by native HTTP headers for the following servers: Tomcat, Jetty, Netty and Undertow. Such implementations avoid unnecessary copying of the headers and leverages as much as possible optimized operations provided by the native implementations. This change has a few consequences: * `HttpHeaders` can now wrap a `MultiValueMap` directly * The default constructor of `HttpHeaders` is still backed by a `LinkedCaseInsensitiveMap` * The HTTP request headers for the websocket HTTP handshake now need to be cloned, because native headers are likely to be pooled/recycled by the server implementation, hence gone when the initial HTTP exchange is done Issue: SPR-17250 --- .../org/springframework/http/HttpHeaders.java | 61 ++--- .../http/ReadOnlyHttpHeaders.java | 135 ++++++++++ .../server/ServletServerHttpResponse.java | 3 + .../AbstractListenerServerHttpResponse.java | 7 +- .../reactive/AbstractServerHttpResponse.java | 7 +- .../server/reactive/JettyHeadersAdapter.java | 222 ++++++++++++++++ .../reactive/JettyHttpHandlerAdapter.java | 40 ++- .../server/reactive/NettyHeadersAdapter.java | 217 ++++++++++++++++ .../reactive/ReactorServerHttpRequest.java | 7 +- .../reactive/ReactorServerHttpResponse.java | 17 +- .../reactive/ServletServerHttpRequest.java | 36 ++- .../reactive/ServletServerHttpResponse.java | 10 +- .../server/reactive/TomcatHeadersAdapter.java | 237 ++++++++++++++++++ .../reactive/TomcatHttpHandlerAdapter.java | 58 ++++- .../reactive/UndertowHeadersAdapter.java | 222 ++++++++++++++++ .../reactive/UndertowServerHttpRequest.java | 9 +- .../reactive/UndertowServerHttpResponse.java | 14 +- .../cors/reactive/DefaultCorsProcessor.java | 2 +- .../support/HandshakeWebSocketService.java | 5 +- 19 files changed, 1224 insertions(+), 85 deletions(-) create mode 100644 spring-web/src/main/java/org/springframework/http/ReadOnlyHttpHeaders.java create mode 100644 spring-web/src/main/java/org/springframework/http/server/reactive/JettyHeadersAdapter.java create mode 100644 spring-web/src/main/java/org/springframework/http/server/reactive/NettyHeadersAdapter.java create mode 100644 spring-web/src/main/java/org/springframework/http/server/reactive/TomcatHeadersAdapter.java create mode 100644 spring-web/src/main/java/org/springframework/http/server/reactive/UndertowHeadersAdapter.java diff --git a/spring-web/src/main/java/org/springframework/http/HttpHeaders.java b/spring-web/src/main/java/org/springframework/http/HttpHeaders.java index 3e0a81dc65..4756017d36 100644 --- a/spring-web/src/main/java/org/springframework/http/HttpHeaders.java +++ b/spring-web/src/main/java/org/springframework/http/HttpHeaders.java @@ -35,8 +35,6 @@ import java.util.Collection; import java.util.Collections; import java.util.EnumSet; import java.util.Iterator; -import java.util.LinkedHashMap; -import java.util.LinkedList; import java.util.List; import java.util.Locale; import java.util.Map; @@ -47,7 +45,9 @@ import java.util.stream.Collectors; import org.springframework.lang.Nullable; import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; import org.springframework.util.LinkedCaseInsensitiveMap; +import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; import org.springframework.util.StringUtils; @@ -78,7 +78,8 @@ public class HttpHeaders implements MultiValueMap, Serializable /** * The empty {@code HttpHeaders} instance (immutable). */ - public static final HttpHeaders EMPTY = new HttpHeaders(new LinkedHashMap<>(), true); + public static final HttpHeaders EMPTY = + new ReadOnlyHttpHeaders(new HttpHeaders(new LinkedMultiValueMap<>(0))); /** * The HTTP {@code Accept} header field name. * @see Section 5.3.2 of RFC 7231 @@ -397,35 +398,27 @@ public class HttpHeaders implements MultiValueMap, Serializable private static final DateTimeFormatter[] DATE_FORMATTERS = new DateTimeFormatter[] { DateTimeFormatter.RFC_1123_DATE_TIME, DateTimeFormatter.ofPattern("EEEE, dd-MMM-yy HH:mm:ss zz", Locale.US), - DateTimeFormatter.ofPattern("EEE MMM dd HH:mm:ss yyyy",Locale.US).withZone(GMT) + DateTimeFormatter.ofPattern("EEE MMM dd HH:mm:ss yyyy", Locale.US).withZone(GMT) }; - private final Map> headers; - - private final boolean readOnly; + final MultiValueMap headers; /** - * Constructs a new, empty instance of the {@code HttpHeaders} object. + * Construct a new, empty instance of the {@code HttpHeaders} object. */ public HttpHeaders() { - this(new LinkedCaseInsensitiveMap<>(8, Locale.ENGLISH), false); + this(CollectionUtils.toMultiValueMap( + new LinkedCaseInsensitiveMap<>(8, Locale.ENGLISH))); } /** - * Private constructor that can create read-only {@code HttpHeader} instances. + * Construct a new {@code HttpHeaders} instance backed by an existing map. */ - private HttpHeaders(Map> headers, boolean readOnly) { - if (readOnly) { - Map> map = new LinkedCaseInsensitiveMap<>(headers.size(), Locale.ENGLISH); - headers.forEach((key, valueList) -> map.put(key, Collections.unmodifiableList(valueList))); - this.headers = Collections.unmodifiableMap(map); - } - else { - this.headers = headers; - } - this.readOnly = readOnly; + public HttpHeaders(MultiValueMap headers) { + Assert.notNull(headers, "headers must not be null"); + this.headers = headers; } @@ -1474,8 +1467,7 @@ public class HttpHeaders implements MultiValueMap, Serializable @Override @Nullable public String getFirst(String headerName) { - List headerValues = this.headers.get(headerName); - return (headerValues != null ? headerValues.get(0) : null); + return this.headers.getFirst(headerName); } /** @@ -1488,19 +1480,17 @@ public class HttpHeaders implements MultiValueMap, Serializable */ @Override public void add(String headerName, @Nullable String headerValue) { - List headerValues = this.headers.computeIfAbsent(headerName, k -> new LinkedList<>()); - headerValues.add(headerValue); + this.headers.add(headerName, headerValue); } @Override public void addAll(String key, List values) { - List currentValues = this.headers.computeIfAbsent(key, k -> new LinkedList<>()); - currentValues.addAll(values); + this.headers.addAll(key, values); } @Override public void addAll(MultiValueMap values) { - values.forEach(this::addAll); + this.headers.addAll(values); } /** @@ -1513,21 +1503,17 @@ public class HttpHeaders implements MultiValueMap, Serializable */ @Override public void set(String headerName, @Nullable String headerValue) { - List headerValues = new LinkedList<>(); - headerValues.add(headerValue); - this.headers.put(headerName, headerValues); + this.headers.set(headerName, headerValue); } @Override public void setAll(Map values) { - values.forEach(this::set); + this.headers.setAll(values); } @Override public Map toSingleValueMap() { - LinkedHashMap singleValueMap = new LinkedHashMap<>(this.headers.size()); - this.headers.forEach((key, valueList) -> singleValueMap.put(key, valueList.get(0))); - return singleValueMap; + return this.headers.toSingleValueMap(); } @@ -1623,7 +1609,12 @@ public class HttpHeaders implements MultiValueMap, Serializable */ public static HttpHeaders readOnlyHttpHeaders(HttpHeaders headers) { Assert.notNull(headers, "HttpHeaders must not be null"); - return (headers.readOnly ? headers : new HttpHeaders(headers, true)); + if (headers instanceof ReadOnlyHttpHeaders) { + return headers; + } + else { + return new ReadOnlyHttpHeaders(headers); + } } } diff --git a/spring-web/src/main/java/org/springframework/http/ReadOnlyHttpHeaders.java b/spring-web/src/main/java/org/springframework/http/ReadOnlyHttpHeaders.java new file mode 100644 index 0000000000..39f64d4b6a --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/ReadOnlyHttpHeaders.java @@ -0,0 +1,135 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http; + +import java.util.AbstractMap; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import org.springframework.lang.Nullable; +import org.springframework.util.MultiValueMap; + +/** + * {@code HttpHeaders} object that can only be read, not written to. + * + * @author Brian Clozel + * @since 5.1 + */ +class ReadOnlyHttpHeaders extends HttpHeaders { + + private static final long serialVersionUID = -8578554704772377436L; + + @Nullable + private MediaType cachedContentType; + + ReadOnlyHttpHeaders(HttpHeaders headers) { + super(headers.headers); + } + + @Override + public MediaType getContentType() { + if (this.cachedContentType != null) { + return this.cachedContentType; + } + else { + MediaType contentType = super.getContentType(); + this.cachedContentType = contentType; + return contentType; + } + } + + @Override + public List get(Object key) { + List values = this.headers.get(key); + if (values != null) { + return Collections.unmodifiableList(values); + } + return values; + } + + @Override + public void add(String headerName, @Nullable String headerValue) { + throw new UnsupportedOperationException(); + } + + @Override + public void addAll(String key, List values) { + throw new UnsupportedOperationException(); + } + + @Override + public void addAll(MultiValueMap values) { + throw new UnsupportedOperationException(); + } + + @Override + public void set(String headerName, @Nullable String headerValue) { + throw new UnsupportedOperationException(); + } + + @Override + public void setAll(Map values) { + throw new UnsupportedOperationException(); + } + + @Override + public Map toSingleValueMap() { + return Collections.unmodifiableMap(this.headers.toSingleValueMap()); + } + + @Override + public Set keySet() { + return Collections.unmodifiableSet(this.headers.keySet()); + } + + @Override + public List put(String key, List value) { + throw new UnsupportedOperationException(); + } + + @Override + public List remove(Object key) { + throw new UnsupportedOperationException(); + } + + @Override + public void putAll(Map> map) { + throw new UnsupportedOperationException(); + } + + @Override + public void clear() { + throw new UnsupportedOperationException(); + } + + @Override + public Collection> values() { + return Collections.unmodifiableCollection(this.headers.values()); + } + + @Override + public Set>> entrySet() { + return Collections.unmodifiableSet(this.headers.entrySet().stream() + .map(AbstractMap.SimpleImmutableEntry::new) + .collect(Collectors.toSet())); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/server/ServletServerHttpResponse.java b/spring-web/src/main/java/org/springframework/http/server/ServletServerHttpResponse.java index a88da3fb2b..29cd6228d3 100644 --- a/spring-web/src/main/java/org/springframework/http/server/ServletServerHttpResponse.java +++ b/spring-web/src/main/java/org/springframework/http/server/ServletServerHttpResponse.java @@ -153,6 +153,9 @@ public class ServletServerHttpResponse implements ServerHttpResponse { Assert.isInstanceOf(String.class, key, "Key must be a String-based header name"); Collection values1 = servletResponse.getHeaders((String) key); + if (headersWritten) { + return new ArrayList<>(values1); + } boolean isEmpty1 = CollectionUtils.isEmpty(values1); List values2 = super.get(key); diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractListenerServerHttpResponse.java b/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractListenerServerHttpResponse.java index 78c537e152..39dd29a9e5 100644 --- a/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractListenerServerHttpResponse.java +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractListenerServerHttpResponse.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2016 the original author or authors. + * Copyright 2002-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,6 +24,7 @@ import reactor.core.publisher.Mono; import org.springframework.core.io.buffer.DataBuffer; import org.springframework.core.io.buffer.DataBufferFactory; +import org.springframework.http.HttpHeaders; /** * Abstract base class for listener-based server responses, e.g. Servlet 3.1 @@ -41,6 +42,10 @@ public abstract class AbstractListenerServerHttpResponse extends AbstractServerH super(dataBufferFactory); } + public AbstractListenerServerHttpResponse(DataBufferFactory dataBufferFactory, HttpHeaders headers) { + super(dataBufferFactory, headers); + } + @Override protected final Mono writeWithInternal(Publisher body) { diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractServerHttpResponse.java b/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractServerHttpResponse.java index 2b21a2ef4d..b8356dbd8c 100644 --- a/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractServerHttpResponse.java +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractServerHttpResponse.java @@ -75,9 +75,14 @@ public abstract class AbstractServerHttpResponse implements ServerHttpResponse { public AbstractServerHttpResponse(DataBufferFactory dataBufferFactory) { + this(dataBufferFactory, new HttpHeaders()); + } + + public AbstractServerHttpResponse(DataBufferFactory dataBufferFactory, HttpHeaders headers) { Assert.notNull(dataBufferFactory, "DataBufferFactory must not be null"); + Assert.notNull(headers, "HttpHeaders must not be null"); this.dataBufferFactory = dataBufferFactory; - this.headers = new HttpHeaders(); + this.headers = headers; this.cookies = new LinkedMultiValueMap<>(); } diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/JettyHeadersAdapter.java b/spring-web/src/main/java/org/springframework/http/server/reactive/JettyHeadersAdapter.java new file mode 100644 index 0000000000..5ffc0779ac --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/JettyHeadersAdapter.java @@ -0,0 +1,222 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server.reactive; + +import java.util.AbstractSet; +import java.util.Collection; +import java.util.Enumeration; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import org.eclipse.jetty.http.HttpField; +import org.eclipse.jetty.http.HttpFields; + +import org.springframework.lang.Nullable; +import org.springframework.util.MultiValueMap; + +/** + * {@code MultiValueMap} implementation for wrapping Jetty HTTP headers. + * + * @author Brian Clozel + * @since 5.1 + */ +class JettyHeadersAdapter implements MultiValueMap { + + private final HttpFields headers; + + JettyHeadersAdapter(HttpFields headers) { + this.headers = headers; + } + + @Override + public String getFirst(String key) { + return this.headers.get(key); + } + + @Override + public void add(String key, @Nullable String value) { + this.headers.add(key, value); + } + + @Override + public void addAll(String key, List values) { + values.forEach(value -> add(key, value)); + } + + @Override + public void addAll(MultiValueMap values) { + values.forEach(this::addAll); + } + + @Override + public void set(String key, @Nullable String value) { + this.headers.put(key, value); + } + + @Override + public void setAll(Map values) { + values.forEach(this::set); + } + + @Override + public Map toSingleValueMap() { + Map singleValueMap = new LinkedHashMap<>(this.headers.size()); + Iterator iterator = this.headers.iterator(); + iterator.forEachRemaining(field -> { + if (!singleValueMap.containsKey(field.getName())) { + singleValueMap.put(field.getName(), field.getValue()); + } + }); + return singleValueMap; + } + + @Override + public int size() { + return this.headers.getFieldNamesCollection().size(); + } + + @Override + public boolean isEmpty() { + return this.headers.size() == 0; + } + + @Override + public boolean containsKey(Object key) { + if (key instanceof String) { + return this.headers.containsKey((String) key); + } + return false; + } + + @Override + public boolean containsValue(Object value) { + if (value instanceof String) { + return this.headers.stream() + .anyMatch(field -> field.contains((String) value)); + } + return false; + } + + @Nullable + @Override + public List get(Object key) { + if (key instanceof String) { + return this.headers.getValuesList((String) key); + } + return null; + } + + @Nullable + @Override + public List put(String key, List value) { + List oldValues = get(key); + this.headers.put(key, value); + return oldValues; + } + + @Nullable + @Override + public List remove(Object key) { + if (key instanceof String) { + List oldValues = get(key); + this.headers.remove((String) key); + return oldValues; + } + return null; + } + + @Override + public void putAll(Map> m) { + m.forEach(this::put); + } + + @Override + public void clear() { + this.headers.clear(); + } + + @Override + public Set keySet() { + return this.headers.getFieldNamesCollection(); + } + + @Override + public Collection> values() { + return this.headers.getFieldNamesCollection().stream() + .map(this.headers::getValuesList).collect(Collectors.toList()); + } + + @Override + public Set>> entrySet() { + return new AbstractSet>>() { + @Override + public Iterator>> iterator() { + return new EntryIterator(); + } + + @Override + public int size() { + return headers.size(); + } + }; + } + + private class EntryIterator implements Iterator>> { + + private Enumeration names = headers.getFieldNames(); + + @Override + public boolean hasNext() { + return this.names.hasMoreElements(); + } + + @Override + public Entry> next() { + return new HeaderEntry(this.names.nextElement()); + } + } + + private class HeaderEntry implements Entry> { + + private final String key; + + HeaderEntry(String key) { + this.key = key; + } + + @Override + public String getKey() { + return this.key.toString(); + } + + @Override + public List getValue() { + return headers.getValuesList(this.key); + } + + @Override + public List setValue(List value) { + List previousValues = headers.getValuesList(this.key); + headers.put(this.key, value); + return previousValues; + } + } +} diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/JettyHttpHandlerAdapter.java b/spring-web/src/main/java/org/springframework/http/server/reactive/JettyHttpHandlerAdapter.java index 76bcef8549..bd97b1ef6f 100644 --- a/spring-web/src/main/java/org/springframework/http/server/reactive/JettyHttpHandlerAdapter.java +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/JettyHttpHandlerAdapter.java @@ -17,15 +17,21 @@ package org.springframework.http.server.reactive; import java.io.IOException; +import java.net.URISyntaxException; import java.nio.ByteBuffer; import javax.servlet.AsyncContext; import javax.servlet.ServletResponse; +import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; +import org.eclipse.jetty.http.HttpFields; import org.eclipse.jetty.server.HttpOutput; +import org.eclipse.jetty.server.Request; +import org.eclipse.jetty.server.Response; import org.springframework.core.io.buffer.DataBuffer; import org.springframework.core.io.buffer.DataBufferFactory; +import org.springframework.http.HttpHeaders; /** * {@link ServletHttpHandlerAdapter} extension that uses Jetty APIs for writing @@ -42,6 +48,12 @@ public class JettyHttpHandlerAdapter extends ServletHttpHandlerAdapter { } + @Override + protected ServletServerHttpRequest createRequest(HttpServletRequest request, AsyncContext context) + throws IOException, URISyntaxException { + return new JettyServerHttpRequest(request, context, getServletPath(), getDataBufferFactory(), getBufferSize()); + } + @Override protected ServletServerHttpResponse createResponse(HttpServletResponse response, AsyncContext context, ServletServerHttpRequest request) throws IOException { @@ -50,14 +62,38 @@ public class JettyHttpHandlerAdapter extends ServletHttpHandlerAdapter { response, context, getDataBufferFactory(), getBufferSize(), request); } + private static final class JettyServerHttpRequest extends ServletServerHttpRequest { + + JettyServerHttpRequest(HttpServletRequest request, AsyncContext asyncContext, + String servletPath, DataBufferFactory bufferFactory, int bufferSize) + throws IOException, URISyntaxException { + + super(createHeaders(request), request, asyncContext, servletPath, bufferFactory, bufferSize); + } + + private static HttpHeaders createHeaders(HttpServletRequest request) { + HttpFields fields = ((Request) request).getMetaData().getFields(); + return new HttpHeaders(new JettyHeadersAdapter(fields)); + } + } + private static final class JettyServerHttpResponse extends ServletServerHttpResponse { - public JettyServerHttpResponse(HttpServletResponse response, AsyncContext asyncContext, + JettyServerHttpResponse(HttpServletResponse response, AsyncContext asyncContext, DataBufferFactory bufferFactory, int bufferSize, ServletServerHttpRequest request) throws IOException { - super(response, asyncContext, bufferFactory, bufferSize, request); + super(createHeaders(response), response, asyncContext, bufferFactory, bufferSize, request); + } + + private static HttpHeaders createHeaders(HttpServletResponse response) { + HttpFields fields = ((Response) response).getHttpFields(); + return new HttpHeaders(new JettyHeadersAdapter(fields)); + } + + @Override + protected void applyHeaders() { } @Override diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/NettyHeadersAdapter.java b/spring-web/src/main/java/org/springframework/http/server/reactive/NettyHeadersAdapter.java new file mode 100644 index 0000000000..6d68ceb1d8 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/NettyHeadersAdapter.java @@ -0,0 +1,217 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server.reactive; + +import java.util.AbstractSet; +import java.util.Collection; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import io.netty.handler.codec.http.HttpHeaders; + +import org.springframework.lang.Nullable; +import org.springframework.util.MultiValueMap; + +/** + * {@code MultiValueMap} implementation for wrapping Netty HTTP headers. + * + * @author Brian Clozel + * @since 5.1 + */ +class NettyHeadersAdapter implements MultiValueMap { + + private final HttpHeaders headers; + + NettyHeadersAdapter(HttpHeaders headers) { + this.headers = headers; + } + + @Override + @Nullable + public String getFirst(String key) { + return this.headers.get(key); + } + + @Override + public void add(String key, @Nullable String value) { + this.headers.add(key, value); + } + + @Override + public void addAll(String key, List values) { + this.headers.add(key, values); + } + + @Override + public void addAll(MultiValueMap values) { + values.forEach(this.headers::add); + } + + @Override + public void set(String key, @Nullable String value) { + this.headers.set(key, value); + } + + @Override + public void setAll(Map values) { + values.forEach(this.headers::set); + } + + @Override + public Map toSingleValueMap() { + Map singleValueMap = new LinkedHashMap<>(this.headers.size()); + this.headers.entries() + .forEach(entry -> { + if (!singleValueMap.containsKey(entry.getKey())) { + singleValueMap.put(entry.getKey(), entry.getValue()); + } + }); + return singleValueMap; + } + + @Override + public int size() { + return this.headers.size(); + } + + @Override + public boolean isEmpty() { + return this.headers.isEmpty(); + } + + @Override + public boolean containsKey(Object key) { + return (key instanceof String) && this.headers.contains((String) key); + } + + @Override + public boolean containsValue(Object value) { + return (value instanceof String) && + this.headers.entries().stream() + .anyMatch(entry -> value != null && value.equals(entry.getValue())); + } + + @Override + @Nullable + public List get(Object key) { + if (key instanceof String) { + return this.headers.getAll((String) key); + } + return null; + } + + @Nullable + @Override + public List put(String key, @Nullable List value) { + List previousValues = this.headers.getAll(key); + this.headers.add(key, value); + return previousValues; + } + + @Nullable + @Override + public List remove(Object key) { + if (key instanceof String) { + List previousValues = this.headers.getAll((String) key); + this.headers.remove((String) key); + return previousValues; + } + return null; + } + + @Override + public void putAll(Map> m) { + m.forEach(this.headers::add); + } + + @Override + public void clear() { + this.headers.clear(); + } + + @Override + public Set keySet() { + return this.headers.names(); + } + + @Override + public Collection> values() { + return this.headers.names().stream() + .map(this.headers::getAll).collect(Collectors.toList()); + } + + @Override + public Set>> entrySet() { + return new AbstractSet>>() { + @Override + public Iterator>> iterator() { + return new EntryIterator(); + } + + @Override + public int size() { + return headers.size(); + } + }; + } + + private class EntryIterator implements Iterator>> { + + private Iterator names = headers.names().iterator(); + + @Override + public boolean hasNext() { + return this.names.hasNext(); + } + + @Override + public Entry> next() { + return new HeaderEntry(this.names.next()); + } + } + + private class HeaderEntry implements Entry> { + + private final String key; + + HeaderEntry(String key) { + this.key = key; + } + + @Override + public String getKey() { + return this.key; + } + + @Override + public List getValue() { + return headers.getAll(this.key); + } + + @Override + public List setValue(List value) { + List previousValues = headers.getAll(this.key); + headers.set(this.key, value); + return previousValues; + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/ReactorServerHttpRequest.java b/spring-web/src/main/java/org/springframework/http/server/reactive/ReactorServerHttpRequest.java index fb3a88e3bb..884f2ccd51 100644 --- a/spring-web/src/main/java/org/springframework/http/server/reactive/ReactorServerHttpRequest.java +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/ReactorServerHttpRequest.java @@ -125,11 +125,8 @@ class ReactorServerHttpRequest extends AbstractServerHttpRequest { } private static HttpHeaders initHeaders(HttpServerRequest channel) { - HttpHeaders headers = new HttpHeaders(); - for (String name : channel.requestHeaders().names()) { - headers.put(name, channel.requestHeaders().getAll(name)); - } - return headers; + NettyHeadersAdapter headersMap = new NettyHeadersAdapter(channel.requestHeaders()); + return new HttpHeaders(headersMap); } diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/ReactorServerHttpResponse.java b/spring-web/src/main/java/org/springframework/http/server/reactive/ReactorServerHttpResponse.java index e2dd16ed7f..b536a6d960 100644 --- a/spring-web/src/main/java/org/springframework/http/server/reactive/ReactorServerHttpResponse.java +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/ReactorServerHttpResponse.java @@ -19,6 +19,7 @@ package org.springframework.http.server.reactive; import java.nio.file.Path; import io.netty.buffer.ByteBuf; +import io.netty.handler.codec.http.HttpHeaderNames; import io.netty.handler.codec.http.HttpResponseStatus; import io.netty.handler.codec.http.cookie.Cookie; import io.netty.handler.codec.http.cookie.DefaultCookie; @@ -30,6 +31,7 @@ import reactor.netty.http.server.HttpServerResponse; import org.springframework.core.io.buffer.DataBuffer; import org.springframework.core.io.buffer.DataBufferFactory; import org.springframework.core.io.buffer.NettyDataBufferFactory; +import org.springframework.http.HttpHeaders; import org.springframework.http.ResponseCookie; import org.springframework.http.ZeroCopyHttpOutputMessage; import org.springframework.util.Assert; @@ -47,11 +49,16 @@ class ReactorServerHttpResponse extends AbstractServerHttpResponse implements Ze public ReactorServerHttpResponse(HttpServerResponse response, DataBufferFactory bufferFactory) { - super(bufferFactory); + super(bufferFactory, initHeaders(response)); Assert.notNull(response, "HttpServerResponse must not be null"); this.response = response; } + private static HttpHeaders initHeaders(HttpServerResponse channel) { + channel.responseHeaders().remove(HttpHeaderNames.TRANSFER_ENCODING); + NettyHeadersAdapter headersMap = new NettyHeadersAdapter(channel.responseHeaders()); + return new HttpHeaders(headersMap); + } @SuppressWarnings("unchecked") @Override @@ -80,11 +87,9 @@ class ReactorServerHttpResponse extends AbstractServerHttpResponse implements Ze @Override protected void applyHeaders() { - getHeaders().forEach((headerName, headerValues) -> { - for (String value : headerValues) { - this.response.responseHeaders().add(headerName, value); - } - }); + if (getHeaders().getContentLength() == -1) { + this.response.chunkedTransfer(true); + } } @Override diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/ServletServerHttpRequest.java b/spring-web/src/main/java/org/springframework/http/server/reactive/ServletServerHttpRequest.java index b2437e268c..e2d6ec2634 100644 --- a/spring-web/src/main/java/org/springframework/http/server/reactive/ServletServerHttpRequest.java +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/ServletServerHttpRequest.java @@ -69,12 +69,18 @@ class ServletServerHttpRequest extends AbstractServerHttpRequest { private final byte[] buffer; - public ServletServerHttpRequest(HttpServletRequest request, AsyncContext asyncContext, String servletPath, DataBufferFactory bufferFactory, int bufferSize) throws IOException, URISyntaxException { - super(initUri(request), request.getContextPath() + servletPath, initHeaders(request)); + this(createDefaultHttpHeaders(request), request, asyncContext, servletPath, bufferFactory, bufferSize); + } + + public ServletServerHttpRequest(HttpHeaders headers, HttpServletRequest request, AsyncContext asyncContext, + String servletPath, DataBufferFactory bufferFactory, int bufferSize) + throws IOException, URISyntaxException { + + super(initUri(request), request.getContextPath() + servletPath, initHeaders(headers, request)); Assert.notNull(bufferFactory, "'bufferFactory' must not be null"); Assert.isTrue(bufferSize > 0, "'bufferSize' must be higher than 0"); @@ -91,6 +97,18 @@ class ServletServerHttpRequest extends AbstractServerHttpRequest { this.bodyPublisher.registerReadListener(); } + + private static HttpHeaders createDefaultHttpHeaders(HttpServletRequest request) { + HttpHeaders headers = new HttpHeaders(); + for (Enumeration names = request.getHeaderNames(); names.hasMoreElements(); ) { + String name = (String) names.nextElement(); + for (Enumeration values = request.getHeaders(name); values.hasMoreElements(); ) { + headers.add(name, (String) values.nextElement()); + } + } + return headers; + } + private static URI initUri(HttpServletRequest request) throws URISyntaxException { Assert.notNull(request, "'request' must not be null"); StringBuffer url = request.getRequestURL(); @@ -101,16 +119,7 @@ class ServletServerHttpRequest extends AbstractServerHttpRequest { return new URI(url.toString()); } - private static HttpHeaders initHeaders(HttpServletRequest request) { - HttpHeaders headers = new HttpHeaders(); - for (Enumeration names = request.getHeaderNames(); - names.hasMoreElements(); ) { - String name = (String) names.nextElement(); - for (Enumeration values = request.getHeaders(name); - values.hasMoreElements(); ) { - headers.add(name, (String) values.nextElement()); - } - } + private static HttpHeaders initHeaders(HttpHeaders headers, HttpServletRequest request) { MediaType contentType = headers.getContentType(); if (contentType == null) { String requestContentType = request.getContentType(); @@ -231,7 +240,8 @@ class ServletServerHttpRequest extends AbstractServerHttpRequest { private final class RequestAsyncListener implements AsyncListener { @Override - public void onStartAsync(AsyncEvent event) {} + public void onStartAsync(AsyncEvent event) { + } @Override public void onTimeout(AsyncEvent event) { diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/ServletServerHttpResponse.java b/spring-web/src/main/java/org/springframework/http/server/reactive/ServletServerHttpResponse.java index 81d9c0c549..9d11809031 100644 --- a/spring-web/src/main/java/org/springframework/http/server/reactive/ServletServerHttpResponse.java +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/ServletServerHttpResponse.java @@ -33,6 +33,7 @@ import org.reactivestreams.Publisher; import org.springframework.core.io.buffer.DataBuffer; import org.springframework.core.io.buffer.DataBufferFactory; import org.springframework.core.io.buffer.DataBufferUtils; +import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; import org.springframework.http.ResponseCookie; import org.springframework.lang.Nullable; @@ -62,11 +63,16 @@ class ServletServerHttpResponse extends AbstractListenerServerHttpResponse { private final ServletServerHttpRequest request; - public ServletServerHttpResponse(HttpServletResponse response, AsyncContext asyncContext, DataBufferFactory bufferFactory, int bufferSize, ServletServerHttpRequest request) throws IOException { - super(bufferFactory); + this(new HttpHeaders(), response, asyncContext, bufferFactory, bufferSize, request); + } + + public ServletServerHttpResponse(HttpHeaders headers, HttpServletResponse response, AsyncContext asyncContext, + DataBufferFactory bufferFactory, int bufferSize, ServletServerHttpRequest request) throws IOException { + + super(bufferFactory, headers); Assert.notNull(response, "HttpServletResponse must not be null"); Assert.notNull(bufferFactory, "DataBufferFactory must not be null"); diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/TomcatHeadersAdapter.java b/spring-web/src/main/java/org/springframework/http/server/reactive/TomcatHeadersAdapter.java new file mode 100644 index 0000000000..e667a3e86c --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/TomcatHeadersAdapter.java @@ -0,0 +1,237 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server.reactive; + +import java.util.AbstractSet; +import java.util.Collection; +import java.util.Collections; +import java.util.Enumeration; +import java.util.HashSet; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import org.apache.tomcat.util.buf.MessageBytes; +import org.apache.tomcat.util.http.MimeHeaders; + +import org.springframework.lang.Nullable; +import org.springframework.util.MultiValueMap; + +/** + * {@code MultiValueMap} implementation for wrapping Tomcat HTTP headers. + * + * @author Brian Clozel + * @since 5.1 + */ +class TomcatHeadersAdapter implements MultiValueMap { + + private final MimeHeaders headers; + + TomcatHeadersAdapter(MimeHeaders headers) { + this.headers = headers; + } + + @Override + public String getFirst(String key) { + return this.headers.getHeader(key); + } + + @Override + public void add(String key, String value) { + this.headers.addValue(key).setString(value); + } + + @Override + public void addAll(String key, List values) { + values.forEach(value -> add(key, value)); + } + + @Override + public void addAll(MultiValueMap values) { + values.forEach(this::addAll); + } + + @Override + public void set(String key, String value) { + this.headers.setValue(key).setString(value); + } + + @Override + public void setAll(Map values) { + values.forEach(this::set); + } + + @Override + public Map toSingleValueMap() { + Map singleValueMap = new LinkedHashMap<>(this.headers.size()); + this.keySet().forEach(key -> singleValueMap.put(key, getFirst(key))); + return singleValueMap; + } + + @Override + public int size() { + Enumeration names = this.headers.names(); + int size = 0; + while (names.hasMoreElements()) { + size++; + names.nextElement(); + } + return size; + } + + @Override + public boolean isEmpty() { + return this.headers.size() == 0; + } + + @Override + public boolean containsKey(Object key) { + if (key instanceof String) { + return this.headers.findHeader((String) key, 0) != -1; + } + return false; + } + + @Override + public boolean containsValue(Object value) { + if (value instanceof String) { + MessageBytes needle = MessageBytes.newInstance(); + needle.setString((String) value); + for (int i = 0; i < this.headers.size(); i++) { + if (this.headers.getValue(i).equals(needle)) { + return true; + } + } + } + return false; + } + + @Override + @Nullable + public List get(Object key) { + if (key instanceof String) { + return Collections.list(this.headers.values((String) key)); + } + return null; + } + + @Override + @Nullable + public List put(String key, List value) { + List previousValues = get(key); + value.forEach(v -> this.headers.addValue(key).setString(v)); + return previousValues; + } + + @Override + @Nullable + public List remove(Object key) { + if (key instanceof String) { + List previousValues = get(key); + this.headers.removeHeader((String) key); + return previousValues; + } + return null; + } + + @Override + public void putAll(Map> m) { + m.forEach(this::put); + } + + @Override + public void clear() { + this.headers.clear(); + } + + @Override + public Set keySet() { + Set result = new HashSet<>(8); + Enumeration names = this.headers.names(); + while (names.hasMoreElements()) { + result.add(names.nextElement()); + } + return result; + } + + @Override + public Collection> values() { + return keySet().stream().map(this::get).collect(Collectors.toList()); + } + + @Override + public Set>> entrySet() { + return new AbstractSet>>() { + @Override + public Iterator>> iterator() { + return new EntryIterator(); + } + + @Override + public int size() { + return headers.size(); + } + }; + } + + private class EntryIterator implements Iterator>> { + + private Enumeration names = headers.names(); + + @Override + public boolean hasNext() { + return this.names.hasMoreElements(); + } + + @Override + public Entry> next() { + return new HeaderEntry(this.names.nextElement()); + } + } + + private final class HeaderEntry implements Entry> { + + private final String key; + + private HeaderEntry(String key) { + this.key = key; + } + + @Override + public String getKey() { + return this.key; + } + + @Nullable + @Override + public List getValue() { + return get(this.key); + } + + @Nullable + @Override + public List setValue(List value) { + List previous = getValue(); + headers.removeHeader(this.key); + addAll(this.key, value); + return previous; + } + } +} diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/TomcatHttpHandlerAdapter.java b/spring-web/src/main/java/org/springframework/http/server/reactive/TomcatHttpHandlerAdapter.java index 2041234428..89851e938a 100644 --- a/spring-web/src/main/java/org/springframework/http/server/reactive/TomcatHttpHandlerAdapter.java +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/TomcatHttpHandlerAdapter.java @@ -17,6 +17,7 @@ package org.springframework.http.server.reactive; import java.io.IOException; +import java.lang.reflect.Field; import java.net.URISyntaxException; import java.nio.ByteBuffer; import javax.servlet.AsyncContext; @@ -27,17 +28,25 @@ import javax.servlet.http.HttpServletResponse; import org.apache.catalina.connector.CoyoteInputStream; import org.apache.catalina.connector.CoyoteOutputStream; +import org.apache.catalina.connector.RequestFacade; +import org.apache.catalina.connector.ResponseFacade; +import org.apache.coyote.Request; +import org.apache.coyote.Response; import org.springframework.core.io.buffer.DataBuffer; import org.springframework.core.io.buffer.DataBufferFactory; import org.springframework.core.io.buffer.DataBufferUtils; +import org.springframework.http.HttpHeaders; import org.springframework.util.Assert; +import org.springframework.util.ReflectionUtils; /** * {@link ServletHttpHandlerAdapter} extension that uses Tomcat APIs for reading * from the request and writing to the response with {@link ByteBuffer}. * * @author Violeta Georgieva + * @author Brian Clozel + * @author Brian Clozel * @since 5.0 * @see org.springframework.web.server.adapter.AbstractReactiveWebInitializer */ @@ -66,21 +75,39 @@ public class TomcatHttpHandlerAdapter extends ServletHttpHandlerAdapter { response, asyncContext, getDataBufferFactory(), getBufferSize(), request); } + private static final class TomcatServerHttpRequest extends ServletServerHttpRequest { - private final class TomcatServerHttpRequest extends ServletServerHttpRequest { + private static final Field COYOTE_REQUEST_FIELD = ReflectionUtils.findField(RequestFacade.class, "request"); - public TomcatServerHttpRequest(HttpServletRequest request, AsyncContext context, + private final int bufferSize; + + private final DataBufferFactory factory; + + static { + ReflectionUtils.makeAccessible(COYOTE_REQUEST_FIELD); + } + + TomcatServerHttpRequest(HttpServletRequest request, AsyncContext context, String servletPath, DataBufferFactory factory, int bufferSize) throws IOException, URISyntaxException { - super(request, context, servletPath, factory, bufferSize); + super(createTomcatHttpHeaders(request), request, context, servletPath, factory, bufferSize); + this.factory = factory; + this.bufferSize = bufferSize; + } + + private static HttpHeaders createTomcatHttpHeaders(HttpServletRequest request) { + Request tomcatRequest = ((org.apache.catalina.connector.Request) ReflectionUtils + .getField(COYOTE_REQUEST_FIELD, request)).getCoyoteRequest(); + TomcatHeadersAdapter headers = new TomcatHeadersAdapter(tomcatRequest.getMimeHeaders()); + return new HttpHeaders(headers); } @Override protected DataBuffer readFromInputStream() throws IOException { boolean release = true; - int capacity = getBufferSize(); - DataBuffer dataBuffer = getDataBufferFactory().allocateBuffer(capacity); + int capacity = this.bufferSize; + DataBuffer dataBuffer = this.factory.allocateBuffer(capacity); try { ByteBuffer byteBuffer = dataBuffer.asByteBuffer(0, capacity); @@ -111,10 +138,27 @@ public class TomcatHttpHandlerAdapter extends ServletHttpHandlerAdapter { private static final class TomcatServerHttpResponse extends ServletServerHttpResponse { - public TomcatServerHttpResponse(HttpServletResponse response, AsyncContext context, + private static final Field COYOTE_RESPONSE_FIELD = ReflectionUtils.findField(ResponseFacade.class, "response"); + + static { + ReflectionUtils.makeAccessible(COYOTE_RESPONSE_FIELD); + } + + TomcatServerHttpResponse(HttpServletResponse response, AsyncContext context, DataBufferFactory factory, int bufferSize, ServletServerHttpRequest request) throws IOException { - super(response, context, factory, bufferSize, request); + super(createTomcatHttpHeaders(response), response, context, factory, bufferSize, request); + } + + private static HttpHeaders createTomcatHttpHeaders(HttpServletResponse response) { + Response tomcatResponse = ((org.apache.catalina.connector.Response) ReflectionUtils + .getField(COYOTE_RESPONSE_FIELD, response)).getCoyoteResponse(); + TomcatHeadersAdapter headers = new TomcatHeadersAdapter(tomcatResponse.getMimeHeaders()); + return new HttpHeaders(headers); + } + + @Override + protected void applyHeaders() { } @Override diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowHeadersAdapter.java b/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowHeadersAdapter.java new file mode 100644 index 0000000000..3e817c906f --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowHeadersAdapter.java @@ -0,0 +1,222 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server.reactive; + +import java.util.AbstractSet; +import java.util.Collection; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import io.undertow.util.HeaderMap; +import io.undertow.util.HeaderValues; +import io.undertow.util.HttpString; + +import org.springframework.lang.Nullable; +import org.springframework.util.MultiValueMap; + +/** + * {@code MultiValueMap} implementation for wrapping Undertow HTTP headers. + * + * @author Brian Clozel + * @since 5.1 + */ +class UndertowHeadersAdapter implements MultiValueMap { + + private final HeaderMap headers; + + UndertowHeadersAdapter(HeaderMap headers) { + this.headers = headers; + } + + @Override + public String getFirst(String key) { + return this.headers.getFirst(key); + } + + @Override + public void add(String key, @Nullable String value) { + this.headers.add(HttpString.tryFromString(key), value); + } + + @Override + @SuppressWarnings("unchecked") + public void addAll(String key, List values) { + this.headers.addAll(HttpString.tryFromString(key), (List) values); + } + + @Override + public void addAll(MultiValueMap values) { + values.forEach((key, list) -> this.headers.addAll(HttpString.tryFromString(key), list)); + } + + @Override + public void set(String key, @Nullable String value) { + this.headers.put(HttpString.tryFromString(key), value); + } + + @Override + public void setAll(Map values) { + values.forEach((key, list) -> this.headers.put(HttpString.tryFromString(key), list)); + } + + @Override + public Map toSingleValueMap() { + Map singleValueMap = new LinkedHashMap<>(this.headers.size()); + this.headers.forEach(values -> + singleValueMap.put(values.getHeaderName().toString(), values.getFirst())); + return singleValueMap; + } + + @Override + public int size() { + return this.headers.size(); + } + + @Override + public boolean isEmpty() { + return this.headers.size() == 0; + } + + @Override + public boolean containsKey(Object key) { + if (key instanceof String) { + return this.headers.contains((String) key); + } + return false; + } + + @Override + public boolean containsValue(Object value) { + if (value instanceof String) { + return this.headers.getHeaderNames().stream() + .map(this.headers::get) + .anyMatch(values -> values.contains(value)); + } + return false; + } + + @Override + @Nullable + public List get(Object key) { + if (key instanceof String) { + return this.headers.get((String) key); + } + return null; + } + + @Override + @Nullable + public List put(String key, List value) { + HeaderValues previousValues = this.headers.get(key); + this.headers.putAll(HttpString.tryFromString(key), value); + return previousValues; + } + + @Override + @Nullable + public List remove(Object key) { + if (key instanceof String) { + this.headers.remove((String) key); + } + return null; + } + + @Override + public void putAll(Map> m) { + m.forEach((key, values) -> + this.headers.putAll(HttpString.tryFromString(key), values)); + } + + @Override + public void clear() { + this.headers.clear(); + } + + @Override + public Set keySet() { + return this.headers.getHeaderNames().stream() + .map(HttpString::toString) + .collect(Collectors.toSet()); + } + + @Override + public Collection> values() { + return this.headers.getHeaderNames().stream() + .map(this.headers::get) + .collect(Collectors.toList()); + } + + @Override + public Set>> entrySet() { + return new AbstractSet>>() { + @Override + public Iterator>> iterator() { + return new EntryIterator(); + } + + @Override + public int size() { + return headers.size(); + } + }; + } + + private class EntryIterator implements Iterator>> { + + private Iterator names = headers.getHeaderNames().iterator(); + + @Override + public boolean hasNext() { + return this.names.hasNext(); + } + + @Override + public Entry> next() { + return new HeaderEntry(this.names.next()); + } + } + + private class HeaderEntry implements Entry> { + + private final HttpString key; + + HeaderEntry(HttpString key) { + this.key = key; + } + + @Override + public String getKey() { + return this.key.toString(); + } + + @Override + public List getValue() { + return headers.get(this.key); + } + + @Override + public List setValue(List value) { + List previousValues = headers.get(this.key); + headers.putAll(this.key, value); + return previousValues; + } + } +} diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpRequest.java b/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpRequest.java index da6c061391..6c68f5a528 100644 --- a/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpRequest.java +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpRequest.java @@ -30,7 +30,6 @@ import io.undertow.connector.ByteBufferPool; import io.undertow.connector.PooledByteBuffer; import io.undertow.server.HttpServerExchange; import io.undertow.server.handlers.Cookie; -import io.undertow.util.HeaderValues; import org.xnio.channels.StreamSourceChannel; import reactor.core.publisher.Flux; @@ -79,11 +78,9 @@ class UndertowServerHttpRequest extends AbstractServerHttpRequest { } private static HttpHeaders initHeaders(HttpServerExchange exchange) { - HttpHeaders headers = new HttpHeaders(); - for (HeaderValues values : exchange.getRequestHeaders()) { - headers.put(values.getHeaderName().toString(), values); - } - return headers; + UndertowHeadersAdapter headersMap = + new UndertowHeadersAdapter(exchange.getRequestHeaders()); + return new HttpHeaders(headersMap); } @Override diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpResponse.java b/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpResponse.java index 1ad997896c..a9533379cd 100644 --- a/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpResponse.java +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpResponse.java @@ -25,7 +25,6 @@ import java.nio.file.StandardOpenOption; import io.undertow.server.HttpServerExchange; import io.undertow.server.handlers.Cookie; import io.undertow.server.handlers.CookieImpl; -import io.undertow.util.HttpString; import org.reactivestreams.Processor; import org.reactivestreams.Publisher; import org.xnio.channels.Channels; @@ -35,6 +34,7 @@ import reactor.core.publisher.Mono; import org.springframework.core.io.buffer.DataBuffer; import org.springframework.core.io.buffer.DataBufferFactory; import org.springframework.core.io.buffer.DataBufferUtils; +import org.springframework.http.HttpHeaders; import org.springframework.http.ResponseCookie; import org.springframework.http.ZeroCopyHttpOutputMessage; import org.springframework.lang.Nullable; @@ -58,15 +58,21 @@ class UndertowServerHttpResponse extends AbstractListenerServerHttpResponse impl private StreamSinkChannel responseChannel; - public UndertowServerHttpResponse( + UndertowServerHttpResponse( HttpServerExchange exchange, DataBufferFactory bufferFactory, UndertowServerHttpRequest request) { - super(bufferFactory); + super(bufferFactory, createHeaders(exchange)); Assert.notNull(exchange, "HttpServerExchange must not be null"); this.exchange = exchange; this.request = request; } + private static HttpHeaders createHeaders(HttpServerExchange exchange) { + UndertowHeadersAdapter headersMap = + new UndertowHeadersAdapter(exchange.getResponseHeaders()); + return new HttpHeaders(headersMap); + } + @SuppressWarnings("unchecked") @Override @@ -85,8 +91,6 @@ class UndertowServerHttpResponse extends AbstractListenerServerHttpResponse impl @Override protected void applyHeaders() { - getHeaders().forEach((headerName, headerValues) -> - this.exchange.getResponseHeaders().addAll(HttpString.tryFromString(headerName), headerValues)); } @Override diff --git a/spring-web/src/main/java/org/springframework/web/cors/reactive/DefaultCorsProcessor.java b/spring-web/src/main/java/org/springframework/web/cors/reactive/DefaultCorsProcessor.java index b673e1e463..f5027686de 100644 --- a/spring-web/src/main/java/org/springframework/web/cors/reactive/DefaultCorsProcessor.java +++ b/spring-web/src/main/java/org/springframework/web/cors/reactive/DefaultCorsProcessor.java @@ -87,7 +87,7 @@ public class DefaultCorsProcessor implements CorsProcessor { } private boolean responseHasCors(ServerHttpResponse response) { - return (response.getHeaders().getAccessControlAllowOrigin() != null); + return response.getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN) != null; } /** diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/socket/server/support/HandshakeWebSocketService.java b/spring-webflux/src/main/java/org/springframework/web/reactive/socket/server/support/HandshakeWebSocketService.java index f5061f960e..d3e353bf90 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/socket/server/support/HandshakeWebSocketService.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/socket/server/support/HandshakeWebSocketService.java @@ -272,7 +272,10 @@ public class HandshakeWebSocketService implements WebSocketService, Lifecycle { @Nullable String protocol, Map attributes) { URI uri = request.getURI(); - HttpHeaders headers = request.getHeaders(); + // Copy request headers, as they might be pooled and recycled by + // the server implementation once the handshake HTTP exchange is done. + HttpHeaders headers = new HttpHeaders(); + headers.addAll(request.getHeaders()); Mono principal = exchange.getPrincipal(); String logPrefix = exchange.getLogPrefix(); InetSocketAddress remoteAddress = request.getRemoteAddress();