From 89d746ddf86550dd7177159838797823f20868e1 Mon Sep 17 00:00:00 2001 From: Arjen Poutsma Date: Thu, 22 Feb 2024 11:20:17 +0100 Subject: [PATCH] Avoid async dispatch if completed in AsyncServerResponse This commit checks whether the CompletableFuture passed to AsyncServerResponse.async has been completed, and if so returns a CompletedAsyncServerResponse that simply delegates to the completed response, instead of the DefaultAsyncServerResponse that uses async dispatch. Closes gh-32223 --- .../servlet/function/AsyncServerResponse.java | 46 ++++++++++- .../CompletedAsyncServerResponse.java | 82 +++++++++++++++++++ .../function/DefaultAsyncServerResponse.java | 31 +------ .../web/servlet/function/ServerResponse.java | 4 +- .../DefaultAsyncServerResponseTests.java | 19 ++++- 5 files changed, 147 insertions(+), 35 deletions(-) create mode 100644 spring-webmvc/src/main/java/org/springframework/web/servlet/function/CompletedAsyncServerResponse.java diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/function/AsyncServerResponse.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/function/AsyncServerResponse.java index b2fca283a00..84d278305f0 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/function/AsyncServerResponse.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/function/AsyncServerResponse.java @@ -18,10 +18,14 @@ package org.springframework.web.servlet.function; import java.time.Duration; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; import org.reactivestreams.Publisher; +import org.springframework.core.ReactiveAdapter; import org.springframework.core.ReactiveAdapterRegistry; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; /** * Asynchronous subtype of {@link ServerResponse} that exposes the future @@ -53,7 +57,7 @@ public interface AsyncServerResponse extends ServerResponse { * @return the asynchronous response */ static AsyncServerResponse create(Object asyncResponse) { - return DefaultAsyncServerResponse.create(asyncResponse, null); + return createInternal(asyncResponse, null); } /** @@ -69,7 +73,45 @@ public interface AsyncServerResponse extends ServerResponse { * @return the asynchronous response */ static AsyncServerResponse create(Object asyncResponse, Duration timeout) { - return DefaultAsyncServerResponse.create(asyncResponse, timeout); + return createInternal(asyncResponse, timeout); + } + + private static AsyncServerResponse createInternal(Object asyncResponse, @Nullable Duration timeout) { + Assert.notNull(asyncResponse, "AsyncResponse must not be null"); + + CompletableFuture futureResponse = toCompletableFuture(asyncResponse); + if (futureResponse.isDone() && + !futureResponse.isCancelled() && + !futureResponse.isCompletedExceptionally()) { + + try { + ServerResponse completedResponse = futureResponse.get(); + return new CompletedAsyncServerResponse(completedResponse); + } + catch (InterruptedException | ExecutionException ignored) { + // fall through to use DefaultAsyncServerResponse + } + } + return new DefaultAsyncServerResponse(futureResponse, timeout); + } + + @SuppressWarnings("unchecked") + private static CompletableFuture toCompletableFuture(Object obj) { + if (obj instanceof CompletableFuture futureResponse) { + return (CompletableFuture) futureResponse; + } + else if (DefaultAsyncServerResponse.reactiveStreamsPresent) { + ReactiveAdapterRegistry registry = ReactiveAdapterRegistry.getSharedInstance(); + ReactiveAdapter publisherAdapter = registry.getAdapter(obj.getClass()); + if (publisherAdapter != null) { + Publisher publisher = publisherAdapter.toPublisher(obj); + ReactiveAdapter futureAdapter = registry.getAdapter(CompletableFuture.class); + if (futureAdapter != null) { + return (CompletableFuture) futureAdapter.fromPublisher(publisher); + } + } + } + throw new IllegalArgumentException("Asynchronous type not supported: " + obj.getClass()); } } diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/function/CompletedAsyncServerResponse.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/function/CompletedAsyncServerResponse.java new file mode 100644 index 00000000000..676ab9a2126 --- /dev/null +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/function/CompletedAsyncServerResponse.java @@ -0,0 +1,82 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.servlet.function; + +import java.io.IOException; + +import jakarta.servlet.ServletException; +import jakarta.servlet.http.Cookie; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatusCode; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.MultiValueMap; +import org.springframework.web.servlet.ModelAndView; + +/** + * {@link AsyncServerResponse} implementation for completed futures. + * + * @author Arjen Poutsma + * @since 6.2 + */ +final class CompletedAsyncServerResponse implements AsyncServerResponse { + + private final ServerResponse serverResponse; + + + CompletedAsyncServerResponse(ServerResponse serverResponse) { + Assert.notNull(serverResponse, "ServerResponse must not be null"); + this.serverResponse = serverResponse; + } + + @Override + public ServerResponse block() { + return this.serverResponse; + } + + @Override + public HttpStatusCode statusCode() { + return this.serverResponse.statusCode(); + } + + @Override + @Deprecated + public int rawStatusCode() { + return this.serverResponse.rawStatusCode(); + } + + @Override + public HttpHeaders headers() { + return this.serverResponse.headers(); + } + + @Override + public MultiValueMap cookies() { + return this.serverResponse.cookies(); + } + + @Nullable + @Override + public ModelAndView writeTo(HttpServletRequest request, HttpServletResponse response, Context context) + throws ServletException, IOException { + + return this.serverResponse.writeTo(request, response, context); + } +} diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/function/DefaultAsyncServerResponse.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/function/DefaultAsyncServerResponse.java index 1a41098a3ba..62442492122 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/function/DefaultAsyncServerResponse.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/function/DefaultAsyncServerResponse.java @@ -29,14 +29,10 @@ import jakarta.servlet.ServletException; import jakarta.servlet.http.Cookie; import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletResponse; -import org.reactivestreams.Publisher; -import org.springframework.core.ReactiveAdapter; -import org.springframework.core.ReactiveAdapterRegistry; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpStatusCode; import org.springframework.lang.Nullable; -import org.springframework.util.Assert; import org.springframework.util.ClassUtils; import org.springframework.util.MultiValueMap; import org.springframework.web.context.request.async.AsyncWebRequest; @@ -62,7 +58,7 @@ final class DefaultAsyncServerResponse extends ErrorHandlingServerResponse imple private final Duration timeout; - private DefaultAsyncServerResponse(CompletableFuture futureResponse, @Nullable Duration timeout) { + DefaultAsyncServerResponse(CompletableFuture futureResponse, @Nullable Duration timeout) { this.futureResponse = futureResponse; this.timeout = timeout; } @@ -167,29 +163,4 @@ final class DefaultAsyncServerResponse extends ErrorHandlingServerResponse imple }); return result; } - - @SuppressWarnings({"rawtypes", "unchecked"}) - public static AsyncServerResponse create(Object obj, @Nullable Duration timeout) { - Assert.notNull(obj, "Argument to async must not be null"); - - if (obj instanceof CompletableFuture futureResponse) { - return new DefaultAsyncServerResponse(futureResponse, timeout); - } - else if (reactiveStreamsPresent) { - ReactiveAdapterRegistry registry = ReactiveAdapterRegistry.getSharedInstance(); - ReactiveAdapter publisherAdapter = registry.getAdapter(obj.getClass()); - if (publisherAdapter != null) { - Publisher publisher = publisherAdapter.toPublisher(obj); - ReactiveAdapter futureAdapter = registry.getAdapter(CompletableFuture.class); - if (futureAdapter != null) { - CompletableFuture futureResponse = - (CompletableFuture) futureAdapter.fromPublisher(publisher); - return new DefaultAsyncServerResponse(futureResponse, timeout); - } - } - } - throw new IllegalArgumentException("Asynchronous type not supported: " + obj.getClass()); - } - - } diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/function/ServerResponse.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/function/ServerResponse.java index 790e37e8fba..26dbd6a31ab 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/function/ServerResponse.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/function/ServerResponse.java @@ -246,7 +246,7 @@ public interface ServerResponse { * @since 5.3 */ static ServerResponse async(Object asyncResponse) { - return DefaultAsyncServerResponse.create(asyncResponse, null); + return AsyncServerResponse.create(asyncResponse); } /** @@ -267,7 +267,7 @@ public interface ServerResponse { * @since 5.3.2 */ static ServerResponse async(Object asyncResponse, Duration timeout) { - return DefaultAsyncServerResponse.create(asyncResponse, timeout); + return AsyncServerResponse.create(asyncResponse, timeout); } /** diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/function/DefaultAsyncServerResponseTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/function/DefaultAsyncServerResponseTests.java index f1fe0c5e0ce..0d4add7f3e6 100644 --- a/spring-webmvc/src/test/java/org/springframework/web/servlet/function/DefaultAsyncServerResponseTests.java +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/function/DefaultAsyncServerResponseTests.java @@ -28,7 +28,7 @@ import static org.assertj.core.api.Assertions.assertThat; class DefaultAsyncServerResponseTests { @Test - void block() { + void blockCompleted() { ServerResponse wrappee = ServerResponse.ok().build(); CompletableFuture future = CompletableFuture.completedFuture(wrappee); AsyncServerResponse response = AsyncServerResponse.create(future); @@ -36,4 +36,21 @@ class DefaultAsyncServerResponseTests { assertThat(response.block()).isSameAs(wrappee); } + @Test + void blockNotCompleted() { + ServerResponse wrappee = ServerResponse.ok().build(); + CompletableFuture future = CompletableFuture.supplyAsync(() -> { + try { + Thread.sleep(500); + return wrappee; + } + catch (InterruptedException ex) { + throw new RuntimeException(ex); + } + }); + AsyncServerResponse response = AsyncServerResponse.create(future); + + assertThat(response.block()).isSameAs(wrappee); + } + }