Browse Source

Fix race condition in OutputCapture

Update `OutputCapture` to fix a race condition that could occur due to the
cache. Prior to this commit, when data was written whilst simultaneously
being read, any subsequent reading of data might miss the last output.

See gh-46685

Signed-off-by: Daniel Schmidt <daniel-github@ad-schmidt.de>
pull/46752/head
Daniel Schmidt 4 months ago committed by Phillip Webb
parent
commit
10b8a43291
  1. 36
      spring-boot-project/spring-boot-test/src/main/java/org/springframework/boot/test/system/OutputCapture.java
  2. 63
      spring-boot-project/spring-boot-test/src/test/java/org/springframework/boot/test/system/OutputCaptureTests.java

36
spring-boot-project/spring-boot-test/src/main/java/org/springframework/boot/test/system/OutputCapture.java

@ -23,6 +23,7 @@ import java.util.ArrayDeque;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Deque; import java.util.Deque;
import java.util.List; import java.util.List;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer; import java.util.function.Consumer;
import java.util.function.Predicate; import java.util.function.Predicate;
@ -40,6 +41,7 @@ import org.springframework.util.ClassUtils;
* @author Phillip Webb * @author Phillip Webb
* @author Andy Wilkinson * @author Andy Wilkinson
* @author Sam Brannen * @author Sam Brannen
* @author Daniel Schmidt
* @see OutputCaptureExtension * @see OutputCaptureExtension
* @see OutputCaptureRule * @see OutputCaptureRule
*/ */
@ -49,11 +51,14 @@ class OutputCapture implements CapturedOutput {
private AnsiOutputState ansiOutputState; private AnsiOutputState ansiOutputState;
private final AtomicReference<String> out = new AtomicReference<>(null); private final AtomicLong outVersion = new AtomicLong();
private final AtomicReference<VersionedCacheResult> out = new AtomicReference<>(null);
private final AtomicReference<String> err = new AtomicReference<>(null); private final AtomicLong errVersion = new AtomicLong();
private final AtomicReference<VersionedCacheResult> err = new AtomicReference<>(null);
private final AtomicReference<String> all = new AtomicReference<>(null); private final AtomicLong allVersion = new AtomicLong();
private final AtomicReference<VersionedCacheResult> all = new AtomicReference<>(null);
/** /**
* Push a new system capture session onto the stack. * Push a new system capture session onto the stack.
@ -106,7 +111,7 @@ class OutputCapture implements CapturedOutput {
*/ */
@Override @Override
public String getAll() { public String getAll() {
return get(this.all, (type) -> true); return get(this.all, this.allVersion, (type) -> true);
} }
/** /**
@ -115,7 +120,7 @@ class OutputCapture implements CapturedOutput {
*/ */
@Override @Override
public String getOut() { public String getOut() {
return get(this.out, Type.OUT::equals); return get(this.out, this.outVersion, Type.OUT::equals);
} }
/** /**
@ -124,7 +129,7 @@ class OutputCapture implements CapturedOutput {
*/ */
@Override @Override
public String getErr() { public String getErr() {
return get(this.err, Type.ERR::equals); return get(this.err, this.errVersion, Type.ERR::equals);
} }
/** /**
@ -136,19 +141,24 @@ class OutputCapture implements CapturedOutput {
} }
void clearExisting() { void clearExisting() {
this.outVersion.incrementAndGet();
this.out.set(null); this.out.set(null);
this.errVersion.incrementAndGet();
this.err.set(null); this.err.set(null);
this.allVersion.incrementAndGet();
this.all.set(null); this.all.set(null);
} }
private String get(AtomicReference<String> existing, Predicate<Type> filter) { private String get(AtomicReference<VersionedCacheResult> resultCache, AtomicLong version, Predicate<Type> filter) {
Assert.state(!this.systemCaptures.isEmpty(), Assert.state(!this.systemCaptures.isEmpty(),
"No system captures found. Please check your output capture registration."); "No system captures found. Please check your output capture registration.");
String result = existing.get(); long currentVersion = version.get();
if (result == null) { VersionedCacheResult cached = resultCache.get();
result = build(filter); if (cached != null && cached.version == currentVersion) {
existing.compareAndSet(null, result); return cached.result;
} }
String result = build(filter);
resultCache.compareAndSet(null, new VersionedCacheResult(result, currentVersion));
return result; return result;
} }
@ -160,6 +170,10 @@ class OutputCapture implements CapturedOutput {
return builder.toString(); return builder.toString();
} }
private record VersionedCacheResult(String result, long version) {
}
/** /**
* A capture session that captures {@link System#out System.out} and {@link System#out * A capture session that captures {@link System#out System.out} and {@link System#out
* System.err}. * System.err}.

63
spring-boot-project/spring-boot-test/src/test/java/org/springframework/boot/test/system/OutputCaptureTests.java

@ -19,8 +19,13 @@ package org.springframework.boot.test.system;
import java.io.ByteArrayOutputStream; import java.io.ByteArrayOutputStream;
import java.io.PrintStream; import java.io.PrintStream;
import java.util.NoSuchElementException; import java.util.NoSuchElementException;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.function.Predicate; import java.util.function.Predicate;
import org.jspecify.annotations.Nullable;
import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
@ -32,6 +37,7 @@ import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
* Tests for {@link OutputCapture}. * Tests for {@link OutputCapture}.
* *
* @author Phillip Webb * @author Phillip Webb
* @author Daniel Schmidt
*/ */
class OutputCaptureTests { class OutputCaptureTests {
@ -188,6 +194,45 @@ class OutputCaptureTests {
assertThat(this.output.buildCount).isEqualTo(2); assertThat(this.output.buildCount).isEqualTo(2);
} }
@Test
void getOutCacheShouldNotReturnStaleDataWhenDataIsLoggedWhileReading() throws Exception {
this.output.push();
System.out.print("A");
this.output.waitAfterBuildLatch = new CountDownLatch(1);
ExecutorService executorService = null;
try {
executorService = Executors.newFixedThreadPool(2);
var readingThreadFuture = executorService.submit(() -> {
// this will release the releaseAfterBuildLatch and block on the waitAfterBuildLatch
assertThat(this.output.getOut()).isEqualTo("A");
});
var writingThreadFuture = executorService.submit(() -> {
// wait until we finished building the first result (but did not yet update the cache)
try {
this.output.releaseAfterBuildLatch.await();
}
catch (InterruptedException e) {
throw new RuntimeException(e);
}
// print something else and then release the latch, for the other thread to continue
System.out.print("B");
this.output.waitAfterBuildLatch.countDown();
});
readingThreadFuture.get();
writingThreadFuture.get();
}
finally {
if (executorService != null) {
executorService.shutdown();
executorService.awaitTermination(10, TimeUnit.SECONDS);
}
}
// If not synchronized correctly this will fail, because the second print did not clear the cache and the cache will return stale data.
assertThat(this.output.getOut()).isEqualTo("AB");
}
private void pushAndPrint() { private void pushAndPrint() {
this.output.push(); this.output.push();
System.out.print("A"); System.out.print("A");
@ -212,10 +257,26 @@ class OutputCaptureTests {
int buildCount; int buildCount;
@Nullable
CountDownLatch waitAfterBuildLatch = null;
CountDownLatch releaseAfterBuildLatch = new CountDownLatch(1);
@Override @Override
String build(Predicate<Type> filter) { String build(Predicate<Type> filter) {
this.buildCount++; this.buildCount++;
return super.build(filter); var result = super.build(filter);
this.releaseAfterBuildLatch.countDown();
if (this.waitAfterBuildLatch != null) {
try {
this.waitAfterBuildLatch.await();
}
catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new RuntimeException(e);
}
}
return result;
} }
} }

Loading…
Cancel
Save