From 6361b0cb236e62b24a07af5d0a20f42d78d4efaa Mon Sep 17 00:00:00 2001 From: Brian Clozel Date: Thu, 20 Dec 2018 21:33:33 +0100 Subject: [PATCH] Write CharSequence instances to DataBuffers Prior to this commit, one could write a `CharSequence` to an existing `DataBuffer` instance by turning it into a byte array or `ByteBuffer` first. This had the following disadvantages: 1. Memory allocation was not efficient (not leveraging pooled memory when available) 2. Dealing with `CharsetEncoder` is not always easy 3. `DataBuffer` implementations, like `NettyDataBuffer` can use optimized implementations in some cases This commit adds a new `DataBuffer#write(CharSequence, Charset)` method for those cases and also an `ensureCapacity` method useful for checking that the current buffer has enough capacity to write to it.. Issue: SPR-17558 --- .../core/io/buffer/DataBuffer.java | 58 ++++++++++++++++- .../core/io/buffer/DefaultDataBuffer.java | 16 +++-- .../core/io/buffer/NettyDataBuffer.java | 29 ++++++++- .../core/io/buffer/DataBufferTests.java | 65 +++++++++++++++++++ .../core/io/buffer/LeakAwareDataBuffer.java | 11 ++++ .../reactive/UndertowServerHttpRequest.java | 11 ++++ .../reactive/function/BodyInsertersTests.java | 17 +++-- 7 files changed, 190 insertions(+), 17 deletions(-) diff --git a/spring-core/src/main/java/org/springframework/core/io/buffer/DataBuffer.java b/spring-core/src/main/java/org/springframework/core/io/buffer/DataBuffer.java index 6818c1be1a5..9041099449e 100644 --- a/spring-core/src/main/java/org/springframework/core/io/buffer/DataBuffer.java +++ b/spring-core/src/main/java/org/springframework/core/io/buffer/DataBuffer.java @@ -19,8 +19,15 @@ package org.springframework.core.io.buffer; import java.io.InputStream; import java.io.OutputStream; import java.nio.ByteBuffer; +import java.nio.CharBuffer; +import java.nio.charset.Charset; +import java.nio.charset.CharsetEncoder; +import java.nio.charset.CoderResult; +import java.nio.charset.CodingErrorAction; import java.util.function.IntPredicate; +import org.springframework.util.Assert; + /** * Basic abstraction over byte buffers. * @@ -45,6 +52,7 @@ import java.util.function.IntPredicate; * can also be used on non-Netty platforms (i.e. Servlet containers). * * @author Arjen Poutsma + * @author Brian Clozel * @since 5.0 * @see DataBufferFactory */ @@ -106,6 +114,16 @@ public interface DataBuffer { */ DataBuffer capacity(int capacity); + /** + * Ensure that the current buffer has enough {@link #writableByteCount()} + * to write the amount of data given as an argument. If not, the missing + * capacity will be added to the buffer. + * @param capacity the writable capacity to check for + * @return this buffer + * @since 5.1.4 + */ + DataBuffer ensureCapacity(int capacity); + /** * Return the position from which this buffer will read. * @return the read position @@ -181,7 +199,7 @@ public interface DataBuffer { DataBuffer write(byte b); /** - * Write the given source into this buffer, startin at the current writing position + * Write the given source into this buffer, starting at the current writing position * of this buffer. * @param source the bytes to be written into this buffer * @return this buffer @@ -215,6 +233,44 @@ public interface DataBuffer { */ DataBuffer write(ByteBuffer... buffers); + /** + * Write the given {@code CharSequence} using the given {@code Charset}, + * starting at the current writing position. + * @param charSequence the char sequence to write into this buffer + * @param charset the charset to encode the char sequence with + * @return this buffer + * @since 5.1.4 + */ + default DataBuffer write(CharSequence charSequence, Charset charset) { + Assert.notNull(charSequence, "'charSequence' must not be null"); + Assert.notNull(charset, "'charset' must not be null"); + CharsetEncoder charsetEncoder = charset.newEncoder() + .onMalformedInput(CodingErrorAction.REPLACE) + .onUnmappableCharacter(CodingErrorAction.REPLACE); + CharBuffer inBuffer = CharBuffer.wrap(charSequence); + int estimatedSize = (int) (inBuffer.remaining() * charsetEncoder.averageBytesPerChar()); + ByteBuffer outBuffer = ensureCapacity(estimatedSize) + .asByteBuffer(writePosition(), writableByteCount()); + for (; ; ) { + CoderResult cr = inBuffer.hasRemaining() ? + charsetEncoder.encode(inBuffer, outBuffer, true) : CoderResult.UNDERFLOW; + if (cr.isUnderflow()) { + cr = charsetEncoder.flush(outBuffer); + } + if (cr.isUnderflow()) { + break; + } + if (cr.isOverflow()) { + writePosition(outBuffer.position()); + int maximumSize = (int) (inBuffer.remaining() * charsetEncoder.maxBytesPerChar()); + ensureCapacity(maximumSize); + outBuffer = asByteBuffer(writePosition(), writableByteCount()); + } + } + writePosition(outBuffer.position()); + return this; + } + /** * Create a new {@code DataBuffer} whose contents is a shared subsequence of this * data buffer's content. Data between this data buffer and the returned buffer is diff --git a/spring-core/src/main/java/org/springframework/core/io/buffer/DefaultDataBuffer.java b/spring-core/src/main/java/org/springframework/core/io/buffer/DefaultDataBuffer.java index 37fe545843a..ced1890961a 100644 --- a/spring-core/src/main/java/org/springframework/core/io/buffer/DefaultDataBuffer.java +++ b/spring-core/src/main/java/org/springframework/core/io/buffer/DefaultDataBuffer.java @@ -215,6 +215,15 @@ public class DefaultDataBuffer implements DataBuffer { return this; } + @Override + public DataBuffer ensureCapacity(int length) { + if (length > writableByteCount()) { + int newCapacity = calculateCapacity(this.writePosition + length); + capacity(newCapacity); + } + return this; + } + private static ByteBuffer allocate(int capacity, boolean direct) { return direct ? ByteBuffer.allocateDirect(capacity) : ByteBuffer.allocate(capacity); } @@ -369,13 +378,6 @@ public class DefaultDataBuffer implements DataBuffer { return new DefaultDataBufferOutputStream(); } - private void ensureCapacity(int length) { - if (length <= writableByteCount()) { - return; - } - int newCapacity = calculateCapacity(this.writePosition + length); - capacity(newCapacity); - } /** * Calculate the capacity of the buffer. diff --git a/spring-core/src/main/java/org/springframework/core/io/buffer/NettyDataBuffer.java b/spring-core/src/main/java/org/springframework/core/io/buffer/NettyDataBuffer.java index 798a3f397c9..455d9d3afc4 100644 --- a/spring-core/src/main/java/org/springframework/core/io/buffer/NettyDataBuffer.java +++ b/spring-core/src/main/java/org/springframework/core/io/buffer/NettyDataBuffer.java @@ -19,11 +19,14 @@ package org.springframework.core.io.buffer; import java.io.InputStream; import java.io.OutputStream; import java.nio.ByteBuffer; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; import java.util.function.IntPredicate; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufInputStream; import io.netty.buffer.ByteBufOutputStream; +import io.netty.buffer.ByteBufUtil; import org.springframework.util.Assert; import org.springframework.util.ObjectUtils; @@ -138,6 +141,12 @@ public class NettyDataBuffer implements PooledDataBuffer { return this; } + @Override + public DataBuffer ensureCapacity(int capacity) { + this.byteBuf.ensureWritable(capacity); + return this; + } + @Override public byte read() { return this.byteBuf.readByte(); @@ -178,14 +187,14 @@ public class NettyDataBuffer implements PooledDataBuffer { if (!ObjectUtils.isEmpty(buffers)) { if (hasNettyDataBuffers(buffers)) { ByteBuf[] nativeBuffers = new ByteBuf[buffers.length]; - for (int i = 0 ; i < buffers.length; i++) { + for (int i = 0; i < buffers.length; i++) { nativeBuffers[i] = ((NettyDataBuffer) buffers[i]).getNativeBuffer(); } write(nativeBuffers); } else { ByteBuffer[] byteBuffers = new ByteBuffer[buffers.length]; - for (int i = 0 ; i < buffers.length; i++) { + for (int i = 0; i < buffers.length; i++) { byteBuffers[i] = buffers[i].asByteBuffer(); } @@ -229,6 +238,22 @@ public class NettyDataBuffer implements PooledDataBuffer { return this; } + @Override + public DataBuffer write(CharSequence charSequence, Charset charset) { + Assert.notNull(charSequence, "'charSequence' must not be null"); + Assert.notNull(charset, "'charset' must not be null"); + if (StandardCharsets.UTF_8.equals(charset)) { + ByteBufUtil.writeUtf8(this.byteBuf, charSequence); + } + else if (StandardCharsets.US_ASCII.equals(charset)) { + ByteBufUtil.writeAscii(this.byteBuf, charSequence); + } + else { + return PooledDataBuffer.super.write(charSequence, charset); + } + return this; + } + @Override public NettyDataBuffer slice(int index, int length) { ByteBuf slice = this.byteBuf.slice(index, length); diff --git a/spring-core/src/test/java/org/springframework/core/io/buffer/DataBufferTests.java b/spring-core/src/test/java/org/springframework/core/io/buffer/DataBufferTests.java index c76e0937ec7..5591e628b70 100644 --- a/spring-core/src/test/java/org/springframework/core/io/buffer/DataBufferTests.java +++ b/spring-core/src/test/java/org/springframework/core/io/buffer/DataBufferTests.java @@ -20,6 +20,7 @@ import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; import java.util.Arrays; import org.junit.Test; @@ -150,6 +151,70 @@ public class DataBufferTests extends AbstractDataBufferAllocatingTestCase { release(buffer); } + @Test + public void writeNullString() { + DataBuffer buffer = createDataBuffer(1); + try { + buffer.write(null, StandardCharsets.UTF_8); + fail("IllegalArgumentException expected"); + } + catch (IllegalArgumentException exc) { + } + finally { + release(buffer); + } + } + + @Test + public void writeNullCharset() { + DataBuffer buffer = createDataBuffer(1); + try { + buffer.write("test", null); + fail("IllegalArgumentException expected"); + } + catch (IllegalArgumentException exc) { + } + finally { + release(buffer); + } + } + + @Test + public void writeUtf8String() { + DataBuffer buffer = createDataBuffer(6); + buffer.write("Spring", StandardCharsets.UTF_8); + + byte[] result = new byte[6]; + buffer.read(result); + + assertArrayEquals("Spring".getBytes(StandardCharsets.UTF_8), result); + release(buffer); + } + + @Test + public void writeUtf8StringOutGrowsCapacity() { + DataBuffer buffer = createDataBuffer(5); + buffer.write("Spring €", StandardCharsets.UTF_8); + + byte[] result = new byte[10]; + buffer.read(result); + + assertArrayEquals("Spring €".getBytes(StandardCharsets.UTF_8), result); + release(buffer); + } + + @Test + public void writeIsoString() { + DataBuffer buffer = createDataBuffer(3); + buffer.write("\u00A3", StandardCharsets.ISO_8859_1); + + byte[] result = new byte[1]; + buffer.read(result); + + assertArrayEquals("\u00A3".getBytes(StandardCharsets.ISO_8859_1), result); + release(buffer); + } + @Test public void inputStream() throws IOException { DataBuffer buffer = createDataBuffer(4); diff --git a/spring-core/src/test/java/org/springframework/core/io/buffer/LeakAwareDataBuffer.java b/spring-core/src/test/java/org/springframework/core/io/buffer/LeakAwareDataBuffer.java index 94c194a1c1d..ee8fdd2a62b 100644 --- a/spring-core/src/test/java/org/springframework/core/io/buffer/LeakAwareDataBuffer.java +++ b/spring-core/src/test/java/org/springframework/core/io/buffer/LeakAwareDataBuffer.java @@ -19,6 +19,7 @@ package org.springframework.core.io.buffer; import java.io.InputStream; import java.io.OutputStream; import java.nio.ByteBuffer; +import java.nio.charset.Charset; import java.util.function.IntPredicate; import org.springframework.util.Assert; @@ -139,6 +140,11 @@ class LeakAwareDataBuffer implements PooledDataBuffer { return this.delegate.capacity(newCapacity); } + @Override + public DataBuffer ensureCapacity(int capacity) { + return this.delegate.ensureCapacity(capacity); + } + @Override public byte getByte(int index) { return this.delegate.getByte(index); @@ -184,6 +190,11 @@ class LeakAwareDataBuffer implements PooledDataBuffer { return this.delegate.write(byteBuffers); } + @Override + public DataBuffer write(CharSequence charSequence, Charset charset) { + return this.delegate.write(charSequence, charset); + } + @Override public DataBuffer slice(int index, int length) { return this.delegate.slice(index, length); diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpRequest.java b/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpRequest.java index 05e641a3ce6..49fbedf0a4d 100644 --- a/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpRequest.java +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpRequest.java @@ -23,6 +23,7 @@ import java.net.InetSocketAddress; import java.net.URI; import java.net.URISyntaxException; import java.nio.ByteBuffer; +import java.nio.charset.Charset; import java.util.function.IntPredicate; import javax.net.ssl.SSLSession; @@ -294,6 +295,11 @@ class UndertowServerHttpRequest extends AbstractServerHttpRequest { return this.dataBuffer.capacity(newCapacity); } + @Override + public DataBuffer ensureCapacity(int capacity) { + return this.dataBuffer.ensureCapacity(capacity); + } + @Override public byte getByte(int index) { return this.dataBuffer.getByte(index); @@ -343,6 +349,11 @@ class UndertowServerHttpRequest extends AbstractServerHttpRequest { return this.dataBuffer.write(byteBuffers); } + @Override + public DataBuffer write(CharSequence charSequence, Charset charset) { + return this.dataBuffer.write(charSequence, charset); + } + @Override public DataBuffer slice(int index, int length) { return this.dataBuffer.slice(index, length); diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/function/BodyInsertersTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/function/BodyInsertersTests.java index ff788331bc4..b352630cfbf 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/function/BodyInsertersTests.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/function/BodyInsertersTests.java @@ -30,6 +30,7 @@ import java.util.Map; import java.util.Optional; import com.fasterxml.jackson.annotation.JsonView; +import org.junit.Assert; import org.junit.Before; import org.junit.Test; import reactor.core.publisher.Flux; @@ -44,6 +45,7 @@ import org.springframework.core.io.buffer.DataBuffer; import org.springframework.core.io.buffer.DataBufferUtils; import org.springframework.core.io.buffer.DefaultDataBuffer; import org.springframework.core.io.buffer.DefaultDataBufferFactory; +import org.springframework.core.io.buffer.support.DataBufferTestUtils; import org.springframework.http.HttpMethod; import org.springframework.http.HttpRange; import org.springframework.http.ReactiveHttpOutputMessage; @@ -121,10 +123,11 @@ public class BodyInsertersTests { MockServerHttpResponse response = new MockServerHttpResponse(); Mono result = inserter.insert(response, this.context); StepVerifier.create(result).expectComplete().verify(); - - DataBuffer buffer = new DefaultDataBufferFactory().wrap(body.getBytes(UTF_8)); StepVerifier.create(response.getBody()) - .expectNext(buffer) + .consumeNextWith(buf -> { + String actual = DataBufferTestUtils.dumpString(buf, UTF_8); + Assert.assertEquals("foo", actual); + }) .expectComplete() .verify(); } @@ -166,11 +169,11 @@ public class BodyInsertersTests { MockServerHttpResponse response = new MockServerHttpResponse(); Mono result = inserter.insert(response, this.context); StepVerifier.create(result).expectComplete().verify(); - - ByteBuffer byteBuffer = ByteBuffer.wrap("foo".getBytes(UTF_8)); - DataBuffer buffer = new DefaultDataBufferFactory().wrap(byteBuffer); StepVerifier.create(response.getBody()) - .expectNext(buffer) + .consumeNextWith(buf -> { + String actual = DataBufferTestUtils.dumpString(buf, UTF_8); + Assert.assertEquals("foo", actual); + }) .expectComplete() .verify(); }