diff --git a/spring-test/src/main/java/org/springframework/mock/web/MockMultipartHttpServletRequest.java b/spring-test/src/main/java/org/springframework/mock/web/MockMultipartHttpServletRequest.java index 935d66aee71..f9607f5bb2f 100644 --- a/spring-test/src/main/java/org/springframework/mock/web/MockMultipartHttpServletRequest.java +++ b/spring-test/src/main/java/org/springframework/mock/web/MockMultipartHttpServletRequest.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ package org.springframework.mock.web; +import java.io.IOException; import java.util.Collections; import java.util.Enumeration; import java.util.Iterator; @@ -23,6 +24,8 @@ import java.util.List; import java.util.Map; import javax.servlet.ServletContext; +import javax.servlet.ServletException; +import javax.servlet.http.Part; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; @@ -121,9 +124,17 @@ public class MockMultipartHttpServletRequest extends MockHttpServletRequest impl if (file != null) { return file.getContentType(); } - else { - return null; + try { + Part part = getPart(paramOrFileName); + if (part != null) { + return part.getContentType(); + } + } + catch (ServletException | IOException ex) { + // Should never happen (we're not actually parsing) + throw new IllegalStateException(ex); } + return null; } @Override @@ -147,7 +158,7 @@ public class MockMultipartHttpServletRequest extends MockHttpServletRequest impl String contentType = getMultipartContentType(paramOrFileName); if (contentType != null) { HttpHeaders headers = new HttpHeaders(); - headers.add("Content-Type", contentType); + headers.add(HttpHeaders.CONTENT_TYPE, contentType); return headers; } else { diff --git a/spring-web/src/test/java/org/springframework/mock/web/test/MockMultipartHttpServletRequest.java b/spring-web/src/test/java/org/springframework/mock/web/test/MockMultipartHttpServletRequest.java index 7aeab6f98bf..d96a29b4168 100644 --- a/spring-web/src/test/java/org/springframework/mock/web/test/MockMultipartHttpServletRequest.java +++ b/spring-web/src/test/java/org/springframework/mock/web/test/MockMultipartHttpServletRequest.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ package org.springframework.mock.web.test; +import java.io.IOException; import java.util.Collections; import java.util.Enumeration; import java.util.Iterator; @@ -23,6 +24,8 @@ import java.util.List; import java.util.Map; import javax.servlet.ServletContext; +import javax.servlet.ServletException; +import javax.servlet.http.Part; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; @@ -121,9 +124,17 @@ public class MockMultipartHttpServletRequest extends MockHttpServletRequest impl if (file != null) { return file.getContentType(); } - else { - return null; + try { + Part part = getPart(paramOrFileName); + if (part != null) { + return part.getContentType(); + } + } + catch (ServletException | IOException ex) { + // Should never happen (we're not actually parsing) + throw new IllegalStateException(ex); } + return null; } @Override @@ -147,7 +158,7 @@ public class MockMultipartHttpServletRequest extends MockHttpServletRequest impl String contentType = getMultipartContentType(paramOrFileName); if (contentType != null) { HttpHeaders headers = new HttpHeaders(); - headers.add("Content-Type", contentType); + headers.add(HttpHeaders.CONTENT_TYPE, contentType); return headers; } else { diff --git a/spring-web/src/test/java/org/springframework/web/multipart/support/RequestPartServletServerHttpRequestTests.java b/spring-web/src/test/java/org/springframework/web/multipart/support/RequestPartServletServerHttpRequestTests.java index a42d1b335e8..87082cc8bcf 100644 --- a/spring-web/src/test/java/org/springframework/web/multipart/support/RequestPartServletServerHttpRequestTests.java +++ b/spring-web/src/test/java/org/springframework/web/multipart/support/RequestPartServletServerHttpRequestTests.java @@ -139,23 +139,13 @@ public class RequestPartServletServerHttpRequestTests { assertArrayEquals(bytes, result); } - @Test + @Test // gh-25829 public void getBodyViaRequestPart() throws Exception { - MockMultipartHttpServletRequest mockRequest = new MockMultipartHttpServletRequest() { - @Override - public HttpHeaders getMultipartHeaders(String paramOrFileName) { - HttpHeaders headers = new HttpHeaders(); - headers.setContentType(MediaType.APPLICATION_OCTET_STREAM); - return headers; - } - }; - byte[] bytes = "content".getBytes("UTF-8"); MockPart mockPart = new MockPart("part", bytes); mockPart.getHeaders().setContentType(MediaType.APPLICATION_JSON); - mockRequest.addPart(mockPart); - mockRequest.addPart(mockPart); - ServerHttpRequest request = new RequestPartServletServerHttpRequest(mockRequest, "part"); + this.mockRequest.addPart(mockPart); + ServerHttpRequest request = new RequestPartServletServerHttpRequest(this.mockRequest, "part"); byte[] result = FileCopyUtils.copyToByteArray(request.getBody()); assertArrayEquals(bytes, result);