diff --git a/spring-web/src/main/java/org/springframework/http/codec/multipart/MultipartException.java b/spring-web/src/main/java/org/springframework/http/codec/multipart/MultipartException.java new file mode 100644 index 00000000000..c2d78e0fa86 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/codec/multipart/MultipartException.java @@ -0,0 +1,32 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.codec.multipart; + +/** + * @author Brian Clozel + */ +@SuppressWarnings("serial") +public class MultipartException extends RuntimeException { + + public MultipartException(String message) { + super(message); + } + + public MultipartException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/spring-web/src/main/java/org/springframework/http/codec/multipart/SynchronossPartHttpMessageReader.java b/spring-web/src/main/java/org/springframework/http/codec/multipart/SynchronossPartHttpMessageReader.java index 338222e7c55..6caa02e7b1b 100644 --- a/spring-web/src/main/java/org/springframework/http/codec/multipart/SynchronossPartHttpMessageReader.java +++ b/spring-web/src/main/java/org/springframework/http/codec/multipart/SynchronossPartHttpMessageReader.java @@ -40,9 +40,11 @@ import org.synchronoss.cloud.nio.multipart.NioMultipartParser; import org.synchronoss.cloud.nio.multipart.NioMultipartParserListener; import org.synchronoss.cloud.nio.multipart.PartBodyStreamStorageFactory; import org.synchronoss.cloud.nio.stream.storage.StreamStorage; +import reactor.core.publisher.BaseSubscriber; import reactor.core.publisher.Flux; import reactor.core.publisher.FluxSink; import reactor.core.publisher.Mono; +import reactor.core.publisher.SignalType; import org.springframework.core.ResolvableType; import org.springframework.core.codec.Hints; @@ -69,6 +71,7 @@ import org.springframework.util.Assert; * @author Sebastien Deleuze * @author Rossen Stoyanchev * @author Arjen Poutsma + * @author Brian Clozel * @since 5.0 * @see Synchronoss NIO Multipart * @see MultipartHttpMessageReader @@ -79,6 +82,12 @@ public class SynchronossPartHttpMessageReader extends LoggingCodecSupport implem private final PartBodyStreamStorageFactory streamStorageFactory = new DefaultPartBodyStreamStorageFactory(); + private long maxPartCount = -1; + + private long maxFilePartSize = -1; + + private long maxPartSize = -1; + public SynchronossPartHttpMessageReader() { this.bufferFactory = new DefaultDataBufferFactory(); } @@ -87,6 +96,55 @@ public class SynchronossPartHttpMessageReader extends LoggingCodecSupport implem this.bufferFactory = bufferFactory; } + /** + * Get the maximum number of parts allowed in a single multipart request. + * @since 5.1.11 + */ + public long getMaxPartCount() { + return maxPartCount; + } + + /** + * Configure the maximum number of parts allowed in a single multipart request. + * @since 5.1.11 + */ + public void setMaxPartCount(long maxPartCount) { + this.maxPartCount = maxPartCount; + } + + /** + * Get the maximum size of a file part. + * @since 5.1.11 + */ + public long getMaxFilePartSize() { + return this.maxFilePartSize; + } + + /** + * Configure the the maximum size of a file part. + * @since 5.1.11 + */ + public void setMaxFilePartSize(long maxFilePartSize) { + this.maxFilePartSize = maxFilePartSize; + } + + /** + * Get the maximum size of a part. + * @since 5.1.11 + */ + public long getMaxPartSize() { + return this.maxPartSize; + } + + /** + * Configure the maximum size of a part. + * For limits on file parts, use the dedicated {@link #setMaxFilePartSize(long)}. + * @since 5.1.11 + */ + public void setMaxPartSize(long maxPartSize) { + this.maxPartSize = maxPartSize; + } + @Override public List getReadableMediaTypes() { return Collections.singletonList(MediaType.MULTIPART_FORM_DATA); @@ -101,7 +159,8 @@ public class SynchronossPartHttpMessageReader extends LoggingCodecSupport implem @Override public Flux read(ResolvableType elementType, ReactiveHttpInputMessage message, Map hints) { - return Flux.create(new SynchronossPartGenerator(message, this.bufferFactory, this.streamStorageFactory)) + return Flux.create(new SynchronossPartGenerator(message, this.bufferFactory, this.streamStorageFactory, + new MultipartSizeLimiter(getMaxPartCount(), getMaxFilePartSize(), getMaxPartSize()))) .doOnNext(part -> { if (!Hints.isLoggingSuppressed(hints)) { LogFormatUtils.traceDebug(logger, traceOn -> Hints.getLogPrefix(hints) + "Parsed " + @@ -119,11 +178,60 @@ public class SynchronossPartHttpMessageReader extends LoggingCodecSupport implem } + private static class MultipartSizeLimiter { + + private final long maxPartCount; + + private final long maxFilePartSize; + + private final long maxPartSize; + + private boolean currentIsFilePart; + + private long currentPartCount; + + private long currentPartSize; + + + public MultipartSizeLimiter(long maxPartCount, long maxFilePartSize, long maxPartSize) { + this.maxPartCount = maxPartCount; + this.maxFilePartSize = maxFilePartSize; + this.maxPartSize = maxPartSize; + } + + public void startPart(boolean isFilePart) { + this.currentPartCount++; + this.currentIsFilePart = isFilePart; + if (this.maxPartCount != -1 && this.currentPartCount > this.maxPartCount) { + throw new IllegalStateException("Exceeded limit on maximum number of multipart parts"); + } + } + + public void endPart() { + this.currentPartSize = 0L; + this.currentIsFilePart = false; + } + + public void checkCurrentPartSize(long addedBytes) { + this.currentPartSize += addedBytes; + if (this.currentIsFilePart && this.maxFilePartSize != -1 && this.currentPartSize > this.maxFilePartSize) { + throw new IllegalStateException("Exceeded limit on max size of multipart file : " + this.maxFilePartSize); + } + else if (!this.currentIsFilePart && this.maxPartSize != -1 && this.currentPartSize > this.maxPartSize) { + throw new IllegalStateException("Exceeded limit on max size of multipart part : " + this.maxPartSize); + } + } + + } + /** - * Consume and feed input to the Synchronoss parser, then listen for parser - * output events and adapt to {@code Flux>}. + * Consume {@code DataBuffer} as a {@code BaseSubscriber} of the request body + * and feed it as input to the Synchronoss parser. Also listen for parser + * output events and adapt them to {@code Flux>} to emit parts + * for subscribers. */ - private static class SynchronossPartGenerator implements Consumer> { + private static class SynchronossPartGenerator extends BaseSubscriber + implements Consumer> { private final ReactiveHttpInputMessage inputMessage; @@ -131,12 +239,19 @@ public class SynchronossPartHttpMessageReader extends LoggingCodecSupport implem private final PartBodyStreamStorageFactory streamStorageFactory; - SynchronossPartGenerator(ReactiveHttpInputMessage inputMessage, DataBufferFactory bufferFactory, - PartBodyStreamStorageFactory streamStorageFactory) { + private final MultipartSizeLimiter limiter; + + private NioMultipartParserListener listener; + + private NioMultipartParser parser; + + public SynchronossPartGenerator(ReactiveHttpInputMessage inputMessage, DataBufferFactory bufferFactory, + PartBodyStreamStorageFactory streamStorageFactory, MultipartSizeLimiter limiter) { this.inputMessage = inputMessage; this.bufferFactory = bufferFactory; - this.streamStorageFactory = streamStorageFactory; + this.streamStorageFactory = new PartBodyStreamStorageFactoryDecorator(streamStorageFactory, limiter); + this.limiter = limiter; } @Override @@ -149,40 +264,53 @@ public class SynchronossPartHttpMessageReader extends LoggingCodecSupport implem Charset charset = Optional.ofNullable(mediaType.getCharset()).orElse(StandardCharsets.UTF_8); MultipartContext context = new MultipartContext(mediaType.toString(), length, charset.name()); - NioMultipartParserListener listener = new FluxSinkAdapterListener(emitter, this.bufferFactory, context); - NioMultipartParser parser = Multipart + this.listener = new FluxSinkAdapterListener(emitter, this.bufferFactory, context, this.limiter); + this.parser = Multipart .multipart(context) .usePartBodyStreamStorageFactory(this.streamStorageFactory) - .forNIO(listener); + // long to int downcast vs. keeping the default 16Kb value + //.withHeadersSizeLimit(this.limiter.maxPartSize) + .forNIO(this.listener); + this.inputMessage.getBody().subscribe(this); + } - this.inputMessage.getBody().subscribe(buffer -> { - byte[] resultBytes = new byte[buffer.readableByteCount()]; - buffer.read(resultBytes); - try { - parser.write(resultBytes); - } - catch (IOException ex) { - listener.onError("Exception thrown while providing input to the parser", ex); - } - finally { - DataBufferUtils.release(buffer); - } - }, ex -> { - try { - parser.close(); - listener.onError("Request body input error", ex); - } - catch (IOException ex2) { - listener.onError("Exception thrown while closing the parser", ex2); - } - }, () -> { - try { - parser.close(); - } - catch (IOException ex) { - listener.onError("Exception thrown while closing the parser", ex); - } - }); + @Override + protected void hookOnNext(DataBuffer buffer) { + int readableByteCount = buffer.readableByteCount(); + this.limiter.checkCurrentPartSize(readableByteCount); + byte[] resultBytes = new byte[readableByteCount]; + buffer.read(resultBytes); + try { + parser.write(resultBytes); + } + catch (IOException ex) { + this.cancel(); + listener.onError("Exception thrown while providing input to the parser", ex); + } + finally { + DataBufferUtils.release(buffer); + } + } + + @Override + protected void hookOnError(Throwable throwable) { + this.cancel(); + listener.onError("Could not parse multipart request", throwable); + } + + @Override + protected void hookOnCancel() { + this.cancel(); + } + + @Override + protected void hookFinally(SignalType type) { + try { + parser.close(); + } + catch (IOException ex) { + listener.onError("Exception thrown while closing the parser", ex); + } } private int getContentLength(HttpHeaders headers) { @@ -192,6 +320,28 @@ public class SynchronossPartHttpMessageReader extends LoggingCodecSupport implem } } + private static class PartBodyStreamStorageFactoryDecorator implements PartBodyStreamStorageFactory { + + private final PartBodyStreamStorageFactory streamStorageFactory; + + private final MultipartSizeLimiter limiter; + + public PartBodyStreamStorageFactoryDecorator(PartBodyStreamStorageFactory streamStorageFactory, + MultipartSizeLimiter limiter) { + this.streamStorageFactory = streamStorageFactory; + this.limiter = limiter; + } + + @Override + public StreamStorage newStreamStorageForPartBody(Map> partHeaders, int partIndex) { + HttpHeaders httpHeaders = new HttpHeaders(); + httpHeaders.putAll(partHeaders); + String filename = MultipartUtils.getFileName(httpHeaders); + this.limiter.startPart(filename != null); + return streamStorageFactory.newStreamStorageForPartBody(partHeaders, partIndex); + } + } + /** * Listen for parser output and adapt to {@code Flux>}. @@ -204,12 +354,16 @@ public class SynchronossPartHttpMessageReader extends LoggingCodecSupport implem private final MultipartContext context; + private final MultipartSizeLimiter limiter; + private final AtomicInteger terminated = new AtomicInteger(0); - FluxSinkAdapterListener(FluxSink sink, DataBufferFactory factory, MultipartContext context) { + FluxSinkAdapterListener(FluxSink sink, DataBufferFactory factory, + MultipartContext context, MultipartSizeLimiter limiter) { this.sink = sink; this.bufferFactory = factory; this.context = context; + this.limiter = limiter; } @Override @@ -217,6 +371,7 @@ public class SynchronossPartHttpMessageReader extends LoggingCodecSupport implem HttpHeaders httpHeaders = new HttpHeaders(); httpHeaders.putAll(headers); this.sink.next(createPart(storage, httpHeaders)); + this.limiter.endPart(); } private Part createPart(StreamStorage storage, HttpHeaders httpHeaders) { @@ -236,7 +391,7 @@ public class SynchronossPartHttpMessageReader extends LoggingCodecSupport implem @Override public void onError(String message, Throwable cause) { if (this.terminated.getAndIncrement() == 0) { - this.sink.error(new RuntimeException(message, cause)); + this.sink.error(new MultipartException(message, cause)); } } diff --git a/spring-web/src/test/java/org/springframework/http/codec/multipart/SynchronossPartHttpMessageReaderTests.java b/spring-web/src/test/java/org/springframework/http/codec/multipart/SynchronossPartHttpMessageReaderTests.java index b8adb9e6044..14b4bcf59d9 100644 --- a/spring-web/src/test/java/org/springframework/http/codec/multipart/SynchronossPartHttpMessageReaderTests.java +++ b/spring-web/src/test/java/org/springframework/http/codec/multipart/SynchronossPartHttpMessageReaderTests.java @@ -17,9 +17,11 @@ package org.springframework.http.codec.multipart; import java.io.File; +import java.io.IOException; import java.nio.charset.StandardCharsets; import java.time.Duration; import java.util.Map; +import java.util.function.Consumer; import org.junit.jupiter.api.Test; import org.reactivestreams.Subscription; @@ -91,32 +93,32 @@ public class SynchronossPartHttpMessageReaderTests extends AbstractLeakCheckingT ServerHttpRequest request = generateMultipartRequest(); MultiValueMap parts = this.reader.readMono(PARTS_ELEMENT_TYPE, request, emptyMap()).block(); - assertThat(parts).containsOnlyKeys("fooPart", "barPart"); + assertThat(parts).containsOnlyKeys("filePart", "textPart"); - Part part = parts.getFirst("fooPart"); + Part part = parts.getFirst("filePart"); assertThat(part).isInstanceOf(FilePart.class); - assertThat(part.name()).isEqualTo("fooPart"); + assertThat(part.name()).isEqualTo("filePart"); assertThat(((FilePart) part).filename()).isEqualTo("foo.txt"); DataBuffer buffer = DataBufferUtils.join(part.content()).block(); assertThat(DataBufferTestUtils.dumpString(buffer, StandardCharsets.UTF_8)).isEqualTo("Lorem Ipsum."); DataBufferUtils.release(buffer); - part = parts.getFirst("barPart"); + part = parts.getFirst("textPart"); assertThat(part).isInstanceOf(FormFieldPart.class); - assertThat(part.name()).isEqualTo("barPart"); - assertThat(((FormFieldPart) part).value()).isEqualTo("bar"); + assertThat(part.name()).isEqualTo("textPart"); + assertThat(((FormFieldPart) part).value()).isEqualTo("sample-text"); } @Test // SPR-16545 - void transferTo() { + void transferTo() throws IOException { ServerHttpRequest request = generateMultipartRequest(); MultiValueMap parts = this.reader.readMono(PARTS_ELEMENT_TYPE, request, emptyMap()).block(); assertThat(parts).isNotNull(); - FilePart part = (FilePart) parts.getFirst("fooPart"); + FilePart part = (FilePart) parts.getFirst("filePart"); assertThat(part).isNotNull(); - File dest = new File(System.getProperty("java.io.tmpdir") + "/" + part.filename()); + File dest = File.createTempFile(part.filename(), "multipart"); part.transferTo(dest).block(Duration.ofSeconds(5)); assertThat(dest.exists()).isTrue(); @@ -139,12 +141,57 @@ public class SynchronossPartHttpMessageReaderTests extends AbstractLeakCheckingT subscriber.cancel(); } + @Test + void readTooManyParts() { + testMultipartExceptions( + reader -> reader.setMaxPartCount(1), + err -> { + assertThat(err).isInstanceOf(MultipartException.class) + .hasMessage("Could not parse multipart request"); + assertThat(err.getCause()).hasMessage("Exceeded limit on maximum number of multipart parts"); + } + ); + } + + @Test + void readFilePartTooBig() { + testMultipartExceptions( + reader -> reader.setMaxFilePartSize(5), + err -> { + assertThat(err).isInstanceOf(MultipartException.class) + .hasMessage("Could not parse multipart request"); + assertThat(err.getCause()).hasMessage("Exceeded limit on max size of multipart file : 5"); + } + ); + } - private ServerHttpRequest generateMultipartRequest() { + @Test + void readPartTooBig() { + testMultipartExceptions( + reader -> reader.setMaxPartSize(6), + err -> { + assertThat(err).isInstanceOf(MultipartException.class) + .hasMessage("Could not parse multipart request"); + assertThat(err.getCause()).hasMessage("Exceeded limit on max size of multipart part : 6"); + } + ); + } + private void testMultipartExceptions(Consumer configurer, + Consumer assertions) { + SynchronossPartHttpMessageReader synchronossReader = new SynchronossPartHttpMessageReader(this.bufferFactory); + configurer.accept(synchronossReader); + MultipartHttpMessageReader reader = new MultipartHttpMessageReader(synchronossReader); + ServerHttpRequest request = generateMultipartRequest(); + StepVerifier.create(reader.readMono(PARTS_ELEMENT_TYPE, request, emptyMap())) + .consumeErrorWith(assertions) + .verify(); + } + + private ServerHttpRequest generateMultipartRequest() { MultipartBodyBuilder partsBuilder = new MultipartBodyBuilder(); - partsBuilder.part("fooPart", new ClassPathResource("org/springframework/http/codec/multipart/foo.txt")); - partsBuilder.part("barPart", "bar"); + partsBuilder.part("filePart", new ClassPathResource("org/springframework/http/codec/multipart/foo.txt")); + partsBuilder.part("textPart", "sample-text"); MockClientHttpRequest outputMessage = new MockClientHttpRequest(HttpMethod.POST, "/"); new MultipartHttpMessageWriter()