diff --git a/spring-web/src/main/java/org/springframework/http/converter/FormHttpMessageConverter.java b/spring-web/src/main/java/org/springframework/http/converter/FormHttpMessageConverter.java index 1594dcb4208..37027107cef 100644 --- a/spring-web/src/main/java/org/springframework/http/converter/FormHttpMessageConverter.java +++ b/spring-web/src/main/java/org/springframework/http/converter/FormHttpMessageConverter.java @@ -28,6 +28,7 @@ import java.util.Collections; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import java.util.Objects; import org.jspecify.annotations.Nullable; @@ -485,9 +486,18 @@ public class FormHttpMessageConverter implements HttpMessageConverter { - writeParts(outputStream, parts, boundary); - writeEnd(outputStream, boundary); + boolean repeatable = checkPartsRepeatable(parts); + streamingOutputMessage.setBody(new StreamingHttpOutputMessage.Body() { + @Override + public void writeTo(OutputStream outputStream) throws IOException { + FormHttpMessageConverter.this.writeParts(outputStream, parts, boundary); + writeEnd(outputStream, boundary); + } + + @Override + public boolean repeatable() { + return repeatable; + } }); } else { @@ -496,6 +506,35 @@ public class FormHttpMessageConverter implements HttpMessageConverter boolean checkPartsRepeatable(MultiValueMap map) { + return map.entrySet().stream().allMatch(e -> e.getValue().stream().filter(Objects::nonNull).allMatch(part -> { + HttpHeaders headers = null; + Object body = part; + if (part instanceof HttpEntity entity) { + headers = entity.getHeaders(); + body = entity.getBody(); + Assert.state(body != null, "Empty body for part '" + e.getKey() + "': " + part); + } + HttpMessageConverter converter = findConverterFor(e.getKey(), headers, body); + return (converter instanceof AbstractHttpMessageConverter ahmc && + ((AbstractHttpMessageConverter) ahmc).supportsRepeatableWrites((T) body)); + })); + } + + private @Nullable HttpMessageConverter findConverterFor( + String name, @Nullable HttpHeaders headers, Object body) { + + Class partType = body.getClass(); + MediaType contentType = (headers != null ? headers.getContentType() : null); + for (HttpMessageConverter converter : this.partConverters) { + if (converter.canWrite(partType, contentType)) { + return converter; + } + } + return null; + } + /** * When {@link #setMultipartCharset(Charset)} is configured (i.e. RFC 2047, * {@code encoded-word} syntax) we need to use ASCII for part headers, or @@ -521,32 +560,27 @@ public class FormHttpMessageConverter implements HttpMessageConverter partEntity, OutputStream os) throws IOException { Object partBody = partEntity.getBody(); - if (partBody == null) { - throw new IllegalStateException("Empty body for part '" + name + "': " + partEntity); - } - Class partType = partBody.getClass(); + Assert.state(partBody != null, "Empty body for part '" + name + "': " + partEntity); HttpHeaders partHeaders = partEntity.getHeaders(); MediaType partContentType = partHeaders.getContentType(); - for (HttpMessageConverter messageConverter : this.partConverters) { - if (messageConverter.canWrite(partType, partContentType)) { - Charset charset = isFilenameCharsetSet() ? StandardCharsets.US_ASCII : this.charset; - HttpOutputMessage multipartMessage = new MultipartHttpOutputMessage(os, charset); - String filename = getFilename(partBody); - ContentDisposition.Builder cd = ContentDisposition.formData() - .name(name); - if (filename != null) { - cd.filename(filename, this.multipartCharset); - } - multipartMessage.getHeaders().setContentDisposition(cd.build()); - if (!partHeaders.isEmpty()) { - multipartMessage.getHeaders().putAll(partHeaders); - } - ((HttpMessageConverter) messageConverter).write(partBody, partContentType, multipartMessage); - return; + HttpMessageConverter converter = findConverterFor(name, partHeaders, partBody); + if (converter != null) { + Charset charset = isFilenameCharsetSet() ? StandardCharsets.US_ASCII : this.charset; + HttpOutputMessage multipartMessage = new MultipartHttpOutputMessage(os, charset); + String filename = getFilename(partBody); + ContentDisposition.Builder cd = ContentDisposition.formData().name(name); + if (filename != null) { + cd.filename(filename, this.multipartCharset); + } + multipartMessage.getHeaders().setContentDisposition(cd.build()); + if (!partHeaders.isEmpty()) { + multipartMessage.getHeaders().putAll(partHeaders); } + ((HttpMessageConverter) converter).write(partBody, partContentType, multipartMessage); + return; } - throw new HttpMessageNotWritableException("Could not write request: no suitable HttpMessageConverter " + - "found for request type [" + partType.getName() + "]"); + throw new HttpMessageNotWritableException("Could not write request: " + + "no suitable HttpMessageConverter found for request type [" + partBody.getClass().getName() + "]"); } /** diff --git a/spring-web/src/test/java/org/springframework/http/converter/FormHttpMessageConverterTests.java b/spring-web/src/test/java/org/springframework/http/converter/FormHttpMessageConverterTests.java index e0ac599e3f5..38255a69679 100644 --- a/spring-web/src/test/java/org/springframework/http/converter/FormHttpMessageConverterTests.java +++ b/spring-web/src/test/java/org/springframework/http/converter/FormHttpMessageConverterTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2024 the original author or authors. + * Copyright 2002-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -40,6 +40,7 @@ import org.springframework.core.io.Resource; import org.springframework.http.HttpEntity; import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; +import org.springframework.http.StreamingHttpOutputMessage; import org.springframework.http.converter.support.AllEncompassingFormHttpMessageConverter; import org.springframework.http.converter.xml.SourceHttpMessageConverter; import org.springframework.util.LinkedMultiValueMap; @@ -204,7 +205,7 @@ class FormHttpMessageConverterTests { parameters.put("charset", UTF_8.name()); parameters.put("foo", "bar"); - MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + StreamingMockHttpOutputMessage outputMessage = new StreamingMockHttpOutputMessage(); this.converter.write(parts, new MediaType("multipart", "form-data", parameters), outputMessage); final MediaType contentType = outputMessage.getHeaders().getContentType(); @@ -248,6 +249,8 @@ class FormHttpMessageConverterTests { item = items.get(5); assertThat(item.getFieldName()).isEqualTo("json"); assertThat(item.getContentType()).isEqualTo("application/json"); + + assertThat(outputMessage.wasRepeatable()).isTrue(); } @Test @@ -286,7 +289,7 @@ class FormHttpMessageConverterTests { parameters.put("charset", UTF_8.name()); parameters.put("foo", "bar"); - MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + StreamingMockHttpOutputMessage outputMessage = new StreamingMockHttpOutputMessage(); this.converter.write(parts, new MediaType("multipart", "form-data", parameters), outputMessage); final MediaType contentType = outputMessage.getHeaders().getContentType(); @@ -330,6 +333,8 @@ class FormHttpMessageConverterTests { item = items.get(5); assertThat(item.getFieldName()).isEqualTo("xml"); assertThat(item.getContentType()).isEqualTo("text/xml"); + + assertThat(outputMessage.wasRepeatable()).isFalse(); } @Test // SPR-13309 @@ -444,6 +449,27 @@ class FormHttpMessageConverterTests { } + private static class StreamingMockHttpOutputMessage extends MockHttpOutputMessage implements StreamingHttpOutputMessage { + + private boolean repeatable; + + public boolean wasRepeatable() { + return this.repeatable; + } + + @Override + public void setBody(Body body) { + try { + this.repeatable = body.repeatable(); + body.writeTo(getBody()); + } + catch (IOException ex) { + throw new RuntimeException(ex); + } + } + } + + private static class MockHttpOutputMessageRequestContext implements UploadContext { private final MockHttpOutputMessage outputMessage;