From 2c9ed4608f74f76d8bbc59f3445e7ed6416507b6 Mon Sep 17 00:00:00 2001 From: rstoyanchev Date: Fri, 10 May 2024 18:29:30 +0100 Subject: [PATCH] Improve RequestAttributesThreadLocalAccessor Ensure access to request attributes after initial REQUEST dispatch is done, and the RequestAttributes markedCompleted. Closes gh-32296 --- .../RequestAttributesThreadLocalAccessor.java | 90 ++++++++++++++++++- ...estAttributesThreadLocalAccessorTests.java | 54 +++++++++++ 2 files changed, 143 insertions(+), 1 deletion(-) diff --git a/spring-web/src/main/java/org/springframework/web/context/request/RequestAttributesThreadLocalAccessor.java b/spring-web/src/main/java/org/springframework/web/context/request/RequestAttributesThreadLocalAccessor.java index 136127bc675..abec7aa4d41 100644 --- a/spring-web/src/main/java/org/springframework/web/context/request/RequestAttributesThreadLocalAccessor.java +++ b/spring-web/src/main/java/org/springframework/web/context/request/RequestAttributesThreadLocalAccessor.java @@ -16,7 +16,12 @@ package org.springframework.web.context.request; +import java.util.Enumeration; +import java.util.HashMap; +import java.util.Map; + import io.micrometer.context.ThreadLocalAccessor; +import jakarta.servlet.http.HttpServletRequest; import org.springframework.lang.Nullable; @@ -26,6 +31,7 @@ import org.springframework.lang.Nullable; * {@link RequestAttributes} propagation. * * @author Tadaya Tsuyukubo + * @author Rossen Stoyanchev * @since 6.2 */ public class RequestAttributesThreadLocalAccessor implements ThreadLocalAccessor { @@ -44,7 +50,11 @@ public class RequestAttributesThreadLocalAccessor implements ThreadLocalAccessor @Override @Nullable public RequestAttributes getValue() { - return RequestContextHolder.getRequestAttributes(); + RequestAttributes request = RequestContextHolder.getRequestAttributes(); + if (request instanceof ServletRequestAttributes sra && !(sra instanceof SnapshotServletRequestAttributes)) { + request = new SnapshotServletRequestAttributes(sra); + } + return request; } @Override @@ -57,4 +67,82 @@ public class RequestAttributesThreadLocalAccessor implements ThreadLocalAccessor RequestContextHolder.resetRequestAttributes(); } + + /** + * ServletRequestAttributes that takes another instance, and makes a copy of the + * request attributes at present to provides extended read access during async + * handling when the DispatcherServlet has exited from the initial REQUEST dispatch + * and marked the request {@link ServletRequestAttributes#requestCompleted()}. + *

Note that beyond access to request attributes, here is no attempt to support + * setting or removing request attributes, nor to access session attributes after + * the initial REQUEST dispatch has exited. + */ + private static final class SnapshotServletRequestAttributes extends ServletRequestAttributes { + + private final ServletRequestAttributes delegate; + + private final Map attributeMap; + + public SnapshotServletRequestAttributes(ServletRequestAttributes requestAttributes) { + super(requestAttributes.getRequest(), requestAttributes.getResponse()); + this.delegate = requestAttributes; + this.attributeMap = getAttributes(requestAttributes.getRequest()); + } + + private static Map getAttributes(HttpServletRequest request) { + Map map = new HashMap<>(); + Enumeration names = request.getAttributeNames(); + while (names.hasMoreElements()) { + String name = names.nextElement(); + map.put(name, request.getAttribute(name)); + } + return map; + } + + // Delegate methods that check isRequestActive() + + @Nullable + @Override + public Object getAttribute(String name, int scope) { + if (scope == RequestAttributes.SCOPE_REQUEST && !this.delegate.isRequestActive()) { + return this.attributeMap.get(name); + } + try { + return this.delegate.getAttribute(name, scope); + } + catch (IllegalStateException ex) { + if (scope == RequestAttributes.SCOPE_REQUEST) { + return this.attributeMap.get(name); + } + throw ex; + } + } + + @Override + public String[] getAttributeNames(int scope) { + if (scope == RequestAttributes.SCOPE_REQUEST && !this.delegate.isRequestActive()) { + return this.attributeMap.keySet().toArray(new String[0]); + } + try { + return this.delegate.getAttributeNames(scope); + } + catch (IllegalStateException ex) { + if (scope == RequestAttributes.SCOPE_REQUEST) { + return this.attributeMap.keySet().toArray(new String[0]); + } + throw ex; + } + } + + @Override + public void setAttribute(String name, Object value, int scope) { + this.delegate.setAttribute(name, value, scope); + } + + @Override + public void removeAttribute(String name, int scope) { + this.delegate.removeAttribute(name, scope); + } + } + } diff --git a/spring-web/src/test/java/org/springframework/web/context/request/RequestAttributesThreadLocalAccessorTests.java b/spring-web/src/test/java/org/springframework/web/context/request/RequestAttributesThreadLocalAccessorTests.java index dd51f5fcfb1..e0228291c7b 100644 --- a/spring-web/src/test/java/org/springframework/web/context/request/RequestAttributesThreadLocalAccessorTests.java +++ b/spring-web/src/test/java/org/springframework/web/context/request/RequestAttributesThreadLocalAccessorTests.java @@ -23,12 +23,18 @@ import io.micrometer.context.ContextRegistry; import io.micrometer.context.ContextSnapshot; import io.micrometer.context.ContextSnapshot.Scope; import io.micrometer.context.ContextSnapshotFactory; +import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; +import org.springframework.web.testfixture.servlet.MockHttpServletRequest; +import org.springframework.web.testfixture.servlet.MockHttpServletResponse; + import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalStateException; import static org.mockito.Mockito.mock; +import static org.springframework.web.context.request.RequestAttributes.SCOPE_REQUEST; /** * Tests for {@link RequestAttributesThreadLocalAccessor}. @@ -73,6 +79,54 @@ class RequestAttributesThreadLocalAccessorTests { assertThat(requestAfterScope).hasValueSatisfying(value -> assertThat(value).isSameAs(previousRequest)); } + @Test + void accessAfterRequestMarkedCompleted() { + MockHttpServletRequest servletRequest = new MockHttpServletRequest(); + servletRequest.setAttribute("k1", "v1"); + servletRequest.setAttribute("k2", "v2"); + + ServletRequestAttributes attributes = new ServletRequestAttributes(servletRequest, new MockHttpServletResponse()); + ContextSnapshot snapshot = getSnapshotFor(attributes); + attributes.requestCompleted(); // REQUEST dispatch ends, async handling continues + + try (Scope scope = snapshot.setThreadLocals()) { + RequestAttributes current = RequestContextHolder.getRequestAttributes(); + assertThat(current).isNotNull(); + assertThat(current.getAttributeNames(SCOPE_REQUEST)).containsExactly("k1", "k2"); + assertThat(current.getAttribute("k1", SCOPE_REQUEST)).isEqualTo("v1"); + assertThat(current.getAttribute("k2", SCOPE_REQUEST)).isEqualTo("v2"); + assertThatIllegalStateException().isThrownBy(() -> current.setAttribute("k3", "v3", SCOPE_REQUEST)); + } + } + + @Test + void accessBeforeRequestMarkedCompleted() { + + MockHttpServletRequest servletRequest = new MockHttpServletRequest(); + ServletRequestAttributes previous = new ServletRequestAttributes(servletRequest, new MockHttpServletResponse()); + + ContextSnapshot snapshot = getSnapshotFor(previous); + + RequestContextHolder.setRequestAttributes(previous); + try { + try (Scope scope = snapshot.setThreadLocals()) { + RequestAttributes attributes = RequestContextHolder.getRequestAttributes(); + assertThat(attributes).isNotNull(); + attributes.setAttribute("k1", "v1", SCOPE_REQUEST); + } + RequestAttributes attributes = RequestContextHolder.getRequestAttributes(); + assertThat(attributes).isNotNull(); + attributes.setAttribute("k2", "v2", SCOPE_REQUEST); + } + finally { + RequestContextHolder.resetRequestAttributes(); + } + + assertThat(previous.getAttributeNames(SCOPE_REQUEST)).containsExactly("k1", "k2"); + assertThat(previous.getAttribute("k1", SCOPE_REQUEST)).isEqualTo("v1"); + assertThat(previous.getAttribute("k2", SCOPE_REQUEST)).isEqualTo("v2"); + } + private ContextSnapshot getSnapshotFor(RequestAttributes request) { RequestContextHolder.setRequestAttributes(request); try {