From 5dce82c48bc0b174838501c5a111b2de70822914 Mon Sep 17 00:00:00 2001 From: Josh Cummings Date: Mon, 30 Oct 2023 15:35:56 -0600 Subject: [PATCH] Close Both Observations Depending on when a request is cancelled, the before and after observation starts and stops may be called out of order due to the order in which their doOnCancel handlers are invoked. To address this, the before filter-wrapper now always closes both the before observation and the after observation. Since the before filter- wrapper wraps the entire request, this ensures that either that was started is stopped, and either that has not been started yet cannot inadvertently be started by any unexpected ordering of events that follows. Closes gh-14031 --- .../ObservationWebFilterChainDecorator.java | 29 ++++- ...servationWebFilterChainDecoratorTests.java | 115 ++++++++++++++++++ 2 files changed, 140 insertions(+), 4 deletions(-) diff --git a/web/src/main/java/org/springframework/security/web/server/ObservationWebFilterChainDecorator.java b/web/src/main/java/org/springframework/security/web/server/ObservationWebFilterChainDecorator.java index 5cfa80ac0c..8644e1ea00 100644 --- a/web/src/main/java/org/springframework/security/web/server/ObservationWebFilterChainDecorator.java +++ b/web/src/main/java/org/springframework/security/web/server/ObservationWebFilterChainDecorator.java @@ -292,7 +292,13 @@ public final class ObservationWebFilterChainDecorator implements WebFilterChainP @Override public void stop() { - this.currentObservation.get().stop(); + this.before.stop(); + this.after.stop(); + } + + private void close() { + this.before.close(); + this.after.close(); } @Override @@ -357,11 +363,11 @@ public final class ObservationWebFilterChainDecorator implements WebFilterChainP start(); // @formatter:off return filter.filter(exchange, chain) - .doOnSuccess((v) -> stop()) - .doOnCancel(this::stop) + .doOnSuccess((v) -> close()) + .doOnCancel(this::close) .doOnError((t) -> { error(t); - stop(); + close(); }) .contextWrite((context) -> context.put(ObservationThreadLocalAccessor.KEY, this)); // @formatter:on @@ -433,6 +439,21 @@ public final class ObservationWebFilterChainDecorator implements WebFilterChainP } } + private void close() { + try { + this.lock.lock(); + if (this.state.compareAndSet(1, 3)) { + this.observation.stop(); + } + else { + this.state.set(3); + } + } + finally { + this.lock.unlock(); + } + } + } } diff --git a/web/src/test/java/org/springframework/security/web/server/ObservationWebFilterChainDecoratorTests.java b/web/src/test/java/org/springframework/security/web/server/ObservationWebFilterChainDecoratorTests.java index 1266110e8c..44218bd52e 100644 --- a/web/src/test/java/org/springframework/security/web/server/ObservationWebFilterChainDecoratorTests.java +++ b/web/src/test/java/org/springframework/security/web/server/ObservationWebFilterChainDecoratorTests.java @@ -78,6 +78,98 @@ public class ObservationWebFilterChainDecoratorTests { verifyNoInteractions(handler); } + @Test + void decorateWhenTerminatingFilterThenObserves() { + AccumulatingObservationHandler handler = new AccumulatingObservationHandler(); + ObservationRegistry registry = ObservationRegistry.create(); + registry.observationConfig().observationHandler(handler); + ObservationWebFilterChainDecorator decorator = new ObservationWebFilterChainDecorator(registry); + WebFilterChain chain = mock(WebFilterChain.class); + given(chain.filter(any())).willReturn(Mono.error(() -> new Exception("ack"))); + WebFilterChain decorated = decorator.decorate(chain, + List.of(new BasicAuthenticationFilter(), new TerminatingFilter())); + Observation http = Observation.start("http", registry).contextualName("http"); + try { + decorated.filter(MockServerWebExchange.from(MockServerHttpRequest.get("/").build())) + .contextWrite((context) -> context.put(ObservationThreadLocalAccessor.KEY, http)) + .block(); + } + catch (Exception ex) { + http.error(ex); + } + finally { + http.stop(); + } + handler.assertSpanStart(0, "http", null); + handler.assertSpanStart(1, "spring.security.filterchains", "http"); + handler.assertSpanStop(2, "security filterchain before"); + handler.assertSpanStart(3, "spring.security.filterchains", "http"); + handler.assertSpanStop(4, "security filterchain after"); + handler.assertSpanStop(5, "http"); + } + + @Test + void decorateWhenFilterErrorThenStopsObservation() { + AccumulatingObservationHandler handler = new AccumulatingObservationHandler(); + ObservationRegistry registry = ObservationRegistry.create(); + registry.observationConfig().observationHandler(handler); + ObservationWebFilterChainDecorator decorator = new ObservationWebFilterChainDecorator(registry); + WebFilterChain chain = mock(WebFilterChain.class); + WebFilterChain decorated = decorator.decorate(chain, List.of(new ErroringFilter())); + Observation http = Observation.start("http", registry).contextualName("http"); + try { + decorated.filter(MockServerWebExchange.from(MockServerHttpRequest.get("/").build())) + .contextWrite((context) -> context.put(ObservationThreadLocalAccessor.KEY, http)) + .block(); + } + catch (Exception ex) { + http.error(ex); + } + finally { + http.stop(); + } + handler.assertSpanStart(0, "http", null); + handler.assertSpanStart(1, "spring.security.filterchains", "http"); + handler.assertSpanError(2); + handler.assertSpanStop(3, "security filterchain before"); + handler.assertSpanError(4); + handler.assertSpanStop(5, "http"); + } + + @Test + void decorateWhenErrorSignalThenStopsObservation() { + AccumulatingObservationHandler handler = new AccumulatingObservationHandler(); + ObservationRegistry registry = ObservationRegistry.create(); + registry.observationConfig().observationHandler(handler); + ObservationWebFilterChainDecorator decorator = new ObservationWebFilterChainDecorator(registry); + WebFilterChain chain = mock(WebFilterChain.class); + given(chain.filter(any())).willReturn(Mono.error(() -> new Exception("ack"))); + WebFilterChain decorated = decorator.decorate(chain, List.of(new BasicAuthenticationFilter())); + Observation http = Observation.start("http", registry).contextualName("http"); + try { + decorated.filter(MockServerWebExchange.from(MockServerHttpRequest.get("/").build())) + .contextWrite((context) -> context.put(ObservationThreadLocalAccessor.KEY, http)) + .block(); + } + catch (Exception ex) { + http.error(ex); + } + finally { + http.stop(); + } + handler.assertSpanStart(0, "http", null); + handler.assertSpanStart(1, "spring.security.filterchains", "http"); + handler.assertSpanStop(2, "security filterchain before"); + handler.assertSpanStart(3, "secured request", "security filterchain before"); + handler.assertSpanError(4); + handler.assertSpanStop(5, "secured request"); + handler.assertSpanStart(6, "spring.security.filterchains", "http"); + handler.assertSpanError(7); + handler.assertSpanStop(8, "security filterchain after"); + handler.assertSpanError(9); + handler.assertSpanStop(10, "http"); + } + // gh-12849 @Test void decorateWhenCustomAfterFilterThenObserves() { @@ -171,6 +263,24 @@ public class ObservationWebFilterChainDecoratorTests { } + static class ErroringFilter implements WebFilter { + + @Override + public Mono filter(ServerWebExchange exchange, WebFilterChain chain) { + return Mono.error(() -> new RuntimeException("ack")); + } + + } + + static class TerminatingFilter implements WebFilter { + + @Override + public Mono filter(ServerWebExchange exchange, WebFilterChain chain) { + return Mono.empty(); + } + + } + static class AccumulatingObservationHandler implements ObservationHandler { List contexts = new ArrayList<>(); @@ -246,6 +356,11 @@ public class ObservationWebFilterChainDecoratorTests { } } + private void assertSpanError(int index) { + Event event = this.contexts.get(index); + assertThat(event.event).isEqualTo("error"); + } + static class Event { String event;