diff --git a/spring-core/src/main/java/org/springframework/core/io/buffer/DataBufferUtils.java b/spring-core/src/main/java/org/springframework/core/io/buffer/DataBufferUtils.java index acec012beb8..9fd809fe28c 100644 --- a/spring-core/src/main/java/org/springframework/core/io/buffer/DataBufferUtils.java +++ b/spring-core/src/main/java/org/springframework/core/io/buffer/DataBufferUtils.java @@ -18,18 +18,24 @@ package org.springframework.core.io.buffer; import java.io.IOException; import java.io.InputStream; +import java.io.OutputStream; import java.nio.ByteBuffer; import java.nio.channels.AsynchronousFileChannel; import java.nio.channels.Channel; import java.nio.channels.Channels; import java.nio.channels.CompletionHandler; import java.nio.channels.ReadableByteChannel; +import java.nio.channels.WritableByteChannel; import java.util.concurrent.atomic.AtomicLong; import java.util.function.BiFunction; import org.reactivestreams.Publisher; +import org.reactivestreams.Subscription; +import reactor.core.publisher.BaseSubscriber; import reactor.core.publisher.Flux; import reactor.core.publisher.FluxSink; +import reactor.core.publisher.Mono; +import reactor.core.publisher.MonoSink; import reactor.core.publisher.SynchronousSink; import org.springframework.lang.Nullable; @@ -104,7 +110,6 @@ public abstract class DataBufferUtils { * @param bufferSize the maximum size of the data buffers * @return a flux of data buffers read from the given channel */ - @SuppressWarnings("deprecation") public static Flux read(AsynchronousFileChannel channel, long position, DataBufferFactory dataBufferFactory, int bufferSize) { @@ -114,15 +119,104 @@ public abstract class DataBufferUtils { ByteBuffer byteBuffer = ByteBuffer.allocate(bufferSize); - return Flux.create(emitter -> { - emitter.onDispose(() -> closeChannel(channel)); - AsynchronousFileChannelCompletionHandler completionHandler = - new AsynchronousFileChannelCompletionHandler(emitter, position, + return Flux.create(sink -> { + sink.onDispose(() -> closeChannel(channel)); + AsynchronousFileChannelReadCompletionHandler completionHandler = + new AsynchronousFileChannelReadCompletionHandler(sink, position, dataBufferFactory, byteBuffer); channel.read(byteBuffer, position, channel, completionHandler); }); } + /** + * Write the given stream of {@link DataBuffer}s to the given {@code OutputStream}. Does + * not close the output stream when the flux is terminated, but + * does {@linkplain #release(DataBuffer) release} the data buffers in the + * source. + *

Note that the writing process does not start until the returned {@code Mono} is subscribed + * to. + * @param source the stream of data buffers to be written + * @param outputStream the output stream to write to + * @return a mono that starts the writing process when subscribed to, and that indicates the + * completion of the process + */ + public static Mono write(Publisher source, + OutputStream outputStream) { + + Assert.notNull(source, "'source' must not be null"); + Assert.notNull(outputStream, "'outputStream' must not be null"); + + WritableByteChannel channel = Channels.newChannel(outputStream); + return write(source, channel); + } + + /** + * Write the given stream of {@link DataBuffer}s to the given {@code WritableByteChannel}. Does + * not close the channel when the flux is terminated, but + * does {@linkplain #release(DataBuffer) release} the data buffers in the + * source. + *

Note that the writing process does not start until the returned {@code Mono} is subscribed + * to. + * @param source the stream of data buffers to be written + * @param channel the channel to write to + * @return a mono that starts the writing process when subscribed to, and that indicates the + * completion of the process + */ + public static Mono write(Publisher source, + WritableByteChannel channel) { + + Assert.notNull(source, "'source' must not be null"); + Assert.notNull(channel, "'channel' must not be null"); + + Flux flux = Flux.from(source); + + return Mono.create(sink -> + flux.subscribe(dataBuffer -> { + try { + ByteBuffer byteBuffer = dataBuffer.asByteBuffer(); + while (byteBuffer.hasRemaining()) { + channel.write(byteBuffer); + } + release(dataBuffer); + } + catch (IOException ex) { + sink.error(ex); + } + + }, + sink::error, + sink::success)); + } + + /** + * Write the given stream of {@link DataBuffer}s to the given {@code AsynchronousFileChannel}. + * Does not close the channel when the flux is terminated, but + * does {@linkplain #release(DataBuffer) release} the data buffers in the + * source. + *

Note that the writing process does not start until the returned {@code Mono} is subscribed + * to. + * @param source the stream of data buffers to be written + * @param channel the channel to write to + * @return a mono that starts the writing process when subscribed to, and that indicates the + * completion of the process + */ + public static Mono write(Publisher source, AsynchronousFileChannel channel, + long position) { + + Assert.notNull(source, "'source' must not be null"); + Assert.notNull(channel, "'channel' must not be null"); + Assert.isTrue(position >= 0, "'position' must be >= 0"); + + Flux flux = Flux.from(source); + + return Mono.create(sink -> { + BaseSubscriber subscriber = + new AsynchronousFileChannelWriteCompletionHandler(sink, channel, position); + flux.subscribe(subscriber); + }); + } + + private static void closeChannel(@Nullable Channel channel) { try { if (channel != null) { @@ -272,10 +366,10 @@ public abstract class DataBufferUtils { } } - private static class AsynchronousFileChannelCompletionHandler + private static class AsynchronousFileChannelReadCompletionHandler implements CompletionHandler { - private final FluxSink emitter; + private final FluxSink sink; private final ByteBuffer byteBuffer; @@ -283,9 +377,9 @@ public abstract class DataBufferUtils { private long position; - private AsynchronousFileChannelCompletionHandler(FluxSink emitter, + private AsynchronousFileChannelReadCompletionHandler(FluxSink sink, long position, DataBufferFactory dataBufferFactory, ByteBuffer byteBuffer) { - this.emitter = emitter; + this.sink = sink; this.position = position; this.dataBufferFactory = dataBufferFactory; this.byteBuffer = byteBuffer; @@ -301,7 +395,7 @@ public abstract class DataBufferUtils { try { dataBuffer.write(this.byteBuffer); release = false; - this.emitter.next(dataBuffer); + this.sink.next(dataBuffer); } finally { if (release) { @@ -310,20 +404,82 @@ public abstract class DataBufferUtils { } this.byteBuffer.clear(); - if (!this.emitter.isCancelled()) { + if (!this.sink.isCancelled()) { channel.read(this.byteBuffer, this.position, channel, this); } } else { - this.emitter.complete(); + this.sink.complete(); closeChannel(channel); } } @Override public void failed(Throwable exc, AsynchronousFileChannel channel) { - this.emitter.error(exc); + this.sink.error(exc); closeChannel(channel); } } + + private static class AsynchronousFileChannelWriteCompletionHandler + extends BaseSubscriber + implements CompletionHandler { + + private final MonoSink sink; + + private final AsynchronousFileChannel channel; + + private long position; + + @Nullable + private DataBuffer dataBuffer; + + public AsynchronousFileChannelWriteCompletionHandler( + MonoSink sink, AsynchronousFileChannel channel, long position) { + this.sink = sink; + this.channel = channel; + this.position = position; + } + + @Override + protected void hookOnSubscribe(Subscription subscription) { + request(1); + } + + @Override + protected void hookOnNext(DataBuffer value) { + this.dataBuffer = value; + ByteBuffer byteBuffer = value.asByteBuffer(); + + this.channel.write(byteBuffer, this.position, byteBuffer, this); + } + + @Override + protected void hookOnError(Throwable throwable) { + this.sink.error(throwable); + } + + @Override + protected void hookOnComplete() { + this.sink.success(); + } + + @Override + public void completed(Integer written, ByteBuffer byteBuffer) { + this.position += written; + if (byteBuffer.hasRemaining()) { + this.channel.write(byteBuffer, this.position, byteBuffer, this); + } + else { + release(this.dataBuffer); + request(1); + } + + } + + @Override + public void failed(Throwable exc, ByteBuffer byteBuffer) { + this.sink.error(exc); + } + } } diff --git a/spring-core/src/test/java/org/springframework/core/io/buffer/DataBufferUtilsTests.java b/spring-core/src/test/java/org/springframework/core/io/buffer/DataBufferUtilsTests.java index b96867dd95c..c56cbe6fb10 100644 --- a/spring-core/src/test/java/org/springframework/core/io/buffer/DataBufferUtilsTests.java +++ b/spring-core/src/test/java/org/springframework/core/io/buffer/DataBufferUtilsTests.java @@ -17,17 +17,23 @@ package org.springframework.core.io.buffer; import java.io.InputStream; +import java.io.OutputStream; import java.net.URI; import java.nio.channels.AsynchronousFileChannel; import java.nio.channels.FileChannel; +import java.nio.channels.WritableByteChannel; +import java.nio.file.Files; +import java.nio.file.Path; import java.nio.file.Paths; import java.nio.file.StandardOpenOption; +import java.util.stream.Collectors; import org.junit.Test; import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; import reactor.test.StepVerifier; -import static org.junit.Assert.assertFalse; +import static org.junit.Assert.*; /** * @author Arjen Poutsma @@ -112,6 +118,79 @@ public class DataBufferUtilsTests extends AbstractDataBufferAllocatingTestCase { .verify(); } + @Test + public void writeOutputStream() throws Exception { + DataBuffer foo = stringBuffer("foo"); + DataBuffer bar = stringBuffer("bar"); + DataBuffer baz = stringBuffer("baz"); + DataBuffer qux = stringBuffer("qux"); + Flux flux = Flux.just(foo, bar, baz, qux); + + Path tempFile = Files.createTempFile("DataBufferUtilsTests", null); + OutputStream os = Files.newOutputStream(tempFile); + + Mono writeResult = DataBufferUtils.write(flux, os); + StepVerifier.create(writeResult) + .expectComplete() + .verify(); + + String result = Files.readAllLines(tempFile) + .stream() + .collect(Collectors.joining()); + + assertEquals("foobarbazqux", result); + os.close(); + } + + @Test + public void writeWritableByteChannel() throws Exception { + DataBuffer foo = stringBuffer("foo"); + DataBuffer bar = stringBuffer("bar"); + DataBuffer baz = stringBuffer("baz"); + DataBuffer qux = stringBuffer("qux"); + Flux flux = Flux.just(foo, bar, baz, qux); + + Path tempFile = Files.createTempFile("DataBufferUtilsTests", null); + WritableByteChannel channel = Files.newByteChannel(tempFile, StandardOpenOption.WRITE); + + Mono writeResult = DataBufferUtils.write(flux, channel); + StepVerifier.create(writeResult) + .expectComplete() + .verify(); + + String result = Files.readAllLines(tempFile) + .stream() + .collect(Collectors.joining()); + + assertEquals("foobarbazqux", result); + channel.close(); + } + + @Test + public void writeAsynchronousFileChannel() throws Exception { + DataBuffer foo = stringBuffer("foo"); + DataBuffer bar = stringBuffer("bar"); + DataBuffer baz = stringBuffer("baz"); + DataBuffer qux = stringBuffer("qux"); + Flux flux = Flux.just(foo, bar, baz, qux); + + Path tempFile = Files.createTempFile("DataBufferUtilsTests", null); + AsynchronousFileChannel channel = + AsynchronousFileChannel.open(tempFile, StandardOpenOption.WRITE); + + Mono writeResult = DataBufferUtils.write(flux, channel, 0); + StepVerifier.create(writeResult) + .expectComplete() + .verify(); + + String result = Files.readAllLines(tempFile) + .stream() + .collect(Collectors.joining()); + + assertEquals("foobarbazqux", result); + channel.close(); + } + @Test public void takeUntilByteCount() throws Exception { DataBuffer foo = stringBuffer("foo");