diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/config/annotation/WebMvcConfigurationSupport.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/config/annotation/WebMvcConfigurationSupport.java index 535ec014daa..a87b57d5a88 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/config/annotation/WebMvcConfigurationSupport.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/config/annotation/WebMvcConfigurationSupport.java @@ -682,7 +682,14 @@ public class WebMvcConfigurationSupport implements ApplicationContextAware, Serv */ @Bean public HandlerFunctionAdapter handlerFunctionAdapter() { - return new HandlerFunctionAdapter(); + HandlerFunctionAdapter adapter = new HandlerFunctionAdapter(); + + AsyncSupportConfigurer configurer = new AsyncSupportConfigurer(); + configureAsyncSupport(configurer); + if (configurer.getTimeout() != null) { + adapter.setAsyncRequestTimeout(configurer.getTimeout()); + } + return adapter; } /** diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/function/AsyncServerResponse.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/function/AsyncServerResponse.java index f099a3699d3..279a9b165c3 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/function/AsyncServerResponse.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/function/AsyncServerResponse.java @@ -17,19 +17,14 @@ package org.springframework.web.servlet.function; import java.io.IOException; +import java.time.Duration; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.CompletionException; import java.util.function.Function; -import javax.servlet.AsyncContext; -import javax.servlet.AsyncListener; -import javax.servlet.ServletContext; import javax.servlet.ServletException; -import javax.servlet.ServletRequest; -import javax.servlet.ServletResponse; import javax.servlet.http.Cookie; import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletRequestWrapper; import javax.servlet.http.HttpServletResponse; import org.reactivestreams.Publisher; @@ -42,6 +37,10 @@ import org.springframework.lang.Nullable; import org.springframework.util.Assert; import org.springframework.util.ClassUtils; import org.springframework.util.MultiValueMap; +import org.springframework.web.context.request.async.AsyncWebRequest; +import org.springframework.web.context.request.async.DeferredResult; +import org.springframework.web.context.request.async.WebAsyncManager; +import org.springframework.web.context.request.async.WebAsyncUtils; import org.springframework.web.servlet.ModelAndView; /** @@ -59,9 +58,13 @@ final class AsyncServerResponse extends ErrorHandlingServerResponse { private final CompletableFuture futureResponse; + @Nullable + private final Duration timeout; + - private AsyncServerResponse(CompletableFuture futureResponse) { + private AsyncServerResponse(CompletableFuture futureResponse, @Nullable Duration timeout) { this.futureResponse = futureResponse; + this.timeout = timeout; } @Override @@ -96,44 +99,62 @@ final class AsyncServerResponse extends ErrorHandlingServerResponse { @Nullable @Override - public ModelAndView writeTo(HttpServletRequest request, HttpServletResponse response, Context context) { + public ModelAndView writeTo(HttpServletRequest request, HttpServletResponse response, Context context) + throws ServletException, IOException { - SharedAsyncContextHttpServletRequest sharedRequest = new SharedAsyncContextHttpServletRequest(request); - AsyncContext asyncContext = sharedRequest.startAsync(request, response); - this.futureResponse.whenComplete((futureResponse, futureThrowable) -> { - try { - if (futureResponse != null) { - ModelAndView mav = futureResponse.writeTo(sharedRequest, response, context); - Assert.state(mav == null, "Asynchronous, rendering ServerResponse implementations are not " + - "supported in WebMvc.fn. Please use WebFlux.fn instead."); - } - else if (futureThrowable != null) { - handleError(futureThrowable, request, response, context); - } - } - catch (Throwable throwable) { - try { - handleError(throwable, request, response, context); - } - catch (ServletException | IOException ex) { - logger.warn("Asynchronous execution resulted in exception", ex); + writeAsync(request, response, createDeferredResult()); + return null; + } + + static void writeAsync(HttpServletRequest request, HttpServletResponse response, DeferredResult deferredResult) + throws ServletException, IOException { + + WebAsyncManager asyncManager = WebAsyncUtils.getAsyncManager(request); + AsyncWebRequest asyncWebRequest = WebAsyncUtils.createAsyncWebRequest(request, response); + asyncManager.setAsyncWebRequest(asyncWebRequest); + try { + asyncManager.startDeferredResultProcessing(deferredResult); + } + catch (IOException | ServletException ex) { + throw ex; + } + catch (Exception ex) { + throw new ServletException("Async processing failed", ex); + } + + } + + private DeferredResult createDeferredResult() { + DeferredResult result; + if (this.timeout != null) { + result = new DeferredResult<>(this.timeout.toMillis()); + } + else { + result = new DeferredResult<>(); + } + this.futureResponse.handle((value, ex) -> { + if (ex != null) { + if (ex instanceof CompletionException && ex.getCause() != null) { + ex = ex.getCause(); } + result.setErrorResult(ex); } - finally { - asyncContext.complete(); + else { + result.setResult(value); } + return null; }); - return null; + return result; } @SuppressWarnings({"unchecked"}) - public static ServerResponse create(Object o) { + public static ServerResponse create(Object o, @Nullable Duration timeout) { Assert.notNull(o, "Argument to async must not be null"); if (o instanceof CompletableFuture) { CompletableFuture futureResponse = (CompletableFuture) o; - return new AsyncServerResponse(futureResponse); + return new AsyncServerResponse(futureResponse, timeout); } else if (reactiveStreamsPresent) { ReactiveAdapterRegistry registry = ReactiveAdapterRegistry.getSharedInstance(); @@ -144,7 +165,7 @@ final class AsyncServerResponse extends ErrorHandlingServerResponse { if (futureAdapter != null) { CompletableFuture futureResponse = (CompletableFuture) futureAdapter.fromPublisher(publisher); - return new AsyncServerResponse(futureResponse); + return new AsyncServerResponse(futureResponse, timeout); } } } @@ -152,150 +173,4 @@ final class AsyncServerResponse extends ErrorHandlingServerResponse { } - /** - * HttpServletRequestWrapper that shares its AsyncContext between this - * AsyncServerResponse class and other, subsequent ServerResponse - * implementations, keeping track of how many contexts where - * started with startAsync(). This way, we make sure that - * {@link AsyncContext#complete()} only completes for the response that - * finishes last, and is not closed prematurely. - */ - private static final class SharedAsyncContextHttpServletRequest extends HttpServletRequestWrapper { - - private final AsyncContext asyncContext; - - private final AtomicInteger startedContexts; - - public SharedAsyncContextHttpServletRequest(HttpServletRequest request) { - super(request); - this.asyncContext = request.startAsync(); - this.startedContexts = new AtomicInteger(0); - } - - private SharedAsyncContextHttpServletRequest(HttpServletRequest request, AsyncContext asyncContext, - AtomicInteger startedContexts) { - super(request); - this.asyncContext = asyncContext; - this.startedContexts = startedContexts; - } - - @Override - public AsyncContext startAsync() throws IllegalStateException { - this.startedContexts.incrementAndGet(); - return new SharedAsyncContext(this.asyncContext, this, this.asyncContext.getResponse(), - this.startedContexts); - } - - @Override - public AsyncContext startAsync(ServletRequest servletRequest, ServletResponse servletResponse) - throws IllegalStateException { - this.startedContexts.incrementAndGet(); - SharedAsyncContextHttpServletRequest sharedRequest; - if (servletRequest instanceof SharedAsyncContextHttpServletRequest) { - sharedRequest = (SharedAsyncContextHttpServletRequest) servletRequest; - } - else { - sharedRequest = new SharedAsyncContextHttpServletRequest((HttpServletRequest) servletRequest, - this.asyncContext, this.startedContexts); - } - return new SharedAsyncContext(this.asyncContext, sharedRequest, servletResponse, this.startedContexts); - } - - @Override - public AsyncContext getAsyncContext() { - return new SharedAsyncContext(this.asyncContext, this, this.asyncContext.getResponse(), this.startedContexts); - } - - - private static final class SharedAsyncContext implements AsyncContext { - - private final AsyncContext delegate; - - private final AtomicInteger openContexts; - - private final ServletRequest request; - - private final ServletResponse response; - - - public SharedAsyncContext(AsyncContext delegate, SharedAsyncContextHttpServletRequest request, - ServletResponse response, AtomicInteger usageCount) { - - this.delegate = delegate; - this.request = request; - this.response = response; - this.openContexts = usageCount; - } - - @Override - public void complete() { - if (this.openContexts.decrementAndGet() == 0) { - this.delegate.complete(); - } - } - - @Override - public ServletRequest getRequest() { - return this.request; - } - - @Override - public ServletResponse getResponse() { - return this.response; - } - - @Override - public boolean hasOriginalRequestAndResponse() { - return this.delegate.hasOriginalRequestAndResponse(); - } - - @Override - public void dispatch() { - this.delegate.dispatch(); - } - - @Override - public void dispatch(String path) { - this.delegate.dispatch(path); - } - - @Override - public void dispatch(ServletContext context, String path) { - this.delegate.dispatch(context, path); - } - - @Override - public void start(Runnable run) { - this.delegate.start(run); - } - - @Override - public void addListener(AsyncListener listener) { - this.delegate.addListener(listener); - } - - @Override - public void addListener(AsyncListener listener, - ServletRequest servletRequest, - ServletResponse servletResponse) { - - this.delegate.addListener(listener, servletRequest, servletResponse); - } - - @Override - public T createListener(Class clazz) throws ServletException { - return this.delegate.createListener(clazz); - } - - @Override - public void setTimeout(long timeout) { - this.delegate.setTimeout(timeout); - } - - @Override - public long getTimeout() { - return this.delegate.getTimeout(); - } - } - } } diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/function/DefaultEntityResponseBuilder.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/function/DefaultEntityResponseBuilder.java index 868b2d6a7ae..e27b0f48c1f 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/function/DefaultEntityResponseBuilder.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/function/DefaultEntityResponseBuilder.java @@ -25,11 +25,11 @@ import java.util.Arrays; import java.util.LinkedHashSet; import java.util.List; import java.util.Set; +import java.util.concurrent.CompletionException; import java.util.concurrent.CompletionStage; import java.util.function.Consumer; import java.util.stream.Collectors; -import javax.servlet.AsyncContext; import javax.servlet.ServletException; import javax.servlet.http.Cookie; import javax.servlet.http.HttpServletRequest; @@ -61,6 +61,7 @@ import org.springframework.util.Assert; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; import org.springframework.web.HttpMediaTypeNotAcceptableException; +import org.springframework.web.context.request.async.DeferredResult; import org.springframework.web.servlet.ModelAndView; /** @@ -358,32 +359,39 @@ final class DefaultEntityResponseBuilder implements EntityResponse.Builder } @Override - protected ModelAndView writeToInternal(HttpServletRequest servletRequest, - HttpServletResponse servletResponse, Context context) { + protected ModelAndView writeToInternal(HttpServletRequest servletRequest, HttpServletResponse servletResponse, + Context context) throws ServletException, IOException { - AsyncContext asyncContext = servletRequest.startAsync(servletRequest, servletResponse); - entity().whenComplete((entity, throwable) -> { - try { - if (entity != null) { + DeferredResult deferredResult = createDeferredResult(servletRequest, servletResponse, context); + AsyncServerResponse.writeAsync(servletRequest, servletResponse, deferredResult); + return null; + } - tryWriteEntityWithMessageConverters(entity, - (HttpServletRequest) asyncContext.getRequest(), - (HttpServletResponse) asyncContext.getResponse(), - context); - } - else if (throwable != null) { - handleError(throwable, servletRequest, servletResponse, context); + private DeferredResult createDeferredResult(HttpServletRequest request, HttpServletResponse response, + Context context) { + + DeferredResult result = new DeferredResult<>(); + entity().handle((value, ex) -> { + if (ex != null) { + if (ex instanceof CompletionException && ex.getCause() != null) { + ex = ex.getCause(); } + result.setErrorResult(ex); } - catch (ServletException | IOException ex) { - logger.warn("Asynchronous execution resulted in exception", ex); - } - finally { - asyncContext.complete(); + else { + try { + tryWriteEntityWithMessageConverters(value, request, response, context); + result.setResult(null); + } + catch (ServletException | IOException writeException) { + result.setErrorResult(writeException); + } } + return null; }); - return null; + return result; } + } @@ -399,35 +407,46 @@ final class DefaultEntityResponseBuilder implements EntityResponse.Builder } @Override - protected ModelAndView writeToInternal(HttpServletRequest servletRequest, - HttpServletResponse servletResponse, Context context) { + protected ModelAndView writeToInternal(HttpServletRequest servletRequest, HttpServletResponse servletResponse, + Context context) throws ServletException, IOException { + + DeferredResult deferredResult = new DeferredResult<>(); + AsyncServerResponse.writeAsync(servletRequest, servletResponse, deferredResult); - AsyncContext asyncContext = servletRequest.startAsync(servletRequest, - new NoContentLengthResponseWrapper(servletResponse)); - entity().subscribe(new ProducingSubscriber(asyncContext, context)); + entity().subscribe(new DeferredResultSubscriber(servletRequest, servletResponse, context, deferredResult)); return null; } - @SuppressWarnings("SubscriberImplementation") - private class ProducingSubscriber implements Subscriber { - private final AsyncContext asyncContext; + private class DeferredResultSubscriber implements Subscriber { + + private final HttpServletRequest servletRequest; + + private final HttpServletResponse servletResponse; private final Context context; + private final DeferredResult deferredResult; + @Nullable private Subscription subscription; - public ProducingSubscriber(AsyncContext asyncContext, Context context) { - this.asyncContext = asyncContext; + + public DeferredResultSubscriber(HttpServletRequest servletRequest, + HttpServletResponse servletResponse, Context context, + DeferredResult deferredResult) { + + this.servletRequest = servletRequest; + this.servletResponse = new NoContentLengthResponseWrapper(servletResponse); this.context = context; + this.deferredResult = deferredResult; } @Override public void onSubscribe(Subscription s) { if (this.subscription == null) { this.subscription = s; - this.subscription.request(Long.MAX_VALUE); + this.subscription.request(1); } else { s.cancel(); @@ -435,32 +454,34 @@ final class DefaultEntityResponseBuilder implements EntityResponse.Builder } @Override - public void onNext(T element) { - HttpServletRequest servletRequest = (HttpServletRequest) this.asyncContext.getRequest(); - HttpServletResponse servletResponse = (HttpServletResponse) this.asyncContext.getResponse(); + public void onNext(T t) { + Assert.state(this.subscription != null, "No subscription"); try { - tryWriteEntityWithMessageConverters(element, servletRequest, servletResponse, this.context); + tryWriteEntityWithMessageConverters(t, this.servletRequest, this.servletResponse, this.context); + this.servletResponse.getOutputStream().flush(); + this.subscription.request(1); } catch (ServletException | IOException ex) { - onError(ex); + this.subscription.cancel(); + this.deferredResult.setErrorResult(ex); } } @Override public void onError(Throwable t) { - try { - handleError(t, (HttpServletRequest) this.asyncContext.getRequest(), - (HttpServletResponse) this.asyncContext.getResponse(), this.context); - } - catch (ServletException | IOException ex) { - logger.warn("Asynchronous execution resulted in exception", ex); - } - this.asyncContext.complete(); + this.deferredResult.setErrorResult(t); } @Override public void onComplete() { - this.asyncContext.complete(); + try { + this.servletResponse.getOutputStream().flush(); + this.deferredResult.setResult(null); + } + catch (IOException ex) { + this.deferredResult.setErrorResult(ex); + } + } } diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/function/ServerResponse.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/function/ServerResponse.java index 025317adb0a..ad932104989 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/function/ServerResponse.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/function/ServerResponse.java @@ -18,6 +18,7 @@ package org.springframework.web.servlet.function; import java.io.IOException; import java.net.URI; +import java.time.Duration; import java.time.Instant; import java.time.ZonedDateTime; import java.util.Collection; @@ -240,7 +241,33 @@ public interface ServerResponse { * @since 5.3 */ static ServerResponse async(Object asyncResponse) { - return AsyncServerResponse.create(asyncResponse); + return AsyncServerResponse.create(asyncResponse, null); + } + + /** + * Create a (built) response with the given asynchronous response. + * Parameter {@code asyncResponse} can be a + * {@link CompletableFuture CompletableFuture<ServerResponse>} or + * {@link Publisher Publisher<ServerResponse>} (or any + * asynchronous producer of a single {@code ServerResponse} that can be + * adapted via the {@link ReactiveAdapterRegistry}). + * + *

This method can be used to set the response status code, headers, and + * body based on an asynchronous result. If only the body is asynchronous, + * {@link BodyBuilder#body(Object)} can be used instead. + * + *

Note that + * {@linkplain RenderingResponse rendering responses}, as returned by + * {@link BodyBuilder#render}, are not supported as value + * for {@code asyncResponse}. Use WebFlux.fn for asynchronous rendering. + * @param asyncResponse a {@code CompletableFuture} or + * {@code Publisher} + * @param timeout maximum time period to wait for before timing out + * @return the asynchronous response + * @since 5.3.2 + */ + static ServerResponse async(Object asyncResponse, Duration timeout) { + return AsyncServerResponse.create(asyncResponse, timeout); } diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/function/support/HandlerFunctionAdapter.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/function/support/HandlerFunctionAdapter.java index 56ddf6137ca..20f206aa3e3 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/function/support/HandlerFunctionAdapter.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/function/support/HandlerFunctionAdapter.java @@ -18,13 +18,21 @@ package org.springframework.web.servlet.function.support; import java.util.List; +import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + import org.springframework.core.Ordered; +import org.springframework.core.log.LogFormatUtils; import org.springframework.http.converter.HttpMessageConverter; import org.springframework.lang.Nullable; import org.springframework.util.Assert; +import org.springframework.web.context.request.async.AsyncWebRequest; +import org.springframework.web.context.request.async.WebAsyncManager; +import org.springframework.web.context.request.async.WebAsyncUtils; import org.springframework.web.servlet.HandlerAdapter; import org.springframework.web.servlet.ModelAndView; import org.springframework.web.servlet.function.HandlerFunction; @@ -40,8 +48,12 @@ import org.springframework.web.servlet.function.ServerResponse; */ public class HandlerFunctionAdapter implements HandlerAdapter, Ordered { + private static final Log logger = LogFactory.getLog(HandlerFunctionAdapter.class); + private int order = Ordered.LOWEST_PRECEDENCE; + @Nullable + private Long asyncRequestTimeout; /** * Specify the order value for this HandlerAdapter bean. @@ -57,6 +69,19 @@ public class HandlerFunctionAdapter implements HandlerAdapter, Ordered { return this.order; } + /** + * Specify the amount of time, in milliseconds, before concurrent handling + * should time out. In Servlet 3, the timeout begins after the main request + * processing thread has exited and ends when the request is dispatched again + * for further processing of the concurrently produced result. + *

If this value is not set, the default timeout of the underlying + * implementation is used. + * @param timeout the timeout value in milliseconds + */ + public void setAsyncRequestTimeout(long timeout) { + this.asyncRequestTimeout = timeout; + } + @Override public boolean supports(Object handler) { return handler instanceof HandlerFunction; @@ -68,14 +93,34 @@ public class HandlerFunctionAdapter implements HandlerAdapter, Ordered { HttpServletResponse servletResponse, Object handler) throws Exception { - - HandlerFunction handlerFunction = (HandlerFunction) handler; + WebAsyncManager asyncManager = getWebAsyncManager(servletRequest, servletResponse); ServerRequest serverRequest = getServerRequest(servletRequest); - ServerResponse serverResponse = handlerFunction.handle(serverRequest); + ServerResponse serverResponse; + + if (asyncManager.hasConcurrentResult()) { + serverResponse = handleAsync(asyncManager); + } + else { + HandlerFunction handlerFunction = (HandlerFunction) handler; + serverResponse = handlerFunction.handle(serverRequest); + } + + if (serverResponse != null) { + return serverResponse.writeTo(servletRequest, servletResponse, new ServerRequestContext(serverRequest)); + } + else { + return null; + } + } - return serverResponse.writeTo(servletRequest, servletResponse, - new ServerRequestContext(serverRequest)); + private WebAsyncManager getWebAsyncManager(HttpServletRequest servletRequest, HttpServletResponse servletResponse) { + AsyncWebRequest asyncWebRequest = WebAsyncUtils.createAsyncWebRequest(servletRequest, servletResponse); + asyncWebRequest.setTimeout(this.asyncRequestTimeout); + + WebAsyncManager asyncManager = WebAsyncUtils.getAsyncManager(servletRequest); + asyncManager.setAsyncWebRequest(asyncWebRequest); + return asyncManager; } private ServerRequest getServerRequest(HttpServletRequest servletRequest) { @@ -86,6 +131,31 @@ public class HandlerFunctionAdapter implements HandlerAdapter, Ordered { return serverRequest; } + @Nullable + private ServerResponse handleAsync(WebAsyncManager asyncManager) throws Exception { + Object result = asyncManager.getConcurrentResult(); + asyncManager.clearConcurrentResult(); + LogFormatUtils.traceDebug(logger, traceOn -> { + String formatted = LogFormatUtils.formatValue(result, !traceOn); + return "Resume with async result [" + formatted + "]"; + }); + if (result instanceof ServerResponse) { + return (ServerResponse) result; + } + else if (result instanceof Exception) { + throw (Exception) result; + } + else if (result instanceof Throwable) { + throw new ServletException("Async processing failed", (Throwable) result); + } + else if (result == null) { + return null; + } + else { + throw new IllegalArgumentException("Unknown result from WebAsyncManager: [" + result + "]"); + } + } + @Override public long getLastModified(HttpServletRequest request, Object handler) { return -1L; diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/RequestMappingHandlerAdapter.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/RequestMappingHandlerAdapter.java index 9e5ea44432e..935e494491a 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/RequestMappingHandlerAdapter.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/RequestMappingHandlerAdapter.java @@ -880,7 +880,10 @@ public class RequestMappingHandlerAdapter extends AbstractHandlerMethodAdapter asyncManager.registerCallableInterceptors(this.callableInterceptors); asyncManager.registerDeferredResultInterceptors(this.deferredResultInterceptors); - if (asyncManager.hasConcurrentResult()) { + if (asyncManager.hasConcurrentResult() && + asyncManager.getConcurrentResultContext().length > 0 && + asyncManager.getConcurrentResultContext()[0] instanceof ModelAndViewContainer) { + Object result = asyncManager.getConcurrentResult(); mavContainer = (ModelAndViewContainer) asyncManager.getConcurrentResultContext()[0]; asyncManager.clearConcurrentResult();