diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/observability/MongoObservationCommandListener.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/observability/MongoObservationCommandListener.java index e299d7931..2cbe4f03d 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/observability/MongoObservationCommandListener.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/observability/MongoObservationCommandListener.java @@ -15,6 +15,10 @@ */ package org.springframework.data.mongodb.observability; +import io.micrometer.observation.Observation; +import io.micrometer.observation.ObservationRegistry; +import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; + import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.springframework.lang.Nullable; @@ -27,10 +31,6 @@ import com.mongodb.event.CommandListener; import com.mongodb.event.CommandStartedEvent; import com.mongodb.event.CommandSucceededEvent; -import io.micrometer.observation.Observation; -import io.micrometer.observation.ObservationRegistry; -import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; - /** * Implement MongoDB's {@link CommandListener} using Micrometer's {@link Observation} API. * @@ -133,11 +133,10 @@ public class MongoObservationCommandListener implements CommandListener { } Observation observation = requestContext.getOrDefault(ObservationThreadLocalAccessor.KEY, null); - if (observation == null) { + if (observation == null || !(observation.getContext()instanceof MongoHandlerContext context)) { return; } - MongoHandlerContext context = (MongoHandlerContext) observation.getContext(); context.setCommandSucceededEvent(event); if (log.isDebugEnabled()) { @@ -157,11 +156,10 @@ public class MongoObservationCommandListener implements CommandListener { } Observation observation = requestContext.getOrDefault(ObservationThreadLocalAccessor.KEY, null); - if (observation == null) { + if (observation == null || !(observation.getContext()instanceof MongoHandlerContext context)) { return; } - MongoHandlerContext context = (MongoHandlerContext) observation.getContext(); context.setCommandFailedEvent(event); if (log.isDebugEnabled()) { diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/observability/MongoObservationCommandListenerTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/observability/MongoObservationCommandListenerTests.java index 8d7826453..9efc72683 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/observability/MongoObservationCommandListenerTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/observability/MongoObservationCommandListenerTests.java @@ -16,6 +16,15 @@ package org.springframework.data.mongodb.observability; import static io.micrometer.core.tck.MeterRegistryAssert.*; +import static org.mockito.Mockito.*; + +import io.micrometer.common.KeyValues; +import io.micrometer.core.instrument.MeterRegistry; +import io.micrometer.core.instrument.observation.DefaultMeterObservationHandler; +import io.micrometer.core.instrument.simple.SimpleMeterRegistry; +import io.micrometer.observation.Observation; +import io.micrometer.observation.ObservationRegistry; +import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; import org.bson.BsonDocument; import org.bson.BsonString; @@ -33,18 +42,12 @@ import com.mongodb.event.CommandFailedEvent; import com.mongodb.event.CommandStartedEvent; import com.mongodb.event.CommandSucceededEvent; -import io.micrometer.common.KeyValues; -import io.micrometer.core.instrument.MeterRegistry; -import io.micrometer.core.instrument.observation.DefaultMeterObservationHandler; -import io.micrometer.core.instrument.simple.SimpleMeterRegistry; -import io.micrometer.observation.Observation; -import io.micrometer.observation.ObservationRegistry; - /** * Series of test cases exercising {@link MongoObservationCommandListener}. * * @author Marcin Grzejszczak * @author Greg Turnquist + * @author Mark Paluch */ class MongoObservationCommandListenerTests { @@ -176,6 +179,38 @@ class MongoObservationCommandListenerTests { assertThatTimerRegisteredWithTags(); } + @Test // GH-4481 + void completionShouldIgnoreIncompatibleObservationContext() { + + // given + RequestContext traceRequestContext = getContext(); + + Observation observation = mock(Observation.class); + traceRequestContext.put(ObservationThreadLocalAccessor.KEY, observation); + + // when + listener.commandSucceeded(new CommandSucceededEvent(traceRequestContext, 0, null, "insert", null, 0)); + + verify(observation).getContext(); + verifyNoMoreInteractions(observation); + } + + @Test // GH-4481 + void failureShouldIgnoreIncompatibleObservationContext() { + + // given + RequestContext traceRequestContext = getContext(); + + Observation observation = mock(Observation.class); + traceRequestContext.put(ObservationThreadLocalAccessor.KEY, observation); + + // when + listener.commandFailed(new CommandFailedEvent(traceRequestContext, 0, null, "insert", 0, null)); + + verify(observation).getContext(); + verifyNoMoreInteractions(observation); + } + private RequestContext getContext() { return ((SynchronousContextProvider) ContextProviderFactory.create(observationRegistry)).getContext(); }