Browse Source

Fix HttpHeaders and WebSocketHttpHeaders interop issues

Since HttpHeaders no longer implements MultiValueMap (see gh-33913),
a few interoperability issues have arisen between HttpHeaders and
WebSocketHttpHeaders.

To address those issues, this commit:

- Revises addAll(HttpHeaders), putAll(HttpHeaders), and putAll(Map) in
  HttpHeaders so that they no longer operate on the HttpHeaders.headers
  field.

- Overrides addAll(String, List), asSingleValueMap(), and
  asMultiValueMap() in WebSocketHttpHeaders.

- Deletes putAll(HttpHeaders), putAll(Map), and forEach(BiConsumer) in
  WebSocketHttpHeaders, since they do not need to be overridden.

This commit also removes unnecessarily overridden Javadoc in
WebSocketHttpHeaders and revises the implementation of several methods
in HttpHeaders so that they delegate to key methods such as get()
instead of directly accessing the HttpHeaders.headers field.

See gh-33913
Closes gh-35792
pull/34993/merge
Sam Brannen 1 month ago
parent
commit
4593f877dd
  1. 24
      spring-web/src/main/java/org/springframework/http/HttpHeaders.java
  2. 102
      spring-websocket/src/main/java/org/springframework/web/socket/WebSocketHttpHeaders.java
  3. 107
      spring-websocket/src/test/java/org/springframework/web/socket/handler/WebSocketHttpHeadersTests.java
  4. 31
      spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/XhrTransportTests.java

24
spring-web/src/main/java/org/springframework/http/HttpHeaders.java

@ -1686,13 +1686,13 @@ public class HttpHeaders implements Serializable {
* @since 5.2.3 * @since 5.2.3
*/ */
public void clearContentHeaders() { public void clearContentHeaders() {
this.headers.remove(HttpHeaders.CONTENT_DISPOSITION); remove(HttpHeaders.CONTENT_DISPOSITION);
this.headers.remove(HttpHeaders.CONTENT_ENCODING); remove(HttpHeaders.CONTENT_ENCODING);
this.headers.remove(HttpHeaders.CONTENT_LANGUAGE); remove(HttpHeaders.CONTENT_LANGUAGE);
this.headers.remove(HttpHeaders.CONTENT_LENGTH); remove(HttpHeaders.CONTENT_LENGTH);
this.headers.remove(HttpHeaders.CONTENT_LOCATION); remove(HttpHeaders.CONTENT_LOCATION);
this.headers.remove(HttpHeaders.CONTENT_RANGE); remove(HttpHeaders.CONTENT_RANGE);
this.headers.remove(HttpHeaders.CONTENT_TYPE); remove(HttpHeaders.CONTENT_TYPE);
} }
/** /**
@ -1807,7 +1807,7 @@ public class HttpHeaders implements Serializable {
* @see #putAll(HttpHeaders) * @see #putAll(HttpHeaders)
*/ */
public void addAll(HttpHeaders headers) { public void addAll(HttpHeaders headers) {
this.headers.addAll(headers.headers); headers.forEach(this::addAll);
} }
/** /**
@ -1909,7 +1909,7 @@ public class HttpHeaders implements Serializable {
* @since 7.0 * @since 7.0
*/ */
public boolean hasHeaderValues(String headerName, List<String> values) { public boolean hasHeaderValues(String headerName, List<String> values) {
return ObjectUtils.nullSafeEquals(this.headers.get(headerName), values); return ObjectUtils.nullSafeEquals(get(headerName), values);
} }
/** /**
@ -1920,7 +1920,7 @@ public class HttpHeaders implements Serializable {
* @since 7.0 * @since 7.0
*/ */
public boolean containsHeaderValue(String headerName, String value) { public boolean containsHeaderValue(String headerName, String value) {
final List<String> values = this.headers.get(headerName); List<String> values = get(headerName);
if (values == null) { if (values == null) {
return false; return false;
} }
@ -1969,7 +1969,7 @@ public class HttpHeaders implements Serializable {
* @see #put(String, List) * @see #put(String, List)
*/ */
public void putAll(HttpHeaders headers) { public void putAll(HttpHeaders headers) {
this.headers.putAll(headers.headers); headers.forEach(this::put);
} }
/** /**
@ -1978,7 +1978,7 @@ public class HttpHeaders implements Serializable {
* @see #put(String, List) * @see #put(String, List)
*/ */
public void putAll(Map<? extends String, ? extends List<String>> headers) { public void putAll(Map<? extends String, ? extends List<String>> headers) {
this.headers.putAll(headers); headers.forEach(this::put);
} }
/** /**

102
spring-websocket/src/main/java/org/springframework/web/socket/WebSocketHttpHeaders.java

@ -21,18 +21,19 @@ import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.function.BiConsumer;
import org.jspecify.annotations.Nullable; import org.jspecify.annotations.Nullable;
import org.springframework.http.HttpHeaders; import org.springframework.http.HttpHeaders;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
import org.springframework.util.MultiValueMap;
/** /**
* An {@link org.springframework.http.HttpHeaders} variant that adds support for * An {@link org.springframework.http.HttpHeaders} variant that adds support for
* the HTTP headers defined by the WebSocket specification RFC 6455. * the HTTP headers defined by the WebSocket specification RFC 6455.
* *
* @author Rossen Stoyanchev * @author Rossen Stoyanchev
* @author Sam Brannen
* @since 4.0 * @since 4.0
*/ */
public class WebSocketHttpHeaders extends HttpHeaders { public class WebSocketHttpHeaders extends HttpHeaders {
@ -181,40 +182,40 @@ public class WebSocketHttpHeaders extends HttpHeaders {
return getFirst(SEC_WEBSOCKET_VERSION); return getFirst(SEC_WEBSOCKET_VERSION);
} }
@Override
public @Nullable List<String> get(String headerName) {
return this.headers.get(headerName);
}
// Single string methods
/**
* Return the first header value for the given header name, if any.
* @param headerName the header name
* @return the first header value; or {@code null}
*/
@Override @Override
public @Nullable String getFirst(String headerName) { public @Nullable String getFirst(String headerName) {
return this.headers.getFirst(headerName); return this.headers.getFirst(headerName);
} }
/** @Override
* Add the given, single header value under the given name. public @Nullable List<String> put(String key, List<String> value) {
* @param headerName the header name return this.headers.put(key, value);
* @param headerValue the header value }
* @throws UnsupportedOperationException if adding headers is not supported
* @see #put(String, List) @Override
* @see #set(String, String) public @Nullable List<String> putIfAbsent(String headerName, List<String> headerValues) {
*/ return this.headers.putIfAbsent(headerName, headerValues);
}
@Override @Override
public void add(String headerName, @Nullable String headerValue) { public void add(String headerName, @Nullable String headerValue) {
this.headers.add(headerName, headerValue); this.headers.add(headerName, headerValue);
} }
/** /**
* Set the given, single header value under the given name. * {@inheritDoc}
* @param headerName the header name * @since 7.0
* @param headerValue the header value
* @throws UnsupportedOperationException if adding headers is not supported
* @see #put(String, List)
* @see #add(String, String)
*/ */
@Override
public void addAll(String headerName, List<? extends String> headerValues) {
this.headers.addAll(headerName, headerValues);
}
@Override @Override
public void set(String headerName, @Nullable String headerValue) { public void set(String headerName, @Nullable String headerValue) {
this.headers.set(headerName, headerValue); this.headers.set(headerName, headerValue);
@ -230,16 +231,32 @@ public class WebSocketHttpHeaders extends HttpHeaders {
return this.headers.toSingleValueMap(); return this.headers.toSingleValueMap();
} }
// Map implementation /**
* {@inheritDoc}
* @since 7.0
* @deprecated in favor of {@link #toSingleValueMap()} which performs a copy but
* ensures that collection-iterating methods like {@code entrySet()} are
* case-insensitive
*/
@Override @Override
public int size() { @Deprecated(since = "7.0", forRemoval = true)
return this.headers.size(); @SuppressWarnings("removal")
public Map<String, String> asSingleValueMap() {
return this.headers.asSingleValueMap();
} }
/**
* {@inheritDoc}
* @since 7.0
* @deprecated This method is provided for backward compatibility with APIs
* that would only accept maps. Generally avoid using HttpHeaders as a Map
* or MultiValueMap.
*/
@Override @Override
public boolean isEmpty() { @Deprecated(since = "7.0", forRemoval = true)
return this.headers.isEmpty(); @SuppressWarnings("removal")
public MultiValueMap<String, String> asMultiValueMap() {
return this.headers.asMultiValueMap();
} }
@Override @Override
@ -248,13 +265,13 @@ public class WebSocketHttpHeaders extends HttpHeaders {
} }
@Override @Override
public @Nullable List<String> get(String headerName) { public boolean isEmpty() {
return this.headers.get(headerName); return this.headers.isEmpty();
} }
@Override @Override
public @Nullable List<String> put(String key, List<String> value) { public int size() {
return this.headers.put(key, value); return this.headers.size();
} }
@Override @Override
@ -262,16 +279,6 @@ public class WebSocketHttpHeaders extends HttpHeaders {
return this.headers.remove(key); return this.headers.remove(key);
} }
@Override
public void putAll(HttpHeaders headers) {
this.headers.putAll(headers);
}
@Override
public void putAll(Map<? extends String, ? extends List<String>> m) {
this.headers.putAll(m);
}
@Override @Override
public void clear() { public void clear() {
this.headers.clear(); this.headers.clear();
@ -287,17 +294,6 @@ public class WebSocketHttpHeaders extends HttpHeaders {
return this.headers.headerSet(); return this.headers.headerSet();
} }
@Override
public void forEach(BiConsumer<? super String, ? super List<String>> action) {
this.headers.forEach(action);
}
@Override
public @Nullable List<String> putIfAbsent(String headerName, List<String> headerValues) {
return this.headers.putIfAbsent(headerName, headerValues);
}
@Override @Override
public boolean equals(@Nullable Object other) { public boolean equals(@Nullable Object other) {
return (this == other || (other instanceof WebSocketHttpHeaders that && return (this == other || (other instanceof WebSocketHttpHeaders that &&

107
spring-websocket/src/test/java/org/springframework/web/socket/handler/WebSocketHttpHeadersTests.java

@ -16,41 +16,122 @@
package org.springframework.web.socket.handler; package org.springframework.web.socket.handler;
import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Map;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.springframework.web.socket.WebSocketExtension; import org.springframework.http.HttpHeaders;
import org.springframework.web.socket.WebSocketHttpHeaders; import org.springframework.web.socket.WebSocketHttpHeaders;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.entry;
/** /**
* Tests for {@link WebSocketHttpHeaders}. * Tests for {@link WebSocketHttpHeaders}.
* *
* @author Rossen Stoyanchev * @author Rossen Stoyanchev
* @author Sam Brannen
*/ */
class WebSocketHttpHeadersTests { class WebSocketHttpHeadersTests {
private WebSocketHttpHeaders headers; private WebSocketHttpHeaders headers = new WebSocketHttpHeaders();
@BeforeEach
void setUp() {
headers = new WebSocketHttpHeaders();
}
@Test @Test
void parseWebSocketExtensions() { void parseWebSocketExtensions() {
List<String> extensions = new ArrayList<>(); var extensions = List.of("x-foo-extension, x-bar-extension", "x-test-extension");
extensions.add("x-foo-extension, x-bar-extension");
extensions.add("x-test-extension");
this.headers.put(WebSocketHttpHeaders.SEC_WEBSOCKET_EXTENSIONS, extensions); this.headers.put(WebSocketHttpHeaders.SEC_WEBSOCKET_EXTENSIONS, extensions);
List<WebSocketExtension> parsedExtensions = this.headers.getSecWebSocketExtensions(); var parsedExtensions = this.headers.getSecWebSocketExtensions();
assertThat(parsedExtensions).hasSize(3); assertThat(parsedExtensions).hasSize(3);
} }
@Test // gh-35792
void addAllViaWebSocketHttpHeadersApi() {
headers.add("green", "grape");
var otherHeaders = new HttpHeaders();
otherHeaders.add("yellow", "banana");
otherHeaders.add("red", "apple");
headers.addAll(otherHeaders);
assertThat(headers.toSingleValueMap()).containsOnly(
entry("green", "grape"),
entry("yellow", "banana"),
entry("red", "apple")
);
}
@Test // gh-35792
void addAllViaHttpHeadersApi() {
headers.add("yellow", "banana");
headers.add("red", "apple");
var otherHeaders = new HttpHeaders();
otherHeaders.add("green", "grape");
otherHeaders.addAll(headers);
assertThat(otherHeaders.toSingleValueMap()).containsOnly(
entry("green", "grape"),
entry("yellow", "banana"),
entry("red", "apple")
);
}
@Test // gh-35792
void putAllFromHttpHeadersViaWebSocketHttpHeadersApi() {
var otherHeaders = new HttpHeaders();
otherHeaders.add("yellow", "banana");
otherHeaders.add("red", "apple");
headers.putAll(otherHeaders);
assertThat(headers.toSingleValueMap()).containsOnly(
entry("yellow", "banana"),
entry("red", "apple")
);
}
@Test // gh-35792
void putAllFromHttpHeadersViaHttpHeadersApi() {
headers.add("yellow", "banana");
headers.add("red", "apple");
var otherHeaders = new HttpHeaders();
otherHeaders.putAll(headers);
assertThat(otherHeaders.toSingleValueMap()).containsOnly(
entry("yellow", "banana"),
entry("red", "apple")
);
}
@Test // gh-35792
void putAllFromMap() {
headers.putAll(Map.of("yellow", List.of("banana"), "red", List.of("apple")));
assertThat(headers.toSingleValueMap()).containsOnly(
entry("yellow", "banana"),
entry("red", "apple")
);
}
@Test // gh-35792
void setAllFromMap() {
headers.add("yellow", "lemon");
assertThat(headers.toSingleValueMap()).containsOnly(
entry("yellow", "lemon")
);
headers.setAll(Map.of("yellow", "banana", "red", "apple"));
assertThat(headers.toSingleValueMap()).containsOnly(
entry("yellow", "banana"), // not lemon
entry("red", "apple")
);
}
} }

31
spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/XhrTransportTests.java

@ -29,10 +29,12 @@ import org.springframework.http.ResponseEntity;
import org.springframework.web.client.HttpServerErrorException; import org.springframework.web.client.HttpServerErrorException;
import org.springframework.web.socket.TextMessage; import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketHttpHeaders;
import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.WebSocketSession;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.entry;
import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
@ -42,19 +44,38 @@ import static org.mockito.Mockito.verifyNoMoreInteractions;
* Tests for {@link AbstractXhrTransport}. * Tests for {@link AbstractXhrTransport}.
* *
* @author Rossen Stoyanchev * @author Rossen Stoyanchev
* @author Sam Brannen
*/ */
class XhrTransportTests { class XhrTransportTests {
private final TestXhrTransport transport = new TestXhrTransport();
@Test @Test
void infoResponse() { void infoResponse() {
TestXhrTransport transport = new TestXhrTransport();
transport.infoResponseToReturn = new ResponseEntity<>("body", HttpStatus.OK); transport.infoResponseToReturn = new ResponseEntity<>("body", HttpStatus.OK);
assertThat(transport.executeInfoRequest(URI.create("https://example.com/info"), null)).isEqualTo("body"); assertThat(transport.executeInfoRequest(URI.create("https://example.com/info"), null)).isEqualTo("body");
} }
@Test // gh-35792
void infoResponseWithWebSocketHttpHeaders() {
transport.infoResponseToReturn = new ResponseEntity<>("body", HttpStatus.OK);
var headers = new WebSocketHttpHeaders();
headers.setSecWebSocketAccept("enigma");
headers.add("foo", "bar");
transport.executeInfoRequest(URI.create("https://example.com/info"), headers);
assertThat(transport.actualInfoHeaders).isNotNull();
assertThat(transport.actualInfoHeaders.toSingleValueMap()).containsExactly(
entry(WebSocketHttpHeaders.SEC_WEBSOCKET_ACCEPT, "enigma"),
entry("foo", "bar")
);
}
@Test @Test
void infoResponseError() { void infoResponseError() {
TestXhrTransport transport = new TestXhrTransport();
transport.infoResponseToReturn = new ResponseEntity<>("body", HttpStatus.BAD_REQUEST); transport.infoResponseToReturn = new ResponseEntity<>("body", HttpStatus.BAD_REQUEST);
assertThatExceptionOfType(HttpServerErrorException.class).isThrownBy(() -> assertThatExceptionOfType(HttpServerErrorException.class).isThrownBy(() ->
transport.executeInfoRequest(URI.create("https://example.com/info"), null)); transport.executeInfoRequest(URI.create("https://example.com/info"), null));
@ -65,7 +86,6 @@ class XhrTransportTests {
HttpHeaders requestHeaders = new HttpHeaders(); HttpHeaders requestHeaders = new HttpHeaders();
requestHeaders.set("foo", "bar"); requestHeaders.set("foo", "bar");
requestHeaders.setContentType(MediaType.APPLICATION_JSON); requestHeaders.setContentType(MediaType.APPLICATION_JSON);
TestXhrTransport transport = new TestXhrTransport();
transport.sendMessageResponseToReturn = new ResponseEntity<>(HttpStatus.NO_CONTENT); transport.sendMessageResponseToReturn = new ResponseEntity<>(HttpStatus.NO_CONTENT);
URI url = URI.create("https://example.com"); URI url = URI.create("https://example.com");
transport.executeSendRequest(url, requestHeaders, new TextMessage("payload")); transport.executeSendRequest(url, requestHeaders, new TextMessage("payload"));
@ -76,7 +96,6 @@ class XhrTransportTests {
@Test @Test
void sendMessageError() { void sendMessageError() {
TestXhrTransport transport = new TestXhrTransport();
transport.sendMessageResponseToReturn = new ResponseEntity<>(HttpStatus.BAD_REQUEST); transport.sendMessageResponseToReturn = new ResponseEntity<>(HttpStatus.BAD_REQUEST);
URI url = URI.create("https://example.com"); URI url = URI.create("https://example.com");
assertThatExceptionOfType(HttpServerErrorException.class).isThrownBy(() -> assertThatExceptionOfType(HttpServerErrorException.class).isThrownBy(() ->
@ -93,7 +112,6 @@ class XhrTransportTests {
given(request.getHandshakeHeaders()).willReturn(handshakeHeaders); given(request.getHandshakeHeaders()).willReturn(handshakeHeaders);
given(request.getHttpRequestHeaders()).willReturn(new HttpHeaders()); given(request.getHttpRequestHeaders()).willReturn(new HttpHeaders());
TestXhrTransport transport = new TestXhrTransport();
WebSocketHandler handler = mock(); WebSocketHandler handler = mock();
transport.connectAsync(request, handler); transport.connectAsync(request, handler);
@ -124,10 +142,13 @@ class XhrTransportTests {
private HttpHeaders actualHandshakeHeaders; private HttpHeaders actualHandshakeHeaders;
private HttpHeaders actualInfoHeaders;
private XhrClientSockJsSession actualSession; private XhrClientSockJsSession actualSession;
@Override @Override
protected ResponseEntity<String> executeInfoRequestInternal(URI infoUrl, HttpHeaders headers) { protected ResponseEntity<String> executeInfoRequestInternal(URI infoUrl, HttpHeaders headers) {
this.actualInfoHeaders = headers;
return this.infoResponseToReturn; return this.infoResponseToReturn;
} }

Loading…
Cancel
Save