diff --git a/spring-web/src/main/java/org/springframework/http/codec/multipart/FilePart.java b/spring-web/src/main/java/org/springframework/http/codec/multipart/FilePart.java index 6a203fcf576..e1237ce76e1 100644 --- a/spring-web/src/main/java/org/springframework/http/codec/multipart/FilePart.java +++ b/spring-web/src/main/java/org/springframework/http/codec/multipart/FilePart.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2017 the original author or authors. + * Copyright 2002-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,7 +21,8 @@ import java.io.File; import reactor.core.publisher.Mono; /** - * Specialization of {@link Part} for a file upload. + * Specialization of {@link Part} that represents an uploaded file received in + * a multipart request. * * @author Rossen Stoyanchev * @since 5.0 @@ -34,7 +35,9 @@ public interface FilePart extends Part { String filename(); /** - * Transfer the file in this part to the given file destination. + * Convenience method to copy the content of the file in this part to the + * given destination file. If the destination file already exists, it will + * be truncated first. * @param dest the target file * @return completion {@code Mono} with the result of the file transfer, * possibly {@link IllegalStateException} if the part isn't a file diff --git a/spring-web/src/main/java/org/springframework/http/codec/multipart/SynchronossPartHttpMessageReader.java b/spring-web/src/main/java/org/springframework/http/codec/multipart/SynchronossPartHttpMessageReader.java index 65d02c8f624..023efa1f9e4 100644 --- a/spring-web/src/main/java/org/springframework/http/codec/multipart/SynchronossPartHttpMessageReader.java +++ b/spring-web/src/main/java/org/springframework/http/codec/multipart/SynchronossPartHttpMessageReader.java @@ -23,6 +23,7 @@ import java.nio.channels.FileChannel; import java.nio.channels.ReadableByteChannel; import java.nio.charset.Charset; import java.nio.charset.StandardCharsets; +import java.nio.file.OpenOption; import java.nio.file.StandardOpenOption; import java.util.Collections; import java.util.List; @@ -279,6 +280,9 @@ public class SynchronossPartHttpMessageReader implements HttpMessageReader private static class SynchronossFilePart extends DefaultSynchronossPart implements FilePart { + private static final OpenOption[] FILE_CHANNEL_OPTIONS = { + StandardOpenOption.CREATE, StandardOpenOption.TRUNCATE_EXISTING, StandardOpenOption.WRITE }; + private final String filename; public SynchronossFilePart( @@ -299,7 +303,7 @@ public class SynchronossPartHttpMessageReader implements HttpMessageReader FileChannel output = null; try { input = Channels.newChannel(getStorage().getInputStream()); - output = FileChannel.open(destination.toPath(), StandardOpenOption.WRITE); + output = FileChannel.open(destination.toPath(), FILE_CHANNEL_OPTIONS); long size = (input instanceof FileChannel ? ((FileChannel) input).size() : Long.MAX_VALUE); long totalWritten = 0; while (totalWritten < size) { diff --git a/spring-web/src/test/java/org/springframework/http/codec/multipart/SynchronossPartHttpMessageReaderTests.java b/spring-web/src/test/java/org/springframework/http/codec/multipart/SynchronossPartHttpMessageReaderTests.java index a0b85b3c4d4..f1393b6d0c6 100644 --- a/spring-web/src/test/java/org/springframework/http/codec/multipart/SynchronossPartHttpMessageReaderTests.java +++ b/spring-web/src/test/java/org/springframework/http/codec/multipart/SynchronossPartHttpMessageReaderTests.java @@ -16,11 +16,13 @@ package org.springframework.http.codec.multipart; -import java.io.IOException; +import java.io.File; +import java.time.Duration; import java.util.Map; import org.junit.Test; import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; import reactor.test.StepVerifier; import org.springframework.core.ResolvableType; @@ -28,25 +30,25 @@ import org.springframework.core.io.ClassPathResource; import org.springframework.core.io.buffer.DataBuffer; import org.springframework.core.io.buffer.DataBufferUtils; import org.springframework.core.io.buffer.DefaultDataBufferFactory; -import org.springframework.http.HttpEntity; -import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; -import org.springframework.http.MockHttpOutputMessage; -import org.springframework.http.converter.FormHttpMessageConverter; +import org.springframework.http.client.MultipartBodyBuilder; import org.springframework.http.server.reactive.ServerHttpRequest; +import org.springframework.mock.http.client.reactive.test.MockClientHttpRequest; import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest; -import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; -import static java.util.Collections.emptyMap; +import static java.util.Collections.*; import static org.junit.Assert.*; -import static org.springframework.core.ResolvableType.forClassWithGenerics; -import static org.springframework.http.HttpHeaders.CONTENT_LENGTH; -import static org.springframework.http.HttpHeaders.CONTENT_TYPE; -import static org.springframework.http.MediaType.MULTIPART_FORM_DATA; +import static org.springframework.core.ResolvableType.*; +import static org.springframework.http.HttpHeaders.*; +import static org.springframework.http.MediaType.*; /** + * Unit tests for {@link SynchronossPartHttpMessageReader}. + * * @author Sebastien Deleuze + * @author Rossen Stoyanchev */ public class SynchronossPartHttpMessageReaderTests { @@ -78,7 +80,7 @@ public class SynchronossPartHttpMessageReaderTests { } @Test - public void resolveParts() throws IOException { + public void resolveParts() { ServerHttpRequest request = generateMultipartRequest(); ResolvableType elementType = forClassWithGenerics(MultiValueMap.class, String.class, Part.class); MultiValueMap parts = this.reader.readMono(elementType, request, emptyMap()).block(); @@ -102,6 +104,24 @@ public class SynchronossPartHttpMessageReaderTests { assertEquals("bar", ((FormFieldPart) part).value()); } + @Test // SPR-16545 + public void transferTo() { + ServerHttpRequest request = generateMultipartRequest(); + ResolvableType elementType = forClassWithGenerics(MultiValueMap.class, String.class, Part.class); + MultiValueMap parts = this.reader.readMono(elementType, request, emptyMap()).block(); + + assertNotNull(parts); + FilePart part = (FilePart) parts.getFirst("fooPart"); + assertNotNull(part); + + File dest = new File(System.getProperty("java.io.tmpdir") + "/" + part.filename()); + part.transferTo(dest).block(Duration.ofSeconds(5)); + + assertTrue(dest.exists()); + assertEquals(12, dest.length()); + assertTrue(dest.delete()); + } + @Test public void bodyError() { ServerHttpRequest request = generateErrorMultipartRequest(); @@ -110,29 +130,24 @@ public class SynchronossPartHttpMessageReaderTests { } - private ServerHttpRequest generateMultipartRequest() throws IOException { - HttpHeaders fooHeaders = new HttpHeaders(); - fooHeaders.setContentType(MediaType.TEXT_PLAIN); - ClassPathResource fooResource = new ClassPathResource("org/springframework/http/codec/multipart/foo.txt"); - HttpEntity fooPart = new HttpEntity<>(fooResource, fooHeaders); - HttpEntity barPart = new HttpEntity<>("bar"); - FormHttpMessageConverter converter = new FormHttpMessageConverter(); - MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); - MultiValueMap parts = new LinkedMultiValueMap<>(); - parts.add("fooPart", fooPart); - parts.add("barPart", barPart); - converter.write(parts, MULTIPART_FORM_DATA, outputMessage); - byte[] content = outputMessage.getBodyAsBytes(); - return MockServerHttpRequest - .post("/foo") - .header(CONTENT_TYPE, outputMessage.getHeaders().getContentType().toString()) - .header(CONTENT_LENGTH, String.valueOf(content.length)) - .body(new String(content)); + private ServerHttpRequest generateMultipartRequest() { + + MultipartBodyBuilder partsBuilder = new MultipartBodyBuilder(); + partsBuilder.part("fooPart", new ClassPathResource("org/springframework/http/codec/multipart/foo.txt")); + partsBuilder.part("barPart", "bar"); + + MockClientHttpRequest outputMessage = new MockClientHttpRequest(HttpMethod.POST, "/"); + new MultipartHttpMessageWriter() + .write(Mono.just(partsBuilder.build()), null, MediaType.MULTIPART_FORM_DATA, outputMessage, null) + .block(Duration.ofSeconds(5)); + + return MockServerHttpRequest.post("/") + .contentType(outputMessage.getHeaders().getContentType()) + .body(outputMessage.getBody()); } private ServerHttpRequest generateErrorMultipartRequest() { - return MockServerHttpRequest - .post("/foo") + return MockServerHttpRequest.post("/") .header(CONTENT_TYPE, MULTIPART_FORM_DATA.toString()) .body(Flux.just(new DefaultDataBufferFactory().wrap("invalid content".getBytes()))); }