Browse Source

Fix HttpHeaders and WebSocketHttpHeaders interop issues

Since HttpHeaders no longer implements MultiValueMap (see gh-33913),
a few interoperability issues have arisen between HttpHeaders and
WebSocketHttpHeaders.

To address those issues, this commit:

- Revises addAll(HttpHeaders), putAll(HttpHeaders), and putAll(Map) in
  HttpHeaders so that they no longer operate on the HttpHeaders.headers
  field.

- Overrides addAll(String, List), asSingleValueMap(), and
  asMultiValueMap() in WebSocketHttpHeaders.

- Deletes putAll(HttpHeaders), putAll(Map), and forEach(BiConsumer) in
  WebSocketHttpHeaders, since they do not need to be overridden.

This commit also removes unnecessarily overridden Javadoc in
WebSocketHttpHeaders and revises the implementation of several methods
in HttpHeaders so that they delegate to key methods such as get()
instead of directly accessing the HttpHeaders.headers field.

See gh-33913
Closes gh-35792
pull/34993/merge
Sam Brannen 1 month ago
parent
commit
4593f877dd
  1. 24
      spring-web/src/main/java/org/springframework/http/HttpHeaders.java
  2. 102
      spring-websocket/src/main/java/org/springframework/web/socket/WebSocketHttpHeaders.java
  3. 107
      spring-websocket/src/test/java/org/springframework/web/socket/handler/WebSocketHttpHeadersTests.java
  4. 31
      spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/XhrTransportTests.java

24
spring-web/src/main/java/org/springframework/http/HttpHeaders.java

@ -1686,13 +1686,13 @@ public class HttpHeaders implements Serializable { @@ -1686,13 +1686,13 @@ public class HttpHeaders implements Serializable {
* @since 5.2.3
*/
public void clearContentHeaders() {
this.headers.remove(HttpHeaders.CONTENT_DISPOSITION);
this.headers.remove(HttpHeaders.CONTENT_ENCODING);
this.headers.remove(HttpHeaders.CONTENT_LANGUAGE);
this.headers.remove(HttpHeaders.CONTENT_LENGTH);
this.headers.remove(HttpHeaders.CONTENT_LOCATION);
this.headers.remove(HttpHeaders.CONTENT_RANGE);
this.headers.remove(HttpHeaders.CONTENT_TYPE);
remove(HttpHeaders.CONTENT_DISPOSITION);
remove(HttpHeaders.CONTENT_ENCODING);
remove(HttpHeaders.CONTENT_LANGUAGE);
remove(HttpHeaders.CONTENT_LENGTH);
remove(HttpHeaders.CONTENT_LOCATION);
remove(HttpHeaders.CONTENT_RANGE);
remove(HttpHeaders.CONTENT_TYPE);
}
/**
@ -1807,7 +1807,7 @@ public class HttpHeaders implements Serializable { @@ -1807,7 +1807,7 @@ public class HttpHeaders implements Serializable {
* @see #putAll(HttpHeaders)
*/
public void addAll(HttpHeaders headers) {
this.headers.addAll(headers.headers);
headers.forEach(this::addAll);
}
/**
@ -1909,7 +1909,7 @@ public class HttpHeaders implements Serializable { @@ -1909,7 +1909,7 @@ public class HttpHeaders implements Serializable {
* @since 7.0
*/
public boolean hasHeaderValues(String headerName, List<String> values) {
return ObjectUtils.nullSafeEquals(this.headers.get(headerName), values);
return ObjectUtils.nullSafeEquals(get(headerName), values);
}
/**
@ -1920,7 +1920,7 @@ public class HttpHeaders implements Serializable { @@ -1920,7 +1920,7 @@ public class HttpHeaders implements Serializable {
* @since 7.0
*/
public boolean containsHeaderValue(String headerName, String value) {
final List<String> values = this.headers.get(headerName);
List<String> values = get(headerName);
if (values == null) {
return false;
}
@ -1969,7 +1969,7 @@ public class HttpHeaders implements Serializable { @@ -1969,7 +1969,7 @@ public class HttpHeaders implements Serializable {
* @see #put(String, List)
*/
public void putAll(HttpHeaders headers) {
this.headers.putAll(headers.headers);
headers.forEach(this::put);
}
/**
@ -1978,7 +1978,7 @@ public class HttpHeaders implements Serializable { @@ -1978,7 +1978,7 @@ public class HttpHeaders implements Serializable {
* @see #put(String, List)
*/
public void putAll(Map<? extends String, ? extends List<String>> headers) {
this.headers.putAll(headers);
headers.forEach(this::put);
}
/**

102
spring-websocket/src/main/java/org/springframework/web/socket/WebSocketHttpHeaders.java

@ -21,18 +21,19 @@ import java.util.Collections; @@ -21,18 +21,19 @@ import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.BiConsumer;
import org.jspecify.annotations.Nullable;
import org.springframework.http.HttpHeaders;
import org.springframework.util.CollectionUtils;
import org.springframework.util.MultiValueMap;
/**
* An {@link org.springframework.http.HttpHeaders} variant that adds support for
* the HTTP headers defined by the WebSocket specification RFC 6455.
*
* @author Rossen Stoyanchev
* @author Sam Brannen
* @since 4.0
*/
public class WebSocketHttpHeaders extends HttpHeaders {
@ -181,40 +182,40 @@ public class WebSocketHttpHeaders extends HttpHeaders { @@ -181,40 +182,40 @@ public class WebSocketHttpHeaders extends HttpHeaders {
return getFirst(SEC_WEBSOCKET_VERSION);
}
@Override
public @Nullable List<String> get(String headerName) {
return this.headers.get(headerName);
}
// Single string methods
/**
* Return the first header value for the given header name, if any.
* @param headerName the header name
* @return the first header value; or {@code null}
*/
@Override
public @Nullable String getFirst(String headerName) {
return this.headers.getFirst(headerName);
}
/**
* Add the given, single header value under the given name.
* @param headerName the header name
* @param headerValue the header value
* @throws UnsupportedOperationException if adding headers is not supported
* @see #put(String, List)
* @see #set(String, String)
*/
@Override
public @Nullable List<String> put(String key, List<String> value) {
return this.headers.put(key, value);
}
@Override
public @Nullable List<String> putIfAbsent(String headerName, List<String> headerValues) {
return this.headers.putIfAbsent(headerName, headerValues);
}
@Override
public void add(String headerName, @Nullable String headerValue) {
this.headers.add(headerName, headerValue);
}
/**
* Set the given, single header value under the given name.
* @param headerName the header name
* @param headerValue the header value
* @throws UnsupportedOperationException if adding headers is not supported
* @see #put(String, List)
* @see #add(String, String)
* {@inheritDoc}
* @since 7.0
*/
@Override
public void addAll(String headerName, List<? extends String> headerValues) {
this.headers.addAll(headerName, headerValues);
}
@Override
public void set(String headerName, @Nullable String headerValue) {
this.headers.set(headerName, headerValue);
@ -230,16 +231,32 @@ public class WebSocketHttpHeaders extends HttpHeaders { @@ -230,16 +231,32 @@ public class WebSocketHttpHeaders extends HttpHeaders {
return this.headers.toSingleValueMap();
}
// Map implementation
/**
* {@inheritDoc}
* @since 7.0
* @deprecated in favor of {@link #toSingleValueMap()} which performs a copy but
* ensures that collection-iterating methods like {@code entrySet()} are
* case-insensitive
*/
@Override
public int size() {
return this.headers.size();
@Deprecated(since = "7.0", forRemoval = true)
@SuppressWarnings("removal")
public Map<String, String> asSingleValueMap() {
return this.headers.asSingleValueMap();
}
/**
* {@inheritDoc}
* @since 7.0
* @deprecated This method is provided for backward compatibility with APIs
* that would only accept maps. Generally avoid using HttpHeaders as a Map
* or MultiValueMap.
*/
@Override
public boolean isEmpty() {
return this.headers.isEmpty();
@Deprecated(since = "7.0", forRemoval = true)
@SuppressWarnings("removal")
public MultiValueMap<String, String> asMultiValueMap() {
return this.headers.asMultiValueMap();
}
@Override
@ -248,13 +265,13 @@ public class WebSocketHttpHeaders extends HttpHeaders { @@ -248,13 +265,13 @@ public class WebSocketHttpHeaders extends HttpHeaders {
}
@Override
public @Nullable List<String> get(String headerName) {
return this.headers.get(headerName);
public boolean isEmpty() {
return this.headers.isEmpty();
}
@Override
public @Nullable List<String> put(String key, List<String> value) {
return this.headers.put(key, value);
public int size() {
return this.headers.size();
}
@Override
@ -262,16 +279,6 @@ public class WebSocketHttpHeaders extends HttpHeaders { @@ -262,16 +279,6 @@ public class WebSocketHttpHeaders extends HttpHeaders {
return this.headers.remove(key);
}
@Override
public void putAll(HttpHeaders headers) {
this.headers.putAll(headers);
}
@Override
public void putAll(Map<? extends String, ? extends List<String>> m) {
this.headers.putAll(m);
}
@Override
public void clear() {
this.headers.clear();
@ -287,17 +294,6 @@ public class WebSocketHttpHeaders extends HttpHeaders { @@ -287,17 +294,6 @@ public class WebSocketHttpHeaders extends HttpHeaders {
return this.headers.headerSet();
}
@Override
public void forEach(BiConsumer<? super String, ? super List<String>> action) {
this.headers.forEach(action);
}
@Override
public @Nullable List<String> putIfAbsent(String headerName, List<String> headerValues) {
return this.headers.putIfAbsent(headerName, headerValues);
}
@Override
public boolean equals(@Nullable Object other) {
return (this == other || (other instanceof WebSocketHttpHeaders that &&

107
spring-websocket/src/test/java/org/springframework/web/socket/handler/WebSocketHttpHeadersTests.java

@ -16,41 +16,122 @@ @@ -16,41 +16,122 @@
package org.springframework.web.socket.handler;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.springframework.web.socket.WebSocketExtension;
import org.springframework.http.HttpHeaders;
import org.springframework.web.socket.WebSocketHttpHeaders;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.entry;
/**
* Tests for {@link WebSocketHttpHeaders}.
*
* @author Rossen Stoyanchev
* @author Sam Brannen
*/
class WebSocketHttpHeadersTests {
private WebSocketHttpHeaders headers;
private WebSocketHttpHeaders headers = new WebSocketHttpHeaders();
@BeforeEach
void setUp() {
headers = new WebSocketHttpHeaders();
}
@Test
void parseWebSocketExtensions() {
List<String> extensions = new ArrayList<>();
extensions.add("x-foo-extension, x-bar-extension");
extensions.add("x-test-extension");
var extensions = List.of("x-foo-extension, x-bar-extension", "x-test-extension");
this.headers.put(WebSocketHttpHeaders.SEC_WEBSOCKET_EXTENSIONS, extensions);
List<WebSocketExtension> parsedExtensions = this.headers.getSecWebSocketExtensions();
var parsedExtensions = this.headers.getSecWebSocketExtensions();
assertThat(parsedExtensions).hasSize(3);
}
@Test // gh-35792
void addAllViaWebSocketHttpHeadersApi() {
headers.add("green", "grape");
var otherHeaders = new HttpHeaders();
otherHeaders.add("yellow", "banana");
otherHeaders.add("red", "apple");
headers.addAll(otherHeaders);
assertThat(headers.toSingleValueMap()).containsOnly(
entry("green", "grape"),
entry("yellow", "banana"),
entry("red", "apple")
);
}
@Test // gh-35792
void addAllViaHttpHeadersApi() {
headers.add("yellow", "banana");
headers.add("red", "apple");
var otherHeaders = new HttpHeaders();
otherHeaders.add("green", "grape");
otherHeaders.addAll(headers);
assertThat(otherHeaders.toSingleValueMap()).containsOnly(
entry("green", "grape"),
entry("yellow", "banana"),
entry("red", "apple")
);
}
@Test // gh-35792
void putAllFromHttpHeadersViaWebSocketHttpHeadersApi() {
var otherHeaders = new HttpHeaders();
otherHeaders.add("yellow", "banana");
otherHeaders.add("red", "apple");
headers.putAll(otherHeaders);
assertThat(headers.toSingleValueMap()).containsOnly(
entry("yellow", "banana"),
entry("red", "apple")
);
}
@Test // gh-35792
void putAllFromHttpHeadersViaHttpHeadersApi() {
headers.add("yellow", "banana");
headers.add("red", "apple");
var otherHeaders = new HttpHeaders();
otherHeaders.putAll(headers);
assertThat(otherHeaders.toSingleValueMap()).containsOnly(
entry("yellow", "banana"),
entry("red", "apple")
);
}
@Test // gh-35792
void putAllFromMap() {
headers.putAll(Map.of("yellow", List.of("banana"), "red", List.of("apple")));
assertThat(headers.toSingleValueMap()).containsOnly(
entry("yellow", "banana"),
entry("red", "apple")
);
}
@Test // gh-35792
void setAllFromMap() {
headers.add("yellow", "lemon");
assertThat(headers.toSingleValueMap()).containsOnly(
entry("yellow", "lemon")
);
headers.setAll(Map.of("yellow", "banana", "red", "apple"));
assertThat(headers.toSingleValueMap()).containsOnly(
entry("yellow", "banana"), // not lemon
entry("red", "apple")
);
}
}

31
spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/XhrTransportTests.java

@ -29,10 +29,12 @@ import org.springframework.http.ResponseEntity; @@ -29,10 +29,12 @@ import org.springframework.http.ResponseEntity;
import org.springframework.web.client.HttpServerErrorException;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketHttpHeaders;
import org.springframework.web.socket.WebSocketSession;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.entry;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
@ -42,19 +44,38 @@ import static org.mockito.Mockito.verifyNoMoreInteractions; @@ -42,19 +44,38 @@ import static org.mockito.Mockito.verifyNoMoreInteractions;
* Tests for {@link AbstractXhrTransport}.
*
* @author Rossen Stoyanchev
* @author Sam Brannen
*/
class XhrTransportTests {
private final TestXhrTransport transport = new TestXhrTransport();
@Test
void infoResponse() {
TestXhrTransport transport = new TestXhrTransport();
transport.infoResponseToReturn = new ResponseEntity<>("body", HttpStatus.OK);
assertThat(transport.executeInfoRequest(URI.create("https://example.com/info"), null)).isEqualTo("body");
}
@Test // gh-35792
void infoResponseWithWebSocketHttpHeaders() {
transport.infoResponseToReturn = new ResponseEntity<>("body", HttpStatus.OK);
var headers = new WebSocketHttpHeaders();
headers.setSecWebSocketAccept("enigma");
headers.add("foo", "bar");
transport.executeInfoRequest(URI.create("https://example.com/info"), headers);
assertThat(transport.actualInfoHeaders).isNotNull();
assertThat(transport.actualInfoHeaders.toSingleValueMap()).containsExactly(
entry(WebSocketHttpHeaders.SEC_WEBSOCKET_ACCEPT, "enigma"),
entry("foo", "bar")
);
}
@Test
void infoResponseError() {
TestXhrTransport transport = new TestXhrTransport();
transport.infoResponseToReturn = new ResponseEntity<>("body", HttpStatus.BAD_REQUEST);
assertThatExceptionOfType(HttpServerErrorException.class).isThrownBy(() ->
transport.executeInfoRequest(URI.create("https://example.com/info"), null));
@ -65,7 +86,6 @@ class XhrTransportTests { @@ -65,7 +86,6 @@ class XhrTransportTests {
HttpHeaders requestHeaders = new HttpHeaders();
requestHeaders.set("foo", "bar");
requestHeaders.setContentType(MediaType.APPLICATION_JSON);
TestXhrTransport transport = new TestXhrTransport();
transport.sendMessageResponseToReturn = new ResponseEntity<>(HttpStatus.NO_CONTENT);
URI url = URI.create("https://example.com");
transport.executeSendRequest(url, requestHeaders, new TextMessage("payload"));
@ -76,7 +96,6 @@ class XhrTransportTests { @@ -76,7 +96,6 @@ class XhrTransportTests {
@Test
void sendMessageError() {
TestXhrTransport transport = new TestXhrTransport();
transport.sendMessageResponseToReturn = new ResponseEntity<>(HttpStatus.BAD_REQUEST);
URI url = URI.create("https://example.com");
assertThatExceptionOfType(HttpServerErrorException.class).isThrownBy(() ->
@ -93,7 +112,6 @@ class XhrTransportTests { @@ -93,7 +112,6 @@ class XhrTransportTests {
given(request.getHandshakeHeaders()).willReturn(handshakeHeaders);
given(request.getHttpRequestHeaders()).willReturn(new HttpHeaders());
TestXhrTransport transport = new TestXhrTransport();
WebSocketHandler handler = mock();
transport.connectAsync(request, handler);
@ -124,10 +142,13 @@ class XhrTransportTests { @@ -124,10 +142,13 @@ class XhrTransportTests {
private HttpHeaders actualHandshakeHeaders;
private HttpHeaders actualInfoHeaders;
private XhrClientSockJsSession actualSession;
@Override
protected ResponseEntity<String> executeInfoRequestInternal(URI infoUrl, HttpHeaders headers) {
this.actualInfoHeaders = headers;
return this.infoResponseToReturn;
}

Loading…
Cancel
Save