diff --git a/spring-web/src/main/java/org/springframework/web/client/DefaultRestClient.java b/spring-web/src/main/java/org/springframework/web/client/DefaultRestClient.java index 1ac20207f88..434916de5f4 100644 --- a/spring-web/src/main/java/org/springframework/web/client/DefaultRestClient.java +++ b/spring-web/src/main/java/org/springframework/web/client/DefaultRestClient.java @@ -813,19 +813,31 @@ final class DefaultRestClient implements RestClient { } @Override - @SuppressWarnings("NullAway") // See https://github.com/uber/NullAway/issues/1290 public @Nullable T body(Class bodyType) { return executeAndExtract((request, response) -> readBody(request, response, bodyType, bodyType, this.hints)); } @Override - @SuppressWarnings("NullAway") // See https://github.com/uber/NullAway/issues/1290 + public T requiredBody(Class bodyType) { + T body = body(bodyType); + Assert.state(body != null, "The body must not be null"); + return body; + } + + @Override public @Nullable T body(ParameterizedTypeReference bodyType) { Type type = bodyType.getType(); Class bodyClass = bodyClass(type); return executeAndExtract((request, response) -> readBody(request, response, type, bodyClass, this.hints)); } + @Override + public T requiredBody(ParameterizedTypeReference bodyType) { + T body = body(bodyType); + Assert.state(body != null, "The body must not be null"); + return body; + } + @Override public ResponseEntity toEntity(Class bodyType) { return toEntityInternal(bodyType, bodyType); diff --git a/spring-web/src/main/java/org/springframework/web/client/RestClient.java b/spring-web/src/main/java/org/springframework/web/client/RestClient.java index 2dc85e0867c..facd6111c0f 100644 --- a/spring-web/src/main/java/org/springframework/web/client/RestClient.java +++ b/spring-web/src/main/java/org/springframework/web/client/RestClient.java @@ -1027,9 +1027,24 @@ public interface RestClient { * response with a status code of 4xx or 5xx. Use * {@link #onStatus(Predicate, ErrorHandler)} to customize error response * handling. + * @see #requiredBody(Class) */ @Nullable T body(Class bodyType); + /** + * Extract the body as an object of the given type. + * @param bodyType the type of return value + * @param the body type + * @return the body + * @throws IllegalStateException if no response body was available + * @throws RestClientResponseException by default when receiving a + * response with a status code of 4xx or 5xx. Use + * {@link #onStatus(Predicate, ErrorHandler)} to customize error response + * handling. + * @since 7.0.4 + */ + T requiredBody(Class bodyType); + /** * Extract the body as an object of the given type. * @param bodyType the type of return value @@ -1039,9 +1054,24 @@ public interface RestClient { * response with a status code of 4xx or 5xx. Use * {@link #onStatus(Predicate, ErrorHandler)} to customize error response * handling. + * @see #requiredBody(ParameterizedTypeReference) */ @Nullable T body(ParameterizedTypeReference bodyType); + /** + * Extract the body as an object of the given type. + * @param bodyType the type of return value + * @param the body type + * @return the body + * @throws IllegalStateException if no response body was available + * @throws RestClientResponseException by default when receiving a + * response with a status code of 4xx or 5xx. Use + * {@link #onStatus(Predicate, ErrorHandler)} to customize error response + * handling. + * @since 7.0.4 + */ + T requiredBody(ParameterizedTypeReference bodyType); + /** * Return a {@code ResponseEntity} with the body decoded to an Object of * the given type. diff --git a/spring-web/src/test/java/org/springframework/web/client/DefaultRestClientTests.java b/spring-web/src/test/java/org/springframework/web/client/DefaultRestClientTests.java new file mode 100644 index 00000000000..45847f586f9 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/client/DefaultRestClientTests.java @@ -0,0 +1,149 @@ +/* + * Copyright 2002-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.client; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.net.URI; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.client.ClientHttpRequest; +import org.springframework.http.client.ClientHttpRequestFactory; +import org.springframework.http.client.ClientHttpResponse; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalStateException; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; + +/** + * Tests for {@link DefaultRestClient}. + * + * @author Sebastien Deleuze + */ +class DefaultRestClientTests { + + private final ClientHttpRequestFactory requestFactory = mock(); + + private final ClientHttpRequest request = mock(); + + private final ClientHttpResponse response = mock(); + + private RestClient client; + + + @BeforeEach + void setup() { + this.client = RestClient.builder() + .requestFactory(this.requestFactory) + .build(); + } + + + @Test + void requiredBodyWithClass() throws IOException { + mockSentRequest(HttpMethod.GET, "https://example.org"); + mockResponseStatus(HttpStatus.OK); + mockResponseBody("Hello World", MediaType.TEXT_PLAIN); + + String result = this.client.get() + .uri("https://example.org") + .retrieve() + .requiredBody(String.class); + + assertThat(result).isEqualTo("Hello World"); + } + + @Test + void requiredBodyWithClassAndNullBody() throws IOException { + mockSentRequest(HttpMethod.GET, "https://example.org"); + mockResponseStatus(HttpStatus.OK); + mockEmptyResponseBody(); + + assertThatIllegalStateException().isThrownBy(() -> + this.client.get() + .uri("https://example.org") + .retrieve() + .requiredBody(String.class) + ); + } + + @Test + void requiredBodyWithParameterizedTypeReference() throws IOException { + mockSentRequest(HttpMethod.GET, "https://example.org"); + mockResponseStatus(HttpStatus.OK); + mockResponseBody("Hello World", MediaType.TEXT_PLAIN); + + String result = this.client.get() + .uri("https://example.org") + .retrieve() + .requiredBody(new ParameterizedTypeReference<>() {}); + + assertThat(result).isEqualTo("Hello World"); + } + + @Test + void requiredBodyWithParameterizedTypeReferenceAndNullBody() throws IOException { + mockSentRequest(HttpMethod.GET, "https://example.org"); + mockResponseStatus(HttpStatus.OK); + mockEmptyResponseBody(); + + assertThatIllegalStateException().isThrownBy(() -> + this.client.get() + .uri("https://example.org") + .retrieve() + .requiredBody(new ParameterizedTypeReference() {}) + ); + } + + + private void mockSentRequest(HttpMethod method, String uri) throws IOException { + given(this.requestFactory.createRequest(URI.create(uri), method)).willReturn(this.request); + given(this.request.getHeaders()).willReturn(new HttpHeaders()); + given(this.request.getMethod()).willReturn(method); + given(this.request.getURI()).willReturn(URI.create(uri)); + } + + private void mockResponseStatus(HttpStatus responseStatus) throws IOException { + given(this.request.execute()).willReturn(this.response); + given(this.response.getStatusCode()).willReturn(responseStatus); + given(this.response.getStatusText()).willReturn(responseStatus.getReasonPhrase()); + } + + private void mockResponseBody(String expectedBody, MediaType mediaType) throws IOException { + HttpHeaders responseHeaders = new HttpHeaders(); + responseHeaders.setContentType(mediaType); + responseHeaders.setContentLength(expectedBody.length()); + given(this.response.getHeaders()).willReturn(responseHeaders); + given(this.response.getBody()).willReturn(new ByteArrayInputStream(expectedBody.getBytes())); + } + + private void mockEmptyResponseBody() throws IOException { + HttpHeaders responseHeaders = new HttpHeaders(); + responseHeaders.setContentLength(0); + given(this.response.getHeaders()).willReturn(responseHeaders); + given(this.response.getBody()).willReturn(new ByteArrayInputStream(new byte[0])); + } + +}