diff --git a/spring-web/src/main/java/org/springframework/http/codec/multipart/MultipartException.java b/spring-web/src/main/java/org/springframework/http/codec/multipart/MultipartException.java
new file mode 100644
index 00000000000..c2d78e0fa86
--- /dev/null
+++ b/spring-web/src/main/java/org/springframework/http/codec/multipart/MultipartException.java
@@ -0,0 +1,32 @@
+/*
+ * Copyright 2002-2019 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.http.codec.multipart;
+
+/**
+ * @author Brian Clozel
+ */
+@SuppressWarnings("serial")
+public class MultipartException extends RuntimeException {
+
+ public MultipartException(String message) {
+ super(message);
+ }
+
+ public MultipartException(String message, Throwable cause) {
+ super(message, cause);
+ }
+}
diff --git a/spring-web/src/main/java/org/springframework/http/codec/multipart/SynchronossPartHttpMessageReader.java b/spring-web/src/main/java/org/springframework/http/codec/multipart/SynchronossPartHttpMessageReader.java
index 338222e7c55..6caa02e7b1b 100644
--- a/spring-web/src/main/java/org/springframework/http/codec/multipart/SynchronossPartHttpMessageReader.java
+++ b/spring-web/src/main/java/org/springframework/http/codec/multipart/SynchronossPartHttpMessageReader.java
@@ -40,9 +40,11 @@ import org.synchronoss.cloud.nio.multipart.NioMultipartParser;
import org.synchronoss.cloud.nio.multipart.NioMultipartParserListener;
import org.synchronoss.cloud.nio.multipart.PartBodyStreamStorageFactory;
import org.synchronoss.cloud.nio.stream.storage.StreamStorage;
+import reactor.core.publisher.BaseSubscriber;
import reactor.core.publisher.Flux;
import reactor.core.publisher.FluxSink;
import reactor.core.publisher.Mono;
+import reactor.core.publisher.SignalType;
import org.springframework.core.ResolvableType;
import org.springframework.core.codec.Hints;
@@ -69,6 +71,7 @@ import org.springframework.util.Assert;
* @author Sebastien Deleuze
* @author Rossen Stoyanchev
* @author Arjen Poutsma
+ * @author Brian Clozel
* @since 5.0
* @see Synchronoss NIO Multipart
* @see MultipartHttpMessageReader
@@ -79,6 +82,12 @@ public class SynchronossPartHttpMessageReader extends LoggingCodecSupport implem
private final PartBodyStreamStorageFactory streamStorageFactory = new DefaultPartBodyStreamStorageFactory();
+ private long maxPartCount = -1;
+
+ private long maxFilePartSize = -1;
+
+ private long maxPartSize = -1;
+
public SynchronossPartHttpMessageReader() {
this.bufferFactory = new DefaultDataBufferFactory();
}
@@ -87,6 +96,55 @@ public class SynchronossPartHttpMessageReader extends LoggingCodecSupport implem
this.bufferFactory = bufferFactory;
}
+ /**
+ * Get the maximum number of parts allowed in a single multipart request.
+ * @since 5.1.11
+ */
+ public long getMaxPartCount() {
+ return maxPartCount;
+ }
+
+ /**
+ * Configure the maximum number of parts allowed in a single multipart request.
+ * @since 5.1.11
+ */
+ public void setMaxPartCount(long maxPartCount) {
+ this.maxPartCount = maxPartCount;
+ }
+
+ /**
+ * Get the maximum size of a file part.
+ * @since 5.1.11
+ */
+ public long getMaxFilePartSize() {
+ return this.maxFilePartSize;
+ }
+
+ /**
+ * Configure the the maximum size of a file part.
+ * @since 5.1.11
+ */
+ public void setMaxFilePartSize(long maxFilePartSize) {
+ this.maxFilePartSize = maxFilePartSize;
+ }
+
+ /**
+ * Get the maximum size of a part.
+ * @since 5.1.11
+ */
+ public long getMaxPartSize() {
+ return this.maxPartSize;
+ }
+
+ /**
+ * Configure the maximum size of a part.
+ * For limits on file parts, use the dedicated {@link #setMaxFilePartSize(long)}.
+ * @since 5.1.11
+ */
+ public void setMaxPartSize(long maxPartSize) {
+ this.maxPartSize = maxPartSize;
+ }
+
@Override
public List getReadableMediaTypes() {
return Collections.singletonList(MediaType.MULTIPART_FORM_DATA);
@@ -101,7 +159,8 @@ public class SynchronossPartHttpMessageReader extends LoggingCodecSupport implem
@Override
public Flux read(ResolvableType elementType, ReactiveHttpInputMessage message, Map hints) {
- return Flux.create(new SynchronossPartGenerator(message, this.bufferFactory, this.streamStorageFactory))
+ return Flux.create(new SynchronossPartGenerator(message, this.bufferFactory, this.streamStorageFactory,
+ new MultipartSizeLimiter(getMaxPartCount(), getMaxFilePartSize(), getMaxPartSize())))
.doOnNext(part -> {
if (!Hints.isLoggingSuppressed(hints)) {
LogFormatUtils.traceDebug(logger, traceOn -> Hints.getLogPrefix(hints) + "Parsed " +
@@ -119,11 +178,60 @@ public class SynchronossPartHttpMessageReader extends LoggingCodecSupport implem
}
+ private static class MultipartSizeLimiter {
+
+ private final long maxPartCount;
+
+ private final long maxFilePartSize;
+
+ private final long maxPartSize;
+
+ private boolean currentIsFilePart;
+
+ private long currentPartCount;
+
+ private long currentPartSize;
+
+
+ public MultipartSizeLimiter(long maxPartCount, long maxFilePartSize, long maxPartSize) {
+ this.maxPartCount = maxPartCount;
+ this.maxFilePartSize = maxFilePartSize;
+ this.maxPartSize = maxPartSize;
+ }
+
+ public void startPart(boolean isFilePart) {
+ this.currentPartCount++;
+ this.currentIsFilePart = isFilePart;
+ if (this.maxPartCount != -1 && this.currentPartCount > this.maxPartCount) {
+ throw new IllegalStateException("Exceeded limit on maximum number of multipart parts");
+ }
+ }
+
+ public void endPart() {
+ this.currentPartSize = 0L;
+ this.currentIsFilePart = false;
+ }
+
+ public void checkCurrentPartSize(long addedBytes) {
+ this.currentPartSize += addedBytes;
+ if (this.currentIsFilePart && this.maxFilePartSize != -1 && this.currentPartSize > this.maxFilePartSize) {
+ throw new IllegalStateException("Exceeded limit on max size of multipart file : " + this.maxFilePartSize);
+ }
+ else if (!this.currentIsFilePart && this.maxPartSize != -1 && this.currentPartSize > this.maxPartSize) {
+ throw new IllegalStateException("Exceeded limit on max size of multipart part : " + this.maxPartSize);
+ }
+ }
+
+ }
+
/**
- * Consume and feed input to the Synchronoss parser, then listen for parser
- * output events and adapt to {@code Flux>}.
+ * Consume {@code DataBuffer} as a {@code BaseSubscriber} of the request body
+ * and feed it as input to the Synchronoss parser. Also listen for parser
+ * output events and adapt them to {@code Flux>} to emit parts
+ * for subscribers.
*/
- private static class SynchronossPartGenerator implements Consumer> {
+ private static class SynchronossPartGenerator extends BaseSubscriber
+ implements Consumer> {
private final ReactiveHttpInputMessage inputMessage;
@@ -131,12 +239,19 @@ public class SynchronossPartHttpMessageReader extends LoggingCodecSupport implem
private final PartBodyStreamStorageFactory streamStorageFactory;
- SynchronossPartGenerator(ReactiveHttpInputMessage inputMessage, DataBufferFactory bufferFactory,
- PartBodyStreamStorageFactory streamStorageFactory) {
+ private final MultipartSizeLimiter limiter;
+
+ private NioMultipartParserListener listener;
+
+ private NioMultipartParser parser;
+
+ public SynchronossPartGenerator(ReactiveHttpInputMessage inputMessage, DataBufferFactory bufferFactory,
+ PartBodyStreamStorageFactory streamStorageFactory, MultipartSizeLimiter limiter) {
this.inputMessage = inputMessage;
this.bufferFactory = bufferFactory;
- this.streamStorageFactory = streamStorageFactory;
+ this.streamStorageFactory = new PartBodyStreamStorageFactoryDecorator(streamStorageFactory, limiter);
+ this.limiter = limiter;
}
@Override
@@ -149,40 +264,53 @@ public class SynchronossPartHttpMessageReader extends LoggingCodecSupport implem
Charset charset = Optional.ofNullable(mediaType.getCharset()).orElse(StandardCharsets.UTF_8);
MultipartContext context = new MultipartContext(mediaType.toString(), length, charset.name());
- NioMultipartParserListener listener = new FluxSinkAdapterListener(emitter, this.bufferFactory, context);
- NioMultipartParser parser = Multipart
+ this.listener = new FluxSinkAdapterListener(emitter, this.bufferFactory, context, this.limiter);
+ this.parser = Multipart
.multipart(context)
.usePartBodyStreamStorageFactory(this.streamStorageFactory)
- .forNIO(listener);
+ // long to int downcast vs. keeping the default 16Kb value
+ //.withHeadersSizeLimit(this.limiter.maxPartSize)
+ .forNIO(this.listener);
+ this.inputMessage.getBody().subscribe(this);
+ }
- this.inputMessage.getBody().subscribe(buffer -> {
- byte[] resultBytes = new byte[buffer.readableByteCount()];
- buffer.read(resultBytes);
- try {
- parser.write(resultBytes);
- }
- catch (IOException ex) {
- listener.onError("Exception thrown while providing input to the parser", ex);
- }
- finally {
- DataBufferUtils.release(buffer);
- }
- }, ex -> {
- try {
- parser.close();
- listener.onError("Request body input error", ex);
- }
- catch (IOException ex2) {
- listener.onError("Exception thrown while closing the parser", ex2);
- }
- }, () -> {
- try {
- parser.close();
- }
- catch (IOException ex) {
- listener.onError("Exception thrown while closing the parser", ex);
- }
- });
+ @Override
+ protected void hookOnNext(DataBuffer buffer) {
+ int readableByteCount = buffer.readableByteCount();
+ this.limiter.checkCurrentPartSize(readableByteCount);
+ byte[] resultBytes = new byte[readableByteCount];
+ buffer.read(resultBytes);
+ try {
+ parser.write(resultBytes);
+ }
+ catch (IOException ex) {
+ this.cancel();
+ listener.onError("Exception thrown while providing input to the parser", ex);
+ }
+ finally {
+ DataBufferUtils.release(buffer);
+ }
+ }
+
+ @Override
+ protected void hookOnError(Throwable throwable) {
+ this.cancel();
+ listener.onError("Could not parse multipart request", throwable);
+ }
+
+ @Override
+ protected void hookOnCancel() {
+ this.cancel();
+ }
+
+ @Override
+ protected void hookFinally(SignalType type) {
+ try {
+ parser.close();
+ }
+ catch (IOException ex) {
+ listener.onError("Exception thrown while closing the parser", ex);
+ }
}
private int getContentLength(HttpHeaders headers) {
@@ -192,6 +320,28 @@ public class SynchronossPartHttpMessageReader extends LoggingCodecSupport implem
}
}
+ private static class PartBodyStreamStorageFactoryDecorator implements PartBodyStreamStorageFactory {
+
+ private final PartBodyStreamStorageFactory streamStorageFactory;
+
+ private final MultipartSizeLimiter limiter;
+
+ public PartBodyStreamStorageFactoryDecorator(PartBodyStreamStorageFactory streamStorageFactory,
+ MultipartSizeLimiter limiter) {
+ this.streamStorageFactory = streamStorageFactory;
+ this.limiter = limiter;
+ }
+
+ @Override
+ public StreamStorage newStreamStorageForPartBody(Map> partHeaders, int partIndex) {
+ HttpHeaders httpHeaders = new HttpHeaders();
+ httpHeaders.putAll(partHeaders);
+ String filename = MultipartUtils.getFileName(httpHeaders);
+ this.limiter.startPart(filename != null);
+ return streamStorageFactory.newStreamStorageForPartBody(partHeaders, partIndex);
+ }
+ }
+
/**
* Listen for parser output and adapt to {@code Flux>}.
@@ -204,12 +354,16 @@ public class SynchronossPartHttpMessageReader extends LoggingCodecSupport implem
private final MultipartContext context;
+ private final MultipartSizeLimiter limiter;
+
private final AtomicInteger terminated = new AtomicInteger(0);
- FluxSinkAdapterListener(FluxSink sink, DataBufferFactory factory, MultipartContext context) {
+ FluxSinkAdapterListener(FluxSink sink, DataBufferFactory factory,
+ MultipartContext context, MultipartSizeLimiter limiter) {
this.sink = sink;
this.bufferFactory = factory;
this.context = context;
+ this.limiter = limiter;
}
@Override
@@ -217,6 +371,7 @@ public class SynchronossPartHttpMessageReader extends LoggingCodecSupport implem
HttpHeaders httpHeaders = new HttpHeaders();
httpHeaders.putAll(headers);
this.sink.next(createPart(storage, httpHeaders));
+ this.limiter.endPart();
}
private Part createPart(StreamStorage storage, HttpHeaders httpHeaders) {
@@ -236,7 +391,7 @@ public class SynchronossPartHttpMessageReader extends LoggingCodecSupport implem
@Override
public void onError(String message, Throwable cause) {
if (this.terminated.getAndIncrement() == 0) {
- this.sink.error(new RuntimeException(message, cause));
+ this.sink.error(new MultipartException(message, cause));
}
}
diff --git a/spring-web/src/test/java/org/springframework/http/codec/multipart/SynchronossPartHttpMessageReaderTests.java b/spring-web/src/test/java/org/springframework/http/codec/multipart/SynchronossPartHttpMessageReaderTests.java
index b8adb9e6044..14b4bcf59d9 100644
--- a/spring-web/src/test/java/org/springframework/http/codec/multipart/SynchronossPartHttpMessageReaderTests.java
+++ b/spring-web/src/test/java/org/springframework/http/codec/multipart/SynchronossPartHttpMessageReaderTests.java
@@ -17,9 +17,11 @@
package org.springframework.http.codec.multipart;
import java.io.File;
+import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.Map;
+import java.util.function.Consumer;
import org.junit.jupiter.api.Test;
import org.reactivestreams.Subscription;
@@ -91,32 +93,32 @@ public class SynchronossPartHttpMessageReaderTests extends AbstractLeakCheckingT
ServerHttpRequest request = generateMultipartRequest();
MultiValueMap parts = this.reader.readMono(PARTS_ELEMENT_TYPE, request, emptyMap()).block();
- assertThat(parts).containsOnlyKeys("fooPart", "barPart");
+ assertThat(parts).containsOnlyKeys("filePart", "textPart");
- Part part = parts.getFirst("fooPart");
+ Part part = parts.getFirst("filePart");
assertThat(part).isInstanceOf(FilePart.class);
- assertThat(part.name()).isEqualTo("fooPart");
+ assertThat(part.name()).isEqualTo("filePart");
assertThat(((FilePart) part).filename()).isEqualTo("foo.txt");
DataBuffer buffer = DataBufferUtils.join(part.content()).block();
assertThat(DataBufferTestUtils.dumpString(buffer, StandardCharsets.UTF_8)).isEqualTo("Lorem Ipsum.");
DataBufferUtils.release(buffer);
- part = parts.getFirst("barPart");
+ part = parts.getFirst("textPart");
assertThat(part).isInstanceOf(FormFieldPart.class);
- assertThat(part.name()).isEqualTo("barPart");
- assertThat(((FormFieldPart) part).value()).isEqualTo("bar");
+ assertThat(part.name()).isEqualTo("textPart");
+ assertThat(((FormFieldPart) part).value()).isEqualTo("sample-text");
}
@Test // SPR-16545
- void transferTo() {
+ void transferTo() throws IOException {
ServerHttpRequest request = generateMultipartRequest();
MultiValueMap parts = this.reader.readMono(PARTS_ELEMENT_TYPE, request, emptyMap()).block();
assertThat(parts).isNotNull();
- FilePart part = (FilePart) parts.getFirst("fooPart");
+ FilePart part = (FilePart) parts.getFirst("filePart");
assertThat(part).isNotNull();
- File dest = new File(System.getProperty("java.io.tmpdir") + "/" + part.filename());
+ File dest = File.createTempFile(part.filename(), "multipart");
part.transferTo(dest).block(Duration.ofSeconds(5));
assertThat(dest.exists()).isTrue();
@@ -139,12 +141,57 @@ public class SynchronossPartHttpMessageReaderTests extends AbstractLeakCheckingT
subscriber.cancel();
}
+ @Test
+ void readTooManyParts() {
+ testMultipartExceptions(
+ reader -> reader.setMaxPartCount(1),
+ err -> {
+ assertThat(err).isInstanceOf(MultipartException.class)
+ .hasMessage("Could not parse multipart request");
+ assertThat(err.getCause()).hasMessage("Exceeded limit on maximum number of multipart parts");
+ }
+ );
+ }
+
+ @Test
+ void readFilePartTooBig() {
+ testMultipartExceptions(
+ reader -> reader.setMaxFilePartSize(5),
+ err -> {
+ assertThat(err).isInstanceOf(MultipartException.class)
+ .hasMessage("Could not parse multipart request");
+ assertThat(err.getCause()).hasMessage("Exceeded limit on max size of multipart file : 5");
+ }
+ );
+ }
- private ServerHttpRequest generateMultipartRequest() {
+ @Test
+ void readPartTooBig() {
+ testMultipartExceptions(
+ reader -> reader.setMaxPartSize(6),
+ err -> {
+ assertThat(err).isInstanceOf(MultipartException.class)
+ .hasMessage("Could not parse multipart request");
+ assertThat(err.getCause()).hasMessage("Exceeded limit on max size of multipart part : 6");
+ }
+ );
+ }
+ private void testMultipartExceptions(Consumer configurer,
+ Consumer assertions) {
+ SynchronossPartHttpMessageReader synchronossReader = new SynchronossPartHttpMessageReader(this.bufferFactory);
+ configurer.accept(synchronossReader);
+ MultipartHttpMessageReader reader = new MultipartHttpMessageReader(synchronossReader);
+ ServerHttpRequest request = generateMultipartRequest();
+ StepVerifier.create(reader.readMono(PARTS_ELEMENT_TYPE, request, emptyMap()))
+ .consumeErrorWith(assertions)
+ .verify();
+ }
+
+ private ServerHttpRequest generateMultipartRequest() {
MultipartBodyBuilder partsBuilder = new MultipartBodyBuilder();
- partsBuilder.part("fooPart", new ClassPathResource("org/springframework/http/codec/multipart/foo.txt"));
- partsBuilder.part("barPart", "bar");
+ partsBuilder.part("filePart", new ClassPathResource("org/springframework/http/codec/multipart/foo.txt"));
+ partsBuilder.part("textPart", "sample-text");
MockClientHttpRequest outputMessage = new MockClientHttpRequest(HttpMethod.POST, "/");
new MultipartHttpMessageWriter()