diff --git a/spring-web/src/main/java/org/springframework/web/filter/ShallowEtagHeaderFilter.java b/spring-web/src/main/java/org/springframework/web/filter/ShallowEtagHeaderFilter.java index 3007318e94f..fb7abfeb8ed 100644 --- a/spring-web/src/main/java/org/springframework/web/filter/ShallowEtagHeaderFilter.java +++ b/spring-web/src/main/java/org/springframework/web/filter/ShallowEtagHeaderFilter.java @@ -31,6 +31,8 @@ import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; import org.springframework.util.Assert; import org.springframework.util.DigestUtils; +import org.springframework.util.StringUtils; +import org.springframework.web.context.request.ServletWebRequest; import org.springframework.web.util.ContentCachingResponseWrapper; import org.springframework.web.util.WebUtils; @@ -117,11 +119,12 @@ public class ShallowEtagHeaderFilter extends OncePerRequestFilter { HttpServletResponse rawResponse = (HttpServletResponse) wrapper.getResponse(); if (isEligibleForEtag(request, wrapper, wrapper.getStatus(), wrapper.getContentInputStream())) { - String responseETag = generateETagHeaderValue(wrapper.getContentInputStream(), this.writeWeakETag); - rawResponse.setHeader(HttpHeaders.ETAG, responseETag); - String requestETag = request.getHeader(HttpHeaders.IF_NONE_MATCH); - if (requestETag != null && ("*".equals(requestETag) || compareETagHeaderValue(requestETag, responseETag))) { - rawResponse.setStatus(HttpServletResponse.SC_NOT_MODIFIED); + String eTag = wrapper.getHeader(HttpHeaders.ETAG); + if (!StringUtils.hasText(eTag)) { + eTag = generateETagHeaderValue(wrapper.getContentInputStream(), this.writeWeakETag); + rawResponse.setHeader(HttpHeaders.ETAG, eTag); + } + if (new ServletWebRequest(request, rawResponse).checkNotModified(eTag)) { return; } } @@ -224,15 +227,19 @@ public class ShallowEtagHeaderFilter extends OncePerRequestFilter { @Override public ServletOutputStream getOutputStream() throws IOException { - return (isContentCachingDisabled(this.request) ? + return (isContentCachingDisabled(this.request) || hasETag() ? getResponse().getOutputStream() : super.getOutputStream()); } @Override public PrintWriter getWriter() throws IOException { - return (isContentCachingDisabled(this.request) ? + return (isContentCachingDisabled(this.request) || hasETag()? getResponse().getWriter() : super.getWriter()); } + + private boolean hasETag() { + return StringUtils.hasText(getHeader(HttpHeaders.ETAG)); + } } } diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/HttpEntityMethodProcessorTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/HttpEntityMethodProcessorTests.java index 10b31886546..1b809cf281e 100644 --- a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/HttpEntityMethodProcessorTests.java +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/HttpEntityMethodProcessorTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,8 +19,13 @@ package org.springframework.web.servlet.mvc.method.annotation; import java.io.Serializable; import java.lang.reflect.Method; import java.util.ArrayList; +import java.util.Collections; import java.util.List; +import javax.servlet.FilterChain; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + import com.fasterxml.jackson.annotation.JsonTypeInfo; import com.fasterxml.jackson.annotation.JsonTypeName; import org.junit.jupiter.api.BeforeEach; @@ -28,6 +33,7 @@ import org.junit.jupiter.api.Test; import org.springframework.core.MethodParameter; import org.springframework.http.HttpEntity; +import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; import org.springframework.http.ResponseEntity; import org.springframework.http.converter.ByteArrayHttpMessageConverter; @@ -42,6 +48,7 @@ import org.springframework.web.bind.annotation.ResponseBody; import org.springframework.web.bind.support.WebDataBinderFactory; import org.springframework.web.context.request.NativeWebRequest; import org.springframework.web.context.request.ServletWebRequest; +import org.springframework.web.filter.ShallowEtagHeaderFilter; import org.springframework.web.method.HandlerMethod; import org.springframework.web.method.support.ModelAndViewContainer; import org.springframework.web.testfixture.servlet.MockHttpServletRequest; @@ -200,6 +207,41 @@ public class HttpEntityMethodProcessorTests { assertThat(servletResponse.getContentAsString()).isEqualTo("Foo"); } + @Test // SPR-13423 + public void handleReturnValueWithETagAndETagFilter() throws Exception { + + String eTagValue = "\"deadb33f8badf00d\""; + String content = "body"; + + Method method = getClass().getDeclaredMethod("handle"); + MethodParameter returnType = new MethodParameter(method, -1); + + FilterChain chain = (req, res) -> { + ResponseEntity returnValue = ResponseEntity.ok().eTag(eTagValue).body(content); + try { + ServletWebRequest requestToUse = + new ServletWebRequest((HttpServletRequest) req, (HttpServletResponse) res); + + new HttpEntityMethodProcessor(Collections.singletonList(new StringHttpMessageConverter())) + .handleReturnValue(returnValue, returnType, mavContainer, requestToUse); + + assertThat(this.servletResponse.getContentAsString()) + .as("Response body was cached? It should be written directly to the raw response") + .isEqualTo(content); + } + catch (Exception ex) { + throw new IllegalStateException(ex); + } + }; + + this.servletRequest.setMethod("GET"); + new ShallowEtagHeaderFilter().doFilter(this.servletRequest, this.servletResponse, chain); + + assertThat(this.servletResponse.getStatus()).isEqualTo(200); + assertThat(this.servletResponse.getHeader(HttpHeaders.ETAG)).isEqualTo(eTagValue); + assertThat(this.servletResponse.getContentAsString()).isEqualTo(content); + } + @SuppressWarnings("unused") private void handle(HttpEntity> arg1, HttpEntity arg2) {