diff --git a/spring-web/src/main/java/org/springframework/web/filter/ServerHttpObservationFilter.java b/spring-web/src/main/java/org/springframework/web/filter/ServerHttpObservationFilter.java index d2b8bfcac36..e56b297349d 100644 --- a/spring-web/src/main/java/org/springframework/web/filter/ServerHttpObservationFilter.java +++ b/spring-web/src/main/java/org/springframework/web/filter/ServerHttpObservationFilter.java @@ -23,6 +23,7 @@ import io.micrometer.observation.Observation; import io.micrometer.observation.ObservationRegistry; import jakarta.servlet.AsyncEvent; import jakarta.servlet.AsyncListener; +import jakarta.servlet.DispatcherType; import jakarta.servlet.FilterChain; import jakarta.servlet.RequestDispatcher; import jakarta.servlet.ServletException; @@ -97,6 +98,11 @@ public class ServerHttpObservationFilter extends OncePerRequestFilter { return Optional.ofNullable((ServerRequestObservationContext) request.getAttribute(CURRENT_OBSERVATION_CONTEXT_ATTRIBUTE)); } + @Override + protected boolean shouldNotFilterAsyncDispatch() { + return false; + } + @Override @SuppressWarnings("try") protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) @@ -116,8 +122,9 @@ public class ServerHttpObservationFilter extends OncePerRequestFilter { if (request.isAsyncStarted()) { request.getAsyncContext().addListener(new ObservationAsyncListener(observation)); } - // Stop Observation right now if async processing has not been started. - else { + // scope is opened for ASYNC dispatches, but the observation will be closed + // by the async listener. + else if (request.getDispatcherType() != DispatcherType.ASYNC){ Throwable error = fetchException(request); if (error != null) { observation.error(error); @@ -176,7 +183,6 @@ public class ServerHttpObservationFilter extends OncePerRequestFilter { @Override public void onError(AsyncEvent event) { this.currentObservation.error(unwrapServletException(event.getThrowable())); - this.currentObservation.stop(); } } diff --git a/spring-web/src/test/java/org/springframework/web/filter/ServerHttpObservationFilterTests.java b/spring-web/src/test/java/org/springframework/web/filter/ServerHttpObservationFilterTests.java index 80bdfbcbb27..9a7eeb4f411 100644 --- a/spring-web/src/test/java/org/springframework/web/filter/ServerHttpObservationFilterTests.java +++ b/spring-web/src/test/java/org/springframework/web/filter/ServerHttpObservationFilterTests.java @@ -16,15 +16,24 @@ package org.springframework.web.filter; +import java.io.IOException; + import io.micrometer.observation.ObservationRegistry; import io.micrometer.observation.tck.TestObservationRegistry; import io.micrometer.observation.tck.TestObservationRegistryAssert; +import jakarta.servlet.AsyncEvent; +import jakarta.servlet.AsyncListener; +import jakarta.servlet.DispatcherType; import jakarta.servlet.RequestDispatcher; import jakarta.servlet.ServletException; +import jakarta.servlet.http.HttpServlet; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; import org.junit.jupiter.api.Test; import org.springframework.http.HttpMethod; import org.springframework.http.server.observation.ServerRequestObservationContext; +import org.springframework.web.testfixture.servlet.MockAsyncContext; import org.springframework.web.testfixture.servlet.MockFilterChain; import org.springframework.web.testfixture.servlet.MockHttpServletRequest; import org.springframework.web.testfixture.servlet.MockHttpServletResponse; @@ -41,18 +50,18 @@ class ServerHttpObservationFilterTests { private final TestObservationRegistry observationRegistry = TestObservationRegistry.create(); - private final ServerHttpObservationFilter filter = new ServerHttpObservationFilter(this.observationRegistry); - - private final MockFilterChain mockFilterChain = new MockFilterChain(); - private final MockHttpServletRequest request = new MockHttpServletRequest(HttpMethod.GET.name(), "/resource/test"); private final MockHttpServletResponse response = new MockHttpServletResponse(); + private MockFilterChain mockFilterChain = new MockFilterChain(); + + private ServerHttpObservationFilter filter = new ServerHttpObservationFilter(this.observationRegistry); + @Test - void filterShouldNotProcessAsyncDispatch() { - assertThat(this.filter.shouldNotFilterAsyncDispatch()).isTrue(); + void filterShouldProcessAsyncDispatch() { + assertThat(this.filter.shouldNotFilterAsyncDispatch()).isFalse(); } @Test @@ -68,6 +77,12 @@ class ServerHttpObservationFilterTests { assertThatHttpObservation().hasLowCardinalityKeyValue("outcome", "SUCCESS").hasBeenStopped(); } + @Test + void filterShouldOpenScope() throws Exception { + this.mockFilterChain = new MockFilterChain(new ScopeCheckingServlet(this.observationRegistry)); + filter.doFilter(this.request, this.response, this.mockFilterChain); + } + @Test void filterShouldAcceptNoOpObservationContext() throws Exception { ServerHttpObservationFilter filter = new ServerHttpObservationFilter(ObservationRegistry.NOOP); @@ -126,9 +141,52 @@ class ServerHttpObservationFilterTests { assertThatHttpObservation().hasLowCardinalityKeyValue("outcome", "SUCCESS").hasBeenStopped(); } + @Test + void shouldCloseObservationAfterAsyncError() throws Exception { + this.request.setAsyncSupported(true); + this.request.startAsync(); + this.filter.doFilter(this.request, this.response, this.mockFilterChain); + MockAsyncContext asyncContext = (MockAsyncContext) this.request.getAsyncContext(); + for (AsyncListener listener : asyncContext.getListeners()) { + listener.onError(new AsyncEvent(this.request.getAsyncContext(), new IllegalStateException("test error"))); + } + asyncContext.complete(); + assertThatHttpObservation().hasLowCardinalityKeyValue("exception", "IllegalStateException").hasBeenStopped(); + } + + @Test + void shouldNotCloseObservationDuringAsyncDispatch() throws Exception { + this.mockFilterChain = new MockFilterChain(new ScopeCheckingServlet(this.observationRegistry)); + this.request.setDispatcherType(DispatcherType.ASYNC); + this.filter.doFilter(this.request, this.response, this.mockFilterChain); + TestObservationRegistryAssert.assertThat(this.observationRegistry) + .hasObservationWithNameEqualTo("http.server.requests") + .that().isNotStopped(); + } + private TestObservationRegistryAssert.TestObservationRegistryAssertReturningObservationContextAssert assertThatHttpObservation() { + TestObservationRegistryAssert.assertThat(this.observationRegistry) + .hasNumberOfObservationsWithNameEqualTo("http.server.requests", 1); + return TestObservationRegistryAssert.assertThat(this.observationRegistry) - .hasObservationWithNameEqualTo("http.server.requests").that(); + .hasObservationWithNameEqualTo("http.server.requests") + .that() + .hasBeenStopped(); + } + + @SuppressWarnings("serial") + static class ScopeCheckingServlet extends HttpServlet { + + private final ObservationRegistry observationRegistry; + + public ScopeCheckingServlet(ObservationRegistry observationRegistry) { + this.observationRegistry = observationRegistry; + } + + @Override + protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { + assertThat(this.observationRegistry.getCurrentObservation()).isNotNull(); + } } }