diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/function/DefaultServerRequest.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/function/DefaultServerRequest.java index a2223c29ea5..4775d37f2bc 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/function/DefaultServerRequest.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/function/DefaultServerRequest.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,10 +26,13 @@ import java.nio.charset.Charset; import java.security.Principal; import java.time.Instant; import java.util.AbstractMap; +import java.util.AbstractSet; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.Collections; +import java.util.Enumeration; +import java.util.Iterator; import java.util.List; import java.util.Locale; import java.util.Map; @@ -55,6 +58,7 @@ import org.springframework.http.converter.GenericHttpMessageConverter; import org.springframework.http.converter.HttpMessageConverter; import org.springframework.http.server.RequestPath; import org.springframework.http.server.ServletServerHttpRequest; +import org.springframework.lang.NonNull; import org.springframework.lang.Nullable; import org.springframework.util.CollectionUtils; import org.springframework.util.LinkedMultiValueMap; @@ -73,6 +77,7 @@ import org.springframework.web.util.UriBuilder; * * @author Arjen Poutsma * @author Sam Brannen + * @author Patrick Strawderman * @since 5.2 */ class DefaultServerRequest implements ServerRequest { @@ -433,18 +438,77 @@ class DefaultServerRequest implements ServerRequest { @Override public void clear() { - List attributeNames = Collections.list(this.servletRequest.getAttributeNames()); - attributeNames.forEach(this.servletRequest::removeAttribute); + this.servletRequest.getAttributeNames().asIterator().forEachRemaining(this.servletRequest::removeAttribute); } @Override public Set> entrySet() { - return Collections.list(this.servletRequest.getAttributeNames()).stream() - .map(name -> { - Object value = this.servletRequest.getAttribute(name); - return new SimpleImmutableEntry<>(name, value); - }) - .collect(Collectors.toSet()); + return new AbstractSet<>() { + @Override + public Iterator> iterator() { + return new Iterator<>() { + + private final Iterator attributes = ServletAttributesMap.this.servletRequest.getAttributeNames().asIterator(); + + @Override + public boolean hasNext() { + return this.attributes.hasNext(); + } + + @Override + public Entry next() { + String attribute = this.attributes.next(); + Object value = ServletAttributesMap.this.servletRequest.getAttribute(attribute); + return new SimpleImmutableEntry<>(attribute, value); + } + }; + } + + @Override + public boolean isEmpty() { + return ServletAttributesMap.this.isEmpty(); + } + + @Override + public int size() { + return ServletAttributesMap.this.size(); + } + + @Override + public boolean contains(Object o) { + if (!(o instanceof Map.Entry entry)) { + return false; + } + String attribute = (String) entry.getKey(); + Object value = ServletAttributesMap.this.servletRequest.getAttribute(attribute); + return value != null && value.equals(entry.getValue()); + } + + @Override + public boolean addAll(@NonNull Collection> c) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean remove(Object o) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean removeAll(Collection c) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean retainAll(@NonNull Collection c) { + throw new UnsupportedOperationException(); + } + + @Override + public void clear() { + throw new UnsupportedOperationException(); + } + }; } @Override @@ -467,6 +531,22 @@ class DefaultServerRequest implements ServerRequest { this.servletRequest.removeAttribute(name); return value; } + + @Override + public int size() { + Enumeration attributes = this.servletRequest.getAttributeNames(); + int size = 0; + while (attributes.hasMoreElements()) { + size++; + attributes.nextElement(); + } + return size; + } + + @Override + public boolean isEmpty() { + return !this.servletRequest.getAttributeNames().hasMoreElements(); + } } diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/function/DefaultServerRequestTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/function/DefaultServerRequestTests.java index 30707f1a8ec..bce1120ddcd 100644 --- a/spring-webmvc/src/test/java/org/springframework/web/servlet/function/DefaultServerRequestTests.java +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/function/DefaultServerRequestTests.java @@ -27,10 +27,12 @@ import java.security.Principal; import java.time.Instant; import java.time.temporal.ChronoUnit; import java.util.Collections; +import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.OptionalLong; +import java.util.Set; import jakarta.servlet.http.Cookie; import jakarta.servlet.http.Part; @@ -61,8 +63,8 @@ import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; /** + * Tests for {@link DefaultServerRequest}. * @author Arjen Poutsma - * @since 5.1 */ class DefaultServerRequestTests { @@ -114,6 +116,46 @@ class DefaultServerRequestTests { assertThat(request.attribute("foo")).contains("bar"); } + @Test + void attributes() { + MockHttpServletRequest servletRequest = PathPatternsTestUtils.initRequest("GET", "/", true); + servletRequest.setAttribute("foo", "bar"); + servletRequest.setAttribute("baz", "qux"); + + DefaultServerRequest request = new DefaultServerRequest(servletRequest, this.messageConverters); + + Map attributesMap = request.attributes(); + assertThat(attributesMap).isNotEmpty(); + assertThat(attributesMap).containsEntry("foo", "bar"); + assertThat(attributesMap).containsEntry("baz", "qux"); + assertThat(attributesMap).doesNotContainEntry("foo", "blah"); + + Set> entrySet = attributesMap.entrySet(); + assertThat(entrySet).isNotEmpty(); + assertThat(entrySet).hasSize(attributesMap.size()); + assertThat(entrySet).contains(Map.entry("foo", "bar")); + assertThat(entrySet).contains(Map.entry("baz", "qux")); + assertThat(entrySet).doesNotContain(Map.entry("foo", "blah")); + assertThat(entrySet).isUnmodifiable(); + + assertThat(entrySet.iterator()).toIterable().contains(Map.entry("foo", "bar"), Map.entry("baz", "qux")); + Iterator attributes = servletRequest.getAttributeNames().asIterator(); + Iterator> entrySetIterator = entrySet.iterator(); + while (attributes.hasNext()) { + attributes.next(); + assertThat(entrySetIterator).hasNext(); + entrySetIterator.next(); + } + assertThat(entrySetIterator).isExhausted(); + + attributesMap.clear(); + assertThat(attributesMap).isEmpty(); + assertThat(attributesMap).hasSize(0); + assertThat(entrySet).isEmpty(); + assertThat(entrySet).hasSize(0); + assertThat(entrySet.iterator()).isExhausted(); + } + @Test void params() { MockHttpServletRequest servletRequest = PathPatternsTestUtils.initRequest("GET", "/", true);