diff --git a/spring-core/src/main/java/org/springframework/core/codec/AbstractDataBufferDecoder.java b/spring-core/src/main/java/org/springframework/core/codec/AbstractDataBufferDecoder.java index f6a29a747aa..ff01c0b47f4 100644 --- a/spring-core/src/main/java/org/springframework/core/codec/AbstractDataBufferDecoder.java +++ b/spring-core/src/main/java/org/springframework/core/codec/AbstractDataBufferDecoder.java @@ -48,12 +48,39 @@ import org.springframework.util.MimeType; @SuppressWarnings("deprecation") public abstract class AbstractDataBufferDecoder extends AbstractDecoder { + private int maxInMemorySize = 256 * 1024; + protected AbstractDataBufferDecoder(MimeType... supportedMimeTypes) { super(supportedMimeTypes); } + /** + * Configure a limit on the number of bytes that can be buffered whenever + * the input stream needs to be aggregated. This can be a result of + * decoding to a single {@code DataBuffer}, + * {@link java.nio.ByteBuffer ByteBuffer}, {@code byte[]}, + * {@link org.springframework.core.io.Resource Resource}, {@code String}, etc. + * It can also occur when splitting the input stream, e.g. delimited text, + * in which case the limit applies to data buffered between delimiters. + *

By default this is set to 256K. + * @param byteCount the max number of bytes to buffer, or -1 for unlimited + * @since 5.1.11 + */ + public void setMaxInMemorySize(int byteCount) { + this.maxInMemorySize = byteCount; + } + + /** + * Return the {@link #setMaxInMemorySize configured} byte count limit. + * @since 5.1.11 + */ + public int getMaxInMemorySize() { + return this.maxInMemorySize; + } + + @Override public Flux decode(Publisher input, ResolvableType elementType, @Nullable MimeType mimeType, @Nullable Map hints) { @@ -65,7 +92,7 @@ public abstract class AbstractDataBufferDecoder extends AbstractDecoder { public Mono decodeToMono(Publisher input, ResolvableType elementType, @Nullable MimeType mimeType, @Nullable Map hints) { - return DataBufferUtils.join(input) + return DataBufferUtils.join(input, this.maxInMemorySize) .map(buffer -> decodeDataBuffer(buffer, elementType, mimeType, hints)); } diff --git a/spring-core/src/main/java/org/springframework/core/codec/StringDecoder.java b/spring-core/src/main/java/org/springframework/core/codec/StringDecoder.java index d6d0ccfc271..2fa3cf0a405 100644 --- a/spring-core/src/main/java/org/springframework/core/codec/StringDecoder.java +++ b/spring-core/src/main/java/org/springframework/core/codec/StringDecoder.java @@ -25,15 +25,18 @@ import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; +import java.util.function.Consumer; import org.reactivestreams.Publisher; import reactor.core.publisher.Flux; import org.springframework.core.ResolvableType; import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferLimitException; import org.springframework.core.io.buffer.DataBufferUtils; import org.springframework.core.io.buffer.DataBufferWrapper; import org.springframework.core.io.buffer.DefaultDataBufferFactory; +import org.springframework.core.io.buffer.LimitedDataBufferList; import org.springframework.core.io.buffer.PooledDataBuffer; import org.springframework.core.log.LogFormatUtils; import org.springframework.lang.Nullable; @@ -91,12 +94,18 @@ public final class StringDecoder extends AbstractDataBufferDecoder { byte[][] delimiterBytes = getDelimiterBytes(mimeType); + // TODO: Drop Consumer and use bufferUntil with Supplier (reactor-core#1925) + // TODO: Drop doOnDiscard(LimitedDataBufferList.class, ...) (reactor-core#1924) + LimitedDataBufferConsumer limiter = new LimitedDataBufferConsumer(getMaxInMemorySize()); + Flux inputFlux = Flux.defer(() -> { DataBufferUtils.Matcher matcher = DataBufferUtils.matcher(delimiterBytes); return Flux.from(input) .concatMapIterable(buffer -> endFrameAfterDelimiter(buffer, matcher)) + .doOnNext(limiter) .bufferUntil(buffer -> buffer instanceof EndFrameBuffer) .map(buffers -> joinAndStrip(buffers, this.stripDelimiter)) + .doOnDiscard(LimitedDataBufferList.class, LimitedDataBufferList::releaseAndClear) .doOnDiscard(PooledDataBuffer.class, DataBufferUtils::release); }); @@ -279,4 +288,34 @@ public final class StringDecoder extends AbstractDataBufferDecoder { } + /** + * Temporary measure for reactor-core#1925. + * Consumer that adds to a {@link LimitedDataBufferList} to enforce limits. + */ + private static class LimitedDataBufferConsumer implements Consumer { + + private final LimitedDataBufferList bufferList; + + + public LimitedDataBufferConsumer(int maxInMemorySize) { + this.bufferList = new LimitedDataBufferList(maxInMemorySize); + } + + + @Override + public void accept(DataBuffer buffer) { + if (buffer instanceof EndFrameBuffer) { + this.bufferList.clear(); + } + else { + try { + this.bufferList.add(buffer); + } + catch (DataBufferLimitException ex) { + DataBufferUtils.release(buffer); + throw ex; + } + } + } + } } diff --git a/spring-core/src/main/java/org/springframework/core/io/buffer/DataBufferLimitException.java b/spring-core/src/main/java/org/springframework/core/io/buffer/DataBufferLimitException.java new file mode 100644 index 00000000000..ee606aed57f --- /dev/null +++ b/spring-core/src/main/java/org/springframework/core/io/buffer/DataBufferLimitException.java @@ -0,0 +1,37 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.core.io.buffer; + +/** + * Exception that indicates the cumulative number of bytes consumed from a + * stream of {@link DataBuffer DataBuffer}'s exceeded some pre-configured limit. + * This can be raised when data buffers are cached and aggregated, e.g. + * {@link DataBufferUtils#join}. Or it could also be raised when data buffers + * have been released but a parsed representation is being aggregated, e.g. async + * parsing with Jackson. + * + * @author Rossen Stoyanchev + * @since 5.1.11 + */ +@SuppressWarnings("serial") +public class DataBufferLimitException extends IllegalStateException { + + + public DataBufferLimitException(String message) { + super(message); + } + +} diff --git a/spring-core/src/main/java/org/springframework/core/io/buffer/DataBufferUtils.java b/spring-core/src/main/java/org/springframework/core/io/buffer/DataBufferUtils.java index 197e4762180..aaab44b9864 100644 --- a/spring-core/src/main/java/org/springframework/core/io/buffer/DataBufferUtils.java +++ b/spring-core/src/main/java/org/springframework/core/io/buffer/DataBufferUtils.java @@ -525,16 +525,35 @@ public abstract class DataBufferUtils { */ @SuppressWarnings("unchecked") public static Mono join(Publisher dataBuffers) { - Assert.notNull(dataBuffers, "'dataBuffers' must not be null"); + return join(dataBuffers, -1); + } + + /** + * Variant of {@link #join(Publisher)} that behaves the same way up until + * the specified max number of bytes to buffer. Once the limit is exceeded, + * {@link DataBufferLimitException} is raised. + * @param buffers the data buffers that are to be composed + * @param maxByteCount the max number of bytes to buffer, or -1 for unlimited + * @return a buffer with the aggregated content, possibly an empty Mono if + * the max number of bytes to buffer is exceeded. + * @throws DataBufferLimitException if maxByteCount is exceeded + * @since 5.1.11 + */ + @SuppressWarnings("unchecked") + public static Mono join(Publisher buffers, int maxByteCount) { + Assert.notNull(buffers, "'dataBuffers' must not be null"); - if (dataBuffers instanceof Mono) { - return (Mono) dataBuffers; + if (buffers instanceof Mono) { + return (Mono) buffers; } - return Flux.from(dataBuffers) - .collectList() + // TODO: Drop doOnDiscard(LimitedDataBufferList.class, ...) (reactor-core#1924) + + return Flux.from(buffers) + .collect(() -> new LimitedDataBufferList(maxByteCount), LimitedDataBufferList::add) .filter(list -> !list.isEmpty()) .map(list -> list.get(0).factory().join(list)) + .doOnDiscard(LimitedDataBufferList.class, LimitedDataBufferList::releaseAndClear) .doOnDiscard(PooledDataBuffer.class, DataBufferUtils::release); } diff --git a/spring-core/src/main/java/org/springframework/core/io/buffer/LimitedDataBufferList.java b/spring-core/src/main/java/org/springframework/core/io/buffer/LimitedDataBufferList.java new file mode 100644 index 00000000000..fb8c42aeeb0 --- /dev/null +++ b/spring-core/src/main/java/org/springframework/core/io/buffer/LimitedDataBufferList.java @@ -0,0 +1,157 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.core.io.buffer; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.function.Predicate; + +import reactor.core.publisher.Flux; + +/** + * Custom {@link List} to collect data buffers with and enforce a + * limit on the total number of bytes buffered. For use with "collect" or + * other buffering operators in declarative APIs, e.g. {@link Flux}. + * + *

Adding elements increases the byte count and if the limit is exceeded, + * {@link DataBufferLimitException} is raised. {@link #clear()} resets the + * count. Remove and set are not supported. + * + *

Note: This class does not automatically release the + * buffers it contains. It is usually preferable to use hooks such as + * {@link Flux#doOnDiscard} that also take care of cancel and error signals, + * or otherwise {@link #releaseAndClear()} can be used. + * + * @author Rossen Stoyanchev + * @since 5.1.11 + */ +@SuppressWarnings("serial") +public class LimitedDataBufferList extends ArrayList { + + private final int maxByteCount; + + private int byteCount; + + + public LimitedDataBufferList(int maxByteCount) { + this.maxByteCount = maxByteCount; + } + + + @Override + public boolean add(DataBuffer buffer) { + boolean result = super.add(buffer); + if (result) { + updateCount(buffer.readableByteCount()); + } + return result; + } + + @Override + public void add(int index, DataBuffer buffer) { + super.add(index, buffer); + updateCount(buffer.readableByteCount()); + } + + @Override + public boolean addAll(Collection collection) { + boolean result = super.addAll(collection); + collection.forEach(buffer -> updateCount(buffer.readableByteCount())); + return result; + } + + @Override + public boolean addAll(int index, Collection collection) { + boolean result = super.addAll(index, collection); + collection.forEach(buffer -> updateCount(buffer.readableByteCount())); + return result; + } + + private void updateCount(int bytesToAdd) { + if (this.maxByteCount < 0) { + return; + } + if (bytesToAdd > Integer.MAX_VALUE - this.byteCount) { + raiseLimitException(); + } + else { + this.byteCount += bytesToAdd; + if (this.byteCount > this.maxByteCount) { + raiseLimitException(); + } + } + } + + private void raiseLimitException() { + // Do not release here, it's likely down via doOnDiscard.. + throw new DataBufferLimitException( + "Exceeded limit on max bytes to buffer : " + this.maxByteCount); + } + + @Override + public DataBuffer remove(int index) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean remove(Object o) { + throw new UnsupportedOperationException(); + } + + @Override + protected void removeRange(int fromIndex, int toIndex) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean removeAll(Collection c) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean removeIf(Predicate filter) { + throw new UnsupportedOperationException(); + } + + @Override + public DataBuffer set(int index, DataBuffer element) { + throw new UnsupportedOperationException(); + } + + @Override + public void clear() { + this.byteCount = 0; + super.clear(); + } + + /** + * Shortcut to {@link DataBufferUtils#release release} all data buffers and + * then {@link #clear()}. + */ + public void releaseAndClear() { + forEach(buf -> { + try { + DataBufferUtils.release(buf); + } + catch (Throwable ex) { + // Keep going.. + } + }); + clear(); + } + +} diff --git a/spring-core/src/test/java/org/springframework/core/codec/ResourceRegionEncoderTests.java b/spring-core/src/test/java/org/springframework/core/codec/ResourceRegionEncoderTests.java index 18f374a005e..ce83d5bebdf 100644 --- a/spring-core/src/test/java/org/springframework/core/codec/ResourceRegionEncoderTests.java +++ b/spring-core/src/test/java/org/springframework/core/codec/ResourceRegionEncoderTests.java @@ -19,7 +19,6 @@ package org.springframework.core.codec; import java.util.Collections; import java.util.function.Consumer; -import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Test; import org.reactivestreams.Subscription; import reactor.core.publisher.BaseSubscriber; @@ -30,9 +29,9 @@ import reactor.test.StepVerifier; import org.springframework.core.ResolvableType; import org.springframework.core.io.ClassPathResource; import org.springframework.core.io.Resource; +import org.springframework.core.io.buffer.AbstractLeakCheckingTests; import org.springframework.core.io.buffer.DataBuffer; import org.springframework.core.io.buffer.DataBufferUtils; -import org.springframework.core.io.buffer.LeakAwareDataBufferFactory; import org.springframework.core.io.buffer.support.DataBufferTestUtils; import org.springframework.core.io.support.ResourceRegion; import org.springframework.util.MimeType; @@ -45,18 +44,10 @@ import static org.assertj.core.api.Assertions.assertThat; * Test cases for {@link ResourceRegionEncoder} class. * @author Brian Clozel */ -class ResourceRegionEncoderTests { +class ResourceRegionEncoderTests extends AbstractLeakCheckingTests { private ResourceRegionEncoder encoder = new ResourceRegionEncoder(); - private LeakAwareDataBufferFactory bufferFactory = new LeakAwareDataBufferFactory(); - - - @AfterEach - void tearDown() throws Exception { - this.bufferFactory.checkForLeaks(); - } - @Test void canEncode() { ResolvableType resourceRegion = ResolvableType.forClass(ResourceRegion.class); diff --git a/spring-core/src/test/java/org/springframework/core/codec/StringDecoderTests.java b/spring-core/src/test/java/org/springframework/core/codec/StringDecoderTests.java index 980f0a00721..a17ae084c2f 100644 --- a/spring-core/src/test/java/org/springframework/core/codec/StringDecoderTests.java +++ b/spring-core/src/test/java/org/springframework/core/codec/StringDecoderTests.java @@ -29,6 +29,7 @@ import reactor.test.StepVerifier; import org.springframework.core.ResolvableType; import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferLimitException; import org.springframework.util.MimeType; import org.springframework.util.MimeTypeUtils; @@ -127,6 +128,20 @@ class StringDecoderTests extends AbstractDecoderTests { .verify()); } + @Test + void decodeNewLineWithLimit() { + Flux input = Flux.just( + stringBuffer("abc\n"), + stringBuffer("defg\n"), + stringBuffer("hijkl\n") + ); + this.decoder.setMaxInMemorySize(5); + + testDecode(input, String.class, step -> + step.expectNext("abc", "defg") + .verifyError(DataBufferLimitException.class)); + } + @Test void decodeNewLineIncludeDelimiters() { this.decoder = StringDecoder.allMimeTypes(StringDecoder.DEFAULT_DELIMITERS, false); diff --git a/spring-core/src/test/java/org/springframework/core/io/buffer/DataBufferUtilsTests.java b/spring-core/src/test/java/org/springframework/core/io/buffer/DataBufferUtilsTests.java index 07671ba9c92..27067179cdc 100644 --- a/spring-core/src/test/java/org/springframework/core/io/buffer/DataBufferUtilsTests.java +++ b/spring-core/src/test/java/org/springframework/core/io/buffer/DataBufferUtilsTests.java @@ -813,13 +813,27 @@ class DataBufferUtilsTests extends AbstractDataBufferAllocatingTests { Mono result = DataBufferUtils.join(flux); StepVerifier.create(result) - .consumeNextWith(dataBuffer -> { - assertThat(DataBufferTestUtils.dumpString(dataBuffer, StandardCharsets.UTF_8)).isEqualTo("foobarbaz"); - release(dataBuffer); + .consumeNextWith(buf -> { + assertThat(DataBufferTestUtils.dumpString(buf, StandardCharsets.UTF_8)).isEqualTo("foobarbaz"); + release(buf); }) .verifyComplete(); } + @ParameterizedDataBufferAllocatingTest + void joinWithLimit(String displayName, DataBufferFactory bufferFactory) { + super.bufferFactory = bufferFactory; + + DataBuffer foo = stringBuffer("foo"); + DataBuffer bar = stringBuffer("bar"); + DataBuffer baz = stringBuffer("baz"); + Flux flux = Flux.just(foo, bar, baz); + Mono result = DataBufferUtils.join(flux, 8); + + StepVerifier.create(result) + .verifyError(DataBufferLimitException.class); + } + @ParameterizedDataBufferAllocatingTest void joinErrors(String displayName, DataBufferFactory bufferFactory) { super.bufferFactory = bufferFactory; diff --git a/spring-core/src/test/java/org/springframework/core/io/buffer/LimitedDataBufferListTests.java b/spring-core/src/test/java/org/springframework/core/io/buffer/LimitedDataBufferListTests.java new file mode 100644 index 00000000000..eeb816fe274 --- /dev/null +++ b/spring-core/src/test/java/org/springframework/core/io/buffer/LimitedDataBufferListTests.java @@ -0,0 +1,57 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.core.io.buffer; + +import java.nio.charset.StandardCharsets; + +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.Test; + +/** + * Unit tests for {@link LimitedDataBufferList}. + * @author Rossen Stoyanchev + * @since 5.1.11 + */ +public class LimitedDataBufferListTests { + + private final static DataBufferFactory bufferFactory = new DefaultDataBufferFactory(); + + + @Test + void limitEnforced() { + Assertions.assertThatThrownBy(() -> new LimitedDataBufferList(5).add(toDataBuffer("123456"))) + .isInstanceOf(DataBufferLimitException.class); + } + + @Test + void limitIgnored() { + new LimitedDataBufferList(-1).add(toDataBuffer("123456")); + } + + @Test + void clearResetsCount() { + LimitedDataBufferList list = new LimitedDataBufferList(5); + list.add(toDataBuffer("12345")); + list.clear(); + list.add(toDataBuffer("12345")); + } + + + private static DataBuffer toDataBuffer(String value) { + return bufferFactory.wrap(value.getBytes(StandardCharsets.UTF_8)); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/codec/CodecConfigurer.java b/spring-web/src/main/java/org/springframework/http/codec/CodecConfigurer.java index 17b820a9309..7ea35b9c6a6 100644 --- a/spring-web/src/main/java/org/springframework/http/codec/CodecConfigurer.java +++ b/spring-web/src/main/java/org/springframework/http/codec/CodecConfigurer.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -143,6 +143,20 @@ public interface CodecConfigurer { */ void jaxb2Encoder(Encoder encoder); + /** + * Configure a limit on the number of bytes that can be buffered whenever + * the input stream needs to be aggregated. This can be a result of + * decoding to a single {@code DataBuffer}, + * {@link java.nio.ByteBuffer ByteBuffer}, {@code byte[]}, + * {@link org.springframework.core.io.Resource Resource}, {@code String}, etc. + * It can also occur when splitting the input stream, e.g. delimited text, + * in which case the limit applies to data buffered between delimiters. + *

By default this is not set, in which case individual codec defaults + * apply. All codecs are limited to 256K by default. + * @param byteCount the max number of bytes to buffer, or -1 for unlimited + * @sine 5.1.11 + */ + void maxInMemorySize(int byteCount); /** * Whether to log form data at DEBUG level, and headers at TRACE level. * Both may contain sensitive information. diff --git a/spring-web/src/main/java/org/springframework/http/codec/FormHttpMessageReader.java b/spring-web/src/main/java/org/springframework/http/codec/FormHttpMessageReader.java index 01c2d30b33b..63e4a10259f 100644 --- a/spring-web/src/main/java/org/springframework/http/codec/FormHttpMessageReader.java +++ b/spring-web/src/main/java/org/springframework/http/codec/FormHttpMessageReader.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -30,6 +30,7 @@ import reactor.core.publisher.Mono; import org.springframework.core.ResolvableType; import org.springframework.core.codec.Hints; +import org.springframework.core.io.buffer.DataBufferLimitException; import org.springframework.core.io.buffer.DataBufferUtils; import org.springframework.core.log.LogFormatUtils; import org.springframework.http.MediaType; @@ -62,6 +63,8 @@ public class FormHttpMessageReader extends LoggingCodecSupport private Charset defaultCharset = DEFAULT_CHARSET; + private int maxInMemorySize = 256 * 1024; + /** * Set the default character set to use for reading form data when the @@ -80,6 +83,26 @@ public class FormHttpMessageReader extends LoggingCodecSupport return this.defaultCharset; } + /** + * Set the max number of bytes for input form data. As form data is buffered + * before it is parsed, this helps to limit the amount of buffering. Once + * the limit is exceeded, {@link DataBufferLimitException} is raised. + *

By default this is set to 256K. + * @param byteCount the max number of bytes to buffer, or -1 for unlimited + * @since 5.1.11 + */ + public void setMaxInMemorySize(int byteCount) { + this.maxInMemorySize = byteCount; + } + + /** + * Return the {@link #setMaxInMemorySize configured} byte count limit. + * @since 5.1.11 + */ + public int getMaxInMemorySize() { + return this.maxInMemorySize; + } + @Override public boolean canRead(ResolvableType elementType, @Nullable MediaType mediaType) { @@ -105,7 +128,7 @@ public class FormHttpMessageReader extends LoggingCodecSupport MediaType contentType = message.getHeaders().getContentType(); Charset charset = getMediaTypeCharset(contentType); - return DataBufferUtils.join(message.getBody()) + return DataBufferUtils.join(message.getBody(), this.maxInMemorySize) .map(buffer -> { CharBuffer charBuffer = charset.decode(buffer.asByteBuffer()); String body = charBuffer.toString(); diff --git a/spring-web/src/main/java/org/springframework/http/codec/ServerCodecConfigurer.java b/spring-web/src/main/java/org/springframework/http/codec/ServerCodecConfigurer.java index fde144e3cbe..1479390b525 100644 --- a/spring-web/src/main/java/org/springframework/http/codec/ServerCodecConfigurer.java +++ b/spring-web/src/main/java/org/springframework/http/codec/ServerCodecConfigurer.java @@ -76,6 +76,20 @@ public interface ServerCodecConfigurer extends CodecConfigurer { */ interface ServerDefaultCodecs extends DefaultCodecs { + /** + * Configure the {@code HttpMessageReader} to use for multipart requests. + *

By default, if + * Synchronoss NIO Multipart + * is present, this is set to + * {@link org.springframework.http.codec.multipart.MultipartHttpMessageReader + * MultipartHttpMessageReader} created with an instance of + * {@link org.springframework.http.codec.multipart.SynchronossPartHttpMessageReader + * SynchronossPartHttpMessageReader}. + * @param reader the message reader to use for multipart requests. + * @since 5.1.11 + */ + void multipartReader(HttpMessageReader reader); + /** * Configure the {@code Encoder} to use for Server-Sent Events. *

By default if this is not set, and Jackson is available, the diff --git a/spring-web/src/main/java/org/springframework/http/codec/json/AbstractJackson2Decoder.java b/spring-web/src/main/java/org/springframework/http/codec/json/AbstractJackson2Decoder.java index 595ce6f2810..bf7f68d8b82 100644 --- a/spring-web/src/main/java/org/springframework/http/codec/json/AbstractJackson2Decoder.java +++ b/spring-web/src/main/java/org/springframework/http/codec/json/AbstractJackson2Decoder.java @@ -37,6 +37,7 @@ import org.springframework.core.codec.CodecException; import org.springframework.core.codec.DecodingException; import org.springframework.core.codec.Hints; import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferLimitException; import org.springframework.core.io.buffer.DataBufferUtils; import org.springframework.core.log.LogFormatUtils; import org.springframework.http.codec.HttpMessageDecoder; @@ -59,6 +60,9 @@ import org.springframework.util.MimeType; */ public abstract class AbstractJackson2Decoder extends Jackson2CodecSupport implements HttpMessageDecoder { + private int maxInMemorySize = 256 * 1024; + + /** * Constructor with a Jackson {@link ObjectMapper} to use. */ @@ -67,6 +71,28 @@ public abstract class AbstractJackson2Decoder extends Jackson2CodecSupport imple } + /** + * Set the max number of bytes that can be buffered by this decoder. This + * is either the size of the entire input when decoding as a whole, or the + * size of one top-level JSON object within a JSON stream. When the limit + * is exceeded, {@link DataBufferLimitException} is raised. + *

By default this is set to 256K. + * @param byteCount the max number of bytes to buffer, or -1 for unlimited + * @since 5.1.11 + */ + public void setMaxInMemorySize(int byteCount) { + this.maxInMemorySize = byteCount; + } + + /** + * Return the {@link #setMaxInMemorySize configured} byte count limit. + * @since 5.1.11 + */ + public int getMaxInMemorySize() { + return this.maxInMemorySize; + } + + @Override public boolean canDecode(ResolvableType elementType, @Nullable MimeType mimeType) { JavaType javaType = getObjectMapper().getTypeFactory().constructType(elementType.getType()); @@ -81,7 +107,7 @@ public abstract class AbstractJackson2Decoder extends Jackson2CodecSupport imple ObjectMapper mapper = getObjectMapper(); Flux tokens = Jackson2Tokenizer.tokenize( - Flux.from(input), mapper.getFactory(), mapper, true); + Flux.from(input), mapper.getFactory(), mapper, true, getMaxInMemorySize()); ObjectReader reader = getObjectReader(elementType, hints); @@ -103,7 +129,7 @@ public abstract class AbstractJackson2Decoder extends Jackson2CodecSupport imple public Mono decodeToMono(Publisher input, ResolvableType elementType, @Nullable MimeType mimeType, @Nullable Map hints) { - return DataBufferUtils.join(input) + return DataBufferUtils.join(input, this.maxInMemorySize) .map(dataBuffer -> decode(dataBuffer, elementType, mimeType, hints)); } diff --git a/spring-web/src/main/java/org/springframework/http/codec/json/Jackson2Tokenizer.java b/spring-web/src/main/java/org/springframework/http/codec/json/Jackson2Tokenizer.java index 19a03677aa0..1846d3e3bf3 100644 --- a/spring-web/src/main/java/org/springframework/http/codec/json/Jackson2Tokenizer.java +++ b/spring-web/src/main/java/org/springframework/http/codec/json/Jackson2Tokenizer.java @@ -35,6 +35,7 @@ import reactor.core.publisher.Flux; import org.springframework.core.codec.DecodingException; import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferLimitException; import org.springframework.core.io.buffer.DataBufferUtils; /** @@ -61,30 +62,39 @@ final class Jackson2Tokenizer { private int arrayDepth; + private final int maxInMemorySize; + + private int byteCount; + + // TODO: change to ByteBufferFeeder when supported by Jackson // See https://github.com/FasterXML/jackson-core/issues/478 private final ByteArrayFeeder inputFeeder; - private Jackson2Tokenizer( - JsonParser parser, DeserializationContext deserializationContext, boolean tokenizeArrayElements) { + private Jackson2Tokenizer(JsonParser parser, DeserializationContext deserializationContext, + boolean tokenizeArrayElements, int maxInMemorySize) { this.parser = parser; this.deserializationContext = deserializationContext; this.tokenizeArrayElements = tokenizeArrayElements; this.tokenBuffer = new TokenBuffer(parser, deserializationContext); this.inputFeeder = (ByteArrayFeeder) this.parser.getNonBlockingInputFeeder(); + this.maxInMemorySize = maxInMemorySize; } private List tokenize(DataBuffer dataBuffer) { - byte[] bytes = new byte[dataBuffer.readableByteCount()]; + int bufferSize = dataBuffer.readableByteCount(); + byte[] bytes = new byte[bufferSize]; dataBuffer.read(bytes); DataBufferUtils.release(dataBuffer); try { this.inputFeeder.feedInput(bytes, 0, bytes.length); - return parseTokenBufferFlux(); + List result = parseTokenBufferFlux(); + assertInMemorySize(bufferSize, result); + return result; } catch (JsonProcessingException ex) { throw new DecodingException("JSON decoding error: " + ex.getOriginalMessage(), ex); @@ -174,18 +184,40 @@ final class Jackson2Tokenizer { (token == JsonToken.END_ARRAY && this.arrayDepth == 0)); } + private void assertInMemorySize(int currentBufferSize, List result) { + if (this.maxInMemorySize >= 0) { + if (!result.isEmpty()) { + this.byteCount = 0; + } + else if (currentBufferSize > Integer.MAX_VALUE - this.byteCount) { + raiseLimitException(); + } + else { + this.byteCount += currentBufferSize; + if (this.byteCount > this.maxInMemorySize) { + raiseLimitException(); + } + } + } + } + + private void raiseLimitException() { + throw new DataBufferLimitException( + "Exceeded limit on max bytes per JSON object: " + this.maxInMemorySize); + } + /** * Tokenize the given {@code Flux} into {@code Flux}. * @param dataBuffers the source data buffers * @param jsonFactory the factory to use * @param objectMapper the current mapper instance - * @param tokenizeArrayElements if {@code true} and the "top level" JSON object is + * @param tokenizeArrays if {@code true} and the "top level" JSON object is * an array, each element is returned individually immediately after it is received * @return the resulting token buffers */ public static Flux tokenize(Flux dataBuffers, JsonFactory jsonFactory, - ObjectMapper objectMapper, boolean tokenizeArrayElements) { + ObjectMapper objectMapper, boolean tokenizeArrays, int maxInMemorySize) { try { JsonParser parser = jsonFactory.createNonBlockingByteArrayParser(); @@ -194,7 +226,7 @@ final class Jackson2Tokenizer { context = ((DefaultDeserializationContext) context).createInstance( objectMapper.getDeserializationConfig(), parser, objectMapper.getInjectableValues()); } - Jackson2Tokenizer tokenizer = new Jackson2Tokenizer(parser, context, tokenizeArrayElements); + Jackson2Tokenizer tokenizer = new Jackson2Tokenizer(parser, context, tokenizeArrays, maxInMemorySize); return dataBuffers.concatMapIterable(tokenizer::tokenize).concatWith(tokenizer.endOfInput()); } catch (IOException ex) { diff --git a/spring-web/src/main/java/org/springframework/http/codec/multipart/MultipartHttpMessageReader.java b/spring-web/src/main/java/org/springframework/http/codec/multipart/MultipartHttpMessageReader.java index c2392921cdc..0d47dd6fcef 100644 --- a/spring-web/src/main/java/org/springframework/http/codec/multipart/MultipartHttpMessageReader.java +++ b/spring-web/src/main/java/org/springframework/http/codec/multipart/MultipartHttpMessageReader.java @@ -65,6 +65,14 @@ public class MultipartHttpMessageReader extends LoggingCodecSupport } + /** + * Return the configured parts reader. + * @since 5.1.11 + */ + public HttpMessageReader getPartReader() { + return this.partReader; + } + @Override public List getReadableMediaTypes() { return Collections.singletonList(MediaType.MULTIPART_FORM_DATA); diff --git a/spring-web/src/main/java/org/springframework/http/codec/multipart/SynchronossPartHttpMessageReader.java b/spring-web/src/main/java/org/springframework/http/codec/multipart/SynchronossPartHttpMessageReader.java index 642e16688f7..86f3db25718 100644 --- a/spring-web/src/main/java/org/springframework/http/codec/multipart/SynchronossPartHttpMessageReader.java +++ b/spring-web/src/main/java/org/springframework/http/codec/multipart/SynchronossPartHttpMessageReader.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -40,14 +40,18 @@ import org.synchronoss.cloud.nio.multipart.NioMultipartParser; import org.synchronoss.cloud.nio.multipart.NioMultipartParserListener; import org.synchronoss.cloud.nio.multipart.PartBodyStreamStorageFactory; import org.synchronoss.cloud.nio.stream.storage.StreamStorage; +import reactor.core.publisher.BaseSubscriber; import reactor.core.publisher.Flux; import reactor.core.publisher.FluxSink; import reactor.core.publisher.Mono; +import reactor.core.publisher.SignalType; import org.springframework.core.ResolvableType; +import org.springframework.core.codec.DecodingException; import org.springframework.core.codec.Hints; import org.springframework.core.io.buffer.DataBuffer; import org.springframework.core.io.buffer.DataBufferFactory; +import org.springframework.core.io.buffer.DataBufferLimitException; import org.springframework.core.io.buffer.DataBufferUtils; import org.springframework.core.io.buffer.DefaultDataBufferFactory; import org.springframework.core.log.LogFormatUtils; @@ -69,15 +73,82 @@ import org.springframework.util.Assert; * @author Sebastien Deleuze * @author Rossen Stoyanchev * @author Arjen Poutsma + * @author Brian Clozel * @since 5.0 * @see Synchronoss NIO Multipart * @see MultipartHttpMessageReader */ public class SynchronossPartHttpMessageReader extends LoggingCodecSupport implements HttpMessageReader { - private final DataBufferFactory bufferFactory = new DefaultDataBufferFactory(); + // Static DataBufferFactory to copy from FileInputStream or wrap bytes[]. + private static final DataBufferFactory bufferFactory = new DefaultDataBufferFactory(); - private final PartBodyStreamStorageFactory streamStorageFactory = new DefaultPartBodyStreamStorageFactory(); + + private int maxInMemorySize = 256 * 1024; + + private long maxDiskUsagePerPart = -1; + + private long maxParts = -1; + + + /** + * Configure the maximum amount of memory that is allowed to use per part. + * When the limit is exceeded: + *
    + *
  • file parts are written to a temporary file. + *
  • non-file parts are rejected with {@link DataBufferLimitException}. + *
+ *

By default this is set to 256K. + * @param byteCount the in-memory limit in bytes; if set to -1 this limit is + * not enforced, and all parts may be written to disk and are limited only + * by the {@link #setMaxDiskUsagePerPart(long) maxDiskUsagePerPart} property. + * @since 5.1.11 + */ + public void setMaxInMemorySize(int byteCount) { + this.maxInMemorySize = byteCount; + } + + /** + * Get the {@link #setMaxInMemorySize configured} maximum in-memory size. + * @since 5.1.11 + */ + public int getMaxInMemorySize() { + return this.maxInMemorySize; + } + + /** + * Configure the maximum amount of disk space allowed for file parts. + *

By default this is set to -1. + * @param maxDiskUsagePerPart the disk limit in bytes, or -1 for unlimited + * @since 5.1.11 + */ + public void setMaxDiskUsagePerPart(long maxDiskUsagePerPart) { + this.maxDiskUsagePerPart = maxDiskUsagePerPart; + } + + /** + * Get the {@link #setMaxDiskUsagePerPart configured} maximum disk usage. + * @since 5.1.11 + */ + public long getMaxDiskUsagePerPart() { + return this.maxDiskUsagePerPart; + } + + /** + * Specify the maximum number of parts allowed in a given multipart request. + * @since 5.1.11 + */ + public void setMaxParts(long maxParts) { + this.maxParts = maxParts; + } + + /** + * Return the {@link #setMaxParts configured} limit on the number of parts. + * @since 5.1.11 + */ + public long getMaxParts() { + return this.maxParts; + } @Override @@ -94,7 +165,7 @@ public class SynchronossPartHttpMessageReader extends LoggingCodecSupport implem @Override public Flux read(ResolvableType elementType, ReactiveHttpInputMessage message, Map hints) { - return Flux.create(new SynchronossPartGenerator(message, this.bufferFactory, this.streamStorageFactory)) + return Flux.create(new SynchronossPartGenerator(message)) .doOnNext(part -> { if (!Hints.isLoggingSuppressed(hints)) { LogFormatUtils.traceDebug(logger, traceOn -> Hints.getLogPrefix(hints) + "Parsed " + @@ -107,33 +178,36 @@ public class SynchronossPartHttpMessageReader extends LoggingCodecSupport implem @Override - public Mono readMono(ResolvableType elementType, ReactiveHttpInputMessage message, Map hints) { - return Mono.error(new UnsupportedOperationException("Cannot read multipart request body into single Part")); + public Mono readMono( + ResolvableType elementType, ReactiveHttpInputMessage message, Map hints) { + + return Mono.error(new UnsupportedOperationException( + "Cannot read multipart request body into single Part")); } /** - * Consume and feed input to the Synchronoss parser, then listen for parser - * output events and adapt to {@code Flux>}. + * Subscribe to the input stream and feed the Synchronoss parser. Then listen + * for parser output, creating parts, and pushing them into the FluxSink. */ - private static class SynchronossPartGenerator implements Consumer> { + private class SynchronossPartGenerator extends BaseSubscriber implements Consumer> { private final ReactiveHttpInputMessage inputMessage; - private final DataBufferFactory bufferFactory; + private final LimitedPartBodyStreamStorageFactory storageFactory = new LimitedPartBodyStreamStorageFactory(); - private final PartBodyStreamStorageFactory streamStorageFactory; + private NioMultipartParserListener listener; - SynchronossPartGenerator(ReactiveHttpInputMessage inputMessage, DataBufferFactory bufferFactory, - PartBodyStreamStorageFactory streamStorageFactory) { + private NioMultipartParser parser; + + public SynchronossPartGenerator(ReactiveHttpInputMessage inputMessage) { this.inputMessage = inputMessage; - this.bufferFactory = bufferFactory; - this.streamStorageFactory = streamStorageFactory; } + @Override - public void accept(FluxSink emitter) { + public void accept(FluxSink sink) { HttpHeaders headers = this.inputMessage.getHeaders(); MediaType mediaType = headers.getContentType(); Assert.state(mediaType != null, "No content type set"); @@ -142,40 +216,57 @@ public class SynchronossPartHttpMessageReader extends LoggingCodecSupport implem Charset charset = Optional.ofNullable(mediaType.getCharset()).orElse(StandardCharsets.UTF_8); MultipartContext context = new MultipartContext(mediaType.toString(), length, charset.name()); - NioMultipartParserListener listener = new FluxSinkAdapterListener(emitter, this.bufferFactory, context); - NioMultipartParser parser = Multipart + this.listener = new FluxSinkAdapterListener(sink, context, this.storageFactory); + + this.parser = Multipart .multipart(context) - .usePartBodyStreamStorageFactory(this.streamStorageFactory) - .forNIO(listener); - - this.inputMessage.getBody().subscribe(buffer -> { - byte[] resultBytes = new byte[buffer.readableByteCount()]; - buffer.read(resultBytes); - try { - parser.write(resultBytes); - } - catch (IOException ex) { - listener.onError("Exception thrown providing input to the parser", ex); - } - finally { - DataBufferUtils.release(buffer); - } - }, ex -> { - try { - listener.onError("Request body input error", ex); - parser.close(); - } - catch (IOException ex2) { - listener.onError("Exception thrown while closing the parser", ex2); - } - }, () -> { - try { - parser.close(); - } - catch (IOException ex) { - listener.onError("Exception thrown while closing the parser", ex); - } - }); + .usePartBodyStreamStorageFactory(this.storageFactory) + .forNIO(this.listener); + + this.inputMessage.getBody().subscribe(this); + } + + @Override + protected void hookOnNext(DataBuffer buffer) { + int size = buffer.readableByteCount(); + this.storageFactory.increaseByteCount(size); + byte[] resultBytes = new byte[size]; + buffer.read(resultBytes); + try { + this.parser.write(resultBytes); + } + catch (IOException ex) { + cancel(); + int index = this.storageFactory.getCurrentPartIndex(); + this.listener.onError("Parser error for part [" + index + "]", ex); + } + finally { + DataBufferUtils.release(buffer); + } + } + + @Override + protected void hookOnError(Throwable ex) { + try { + this.parser.close(); + } + catch (IOException ex2) { + // ignore + } + finally { + int index = this.storageFactory.getCurrentPartIndex(); + this.listener.onError("Failure while parsing part[" + index + "]", ex); + } + } + + @Override + protected void hookFinally(SignalType type) { + try { + this.parser.close(); + } + catch (IOException ex) { + this.listener.onError("Error while closing parser", ex); + } } private int getContentLength(HttpHeaders headers) { @@ -186,6 +277,54 @@ public class SynchronossPartHttpMessageReader extends LoggingCodecSupport implem } + private class LimitedPartBodyStreamStorageFactory implements PartBodyStreamStorageFactory { + + private final PartBodyStreamStorageFactory storageFactory = maxInMemorySize > 0 ? + new DefaultPartBodyStreamStorageFactory(maxInMemorySize) : + new DefaultPartBodyStreamStorageFactory(); + + private int index = 1; + + private boolean isFilePart; + + private long partSize; + + + public int getCurrentPartIndex() { + return this.index; + } + + @Override + public StreamStorage newStreamStorageForPartBody(Map> headers, int index) { + this.index = index; + this.isFilePart = (MultipartUtils.getFileName(headers) != null); + this.partSize = 0; + if (maxParts > 0 && index > maxParts) { + throw new DecodingException("Too many parts (" + index + " allowed)"); + } + return this.storageFactory.newStreamStorageForPartBody(headers, index); + } + + public void increaseByteCount(long byteCount) { + this.partSize += byteCount; + if (maxInMemorySize > 0 && !this.isFilePart && this.partSize >= maxInMemorySize) { + throw new DataBufferLimitException("Part[" + this.index + "] " + + "exceeded the in-memory limit of " + maxInMemorySize + " bytes"); + } + if (maxDiskUsagePerPart > 0 && this.isFilePart && this.partSize > maxDiskUsagePerPart) { + throw new DecodingException("Part[" + this.index + "] " + + "exceeded the disk usage limit of " + maxDiskUsagePerPart + " bytes"); + } + } + + public void partFinished() { + this.index++; + this.isFilePart = false; + this.partSize = 0; + } + } + + /** * Listen for parser output and adapt to {@code Flux>}. */ @@ -193,43 +332,48 @@ public class SynchronossPartHttpMessageReader extends LoggingCodecSupport implem private final FluxSink sink; - private final DataBufferFactory bufferFactory; - private final MultipartContext context; + private final LimitedPartBodyStreamStorageFactory storageFactory; + private final AtomicInteger terminated = new AtomicInteger(0); - FluxSinkAdapterListener(FluxSink sink, DataBufferFactory factory, MultipartContext context) { + + FluxSinkAdapterListener( + FluxSink sink, MultipartContext context, LimitedPartBodyStreamStorageFactory factory) { + this.sink = sink; - this.bufferFactory = factory; this.context = context; + this.storageFactory = factory; } + @Override public void onPartFinished(StreamStorage storage, Map> headers) { HttpHeaders httpHeaders = new HttpHeaders(); httpHeaders.putAll(headers); + this.storageFactory.partFinished(); this.sink.next(createPart(storage, httpHeaders)); } private Part createPart(StreamStorage storage, HttpHeaders httpHeaders) { String filename = MultipartUtils.getFileName(httpHeaders); if (filename != null) { - return new SynchronossFilePart(httpHeaders, filename, storage, this.bufferFactory); + return new SynchronossFilePart(httpHeaders, filename, storage); } else if (MultipartUtils.isFormField(httpHeaders, this.context)) { String value = MultipartUtils.readFormParameterValue(storage, httpHeaders); - return new SynchronossFormFieldPart(httpHeaders, this.bufferFactory, value); + return new SynchronossFormFieldPart(httpHeaders, value); } else { - return new SynchronossPart(httpHeaders, storage, this.bufferFactory); + return new SynchronossPart(httpHeaders, storage); } } @Override public void onError(String message, Throwable cause) { if (this.terminated.getAndIncrement() == 0) { - this.sink.error(new RuntimeException(message, cause)); + this.sink.error(new DecodingException(message, cause)); } } @@ -256,14 +400,10 @@ public class SynchronossPartHttpMessageReader extends LoggingCodecSupport implem private final HttpHeaders headers; - private final DataBufferFactory bufferFactory; - - AbstractSynchronossPart(HttpHeaders headers, DataBufferFactory bufferFactory) { + AbstractSynchronossPart(HttpHeaders headers) { Assert.notNull(headers, "HttpHeaders is required"); - Assert.notNull(bufferFactory, "DataBufferFactory is required"); this.name = MultipartUtils.getFieldName(headers); this.headers = headers; - this.bufferFactory = bufferFactory; } @Override @@ -276,10 +416,6 @@ public class SynchronossPartHttpMessageReader extends LoggingCodecSupport implem return this.headers; } - DataBufferFactory getBufferFactory() { - return this.bufferFactory; - } - @Override public String toString() { return "Part '" + this.name + "', headers=" + this.headers; @@ -291,15 +427,15 @@ public class SynchronossPartHttpMessageReader extends LoggingCodecSupport implem private final StreamStorage storage; - SynchronossPart(HttpHeaders headers, StreamStorage storage, DataBufferFactory factory) { - super(headers, factory); + SynchronossPart(HttpHeaders headers, StreamStorage storage) { + super(headers); Assert.notNull(storage, "StreamStorage is required"); this.storage = storage; } @Override public Flux content() { - return DataBufferUtils.readInputStream(getStorage()::getInputStream, getBufferFactory(), 4096); + return DataBufferUtils.readInputStream(getStorage()::getInputStream, bufferFactory, 4096); } protected StreamStorage getStorage() { @@ -315,8 +451,8 @@ public class SynchronossPartHttpMessageReader extends LoggingCodecSupport implem private final String filename; - SynchronossFilePart(HttpHeaders headers, String filename, StreamStorage storage, DataBufferFactory factory) { - super(headers, storage, factory); + SynchronossFilePart(HttpHeaders headers, String filename, StreamStorage storage) { + super(headers, storage); this.filename = filename; } @@ -375,8 +511,8 @@ public class SynchronossPartHttpMessageReader extends LoggingCodecSupport implem private final String content; - SynchronossFormFieldPart(HttpHeaders headers, DataBufferFactory bufferFactory, String content) { - super(headers, bufferFactory); + SynchronossFormFieldPart(HttpHeaders headers, String content) { + super(headers); this.content = content; } @@ -388,9 +524,7 @@ public class SynchronossPartHttpMessageReader extends LoggingCodecSupport implem @Override public Flux content() { byte[] bytes = this.content.getBytes(getCharset()); - DataBuffer buffer = getBufferFactory().allocateBuffer(bytes.length); - buffer.write(bytes); - return Flux.just(buffer); + return Flux.just(bufferFactory.wrap(bytes)); } private Charset getCharset() { diff --git a/spring-web/src/main/java/org/springframework/http/codec/protobuf/ProtobufDecoder.java b/spring-web/src/main/java/org/springframework/http/codec/protobuf/ProtobufDecoder.java index 37d7ae4d909..ec530bcbf38 100644 --- a/spring-web/src/main/java/org/springframework/http/codec/protobuf/ProtobufDecoder.java +++ b/spring-web/src/main/java/org/springframework/http/codec/protobuf/ProtobufDecoder.java @@ -36,6 +36,7 @@ import org.springframework.core.ResolvableType; import org.springframework.core.codec.Decoder; import org.springframework.core.codec.DecodingException; import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferLimitException; import org.springframework.core.io.buffer.DataBufferUtils; import org.springframework.lang.Nullable; import org.springframework.util.Assert; @@ -73,7 +74,7 @@ import org.springframework.util.MimeType; public class ProtobufDecoder extends ProtobufCodecSupport implements Decoder { /** The default max size for aggregating messages. */ - protected static final int DEFAULT_MESSAGE_MAX_SIZE = 64 * 1024; + protected static final int DEFAULT_MESSAGE_MAX_SIZE = 256 * 1024; private static final ConcurrentMap, Method> methodCache = new ConcurrentReferenceHashMap<>(); @@ -101,10 +102,23 @@ public class ProtobufDecoder extends ProtobufCodecSupport implements DecoderBy default, this is set to 256K. + * @param maxMessageSize the max size per message, or -1 for unlimited + */ public void setMaxMessageSize(int maxMessageSize) { this.maxMessageSize = maxMessageSize; } + /** + * Return the {@link #setMaxMessageSize configured} message size limit. + * @since 5.1.11 + */ + public int getMaxMessageSize() { + return this.maxMessageSize; + } + @Override public boolean canDecode(ResolvableType elementType, @Nullable MimeType mimeType) { @@ -127,7 +141,7 @@ public class ProtobufDecoder extends ProtobufCodecSupport implements Decoder decodeToMono(Publisher inputStream, ResolvableType elementType, @Nullable MimeType mimeType, @Nullable Map hints) { - return DataBufferUtils.join(inputStream) + return DataBufferUtils.join(inputStream, this.maxMessageSize) .map(dataBuffer -> decode(dataBuffer, elementType, mimeType, hints)); } @@ -204,9 +218,9 @@ public class ProtobufDecoder extends ProtobufCodecSupport implements Decoder this.maxMessageSize) { - throw new DecodingException( - "The number of bytes to read from the incoming stream " + + if (this.maxMessageSize > 0 && this.messageBytesToRead > this.maxMessageSize) { + throw new DataBufferLimitException( + "The number of bytes to read for message " + "(" + this.messageBytesToRead + ") exceeds " + "the configured limit (" + this.maxMessageSize + ")"); } diff --git a/spring-web/src/main/java/org/springframework/http/codec/support/BaseDefaultCodecs.java b/spring-web/src/main/java/org/springframework/http/codec/support/BaseDefaultCodecs.java index 564db7f262a..39d76ad0ddb 100644 --- a/spring-web/src/main/java/org/springframework/http/codec/support/BaseDefaultCodecs.java +++ b/spring-web/src/main/java/org/springframework/http/codec/support/BaseDefaultCodecs.java @@ -20,6 +20,7 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; +import org.springframework.core.codec.AbstractDataBufferDecoder; import org.springframework.core.codec.ByteArrayDecoder; import org.springframework.core.codec.ByteArrayEncoder; import org.springframework.core.codec.ByteBufferDecoder; @@ -29,6 +30,7 @@ import org.springframework.core.codec.DataBufferDecoder; import org.springframework.core.codec.DataBufferEncoder; import org.springframework.core.codec.Decoder; import org.springframework.core.codec.Encoder; +import org.springframework.core.codec.ResourceDecoder; import org.springframework.core.codec.StringDecoder; import org.springframework.http.codec.CodecConfigurer; import org.springframework.http.codec.DecoderHttpMessageReader; @@ -38,6 +40,7 @@ import org.springframework.http.codec.HttpMessageReader; import org.springframework.http.codec.HttpMessageWriter; import org.springframework.http.codec.ResourceHttpMessageReader; import org.springframework.http.codec.ResourceHttpMessageWriter; +import org.springframework.http.codec.json.AbstractJackson2Decoder; import org.springframework.http.codec.json.Jackson2JsonDecoder; import org.springframework.http.codec.json.Jackson2JsonEncoder; import org.springframework.http.codec.json.Jackson2SmileDecoder; @@ -95,6 +98,9 @@ class BaseDefaultCodecs implements CodecConfigurer.DefaultCodecs { @Nullable private Encoder jaxb2Encoder; + @Nullable + private Integer maxInMemorySize; + private boolean enableLoggingRequestDetails = false; private boolean registerDefaults = true; @@ -130,6 +136,16 @@ class BaseDefaultCodecs implements CodecConfigurer.DefaultCodecs { this.jaxb2Encoder = encoder; } + @Override + public void maxInMemorySize(int byteCount) { + this.maxInMemorySize = byteCount; + } + + @Nullable + protected Integer maxInMemorySize() { + return this.maxInMemorySize; + } + @Override public void enableLoggingRequestDetails(boolean enable) { this.enableLoggingRequestDetails = enable; @@ -155,17 +171,20 @@ class BaseDefaultCodecs implements CodecConfigurer.DefaultCodecs { return Collections.emptyList(); } List> readers = new ArrayList<>(); - readers.add(new DecoderHttpMessageReader<>(new ByteArrayDecoder())); - readers.add(new DecoderHttpMessageReader<>(new ByteBufferDecoder())); - readers.add(new DecoderHttpMessageReader<>(new DataBufferDecoder())); - readers.add(new ResourceHttpMessageReader()); - readers.add(new DecoderHttpMessageReader<>(StringDecoder.textPlainOnly())); + readers.add(new DecoderHttpMessageReader<>(init(new ByteArrayDecoder()))); + readers.add(new DecoderHttpMessageReader<>(init(new ByteBufferDecoder()))); + readers.add(new DecoderHttpMessageReader<>(init(new DataBufferDecoder()))); + readers.add(new ResourceHttpMessageReader(init(new ResourceDecoder()))); + readers.add(new DecoderHttpMessageReader<>(init(StringDecoder.textPlainOnly()))); if (protobufPresent) { - Decoder decoder = this.protobufDecoder != null ? this.protobufDecoder : new ProtobufDecoder(); + Decoder decoder = this.protobufDecoder != null ? this.protobufDecoder : init(new ProtobufDecoder()); readers.add(new DecoderHttpMessageReader<>(decoder)); } FormHttpMessageReader formReader = new FormHttpMessageReader(); + if (this.maxInMemorySize != null) { + formReader.setMaxInMemorySize(this.maxInMemorySize); + } formReader.setEnableLoggingRequestDetails(this.enableLoggingRequestDetails); readers.add(formReader); @@ -174,6 +193,28 @@ class BaseDefaultCodecs implements CodecConfigurer.DefaultCodecs { return readers; } + private > T init(T decoder) { + if (this.maxInMemorySize != null) { + if (decoder instanceof AbstractDataBufferDecoder) { + ((AbstractDataBufferDecoder) decoder).setMaxInMemorySize(this.maxInMemorySize); + } + if (decoder instanceof ProtobufDecoder) { + ((ProtobufDecoder) decoder).setMaxMessageSize(this.maxInMemorySize); + } + if (jackson2Present) { + if (decoder instanceof AbstractJackson2Decoder) { + ((AbstractJackson2Decoder) decoder).setMaxInMemorySize(this.maxInMemorySize); + } + } + if (jaxb2Present) { + if (decoder instanceof Jaxb2XmlDecoder) { + ((Jaxb2XmlDecoder) decoder).setMaxInMemorySize(this.maxInMemorySize); + } + } + } + return decoder; + } + /** * Hook for client or server specific typed readers. */ @@ -189,13 +230,13 @@ class BaseDefaultCodecs implements CodecConfigurer.DefaultCodecs { } List> readers = new ArrayList<>(); if (jackson2Present) { - readers.add(new DecoderHttpMessageReader<>(getJackson2JsonDecoder())); + readers.add(new DecoderHttpMessageReader<>(init(getJackson2JsonDecoder()))); } if (jackson2SmilePresent) { - readers.add(new DecoderHttpMessageReader<>(new Jackson2SmileDecoder())); + readers.add(new DecoderHttpMessageReader<>(init(new Jackson2SmileDecoder()))); } if (jaxb2Present) { - Decoder decoder = this.jaxb2Decoder != null ? this.jaxb2Decoder : new Jaxb2XmlDecoder(); + Decoder decoder = this.jaxb2Decoder != null ? this.jaxb2Decoder : init(new Jaxb2XmlDecoder()); readers.add(new DecoderHttpMessageReader<>(decoder)); } extendObjectReaders(readers); @@ -216,7 +257,7 @@ class BaseDefaultCodecs implements CodecConfigurer.DefaultCodecs { return Collections.emptyList(); } List> result = new ArrayList<>(); - result.add(new DecoderHttpMessageReader<>(StringDecoder.allMimeTypes())); + result.add(new DecoderHttpMessageReader<>(init(StringDecoder.allMimeTypes()))); return result; } diff --git a/spring-web/src/main/java/org/springframework/http/codec/support/ServerDefaultCodecsImpl.java b/spring-web/src/main/java/org/springframework/http/codec/support/ServerDefaultCodecsImpl.java index ef91fb7c369..37e924cd7e9 100644 --- a/spring-web/src/main/java/org/springframework/http/codec/support/ServerDefaultCodecsImpl.java +++ b/spring-web/src/main/java/org/springframework/http/codec/support/ServerDefaultCodecsImpl.java @@ -39,10 +39,18 @@ class ServerDefaultCodecsImpl extends BaseDefaultCodecs implements ServerCodecCo DefaultServerCodecConfigurer.class.getClassLoader()); + @Nullable + private HttpMessageReader multipartReader; + @Nullable private Encoder sseEncoder; + @Override + public void multipartReader(HttpMessageReader reader) { + this.multipartReader = reader; + } + @Override public void serverSentEventEncoder(Encoder encoder) { this.sseEncoder = encoder; @@ -51,10 +59,18 @@ class ServerDefaultCodecsImpl extends BaseDefaultCodecs implements ServerCodecCo @Override protected void extendTypedReaders(List> typedReaders) { + if (this.multipartReader != null) { + typedReaders.add(this.multipartReader); + return; + } if (synchronossMultipartPresent) { boolean enable = isEnableLoggingRequestDetails(); SynchronossPartHttpMessageReader partReader = new SynchronossPartHttpMessageReader(); + Integer size = maxInMemorySize(); + if (size != null) { + partReader.setMaxInMemorySize(size); + } partReader.setEnableLoggingRequestDetails(enable); typedReaders.add(partReader); diff --git a/spring-web/src/main/java/org/springframework/http/codec/xml/Jaxb2XmlDecoder.java b/spring-web/src/main/java/org/springframework/http/codec/xml/Jaxb2XmlDecoder.java index 71a233eff71..36d2319f033 100644 --- a/spring-web/src/main/java/org/springframework/http/codec/xml/Jaxb2XmlDecoder.java +++ b/spring-web/src/main/java/org/springframework/http/codec/xml/Jaxb2XmlDecoder.java @@ -49,6 +49,7 @@ import org.springframework.core.codec.CodecException; import org.springframework.core.codec.DecodingException; import org.springframework.core.codec.Hints; import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferLimitException; import org.springframework.core.io.buffer.DataBufferUtils; import org.springframework.core.log.LogFormatUtils; import org.springframework.lang.Nullable; @@ -87,6 +88,8 @@ public class Jaxb2XmlDecoder extends AbstractDecoder { private Function unmarshallerProcessor = Function.identity(); + private int maxInMemorySize = 256 * 1024; + public Jaxb2XmlDecoder() { super(MimeTypeUtils.APPLICATION_XML, MimeTypeUtils.TEXT_XML); @@ -119,6 +122,28 @@ public class Jaxb2XmlDecoder extends AbstractDecoder { return this.unmarshallerProcessor; } + /** + * Set the max number of bytes that can be buffered by this decoder. + * This is either the size of the entire input when decoding as a whole, or when + * using async parsing with Aalto XML, it is the size of one top-level XML tree. + * When the limit is exceeded, {@link DataBufferLimitException} is raised. + *

By default this is set to 256K. + * @param byteCount the max number of bytes to buffer, or -1 for unlimited + * @since 5.1.11 + */ + public void setMaxInMemorySize(int byteCount) { + this.maxInMemorySize = byteCount; + this.xmlEventDecoder.setMaxInMemorySize(byteCount); + } + + /** + * Return the {@link #setMaxInMemorySize configured} byte count limit. + * @since 5.1.11 + */ + public int getMaxInMemorySize() { + return this.maxInMemorySize; + } + @Override public boolean canDecode(ResolvableType elementType, @Nullable MimeType mimeType) { @@ -153,7 +178,7 @@ public class Jaxb2XmlDecoder extends AbstractDecoder { public Mono decodeToMono(Publisher input, ResolvableType elementType, @Nullable MimeType mimeType, @Nullable Map hints) { - return DataBufferUtils.join(input) + return DataBufferUtils.join(input, this.maxInMemorySize) .map(dataBuffer -> decode(dataBuffer, elementType, mimeType, hints)); } diff --git a/spring-web/src/main/java/org/springframework/http/codec/xml/XmlEventDecoder.java b/spring-web/src/main/java/org/springframework/http/codec/xml/XmlEventDecoder.java index cb39ba1307c..2305525a099 100644 --- a/spring-web/src/main/java/org/springframework/http/codec/xml/XmlEventDecoder.java +++ b/spring-web/src/main/java/org/springframework/http/codec/xml/XmlEventDecoder.java @@ -40,6 +40,7 @@ import reactor.core.publisher.Flux; import org.springframework.core.ResolvableType; import org.springframework.core.codec.AbstractDecoder; import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferLimitException; import org.springframework.core.io.buffer.DataBufferUtils; import org.springframework.lang.Nullable; import org.springframework.util.ClassUtils; @@ -89,26 +90,50 @@ public class XmlEventDecoder extends AbstractDecoder { boolean useAalto = aaltoPresent; + private int maxInMemorySize = 256 * 1024; + public XmlEventDecoder() { super(MimeTypeUtils.APPLICATION_XML, MimeTypeUtils.TEXT_XML); } + /** + * Set the max number of bytes that can be buffered by this decoder. This + * is either the size the entire input when decoding as a whole, or when + * using async parsing via Aalto XML, it is size one top-level XML tree. + * When the limit is exceeded, {@link DataBufferLimitException} is raised. + *

By default this is set to 256K. + * @param byteCount the max number of bytes to buffer, or -1 for unlimited + * @since 5.1.11 + */ + public void setMaxInMemorySize(int byteCount) { + this.maxInMemorySize = byteCount; + } + + /** + * Return the {@link #setMaxInMemorySize configured} byte count limit. + * @since 5.1.11 + */ + public int getMaxInMemorySize() { + return this.maxInMemorySize; + } + + @Override @SuppressWarnings({"rawtypes", "unchecked", "cast"}) // XMLEventReader is Iterator on JDK 9 public Flux decode(Publisher input, ResolvableType elementType, @Nullable MimeType mimeType, @Nullable Map hints) { if (this.useAalto) { - AaltoDataBufferToXmlEvent mapper = new AaltoDataBufferToXmlEvent(); + AaltoDataBufferToXmlEvent mapper = new AaltoDataBufferToXmlEvent(this.maxInMemorySize); return Flux.from(input) .flatMapIterable(mapper) .doFinally(signalType -> mapper.endOfInput()); } else { - return DataBufferUtils.join(input). - flatMapIterable(buffer -> { + return DataBufferUtils.join(input, this.maxInMemorySize) + .flatMapIterable(buffer -> { try { InputStream is = buffer.asInputStream(); Iterator eventReader = inputFactory.createXMLEventReader(is); @@ -140,10 +165,22 @@ public class XmlEventDecoder extends AbstractDecoder { private final XMLEventAllocator eventAllocator = EventAllocatorImpl.getDefaultInstance(); + private final int maxInMemorySize; + + private int byteCount; + + private int elementDepth; + + + public AaltoDataBufferToXmlEvent(int maxInMemorySize) { + this.maxInMemorySize = maxInMemorySize; + } + @Override public List apply(DataBuffer dataBuffer) { try { + increaseByteCount(dataBuffer); this.streamReader.getInputFeeder().feedInput(dataBuffer.asByteBuffer()); List events = new ArrayList<>(); while (true) { @@ -157,8 +194,12 @@ public class XmlEventDecoder extends AbstractDecoder { if (event.isEndDocument()) { break; } + checkDepthAndResetByteCount(event); } } + if (this.maxInMemorySize > 0 && this.byteCount > this.maxInMemorySize) { + raiseLimitException(); + } return events; } catch (XMLStreamException ex) { @@ -169,9 +210,40 @@ public class XmlEventDecoder extends AbstractDecoder { } } + private void increaseByteCount(DataBuffer dataBuffer) { + if (this.maxInMemorySize > 0) { + if (dataBuffer.readableByteCount() > Integer.MAX_VALUE - this.byteCount) { + raiseLimitException(); + } + else { + this.byteCount += dataBuffer.readableByteCount(); + } + } + } + + private void checkDepthAndResetByteCount(XMLEvent event) { + if (this.maxInMemorySize > 0) { + if (event.isStartElement()) { + this.byteCount = this.elementDepth == 1 ? 0 : this.byteCount; + this.elementDepth++; + } + else if (event.isEndElement()) { + this.elementDepth--; + this.byteCount = this.elementDepth == 1 ? 0 : this.byteCount; + } + } + } + + private void raiseLimitException() { + throw new DataBufferLimitException( + "Exceeded limit on max bytes per XML top-level node: " + this.maxInMemorySize); + } + public void endOfInput() { this.streamReader.getInputFeeder().endOfInput(); } } + + } diff --git a/spring-web/src/test/java/org/springframework/http/codec/json/Jackson2TokenizerTests.java b/spring-web/src/test/java/org/springframework/http/codec/json/Jackson2TokenizerTests.java index dbedb4dcee7..aa92701eeb2 100644 --- a/spring-web/src/test/java/org/springframework/http/codec/json/Jackson2TokenizerTests.java +++ b/spring-web/src/test/java/org/springframework/http/codec/json/Jackson2TokenizerTests.java @@ -20,7 +20,6 @@ import java.io.IOException; import java.io.UncheckedIOException; import java.nio.charset.StandardCharsets; import java.util.List; -import java.util.function.Consumer; import com.fasterxml.jackson.core.JsonFactory; import com.fasterxml.jackson.core.TreeNode; @@ -36,6 +35,7 @@ import reactor.test.StepVerifier; import org.springframework.core.codec.DecodingException; import org.springframework.core.io.buffer.AbstractLeakCheckingTests; import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferLimitException; import static java.util.Arrays.asList; import static java.util.Collections.singletonList; @@ -181,11 +181,68 @@ public class Jackson2TokenizerTests extends AbstractLeakCheckingTests { testTokenize(asList("[1", ",2,", "3]"), asList("1", "2", "3"), true); } + private void testTokenize(List input, List output, boolean tokenize) { + StepVerifier.FirstStep builder = StepVerifier.create(decode(input, tokenize, -1)); + output.forEach(expected -> builder.assertNext(actual -> { + try { + JSONAssert.assertEquals(expected, actual, true); + } + catch (JSONException ex) { + throw new RuntimeException(ex); + } + })); + builder.verifyComplete(); + } + + @Test + public void testLimit() { + + List source = asList("[", + "{", "\"id\":1,\"name\":\"Dan\"", "},", + "{", "\"id\":2,\"name\":\"Ron\"", "},", + "{", "\"id\":3,\"name\":\"Bartholomew\"", "}", + "]"); + + String expected = String.join("", source); + int maxInMemorySize = expected.length(); + + StepVerifier.create(decode(source, false, maxInMemorySize)) + .expectNext(expected) + .verifyComplete(); + + StepVerifier.create(decode(source, false, maxInMemorySize - 1)) + .expectError(DataBufferLimitException.class); + } + + @Test + public void testLimitTokenized() { + + List source = asList("[", + "{", "\"id\":1, \"name\":\"Dan\"", "},", + "{", "\"id\":2, \"name\":\"Ron\"", "},", + "{", "\"id\":3, \"name\":\"Bartholomew\"", "}", + "]"); + + String expected = "{\"id\":3,\"name\":\"Bartholomew\"}"; + int maxInMemorySize = expected.length(); + + StepVerifier.create(decode(source, true, maxInMemorySize)) + .expectNext("{\"id\":1,\"name\":\"Dan\"}") + .expectNext("{\"id\":2,\"name\":\"Ron\"}") + .expectNext(expected) + .verifyComplete(); + + StepVerifier.create(decode(source, true, maxInMemorySize - 1)) + .expectNext("{\"id\":1,\"name\":\"Dan\"}") + .expectNext("{\"id\":2,\"name\":\"Ron\"}") + .verifyError(DataBufferLimitException.class); + } + @Test public void errorInStream() { DataBuffer buffer = stringBuffer("{\"id\":1,\"name\":"); Flux source = Flux.just(buffer).concatWith(Flux.error(new RuntimeException())); - Flux result = Jackson2Tokenizer.tokenize(source, this.jsonFactory, this.objectMapper, true); + Flux result = Jackson2Tokenizer.tokenize(source, this.jsonFactory, this.objectMapper, true, -1); StepVerifier.create(result) .expectError(RuntimeException.class) @@ -195,7 +252,7 @@ public class Jackson2TokenizerTests extends AbstractLeakCheckingTests { @Test // SPR-16521 public void jsonEOFExceptionIsWrappedAsDecodingError() { Flux source = Flux.just(stringBuffer("{\"status\": \"noClosingQuote}")); - Flux tokens = Jackson2Tokenizer.tokenize(source, this.jsonFactory, this.objectMapper, false); + Flux tokens = Jackson2Tokenizer.tokenize(source, this.jsonFactory, this.objectMapper, false, -1); StepVerifier.create(tokens) .expectError(DecodingException.class) @@ -203,12 +260,13 @@ public class Jackson2TokenizerTests extends AbstractLeakCheckingTests { } - private void testTokenize(List source, List expected, boolean tokenizeArrayElements) { + private Flux decode(List source, boolean tokenize, int maxInMemorySize) { + Flux tokens = Jackson2Tokenizer.tokenize( Flux.fromIterable(source).map(this::stringBuffer), - this.jsonFactory, this.objectMapper, tokenizeArrayElements); + this.jsonFactory, this.objectMapper, tokenize, maxInMemorySize); - Flux result = tokens + return tokens .map(tokenBuffer -> { try { TreeNode root = this.objectMapper.readTree(tokenBuffer.asParser()); @@ -218,10 +276,6 @@ public class Jackson2TokenizerTests extends AbstractLeakCheckingTests { throw new UncheckedIOException(ex); } }); - - StepVerifier.FirstStep builder = StepVerifier.create(result); - expected.forEach(s -> builder.assertNext(new JSONAssertConsumer(s))); - builder.verifyComplete(); } private DataBuffer stringBuffer(String value) { @@ -231,24 +285,4 @@ public class Jackson2TokenizerTests extends AbstractLeakCheckingTests { return buffer; } - - private static class JSONAssertConsumer implements Consumer { - - private final String expected; - - JSONAssertConsumer(String expected) { - this.expected = expected; - } - - @Override - public void accept(String s) { - try { - JSONAssert.assertEquals(this.expected, s, true); - } - catch (JSONException ex) { - throw new RuntimeException(ex); - } - } - } - } diff --git a/spring-web/src/test/java/org/springframework/http/codec/multipart/SynchronossPartHttpMessageReaderTests.java b/spring-web/src/test/java/org/springframework/http/codec/multipart/SynchronossPartHttpMessageReaderTests.java index b9fc46381ab..8533dd426da 100644 --- a/spring-web/src/test/java/org/springframework/http/codec/multipart/SynchronossPartHttpMessageReaderTests.java +++ b/spring-web/src/test/java/org/springframework/http/codec/multipart/SynchronossPartHttpMessageReaderTests.java @@ -17,19 +17,26 @@ package org.springframework.http.codec.multipart; import java.io.File; +import java.io.IOException; +import java.nio.charset.StandardCharsets; import java.time.Duration; import java.util.Map; +import java.util.function.Consumer; import org.junit.jupiter.api.Test; +import org.reactivestreams.Subscription; +import reactor.core.publisher.BaseSubscriber; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; import org.springframework.core.ResolvableType; +import org.springframework.core.codec.DecodingException; import org.springframework.core.io.ClassPathResource; +import org.springframework.core.io.buffer.AbstractLeakCheckingTests; import org.springframework.core.io.buffer.DataBuffer; import org.springframework.core.io.buffer.DataBufferUtils; -import org.springframework.core.io.buffer.DefaultDataBufferFactory; +import org.springframework.core.io.buffer.support.DataBufferTestUtils; import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; import org.springframework.http.client.MultipartBodyBuilder; @@ -49,17 +56,20 @@ import static org.springframework.http.MediaType.MULTIPART_FORM_DATA; * * @author Sebastien Deleuze * @author Rossen Stoyanchev + * @author Brian Clozel */ -public class SynchronossPartHttpMessageReaderTests { +public class SynchronossPartHttpMessageReaderTests extends AbstractLeakCheckingTests { private final MultipartHttpMessageReader reader = new MultipartHttpMessageReader(new SynchronossPartHttpMessageReader()); + private static final ResolvableType PARTS_ELEMENT_TYPE = + forClassWithGenerics(MultiValueMap.class, String.class, Part.class); @Test - public void canRead() { + void canRead() { assertThat(this.reader.canRead( - forClassWithGenerics(MultiValueMap.class, String.class, Part.class), + PARTS_ELEMENT_TYPE, MediaType.MULTIPART_FORM_DATA)).isTrue(); assertThat(this.reader.canRead( @@ -80,43 +90,36 @@ public class SynchronossPartHttpMessageReaderTests { } @Test - public void resolveParts() { + void resolveParts() { ServerHttpRequest request = generateMultipartRequest(); - ResolvableType elementType = forClassWithGenerics(MultiValueMap.class, String.class, Part.class); - MultiValueMap parts = this.reader.readMono(elementType, request, emptyMap()).block(); - assertThat(parts.size()).isEqualTo(2); - - assertThat(parts.containsKey("fooPart")).isTrue(); - Part part = parts.getFirst("fooPart"); - boolean condition1 = part instanceof FilePart; - assertThat(condition1).isTrue(); - assertThat(part.name()).isEqualTo("fooPart"); + MultiValueMap parts = this.reader.readMono(PARTS_ELEMENT_TYPE, request, emptyMap()).block(); + + assertThat(parts).containsOnlyKeys("filePart", "textPart"); + + Part part = parts.getFirst("filePart"); + assertThat(part).isInstanceOf(FilePart.class); + assertThat(part.name()).isEqualTo("filePart"); assertThat(((FilePart) part).filename()).isEqualTo("foo.txt"); DataBuffer buffer = DataBufferUtils.join(part.content()).block(); - assertThat(buffer.readableByteCount()).isEqualTo(12); - byte[] byteContent = new byte[12]; - buffer.read(byteContent); - assertThat(new String(byteContent)).isEqualTo("Lorem Ipsum."); - - assertThat(parts.containsKey("barPart")).isTrue(); - part = parts.getFirst("barPart"); - boolean condition = part instanceof FormFieldPart; - assertThat(condition).isTrue(); - assertThat(part.name()).isEqualTo("barPart"); - assertThat(((FormFieldPart) part).value()).isEqualTo("bar"); + assertThat(DataBufferTestUtils.dumpString(buffer, StandardCharsets.UTF_8)).isEqualTo("Lorem Ipsum."); + DataBufferUtils.release(buffer); + + part = parts.getFirst("textPart"); + assertThat(part).isInstanceOf(FormFieldPart.class); + assertThat(part.name()).isEqualTo("textPart"); + assertThat(((FormFieldPart) part).value()).isEqualTo("sample-text"); } @Test // SPR-16545 - public void transferTo() { + void transferTo() throws IOException { ServerHttpRequest request = generateMultipartRequest(); - ResolvableType elementType = forClassWithGenerics(MultiValueMap.class, String.class, Part.class); - MultiValueMap parts = this.reader.readMono(elementType, request, emptyMap()).block(); + MultiValueMap parts = this.reader.readMono(PARTS_ELEMENT_TYPE, request, emptyMap()).block(); assertThat(parts).isNotNull(); - FilePart part = (FilePart) parts.getFirst("fooPart"); + FilePart part = (FilePart) parts.getFirst("filePart"); assertThat(part).isNotNull(); - File dest = new File(System.getProperty("java.io.tmpdir") + "/" + part.filename()); + File dest = File.createTempFile(part.filename(), "multipart"); part.transferTo(dest).block(Duration.ofSeconds(5)); assertThat(dest.exists()).isTrue(); @@ -125,33 +128,95 @@ public class SynchronossPartHttpMessageReaderTests { } @Test - public void bodyError() { + void bodyError() { ServerHttpRequest request = generateErrorMultipartRequest(); - ResolvableType elementType = forClassWithGenerics(MultiValueMap.class, String.class, Part.class); - StepVerifier.create(this.reader.readMono(elementType, request, emptyMap())).verifyError(); + StepVerifier.create(this.reader.readMono(PARTS_ELEMENT_TYPE, request, emptyMap())).verifyError(); + } + + @Test + void readPartsWithoutDemand() { + ServerHttpRequest request = generateMultipartRequest(); + Mono> parts = this.reader.readMono(PARTS_ELEMENT_TYPE, request, emptyMap()); + ZeroDemandSubscriber subscriber = new ZeroDemandSubscriber(); + parts.subscribe(subscriber); + subscriber.cancel(); } + @Test + void readTooManyParts() { + testMultipartExceptions(reader -> reader.setMaxParts(1), ex -> { + assertThat(ex) + .isInstanceOf(DecodingException.class) + .hasMessageStartingWith("Failure while parsing part[2]"); + assertThat(ex.getCause()) + .hasMessage("Too many parts (2 allowed)"); + } + ); + } - private ServerHttpRequest generateMultipartRequest() { + @Test + void readFilePartTooBig() { + testMultipartExceptions(reader -> reader.setMaxDiskUsagePerPart(5), ex -> { + assertThat(ex) + .isInstanceOf(DecodingException.class) + .hasMessageStartingWith("Failure while parsing part[1]"); + assertThat(ex.getCause()) + .hasMessage("Part[1] exceeded the disk usage limit of 5 bytes"); + } + ); + } + + @Test + void readPartHeadersTooBig() { + testMultipartExceptions(reader -> reader.setMaxInMemorySize(1), ex -> { + assertThat(ex) + .isInstanceOf(DecodingException.class) + .hasMessageStartingWith("Failure while parsing part[1]"); + assertThat(ex.getCause()) + .hasMessage("Part[1] exceeded the in-memory limit of 1 bytes"); + } + ); + } + + private void testMultipartExceptions( + Consumer configurer, Consumer assertions) { + + SynchronossPartHttpMessageReader reader = new SynchronossPartHttpMessageReader(); + configurer.accept(reader); + MultipartHttpMessageReader multipartReader = new MultipartHttpMessageReader(reader); + StepVerifier.create(multipartReader.readMono(PARTS_ELEMENT_TYPE, generateMultipartRequest(), emptyMap())) + .consumeErrorWith(assertions) + .verify(); + } + private ServerHttpRequest generateMultipartRequest() { MultipartBodyBuilder partsBuilder = new MultipartBodyBuilder(); - partsBuilder.part("fooPart", new ClassPathResource("org/springframework/http/codec/multipart/foo.txt")); - partsBuilder.part("barPart", "bar"); + partsBuilder.part("filePart", new ClassPathResource("org/springframework/http/codec/multipart/foo.txt")); + partsBuilder.part("textPart", "sample-text"); MockClientHttpRequest outputMessage = new MockClientHttpRequest(HttpMethod.POST, "/"); new MultipartHttpMessageWriter() .write(Mono.just(partsBuilder.build()), null, MediaType.MULTIPART_FORM_DATA, outputMessage, null) .block(Duration.ofSeconds(5)); - + Flux requestBody = outputMessage.getBody() + .map(buffer -> this.bufferFactory.wrap(buffer.asByteBuffer())); return MockServerHttpRequest.post("/") .contentType(outputMessage.getHeaders().getContentType()) - .body(outputMessage.getBody()); + .body(requestBody); } private ServerHttpRequest generateErrorMultipartRequest() { return MockServerHttpRequest.post("/") .header(CONTENT_TYPE, MULTIPART_FORM_DATA.toString()) - .body(Flux.just(new DefaultDataBufferFactory().wrap("invalid content".getBytes()))); + .body(Flux.just(this.bufferFactory.wrap("invalid content".getBytes()))); + } + + private static class ZeroDemandSubscriber extends BaseSubscriber> { + + @Override + protected void hookOnSubscribe(Subscription subscription) { + // Just subscribe without requesting + } } } diff --git a/spring-web/src/test/java/org/springframework/http/codec/support/ServerCodecConfigurerTests.java b/spring-web/src/test/java/org/springframework/http/codec/support/ServerCodecConfigurerTests.java index 777f8f51bca..5698e154dd3 100644 --- a/spring-web/src/test/java/org/springframework/http/codec/support/ServerCodecConfigurerTests.java +++ b/spring-web/src/test/java/org/springframework/http/codec/support/ServerCodecConfigurerTests.java @@ -36,6 +36,7 @@ import org.springframework.core.codec.DataBufferDecoder; import org.springframework.core.codec.DataBufferEncoder; import org.springframework.core.codec.Decoder; import org.springframework.core.codec.Encoder; +import org.springframework.core.codec.ResourceDecoder; import org.springframework.core.codec.StringDecoder; import org.springframework.core.io.buffer.DefaultDataBufferFactory; import org.springframework.http.MediaType; @@ -124,13 +125,45 @@ public class ServerCodecConfigurerTests { .filter(e -> e == encoder).orElse(null)).isSameAs(encoder); } + @Test + public void maxInMemorySize() { + int size = 99; + this.configurer.defaultCodecs().maxInMemorySize(size); + List> readers = this.configurer.getReaders(); + assertThat(readers.size()).isEqualTo(13); + assertThat(((ByteArrayDecoder) getNextDecoder(readers)).getMaxInMemorySize()).isEqualTo(size); + assertThat(((ByteBufferDecoder) getNextDecoder(readers)).getMaxInMemorySize()).isEqualTo(size); + assertThat(((DataBufferDecoder) getNextDecoder(readers)).getMaxInMemorySize()).isEqualTo(size); + + ResourceHttpMessageReader resourceReader = (ResourceHttpMessageReader) nextReader(readers); + ResourceDecoder decoder = (ResourceDecoder) resourceReader.getDecoder(); + assertThat(decoder.getMaxInMemorySize()).isEqualTo(size); + + assertThat(((StringDecoder) getNextDecoder(readers)).getMaxInMemorySize()).isEqualTo(size); + assertThat(((ProtobufDecoder) getNextDecoder(readers)).getMaxMessageSize()).isEqualTo(size); + assertThat(((FormHttpMessageReader) nextReader(readers)).getMaxInMemorySize()).isEqualTo(size); + assertThat(((SynchronossPartHttpMessageReader) nextReader(readers)).getMaxInMemorySize()).isEqualTo(size); + + MultipartHttpMessageReader multipartReader = (MultipartHttpMessageReader) nextReader(readers); + SynchronossPartHttpMessageReader reader = (SynchronossPartHttpMessageReader) multipartReader.getPartReader(); + assertThat((reader).getMaxInMemorySize()).isEqualTo(size); + + assertThat(((Jackson2JsonDecoder) getNextDecoder(readers)).getMaxInMemorySize()).isEqualTo(size); + assertThat(((Jackson2SmileDecoder) getNextDecoder(readers)).getMaxInMemorySize()).isEqualTo(size); + assertThat(((Jaxb2XmlDecoder) getNextDecoder(readers)).getMaxInMemorySize()).isEqualTo(size); + assertThat(((StringDecoder) getNextDecoder(readers)).getMaxInMemorySize()).isEqualTo(size); + } private Decoder getNextDecoder(List> readers) { - HttpMessageReader reader = readers.get(this.index.getAndIncrement()); + HttpMessageReader reader = nextReader(readers); assertThat(reader.getClass()).isEqualTo(DecoderHttpMessageReader.class); return ((DecoderHttpMessageReader) reader).getDecoder(); } + private HttpMessageReader nextReader(List> readers) { + return readers.get(this.index.getAndIncrement()); + } + private Encoder getNextEncoder(List> writers) { HttpMessageWriter writer = writers.get(this.index.getAndIncrement()); assertThat(writer.getClass()).isEqualTo(EncoderHttpMessageWriter.class); diff --git a/spring-web/src/test/java/org/springframework/http/codec/xml/XmlEventDecoderTests.java b/spring-web/src/test/java/org/springframework/http/codec/xml/XmlEventDecoderTests.java index 56a271d4577..eb1a7b38910 100644 --- a/spring-web/src/test/java/org/springframework/http/codec/xml/XmlEventDecoderTests.java +++ b/spring-web/src/test/java/org/springframework/http/codec/xml/XmlEventDecoderTests.java @@ -28,6 +28,7 @@ import reactor.test.StepVerifier; import org.springframework.core.io.buffer.AbstractLeakCheckingTests; import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferLimitException; import static org.assertj.core.api.Assertions.assertThat; @@ -44,11 +45,12 @@ public class XmlEventDecoderTests extends AbstractLeakCheckingTests { private XmlEventDecoder decoder = new XmlEventDecoder(); + @Test public void toXMLEventsAalto() { Flux events = - this.decoder.decode(stringBuffer(XML), null, null, Collections.emptyMap()); + this.decoder.decode(stringBufferMono(XML), null, null, Collections.emptyMap()); StepVerifier.create(events) .consumeNextWith(e -> assertThat(e.isStartDocument()).isTrue()) @@ -69,7 +71,7 @@ public class XmlEventDecoderTests extends AbstractLeakCheckingTests { decoder.useAalto = false; Flux events = - this.decoder.decode(stringBuffer(XML), null, null, Collections.emptyMap()); + this.decoder.decode(stringBufferMono(XML), null, null, Collections.emptyMap()); StepVerifier.create(events) .consumeNextWith(e -> assertThat(e.isStartDocument()).isTrue()) @@ -86,10 +88,32 @@ public class XmlEventDecoderTests extends AbstractLeakCheckingTests { .verify(); } + @Test + public void toXMLEventsWithLimit() { + + this.decoder.setMaxInMemorySize(6); + + Flux source = Flux.just( + "", "", "foofoo", "", "", "barbarbar", "", ""); + + Flux events = this.decoder.decode( + source.map(this::stringBuffer), null, null, Collections.emptyMap()); + + StepVerifier.create(events) + .consumeNextWith(e -> assertThat(e.isStartDocument()).isTrue()) + .consumeNextWith(e -> assertStartElement(e, "pojo")) + .consumeNextWith(e -> assertStartElement(e, "foo")) + .consumeNextWith(e -> assertCharacters(e, "foofoo")) + .consumeNextWith(e -> assertEndElement(e, "foo")) + .consumeNextWith(e -> assertStartElement(e, "bar")) + .expectError(DataBufferLimitException.class) + .verify(); + } + @Test public void decodeErrorAalto() { Flux source = Flux.concat( - stringBuffer(""), + stringBufferMono(""), Flux.error(new RuntimeException())); Flux events = @@ -107,7 +131,7 @@ public class XmlEventDecoderTests extends AbstractLeakCheckingTests { decoder.useAalto = false; Flux source = Flux.concat( - stringBuffer(""), + stringBufferMono(""), Flux.error(new RuntimeException())); Flux events = @@ -133,13 +157,15 @@ public class XmlEventDecoderTests extends AbstractLeakCheckingTests { assertThat(event.asCharacters().getData()).isEqualTo(expectedData); } - private Mono stringBuffer(String value) { - return Mono.defer(() -> { - byte[] bytes = value.getBytes(StandardCharsets.UTF_8); - DataBuffer buffer = this.bufferFactory.allocateBuffer(bytes.length); - buffer.write(bytes); - return Mono.just(buffer); - }); + private DataBuffer stringBuffer(String value) { + byte[] bytes = value.getBytes(StandardCharsets.UTF_8); + DataBuffer buffer = this.bufferFactory.allocateBuffer(bytes.length); + buffer.write(bytes); + return buffer; + } + + private Mono stringBufferMono(String value) { + return Mono.defer(() -> Mono.just(stringBuffer(value))); } } diff --git a/src/docs/asciidoc/web/webflux.adoc b/src/docs/asciidoc/web/webflux.adoc index 3af5baa28f9..b1ec0d8fcc7 100644 --- a/src/docs/asciidoc/web/webflux.adoc +++ b/src/docs/asciidoc/web/webflux.adoc @@ -818,6 +818,33 @@ for repeated, map-like access to parts, or otherwise rely on the `SynchronossPartHttpMessageReader` for a one-time access to `Flux`. +[[webflux-codecs-limits]] +==== Limits + +`Decoder` and `HttpMessageReader` implementations that buffer some or all of the input +stream can be configured with a limit on the maximum number of bytes to buffer in memory. +In some cases buffering occurs because input is aggregated and represented as a single +object, e.g. controller method with `@RequestBody byte[]`, `x-www-form-urlencoded` data, +and so on. Buffering can also occurs with streaming, when splitting the input stream, +e.g. delimited text, a stream of JSON objects, and so on. For those streaming cases, the +limit applies to the number of bytes associted with one object in the stream. + +To configure buffer sizes, you can check if a given `Decoder` or `HttpMessageReader` +exposes a `maxInMemorySize` property and if so the Javadoc will have details about default +values. In WebFlux, the `ServerCodecConfigurer` provides a +<> from where to set all codecs, through the +`maxInMemorySize` property for default codecs. + +For <> the `maxInMemorySize` property limits +the size of non-file parts. For file parts it determines the threshold at which the part +is written to disk. For file parts written to disk, there is an additional +`maxDiskUsagePerPart` property to limit the amount of disk space per part. There is also +a `maxParts` property to limit the overall number of parts in a multipart request. +To configure all 3 in WebFlux, you'll need to supply a pre-configured instance of +`MultipartHttpMessageReader` to `ServerCodecConfigurer`. + + + [[webflux-codecs-streaming]] ==== Streaming [.small]#<>#