diff --git a/spring-web/src/main/java/org/springframework/http/client/SimpleBufferingClientHttpRequest.java b/spring-web/src/main/java/org/springframework/http/client/SimpleBufferingClientHttpRequest.java index 52c3eb60e81..394a8307635 100644 --- a/spring-web/src/main/java/org/springframework/http/client/SimpleBufferingClientHttpRequest.java +++ b/spring-web/src/main/java/org/springframework/http/client/SimpleBufferingClientHttpRequest.java @@ -91,6 +91,14 @@ final class SimpleBufferingClientHttpRequest extends AbstractBufferingClientHttp * @param headers the headers to add */ static void addHeaders(HttpURLConnection connection, HttpHeaders headers) { + String method = connection.getRequestMethod(); + if (method.equals("PUT") || method.equals("DELETE")) { + if (!StringUtils.hasText(headers.getFirst(HttpHeaders.ACCEPT))) { + // Avoid "text/html, image/gif, image/jpeg, *; q=.2, */*; q=.2" + // from HttpUrlConnection which prevents JSON error response details. + headers.set(HttpHeaders.ACCEPT, "*/*"); + } + } headers.forEach((headerName, headerValues) -> { if (HttpHeaders.COOKIE.equalsIgnoreCase(headerName)) { // RFC 6265 String headerValue = StringUtils.collectionToDelimitedString(headerValues, "; "); diff --git a/spring-web/src/test/java/org/springframework/http/client/SimpleClientHttpRequestFactoryTests.java b/spring-web/src/test/java/org/springframework/http/client/SimpleClientHttpRequestFactoryTests.java index f698b01728b..df8f9a58893 100644 --- a/spring-web/src/test/java/org/springframework/http/client/SimpleClientHttpRequestFactoryTests.java +++ b/spring-web/src/test/java/org/springframework/http/client/SimpleClientHttpRequestFactoryTests.java @@ -22,6 +22,7 @@ import org.junit.jupiter.api.Test; import org.springframework.http.HttpHeaders; +import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -31,9 +32,11 @@ import static org.mockito.Mockito.verify; */ public class SimpleClientHttpRequestFactoryTests { - @Test // SPR-13225 + + @Test // SPR-13225 public void headerWithNullValue() { HttpURLConnection urlConnection = mock(HttpURLConnection.class); + given(urlConnection.getRequestMethod()).willReturn("GET"); HttpHeaders headers = new HttpHeaders(); headers.set("foo", null); SimpleBufferingClientHttpRequest.addHeaders(urlConnection, headers); diff --git a/spring-web/src/test/java/org/springframework/web/client/RestTemplateTests.java b/spring-web/src/test/java/org/springframework/web/client/RestTemplateTests.java index fbe20803a06..3877572704d 100644 --- a/spring-web/src/test/java/org/springframework/web/client/RestTemplateTests.java +++ b/spring-web/src/test/java/org/springframework/web/client/RestTemplateTests.java @@ -25,8 +25,13 @@ import java.util.EnumSet; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Map.Entry; import java.util.Set; +import java.util.stream.Collectors; +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import okhttp3.mockwebserver.RecordedRequest; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -43,6 +48,7 @@ import org.springframework.http.client.ClientHttpRequestFactory; import org.springframework.http.client.ClientHttpRequestInitializer; import org.springframework.http.client.ClientHttpRequestInterceptor; import org.springframework.http.client.ClientHttpResponse; +import org.springframework.http.client.SimpleClientHttpRequestFactory; import org.springframework.http.converter.GenericHttpMessageConverter; import org.springframework.http.converter.HttpMessageConverter; import org.springframework.util.StreamUtils; @@ -486,6 +492,49 @@ public class RestTemplateTests { verify(response).close(); } + @Test // gh-23740 + public void headerAcceptAllOnPut() throws Exception { + MockWebServer server = new MockWebServer(); + server.enqueue(new MockResponse().setResponseCode(500).setBody("internal server error")); + server.start(); + try { + template.setRequestFactory(new SimpleClientHttpRequestFactory()); + template.put(server.url("/internal/server/error").uri(), null); + assertThat(server.takeRequest().getHeader("Accept")).isEqualTo("*/*"); + } + finally { + server.shutdown(); + } + } + + @Test // gh-23740 + public void keepGivenAcceptHeaderOnPut() throws Exception { + MockWebServer server = new MockWebServer(); + server.enqueue(new MockResponse().setResponseCode(500).setBody("internal server error")); + server.start(); + try { + HttpHeaders headers = new HttpHeaders(); + headers.setAccept(Collections.singletonList(MediaType.APPLICATION_JSON)); + HttpEntity entity = new HttpEntity<>(null, headers); + template.setRequestFactory(new SimpleClientHttpRequestFactory()); + template.exchange(server.url("/internal/server/error").uri(), PUT, entity, Void.class); + + RecordedRequest request = server.takeRequest(); + + final List> accepts = request.getHeaders().toMultimap().entrySet().stream() + .filter(entry -> entry.getKey().equalsIgnoreCase("accept")) + .map(Entry::getValue) + .collect(Collectors.toList()); + + assertThat(accepts).hasSize(1); + assertThat(accepts.get(0)).hasSize(1); + assertThat(accepts.get(0).get(0)).isEqualTo("application/json"); + } + finally { + server.shutdown(); + } + } + @Test public void patchForObject() throws Exception { mockTextPlainHttpMessageConverter(); @@ -532,6 +581,21 @@ public class RestTemplateTests { verify(response).close(); } + @Test // gh-23740 + public void headerAcceptAllOnDelete() throws Exception { + MockWebServer server = new MockWebServer(); + server.enqueue(new MockResponse().setResponseCode(500).setBody("internal server error")); + server.start(); + try { + template.setRequestFactory(new SimpleClientHttpRequestFactory()); + template.delete(server.url("/internal/server/error").uri()); + assertThat(server.takeRequest().getHeader("Accept")).isEqualTo("*/*"); + } + finally { + server.shutdown(); + } + } + @Test public void optionsForAllow() throws Exception { mockSentRequest(OPTIONS, "https://example.com");