diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/BodyExtractors.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/BodyExtractors.java index 838c7d2642c..53f61e06666 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/function/BodyExtractors.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/BodyExtractors.java @@ -29,8 +29,10 @@ import reactor.core.publisher.Mono; import org.springframework.core.ParameterizedTypeReference; import org.springframework.core.ResolvableType; import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferUtils; import org.springframework.http.MediaType; import org.springframework.http.ReactiveHttpInputMessage; +import org.springframework.http.client.reactive.ClientHttpResponse; import org.springframework.http.codec.HttpMessageReader; import org.springframework.http.codec.multipart.Part; import org.springframework.http.server.reactive.ServerHttpRequest; @@ -42,6 +44,7 @@ import org.springframework.util.MultiValueMap; * @author Arjen Poutsma * @author Sebastien Deleuze * @author Rossen Stoyanchev + * @author Brian Clozel * @since 5.0 */ public abstract class BodyExtractors { @@ -81,8 +84,8 @@ public abstract class BodyExtractors { return (inputMessage, context) -> readWithMessageReaders(inputMessage, context, elementType, (HttpMessageReader reader) -> readToMono(inputMessage, context, elementType, reader), - ex -> Mono.from(unsupportedErrorHandler(inputMessage, ex)), - Mono::empty); + ex -> Mono.from(unsupportedErrorHandler(inputMessage, context, ex)), + skipBodyAsMono(inputMessage, context)); } /** @@ -110,8 +113,8 @@ public abstract class BodyExtractors { return (inputMessage, context) -> readWithMessageReaders(inputMessage, context, elementType, (HttpMessageReader reader) -> readToFlux(inputMessage, context, elementType, reader), - ex -> unsupportedErrorHandler(inputMessage, ex), - Flux::empty); + ex -> unsupportedErrorHandler(inputMessage, context, ex), + skipBodyAsFlux(inputMessage, context)); } @@ -183,7 +186,6 @@ public abstract class BodyExtractors { if (VOID_TYPE.equals(elementType)) { return emptySupplier.get(); } - MediaType contentType = Optional.ofNullable(message.getHeaders().getContentType()) .orElse(MediaType.APPLICATION_OCTET_STREAM); @@ -195,6 +197,28 @@ public abstract class BodyExtractors { .orElseGet(() -> errorFunction.apply(unsupportedError(context, elementType, contentType))); } + private static Supplier> skipBodyAsFlux(ReactiveHttpInputMessage message, + BodyExtractor.Context context) { + + if (isExtractingForClient(message, context)) { + return () -> consumeAndCancel(message).thenMany(Flux.empty()); + } + else { + return Flux::empty; + } + } + + private static Supplier> skipBodyAsMono(ReactiveHttpInputMessage message, + BodyExtractor.Context context) { + + if (isExtractingForClient(message, context)) { + return () -> consumeAndCancel(message).then(Mono.empty()); + } + else { + return Mono::empty; + } + } + private static UnsupportedMediaTypeException unsupportedError(BodyExtractor.Context context, ResolvableType elementType, MediaType contentType) { @@ -222,17 +246,21 @@ public abstract class BodyExtractors { } private static Flux unsupportedErrorHandler( - ReactiveHttpInputMessage inputMessage, UnsupportedMediaTypeException ex) { + ReactiveHttpInputMessage inputMessage, BodyExtractor.Context context, + UnsupportedMediaTypeException ex) { + Flux result; if (inputMessage.getHeaders().getContentType() == null) { // Empty body with no content type is ok - return inputMessage.getBody().map(o -> { + result = inputMessage.getBody().map(o -> { throw ex; }); } else { - return Flux.error(ex); + result = Flux.error(ex); } + return isExtractingForClient(inputMessage, context) ? + consumeAndCancel(inputMessage).thenMany(result) : result; } private static HttpMessageReader findReader( @@ -251,4 +279,23 @@ public abstract class BodyExtractors { return (HttpMessageReader) reader; } + private static boolean isExtractingForClient(ReactiveHttpInputMessage message, + BodyExtractor.Context context) { + return !context.serverResponse().isPresent() + && message instanceof ClientHttpResponse; + } + + private static Mono consumeAndCancel(ReactiveHttpInputMessage message) { + return message.getBody() + .map(buffer -> { + DataBufferUtils.release(buffer); + throw new ReadCancellationException(); + }) + .onErrorResume(ReadCancellationException.class, ex -> Mono.empty()) + .then(); + } + + @SuppressWarnings("serial") + private static class ReadCancellationException extends RuntimeException { + } } diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/DefaultClientResponse.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/DefaultClientResponse.java index f39a5d56f37..51e91f7dcb3 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/DefaultClientResponse.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/DefaultClientResponse.java @@ -27,7 +27,6 @@ import reactor.core.publisher.Mono; import org.springframework.core.ParameterizedTypeReference; import org.springframework.core.codec.Hints; -import org.springframework.core.io.buffer.DataBufferUtils; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; @@ -111,73 +110,32 @@ class DefaultClientResponse implements ClientResponse { @Override public Mono bodyToMono(Class elementClass) { - if (Void.class.isAssignableFrom(elementClass)) { - return consumeAndCancel(); - } - else { - return body(BodyExtractors.toMono(elementClass)); - } - } - - @SuppressWarnings("unchecked") - private Mono consumeAndCancel() { - return (Mono) this.response.getBody() - .map(buffer -> { - DataBufferUtils.release(buffer); - throw new ReadCancellationException(); - }) - .onErrorResume(ReadCancellationException.class, ex -> Mono.empty()) - .then(); + return body(BodyExtractors.toMono(elementClass)); } @Override public Mono bodyToMono(ParameterizedTypeReference typeReference) { - if (Void.class.isAssignableFrom(typeReference.getType().getClass())) { - return consumeAndCancel(); - } - else { - return body(BodyExtractors.toMono(typeReference)); - } + return body(BodyExtractors.toMono(typeReference)); } @Override public Flux bodyToFlux(Class elementClass) { - if (Void.class.isAssignableFrom(elementClass)) { - return Flux.from(consumeAndCancel()); - } - else { - return body(BodyExtractors.toFlux(elementClass)); - } + return body(BodyExtractors.toFlux(elementClass)); } @Override public Flux bodyToFlux(ParameterizedTypeReference typeReference) { - if (Void.class.isAssignableFrom(typeReference.getType().getClass())) { - return Flux.from(consumeAndCancel()); - } - else { - return body(BodyExtractors.toFlux(typeReference)); - } + return body(BodyExtractors.toFlux(typeReference)); } @Override public Mono> toEntity(Class bodyType) { - if (Void.class.isAssignableFrom(bodyType)) { - return toEntityInternal(consumeAndCancel()); - } - else { - return toEntityInternal(bodyToMono(bodyType)); - } + return toEntityInternal(bodyToMono(bodyType)); } @Override public Mono> toEntity(ParameterizedTypeReference typeReference) { - if (Void.class.isAssignableFrom(typeReference.getType().getClass())) { - return toEntityInternal(consumeAndCancel()); - } - else { - return toEntityInternal(bodyToMono(typeReference)); - } + return toEntityInternal(bodyToMono(typeReference)); } private Mono> toEntityInternal(Mono bodyMono) { @@ -254,9 +212,4 @@ class DefaultClientResponse implements ClientResponse { } } - - @SuppressWarnings("serial") - private static class ReadCancellationException extends RuntimeException { - } - } diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/function/BodyExtractorsTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/function/BodyExtractorsTests.java index 17bf294cc1d..8439155701c 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/function/BodyExtractorsTests.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/function/BodyExtractorsTests.java @@ -27,11 +27,15 @@ import java.util.Map; import java.util.Optional; import com.fasterxml.jackson.annotation.JsonView; +import io.netty.buffer.PooledByteBufAllocator; +import io.netty.util.IllegalReferenceCountException; +import org.junit.Assert; import org.junit.Before; import org.junit.Test; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; +import reactor.test.publisher.TestPublisher; import org.springframework.core.ParameterizedTypeReference; import org.springframework.core.codec.ByteBufferDecoder; @@ -39,6 +43,9 @@ import org.springframework.core.codec.StringDecoder; import org.springframework.core.io.buffer.DataBuffer; import org.springframework.core.io.buffer.DefaultDataBuffer; import org.springframework.core.io.buffer.DefaultDataBufferFactory; +import org.springframework.core.io.buffer.NettyDataBuffer; +import org.springframework.core.io.buffer.NettyDataBufferFactory; +import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; import org.springframework.http.ReactiveHttpInputMessage; import org.springframework.http.codec.DecoderHttpMessageReader; @@ -53,15 +60,18 @@ import org.springframework.http.codec.multipart.SynchronossPartHttpMessageReader import org.springframework.http.codec.xml.Jaxb2XmlDecoder; import org.springframework.http.server.reactive.ServerHttpRequest; import org.springframework.http.server.reactive.ServerHttpResponse; +import org.springframework.mock.http.client.reactive.test.MockClientHttpResponse; import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest; import org.springframework.util.MultiValueMap; import static org.junit.Assert.*; +import static org.mockito.Mockito.when; import static org.springframework.http.codec.json.Jackson2CodecSupport.*; /** * @author Arjen Poutsma * @author Sebastien Deleuze + * @author Brian Clozel */ public class BodyExtractorsTests { @@ -69,6 +79,8 @@ public class BodyExtractorsTests { private Map hints; + private Optional serverResponse = Optional.empty(); + @Before public void createContext() { @@ -92,7 +104,7 @@ public class BodyExtractorsTests { @Override public Optional serverResponse() { - return Optional.empty(); + return serverResponse; } @Override @@ -180,6 +192,43 @@ public class BodyExtractorsTests { StepVerifier.create(result).expectComplete().verify(); } + @Test + public void toMonoVoidAsClientShouldConsumeAndCancel() { + DefaultDataBufferFactory factory = new DefaultDataBufferFactory(); + DefaultDataBuffer dataBuffer = + factory.wrap(ByteBuffer.wrap("foo".getBytes(StandardCharsets.UTF_8))); + TestPublisher body = TestPublisher.create(); + + BodyExtractor, ReactiveHttpInputMessage> extractor = BodyExtractors.toMono(Void.class); + MockClientHttpResponse response = new MockClientHttpResponse(HttpStatus.OK); + response.setBody(body.flux()); + + StepVerifier.create(extractor.extract(response, this.context)) + .then(() -> { + body.assertWasSubscribed(); + body.emit(dataBuffer); + }) + .verifyComplete(); + + body.assertCancelled(); + } + + @Test + public void toMonoVoidAsClientWithEmptyBody() { + TestPublisher body = TestPublisher.create(); + + BodyExtractor, ReactiveHttpInputMessage> extractor = BodyExtractors.toMono(Void.class); + MockClientHttpResponse response = new MockClientHttpResponse(HttpStatus.OK); + response.setBody(body.flux()); + + StepVerifier.create(extractor.extract(response, this.context)) + .then(() -> { + body.assertWasSubscribed(); + body.complete(); + }) + .verifyComplete(); + } + @Test public void toFlux() { BodyExtractor, ReactiveHttpInputMessage> extractor = BodyExtractors.toFlux(String.class); @@ -366,6 +415,34 @@ public class BodyExtractorsTests { .verify(); } + @Test // SPR-17054 + public void unsupportedMediaTypeShouldConsumeAndCancel() { + NettyDataBufferFactory factory = new NettyDataBufferFactory(new PooledByteBufAllocator(true)); + NettyDataBuffer buffer = factory.wrap(ByteBuffer.wrap("spring".getBytes(StandardCharsets.UTF_8))); + TestPublisher body = TestPublisher.create(); + + MockClientHttpResponse response = new MockClientHttpResponse(HttpStatus.OK); + response.getHeaders().setContentType(MediaType.APPLICATION_PDF); + response.setBody(body.flux()); + + BodyExtractor, ReactiveHttpInputMessage> extractor = BodyExtractors.toMono(User.class); + StepVerifier.create(extractor.extract(response, this.context)) + .then(() -> { + body.assertWasSubscribed(); + body.emit(buffer); + }) + .expectErrorSatisfies(throwable -> { + assertTrue(throwable instanceof UnsupportedMediaTypeException); + try { + buffer.release(); + Assert.fail("releasing the buffer should have failed"); + } catch (IllegalReferenceCountException exc) { + + } + body.assertCancelled(); + }).verify(); + } + interface SafeToDeserialize {} diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/function/client/DefaultClientResponseTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/function/client/DefaultClientResponseTests.java index 6fcd30d518d..f4fb4e421d7 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/function/client/DefaultClientResponseTests.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/function/client/DefaultClientResponseTests.java @@ -28,8 +28,6 @@ import org.junit.Before; import org.junit.Test; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; -import reactor.test.StepVerifier; -import reactor.test.publisher.TestPublisher; import org.springframework.core.ParameterizedTypeReference; import org.springframework.core.codec.StringDecoder; @@ -48,9 +46,12 @@ import org.springframework.http.codec.HttpMessageReader; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; -import static org.junit.Assert.*; -import static org.mockito.Mockito.*; -import static org.springframework.web.reactive.function.BodyExtractors.*; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.fail; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.springframework.web.reactive.function.BodyExtractors.toMono; /** * @author Arjen Poutsma @@ -127,11 +128,7 @@ public class DefaultClientResponseTests { DefaultDataBuffer dataBuffer = factory.wrap(ByteBuffer.wrap("foo".getBytes(StandardCharsets.UTF_8))); Flux body = Flux.just(dataBuffer); - - HttpHeaders httpHeaders = new HttpHeaders(); - httpHeaders.setContentType(MediaType.TEXT_PLAIN); - when(mockResponse.getHeaders()).thenReturn(httpHeaders); - when(mockResponse.getBody()).thenReturn(body); + mockTextPlainResponse(body); List> messageReaders = Collections .singletonList(new DecoderHttpMessageReader<>(StringDecoder.allMimeTypes())); @@ -147,13 +144,7 @@ public class DefaultClientResponseTests { DefaultDataBuffer dataBuffer = factory.wrap(ByteBuffer.wrap("foo".getBytes(StandardCharsets.UTF_8))); Flux body = Flux.just(dataBuffer); - - HttpHeaders httpHeaders = new HttpHeaders(); - httpHeaders.setContentType(MediaType.TEXT_PLAIN); - when(mockResponse.getHeaders()).thenReturn(httpHeaders); - when(mockResponse.getStatusCode()).thenReturn(HttpStatus.OK); - when(mockResponse.getRawStatusCode()).thenReturn(HttpStatus.OK.value()); - when(mockResponse.getBody()).thenReturn(body); + mockTextPlainResponse(body); List> messageReaders = Collections .singletonList(new DecoderHttpMessageReader<>(StringDecoder.allMimeTypes())); @@ -169,13 +160,7 @@ public class DefaultClientResponseTests { DefaultDataBuffer dataBuffer = factory.wrap(ByteBuffer.wrap("foo".getBytes(StandardCharsets.UTF_8))); Flux body = Flux.just(dataBuffer); - - HttpHeaders httpHeaders = new HttpHeaders(); - httpHeaders.setContentType(MediaType.TEXT_PLAIN); - when(mockResponse.getHeaders()).thenReturn(httpHeaders); - when(mockResponse.getStatusCode()).thenReturn(HttpStatus.OK); - when(mockResponse.getRawStatusCode()).thenReturn(HttpStatus.OK.value()); - when(mockResponse.getBody()).thenReturn(body); + mockTextPlainResponse(body); List> messageReaders = Collections .singletonList(new DecoderHttpMessageReader<>(StringDecoder.allMimeTypes())); @@ -193,13 +178,7 @@ public class DefaultClientResponseTests { DefaultDataBuffer dataBuffer = factory.wrap(ByteBuffer.wrap("foo".getBytes(StandardCharsets.UTF_8))); Flux body = Flux.just(dataBuffer); - - HttpHeaders httpHeaders = new HttpHeaders(); - httpHeaders.setContentType(MediaType.TEXT_PLAIN); - when(mockResponse.getHeaders()).thenReturn(httpHeaders); - when(mockResponse.getStatusCode()).thenReturn(HttpStatus.OK); - when(mockResponse.getRawStatusCode()).thenReturn(HttpStatus.OK.value()); - when(mockResponse.getBody()).thenReturn(body); + mockTextPlainResponse(body); List> messageReaders = Collections .singletonList(new DecoderHttpMessageReader<>(StringDecoder.allMimeTypes())); @@ -216,13 +195,7 @@ public class DefaultClientResponseTests { DefaultDataBuffer dataBuffer = factory.wrap(ByteBuffer.wrap("foo".getBytes(StandardCharsets.UTF_8))); Flux body = Flux.just(dataBuffer); - - HttpHeaders httpHeaders = new HttpHeaders(); - httpHeaders.setContentType(MediaType.TEXT_PLAIN); - when(mockResponse.getHeaders()).thenReturn(httpHeaders); - when(mockResponse.getStatusCode()).thenReturn(HttpStatus.OK); - when(mockResponse.getRawStatusCode()).thenReturn(HttpStatus.OK.value()); - when(mockResponse.getBody()).thenReturn(body); + mockTextPlainResponse(body); List> messageReaders = Collections .singletonList(new DecoderHttpMessageReader<>(StringDecoder.allMimeTypes())); @@ -241,13 +214,7 @@ public class DefaultClientResponseTests { DefaultDataBuffer dataBuffer = factory.wrap(ByteBuffer.wrap("foo".getBytes(StandardCharsets.UTF_8))); Flux body = Flux.just(dataBuffer); - - HttpHeaders httpHeaders = new HttpHeaders(); - httpHeaders.setContentType(MediaType.TEXT_PLAIN); - when(mockResponse.getHeaders()).thenReturn(httpHeaders); - when(mockResponse.getStatusCode()).thenReturn(HttpStatus.OK); - when(mockResponse.getRawStatusCode()).thenReturn(HttpStatus.OK.value()); - when(mockResponse.getBody()).thenReturn(body); + mockTextPlainResponse(body); List> messageReaders = Collections .singletonList(new DecoderHttpMessageReader<>(StringDecoder.allMimeTypes())); @@ -296,13 +263,7 @@ public class DefaultClientResponseTests { DefaultDataBuffer dataBuffer = factory.wrap(ByteBuffer.wrap("foo".getBytes(StandardCharsets.UTF_8))); Flux body = Flux.just(dataBuffer); - - HttpHeaders httpHeaders = new HttpHeaders(); - httpHeaders.setContentType(MediaType.TEXT_PLAIN); - when(mockResponse.getHeaders()).thenReturn(httpHeaders); - when(mockResponse.getStatusCode()).thenReturn(HttpStatus.OK); - when(mockResponse.getRawStatusCode()).thenReturn(HttpStatus.OK.value()); - when(mockResponse.getBody()).thenReturn(body); + mockTextPlainResponse(body); List> messageReaders = Collections .singletonList(new DecoderHttpMessageReader<>(StringDecoder.allMimeTypes())); @@ -323,13 +284,7 @@ public class DefaultClientResponseTests { DefaultDataBuffer dataBuffer = factory.wrap(ByteBuffer.wrap("foo".getBytes(StandardCharsets.UTF_8))); Flux body = Flux.just(dataBuffer); - - HttpHeaders httpHeaders = new HttpHeaders(); - httpHeaders.setContentType(MediaType.TEXT_PLAIN); - when(mockResponse.getHeaders()).thenReturn(httpHeaders); - when(mockResponse.getStatusCode()).thenReturn(HttpStatus.OK); - when(mockResponse.getRawStatusCode()).thenReturn(HttpStatus.OK.value()); - when(mockResponse.getBody()).thenReturn(body); + mockTextPlainResponse(body); List> messageReaders = Collections .singletonList(new DecoderHttpMessageReader<>(StringDecoder.allMimeTypes())); @@ -379,12 +334,7 @@ public class DefaultClientResponseTests { factory.wrap(ByteBuffer.wrap("foo".getBytes(StandardCharsets.UTF_8))); Flux body = Flux.just(dataBuffer); - HttpHeaders httpHeaders = new HttpHeaders(); - httpHeaders.setContentType(MediaType.TEXT_PLAIN); - when(mockResponse.getHeaders()).thenReturn(httpHeaders); - when(mockResponse.getStatusCode()).thenReturn(HttpStatus.OK); - when(mockResponse.getRawStatusCode()).thenReturn(HttpStatus.OK.value()); - when(mockResponse.getBody()).thenReturn(body); + mockTextPlainResponse(body); List> messageReaders = Collections .singletonList(new DecoderHttpMessageReader<>(StringDecoder.allMimeTypes())); @@ -399,55 +349,14 @@ public class DefaultClientResponseTests { assertEquals(MediaType.TEXT_PLAIN, result.getHeaders().getContentType()); } - @Test - public void toMonoVoid() { - TestPublisher body = TestPublisher.create(); - - HttpHeaders httpHeaders = new HttpHeaders(); - httpHeaders.setContentType(MediaType.TEXT_PLAIN); - when(mockResponse.getHeaders()).thenReturn(httpHeaders); - when(mockResponse.getStatusCode()).thenReturn(HttpStatus.OK); - when(mockResponse.getRawStatusCode()).thenReturn(HttpStatus.OK.value()); - when(mockResponse.getBody()).thenReturn(body.flux()); - - List> messageReaders = Collections - .singletonList(new DecoderHttpMessageReader<>(StringDecoder.allMimeTypes())); - when(mockExchangeStrategies.messageReaders()).thenReturn(messageReaders); - - StepVerifier.create(defaultClientResponse.bodyToMono(Void.class)) - .then(() -> { - body.assertWasSubscribed(); - body.complete(); - }) - .verifyComplete(); - } - - @Test - public void toMonoVoidNonEmptyBody() { - DefaultDataBufferFactory factory = new DefaultDataBufferFactory(); - DefaultDataBuffer dataBuffer = - factory.wrap(ByteBuffer.wrap("foo".getBytes(StandardCharsets.UTF_8))); - TestPublisher body = TestPublisher.create(); + private void mockTextPlainResponse(Flux body) { HttpHeaders httpHeaders = new HttpHeaders(); httpHeaders.setContentType(MediaType.TEXT_PLAIN); when(mockResponse.getHeaders()).thenReturn(httpHeaders); when(mockResponse.getStatusCode()).thenReturn(HttpStatus.OK); when(mockResponse.getRawStatusCode()).thenReturn(HttpStatus.OK.value()); - when(mockResponse.getBody()).thenReturn(body.flux()); - - List> messageReaders = Collections - .singletonList(new DecoderHttpMessageReader<>(StringDecoder.allMimeTypes())); - when(mockExchangeStrategies.messageReaders()).thenReturn(messageReaders); - - StepVerifier.create(defaultClientResponse.bodyToMono(Void.class)) - .then(() -> { - body.assertWasSubscribed(); - body.emit(dataBuffer); - }) - .verifyComplete(); - - body.assertCancelled(); + when(mockResponse.getBody()).thenReturn(body); } }