diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/HttpEntityMethodProcessor.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/HttpEntityMethodProcessor.java index 6afef13aa3c..c35fb3bd93f 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/HttpEntityMethodProcessor.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/HttpEntityMethodProcessor.java @@ -181,15 +181,15 @@ public class HttpEntityMethodProcessor extends AbstractMessageConverterMethodPro HttpHeaders outputHeaders = outputMessage.getHeaders(); HttpHeaders entityHeaders = responseEntity.getHeaders(); - if (outputHeaders.containsKey(HttpHeaders.VARY) && entityHeaders.containsKey(HttpHeaders.VARY)) { - List values = getVaryRequestHeadersToAdd(outputHeaders, entityHeaders); - if (!values.isEmpty()) { - outputHeaders.setVary(values); - } - } if (!entityHeaders.isEmpty()) { entityHeaders.forEach((key, value) -> { - if (!outputHeaders.containsKey(key)) { + if (HttpHeaders.VARY.equals(key) && outputHeaders.containsKey(HttpHeaders.VARY)) { + List values = getVaryRequestHeadersToAdd(outputHeaders, entityHeaders); + if (!values.isEmpty()) { + outputHeaders.setVary(values); + } + } + else { outputHeaders.put(key, value); } }); diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/HttpEntityMethodProcessorMockTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/HttpEntityMethodProcessorMockTests.java index 2a479f407a8..417b3f11aa2 100644 --- a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/HttpEntityMethodProcessorMockTests.java +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/HttpEntityMethodProcessorMockTests.java @@ -542,6 +542,59 @@ public class HttpEntityMethodProcessorMockTests { assertConditionalResponse(HttpStatus.OK, "body", etagValue, -1); } + @Test + public void varyHeader() throws Exception { + String[] entityValues = {"Accept-Language", "User-Agent"}; + String[] existingValues = {}; + String[] expected = {"Accept-Language, User-Agent"}; + testVaryHeader(entityValues, existingValues, expected); + } + + @Test + public void varyHeaderWithExistingWildcard() throws Exception { + String[] entityValues = {"Accept-Language"}; + String[] existingValues = {"*"}; + String[] expected = {"*"}; + testVaryHeader(entityValues, existingValues, expected); + } + + @Test + public void varyHeaderWithExistingCommaValues() throws Exception { + String[] entityValues = {"Accept-Language", "User-Agent"}; + String[] existingValues = {"Accept-Encoding", "Accept-Language"}; + String[] expected = {"Accept-Encoding", "Accept-Language", "User-Agent"}; + testVaryHeader(entityValues, existingValues, expected); + } + + @Test + public void varyHeaderWithExistingCommaSeparatedValues() throws Exception { + String[] entityValues = {"Accept-Language", "User-Agent"}; + String[] existingValues = {"Accept-Encoding, Accept-Language"}; + String[] expected = {"Accept-Encoding, Accept-Language", "User-Agent"}; + testVaryHeader(entityValues, existingValues, expected); + } + + @Test + public void handleReturnValueVaryHeader() throws Exception { + String[] entityValues = {"Accept-Language", "User-Agent"}; + String[] existingValues = {"Accept-Encoding, Accept-Language"}; + String[] expected = {"Accept-Encoding, Accept-Language", "User-Agent"}; + testVaryHeader(entityValues, existingValues, expected); + } + + + private void testVaryHeader(String[] entityValues, String[] existingValues, String[] expected) throws Exception { + ResponseEntity returnValue = ResponseEntity.ok().varyBy(entityValues).body("Foo"); + for (String value : existingValues) { + servletResponse.addHeader("Vary", value); + } + initStringMessageConversion(MediaType.TEXT_PLAIN); + processor.handleReturnValue(returnValue, returnTypeResponseEntity, mavContainer, webRequest); + + assertTrue(mavContainer.isRequestHandled()); + assertEquals(Arrays.asList(expected), servletResponse.getHeaders("Vary")); + verify(stringHttpMessageConverter).write(eq("Foo"), eq(MediaType.TEXT_PLAIN), isA(HttpOutputMessage.class)); + } private void initStringMessageConversion(MediaType accepted) { given(stringHttpMessageConverter.canWrite(String.class, null)).willReturn(true); @@ -554,8 +607,7 @@ public class HttpEntityMethodProcessorMockTests { verify(stringHttpMessageConverter).write(eq(body), eq(MediaType.TEXT_PLAIN), outputMessage.capture()); } - private void assertConditionalResponse(HttpStatus status, String body, - String etag, long lastModified) throws Exception { + private void assertConditionalResponse(HttpStatus status, String body, String etag, long lastModified) throws Exception { assertEquals(status.value(), servletResponse.getStatus()); assertTrue(mavContainer.isRequestHandled()); if (body != null) {