diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/gridfs/AsyncInputStreamAdapter.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/gridfs/AsyncInputStreamAdapter.java index 51ea4a1dc..f9e3dc602 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/gridfs/AsyncInputStreamAdapter.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/gridfs/AsyncInputStreamAdapter.java @@ -17,6 +17,8 @@ package org.springframework.data.mongodb.gridfs; import lombok.RequiredArgsConstructor; import reactor.core.CoreSubscriber; +import reactor.core.publisher.Flux; +import reactor.core.publisher.FluxSink; import reactor.core.publisher.Mono; import reactor.core.publisher.Operators; import reactor.util.concurrent.Queues; @@ -25,14 +27,15 @@ import reactor.util.context.Context; import java.nio.ByteBuffer; import java.util.Queue; import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicLongFieldUpdater; import java.util.function.BiConsumer; import org.reactivestreams.Publisher; import org.reactivestreams.Subscription; + import org.springframework.core.io.buffer.DataBuffer; import org.springframework.core.io.buffer.DataBufferUtils; -import org.springframework.core.io.buffer.DefaultDataBufferFactory; import com.mongodb.reactivestreams.client.Success; import com.mongodb.reactivestreams.client.gridfs.AsyncInputStream; @@ -66,15 +69,16 @@ class AsyncInputStreamAdapter implements AsyncInputStream { private final Publisher buffers; private final Context subscriberContext; - private final DefaultDataBufferFactory factory = new DefaultDataBufferFactory(); private volatile Subscription subscription; private volatile boolean cancelled; - private volatile boolean complete; + private volatile boolean allDataBuffersReceived; private volatile Throwable error; private final Queue> readRequests = Queues.> small() .get(); + private final Queue bufferQueue = Queues. small().get(); + // see DEMAND volatile long demand; @@ -88,41 +92,75 @@ class AsyncInputStreamAdapter implements AsyncInputStream { @Override public Publisher read(ByteBuffer dst) { - return Mono.create(sink -> { + return Flux.create(sink -> { + AtomicLong written = new AtomicLong(); readRequests.offer((db, bytecount) -> { try { if (error != null) { - - sink.error(error); + onError(sink, error); return; } if (bytecount == -1) { - sink.success(-1); + onComplete(sink, written.get() > 0 ? written.intValue() : -1); return; } ByteBuffer byteBuffer = db.asByteBuffer(); - int toWrite = byteBuffer.remaining(); + int remaining = byteBuffer.remaining(); + int writeCapacity = Math.min(dst.remaining(), remaining); + int limit = Math.min(byteBuffer.position() + writeCapacity, byteBuffer.capacity()); + int toWrite = limit - byteBuffer.position(); + + if (toWrite == 0) { + onComplete(sink, written.intValue()); + return; + } + + int oldPosition = byteBuffer.position(); + + byteBuffer.limit(toWrite); dst.put(byteBuffer); - sink.success(toWrite); + byteBuffer.limit(byteBuffer.capacity()); + byteBuffer.position(oldPosition); + db.readPosition(db.readPosition() + toWrite); + written.addAndGet(toWrite); } catch (Exception e) { - sink.error(e); + onError(sink, e); } finally { - DataBufferUtils.release(db); + + if (db != null && db.readableByteCount() == 0) { + DataBufferUtils.release(db); + } } }); - request(1); + sink.onCancel(this::terminatePendingReads); + sink.onDispose(this::terminatePendingReads); + sink.onRequest(this::request); }); } + void onError(FluxSink sink, Throwable e) { + + readRequests.poll(); + sink.error(e); + } + + void onComplete(FluxSink sink, int writtenBytes) { + + readRequests.poll(); + DEMAND.decrementAndGet(this); + sink.next(writtenBytes); + sink.complete(); + } + /* * (non-Javadoc) * @see com.mongodb.reactivestreams.client.gridfs.AsyncInputStream#skip(long) @@ -144,17 +182,19 @@ class AsyncInputStreamAdapter implements AsyncInputStream { cancelled = true; if (error != null) { + terminatePendingReads(); sink.error(error); return; } + terminatePendingReads(); sink.success(Success.SUCCESS); }); } - protected void request(int n) { + protected void request(long n) { - if (complete) { + if (allDataBuffersReceived && bufferQueue.isEmpty()) { terminatePendingReads(); return; @@ -176,18 +216,51 @@ class AsyncInputStreamAdapter implements AsyncInputStream { requestFromSubscription(subscription); } } + } void requestFromSubscription(Subscription subscription) { - long demand = DEMAND.get(AsyncInputStreamAdapter.this); - if (cancelled) { subscription.cancel(); } - if (demand > 0 && DEMAND.compareAndSet(AsyncInputStreamAdapter.this, demand, demand - 1)) { - subscription.request(1); + drainLoop(); + } + + void drainLoop() { + + while (DEMAND.get(AsyncInputStreamAdapter.this) > 0) { + + DataBuffer wip = bufferQueue.peek(); + + if (wip == null) { + break; + } + + if (wip.readableByteCount() == 0) { + bufferQueue.poll(); + continue; + } + + BiConsumer consumer = AsyncInputStreamAdapter.this.readRequests.peek(); + if (consumer == null) { + break; + } + + consumer.accept(wip, wip.readableByteCount()); + } + + if (bufferQueue.isEmpty()) { + + if (allDataBuffersReceived) { + terminatePendingReads(); + return; + } + + if (demand > 0) { + subscription.request(1); + } } } @@ -199,7 +272,7 @@ class AsyncInputStreamAdapter implements AsyncInputStream { BiConsumer readers; while ((readers = readRequests.poll()) != null) { - readers.accept(factory.wrap(new byte[0]), -1); + readers.accept(null, -1); } } @@ -214,23 +287,21 @@ class AsyncInputStreamAdapter implements AsyncInputStream { public void onSubscribe(Subscription s) { AsyncInputStreamAdapter.this.subscription = s; - - Operators.addCap(DEMAND, AsyncInputStreamAdapter.this, -1); s.request(1); } @Override public void onNext(DataBuffer dataBuffer) { - if (cancelled || complete) { + if (cancelled || allDataBuffersReceived) { DataBufferUtils.release(dataBuffer); Operators.onNextDropped(dataBuffer, AsyncInputStreamAdapter.this.subscriberContext); return; } - BiConsumer poll = AsyncInputStreamAdapter.this.readRequests.poll(); + BiConsumer readRequest = AsyncInputStreamAdapter.this.readRequests.peek(); - if (poll == null) { + if (readRequest == null) { DataBufferUtils.release(dataBuffer); Operators.onNextDropped(dataBuffer, AsyncInputStreamAdapter.this.subscriberContext); @@ -238,29 +309,31 @@ class AsyncInputStreamAdapter implements AsyncInputStream { return; } - poll.accept(dataBuffer, dataBuffer.readableByteCount()); + bufferQueue.offer(dataBuffer); - requestFromSubscription(subscription); + drainLoop(); } @Override public void onError(Throwable t) { - if (AsyncInputStreamAdapter.this.cancelled || AsyncInputStreamAdapter.this.complete) { + if (AsyncInputStreamAdapter.this.cancelled || AsyncInputStreamAdapter.this.allDataBuffersReceived) { Operators.onErrorDropped(t, AsyncInputStreamAdapter.this.subscriberContext); return; } AsyncInputStreamAdapter.this.error = t; - AsyncInputStreamAdapter.this.complete = true; + AsyncInputStreamAdapter.this.allDataBuffersReceived = true; terminatePendingReads(); } @Override public void onComplete() { - AsyncInputStreamAdapter.this.complete = true; - terminatePendingReads(); + AsyncInputStreamAdapter.this.allDataBuffersReceived = true; + if (bufferQueue.isEmpty()) { + terminatePendingReads(); + } } } }