From 74bb42b78f91fd409eb40b0d8bed6088935e915e Mon Sep 17 00:00:00 2001 From: Patrick Strawderman Date: Fri, 2 Feb 2024 10:13:13 -0800 Subject: [PATCH] Optimize Map methods in ServletAttributesMap ServletAttributesMap inherited default implementations of the size and isEmpty methods from AbstractMap which delegates to the Set returned by entrySet. ServletAttributesMap's entrySet method made this fairly expensive, since it would copy the attributes to a List, then use a Stream to build the Set. To avoid the cost, add implementations of isEmpty / size that don't need to call entrySet at all. Additionally, change entrySet to return a Set view that simply lazily delegates to the underlying servlet request for iteration. Closes gh-32197 --- .../function/DefaultServerRequest.java | 98 +++++++++++++++++-- .../function/DefaultServerRequestTests.java | 44 ++++++++- 2 files changed, 132 insertions(+), 10 deletions(-) diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/function/DefaultServerRequest.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/function/DefaultServerRequest.java index a2223c29ea5..4775d37f2bc 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/function/DefaultServerRequest.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/function/DefaultServerRequest.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,10 +26,13 @@ import java.nio.charset.Charset; import java.security.Principal; import java.time.Instant; import java.util.AbstractMap; +import java.util.AbstractSet; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.Collections; +import java.util.Enumeration; +import java.util.Iterator; import java.util.List; import java.util.Locale; import java.util.Map; @@ -55,6 +58,7 @@ import org.springframework.http.converter.GenericHttpMessageConverter; import org.springframework.http.converter.HttpMessageConverter; import org.springframework.http.server.RequestPath; import org.springframework.http.server.ServletServerHttpRequest; +import org.springframework.lang.NonNull; import org.springframework.lang.Nullable; import org.springframework.util.CollectionUtils; import org.springframework.util.LinkedMultiValueMap; @@ -73,6 +77,7 @@ import org.springframework.web.util.UriBuilder; * * @author Arjen Poutsma * @author Sam Brannen + * @author Patrick Strawderman * @since 5.2 */ class DefaultServerRequest implements ServerRequest { @@ -433,18 +438,77 @@ class DefaultServerRequest implements ServerRequest { @Override public void clear() { - List attributeNames = Collections.list(this.servletRequest.getAttributeNames()); - attributeNames.forEach(this.servletRequest::removeAttribute); + this.servletRequest.getAttributeNames().asIterator().forEachRemaining(this.servletRequest::removeAttribute); } @Override public Set> entrySet() { - return Collections.list(this.servletRequest.getAttributeNames()).stream() - .map(name -> { - Object value = this.servletRequest.getAttribute(name); - return new SimpleImmutableEntry<>(name, value); - }) - .collect(Collectors.toSet()); + return new AbstractSet<>() { + @Override + public Iterator> iterator() { + return new Iterator<>() { + + private final Iterator attributes = ServletAttributesMap.this.servletRequest.getAttributeNames().asIterator(); + + @Override + public boolean hasNext() { + return this.attributes.hasNext(); + } + + @Override + public Entry next() { + String attribute = this.attributes.next(); + Object value = ServletAttributesMap.this.servletRequest.getAttribute(attribute); + return new SimpleImmutableEntry<>(attribute, value); + } + }; + } + + @Override + public boolean isEmpty() { + return ServletAttributesMap.this.isEmpty(); + } + + @Override + public int size() { + return ServletAttributesMap.this.size(); + } + + @Override + public boolean contains(Object o) { + if (!(o instanceof Map.Entry entry)) { + return false; + } + String attribute = (String) entry.getKey(); + Object value = ServletAttributesMap.this.servletRequest.getAttribute(attribute); + return value != null && value.equals(entry.getValue()); + } + + @Override + public boolean addAll(@NonNull Collection> c) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean remove(Object o) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean removeAll(Collection c) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean retainAll(@NonNull Collection c) { + throw new UnsupportedOperationException(); + } + + @Override + public void clear() { + throw new UnsupportedOperationException(); + } + }; } @Override @@ -467,6 +531,22 @@ class DefaultServerRequest implements ServerRequest { this.servletRequest.removeAttribute(name); return value; } + + @Override + public int size() { + Enumeration attributes = this.servletRequest.getAttributeNames(); + int size = 0; + while (attributes.hasMoreElements()) { + size++; + attributes.nextElement(); + } + return size; + } + + @Override + public boolean isEmpty() { + return !this.servletRequest.getAttributeNames().hasMoreElements(); + } } diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/function/DefaultServerRequestTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/function/DefaultServerRequestTests.java index 30707f1a8ec..bce1120ddcd 100644 --- a/spring-webmvc/src/test/java/org/springframework/web/servlet/function/DefaultServerRequestTests.java +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/function/DefaultServerRequestTests.java @@ -27,10 +27,12 @@ import java.security.Principal; import java.time.Instant; import java.time.temporal.ChronoUnit; import java.util.Collections; +import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.OptionalLong; +import java.util.Set; import jakarta.servlet.http.Cookie; import jakarta.servlet.http.Part; @@ -61,8 +63,8 @@ import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; /** + * Tests for {@link DefaultServerRequest}. * @author Arjen Poutsma - * @since 5.1 */ class DefaultServerRequestTests { @@ -114,6 +116,46 @@ class DefaultServerRequestTests { assertThat(request.attribute("foo")).contains("bar"); } + @Test + void attributes() { + MockHttpServletRequest servletRequest = PathPatternsTestUtils.initRequest("GET", "/", true); + servletRequest.setAttribute("foo", "bar"); + servletRequest.setAttribute("baz", "qux"); + + DefaultServerRequest request = new DefaultServerRequest(servletRequest, this.messageConverters); + + Map attributesMap = request.attributes(); + assertThat(attributesMap).isNotEmpty(); + assertThat(attributesMap).containsEntry("foo", "bar"); + assertThat(attributesMap).containsEntry("baz", "qux"); + assertThat(attributesMap).doesNotContainEntry("foo", "blah"); + + Set> entrySet = attributesMap.entrySet(); + assertThat(entrySet).isNotEmpty(); + assertThat(entrySet).hasSize(attributesMap.size()); + assertThat(entrySet).contains(Map.entry("foo", "bar")); + assertThat(entrySet).contains(Map.entry("baz", "qux")); + assertThat(entrySet).doesNotContain(Map.entry("foo", "blah")); + assertThat(entrySet).isUnmodifiable(); + + assertThat(entrySet.iterator()).toIterable().contains(Map.entry("foo", "bar"), Map.entry("baz", "qux")); + Iterator attributes = servletRequest.getAttributeNames().asIterator(); + Iterator> entrySetIterator = entrySet.iterator(); + while (attributes.hasNext()) { + attributes.next(); + assertThat(entrySetIterator).hasNext(); + entrySetIterator.next(); + } + assertThat(entrySetIterator).isExhausted(); + + attributesMap.clear(); + assertThat(attributesMap).isEmpty(); + assertThat(attributesMap).hasSize(0); + assertThat(entrySet).isEmpty(); + assertThat(entrySet).hasSize(0); + assertThat(entrySet.iterator()).isExhausted(); + } + @Test void params() { MockHttpServletRequest servletRequest = PathPatternsTestUtils.initRequest("GET", "/", true);