From 48176a833c173febb7d24ae7dc1b9a630eb9b674 Mon Sep 17 00:00:00 2001 From: Mark Paluch Date: Mon, 21 Oct 2019 09:17:34 +0200 Subject: [PATCH] DATAMONGO-2393 - Fix BufferOverflow in GridFS upload. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit AsyncInputStreamAdapter now properly splits and buffers incoming DataBuffers according the read requests of AsyncInputStream.read(…) calls. Previously, the adapter used the input buffer size to be used as the output buffer size. A larger DataBuffer than the transfer buffer handed in through read(…) caused a BufferOverflow. Original Pull Request: #799 --- .../gridfs/AsyncInputStreamAdapter.java | 133 ++++++++++++++---- 1 file changed, 103 insertions(+), 30 deletions(-) diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/gridfs/AsyncInputStreamAdapter.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/gridfs/AsyncInputStreamAdapter.java index 51ea4a1dc..f9e3dc602 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/gridfs/AsyncInputStreamAdapter.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/gridfs/AsyncInputStreamAdapter.java @@ -17,6 +17,8 @@ package org.springframework.data.mongodb.gridfs; import lombok.RequiredArgsConstructor; import reactor.core.CoreSubscriber; +import reactor.core.publisher.Flux; +import reactor.core.publisher.FluxSink; import reactor.core.publisher.Mono; import reactor.core.publisher.Operators; import reactor.util.concurrent.Queues; @@ -25,14 +27,15 @@ import reactor.util.context.Context; import java.nio.ByteBuffer; import java.util.Queue; import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicLongFieldUpdater; import java.util.function.BiConsumer; import org.reactivestreams.Publisher; import org.reactivestreams.Subscription; + import org.springframework.core.io.buffer.DataBuffer; import org.springframework.core.io.buffer.DataBufferUtils; -import org.springframework.core.io.buffer.DefaultDataBufferFactory; import com.mongodb.reactivestreams.client.Success; import com.mongodb.reactivestreams.client.gridfs.AsyncInputStream; @@ -66,15 +69,16 @@ class AsyncInputStreamAdapter implements AsyncInputStream { private final Publisher buffers; private final Context subscriberContext; - private final DefaultDataBufferFactory factory = new DefaultDataBufferFactory(); private volatile Subscription subscription; private volatile boolean cancelled; - private volatile boolean complete; + private volatile boolean allDataBuffersReceived; private volatile Throwable error; private final Queue> readRequests = Queues.> small() .get(); + private final Queue bufferQueue = Queues. small().get(); + // see DEMAND volatile long demand; @@ -88,41 +92,75 @@ class AsyncInputStreamAdapter implements AsyncInputStream { @Override public Publisher read(ByteBuffer dst) { - return Mono.create(sink -> { + return Flux.create(sink -> { + AtomicLong written = new AtomicLong(); readRequests.offer((db, bytecount) -> { try { if (error != null) { - - sink.error(error); + onError(sink, error); return; } if (bytecount == -1) { - sink.success(-1); + onComplete(sink, written.get() > 0 ? written.intValue() : -1); return; } ByteBuffer byteBuffer = db.asByteBuffer(); - int toWrite = byteBuffer.remaining(); + int remaining = byteBuffer.remaining(); + int writeCapacity = Math.min(dst.remaining(), remaining); + int limit = Math.min(byteBuffer.position() + writeCapacity, byteBuffer.capacity()); + int toWrite = limit - byteBuffer.position(); + + if (toWrite == 0) { + onComplete(sink, written.intValue()); + return; + } + + int oldPosition = byteBuffer.position(); + + byteBuffer.limit(toWrite); dst.put(byteBuffer); - sink.success(toWrite); + byteBuffer.limit(byteBuffer.capacity()); + byteBuffer.position(oldPosition); + db.readPosition(db.readPosition() + toWrite); + written.addAndGet(toWrite); } catch (Exception e) { - sink.error(e); + onError(sink, e); } finally { - DataBufferUtils.release(db); + + if (db != null && db.readableByteCount() == 0) { + DataBufferUtils.release(db); + } } }); - request(1); + sink.onCancel(this::terminatePendingReads); + sink.onDispose(this::terminatePendingReads); + sink.onRequest(this::request); }); } + void onError(FluxSink sink, Throwable e) { + + readRequests.poll(); + sink.error(e); + } + + void onComplete(FluxSink sink, int writtenBytes) { + + readRequests.poll(); + DEMAND.decrementAndGet(this); + sink.next(writtenBytes); + sink.complete(); + } + /* * (non-Javadoc) * @see com.mongodb.reactivestreams.client.gridfs.AsyncInputStream#skip(long) @@ -144,17 +182,19 @@ class AsyncInputStreamAdapter implements AsyncInputStream { cancelled = true; if (error != null) { + terminatePendingReads(); sink.error(error); return; } + terminatePendingReads(); sink.success(Success.SUCCESS); }); } - protected void request(int n) { + protected void request(long n) { - if (complete) { + if (allDataBuffersReceived && bufferQueue.isEmpty()) { terminatePendingReads(); return; @@ -176,18 +216,51 @@ class AsyncInputStreamAdapter implements AsyncInputStream { requestFromSubscription(subscription); } } + } void requestFromSubscription(Subscription subscription) { - long demand = DEMAND.get(AsyncInputStreamAdapter.this); - if (cancelled) { subscription.cancel(); } - if (demand > 0 && DEMAND.compareAndSet(AsyncInputStreamAdapter.this, demand, demand - 1)) { - subscription.request(1); + drainLoop(); + } + + void drainLoop() { + + while (DEMAND.get(AsyncInputStreamAdapter.this) > 0) { + + DataBuffer wip = bufferQueue.peek(); + + if (wip == null) { + break; + } + + if (wip.readableByteCount() == 0) { + bufferQueue.poll(); + continue; + } + + BiConsumer consumer = AsyncInputStreamAdapter.this.readRequests.peek(); + if (consumer == null) { + break; + } + + consumer.accept(wip, wip.readableByteCount()); + } + + if (bufferQueue.isEmpty()) { + + if (allDataBuffersReceived) { + terminatePendingReads(); + return; + } + + if (demand > 0) { + subscription.request(1); + } } } @@ -199,7 +272,7 @@ class AsyncInputStreamAdapter implements AsyncInputStream { BiConsumer readers; while ((readers = readRequests.poll()) != null) { - readers.accept(factory.wrap(new byte[0]), -1); + readers.accept(null, -1); } } @@ -214,23 +287,21 @@ class AsyncInputStreamAdapter implements AsyncInputStream { public void onSubscribe(Subscription s) { AsyncInputStreamAdapter.this.subscription = s; - - Operators.addCap(DEMAND, AsyncInputStreamAdapter.this, -1); s.request(1); } @Override public void onNext(DataBuffer dataBuffer) { - if (cancelled || complete) { + if (cancelled || allDataBuffersReceived) { DataBufferUtils.release(dataBuffer); Operators.onNextDropped(dataBuffer, AsyncInputStreamAdapter.this.subscriberContext); return; } - BiConsumer poll = AsyncInputStreamAdapter.this.readRequests.poll(); + BiConsumer readRequest = AsyncInputStreamAdapter.this.readRequests.peek(); - if (poll == null) { + if (readRequest == null) { DataBufferUtils.release(dataBuffer); Operators.onNextDropped(dataBuffer, AsyncInputStreamAdapter.this.subscriberContext); @@ -238,29 +309,31 @@ class AsyncInputStreamAdapter implements AsyncInputStream { return; } - poll.accept(dataBuffer, dataBuffer.readableByteCount()); + bufferQueue.offer(dataBuffer); - requestFromSubscription(subscription); + drainLoop(); } @Override public void onError(Throwable t) { - if (AsyncInputStreamAdapter.this.cancelled || AsyncInputStreamAdapter.this.complete) { + if (AsyncInputStreamAdapter.this.cancelled || AsyncInputStreamAdapter.this.allDataBuffersReceived) { Operators.onErrorDropped(t, AsyncInputStreamAdapter.this.subscriberContext); return; } AsyncInputStreamAdapter.this.error = t; - AsyncInputStreamAdapter.this.complete = true; + AsyncInputStreamAdapter.this.allDataBuffersReceived = true; terminatePendingReads(); } @Override public void onComplete() { - AsyncInputStreamAdapter.this.complete = true; - terminatePendingReads(); + AsyncInputStreamAdapter.this.allDataBuffersReceived = true; + if (bufferQueue.isEmpty()) { + terminatePendingReads(); + } } } }