Optimize Map methods in ServletAttributesMap

ServletAttributesMap inherited default implementations of the size
and isEmpty methods from AbstractMap which delegates to the Set returned
by entrySet. ServletAttributesMap's entrySet method made this fairly
expensive, since it would copy the attributes to a List, then use a
Stream to build the Set. To avoid the cost, add implementations of
isEmpty / size that don't need to call entrySet at all.

Additionally, change entrySet to return a Set view that simply lazily
delegates to the underlying servlet request for iteration.

Closes gh-32189
This commit is contained in:
Patrick Strawderman 2024-02-02 10:13:13 -08:00 committed by Brian Clozel
parent c04d4da9a3
commit 0fdf759896
2 changed files with 132 additions and 10 deletions

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2023 the original author or authors. * Copyright 2002-2024 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -26,10 +26,13 @@ import java.nio.charset.Charset;
import java.security.Principal; import java.security.Principal;
import java.time.Instant; import java.time.Instant;
import java.util.AbstractMap; import java.util.AbstractMap;
import java.util.AbstractSet;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collection; import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.Enumeration;
import java.util.Iterator;
import java.util.List; import java.util.List;
import java.util.Locale; import java.util.Locale;
import java.util.Map; import java.util.Map;
@ -57,6 +60,7 @@ import org.springframework.http.converter.GenericHttpMessageConverter;
import org.springframework.http.converter.HttpMessageConverter; import org.springframework.http.converter.HttpMessageConverter;
import org.springframework.http.server.RequestPath; import org.springframework.http.server.RequestPath;
import org.springframework.http.server.ServletServerHttpRequest; import org.springframework.http.server.ServletServerHttpRequest;
import org.springframework.lang.NonNull;
import org.springframework.lang.Nullable; import org.springframework.lang.Nullable;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
@ -80,6 +84,7 @@ import org.springframework.web.util.UriBuilder;
* *
* @author Arjen Poutsma * @author Arjen Poutsma
* @author Sam Brannen * @author Sam Brannen
* @author Patrick Strawderman
* @since 5.2 * @since 5.2
*/ */
class DefaultServerRequest implements ServerRequest { class DefaultServerRequest implements ServerRequest {
@ -469,18 +474,77 @@ class DefaultServerRequest implements ServerRequest {
@Override @Override
public void clear() { public void clear() {
List<String> attributeNames = Collections.list(this.servletRequest.getAttributeNames()); this.servletRequest.getAttributeNames().asIterator().forEachRemaining(this.servletRequest::removeAttribute);
attributeNames.forEach(this.servletRequest::removeAttribute);
} }
@Override @Override
public Set<Entry<String, Object>> entrySet() { public Set<Entry<String, Object>> entrySet() {
return Collections.list(this.servletRequest.getAttributeNames()).stream() return new AbstractSet<>() {
.map(name -> { @Override
Object value = this.servletRequest.getAttribute(name); public Iterator<Entry<String, Object>> iterator() {
return new SimpleImmutableEntry<>(name, value); return new Iterator<>() {
})
.collect(Collectors.toSet()); private final Iterator<String> attributes = ServletAttributesMap.this.servletRequest.getAttributeNames().asIterator();
@Override
public boolean hasNext() {
return this.attributes.hasNext();
}
@Override
public Entry<String, Object> next() {
String attribute = this.attributes.next();
Object value = ServletAttributesMap.this.servletRequest.getAttribute(attribute);
return new SimpleImmutableEntry<>(attribute, value);
}
};
}
@Override
public boolean isEmpty() {
return ServletAttributesMap.this.isEmpty();
}
@Override
public int size() {
return ServletAttributesMap.this.size();
}
@Override
public boolean contains(Object o) {
if (!(o instanceof Map.Entry<?,?> entry)) {
return false;
}
String attribute = (String) entry.getKey();
Object value = ServletAttributesMap.this.servletRequest.getAttribute(attribute);
return value != null && value.equals(entry.getValue());
}
@Override
public boolean addAll(@NonNull Collection<? extends Entry<String, Object>> c) {
throw new UnsupportedOperationException();
}
@Override
public boolean remove(Object o) {
throw new UnsupportedOperationException();
}
@Override
public boolean removeAll(Collection<?> c) {
throw new UnsupportedOperationException();
}
@Override
public boolean retainAll(@NonNull Collection<?> c) {
throw new UnsupportedOperationException();
}
@Override
public void clear() {
throw new UnsupportedOperationException();
}
};
} }
@Override @Override
@ -503,6 +567,22 @@ class DefaultServerRequest implements ServerRequest {
this.servletRequest.removeAttribute(name); this.servletRequest.removeAttribute(name);
return value; return value;
} }
@Override
public int size() {
Enumeration<String> attributes = this.servletRequest.getAttributeNames();
int size = 0;
while (attributes.hasMoreElements()) {
size++;
attributes.nextElement();
}
return size;
}
@Override
public boolean isEmpty() {
return !this.servletRequest.getAttributeNames().hasMoreElements();
}
} }

View File

@ -27,10 +27,12 @@ import java.security.Principal;
import java.time.Instant; import java.time.Instant;
import java.time.temporal.ChronoUnit; import java.time.temporal.ChronoUnit;
import java.util.Collections; import java.util.Collections;
import java.util.Iterator;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Optional; import java.util.Optional;
import java.util.OptionalLong; import java.util.OptionalLong;
import java.util.Set;
import jakarta.servlet.http.Cookie; import jakarta.servlet.http.Cookie;
import jakarta.servlet.http.Part; import jakarta.servlet.http.Part;
@ -62,8 +64,8 @@ import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
/** /**
* Tests for {@link DefaultServerRequest}.
* @author Arjen Poutsma * @author Arjen Poutsma
* @since 5.1
*/ */
class DefaultServerRequestTests { class DefaultServerRequestTests {
@ -115,6 +117,46 @@ class DefaultServerRequestTests {
assertThat(request.attribute("foo")).contains("bar"); assertThat(request.attribute("foo")).contains("bar");
} }
@Test
void attributes() {
MockHttpServletRequest servletRequest = PathPatternsTestUtils.initRequest("GET", "/", true);
servletRequest.setAttribute("foo", "bar");
servletRequest.setAttribute("baz", "qux");
DefaultServerRequest request = new DefaultServerRequest(servletRequest, this.messageConverters);
Map<String, Object> attributesMap = request.attributes();
assertThat(attributesMap).isNotEmpty();
assertThat(attributesMap).containsEntry("foo", "bar");
assertThat(attributesMap).containsEntry("baz", "qux");
assertThat(attributesMap).doesNotContainEntry("foo", "blah");
Set<Map.Entry<String, Object>> entrySet = attributesMap.entrySet();
assertThat(entrySet).isNotEmpty();
assertThat(entrySet).hasSize(attributesMap.size());
assertThat(entrySet).contains(Map.entry("foo", "bar"));
assertThat(entrySet).contains(Map.entry("baz", "qux"));
assertThat(entrySet).doesNotContain(Map.entry("foo", "blah"));
assertThat(entrySet).isUnmodifiable();
assertThat(entrySet.iterator()).toIterable().contains(Map.entry("foo", "bar"), Map.entry("baz", "qux"));
Iterator<String> attributes = servletRequest.getAttributeNames().asIterator();
Iterator<Map.Entry<String, Object>> entrySetIterator = entrySet.iterator();
while (attributes.hasNext()) {
attributes.next();
assertThat(entrySetIterator).hasNext();
entrySetIterator.next();
}
assertThat(entrySetIterator).isExhausted();
attributesMap.clear();
assertThat(attributesMap).isEmpty();
assertThat(attributesMap).hasSize(0);
assertThat(entrySet).isEmpty();
assertThat(entrySet).hasSize(0);
assertThat(entrySet.iterator()).isExhausted();
}
@Test @Test
void params() { void params() {
MockHttpServletRequest servletRequest = PathPatternsTestUtils.initRequest("GET", "/", true); MockHttpServletRequest servletRequest = PathPatternsTestUtils.initRequest("GET", "/", true);