Browse Source

Pass-through handling of Servlet headers

See gh-36334
pull/36335/head
rstoyanchev 1 month ago
parent
commit
7ea11baff9
  1. 527
      spring-web/src/main/java/org/springframework/http/server/ServletRequestHeadersAdapter.java
  2. 269
      spring-web/src/main/java/org/springframework/http/server/ServletResponseHeadersAdapter.java
  3. 70
      spring-web/src/main/java/org/springframework/http/server/ServletServerHttpRequest.java
  4. 92
      spring-web/src/main/java/org/springframework/http/server/ServletServerHttpResponse.java
  5. 2
      spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/HttpEntityMethodProcessor.java
  6. 2
      spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitterReturnValueHandlerTests.java

527
spring-web/src/main/java/org/springframework/http/server/ServletRequestHeadersAdapter.java

@ -0,0 +1,527 @@ @@ -0,0 +1,527 @@
/*
* Copyright 2002-present the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.http.server;
import java.util.AbstractSet;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Enumeration;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import jakarta.servlet.http.HttpServletRequest;
import org.jspecify.annotations.Nullable;
import org.springframework.http.HttpHeaders;
import org.springframework.util.CollectionUtils;
import org.springframework.util.LinkedCaseInsensitiveMap;
import org.springframework.util.MultiValueMap;
/**
* {@code MultiValueMap} implementation for wrapping Servlet request headers.
*
* @author Rossen Stoyanchev
* @since 7.0.5
*/
final class ServletRequestHeadersAdapter implements MultiValueMap<String, String> {
private final HttpServletRequest request;
ServletRequestHeadersAdapter(HttpServletRequest request) {
this.request = request;
}
@Override
public @Nullable String getFirst(String key) {
return this.request.getHeader(key);
}
@Override
public void add(String key, @Nullable String value) {
throw new UnsupportedOperationException();
}
@Override
public void addAll(String key, List<? extends String> values) {
throw new UnsupportedOperationException();
}
@Override
public void addAll(MultiValueMap<String, String> map) {
throw new UnsupportedOperationException();
}
@Override
public void set(String key, @Nullable String value) {
throw new UnsupportedOperationException();
}
@Override
public void setAll(Map<String, String> map) {
throw new UnsupportedOperationException();
}
@Override
public Map<String, String> toSingleValueMap() {
Map<String, String> map = new LinkedHashMap<>();
Enumeration<String> names = this.request.getHeaderNames();
while (names.hasMoreElements()) {
String name = names.nextElement();
map.put(name, this.request.getHeader(name));
}
return map;
}
@Override
public int size() {
Enumeration<String> names = this.request.getHeaderNames();
Set<String> set = new LinkedHashSet<>();
while (names.hasMoreElements()) {
set.add(names.nextElement().toLowerCase(Locale.ROOT));
}
return set.size();
}
@Override
public boolean isEmpty() {
return !this.request.getHeaderNames().hasMoreElements();
}
@Override
public boolean containsKey(Object key) {
if (key instanceof String headerName) {
Enumeration<String> names = this.request.getHeaderNames();
while (names.hasMoreElements()) {
if (headerName.equalsIgnoreCase(names.nextElement())) {
return true;
}
}
}
return false;
}
@Override
public boolean containsValue(Object rawValue) {
if (rawValue instanceof String text) {
Enumeration<String> names = this.request.getHeaderNames();
while (names.hasMoreElements()) {
Enumeration<String> values = this.request.getHeaders(names.nextElement());
while (values.hasMoreElements()) {
if (text.equals(values.nextElement())) {
return true;
}
}
}
}
return false;
}
@Override
public @Nullable List<String> get(Object key) {
if (key instanceof String headerName) {
Enumeration<String> values = this.request.getHeaders(headerName);
if (values.hasMoreElements()) {
List<String> result = new ArrayList<>();
while (values.hasMoreElements()) {
result.add(values.nextElement());
}
return result;
}
}
return null;
}
@Override
public @Nullable List<String> put(String key, List<String> value) {
throw new UnsupportedOperationException();
}
@Override
public @Nullable List<String> remove(Object key) {
throw new UnsupportedOperationException();
}
@Override
public void putAll(Map<? extends String, ? extends List<String>> map) {
throw new UnsupportedOperationException();
}
@Override
public void clear() {
throw new UnsupportedOperationException();
}
@Override
public Set<String> keySet() {
Set<String> set = new LinkedHashSet<>();
Enumeration<String> names = this.request.getHeaderNames();
while (names.hasMoreElements()) {
set.add(names.nextElement());
}
return set;
}
@Override
public Collection<List<String>> values() {
List<List<String>> allValues = new ArrayList<>();
Enumeration<String> names = this.request.getHeaderNames();
while (names.hasMoreElements()) {
String name = names.nextElement();
List<String> currentValues = new ArrayList<>();
Enumeration<String> values = this.request.getHeaders(name);
while (values.hasMoreElements()) {
currentValues.add(values.nextElement());
}
allValues.add(currentValues);
}
return allValues;
}
@Override
public Set<Entry<String, List<String>>> entrySet() {
return new AbstractSet<>() {
@Override
public Iterator<Entry<String, List<String>>> iterator() {
return new EntryIterator();
}
@Override
public int size() {
return ServletRequestHeadersAdapter.this.size();
}
};
}
@Override
public int hashCode() {
return Map.copyOf(this).hashCode();
}
@Override
public boolean equals(@Nullable Object other) {
return (this == other ||
(other instanceof MultiValueMap<?,?> that && Map.copyOf(this).equals(that)));
}
@Override
public String toString() {
return HttpHeaders.formatHeaders(this);
}
/**
* Apply a wrapper that allows headers to be set or added, and treats those
* as overrides to the headers in the given MultiValueMap.
* @param headers the headers map to wrap
* @return the wrapper instance
*/
public static MultiValueMap<String, String> overrideHeadersWrapper(MultiValueMap<String, String> headers) {
return new OverrideHeaderWrapper(headers);
}
private class EntryIterator implements Iterator<Entry<String, List<String>>> {
private final Iterator<String> names = ServletRequestHeadersAdapter.this.keySet().iterator();
@Override
public boolean hasNext() {
return this.names.hasNext();
}
@Override
public Entry<String, List<String>> next() {
return new HeaderEntry(this.names.next());
}
}
private final class HeaderEntry implements Entry<String, List<String>> {
private final String key;
HeaderEntry(String key) {
this.key = key;
}
@Override
public String getKey() {
return this.key;
}
@Override
public @Nullable List<String> getValue() {
return get(this.key);
}
@Override
public @Nullable List<String> setValue(List<String> value) {
List<String> previous = getValue();
remove(this.key);
addAll(this.key, value);
return previous;
}
}
/**
* Wrapper that supports optional override values.
*/
private static class OverrideHeaderWrapper implements MultiValueMap<String, String> {
private final MultiValueMap<String, String> delegate;
private @Nullable MultiValueMap<String, String> overrideHeaders;
OverrideHeaderWrapper(MultiValueMap<String, String> delegate) {
this.delegate = delegate;
}
@Override
public @Nullable String getFirst(String key) {
String value = (this.overrideHeaders != null ? this.overrideHeaders.getFirst(key) : null);
return (value != null ? value : this.delegate.getFirst(key));
}
@Override
public void add(String key, @Nullable String value) {
initOverrideHeaders().add(key, value);
}
@Override
public void addAll(String key, List<? extends String> values) {
initOverrideHeaders().addAll(key, values);
}
@Override
public void addAll(MultiValueMap<String, String> map) {
initOverrideHeaders().addAll(map);
}
@Override
public void set(String key, @Nullable String value) {
initOverrideHeaders().set(key, value);
}
@Override
public void setAll(Map<String, String> map) {
initOverrideHeaders().setAll(map);
}
@Override
public Map<String, String> toSingleValueMap() {
Map<String, String> map = this.delegate.toSingleValueMap();
if (this.overrideHeaders != null) {
this.overrideHeaders.forEach((key, values) -> map.put(key, values.get(0)));
}
return map;
}
@Override
public int size() {
if (this.overrideHeaders == null) {
return this.delegate.size();
}
Set<String> set = new LinkedHashSet<>();
for (String name : this.delegate.keySet()) {
set.add(name.toLowerCase(Locale.ROOT));
}
this.overrideHeaders.keySet().forEach(key -> set.add(key.toLowerCase(Locale.ROOT)));
return set.size();
}
@Override
public boolean isEmpty() {
return (this.delegate.isEmpty() && (this.overrideHeaders == null || this.overrideHeaders.isEmpty()));
}
@Override
public boolean containsKey(Object key) {
if (key instanceof String headerName) {
if (this.delegate.containsKey(headerName)) {
return true;
}
if (this.overrideHeaders != null) {
return this.overrideHeaders.containsKey(headerName);
}
}
return false;
}
@Override
public boolean containsValue(Object rawValue) {
if (rawValue instanceof String text) {
if (this.delegate.containsValue(text)) {
return true;
}
if (this.overrideHeaders != null) {
return this.overrideHeaders.containsValue(rawValue);
}
}
return false;
}
@Override
public @Nullable List<String> get(Object key) {
if (key instanceof String headerName) {
if (this.overrideHeaders != null) {
List<String> values = this.overrideHeaders.get(headerName);
if (values != null) {
return values;
}
}
return this.delegate.get(headerName);
}
return null;
}
@Override
public @Nullable List<String> put(String key, List<String> value) {
return initOverrideHeaders().put(key, value);
}
@Override
public @Nullable List<String> remove(Object key) {
return initOverrideHeaders().remove(key);
}
@Override
public void putAll(Map<? extends String, ? extends List<String>> map) {
initOverrideHeaders().putAll(map);
}
@Override
public void clear() {
initOverrideHeaders().clear();
}
@Override
public Set<String> keySet() {
Set<String> set = this.delegate.keySet();
if (this.overrideHeaders != null) {
set.addAll(this.overrideHeaders.keySet());
}
return set;
}
@Override
public Collection<List<String>> values() {
List<List<String>> allValues = new ArrayList<>();
for (String name : keySet()) {
if (this.overrideHeaders != null && this.overrideHeaders.containsKey(name)) {
allValues.add(this.overrideHeaders.get(name));
}
else {
allValues.add(this.delegate.get(name));
}
}
return allValues;
}
@Override
public Set<Entry<String, List<String>>> entrySet() {
return new AbstractSet<>() {
@Override
public Iterator<Entry<String, List<String>>> iterator() {
return new OverrideHeaderWrapper.EntryIterator();
}
@Override
public int size() {
return OverrideHeaderWrapper.this.size();
}
};
}
private MultiValueMap<String, String> initOverrideHeaders() {
if (this.overrideHeaders == null) {
this.overrideHeaders = CollectionUtils.toMultiValueMap(new LinkedCaseInsensitiveMap<>(8, Locale.ROOT));
}
return this.overrideHeaders;
}
@Override
public int hashCode() {
return Map.copyOf(this).hashCode();
}
@Override
public boolean equals(@Nullable Object other) {
return (this == other ||
(other instanceof MultiValueMap<?,?> that && Map.copyOf(this).equals(that)));
}
@Override
public String toString() {
return HttpHeaders.formatHeaders(this);
}
private class EntryIterator implements Iterator<Entry<String, List<String>>> {
private final Iterator<String> names = OverrideHeaderWrapper.this.keySet().iterator();
@Override
public boolean hasNext() {
return this.names.hasNext();
}
@Override
public Entry<String, List<String>> next() {
return new OverrideHeaderWrapper.HeaderEntry(this.names.next());
}
}
private final class HeaderEntry implements Entry<String, List<String>> {
private final String key;
HeaderEntry(String key) {
this.key = key;
}
@Override
public String getKey() {
return this.key;
}
@Override
public @Nullable List<String> getValue() {
return get(this.key);
}
@Override
public @Nullable List<String> setValue(List<String> value) {
List<String> previous = getValue();
remove(this.key);
addAll(this.key, value);
return previous;
}
}
}
}

269
spring-web/src/main/java/org/springframework/http/server/ServletResponseHeadersAdapter.java

@ -0,0 +1,269 @@ @@ -0,0 +1,269 @@
/*
* Copyright 2002-present the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.http.server;
import java.util.AbstractSet;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import jakarta.servlet.http.HttpServletResponse;
import org.jspecify.annotations.Nullable;
import org.springframework.http.HttpHeaders;
import org.springframework.util.MultiValueMap;
/**
* {@code MultiValueMap} implementation for wrapping Servlet response headers.
*
* @author Rossen Stoyanchev
* @since 7.0.5
*/
class ServletResponseHeadersAdapter implements MultiValueMap<String, String> {
private final HttpServletResponse response;
ServletResponseHeadersAdapter(HttpServletResponse response) {
this.response = response;
}
@Override
public String getFirst(String key) {
return this.response.getHeader(key);
}
@Override
public void add(String key, @Nullable String value) {
this.response.addHeader(key, value);
}
@Override
public void addAll(String key, List<? extends String> values) {
values.forEach(value -> this.response.addHeader(key, value));
}
@Override
public void addAll(MultiValueMap<String, String> map) {
for (Entry<String, List<String>> entry : map.entrySet()) {
for (String value : entry.getValue()) {
this.response.addHeader(entry.getKey(), value);
}
}
}
@Override
public void set(String key, @Nullable String value) {
this.response.setHeader(key, value);
}
@Override
public void setAll(Map<String, String> map) {
for (Entry<String, String> entry : map.entrySet()) {
this.response.setHeader(entry.getKey(), entry.getValue());
}
}
@Override
public Map<String, String> toSingleValueMap() {
Map<String, String> map = new LinkedHashMap<>();
Collection<String> names = this.response.getHeaderNames();
for (String name : names) {
map.put(name, this.response.getHeader(name));
}
return map;
}
@Override
public int size() {
return this.response.getHeaderNames().size();
}
@Override
public boolean isEmpty() {
return this.response.getHeaderNames().isEmpty();
}
@Override
public boolean containsKey(Object key) {
if (key instanceof String headerName) {
return this.response.containsHeader(headerName);
}
return false;
}
@Override
public boolean containsValue(Object rawValue) {
if (rawValue instanceof String text) {
for (String name : this.response.getHeaderNames()) {
Collection<String> values = this.response.getHeaders(name);
for (String value : values) {
if (text.equals(value)) {
return true;
}
}
}
}
return false;
}
@Override
public @Nullable List<String> get(Object key) {
if (key instanceof String headerName) {
Collection<String> values = this.response.getHeaders(headerName);
if (!values.isEmpty()) {
return (values instanceof List ? (List<String>) values : new ArrayList<>(values));
}
return (this.response.containsHeader(headerName) ? Collections.emptyList() : null);
}
return null;
}
@Override
public @Nullable List<String> put(String key, List<String> values) {
List<String> previous = remove(key);
for (String value : values) {
this.response.addHeader(key, value);
}
return previous;
}
@Override
public @Nullable List<String> remove(Object key) {
if (key instanceof String headerName) {
Collection<String> previous = this.response.getHeaders(headerName);
if (previous != null) {
this.response.setHeader(headerName, null);
}
return (previous != null ? new ArrayList<>(previous) : null);
}
return null;
}
@Override
public void putAll(Map<? extends String, ? extends List<String>> map) {
for (Entry<? extends String, ? extends List<String>> entry : map.entrySet()) {
this.response.setHeader(entry.getKey(), null);
for (String value : entry.getValue()) {
this.response.addHeader(entry.getKey(), value);
}
}
}
@Override
public void clear() {
for (String headerName : this.response.getHeaderNames()) {
this.response.setHeader(headerName, null);
}
}
@Override
public Set<String> keySet() {
return new LinkedHashSet<>(this.response.getHeaderNames());
}
@Override
public Collection<List<String>> values() {
List<List<String>> allValues = new ArrayList<>();
for (String name : this.response.getHeaderNames()) {
allValues.add(new ArrayList<>(this.response.getHeaders(name)));
}
return allValues;
}
@Override
public Set<Entry<String, List<String>>> entrySet() {
return new AbstractSet<>() {
@Override
public Iterator<Entry<String, List<String>>> iterator() {
return new EntryIterator();
}
@Override
public int size() {
return ServletResponseHeadersAdapter.this.size();
}
};
}
@Override
public int hashCode() {
return Map.copyOf(this).hashCode();
}
@Override
public boolean equals(@Nullable Object other) {
return (this == other ||
(other instanceof MultiValueMap<?,?> that && Map.copyOf(this).equals(that)));
}
@Override
public String toString() {
return HttpHeaders.formatHeaders(this);
}
private class EntryIterator implements Iterator<Entry<String, List<String>>> {
private final Iterator<String> names =
ServletResponseHeadersAdapter.this.response.getHeaderNames().iterator();
@Override
public boolean hasNext() {
return this.names.hasNext();
}
@Override
public Entry<String, List<String>> next() {
return new HeaderEntry(this.names.next());
}
}
private final class HeaderEntry implements Entry<String, List<String>> {
private final String key;
HeaderEntry(String key) {
this.key = key;
}
@Override
public String getKey() {
return this.key;
}
@Override
public @Nullable List<String> getValue() {
return ServletResponseHeadersAdapter.this.get(this.key);
}
@Override
public @Nullable List<String> setValue(List<String> values) {
return ServletResponseHeadersAdapter.this.put(this.key, values);
}
}
}

70
spring-web/src/main/java/org/springframework/http/server/ServletServerHttpRequest.java

@ -22,6 +22,7 @@ import java.io.IOException; @@ -22,6 +22,7 @@ import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStreamWriter;
import java.io.Writer;
import java.lang.reflect.Field;
import java.net.InetSocketAddress;
import java.net.URI;
import java.net.URISyntaxException;
@ -41,6 +42,9 @@ import java.util.Map; @@ -41,6 +42,9 @@ import java.util.Map;
import java.util.Set;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletRequestWrapper;
import org.apache.catalina.connector.RequestFacade;
import org.apache.coyote.Request;
import org.jspecify.annotations.Nullable;
import org.springframework.http.HttpHeaders;
@ -48,7 +52,10 @@ import org.springframework.http.HttpMethod; @@ -48,7 +52,10 @@ import org.springframework.http.HttpMethod;
import org.springframework.http.InvalidMediaTypeException;
import org.springframework.http.MediaType;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
import org.springframework.util.LinkedCaseInsensitiveMap;
import org.springframework.util.MultiValueMap;
import org.springframework.util.ReflectionUtils;
import org.springframework.util.StringUtils;
/**
@ -63,6 +70,9 @@ public class ServletServerHttpRequest implements ServerHttpRequest { @@ -63,6 +70,9 @@ public class ServletServerHttpRequest implements ServerHttpRequest {
protected static final Charset FORM_CHARSET = StandardCharsets.UTF_8;
private static final boolean TOMCAT_PRESENT = ClassUtils.isPresent(
"org.apache.tomcat.util.http.MimeHeaders", ServletServerHttpRequest.class.getClassLoader());
private final HttpServletRequest servletRequest;
@ -156,16 +166,8 @@ public class ServletServerHttpRequest implements ServerHttpRequest { @@ -156,16 +166,8 @@ public class ServletServerHttpRequest implements ServerHttpRequest {
@Override
public HttpHeaders getHeaders() {
if (this.headers == null) {
this.headers = new HttpHeaders();
for (Enumeration<?> names = this.servletRequest.getHeaderNames(); names.hasMoreElements();) {
String headerName = (String) names.nextElement();
for (Enumeration<?> headerValues = this.servletRequest.getHeaders(headerName);
headerValues.hasMoreElements();) {
String headerValue = (String) headerValues.nextElement();
this.headers.add(headerName, headerValue);
}
}
MultiValueMap<String, String> headersAdapter = initHeadersMultiValueMap();
this.headers = new HttpHeaders(headersAdapter);
// HttpServletRequest exposes some headers as properties:
// we should include those if not already present
@ -203,10 +205,21 @@ public class ServletServerHttpRequest implements ServerHttpRequest { @@ -203,10 +205,21 @@ public class ServletServerHttpRequest implements ServerHttpRequest {
}
}
}
return this.headers;
}
private MultiValueMap<String, String> initHeadersMultiValueMap() {
MultiValueMap<String, String> nativeHeaders = null;
if (TOMCAT_PRESENT) {
nativeHeaders = TomcatInitializer.createTomcatHttpHeaders(this.servletRequest);
}
if (nativeHeaders == null) {
nativeHeaders = new ServletRequestHeadersAdapter(this.servletRequest);
}
return ServletRequestHeadersAdapter.overrideHeadersWrapper(nativeHeaders);
}
@Override
public @Nullable Principal getPrincipal() {
return this.servletRequest.getUserPrincipal();
@ -317,6 +330,41 @@ public class ServletServerHttpRequest implements ServerHttpRequest { @@ -317,6 +330,41 @@ public class ServletServerHttpRequest implements ServerHttpRequest {
}
private static final class TomcatInitializer {
private static final Field COYOTE_REQUEST_FIELD;
static {
Field field = ReflectionUtils.findField(RequestFacade.class, "request");
Assert.state(field != null, "Incompatible Tomcat implementation");
ReflectionUtils.makeAccessible(field);
COYOTE_REQUEST_FIELD = field;
}
public static @Nullable MultiValueMap<String, String> createTomcatHttpHeaders(HttpServletRequest servletRequest) {
RequestFacade requestFacade = getRequestFacade(servletRequest);
if (requestFacade == null) {
return null;
}
Object field = ReflectionUtils.getField(COYOTE_REQUEST_FIELD, requestFacade);
Assert.state(field != null, "No Tomcat connector request");
Request coyoteRequest = ((org.apache.catalina.connector.Request) field).getCoyoteRequest();
return new TomcatHeadersAdapter(coyoteRequest.getMimeHeaders());
}
private static @Nullable RequestFacade getRequestFacade(HttpServletRequest request) {
if (request instanceof RequestFacade facade) {
return facade;
}
else if (request instanceof HttpServletRequestWrapper wrapper) {
HttpServletRequest wrappedRequest = (HttpServletRequest) wrapper.getRequest();
return getRequestFacade(wrappedRequest);
}
return null;
}
}
private final class AttributesMap extends AbstractMap<String, Object> {
private @Nullable transient Set<String> keySet;

92
spring-web/src/main/java/org/springframework/http/server/ServletServerHttpResponse.java

@ -18,19 +18,13 @@ package org.springframework.http.server; @@ -18,19 +18,13 @@ package org.springframework.http.server;
import java.io.IOException;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import jakarta.servlet.http.HttpServletResponse;
import org.jspecify.annotations.Nullable;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatusCode;
import org.springframework.http.MediaType;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
/**
* {@link ServerHttpResponse} implementation that is based on a {@link HttpServletResponse}.
@ -59,7 +53,7 @@ public class ServletServerHttpResponse implements ServerHttpResponse { @@ -59,7 +53,7 @@ public class ServletServerHttpResponse implements ServerHttpResponse {
public ServletServerHttpResponse(HttpServletResponse servletResponse) {
Assert.notNull(servletResponse, "HttpServletResponse must not be null");
this.servletResponse = servletResponse;
this.headers = new ServletResponseHttpHeaders();
this.headers = new HttpHeaders(new ServletResponseHeadersAdapter(servletResponse));
}
@ -112,92 +106,8 @@ public class ServletServerHttpResponse implements ServerHttpResponse { @@ -112,92 +106,8 @@ public class ServletServerHttpResponse implements ServerHttpResponse {
private void writeHeaders() {
if (!this.headersWritten) {
getHeaders().forEach((headerName, headerValues) -> {
for (String headerValue : headerValues) {
this.servletResponse.addHeader(headerName, headerValue);
}
});
// HttpServletResponse exposes some headers as properties: we should include those if not already present
MediaType contentTypeHeader = this.headers.getContentType();
if (this.servletResponse.getContentType() == null && contentTypeHeader != null) {
this.servletResponse.setContentType(contentTypeHeader.toString());
}
if (this.servletResponse.getCharacterEncoding() == null && contentTypeHeader != null &&
contentTypeHeader.getCharset() != null) {
this.servletResponse.setCharacterEncoding(contentTypeHeader.getCharset().name());
}
long contentLength = getHeaders().getContentLength();
if (contentLength != -1) {
this.servletResponse.setContentLengthLong(contentLength);
}
this.headersWritten = true;
}
}
/**
* Extends HttpHeaders with the ability to look up headers already present in
* the underlying HttpServletResponse.
*
* <p>The intent is merely to expose what is available through the HttpServletResponse
* i.e. the ability to look up specific header values by name. All other
* map-related operations (for example, iteration, removal, etc) apply only to values
* added directly through HttpHeaders methods.
*
* @since 4.0.3
*/
private class ServletResponseHttpHeaders extends HttpHeaders {
private static final long serialVersionUID = 3410708522401046302L;
@Override
public boolean containsHeader(String key) {
return (super.containsHeader(key) || (get(key) != null));
}
@Override
public @Nullable String getFirst(String headerName) {
if (headerName.equalsIgnoreCase(CONTENT_TYPE)) {
// Content-Type is written as an override so check super first
String value = super.getFirst(headerName);
return (value != null ? value : servletResponse.getContentType());
}
else {
String value = servletResponse.getHeader(headerName);
return (value != null ? value : super.getFirst(headerName));
}
}
@Override
public @Nullable List<String> get(String headerName) {
if (headerName.equalsIgnoreCase(CONTENT_TYPE)) {
// Content-Type is written as an override so don't merge
String value = getFirst(headerName);
return (value != null ? Collections.singletonList(value) : null);
}
Collection<String> values1 = servletResponse.getHeaders(headerName);
if (headersWritten) {
return new ArrayList<>(values1);
}
boolean isEmpty1 = CollectionUtils.isEmpty(values1);
List<String> values2 = super.get(headerName);
boolean isEmpty2 = CollectionUtils.isEmpty(values2);
if (isEmpty1 && isEmpty2) {
return null;
}
List<String> values = new ArrayList<>();
if (!isEmpty1) {
values.addAll(values1);
}
if (!isEmpty2) {
values.addAll(values2);
}
return values;
}
}
}

2
spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/HttpEntityMethodProcessor.java

@ -229,7 +229,7 @@ public class HttpEntityMethodProcessor extends AbstractMessageConverterMethodPro @@ -229,7 +229,7 @@ public class HttpEntityMethodProcessor extends AbstractMessageConverterMethodPro
if (HttpHeaders.VARY.equals(key) && outputHeaders.containsHeader(HttpHeaders.VARY)) {
List<String> values = getVaryRequestHeadersToAdd(outputHeaders, entityHeaders);
if (!values.isEmpty()) {
outputHeaders.setVary(values);
outputHeaders.addAll(HttpHeaders.VARY, values);
}
}
else {

2
spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitterReturnValueHandlerTests.java

@ -288,7 +288,7 @@ class ResponseBodyEmitterReturnValueHandlerTests { @@ -288,7 +288,7 @@ class ResponseBodyEmitterReturnValueHandlerTests {
WebAsyncManager asyncManager = WebAsyncUtils.getAsyncManager(this.webRequest);
assertThat(asyncManager.getConcurrentResult()).isSameAs(ex);
assertThat(this.response.getContentType()).isNull();
assertThat(this.response.getContentType()).isEqualTo(MediaType.TEXT_EVENT_STREAM_VALUE);
}
@Test

Loading…
Cancel
Save