diff --git a/spring-web/src/main/java/org/springframework/web/client/RestTemplate.java b/spring-web/src/main/java/org/springframework/web/client/RestTemplate.java index 9ccd5e4dd73..4e21a9f9e06 100644 --- a/spring-web/src/main/java/org/springframework/web/client/RestTemplate.java +++ b/spring-web/src/main/java/org/springframework/web/client/RestTemplate.java @@ -739,6 +739,9 @@ public class RestTemplate extends InterceptingHttpAccessor implements RestOperat if (requestCallback != null) { requestCallback.doWithRequest(request); } + if ((method == HttpMethod.DELETE || method == HttpMethod.PUT) && request.getHeaders().getAccept().isEmpty()) { + request.getHeaders().add("Accept", "*/*"); + } response = request.execute(); handleResponse(url, method, response); return (responseExtractor != null ? responseExtractor.extractData(response) : null); diff --git a/spring-web/src/test/java/org/springframework/web/client/RestTemplateTests.java b/spring-web/src/test/java/org/springframework/web/client/RestTemplateTests.java index fbe20803a06..f899e6802ae 100644 --- a/spring-web/src/test/java/org/springframework/web/client/RestTemplateTests.java +++ b/spring-web/src/test/java/org/springframework/web/client/RestTemplateTests.java @@ -25,8 +25,13 @@ import java.util.EnumSet; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Map.Entry; import java.util.Set; +import java.util.stream.Collectors; +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import okhttp3.mockwebserver.RecordedRequest; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -43,6 +48,7 @@ import org.springframework.http.client.ClientHttpRequestFactory; import org.springframework.http.client.ClientHttpRequestInitializer; import org.springframework.http.client.ClientHttpRequestInterceptor; import org.springframework.http.client.ClientHttpResponse; +import org.springframework.http.client.SimpleClientHttpRequestFactory; import org.springframework.http.converter.GenericHttpMessageConverter; import org.springframework.http.converter.HttpMessageConverter; import org.springframework.util.StreamUtils; @@ -486,6 +492,47 @@ public class RestTemplateTests { verify(response).close(); } + @Test + public void headerAcceptAllOnPut() throws Exception { + MockWebServer server = new MockWebServer(); + server.enqueue(new MockResponse().setResponseCode(500).setBody("internal server error")); + server.start(); + template.setRequestFactory(new SimpleClientHttpRequestFactory()); + + template.put(server.url("/internal/server/error").uri(), null); + + RecordedRequest request = server.takeRequest(); + assertThat(request.getHeader("Accept")).isEqualTo("*/*"); + + server.shutdown(); + } + + @Test + public void keepGivenAcceptHeaderOnPut() throws Exception { + MockWebServer server = new MockWebServer(); + server.enqueue(new MockResponse().setResponseCode(500).setBody("internal server error")); + server.start(); + + template.setRequestFactory(new SimpleClientHttpRequestFactory()); + HttpHeaders headers = new HttpHeaders(); + headers.setAccept(Collections.singletonList(MediaType.APPLICATION_JSON)); + HttpEntity entity = new HttpEntity<>(null, headers); + template.exchange(server.url("/internal/server/error").uri(), PUT, entity, Void.class); + + RecordedRequest request = server.takeRequest(); + + final List> accepts = request.getHeaders().toMultimap().entrySet().stream() + .filter(entry -> entry.getKey().equalsIgnoreCase("accept")) + .map(Entry::getValue) + .collect(Collectors.toList()); + + assertThat(accepts).hasSize(1); + assertThat(accepts.get(0)).hasSize(1); + assertThat(accepts.get(0).get(0)).isEqualTo("application/json"); + + server.shutdown(); + } + @Test public void patchForObject() throws Exception { mockTextPlainHttpMessageConverter(); @@ -532,6 +579,21 @@ public class RestTemplateTests { verify(response).close(); } + @Test + public void headerAcceptAllOnDelete() throws Exception { + MockWebServer server = new MockWebServer(); + server.enqueue(new MockResponse().setResponseCode(500).setBody("internal server error")); + server.start(); + template.setRequestFactory(new SimpleClientHttpRequestFactory()); + + template.delete(server.url("/internal/server/error").uri()); + + RecordedRequest request = server.takeRequest(); + assertThat(request.getHeader("Accept")).isEqualTo("*/*"); + + server.shutdown(); + } + @Test public void optionsForAllow() throws Exception { mockSentRequest(OPTIONS, "https://example.com");