Browse Source

DATAMONGO-2393 - Fix BufferOverflow in GridFS upload.

AsyncInputStreamAdapter now properly splits and buffers incoming DataBuffers according the read requests of AsyncInputStream.read(…) calls.
Previously, the adapter used the input buffer size to be used as the output buffer size. A larger DataBuffer than the transfer buffer handed in through read(…) caused a BufferOverflow.

Original Pull Request: #799
pull/804/head
Mark Paluch 6 years ago committed by Christoph Strobl
parent
commit
48176a833c
  1. 133
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/gridfs/AsyncInputStreamAdapter.java

133
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/gridfs/AsyncInputStreamAdapter.java

@ -17,6 +17,8 @@ package org.springframework.data.mongodb.gridfs; @@ -17,6 +17,8 @@ package org.springframework.data.mongodb.gridfs;
import lombok.RequiredArgsConstructor;
import reactor.core.CoreSubscriber;
import reactor.core.publisher.Flux;
import reactor.core.publisher.FluxSink;
import reactor.core.publisher.Mono;
import reactor.core.publisher.Operators;
import reactor.util.concurrent.Queues;
@ -25,14 +27,15 @@ import reactor.util.context.Context; @@ -25,14 +27,15 @@ import reactor.util.context.Context;
import java.nio.ByteBuffer;
import java.util.Queue;
import java.util.concurrent.atomic.AtomicIntegerFieldUpdater;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicLongFieldUpdater;
import java.util.function.BiConsumer;
import org.reactivestreams.Publisher;
import org.reactivestreams.Subscription;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferUtils;
import org.springframework.core.io.buffer.DefaultDataBufferFactory;
import com.mongodb.reactivestreams.client.Success;
import com.mongodb.reactivestreams.client.gridfs.AsyncInputStream;
@ -66,15 +69,16 @@ class AsyncInputStreamAdapter implements AsyncInputStream { @@ -66,15 +69,16 @@ class AsyncInputStreamAdapter implements AsyncInputStream {
private final Publisher<? extends DataBuffer> buffers;
private final Context subscriberContext;
private final DefaultDataBufferFactory factory = new DefaultDataBufferFactory();
private volatile Subscription subscription;
private volatile boolean cancelled;
private volatile boolean complete;
private volatile boolean allDataBuffersReceived;
private volatile Throwable error;
private final Queue<BiConsumer<DataBuffer, Integer>> readRequests = Queues.<BiConsumer<DataBuffer, Integer>> small()
.get();
private final Queue<DataBuffer> bufferQueue = Queues.<DataBuffer> small().get();
// see DEMAND
volatile long demand;
@ -88,41 +92,75 @@ class AsyncInputStreamAdapter implements AsyncInputStream { @@ -88,41 +92,75 @@ class AsyncInputStreamAdapter implements AsyncInputStream {
@Override
public Publisher<Integer> read(ByteBuffer dst) {
return Mono.create(sink -> {
return Flux.create(sink -> {
AtomicLong written = new AtomicLong();
readRequests.offer((db, bytecount) -> {
try {
if (error != null) {
sink.error(error);
onError(sink, error);
return;
}
if (bytecount == -1) {
sink.success(-1);
onComplete(sink, written.get() > 0 ? written.intValue() : -1);
return;
}
ByteBuffer byteBuffer = db.asByteBuffer();
int toWrite = byteBuffer.remaining();
int remaining = byteBuffer.remaining();
int writeCapacity = Math.min(dst.remaining(), remaining);
int limit = Math.min(byteBuffer.position() + writeCapacity, byteBuffer.capacity());
int toWrite = limit - byteBuffer.position();
if (toWrite == 0) {
onComplete(sink, written.intValue());
return;
}
int oldPosition = byteBuffer.position();
byteBuffer.limit(toWrite);
dst.put(byteBuffer);
sink.success(toWrite);
byteBuffer.limit(byteBuffer.capacity());
byteBuffer.position(oldPosition);
db.readPosition(db.readPosition() + toWrite);
written.addAndGet(toWrite);
} catch (Exception e) {
sink.error(e);
onError(sink, e);
} finally {
DataBufferUtils.release(db);
if (db != null && db.readableByteCount() == 0) {
DataBufferUtils.release(db);
}
}
});
request(1);
sink.onCancel(this::terminatePendingReads);
sink.onDispose(this::terminatePendingReads);
sink.onRequest(this::request);
});
}
void onError(FluxSink<Integer> sink, Throwable e) {
readRequests.poll();
sink.error(e);
}
void onComplete(FluxSink<Integer> sink, int writtenBytes) {
readRequests.poll();
DEMAND.decrementAndGet(this);
sink.next(writtenBytes);
sink.complete();
}
/*
* (non-Javadoc)
* @see com.mongodb.reactivestreams.client.gridfs.AsyncInputStream#skip(long)
@ -144,17 +182,19 @@ class AsyncInputStreamAdapter implements AsyncInputStream { @@ -144,17 +182,19 @@ class AsyncInputStreamAdapter implements AsyncInputStream {
cancelled = true;
if (error != null) {
terminatePendingReads();
sink.error(error);
return;
}
terminatePendingReads();
sink.success(Success.SUCCESS);
});
}
protected void request(int n) {
protected void request(long n) {
if (complete) {
if (allDataBuffersReceived && bufferQueue.isEmpty()) {
terminatePendingReads();
return;
@ -176,18 +216,51 @@ class AsyncInputStreamAdapter implements AsyncInputStream { @@ -176,18 +216,51 @@ class AsyncInputStreamAdapter implements AsyncInputStream {
requestFromSubscription(subscription);
}
}
}
void requestFromSubscription(Subscription subscription) {
long demand = DEMAND.get(AsyncInputStreamAdapter.this);
if (cancelled) {
subscription.cancel();
}
if (demand > 0 && DEMAND.compareAndSet(AsyncInputStreamAdapter.this, demand, demand - 1)) {
subscription.request(1);
drainLoop();
}
void drainLoop() {
while (DEMAND.get(AsyncInputStreamAdapter.this) > 0) {
DataBuffer wip = bufferQueue.peek();
if (wip == null) {
break;
}
if (wip.readableByteCount() == 0) {
bufferQueue.poll();
continue;
}
BiConsumer<DataBuffer, Integer> consumer = AsyncInputStreamAdapter.this.readRequests.peek();
if (consumer == null) {
break;
}
consumer.accept(wip, wip.readableByteCount());
}
if (bufferQueue.isEmpty()) {
if (allDataBuffersReceived) {
terminatePendingReads();
return;
}
if (demand > 0) {
subscription.request(1);
}
}
}
@ -199,7 +272,7 @@ class AsyncInputStreamAdapter implements AsyncInputStream { @@ -199,7 +272,7 @@ class AsyncInputStreamAdapter implements AsyncInputStream {
BiConsumer<DataBuffer, Integer> readers;
while ((readers = readRequests.poll()) != null) {
readers.accept(factory.wrap(new byte[0]), -1);
readers.accept(null, -1);
}
}
@ -214,23 +287,21 @@ class AsyncInputStreamAdapter implements AsyncInputStream { @@ -214,23 +287,21 @@ class AsyncInputStreamAdapter implements AsyncInputStream {
public void onSubscribe(Subscription s) {
AsyncInputStreamAdapter.this.subscription = s;
Operators.addCap(DEMAND, AsyncInputStreamAdapter.this, -1);
s.request(1);
}
@Override
public void onNext(DataBuffer dataBuffer) {
if (cancelled || complete) {
if (cancelled || allDataBuffersReceived) {
DataBufferUtils.release(dataBuffer);
Operators.onNextDropped(dataBuffer, AsyncInputStreamAdapter.this.subscriberContext);
return;
}
BiConsumer<DataBuffer, Integer> poll = AsyncInputStreamAdapter.this.readRequests.poll();
BiConsumer<DataBuffer, Integer> readRequest = AsyncInputStreamAdapter.this.readRequests.peek();
if (poll == null) {
if (readRequest == null) {
DataBufferUtils.release(dataBuffer);
Operators.onNextDropped(dataBuffer, AsyncInputStreamAdapter.this.subscriberContext);
@ -238,29 +309,31 @@ class AsyncInputStreamAdapter implements AsyncInputStream { @@ -238,29 +309,31 @@ class AsyncInputStreamAdapter implements AsyncInputStream {
return;
}
poll.accept(dataBuffer, dataBuffer.readableByteCount());
bufferQueue.offer(dataBuffer);
requestFromSubscription(subscription);
drainLoop();
}
@Override
public void onError(Throwable t) {
if (AsyncInputStreamAdapter.this.cancelled || AsyncInputStreamAdapter.this.complete) {
if (AsyncInputStreamAdapter.this.cancelled || AsyncInputStreamAdapter.this.allDataBuffersReceived) {
Operators.onErrorDropped(t, AsyncInputStreamAdapter.this.subscriberContext);
return;
}
AsyncInputStreamAdapter.this.error = t;
AsyncInputStreamAdapter.this.complete = true;
AsyncInputStreamAdapter.this.allDataBuffersReceived = true;
terminatePendingReads();
}
@Override
public void onComplete() {
AsyncInputStreamAdapter.this.complete = true;
terminatePendingReads();
AsyncInputStreamAdapter.this.allDataBuffersReceived = true;
if (bufferQueue.isEmpty()) {
terminatePendingReads();
}
}
}
}

Loading…
Cancel
Save