diff --git a/spring-web/src/main/java/org/springframework/http/HttpHeaders.java b/spring-web/src/main/java/org/springframework/http/HttpHeaders.java index ed1a6a14f35..c713f775ab6 100644 --- a/spring-web/src/main/java/org/springframework/http/HttpHeaders.java +++ b/spring-web/src/main/java/org/springframework/http/HttpHeaders.java @@ -1686,13 +1686,13 @@ public class HttpHeaders implements Serializable { * @since 5.2.3 */ public void clearContentHeaders() { - this.headers.remove(HttpHeaders.CONTENT_DISPOSITION); - this.headers.remove(HttpHeaders.CONTENT_ENCODING); - this.headers.remove(HttpHeaders.CONTENT_LANGUAGE); - this.headers.remove(HttpHeaders.CONTENT_LENGTH); - this.headers.remove(HttpHeaders.CONTENT_LOCATION); - this.headers.remove(HttpHeaders.CONTENT_RANGE); - this.headers.remove(HttpHeaders.CONTENT_TYPE); + remove(HttpHeaders.CONTENT_DISPOSITION); + remove(HttpHeaders.CONTENT_ENCODING); + remove(HttpHeaders.CONTENT_LANGUAGE); + remove(HttpHeaders.CONTENT_LENGTH); + remove(HttpHeaders.CONTENT_LOCATION); + remove(HttpHeaders.CONTENT_RANGE); + remove(HttpHeaders.CONTENT_TYPE); } /** @@ -1807,7 +1807,7 @@ public class HttpHeaders implements Serializable { * @see #putAll(HttpHeaders) */ public void addAll(HttpHeaders headers) { - this.headers.addAll(headers.headers); + headers.forEach(this::addAll); } /** @@ -1909,7 +1909,7 @@ public class HttpHeaders implements Serializable { * @since 7.0 */ public boolean hasHeaderValues(String headerName, List values) { - return ObjectUtils.nullSafeEquals(this.headers.get(headerName), values); + return ObjectUtils.nullSafeEquals(get(headerName), values); } /** @@ -1920,7 +1920,7 @@ public class HttpHeaders implements Serializable { * @since 7.0 */ public boolean containsHeaderValue(String headerName, String value) { - final List values = this.headers.get(headerName); + List values = get(headerName); if (values == null) { return false; } @@ -1969,7 +1969,7 @@ public class HttpHeaders implements Serializable { * @see #put(String, List) */ public void putAll(HttpHeaders headers) { - this.headers.putAll(headers.headers); + headers.forEach(this::put); } /** @@ -1978,7 +1978,7 @@ public class HttpHeaders implements Serializable { * @see #put(String, List) */ public void putAll(Map> headers) { - this.headers.putAll(headers); + headers.forEach(this::put); } /** diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/WebSocketHttpHeaders.java b/spring-websocket/src/main/java/org/springframework/web/socket/WebSocketHttpHeaders.java index 0ba36021d37..be90695e7b9 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/WebSocketHttpHeaders.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/WebSocketHttpHeaders.java @@ -21,18 +21,19 @@ import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Set; -import java.util.function.BiConsumer; import org.jspecify.annotations.Nullable; import org.springframework.http.HttpHeaders; import org.springframework.util.CollectionUtils; +import org.springframework.util.MultiValueMap; /** * An {@link org.springframework.http.HttpHeaders} variant that adds support for * the HTTP headers defined by the WebSocket specification RFC 6455. * * @author Rossen Stoyanchev + * @author Sam Brannen * @since 4.0 */ public class WebSocketHttpHeaders extends HttpHeaders { @@ -181,40 +182,40 @@ public class WebSocketHttpHeaders extends HttpHeaders { return getFirst(SEC_WEBSOCKET_VERSION); } + @Override + public @Nullable List get(String headerName) { + return this.headers.get(headerName); + } - // Single string methods - - /** - * Return the first header value for the given header name, if any. - * @param headerName the header name - * @return the first header value; or {@code null} - */ @Override public @Nullable String getFirst(String headerName) { return this.headers.getFirst(headerName); } - /** - * Add the given, single header value under the given name. - * @param headerName the header name - * @param headerValue the header value - * @throws UnsupportedOperationException if adding headers is not supported - * @see #put(String, List) - * @see #set(String, String) - */ + @Override + public @Nullable List put(String key, List value) { + return this.headers.put(key, value); + } + + @Override + public @Nullable List putIfAbsent(String headerName, List headerValues) { + return this.headers.putIfAbsent(headerName, headerValues); + } + @Override public void add(String headerName, @Nullable String headerValue) { this.headers.add(headerName, headerValue); } /** - * Set the given, single header value under the given name. - * @param headerName the header name - * @param headerValue the header value - * @throws UnsupportedOperationException if adding headers is not supported - * @see #put(String, List) - * @see #add(String, String) + * {@inheritDoc} + * @since 7.0 */ + @Override + public void addAll(String headerName, List headerValues) { + this.headers.addAll(headerName, headerValues); + } + @Override public void set(String headerName, @Nullable String headerValue) { this.headers.set(headerName, headerValue); @@ -230,16 +231,32 @@ public class WebSocketHttpHeaders extends HttpHeaders { return this.headers.toSingleValueMap(); } - // Map implementation - + /** + * {@inheritDoc} + * @since 7.0 + * @deprecated in favor of {@link #toSingleValueMap()} which performs a copy but + * ensures that collection-iterating methods like {@code entrySet()} are + * case-insensitive + */ @Override - public int size() { - return this.headers.size(); + @Deprecated(since = "7.0", forRemoval = true) + @SuppressWarnings("removal") + public Map asSingleValueMap() { + return this.headers.asSingleValueMap(); } + /** + * {@inheritDoc} + * @since 7.0 + * @deprecated This method is provided for backward compatibility with APIs + * that would only accept maps. Generally avoid using HttpHeaders as a Map + * or MultiValueMap. + */ @Override - public boolean isEmpty() { - return this.headers.isEmpty(); + @Deprecated(since = "7.0", forRemoval = true) + @SuppressWarnings("removal") + public MultiValueMap asMultiValueMap() { + return this.headers.asMultiValueMap(); } @Override @@ -248,13 +265,13 @@ public class WebSocketHttpHeaders extends HttpHeaders { } @Override - public @Nullable List get(String headerName) { - return this.headers.get(headerName); + public boolean isEmpty() { + return this.headers.isEmpty(); } @Override - public @Nullable List put(String key, List value) { - return this.headers.put(key, value); + public int size() { + return this.headers.size(); } @Override @@ -262,16 +279,6 @@ public class WebSocketHttpHeaders extends HttpHeaders { return this.headers.remove(key); } - @Override - public void putAll(HttpHeaders headers) { - this.headers.putAll(headers); - } - - @Override - public void putAll(Map> m) { - this.headers.putAll(m); - } - @Override public void clear() { this.headers.clear(); @@ -287,17 +294,6 @@ public class WebSocketHttpHeaders extends HttpHeaders { return this.headers.headerSet(); } - @Override - public void forEach(BiConsumer> action) { - this.headers.forEach(action); - } - - @Override - public @Nullable List putIfAbsent(String headerName, List headerValues) { - return this.headers.putIfAbsent(headerName, headerValues); - } - - @Override public boolean equals(@Nullable Object other) { return (this == other || (other instanceof WebSocketHttpHeaders that && diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/handler/WebSocketHttpHeadersTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/handler/WebSocketHttpHeadersTests.java index f8005cdae3a..f5c47083e77 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/handler/WebSocketHttpHeadersTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/handler/WebSocketHttpHeadersTests.java @@ -16,41 +16,122 @@ package org.springframework.web.socket.handler; -import java.util.ArrayList; import java.util.List; +import java.util.Map; -import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.springframework.web.socket.WebSocketExtension; +import org.springframework.http.HttpHeaders; import org.springframework.web.socket.WebSocketHttpHeaders; import static org.assertj.core.api.Assertions.assertThat; - +import static org.assertj.core.api.Assertions.entry; /** * Tests for {@link WebSocketHttpHeaders}. * * @author Rossen Stoyanchev + * @author Sam Brannen */ class WebSocketHttpHeadersTests { - private WebSocketHttpHeaders headers; + private WebSocketHttpHeaders headers = new WebSocketHttpHeaders(); - @BeforeEach - void setUp() { - headers = new WebSocketHttpHeaders(); - } @Test void parseWebSocketExtensions() { - List extensions = new ArrayList<>(); - extensions.add("x-foo-extension, x-bar-extension"); - extensions.add("x-test-extension"); + var extensions = List.of("x-foo-extension, x-bar-extension", "x-test-extension"); this.headers.put(WebSocketHttpHeaders.SEC_WEBSOCKET_EXTENSIONS, extensions); - List parsedExtensions = this.headers.getSecWebSocketExtensions(); + var parsedExtensions = this.headers.getSecWebSocketExtensions(); assertThat(parsedExtensions).hasSize(3); } + @Test // gh-35792 + void addAllViaWebSocketHttpHeadersApi() { + headers.add("green", "grape"); + + var otherHeaders = new HttpHeaders(); + otherHeaders.add("yellow", "banana"); + otherHeaders.add("red", "apple"); + + headers.addAll(otherHeaders); + + assertThat(headers.toSingleValueMap()).containsOnly( + entry("green", "grape"), + entry("yellow", "banana"), + entry("red", "apple") + ); + } + + @Test // gh-35792 + void addAllViaHttpHeadersApi() { + headers.add("yellow", "banana"); + headers.add("red", "apple"); + + var otherHeaders = new HttpHeaders(); + otherHeaders.add("green", "grape"); + + otherHeaders.addAll(headers); + + assertThat(otherHeaders.toSingleValueMap()).containsOnly( + entry("green", "grape"), + entry("yellow", "banana"), + entry("red", "apple") + ); + } + + @Test // gh-35792 + void putAllFromHttpHeadersViaWebSocketHttpHeadersApi() { + var otherHeaders = new HttpHeaders(); + otherHeaders.add("yellow", "banana"); + otherHeaders.add("red", "apple"); + + headers.putAll(otherHeaders); + + assertThat(headers.toSingleValueMap()).containsOnly( + entry("yellow", "banana"), + entry("red", "apple") + ); + } + + @Test // gh-35792 + void putAllFromHttpHeadersViaHttpHeadersApi() { + headers.add("yellow", "banana"); + headers.add("red", "apple"); + + var otherHeaders = new HttpHeaders(); + otherHeaders.putAll(headers); + + assertThat(otherHeaders.toSingleValueMap()).containsOnly( + entry("yellow", "banana"), + entry("red", "apple") + ); + } + + @Test // gh-35792 + void putAllFromMap() { + headers.putAll(Map.of("yellow", List.of("banana"), "red", List.of("apple"))); + + assertThat(headers.toSingleValueMap()).containsOnly( + entry("yellow", "banana"), + entry("red", "apple") + ); + } + + @Test // gh-35792 + void setAllFromMap() { + headers.add("yellow", "lemon"); + assertThat(headers.toSingleValueMap()).containsOnly( + entry("yellow", "lemon") + ); + + headers.setAll(Map.of("yellow", "banana", "red", "apple")); + + assertThat(headers.toSingleValueMap()).containsOnly( + entry("yellow", "banana"), // not lemon + entry("red", "apple") + ); + } + } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/XhrTransportTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/XhrTransportTests.java index aceccf58185..0129bde3c85 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/XhrTransportTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/XhrTransportTests.java @@ -29,10 +29,12 @@ import org.springframework.http.ResponseEntity; import org.springframework.web.client.HttpServerErrorException; import org.springframework.web.socket.TextMessage; import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.WebSocketHttpHeaders; import org.springframework.web.socket.WebSocketSession; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.entry; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; @@ -42,19 +44,38 @@ import static org.mockito.Mockito.verifyNoMoreInteractions; * Tests for {@link AbstractXhrTransport}. * * @author Rossen Stoyanchev + * @author Sam Brannen */ class XhrTransportTests { + private final TestXhrTransport transport = new TestXhrTransport(); + + @Test void infoResponse() { - TestXhrTransport transport = new TestXhrTransport(); transport.infoResponseToReturn = new ResponseEntity<>("body", HttpStatus.OK); assertThat(transport.executeInfoRequest(URI.create("https://example.com/info"), null)).isEqualTo("body"); } + @Test // gh-35792 + void infoResponseWithWebSocketHttpHeaders() { + transport.infoResponseToReturn = new ResponseEntity<>("body", HttpStatus.OK); + + var headers = new WebSocketHttpHeaders(); + headers.setSecWebSocketAccept("enigma"); + headers.add("foo", "bar"); + + transport.executeInfoRequest(URI.create("https://example.com/info"), headers); + + assertThat(transport.actualInfoHeaders).isNotNull(); + assertThat(transport.actualInfoHeaders.toSingleValueMap()).containsExactly( + entry(WebSocketHttpHeaders.SEC_WEBSOCKET_ACCEPT, "enigma"), + entry("foo", "bar") + ); + } + @Test void infoResponseError() { - TestXhrTransport transport = new TestXhrTransport(); transport.infoResponseToReturn = new ResponseEntity<>("body", HttpStatus.BAD_REQUEST); assertThatExceptionOfType(HttpServerErrorException.class).isThrownBy(() -> transport.executeInfoRequest(URI.create("https://example.com/info"), null)); @@ -65,7 +86,6 @@ class XhrTransportTests { HttpHeaders requestHeaders = new HttpHeaders(); requestHeaders.set("foo", "bar"); requestHeaders.setContentType(MediaType.APPLICATION_JSON); - TestXhrTransport transport = new TestXhrTransport(); transport.sendMessageResponseToReturn = new ResponseEntity<>(HttpStatus.NO_CONTENT); URI url = URI.create("https://example.com"); transport.executeSendRequest(url, requestHeaders, new TextMessage("payload")); @@ -76,7 +96,6 @@ class XhrTransportTests { @Test void sendMessageError() { - TestXhrTransport transport = new TestXhrTransport(); transport.sendMessageResponseToReturn = new ResponseEntity<>(HttpStatus.BAD_REQUEST); URI url = URI.create("https://example.com"); assertThatExceptionOfType(HttpServerErrorException.class).isThrownBy(() -> @@ -93,7 +112,6 @@ class XhrTransportTests { given(request.getHandshakeHeaders()).willReturn(handshakeHeaders); given(request.getHttpRequestHeaders()).willReturn(new HttpHeaders()); - TestXhrTransport transport = new TestXhrTransport(); WebSocketHandler handler = mock(); transport.connectAsync(request, handler); @@ -124,10 +142,13 @@ class XhrTransportTests { private HttpHeaders actualHandshakeHeaders; + private HttpHeaders actualInfoHeaders; + private XhrClientSockJsSession actualSession; @Override protected ResponseEntity executeInfoRequestInternal(URI infoUrl, HttpHeaders headers) { + this.actualInfoHeaders = headers; return this.infoResponseToReturn; }