diff --git a/spring-core/src/main/java/org/springframework/core/io/buffer/DataBufferUtils.java b/spring-core/src/main/java/org/springframework/core/io/buffer/DataBufferUtils.java
index 74544ffb73b..c9235049643 100644
--- a/spring-core/src/main/java/org/springframework/core/io/buffer/DataBufferUtils.java
+++ b/spring-core/src/main/java/org/springframework/core/io/buffer/DataBufferUtils.java
@@ -32,6 +32,7 @@ import java.util.concurrent.Callable;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Consumer;
+import java.util.function.IntPredicate;
import org.reactivestreams.Publisher;
import org.reactivestreams.Subscription;
@@ -473,6 +474,9 @@ public abstract class DataBufferUtils {
* Depending on the {@link DataBuffer} implementation, the returned buffer may be a single
* buffer containing all data of the provided buffers, or it may be a true composite that
* contains references to the buffers.
+ *
If {@code dataBuffers} contains an error signal, then all buffers that preceded the error
+ * will be {@linkplain #release(DataBuffer) released}, and the error is stored in the
+ * returned {@code Mono}.
* @param dataBuffers the data buffers that are to be composed
* @return a buffer that is composed from the {@code dataBuffers} argument
* @since 5.0.3
@@ -481,14 +485,26 @@ public abstract class DataBufferUtils {
Assert.notNull(dataBuffers, "'dataBuffers' must not be null");
return Flux.from(dataBuffers)
+ .onErrorResume(DataBufferUtils::exceptionDataBuffer)
.collectList()
.filter(list -> !list.isEmpty())
- .map(list -> {
+ .flatMap(list -> {
+ for (int i = 0; i < list.size(); i++) {
+ DataBuffer dataBuffer = list.get(i);
+ if (dataBuffer instanceof ExceptionDataBuffer) {
+ list.subList(0, i).forEach(DataBufferUtils::release);
+ return Mono.error(((ExceptionDataBuffer) dataBuffer).throwable());
+ }
+ }
DataBufferFactory bufferFactory = list.get(0).factory();
- return bufferFactory.join(list);
+ return Mono.just(bufferFactory.join(list));
});
}
+ private static Mono exceptionDataBuffer(Throwable throwable) {
+ return Mono.just(new ExceptionDataBuffer(throwable));
+ }
+
private static class ReadableByteChannelGenerator implements Consumer> {
@@ -658,4 +674,153 @@ public abstract class DataBufferUtils {
}
}
+ /**
+ * DataBuffer implementation that holds a {@link Throwable}, used in {@link #join(Publisher)}.
+ */
+ private static final class ExceptionDataBuffer implements DataBuffer {
+
+ private final Throwable throwable;
+
+
+ public ExceptionDataBuffer(Throwable throwable) {
+ this.throwable = throwable;
+ }
+
+ public Throwable throwable() {
+ return this.throwable;
+ }
+
+ // Unsupported
+
+ @Override
+ public DataBufferFactory factory() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public int indexOf(IntPredicate predicate, int fromIndex) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public int lastIndexOf(IntPredicate predicate, int fromIndex) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public int readableByteCount() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public int writableByteCount() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public int capacity() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public DataBuffer capacity(int capacity) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public int readPosition() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public DataBuffer readPosition(int readPosition) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public int writePosition() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public DataBuffer writePosition(int writePosition) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public byte getByte(int index) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public byte read() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public DataBuffer read(byte[] destination) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public DataBuffer read(byte[] destination, int offset, int length) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public DataBuffer write(byte b) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public DataBuffer write(byte[] source) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public DataBuffer write(byte[] source, int offset, int length) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public DataBuffer write(DataBuffer... buffers) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public DataBuffer write(ByteBuffer... buffers) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public DataBuffer slice(int index, int length) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public ByteBuffer asByteBuffer() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public ByteBuffer asByteBuffer(int index, int length) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public InputStream asInputStream() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public InputStream asInputStream(boolean releaseOnClose) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public OutputStream asOutputStream() {
+ throw new UnsupportedOperationException();
+ }
+ }
+
}
diff --git a/spring-core/src/test/java/org/springframework/core/io/buffer/DataBufferUtilsTests.java b/spring-core/src/test/java/org/springframework/core/io/buffer/DataBufferUtilsTests.java
index 4b7ddb7dc39..f8419d40cc5 100644
--- a/spring-core/src/test/java/org/springframework/core/io/buffer/DataBufferUtilsTests.java
+++ b/spring-core/src/test/java/org/springframework/core/io/buffer/DataBufferUtilsTests.java
@@ -35,6 +35,7 @@ import io.netty.buffer.ByteBuf;
import org.junit.Test;
import org.mockito.stubbing.Answer;
import reactor.core.publisher.Flux;
+import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;
import org.springframework.core.io.ClassPathResource;
@@ -338,12 +339,26 @@ public class DataBufferUtilsTests extends AbstractDataBufferAllocatingTestCase {
DataBuffer bar = stringBuffer("bar");
DataBuffer baz = stringBuffer("baz");
Flux flux = Flux.just(foo, bar, baz);
+ Mono result = DataBufferUtils.join(flux);
- DataBuffer result = DataBufferUtils.join(flux).block(Duration.ofSeconds(5));
+ StepVerifier.create(result)
+ .consumeNextWith(dataBuffer -> {
+ assertEquals("foobarbaz", DataBufferTestUtils.dumpString(dataBuffer, StandardCharsets.UTF_8));
+ release(dataBuffer);
+ })
+ .verifyComplete();
+ }
- assertEquals("foobarbaz", DataBufferTestUtils.dumpString(result, StandardCharsets.UTF_8));
+ @Test
+ public void joinErrors() {
+ DataBuffer foo = stringBuffer("foo");
+ DataBuffer bar = stringBuffer("bar");
+ Flux flux = Flux.just(foo, bar).mergeWith(Flux.error(new RuntimeException()));
+ Mono result = DataBufferUtils.join(flux);
- release(result);
+ StepVerifier.create(result)
+ .expectError(RuntimeException.class)
+ .verify();
}
}