diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/ResponseEntityResultHandler.java b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/ResponseEntityResultHandler.java index 016c56fe097..d4b4b2438a0 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/ResponseEntityResultHandler.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/ResponseEntityResultHandler.java @@ -28,6 +28,7 @@ import org.springframework.core.ReactiveAdapterRegistry; import org.springframework.http.HttpEntity; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; +import org.springframework.http.HttpStatus; import org.springframework.http.RequestEntity; import org.springframework.http.ResponseEntity; import org.springframework.http.codec.HttpMessageWriter; @@ -100,8 +101,9 @@ public class ResponseEntityResultHandler extends AbstractMessageWriterResultHand } private boolean isSupportedType(@Nullable Class clazz) { - return (clazz != null && HttpEntity.class.isAssignableFrom(clazz) && - !RequestEntity.class.isAssignableFrom(clazz)); + return (clazz != null && ((HttpEntity.class.isAssignableFrom(clazz) && + !RequestEntity.class.isAssignableFrom(clazz)) || + HttpHeaders.class.isAssignableFrom(clazz))); } @@ -123,8 +125,17 @@ public class ResponseEntityResultHandler extends AbstractMessageWriterResultHand } return returnValueMono.flatMap(returnValue -> { - Assert.isInstanceOf(HttpEntity.class, returnValue, "HttpEntity expected"); - HttpEntity httpEntity = (HttpEntity) returnValue; + HttpEntity httpEntity; + if (returnValue instanceof HttpEntity) { + httpEntity = (HttpEntity) returnValue; + } + else if (returnValue instanceof HttpHeaders) { + httpEntity = new ResponseEntity((HttpHeaders) returnValue, HttpStatus.OK); + } + else { + throw new IllegalArgumentException( + "HttpEntity or HttpHeaders expected but got: " + returnValue.getClass()); + } if (httpEntity instanceof ResponseEntity) { ResponseEntity responseEntity = (ResponseEntity) httpEntity; @@ -139,7 +150,8 @@ public class ResponseEntityResultHandler extends AbstractMessageWriterResultHand .filter(entry -> !responseHeaders.containsKey(entry.getKey())) .forEach(entry -> responseHeaders.put(entry.getKey(), entry.getValue())); } - if(httpEntity.getBody() == null) { + + if(httpEntity.getBody() == null || returnValue instanceof HttpHeaders) { return exchange.getResponse().setComplete(); } diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/ResponseEntityResultHandlerTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/ResponseEntityResultHandlerTests.java index bf6e74b697d..b1a675ef17c 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/ResponseEntityResultHandlerTests.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/ResponseEntityResultHandlerTests.java @@ -24,6 +24,7 @@ import java.time.temporal.ChronoUnit; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; +import java.util.LinkedHashSet; import java.util.List; import java.util.concurrent.CompletableFuture; @@ -41,6 +42,7 @@ import org.springframework.core.codec.ByteBufferEncoder; import org.springframework.core.codec.CharSequenceEncoder; import org.springframework.core.io.buffer.support.DataBufferTestUtils; import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; import org.springframework.http.HttpStatus; import org.springframework.http.ResponseEntity; import org.springframework.http.codec.EncoderHttpMessageWriter; @@ -121,6 +123,9 @@ public class ResponseEntityResultHandlerTests { returnType = on(TestController.class).resolveReturnType(CompletableFuture.class, entity(String.class)); assertTrue(this.resultHandler.supports(handlerResult(value, returnType))); + returnType = on(TestController.class).resolveReturnType(HttpHeaders.class); + assertTrue(this.resultHandler.supports(handlerResult(value, returnType))); + // SPR-15785 value = ResponseEntity.ok("testing"); returnType = on(TestController.class).resolveReturnType(Object.class); @@ -150,7 +155,7 @@ public class ResponseEntityResultHandlerTests { } @Test - public void statusCode() throws Exception { + public void responseEntityStatusCode() throws Exception { ResponseEntity value = ResponseEntity.noContent().build(); MethodParameter returnType = on(TestController.class).resolveReturnType(entity(Void.class)); HandlerResult result = handlerResult(value, returnType); @@ -163,7 +168,22 @@ public class ResponseEntityResultHandlerTests { } @Test - public void headers() throws Exception { + public void httpHeaders() throws Exception { + HttpHeaders headers = new HttpHeaders(); + headers.setAllow(new LinkedHashSet<>(Arrays.asList(HttpMethod.GET, HttpMethod.POST, HttpMethod.OPTIONS))); + MethodParameter returnType = on(TestController.class).resolveReturnType(entity(Void.class)); + HandlerResult result = handlerResult(headers, returnType); + MockServerWebExchange exchange = get("/path").toExchange(); + this.resultHandler.handleResult(exchange, result).block(Duration.ofSeconds(5)); + + assertEquals(HttpStatus.OK, exchange.getResponse().getStatusCode()); + assertEquals(1, exchange.getResponse().getHeaders().size()); + assertEquals("GET,POST,OPTIONS", exchange.getResponse().getHeaders().getFirst("Allow")); + assertResponseBodyIsEmpty(exchange); + } + + @Test + public void responseEntityHeaders() throws Exception { URI location = new URI("/path"); ResponseEntity value = ResponseEntity.created(location).build(); MethodParameter returnType = on(TestController.class).resolveReturnType(entity(Void.class)); @@ -382,6 +402,8 @@ public class ResponseEntityResultHandlerTests { ResponseEntity responseEntityVoid() { return null; } + HttpHeaders httpHeaders() { return null; } + Mono> mono() { return null; } Single> single() { return null; } diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/RequestMappingInfoHandlerMapping.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/RequestMappingInfoHandlerMapping.java index 079b020fd32..9d7c2a16786 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/RequestMappingInfoHandlerMapping.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/RequestMappingInfoHandlerMapping.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2016 the original author or authors. + * Copyright 2002-2017 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License.