diff --git a/spring-web/src/main/java/org/springframework/http/server/ServletRequestHeadersAdapter.java b/spring-web/src/main/java/org/springframework/http/server/ServletRequestHeadersAdapter.java new file mode 100644 index 00000000000..7c1b49bf678 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/server/ServletRequestHeadersAdapter.java @@ -0,0 +1,527 @@ +/* + * Copyright 2002-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server; + +import java.util.AbstractSet; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Enumeration; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Set; + +import jakarta.servlet.http.HttpServletRequest; +import org.jspecify.annotations.Nullable; + +import org.springframework.http.HttpHeaders; +import org.springframework.util.CollectionUtils; +import org.springframework.util.LinkedCaseInsensitiveMap; +import org.springframework.util.MultiValueMap; + +/** + * {@code MultiValueMap} implementation for wrapping Servlet request headers. + * + * @author Rossen Stoyanchev + * @since 7.0.5 + */ +final class ServletRequestHeadersAdapter implements MultiValueMap { + + private final HttpServletRequest request; + + + ServletRequestHeadersAdapter(HttpServletRequest request) { + this.request = request; + } + + + @Override + public @Nullable String getFirst(String key) { + return this.request.getHeader(key); + } + + @Override + public void add(String key, @Nullable String value) { + throw new UnsupportedOperationException(); + } + + @Override + public void addAll(String key, List values) { + throw new UnsupportedOperationException(); + } + + @Override + public void addAll(MultiValueMap map) { + throw new UnsupportedOperationException(); + } + + @Override + public void set(String key, @Nullable String value) { + throw new UnsupportedOperationException(); + } + + @Override + public void setAll(Map map) { + throw new UnsupportedOperationException(); + } + + @Override + public Map toSingleValueMap() { + Map map = new LinkedHashMap<>(); + Enumeration names = this.request.getHeaderNames(); + while (names.hasMoreElements()) { + String name = names.nextElement(); + map.put(name, this.request.getHeader(name)); + } + return map; + } + + @Override + public int size() { + Enumeration names = this.request.getHeaderNames(); + Set set = new LinkedHashSet<>(); + while (names.hasMoreElements()) { + set.add(names.nextElement().toLowerCase(Locale.ROOT)); + } + return set.size(); + } + + @Override + public boolean isEmpty() { + return !this.request.getHeaderNames().hasMoreElements(); + } + + @Override + public boolean containsKey(Object key) { + if (key instanceof String headerName) { + Enumeration names = this.request.getHeaderNames(); + while (names.hasMoreElements()) { + if (headerName.equalsIgnoreCase(names.nextElement())) { + return true; + } + } + } + return false; + } + + @Override + public boolean containsValue(Object rawValue) { + if (rawValue instanceof String text) { + Enumeration names = this.request.getHeaderNames(); + while (names.hasMoreElements()) { + Enumeration values = this.request.getHeaders(names.nextElement()); + while (values.hasMoreElements()) { + if (text.equals(values.nextElement())) { + return true; + } + } + } + } + return false; + } + + @Override + public @Nullable List get(Object key) { + if (key instanceof String headerName) { + Enumeration values = this.request.getHeaders(headerName); + if (values.hasMoreElements()) { + List result = new ArrayList<>(); + while (values.hasMoreElements()) { + result.add(values.nextElement()); + } + return result; + } + } + return null; + } + + @Override + public @Nullable List put(String key, List value) { + throw new UnsupportedOperationException(); + } + + @Override + public @Nullable List remove(Object key) { + throw new UnsupportedOperationException(); + } + + @Override + public void putAll(Map> map) { + throw new UnsupportedOperationException(); + } + + @Override + public void clear() { + throw new UnsupportedOperationException(); + } + + @Override + public Set keySet() { + Set set = new LinkedHashSet<>(); + Enumeration names = this.request.getHeaderNames(); + while (names.hasMoreElements()) { + set.add(names.nextElement()); + } + return set; + } + + @Override + public Collection> values() { + List> allValues = new ArrayList<>(); + Enumeration names = this.request.getHeaderNames(); + while (names.hasMoreElements()) { + String name = names.nextElement(); + List currentValues = new ArrayList<>(); + Enumeration values = this.request.getHeaders(name); + while (values.hasMoreElements()) { + currentValues.add(values.nextElement()); + } + allValues.add(currentValues); + } + return allValues; + } + + @Override + public Set>> entrySet() { + return new AbstractSet<>() { + @Override + public Iterator>> iterator() { + return new EntryIterator(); + } + + @Override + public int size() { + return ServletRequestHeadersAdapter.this.size(); + } + }; + } + + + @Override + public int hashCode() { + return Map.copyOf(this).hashCode(); + } + + @Override + public boolean equals(@Nullable Object other) { + return (this == other || + (other instanceof MultiValueMap that && Map.copyOf(this).equals(that))); + } + + @Override + public String toString() { + return HttpHeaders.formatHeaders(this); + } + + + /** + * Apply a wrapper that allows headers to be set or added, and treats those + * as overrides to the headers in the given MultiValueMap. + * @param headers the headers map to wrap + * @return the wrapper instance + */ + public static MultiValueMap overrideHeadersWrapper(MultiValueMap headers) { + return new OverrideHeaderWrapper(headers); + } + + + private class EntryIterator implements Iterator>> { + + private final Iterator names = ServletRequestHeadersAdapter.this.keySet().iterator(); + + @Override + public boolean hasNext() { + return this.names.hasNext(); + } + + @Override + public Entry> next() { + return new HeaderEntry(this.names.next()); + } + } + + + private final class HeaderEntry implements Entry> { + + private final String key; + + HeaderEntry(String key) { + this.key = key; + } + + @Override + public String getKey() { + return this.key; + } + + @Override + public @Nullable List getValue() { + return get(this.key); + } + + @Override + public @Nullable List setValue(List value) { + List previous = getValue(); + remove(this.key); + addAll(this.key, value); + return previous; + } + } + + + /** + * Wrapper that supports optional override values. + */ + private static class OverrideHeaderWrapper implements MultiValueMap { + + private final MultiValueMap delegate; + + private @Nullable MultiValueMap overrideHeaders; + + OverrideHeaderWrapper(MultiValueMap delegate) { + this.delegate = delegate; + } + + @Override + public @Nullable String getFirst(String key) { + String value = (this.overrideHeaders != null ? this.overrideHeaders.getFirst(key) : null); + return (value != null ? value : this.delegate.getFirst(key)); + } + + @Override + public void add(String key, @Nullable String value) { + initOverrideHeaders().add(key, value); + } + + @Override + public void addAll(String key, List values) { + initOverrideHeaders().addAll(key, values); + } + + @Override + public void addAll(MultiValueMap map) { + initOverrideHeaders().addAll(map); + } + + @Override + public void set(String key, @Nullable String value) { + initOverrideHeaders().set(key, value); + } + + @Override + public void setAll(Map map) { + initOverrideHeaders().setAll(map); + } + + @Override + public Map toSingleValueMap() { + Map map = this.delegate.toSingleValueMap(); + if (this.overrideHeaders != null) { + this.overrideHeaders.forEach((key, values) -> map.put(key, values.get(0))); + } + return map; + } + + @Override + public int size() { + if (this.overrideHeaders == null) { + return this.delegate.size(); + } + Set set = new LinkedHashSet<>(); + for (String name : this.delegate.keySet()) { + set.add(name.toLowerCase(Locale.ROOT)); + } + this.overrideHeaders.keySet().forEach(key -> set.add(key.toLowerCase(Locale.ROOT))); + return set.size(); + } + + @Override + public boolean isEmpty() { + return (this.delegate.isEmpty() && (this.overrideHeaders == null || this.overrideHeaders.isEmpty())); + } + + @Override + public boolean containsKey(Object key) { + if (key instanceof String headerName) { + if (this.delegate.containsKey(headerName)) { + return true; + } + if (this.overrideHeaders != null) { + return this.overrideHeaders.containsKey(headerName); + } + } + return false; + } + + @Override + public boolean containsValue(Object rawValue) { + if (rawValue instanceof String text) { + if (this.delegate.containsValue(text)) { + return true; + } + if (this.overrideHeaders != null) { + return this.overrideHeaders.containsValue(rawValue); + } + } + return false; + } + + @Override + public @Nullable List get(Object key) { + if (key instanceof String headerName) { + if (this.overrideHeaders != null) { + List values = this.overrideHeaders.get(headerName); + if (values != null) { + return values; + } + } + return this.delegate.get(headerName); + } + return null; + } + + @Override + public @Nullable List put(String key, List value) { + return initOverrideHeaders().put(key, value); + } + + @Override + public @Nullable List remove(Object key) { + return initOverrideHeaders().remove(key); + } + + @Override + public void putAll(Map> map) { + initOverrideHeaders().putAll(map); + } + + @Override + public void clear() { + initOverrideHeaders().clear(); + } + + @Override + public Set keySet() { + Set set = this.delegate.keySet(); + if (this.overrideHeaders != null) { + set.addAll(this.overrideHeaders.keySet()); + } + return set; + } + + @Override + public Collection> values() { + List> allValues = new ArrayList<>(); + for (String name : keySet()) { + if (this.overrideHeaders != null && this.overrideHeaders.containsKey(name)) { + allValues.add(this.overrideHeaders.get(name)); + } + else { + allValues.add(this.delegate.get(name)); + } + } + return allValues; + } + + @Override + public Set>> entrySet() { + return new AbstractSet<>() { + @Override + public Iterator>> iterator() { + return new OverrideHeaderWrapper.EntryIterator(); + } + + @Override + public int size() { + return OverrideHeaderWrapper.this.size(); + } + }; + } + + private MultiValueMap initOverrideHeaders() { + if (this.overrideHeaders == null) { + this.overrideHeaders = CollectionUtils.toMultiValueMap(new LinkedCaseInsensitiveMap<>(8, Locale.ROOT)); + } + return this.overrideHeaders; + } + + + @Override + public int hashCode() { + return Map.copyOf(this).hashCode(); + } + + @Override + public boolean equals(@Nullable Object other) { + return (this == other || + (other instanceof MultiValueMap that && Map.copyOf(this).equals(that))); + } + + @Override + public String toString() { + return HttpHeaders.formatHeaders(this); + } + + + private class EntryIterator implements Iterator>> { + + private final Iterator names = OverrideHeaderWrapper.this.keySet().iterator(); + + @Override + public boolean hasNext() { + return this.names.hasNext(); + } + + @Override + public Entry> next() { + return new OverrideHeaderWrapper.HeaderEntry(this.names.next()); + } + } + + + private final class HeaderEntry implements Entry> { + + private final String key; + + HeaderEntry(String key) { + this.key = key; + } + + @Override + public String getKey() { + return this.key; + } + + @Override + public @Nullable List getValue() { + return get(this.key); + } + + @Override + public @Nullable List setValue(List value) { + List previous = getValue(); + remove(this.key); + addAll(this.key, value); + return previous; + } + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/server/ServletResponseHeadersAdapter.java b/spring-web/src/main/java/org/springframework/http/server/ServletResponseHeadersAdapter.java new file mode 100644 index 00000000000..d9456bd415e --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/server/ServletResponseHeadersAdapter.java @@ -0,0 +1,269 @@ +/* + * Copyright 2002-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server; + +import java.util.AbstractSet; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import jakarta.servlet.http.HttpServletResponse; +import org.jspecify.annotations.Nullable; + +import org.springframework.http.HttpHeaders; +import org.springframework.util.MultiValueMap; + +/** + * {@code MultiValueMap} implementation for wrapping Servlet response headers. + * + * @author Rossen Stoyanchev + * @since 7.0.5 + */ +class ServletResponseHeadersAdapter implements MultiValueMap { + + private final HttpServletResponse response; + + + ServletResponseHeadersAdapter(HttpServletResponse response) { + this.response = response; + } + + + @Override + public String getFirst(String key) { + return this.response.getHeader(key); + } + + @Override + public void add(String key, @Nullable String value) { + this.response.addHeader(key, value); + } + + @Override + public void addAll(String key, List values) { + values.forEach(value -> this.response.addHeader(key, value)); + } + + @Override + public void addAll(MultiValueMap map) { + for (Entry> entry : map.entrySet()) { + for (String value : entry.getValue()) { + this.response.addHeader(entry.getKey(), value); + } + } + } + + @Override + public void set(String key, @Nullable String value) { + this.response.setHeader(key, value); + } + + @Override + public void setAll(Map map) { + for (Entry entry : map.entrySet()) { + this.response.setHeader(entry.getKey(), entry.getValue()); + } + } + + @Override + public Map toSingleValueMap() { + Map map = new LinkedHashMap<>(); + Collection names = this.response.getHeaderNames(); + for (String name : names) { + map.put(name, this.response.getHeader(name)); + } + return map; + } + + @Override + public int size() { + return this.response.getHeaderNames().size(); + } + + @Override + public boolean isEmpty() { + return this.response.getHeaderNames().isEmpty(); + } + + @Override + public boolean containsKey(Object key) { + if (key instanceof String headerName) { + return this.response.containsHeader(headerName); + } + return false; + } + + @Override + public boolean containsValue(Object rawValue) { + if (rawValue instanceof String text) { + for (String name : this.response.getHeaderNames()) { + Collection values = this.response.getHeaders(name); + for (String value : values) { + if (text.equals(value)) { + return true; + } + } + } + } + return false; + } + + @Override + public @Nullable List get(Object key) { + if (key instanceof String headerName) { + Collection values = this.response.getHeaders(headerName); + if (!values.isEmpty()) { + return (values instanceof List ? (List) values : new ArrayList<>(values)); + } + return (this.response.containsHeader(headerName) ? Collections.emptyList() : null); + } + return null; + } + + @Override + public @Nullable List put(String key, List values) { + List previous = remove(key); + for (String value : values) { + this.response.addHeader(key, value); + } + return previous; + } + + @Override + public @Nullable List remove(Object key) { + if (key instanceof String headerName) { + Collection previous = this.response.getHeaders(headerName); + if (previous != null) { + this.response.setHeader(headerName, null); + } + return (previous != null ? new ArrayList<>(previous) : null); + } + return null; + } + + @Override + public void putAll(Map> map) { + for (Entry> entry : map.entrySet()) { + this.response.setHeader(entry.getKey(), null); + for (String value : entry.getValue()) { + this.response.addHeader(entry.getKey(), value); + } + } + } + + @Override + public void clear() { + for (String headerName : this.response.getHeaderNames()) { + this.response.setHeader(headerName, null); + } + } + + @Override + public Set keySet() { + return new LinkedHashSet<>(this.response.getHeaderNames()); + } + + @Override + public Collection> values() { + List> allValues = new ArrayList<>(); + for (String name : this.response.getHeaderNames()) { + allValues.add(new ArrayList<>(this.response.getHeaders(name))); + } + return allValues; + } + + @Override + public Set>> entrySet() { + return new AbstractSet<>() { + @Override + public Iterator>> iterator() { + return new EntryIterator(); + } + + @Override + public int size() { + return ServletResponseHeadersAdapter.this.size(); + } + }; + } + + + @Override + public int hashCode() { + return Map.copyOf(this).hashCode(); + } + + @Override + public boolean equals(@Nullable Object other) { + return (this == other || + (other instanceof MultiValueMap that && Map.copyOf(this).equals(that))); + } + + @Override + public String toString() { + return HttpHeaders.formatHeaders(this); + } + + + private class EntryIterator implements Iterator>> { + + private final Iterator names = + ServletResponseHeadersAdapter.this.response.getHeaderNames().iterator(); + + @Override + public boolean hasNext() { + return this.names.hasNext(); + } + + @Override + public Entry> next() { + return new HeaderEntry(this.names.next()); + } + } + + + private final class HeaderEntry implements Entry> { + + private final String key; + + HeaderEntry(String key) { + this.key = key; + } + + @Override + public String getKey() { + return this.key; + } + + @Override + public @Nullable List getValue() { + return ServletResponseHeadersAdapter.this.get(this.key); + } + + @Override + public @Nullable List setValue(List values) { + return ServletResponseHeadersAdapter.this.put(this.key, values); + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/server/ServletServerHttpRequest.java b/spring-web/src/main/java/org/springframework/http/server/ServletServerHttpRequest.java index f6b99c94bb2..a2724f51b06 100644 --- a/spring-web/src/main/java/org/springframework/http/server/ServletServerHttpRequest.java +++ b/spring-web/src/main/java/org/springframework/http/server/ServletServerHttpRequest.java @@ -22,6 +22,7 @@ import java.io.IOException; import java.io.InputStream; import java.io.OutputStreamWriter; import java.io.Writer; +import java.lang.reflect.Field; import java.net.InetSocketAddress; import java.net.URI; import java.net.URISyntaxException; @@ -41,6 +42,9 @@ import java.util.Map; import java.util.Set; import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletRequestWrapper; +import org.apache.catalina.connector.RequestFacade; +import org.apache.coyote.Request; import org.jspecify.annotations.Nullable; import org.springframework.http.HttpHeaders; @@ -48,7 +52,10 @@ import org.springframework.http.HttpMethod; import org.springframework.http.InvalidMediaTypeException; import org.springframework.http.MediaType; import org.springframework.util.Assert; +import org.springframework.util.ClassUtils; import org.springframework.util.LinkedCaseInsensitiveMap; +import org.springframework.util.MultiValueMap; +import org.springframework.util.ReflectionUtils; import org.springframework.util.StringUtils; /** @@ -63,6 +70,9 @@ public class ServletServerHttpRequest implements ServerHttpRequest { protected static final Charset FORM_CHARSET = StandardCharsets.UTF_8; + private static final boolean TOMCAT_PRESENT = ClassUtils.isPresent( + "org.apache.tomcat.util.http.MimeHeaders", ServletServerHttpRequest.class.getClassLoader()); + private final HttpServletRequest servletRequest; @@ -156,16 +166,8 @@ public class ServletServerHttpRequest implements ServerHttpRequest { @Override public HttpHeaders getHeaders() { if (this.headers == null) { - this.headers = new HttpHeaders(); - - for (Enumeration names = this.servletRequest.getHeaderNames(); names.hasMoreElements();) { - String headerName = (String) names.nextElement(); - for (Enumeration headerValues = this.servletRequest.getHeaders(headerName); - headerValues.hasMoreElements();) { - String headerValue = (String) headerValues.nextElement(); - this.headers.add(headerName, headerValue); - } - } + MultiValueMap headersAdapter = initHeadersMultiValueMap(); + this.headers = new HttpHeaders(headersAdapter); // HttpServletRequest exposes some headers as properties: // we should include those if not already present @@ -203,10 +205,21 @@ public class ServletServerHttpRequest implements ServerHttpRequest { } } } - return this.headers; } + private MultiValueMap initHeadersMultiValueMap() { + MultiValueMap nativeHeaders = null; + if (TOMCAT_PRESENT) { + nativeHeaders = TomcatInitializer.createTomcatHttpHeaders(this.servletRequest); + } + if (nativeHeaders == null) { + nativeHeaders = new ServletRequestHeadersAdapter(this.servletRequest); + } + return ServletRequestHeadersAdapter.overrideHeadersWrapper(nativeHeaders); + } + + @Override public @Nullable Principal getPrincipal() { return this.servletRequest.getUserPrincipal(); @@ -317,6 +330,41 @@ public class ServletServerHttpRequest implements ServerHttpRequest { } + private static final class TomcatInitializer { + + private static final Field COYOTE_REQUEST_FIELD; + + static { + Field field = ReflectionUtils.findField(RequestFacade.class, "request"); + Assert.state(field != null, "Incompatible Tomcat implementation"); + ReflectionUtils.makeAccessible(field); + COYOTE_REQUEST_FIELD = field; + } + + public static @Nullable MultiValueMap createTomcatHttpHeaders(HttpServletRequest servletRequest) { + RequestFacade requestFacade = getRequestFacade(servletRequest); + if (requestFacade == null) { + return null; + } + Object field = ReflectionUtils.getField(COYOTE_REQUEST_FIELD, requestFacade); + Assert.state(field != null, "No Tomcat connector request"); + Request coyoteRequest = ((org.apache.catalina.connector.Request) field).getCoyoteRequest(); + return new TomcatHeadersAdapter(coyoteRequest.getMimeHeaders()); + } + + private static @Nullable RequestFacade getRequestFacade(HttpServletRequest request) { + if (request instanceof RequestFacade facade) { + return facade; + } + else if (request instanceof HttpServletRequestWrapper wrapper) { + HttpServletRequest wrappedRequest = (HttpServletRequest) wrapper.getRequest(); + return getRequestFacade(wrappedRequest); + } + return null; + } + } + + private final class AttributesMap extends AbstractMap { private @Nullable transient Set keySet; diff --git a/spring-web/src/main/java/org/springframework/http/server/ServletServerHttpResponse.java b/spring-web/src/main/java/org/springframework/http/server/ServletServerHttpResponse.java index b1e0a13e6b6..843f603de3c 100644 --- a/spring-web/src/main/java/org/springframework/http/server/ServletServerHttpResponse.java +++ b/spring-web/src/main/java/org/springframework/http/server/ServletServerHttpResponse.java @@ -18,19 +18,13 @@ package org.springframework.http.server; import java.io.IOException; import java.io.OutputStream; -import java.util.ArrayList; -import java.util.Collection; -import java.util.Collections; -import java.util.List; import jakarta.servlet.http.HttpServletResponse; import org.jspecify.annotations.Nullable; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpStatusCode; -import org.springframework.http.MediaType; import org.springframework.util.Assert; -import org.springframework.util.CollectionUtils; /** * {@link ServerHttpResponse} implementation that is based on a {@link HttpServletResponse}. @@ -59,7 +53,7 @@ public class ServletServerHttpResponse implements ServerHttpResponse { public ServletServerHttpResponse(HttpServletResponse servletResponse) { Assert.notNull(servletResponse, "HttpServletResponse must not be null"); this.servletResponse = servletResponse; - this.headers = new ServletResponseHttpHeaders(); + this.headers = new HttpHeaders(new ServletResponseHeadersAdapter(servletResponse)); } @@ -112,92 +106,8 @@ public class ServletServerHttpResponse implements ServerHttpResponse { private void writeHeaders() { if (!this.headersWritten) { - getHeaders().forEach((headerName, headerValues) -> { - for (String headerValue : headerValues) { - this.servletResponse.addHeader(headerName, headerValue); - } - }); - // HttpServletResponse exposes some headers as properties: we should include those if not already present - MediaType contentTypeHeader = this.headers.getContentType(); - if (this.servletResponse.getContentType() == null && contentTypeHeader != null) { - this.servletResponse.setContentType(contentTypeHeader.toString()); - } - if (this.servletResponse.getCharacterEncoding() == null && contentTypeHeader != null && - contentTypeHeader.getCharset() != null) { - this.servletResponse.setCharacterEncoding(contentTypeHeader.getCharset().name()); - } - long contentLength = getHeaders().getContentLength(); - if (contentLength != -1) { - this.servletResponse.setContentLengthLong(contentLength); - } this.headersWritten = true; } } - - /** - * Extends HttpHeaders with the ability to look up headers already present in - * the underlying HttpServletResponse. - * - *

The intent is merely to expose what is available through the HttpServletResponse - * i.e. the ability to look up specific header values by name. All other - * map-related operations (for example, iteration, removal, etc) apply only to values - * added directly through HttpHeaders methods. - * - * @since 4.0.3 - */ - private class ServletResponseHttpHeaders extends HttpHeaders { - - private static final long serialVersionUID = 3410708522401046302L; - - @Override - public boolean containsHeader(String key) { - return (super.containsHeader(key) || (get(key) != null)); - } - - @Override - public @Nullable String getFirst(String headerName) { - if (headerName.equalsIgnoreCase(CONTENT_TYPE)) { - // Content-Type is written as an override so check super first - String value = super.getFirst(headerName); - return (value != null ? value : servletResponse.getContentType()); - } - else { - String value = servletResponse.getHeader(headerName); - return (value != null ? value : super.getFirst(headerName)); - } - } - - @Override - public @Nullable List get(String headerName) { - if (headerName.equalsIgnoreCase(CONTENT_TYPE)) { - // Content-Type is written as an override so don't merge - String value = getFirst(headerName); - return (value != null ? Collections.singletonList(value) : null); - } - - Collection values1 = servletResponse.getHeaders(headerName); - if (headersWritten) { - return new ArrayList<>(values1); - } - boolean isEmpty1 = CollectionUtils.isEmpty(values1); - - List values2 = super.get(headerName); - boolean isEmpty2 = CollectionUtils.isEmpty(values2); - - if (isEmpty1 && isEmpty2) { - return null; - } - - List values = new ArrayList<>(); - if (!isEmpty1) { - values.addAll(values1); - } - if (!isEmpty2) { - values.addAll(values2); - } - return values; - } - } - } diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/HttpEntityMethodProcessor.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/HttpEntityMethodProcessor.java index 0bf8037b2e0..ba6a0388687 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/HttpEntityMethodProcessor.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/HttpEntityMethodProcessor.java @@ -229,7 +229,7 @@ public class HttpEntityMethodProcessor extends AbstractMessageConverterMethodPro if (HttpHeaders.VARY.equals(key) && outputHeaders.containsHeader(HttpHeaders.VARY)) { List values = getVaryRequestHeadersToAdd(outputHeaders, entityHeaders); if (!values.isEmpty()) { - outputHeaders.setVary(values); + outputHeaders.addAll(HttpHeaders.VARY, values); } } else { diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitterReturnValueHandlerTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitterReturnValueHandlerTests.java index 2abfe70f767..c8fbbd59bda 100644 --- a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitterReturnValueHandlerTests.java +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitterReturnValueHandlerTests.java @@ -288,7 +288,7 @@ class ResponseBodyEmitterReturnValueHandlerTests { WebAsyncManager asyncManager = WebAsyncUtils.getAsyncManager(this.webRequest); assertThat(asyncManager.getConcurrentResult()).isSameAs(ex); - assertThat(this.response.getContentType()).isNull(); + assertThat(this.response.getContentType()).isEqualTo(MediaType.TEXT_EVENT_STREAM_VALUE); } @Test