diff --git a/spring-web/src/main/java/org/springframework/web/method/annotation/RequestParamMethodArgumentResolver.java b/spring-web/src/main/java/org/springframework/web/method/annotation/RequestParamMethodArgumentResolver.java index b254bab033a..5ec3acbbe13 100644 --- a/spring-web/src/main/java/org/springframework/web/method/annotation/RequestParamMethodArgumentResolver.java +++ b/spring-web/src/main/java/org/springframework/web/method/annotation/RequestParamMethodArgumentResolver.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2012 the original author or authors. + * Copyright 2002-2013 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ package org.springframework.web.method.annotation; import java.beans.PropertyEditor; +import java.util.ArrayList; import java.util.Collection; import java.util.List; import java.util.Map; @@ -161,6 +162,10 @@ public class RequestParamMethodArgumentResolver extends AbstractNamedValueMethod assertIsMultipartRequest(servletRequest); arg = servletRequest.getPart(name); } + else if (isPartCollection(parameter)) { + assertIsMultipartRequest(servletRequest); + arg = new ArrayList(servletRequest.getParts()); + } else { arg = null; if (multipartRequest != null) { @@ -188,14 +193,24 @@ public class RequestParamMethodArgumentResolver extends AbstractNamedValueMethod } private boolean isMultipartFileCollection(MethodParameter parameter) { + Class collectionType = getCollectionParameterType(parameter); + return ((collectionType != null) && collectionType.equals(MultipartFile.class)); + } + + private boolean isPartCollection(MethodParameter parameter) { + Class collectionType = getCollectionParameterType(parameter); + return ((collectionType != null) && "javax.servlet.http.Part".equals(collectionType.getName())); + } + + private Class getCollectionParameterType(MethodParameter parameter) { Class paramType = parameter.getParameterType(); if (Collection.class.equals(paramType) || List.class.isAssignableFrom(paramType)){ Class valueType = GenericCollectionTypeResolver.getCollectionParameterType(parameter); - if (valueType != null && valueType.equals(MultipartFile.class)) { - return true; + if (valueType != null) { + return valueType; } } - return false; + return null; } @Override diff --git a/spring-web/src/test/java/org/springframework/web/method/annotation/RequestParamMethodArgumentResolverTests.java b/spring-web/src/test/java/org/springframework/web/method/annotation/RequestParamMethodArgumentResolverTests.java index 357146ec0d7..eec01e8a673 100644 --- a/spring-web/src/test/java/org/springframework/web/method/annotation/RequestParamMethodArgumentResolverTests.java +++ b/spring-web/src/test/java/org/springframework/web/method/annotation/RequestParamMethodArgumentResolverTests.java @@ -68,7 +68,7 @@ public class RequestParamMethodArgumentResolverTests { private MethodParameter paramStringNotAnnot; private MethodParameter paramMultipartFileNotAnnot; private MethodParameter paramMultipartFileList; - private MethodParameter paramServlet30Part; + private MethodParameter paramPart; private MethodParameter paramRequestPartAnnot; private MethodParameter paramRequired; private MethodParameter paramNotRequired; @@ -99,8 +99,8 @@ public class RequestParamMethodArgumentResolverTests { paramMultipartFileNotAnnot.initParameterNameDiscovery(paramNameDiscoverer); paramMultipartFileList = new MethodParameter(method, 7); paramMultipartFileList.initParameterNameDiscovery(paramNameDiscoverer); - paramServlet30Part = new MethodParameter(method, 8); - paramServlet30Part.initParameterNameDiscovery(paramNameDiscoverer); + paramPart = new MethodParameter(method, 8); + paramPart.initParameterNameDiscovery(paramNameDiscoverer); paramRequestPartAnnot = new MethodParameter(method, 9); paramRequired = new MethodParameter(method, 10); paramNotRequired = new MethodParameter(method, 11); @@ -119,7 +119,7 @@ public class RequestParamMethodArgumentResolverTests { assertFalse("non-@RequestParam parameter supported", resolver.supportsParameter(paramMap)); assertTrue("Simple type params supported w/o annotations", resolver.supportsParameter(paramStringNotAnnot)); assertTrue("MultipartFile parameter not supported", resolver.supportsParameter(paramMultipartFileNotAnnot)); - assertTrue("Part parameter not supported", resolver.supportsParameter(paramServlet30Part)); + assertTrue("Part parameter not supported", resolver.supportsParameter(paramPart)); resolver = new RequestParamMethodArgumentResolver(null, false); assertFalse(resolver.supportsParameter(paramStringNotAnnot)); @@ -220,15 +220,15 @@ public class RequestParamMethodArgumentResolverTests { } @Test - public void resolveServlet30Part() throws Exception { - MockPart expected = new MockPart("servlet30Part", "Hello World".getBytes()); + public void resolvePart() throws Exception { + MockPart expected = new MockPart("part", "Hello World".getBytes()); MockHttpServletRequest request = new MockHttpServletRequest(); request.setMethod("POST"); request.setContentType("multipart/form-data"); request.addPart(expected); webRequest = new ServletWebRequest(request); - Object result = resolver.resolveArgument(paramServlet30Part, null, webRequest, null); + Object result = resolver.resolveArgument(paramPart, null, webRequest, null); assertTrue(result instanceof Part); assertEquals("Invalid result", expected, result); @@ -331,7 +331,7 @@ public class RequestParamMethodArgumentResolverTests { String stringNotAnnot, MultipartFile multipartFileNotAnnot, List multipartFileList, - Part servlet30Part, + Part part, @RequestPart MultipartFile requestPartAnnot, @RequestParam(value = "name") String paramRequired, @RequestParam(value = "name", required=false) String paramNotRequired) { diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/RequestPartMethodArgumentResolver.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/RequestPartMethodArgumentResolver.java index f3474d102a8..e14d8bae198 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/RequestPartMethodArgumentResolver.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/RequestPartMethodArgumentResolver.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2012 the original author or authors. + * Copyright 2002-2013 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,8 +17,10 @@ package org.springframework.web.servlet.mvc.method.annotation; import java.lang.annotation.Annotation; +import java.util.ArrayList; import java.util.Collection; import java.util.List; + import javax.servlet.http.HttpServletRequest; import org.springframework.core.GenericCollectionTypeResolver; @@ -130,8 +132,13 @@ public class RequestPartMethodArgumentResolver extends AbstractMessageConverterM arg = multipartRequest.getFiles(partName); } else if ("javax.servlet.http.Part".equals(parameter.getParameterType().getName())) { + assertIsMultipartRequest(servletRequest); arg = servletRequest.getPart(partName); } + else if (isPartCollection(parameter)) { + assertIsMultipartRequest(servletRequest); + arg = new ArrayList(servletRequest.getParts()); + } else { try { HttpInputMessage inputMessage = new RequestPartServletServerHttpRequest(servletRequest, partName); @@ -177,14 +184,24 @@ public class RequestPartMethodArgumentResolver extends AbstractMessageConverterM } private boolean isMultipartFileCollection(MethodParameter parameter) { + Class collectionType = getCollectionParameterType(parameter); + return ((collectionType != null) && collectionType.equals(MultipartFile.class)); + } + + private boolean isPartCollection(MethodParameter parameter) { + Class collectionType = getCollectionParameterType(parameter); + return ((collectionType != null) && "javax.servlet.http.Part".equals(collectionType.getName())); + } + + private Class getCollectionParameterType(MethodParameter parameter) { Class paramType = parameter.getParameterType(); if (Collection.class.equals(paramType) || List.class.isAssignableFrom(paramType)){ Class valueType = GenericCollectionTypeResolver.getCollectionParameterType(parameter); - if (valueType != null && valueType.equals(MultipartFile.class)) { - return true; + if (valueType != null) { + return valueType; } } - return false; + return null; } private void validate(WebDataBinder binder, MethodParameter parameter) throws MethodArgumentNotValidException { diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/RequestPartMethodArgumentResolverTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/RequestPartMethodArgumentResolverTests.java index dd5a4570bdd..46a4c00cc21 100644 --- a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/RequestPartMethodArgumentResolverTests.java +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/RequestPartMethodArgumentResolverTests.java @@ -53,6 +53,8 @@ import org.springframework.web.multipart.support.RequestPartServletServerHttpReq import static org.junit.Assert.*; import static org.mockito.BDDMockito.*; +import static org.mockito.Matchers.*; +import static org.mockito.Mockito.*; /** * Test fixture with {@link RequestPartMethodArgumentResolver} and mock {@link HttpMessageConverter}. @@ -75,7 +77,8 @@ public class RequestPartMethodArgumentResolverTests { private MethodParameter paramMultipartFileList; private MethodParameter paramInt; private MethodParameter paramMultipartFileNotAnnot; - private MethodParameter paramServlet30Part; + private MethodParameter paramPart; + private MethodParameter paramPartList; private MethodParameter paramRequestParamAnnot; private NativeWebRequest webRequest; @@ -88,8 +91,9 @@ public class RequestPartMethodArgumentResolverTests { @Before public void setUp() throws Exception { - Method method = getClass().getMethod("handle", SimpleBean.class, SimpleBean.class, SimpleBean.class, - MultipartFile.class, List.class, Integer.TYPE, MultipartFile.class, Part.class, MultipartFile.class); + Method method = getClass().getMethod("handle", SimpleBean.class, SimpleBean.class, + SimpleBean.class, MultipartFile.class, List.class, Integer.TYPE, + MultipartFile.class, Part.class, List.class, MultipartFile.class); paramRequestPart = new MethodParameter(method, 0); paramRequestPart.initParameterNameDiscovery(new LocalVariableTableParameterNameDiscoverer()); @@ -100,9 +104,10 @@ public class RequestPartMethodArgumentResolverTests { paramInt = new MethodParameter(method, 5); paramMultipartFileNotAnnot = new MethodParameter(method, 6); paramMultipartFileNotAnnot.initParameterNameDiscovery(new LocalVariableTableParameterNameDiscoverer()); - paramServlet30Part = new MethodParameter(method, 7); - paramServlet30Part.initParameterNameDiscovery(new LocalVariableTableParameterNameDiscoverer()); - paramRequestParamAnnot = new MethodParameter(method, 8); + paramPart = new MethodParameter(method, 7); + paramPart.initParameterNameDiscovery(new LocalVariableTableParameterNameDiscoverer()); + paramPartList = new MethodParameter(method, 8); + paramRequestParamAnnot = new MethodParameter(method, 9); messageConverter = mock(HttpMessageConverter.class); given(messageConverter.getSupportedMediaTypes()).willReturn(Collections.singletonList(MediaType.TEXT_PLAIN)); @@ -123,7 +128,7 @@ public class RequestPartMethodArgumentResolverTests { public void supportsParameter() { assertTrue("RequestPart parameter not supported", resolver.supportsParameter(paramRequestPart)); assertTrue("MultipartFile parameter not supported", resolver.supportsParameter(paramMultipartFileNotAnnot)); - assertTrue("Part parameter not supported", resolver.supportsParameter(paramServlet30Part)); + assertTrue("Part parameter not supported", resolver.supportsParameter(paramPart)); assertFalse("non-RequestPart parameter supported", resolver.supportsParameter(paramInt)); assertFalse("@RequestParam args not supported", resolver.supportsParameter(paramRequestParamAnnot)); } @@ -157,20 +162,37 @@ public class RequestPartMethodArgumentResolverTests { } @Test - public void resolveServlet30PartArgument() throws Exception { - MockPart expected = new MockPart("servlet30Part", "Hello World".getBytes()); + public void resolvePartArgument() throws Exception { + MockPart expected = new MockPart("part", "Hello World".getBytes()); MockHttpServletRequest request = new MockHttpServletRequest(); request.setMethod("POST"); request.setContentType("multipart/form-data"); request.addPart(expected); webRequest = new ServletWebRequest(request); - Object result = resolver.resolveArgument(paramServlet30Part, null, webRequest, null); + Object result = resolver.resolveArgument(paramPart, null, webRequest, null); assertTrue(result instanceof Part); assertEquals("Invalid result", expected, result); } + @Test + public void resolvePartListArgument() throws Exception { + MockPart part1 = new MockPart("requestPart1", "Hello World 1".getBytes()); + MockPart part2 = new MockPart("requestPart2", "Hello World 2".getBytes()); + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setMethod("POST"); + request.setContentType("multipart/form-data"); + request.addPart(part1); + request.addPart(part2); + webRequest = new ServletWebRequest(request); + + Object result = resolver.resolveArgument(paramPartList, null, webRequest, null); + + assertTrue(result instanceof List); + assertEquals(Arrays.asList(part1, part2), result); + } + @Test public void resolveRequestPart() throws Exception { testResolveArgument(new SimpleBean("foo"), paramRequestPart); @@ -276,7 +298,8 @@ public class RequestPartMethodArgumentResolverTests { @RequestPart("requestPart") List multipartFileList, int i, MultipartFile multipartFileNotAnnot, - Part servlet30Part, + Part part, + @RequestPart("requestPart") List partList, @RequestParam MultipartFile requestParamAnnot) { }