diff --git a/spring-web-reactive/src/main/java/org/springframework/core/io/buffer/support/DataBufferUtils.java b/spring-web-reactive/src/main/java/org/springframework/core/io/buffer/support/DataBufferUtils.java index 8da723e6556..ad50c7b3efa 100644 --- a/spring-web-reactive/src/main/java/org/springframework/core/io/buffer/support/DataBufferUtils.java +++ b/spring-web-reactive/src/main/java/org/springframework/core/io/buffer/support/DataBufferUtils.java @@ -29,6 +29,7 @@ import org.reactivestreams.Publisher; import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; import reactor.core.publisher.Flux; +import reactor.core.publisher.FluxSource; import reactor.core.subscriber.SubscriberWithContext; import org.springframework.core.io.buffer.DataBuffer; @@ -105,49 +106,7 @@ public abstract class DataBufferUtils { Assert.notNull(publisher, "'publisher' must not be null"); Assert.isTrue(maxByteCount >= 0, "'maxByteCount' must be a positive number"); - return Flux.from(publisher).lift(subscriber -> new Subscriber() { - - private Subscription subscription; - - private final AtomicLong byteCount = new AtomicLong(); - - @Override - public void onSubscribe(Subscription s) { - this.subscription = s; - subscriber.onSubscribe(s); - } - - @Override - public void onNext(DataBuffer dataBuffer) { - int delta = dataBuffer.readableByteCount(); - long currentCount = this.byteCount.addAndGet(delta); - if (currentCount > maxByteCount) { - int size = (int) (maxByteCount - currentCount + delta); - ByteBuffer byteBuffer = - (ByteBuffer) dataBuffer.asByteBuffer().limit(size); - DataBuffer partialBuffer = - dataBuffer.allocator().allocateBuffer(size); - partialBuffer.write(byteBuffer); - - subscriber.onNext(partialBuffer); - subscriber.onComplete(); - this.subscription.cancel(); - } - else { - subscriber.onNext(dataBuffer); - } - } - - @Override - public void onError(Throwable t) { - subscriber.onError(t); - } - - @Override - public void onComplete() { - subscriber.onComplete(); - } - }); + return new TakeByteUntilCount(publisher, maxByteCount); } /** @@ -162,6 +121,63 @@ public abstract class DataBufferUtils { return false; } + private static final class TakeByteUntilCount extends FluxSource { + + final long maxByteCount; + + TakeByteUntilCount(Publisher source, long maxByteCount) { + super(source); + this.maxByteCount = maxByteCount; + } + + @Override + public void subscribe(Subscriber subscriber) { + source.subscribe(new Subscriber() { + + private Subscription subscription; + + private final AtomicLong byteCount = new AtomicLong(); + + @Override + public void onSubscribe(Subscription s) { + this.subscription = s; + subscriber.onSubscribe(s); + } + + @Override + public void onNext(DataBuffer dataBuffer) { + int delta = dataBuffer.readableByteCount(); + long currentCount = this.byteCount.addAndGet(delta); + if (currentCount > maxByteCount) { + int size = (int) (maxByteCount - currentCount + delta); + ByteBuffer byteBuffer = + (ByteBuffer) dataBuffer.asByteBuffer().limit(size); + DataBuffer partialBuffer = + dataBuffer.allocator().allocateBuffer(size); + partialBuffer.write(byteBuffer); + + subscriber.onNext(partialBuffer); + subscriber.onComplete(); + this.subscription.cancel(); + } + else { + subscriber.onNext(dataBuffer); + } + } + + @Override + public void onError(Throwable t) { + subscriber.onError(t); + } + + @Override + public void onComplete() { + subscriber.onComplete(); + } + }); + } + } + private static class ReadableByteChannelConsumer implements Consumer> {