diff --git a/web/src/main/java/org/springframework/security/web/server/ObservationWebFilterChainDecorator.java b/web/src/main/java/org/springframework/security/web/server/ObservationWebFilterChainDecorator.java index b414758e53..cb5c8fe179 100644 --- a/web/src/main/java/org/springframework/security/web/server/ObservationWebFilterChainDecorator.java +++ b/web/src/main/java/org/springframework/security/web/server/ObservationWebFilterChainDecorator.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,10 +19,6 @@ package org.springframework.security.web.server; import java.util.ArrayList; import java.util.List; import java.util.ListIterator; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicReference; -import java.util.concurrent.locks.Lock; -import java.util.concurrent.locks.ReentrantLock; import io.micrometer.common.KeyValue; import io.micrometer.common.KeyValues; @@ -258,12 +254,13 @@ public final class ObservationWebFilterChainDecorator implements WebFilterChainP class SimpleAroundWebFilterObservation implements AroundWebFilterObservation { + private final Object lock = new Object(); + private final PhasedObservation before; private final PhasedObservation after; - private final AtomicReference currentObservation = new AtomicReference<>( - PhasedObservation.NOOP); + private volatile PhasedObservation currentObservation = PhasedObservation.NOOP; SimpleAroundWebFilterObservation(Observation before, Observation after) { this.before = new PhasedObservation(before); @@ -272,22 +269,26 @@ public final class ObservationWebFilterChainDecorator implements WebFilterChainP @Override public Observation start() { - if (this.currentObservation.compareAndSet(PhasedObservation.NOOP, this.before)) { - this.before.start(); - return this.before.observation; - } - if (this.currentObservation.compareAndSet(this.before, this.after)) { - this.before.stop(); - this.after.start(); - return this.after.observation; + synchronized (this.lock) { + if (this.currentObservation == PhasedObservation.NOOP) { + this.before.start(); + this.currentObservation = this.before; + return this.currentObservation; + } + if (this.currentObservation == this.before) { + this.before.stop(); + this.after.start(); + this.currentObservation = this.after; + return this.currentObservation; + } } return Observation.NOOP; } @Override public Observation error(Throwable ex) { - this.currentObservation.get().error(ex); - return this.currentObservation.get().observation; + this.currentObservation.error(ex); + return this.currentObservation.observation; } @Override @@ -303,42 +304,42 @@ public final class ObservationWebFilterChainDecorator implements WebFilterChainP @Override public Observation contextualName(String contextualName) { - return this.currentObservation.get().observation.contextualName(contextualName); + return this.currentObservation.observation.contextualName(contextualName); } @Override public Observation parentObservation(Observation parentObservation) { - return this.currentObservation.get().observation.parentObservation(parentObservation); + return this.currentObservation.observation.parentObservation(parentObservation); } @Override public Observation lowCardinalityKeyValue(KeyValue keyValue) { - return this.currentObservation.get().observation.lowCardinalityKeyValue(keyValue); + return this.currentObservation.observation.lowCardinalityKeyValue(keyValue); } @Override public Observation highCardinalityKeyValue(KeyValue keyValue) { - return this.currentObservation.get().observation.highCardinalityKeyValue(keyValue); + return this.currentObservation.observation.highCardinalityKeyValue(keyValue); } @Override public Observation observationConvention(ObservationConvention observationConvention) { - return this.currentObservation.get().observation.observationConvention(observationConvention); + return this.currentObservation.observation.observationConvention(observationConvention); } @Override public Observation event(Event event) { - return this.currentObservation.get().observation.event(event); + return this.currentObservation.observation.event(event); } @Override public Context getContext() { - return this.currentObservation.get().observation.getContext(); + return this.currentObservation.observation.getContext(); } @Override public Scope openScope() { - return this.currentObservation.get().observation.openScope(); + return this.currentObservation.observation.openScope(); } @Override @@ -386,7 +387,7 @@ public final class ObservationWebFilterChainDecorator implements WebFilterChainP @Override public String toString() { - return this.currentObservation.get().observation.toString(); + return this.currentObservation.observation.toString(); } } @@ -665,9 +666,9 @@ public final class ObservationWebFilterChainDecorator implements WebFilterChainP private static final PhasedObservation NOOP = new PhasedObservation(Observation.NOOP); - private final Lock lock = new ReentrantLock(); + private final Object lock = new Object(); - private final AtomicInteger phase = new AtomicInteger(0); + private volatile int phase = 0; private final Observation observation; @@ -717,57 +718,41 @@ public final class ObservationWebFilterChainDecorator implements WebFilterChainP @Override public PhasedObservation start() { - try { - this.lock.lock(); - if (this.phase.compareAndSet(0, 1)) { + synchronized (this.lock) { + if (this.phase == 0) { this.observation.start(); + this.phase = 1; } } - finally { - this.lock.unlock(); - } return this; } @Override public PhasedObservation error(Throwable ex) { - try { - this.lock.lock(); - if (this.phase.get() == 1) { + synchronized (this.lock) { + if (this.phase == 1) { this.observation.error(ex); } } - finally { - this.lock.unlock(); - } return this; } @Override public void stop() { - try { - this.lock.lock(); - if (this.phase.compareAndSet(1, 2)) { + synchronized (this.lock) { + if (this.phase == 1) { this.observation.stop(); + this.phase = 2; } } - finally { - this.lock.unlock(); - } } void close() { - try { - this.lock.lock(); - if (this.phase.compareAndSet(1, 3)) { + synchronized (this.lock) { + if (this.phase == 1) { this.observation.stop(); } - else { - this.phase.set(3); - } - } - finally { - this.lock.unlock(); + this.phase = 3; } }