diff --git a/spring-web/src/main/java/org/springframework/web/client/RestTemplate.java b/spring-web/src/main/java/org/springframework/web/client/RestTemplate.java index 4b56e73d52c..41f1c6c9a6c 100644 --- a/spring-web/src/main/java/org/springframework/web/client/RestTemplate.java +++ b/spring-web/src/main/java/org/springframework/web/client/RestTemplate.java @@ -852,6 +852,7 @@ public class RestTemplate extends InterceptingHttpAccessor implements RestOperat Class requestBodyClass = requestBody.getClass(); Type requestBodyType = (this.requestEntity instanceof RequestEntity ? ((RequestEntity)this.requestEntity).getType() : requestBodyClass); + HttpHeaders httpHeaders = httpRequest.getHeaders(); HttpHeaders requestHeaders = this.requestEntity.getHeaders(); MediaType requestContentType = requestHeaders.getContentType(); for (HttpMessageConverter messageConverter : getMessageConverters()) { @@ -859,7 +860,9 @@ public class RestTemplate extends InterceptingHttpAccessor implements RestOperat GenericHttpMessageConverter genericMessageConverter = (GenericHttpMessageConverter) messageConverter; if (genericMessageConverter.canWrite(requestBodyType, requestBodyClass, requestContentType)) { if (!requestHeaders.isEmpty()) { - httpRequest.getHeaders().putAll(requestHeaders); + for (Map.Entry> entry : requestHeaders.entrySet()) { + httpHeaders.put(entry.getKey(), new LinkedList(entry.getValue())); + } } if (logger.isDebugEnabled()) { if (requestContentType != null) { @@ -878,7 +881,9 @@ public class RestTemplate extends InterceptingHttpAccessor implements RestOperat } else if (messageConverter.canWrite(requestBodyClass, requestContentType)) { if (!requestHeaders.isEmpty()) { - httpRequest.getHeaders().putAll(requestHeaders); + for (Map.Entry> entry : requestHeaders.entrySet()) { + httpHeaders.put(entry.getKey(), new LinkedList(entry.getValue())); + } } if (logger.isDebugEnabled()) { if (requestContentType != null) { diff --git a/spring-web/src/test/java/org/springframework/web/client/RestTemplateTests.java b/spring-web/src/test/java/org/springframework/web/client/RestTemplateTests.java index 021c377089a..1ceb8ff7603 100644 --- a/spring-web/src/test/java/org/springframework/web/client/RestTemplateTests.java +++ b/spring-web/src/test/java/org/springframework/web/client/RestTemplateTests.java @@ -836,7 +836,7 @@ public class RestTemplateTests { } @Test // SPR-15066 - public void requestInterceptorCanAddExistingHeaderValue() throws Exception { + public void requestInterceptorCanAddExistingHeaderValueWithoutBody() throws Exception { ClientHttpRequestInterceptor interceptor = (request, body, execution) -> { request.getHeaders().add("MyHeader", "MyInterceptorValue"); return execution.execute(request, body); @@ -861,4 +861,35 @@ public class RestTemplateTests { verify(response).close(); } + @Test // SPR-15066 + public void requestInterceptorCanAddExistingHeaderValueWithBody() throws Exception { + ClientHttpRequestInterceptor interceptor = (request, body, execution) -> { + request.getHeaders().add("MyHeader", "MyInterceptorValue"); + return execution.execute(request, body); + }; + template.setInterceptors(Collections.singletonList(interceptor)); + + MediaType contentType = MediaType.TEXT_PLAIN; + given(converter.canWrite(String.class, contentType)).willReturn(true); + given(requestFactory.createRequest(new URI("http://example.com"), HttpMethod.POST)).willReturn(request); + String helloWorld = "Hello World"; + HttpHeaders requestHeaders = new HttpHeaders(); + given(request.getHeaders()).willReturn(requestHeaders); + converter.write(helloWorld, contentType, request); + given(request.execute()).willReturn(response); + given(errorHandler.hasError(response)).willReturn(false); + HttpStatus status = HttpStatus.OK; + given(response.getStatusCode()).willReturn(status); + given(response.getStatusText()).willReturn(status.getReasonPhrase()); + + HttpHeaders entityHeaders = new HttpHeaders(); + entityHeaders.setContentType(contentType); + entityHeaders.add("MyHeader", "MyEntityValue"); + HttpEntity entity = new HttpEntity<>(helloWorld, entityHeaders); + template.exchange("http://example.com", HttpMethod.POST, entity, Void.class); + assertThat(requestHeaders.get("MyHeader"), contains("MyEntityValue", "MyInterceptorValue")); + + verify(response).close(); + } + }