From b3c0fbb02d55cb26cc5e907f2e5103f119cf0849 Mon Sep 17 00:00:00 2001 From: Mark Paluch Date: Thu, 24 Aug 2023 13:47:51 +0200 Subject: [PATCH] Guard command completion listener against unsupported observation context. We now no longer attempt to complete the Observation if the context is not a MongoDB one. For commands that target the admin database and run within a parent observation, we still might have an Observation but that one points to the parent invocation and not the MongoDB one as we do not record commands for the admin database. Closes #4481 --- .../MongoObservationCommandListener.java | 14 +++--- .../MongoObservationCommandListenerTests.java | 49 ++++++++++++++++--- 2 files changed, 48 insertions(+), 15 deletions(-) diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/observability/MongoObservationCommandListener.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/observability/MongoObservationCommandListener.java index e299d7931..2cbe4f03d 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/observability/MongoObservationCommandListener.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/observability/MongoObservationCommandListener.java @@ -15,6 +15,10 @@ */ package org.springframework.data.mongodb.observability; +import io.micrometer.observation.Observation; +import io.micrometer.observation.ObservationRegistry; +import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; + import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.springframework.lang.Nullable; @@ -27,10 +31,6 @@ import com.mongodb.event.CommandListener; import com.mongodb.event.CommandStartedEvent; import com.mongodb.event.CommandSucceededEvent; -import io.micrometer.observation.Observation; -import io.micrometer.observation.ObservationRegistry; -import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; - /** * Implement MongoDB's {@link CommandListener} using Micrometer's {@link Observation} API. * @@ -133,11 +133,10 @@ public class MongoObservationCommandListener implements CommandListener { } Observation observation = requestContext.getOrDefault(ObservationThreadLocalAccessor.KEY, null); - if (observation == null) { + if (observation == null || !(observation.getContext()instanceof MongoHandlerContext context)) { return; } - MongoHandlerContext context = (MongoHandlerContext) observation.getContext(); context.setCommandSucceededEvent(event); if (log.isDebugEnabled()) { @@ -157,11 +156,10 @@ public class MongoObservationCommandListener implements CommandListener { } Observation observation = requestContext.getOrDefault(ObservationThreadLocalAccessor.KEY, null); - if (observation == null) { + if (observation == null || !(observation.getContext()instanceof MongoHandlerContext context)) { return; } - MongoHandlerContext context = (MongoHandlerContext) observation.getContext(); context.setCommandFailedEvent(event); if (log.isDebugEnabled()) { diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/observability/MongoObservationCommandListenerTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/observability/MongoObservationCommandListenerTests.java index 8d7826453..9efc72683 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/observability/MongoObservationCommandListenerTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/observability/MongoObservationCommandListenerTests.java @@ -16,6 +16,15 @@ package org.springframework.data.mongodb.observability; import static io.micrometer.core.tck.MeterRegistryAssert.*; +import static org.mockito.Mockito.*; + +import io.micrometer.common.KeyValues; +import io.micrometer.core.instrument.MeterRegistry; +import io.micrometer.core.instrument.observation.DefaultMeterObservationHandler; +import io.micrometer.core.instrument.simple.SimpleMeterRegistry; +import io.micrometer.observation.Observation; +import io.micrometer.observation.ObservationRegistry; +import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; import org.bson.BsonDocument; import org.bson.BsonString; @@ -33,18 +42,12 @@ import com.mongodb.event.CommandFailedEvent; import com.mongodb.event.CommandStartedEvent; import com.mongodb.event.CommandSucceededEvent; -import io.micrometer.common.KeyValues; -import io.micrometer.core.instrument.MeterRegistry; -import io.micrometer.core.instrument.observation.DefaultMeterObservationHandler; -import io.micrometer.core.instrument.simple.SimpleMeterRegistry; -import io.micrometer.observation.Observation; -import io.micrometer.observation.ObservationRegistry; - /** * Series of test cases exercising {@link MongoObservationCommandListener}. * * @author Marcin Grzejszczak * @author Greg Turnquist + * @author Mark Paluch */ class MongoObservationCommandListenerTests { @@ -176,6 +179,38 @@ class MongoObservationCommandListenerTests { assertThatTimerRegisteredWithTags(); } + @Test // GH-4481 + void completionShouldIgnoreIncompatibleObservationContext() { + + // given + RequestContext traceRequestContext = getContext(); + + Observation observation = mock(Observation.class); + traceRequestContext.put(ObservationThreadLocalAccessor.KEY, observation); + + // when + listener.commandSucceeded(new CommandSucceededEvent(traceRequestContext, 0, null, "insert", null, 0)); + + verify(observation).getContext(); + verifyNoMoreInteractions(observation); + } + + @Test // GH-4481 + void failureShouldIgnoreIncompatibleObservationContext() { + + // given + RequestContext traceRequestContext = getContext(); + + Observation observation = mock(Observation.class); + traceRequestContext.put(ObservationThreadLocalAccessor.KEY, observation); + + // when + listener.commandFailed(new CommandFailedEvent(traceRequestContext, 0, null, "insert", 0, null)); + + verify(observation).getContext(); + verifyNoMoreInteractions(observation); + } + private RequestContext getContext() { return ((SynchronousContextProvider) ContextProviderFactory.create(observationRegistry)).getContext(); }