diff --git a/spring-web/src/main/java/org/springframework/web/bind/support/WebRequestDataBinder.java b/spring-web/src/main/java/org/springframework/web/bind/support/WebRequestDataBinder.java index b896470e9c0..164c77ff415 100644 --- a/spring-web/src/main/java/org/springframework/web/bind/support/WebRequestDataBinder.java +++ b/spring-web/src/main/java/org/springframework/web/bind/support/WebRequestDataBinder.java @@ -17,6 +17,8 @@ package org.springframework.web.bind.support; import java.io.IOException; +import java.util.List; +import java.util.Map; import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; @@ -24,6 +26,8 @@ import javax.servlet.http.Part; import org.springframework.beans.MutablePropertyValues; import org.springframework.util.ClassUtils; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; import org.springframework.util.StringUtils; import org.springframework.validation.BindException; import org.springframework.web.bind.WebDataBinder; @@ -119,7 +123,7 @@ public class WebRequestDataBinder extends WebDataBinder { } else if (ClassUtils.hasMethod(HttpServletRequest.class, "getParts")) { HttpServletRequest serlvetRequest = ((NativeWebRequest) request).getNativeRequest(HttpServletRequest.class); - new Servlet3MultipartHelper().bindParts(serlvetRequest, mpvs); + new Servlet3MultipartHelper(isBindEmptyMultipartFiles()).bindParts(serlvetRequest, mpvs); } } doBind(mpvs); @@ -154,10 +158,30 @@ public class WebRequestDataBinder extends WebDataBinder { */ private static class Servlet3MultipartHelper { + private final boolean bindEmptyMultipartFiles; + + + public Servlet3MultipartHelper(boolean bindEmptyMultipartFiles) { + this.bindEmptyMultipartFiles = bindEmptyMultipartFiles; + } + + public void bindParts(HttpServletRequest request, MutablePropertyValues mpvs) { try { - for(Part part : request.getParts()) { - mpvs.add(part.getName(), part); + MultiValueMap map = new LinkedMultiValueMap(); + for (Part part : request.getParts()) { + map.add(part.getName(), part); + } + for (Map.Entry> entry: map.entrySet()) { + if (entry.getValue().size() == 1) { + Part part = entry.getValue().get(0); + if (this.bindEmptyMultipartFiles || part.getSize() > 0) { + mpvs.add(entry.getKey(), part); + } + } + else { + mpvs.add(entry.getKey(), entry.getValue()); + } } } catch (IOException ex) { diff --git a/spring-web/src/test/java/org/springframework/web/bind/support/WebRequestDataBinderIntegrationTests.java b/spring-web/src/test/java/org/springframework/web/bind/support/WebRequestDataBinderIntegrationTests.java index e859bfca3fd..c89e38aa38f 100644 --- a/spring-web/src/test/java/org/springframework/web/bind/support/WebRequestDataBinderIntegrationTests.java +++ b/spring-web/src/test/java/org/springframework/web/bind/support/WebRequestDataBinderIntegrationTests.java @@ -21,7 +21,6 @@ import java.util.List; import javax.servlet.MultipartConfigElement; import javax.servlet.ServletException; -import javax.servlet.annotation.MultipartConfig; import javax.servlet.http.HttpServlet; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; @@ -38,7 +37,6 @@ import org.springframework.core.io.ClassPathResource; import org.springframework.core.io.Resource; import org.springframework.http.MediaType; import org.springframework.http.client.HttpComponentsClientHttpRequestFactory; -import org.springframework.mock.web.test.MockMultipartFile; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; import org.springframework.util.SocketUtils; @@ -109,7 +107,7 @@ public class WebRequestDataBinderIntegrationTests { partsServlet.setBean(bean); MultiValueMap parts = new LinkedMultiValueMap(); - MockMultipartFile firstPart = new MockMultipartFile("fileName", "aValue".getBytes()); + Resource firstPart = new ClassPathResource("/org/springframework/http/converter/logo.jpg"); parts.add("firstPart", firstPart); parts.add("secondPart", "secondValue"); @@ -134,7 +132,7 @@ public class WebRequestDataBinderIntegrationTests { template.postForLocation(baseUrl + "/partlist", parts); assertNotNull(bean.getPartList()); - assertEquals(parts.size(), bean.getPartList().size()); + assertEquals(parts.get("partList").size(), bean.getPartList().size()); } diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/RequestPartIntegrationTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/RequestPartIntegrationTests.java index 7ff9859aebb..f23e76a9281 100644 --- a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/RequestPartIntegrationTests.java +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/RequestPartIntegrationTests.java @@ -19,13 +19,14 @@ package org.springframework.web.servlet.mvc.method.annotation; import java.net.URI; import java.util.Arrays; +import javax.servlet.MultipartConfigElement; + import org.eclipse.jetty.server.Server; import org.eclipse.jetty.servlet.ServletContextHandler; import org.eclipse.jetty.servlet.ServletHolder; import org.junit.AfterClass; import org.junit.Before; import org.junit.BeforeClass; -import org.junit.Ignore; import org.junit.Test; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; @@ -62,6 +63,7 @@ import static org.junit.Assert.*; * Test access to parts of a multipart request with {@link RequestPart}. * * @author Rossen Stoyanchev + * @author Brian Clozel */ public class RequestPartIntegrationTests { @@ -91,10 +93,9 @@ public class RequestPartIntegrationTests { ServletHolder standardResolverServlet = new ServletHolder(DispatcherServlet.class); standardResolverServlet.setInitParameter("contextConfigLocation", config.getName()); standardResolverServlet.setInitParameter("contextClass", AnnotationConfigWebApplicationContext.class.getName()); + standardResolverServlet.getRegistration().setMultipartConfig(new MultipartConfigElement("")); handler.addServlet(standardResolverServlet, "/standard-resolver/*"); - // TODO: add Servlet 3.0 test case without MultipartResolver - server.setHandler(handler); server.start(); } @@ -123,7 +124,6 @@ public class RequestPartIntegrationTests { } @Test - @Ignore("jetty 6.1.9 doesn't support Servlet 3.0") public void standardMultipartResolver() throws Exception { testCreate(baseUrl + "/standard-resolver/test"); }