From 529e62921d3aca4499689fcf7c7e2bbeb7ad221e Mon Sep 17 00:00:00 2001 From: Rossen Stoyanchev Date: Tue, 24 Jul 2012 16:00:05 -0400 Subject: [PATCH] Refactor Servlet 3 async support As a result of the refactoring, the AsyncContext dispatch mechanism is used much more centrally. Effectively every asynchronously processed request involves one initial (container) thread, a second thread to produce the handler return value asynchronously, and a third thread as a result of a dispatch back to the container to resume processing of the asynchronous resuilt. Other updates include the addition of a MockAsyncContext and support of related request method in the test packages of spring-web and spring-webmvc. Also an upgrade of a Jetty test dependency required to make tests pass. Issue: SPR-9433 --- build.gradle | 8 +- .../support/OpenSessionInViewFilter.java | 84 ++--- .../support/OpenSessionInViewInterceptor.java | 108 +++--- .../support/OpenSessionInViewFilter.java | 79 ++--- .../support/OpenSessionInViewInterceptor.java | 91 ++--- .../support/OpenSessionInViewTests.java | 97 ++++-- ...thExtensionContentNegotiationStrategy.java | 3 - .../request/WebRequestInterceptor.java | 8 + .../async/AbstractDelegatingCallable.java | 45 --- .../request/async/AsyncExecutionChain.java | 216 ------------ .../async/AsyncExecutionChainRunnable.java | 79 ----- .../request/async/AsyncWebRequest.java | 48 +-- .../async/AsyncWebRequestInterceptor.java | 66 ++-- .../context/request/async/AsyncWebUtils.java | 60 ++++ .../context/request/async/DeferredResult.java | 21 +- ...est.java => NoSupportAsyncWebRequest.java} | 28 +- .../StaleAsyncRequestCheckingCallable.java | 50 --- .../async/StaleAsyncWebRequestException.java | 3 - .../async/StandardServletAsyncWebRequest.java | 97 +++--- .../request/async/WebAsyncManager.java | 319 ++++++++++++++++++ .../filter/AbstractRequestLoggingFilter.java | 47 +-- .../web/filter/OncePerRequestFilter.java | 91 +++-- .../web/filter/RequestContextFilter.java | 43 +-- .../web/filter/ShallowEtagHeaderFilter.java | 44 ++- .../AbstractHttpRequestFactoryTestCase.java | 52 +-- .../mock/web/MockAsyncContext.java | 133 ++++++++ .../mock/web/MockHttpServletRequest.java | 54 ++- .../client/RestTemplateIntegrationTests.java | 47 +-- .../async/AsyncExecutionChainTests.java | 251 -------------- .../request/async/DeferredResultTests.java | 2 + ...taleAsyncRequestCheckingCallableTests.java | 71 ---- .../StandardServletAsyncWebRequestTests.java | 116 +++---- .../request/async/WebAsyncManagerTests.java | 222 ++++++++++++ .../filter/CharacterEncodingFilterTests.java | 25 +- .../web/servlet/AsyncHandlerInterceptor.java | 74 ++-- .../web/servlet/DispatcherServlet.java | 107 ++---- .../web/servlet/FrameworkServlet.java | 133 +++----- .../web/servlet/HandlerExecutionChain.java | 71 ++-- .../web/servlet/HandlerInterceptor.java | 8 + .../WebRequestHandlerInterceptorAdapter.java | 28 +- .../AsyncMethodReturnValueHandler.java | 12 +- .../RequestMappingHandlerAdapter.java | 87 +++-- .../ServletInvocableHandlerMethod.java | 67 ++-- .../mock/web/MockAsyncContext.java | 133 ++++++++ .../mock/web/MockHttpServletRequest.java | 51 ++- .../servlet/HandlerExecutionChainTests.java | 36 +- .../RequestPartIntegrationTests.java | 39 +-- 47 files changed, 1825 insertions(+), 1729 deletions(-) delete mode 100644 spring-web/src/main/java/org/springframework/web/context/request/async/AbstractDelegatingCallable.java delete mode 100644 spring-web/src/main/java/org/springframework/web/context/request/async/AsyncExecutionChain.java delete mode 100644 spring-web/src/main/java/org/springframework/web/context/request/async/AsyncExecutionChainRunnable.java create mode 100644 spring-web/src/main/java/org/springframework/web/context/request/async/AsyncWebUtils.java rename spring-web/src/main/java/org/springframework/web/context/request/async/{NoOpAsyncWebRequest.java => NoSupportAsyncWebRequest.java} (68%) delete mode 100644 spring-web/src/main/java/org/springframework/web/context/request/async/StaleAsyncRequestCheckingCallable.java create mode 100644 spring-web/src/main/java/org/springframework/web/context/request/async/WebAsyncManager.java create mode 100644 spring-web/src/test/java/org/springframework/mock/web/MockAsyncContext.java delete mode 100644 spring-web/src/test/java/org/springframework/web/context/request/async/AsyncExecutionChainTests.java delete mode 100644 spring-web/src/test/java/org/springframework/web/context/request/async/StaleAsyncRequestCheckingCallableTests.java create mode 100644 spring-web/src/test/java/org/springframework/web/context/request/async/WebAsyncManagerTests.java create mode 100644 spring-webmvc/src/test/java/org/springframework/mock/web/MockAsyncContext.java diff --git a/build.gradle b/build.gradle index 2ea0800567e..d72b3f0360a 100644 --- a/build.gradle +++ b/build.gradle @@ -366,9 +366,13 @@ project('spring-web') { compile("org.codehaus.jackson:jackson-mapper-asl:1.4.2", optional) compile("com.fasterxml.jackson.core:jackson-databind:2.0.1", optional) compile("taglibs:standard:1.1.2", optional) - compile("org.mortbay.jetty:jetty:6.1.9") { dep -> + compile("org.eclipse.jetty:jetty-servlet:8.1.5.v20120716") { dep -> optional dep - exclude group: 'org.mortbay.jetty', module: 'servlet-api-2.5' + exclude group: 'org.eclipse.jetty.orbit', module: 'javax.servlet' + } + compile("org.eclipse.jetty:jetty-server:8.1.5.v20120716") { dep -> + optional dep + exclude group: 'org.eclipse.jetty.orbit', module: 'javax.servlet' } testCompile project(":spring-context-support") // for JafMediaTypeFactory testCompile "xmlunit:xmlunit:1.2" diff --git a/spring-orm/src/main/java/org/springframework/orm/hibernate3/support/OpenSessionInViewFilter.java b/spring-orm/src/main/java/org/springframework/orm/hibernate3/support/OpenSessionInViewFilter.java index f55ae4e4366..1af4dc4d2f9 100644 --- a/spring-orm/src/main/java/org/springframework/orm/hibernate3/support/OpenSessionInViewFilter.java +++ b/spring-orm/src/main/java/org/springframework/orm/hibernate3/support/OpenSessionInViewFilter.java @@ -17,6 +17,7 @@ package org.springframework.orm.hibernate3.support; import java.io.IOException; + import javax.servlet.FilterChain; import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; @@ -25,14 +26,15 @@ import javax.servlet.http.HttpServletResponse; import org.hibernate.FlushMode; import org.hibernate.Session; import org.hibernate.SessionFactory; - import org.springframework.dao.DataAccessResourceFailureException; import org.springframework.orm.hibernate3.SessionFactoryUtils; import org.springframework.orm.hibernate3.SessionHolder; import org.springframework.transaction.support.TransactionSynchronizationManager; +import org.springframework.util.Assert; import org.springframework.web.context.WebApplicationContext; -import org.springframework.web.context.request.async.AbstractDelegatingCallable; -import org.springframework.web.context.request.async.AsyncExecutionChain; +import org.springframework.web.context.request.async.AsyncWebUtils; +import org.springframework.web.context.request.async.WebAsyncManager; +import org.springframework.web.context.request.async.WebAsyncManager.AsyncThreadInitializer; import org.springframework.web.context.support.WebApplicationContextUtils; import org.springframework.web.filter.OncePerRequestFilter; @@ -165,16 +167,27 @@ public class OpenSessionInViewFilter extends OncePerRequestFilter { } + /** + * The default value is "true" so that the filter may re-bind the opened + * {@code Session} to each asynchronously dispatched thread and postpone + * closing it until the very last asynchronous dispatch. + */ + @Override + protected boolean shouldFilterAsyncDispatches() { + return true; + } + @Override protected void doFilterInternal( HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { - AsyncExecutionChain chain = AsyncExecutionChain.getForCurrentRequest(request); - SessionFactory sessionFactory = lookupSessionFactory(request); boolean participate = false; + WebAsyncManager asyncManager = AsyncWebUtils.getAsyncManager(request); + String key = getAlreadyFilteredAttributeName(); + if (isSingleSession()) { // single session mode if (TransactionSynchronizationManager.hasResource(sessionFactory)) { @@ -182,16 +195,20 @@ public class OpenSessionInViewFilter extends OncePerRequestFilter { participate = true; } else { - logger.debug("Opening single Hibernate Session in OpenSessionInViewFilter"); - Session session = getSession(sessionFactory); - SessionHolder sessionHolder = new SessionHolder(session); - TransactionSynchronizationManager.bindResource(sessionFactory, sessionHolder); + if (!isAsyncDispatch(request) || !asyncManager.applyAsyncThreadInitializer(key)) { + logger.debug("Opening single Hibernate Session in OpenSessionInViewFilter"); + Session session = getSession(sessionFactory); + SessionHolder sessionHolder = new SessionHolder(session); + TransactionSynchronizationManager.bindResource(sessionFactory, sessionHolder); - chain.push(getAsyncCallable(request, sessionFactory, sessionHolder)); + AsyncThreadInitializer initializer = createAsyncThreadInitializer(sessionFactory, sessionHolder); + asyncManager.registerAsyncThreadInitializer(key, initializer); + } } } else { // deferred close mode + Assert.state(isLastRequestThread(request), "Deferred close mode is not supported on async dispatches"); if (SessionFactoryUtils.isDeferredCloseActive(sessionFactory)) { // Do not modify deferred close: just set the participate flag. participate = true; @@ -210,16 +227,12 @@ public class OpenSessionInViewFilter extends OncePerRequestFilter { // single session mode SessionHolder sessionHolder = (SessionHolder) TransactionSynchronizationManager.unbindResource(sessionFactory); - if (!chain.pop()) { - return; + if (isLastRequestThread(request)) { + logger.debug("Closing single Hibernate Session in OpenSessionInViewFilter"); + closeSession(sessionHolder.getSession(), sessionFactory); } - logger.debug("Closing single Hibernate Session in OpenSessionInViewFilter"); - closeSession(sessionHolder.getSession(), sessionFactory); } else { - if (chain.isAsyncStarted()) { - throw new IllegalStateException("Deferred close is not supported with async requests."); - } // deferred close mode SessionFactoryUtils.processDeferredClose(sessionFactory); } @@ -227,6 +240,19 @@ public class OpenSessionInViewFilter extends OncePerRequestFilter { } } + private AsyncThreadInitializer createAsyncThreadInitializer(final SessionFactory sessionFactory, + final SessionHolder sessionHolder) { + + return new AsyncThreadInitializer() { + public void initialize() { + TransactionSynchronizationManager.bindResource(sessionFactory, sessionHolder); + } + public void reset() { + TransactionSynchronizationManager.unbindResource(sessionFactory); + } + }; + } + /** * Look up the SessionFactory that this filter should use, * taking the current HTTP request as argument. @@ -291,28 +317,4 @@ public class OpenSessionInViewFilter extends OncePerRequestFilter { SessionFactoryUtils.closeSession(session); } - /** - * Create a Callable to extend the use of the open Hibernate Session to the - * async thread completing the request. - */ - private AbstractDelegatingCallable getAsyncCallable(final HttpServletRequest request, - final SessionFactory sessionFactory, final SessionHolder sessionHolder) { - - return new AbstractDelegatingCallable() { - public Object call() throws Exception { - TransactionSynchronizationManager.bindResource(sessionFactory, sessionHolder); - try { - getNext().call(); - } - finally { - SessionHolder sessionHolder = - (SessionHolder) TransactionSynchronizationManager.unbindResource(sessionFactory); - logger.debug("Closing Hibernate Session in OpenSessionInViewFilter"); - SessionFactoryUtils.closeSession(sessionHolder.getSession()); - } - return null; - } - }; - } - } diff --git a/spring-orm/src/main/java/org/springframework/orm/hibernate3/support/OpenSessionInViewInterceptor.java b/spring-orm/src/main/java/org/springframework/orm/hibernate3/support/OpenSessionInViewInterceptor.java index 253442389e3..5f7ca275490 100644 --- a/spring-orm/src/main/java/org/springframework/orm/hibernate3/support/OpenSessionInViewInterceptor.java +++ b/spring-orm/src/main/java/org/springframework/orm/hibernate3/support/OpenSessionInViewInterceptor.java @@ -25,8 +25,10 @@ import org.springframework.orm.hibernate3.SessionHolder; import org.springframework.transaction.support.TransactionSynchronizationManager; import org.springframework.ui.ModelMap; import org.springframework.web.context.request.WebRequest; -import org.springframework.web.context.request.async.AbstractDelegatingCallable; import org.springframework.web.context.request.async.AsyncWebRequestInterceptor; +import org.springframework.web.context.request.async.AsyncWebUtils; +import org.springframework.web.context.request.async.WebAsyncManager; +import org.springframework.web.context.request.async.WebAsyncManager.AsyncThreadInitializer; /** * Spring web request interceptor that binds a Hibernate Session to the @@ -140,10 +142,19 @@ public class OpenSessionInViewInterceptor extends HibernateAccessor implements A * @see org.springframework.orm.hibernate3.SessionFactoryUtils#getSession */ public void preHandle(WebRequest request) throws DataAccessException { + + WebAsyncManager asyncManager = AsyncWebUtils.getAsyncManager(request); + String participateAttributeName = getParticipateAttributeName(); + + if (asyncManager.hasConcurrentResult()) { + if (asyncManager.applyAsyncThreadInitializer(participateAttributeName)) { + return; + } + } + if ((isSingleSession() && TransactionSynchronizationManager.hasResource(getSessionFactory())) || SessionFactoryUtils.isDeferredCloseActive(getSessionFactory())) { // Do not modify the Session: just mark the request accordingly. - String participateAttributeName = getParticipateAttributeName(); Integer count = (Integer) request.getAttribute(participateAttributeName, WebRequest.SCOPE_REQUEST); int newCount = (count != null ? count + 1 : 1); request.setAttribute(getParticipateAttributeName(), newCount, WebRequest.SCOPE_REQUEST); @@ -157,6 +168,9 @@ public class OpenSessionInViewInterceptor extends HibernateAccessor implements A applyFlushMode(session, false); SessionHolder sessionHolder = new SessionHolder(session); TransactionSynchronizationManager.bindResource(getSessionFactory(), sessionHolder); + + AsyncThreadInitializer asyncThreadInitializer = createThreadInitializer(sessionHolder); + asyncManager.registerAsyncThreadInitializer(participateAttributeName, asyncThreadInitializer); } else { // deferred close mode @@ -165,44 +179,6 @@ public class OpenSessionInViewInterceptor extends HibernateAccessor implements A } } - /** - * Create a Callable to bind the Hibernate session - * to the async request thread. - */ - public AbstractDelegatingCallable getAsyncCallable(WebRequest request) { - String attributeName = getParticipateAttributeName(); - if ((request.getAttribute(attributeName, WebRequest.SCOPE_REQUEST) != null) || !isSingleSession()) { - return null; - } - - final SessionHolder sessionHolder = - (SessionHolder) TransactionSynchronizationManager.getResource(getSessionFactory()); - - return new AbstractDelegatingCallable() { - public Object call() throws Exception { - TransactionSynchronizationManager.bindResource(getSessionFactory(), sessionHolder); - getNext().call(); - return null; - } - }; - } - - /** - * Unbind the Hibernate Session from the main thread but leave - * the Session open for further use from the async thread. - */ - public void postHandleAsyncStarted(WebRequest request) { - String attributeName = getParticipateAttributeName(); - if (request.getAttribute(attributeName, WebRequest.SCOPE_REQUEST) == null) { - if (isSingleSession()) { - TransactionSynchronizationManager.unbindResource(getSessionFactory()); - } - else { - throw new IllegalStateException("Deferred close is not supported with async requests."); - } - } - } - /** * Flush the Hibernate Session before view rendering, if necessary. *

Note that this just applies in {@link #isSingleSession() single session mode}! @@ -232,18 +208,7 @@ public class OpenSessionInViewInterceptor extends HibernateAccessor implements A * @see org.springframework.transaction.support.TransactionSynchronizationManager */ public void afterCompletion(WebRequest request, Exception ex) throws DataAccessException { - String participateAttributeName = getParticipateAttributeName(); - Integer count = (Integer) request.getAttribute(participateAttributeName, WebRequest.SCOPE_REQUEST); - if (count != null) { - // Do not modify the Session: just clear the marker. - if (count > 1) { - request.setAttribute(participateAttributeName, count - 1, WebRequest.SCOPE_REQUEST); - } - else { - request.removeAttribute(participateAttributeName, WebRequest.SCOPE_REQUEST); - } - } - else { + if (!decrementParticipateCount(request)) { if (isSingleSession()) { // single session mode SessionHolder sessionHolder = @@ -258,6 +223,34 @@ public class OpenSessionInViewInterceptor extends HibernateAccessor implements A } } + public void afterConcurrentHandlingStarted(WebRequest request) { + if (!decrementParticipateCount(request)) { + if (isSingleSession()) { + TransactionSynchronizationManager.unbindResource(getSessionFactory()); + } + else { + throw new IllegalStateException("Deferred close mode is not supported with async requests."); + } + + } + } + + private boolean decrementParticipateCount(WebRequest request) { + String participateAttributeName = getParticipateAttributeName(); + Integer count = (Integer) request.getAttribute(participateAttributeName, WebRequest.SCOPE_REQUEST); + if (count == null) { + return false; + } + // Do not modify the Session: just clear the marker. + if (count > 1) { + request.setAttribute(participateAttributeName, count - 1, WebRequest.SCOPE_REQUEST); + } + else { + request.removeAttribute(participateAttributeName, WebRequest.SCOPE_REQUEST); + } + return true; + } + /** * Return the name of the request attribute that identifies that a request is * already intercepted. @@ -268,4 +261,15 @@ public class OpenSessionInViewInterceptor extends HibernateAccessor implements A return getSessionFactory().toString() + PARTICIPATE_SUFFIX; } + private AsyncThreadInitializer createThreadInitializer(final SessionHolder sessionHolder) { + return new AsyncThreadInitializer() { + public void initialize() { + TransactionSynchronizationManager.bindResource(getSessionFactory(), sessionHolder); + } + public void reset() { + TransactionSynchronizationManager.unbindResource(getSessionFactory()); + } + }; + } + } diff --git a/spring-orm/src/main/java/org/springframework/orm/hibernate4/support/OpenSessionInViewFilter.java b/spring-orm/src/main/java/org/springframework/orm/hibernate4/support/OpenSessionInViewFilter.java index 9f1f0b10dea..329a4e4eca1 100644 --- a/spring-orm/src/main/java/org/springframework/orm/hibernate4/support/OpenSessionInViewFilter.java +++ b/spring-orm/src/main/java/org/springframework/orm/hibernate4/support/OpenSessionInViewFilter.java @@ -17,6 +17,7 @@ package org.springframework.orm.hibernate4.support; import java.io.IOException; + import javax.servlet.FilterChain; import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; @@ -26,14 +27,14 @@ import org.hibernate.FlushMode; import org.hibernate.HibernateException; import org.hibernate.Session; import org.hibernate.SessionFactory; - import org.springframework.dao.DataAccessResourceFailureException; import org.springframework.orm.hibernate4.SessionFactoryUtils; import org.springframework.orm.hibernate4.SessionHolder; import org.springframework.transaction.support.TransactionSynchronizationManager; import org.springframework.web.context.WebApplicationContext; -import org.springframework.web.context.request.async.AbstractDelegatingCallable; -import org.springframework.web.context.request.async.AsyncExecutionChain; +import org.springframework.web.context.request.async.AsyncWebUtils; +import org.springframework.web.context.request.async.WebAsyncManager; +import org.springframework.web.context.request.async.WebAsyncManager.AsyncThreadInitializer; import org.springframework.web.context.support.WebApplicationContextUtils; import org.springframework.web.filter.OncePerRequestFilter; @@ -99,27 +100,41 @@ public class OpenSessionInViewFilter extends OncePerRequestFilter { } + /** + * The default value is "true" so that the filter may re-bind the opened + * {@code Session} to each asynchronously dispatched thread and postpone + * closing it until the very last asynchronous dispatch. + */ + @Override + protected boolean shouldFilterAsyncDispatches() { + return true; + } + @Override protected void doFilterInternal( HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { - AsyncExecutionChain chain = AsyncExecutionChain.getForCurrentRequest(request); - SessionFactory sessionFactory = lookupSessionFactory(request); boolean participate = false; + WebAsyncManager asyncManager = AsyncWebUtils.getAsyncManager(request); + String key = getAlreadyFilteredAttributeName(); + if (TransactionSynchronizationManager.hasResource(sessionFactory)) { // Do not modify the Session: just set the participate flag. participate = true; } else { - logger.debug("Opening Hibernate Session in OpenSessionInViewFilter"); - Session session = openSession(sessionFactory); - SessionHolder sessionHolder = new SessionHolder(session); - TransactionSynchronizationManager.bindResource(sessionFactory, sessionHolder); + if (!isAsyncDispatch(request) || !asyncManager.applyAsyncThreadInitializer(key)) { + logger.debug("Opening Hibernate Session in OpenSessionInViewFilter"); + Session session = openSession(sessionFactory); + SessionHolder sessionHolder = new SessionHolder(session); + TransactionSynchronizationManager.bindResource(sessionFactory, sessionHolder); - chain.push(getAsyncCallable(request, sessionFactory, sessionHolder)); + AsyncThreadInitializer initializer = createAsyncThreadInitializer(sessionFactory, sessionHolder); + asyncManager.registerAsyncThreadInitializer(key, initializer); + } } try { @@ -130,15 +145,27 @@ public class OpenSessionInViewFilter extends OncePerRequestFilter { if (!participate) { SessionHolder sessionHolder = (SessionHolder) TransactionSynchronizationManager.unbindResource(sessionFactory); - if (!chain.pop()) { - return; + if (isLastRequestThread(request)) { + logger.debug("Closing Hibernate Session in OpenSessionInViewFilter"); + SessionFactoryUtils.closeSession(sessionHolder.getSession()); } - logger.debug("Closing Hibernate Session in OpenSessionInViewFilter"); - SessionFactoryUtils.closeSession(sessionHolder.getSession()); } } } + private AsyncThreadInitializer createAsyncThreadInitializer(final SessionFactory sessionFactory, + final SessionHolder sessionHolder) { + + return new AsyncThreadInitializer() { + public void initialize() { + TransactionSynchronizationManager.bindResource(sessionFactory, sessionHolder); + } + public void reset() { + TransactionSynchronizationManager.unbindResource(sessionFactory); + } + }; + } + /** * Look up the SessionFactory that this filter should use, * taking the current HTTP request as argument. @@ -187,28 +214,4 @@ public class OpenSessionInViewFilter extends OncePerRequestFilter { } } - /** - * Create a Callable to extend the use of the open Hibernate Session to the - * async thread completing the request. - */ - private AbstractDelegatingCallable getAsyncCallable(final HttpServletRequest request, - final SessionFactory sessionFactory, final SessionHolder sessionHolder) { - - return new AbstractDelegatingCallable() { - public Object call() throws Exception { - TransactionSynchronizationManager.bindResource(sessionFactory, sessionHolder); - try { - getNext().call(); - } - finally { - SessionHolder sessionHolder = - (SessionHolder) TransactionSynchronizationManager.unbindResource(sessionFactory); - logger.debug("Closing Hibernate Session in OpenSessionInViewFilter"); - SessionFactoryUtils.closeSession(sessionHolder.getSession()); - } - return null; - } - }; - } - } diff --git a/spring-orm/src/main/java/org/springframework/orm/hibernate4/support/OpenSessionInViewInterceptor.java b/spring-orm/src/main/java/org/springframework/orm/hibernate4/support/OpenSessionInViewInterceptor.java index 7bd43ac9394..6132430c38a 100644 --- a/spring-orm/src/main/java/org/springframework/orm/hibernate4/support/OpenSessionInViewInterceptor.java +++ b/spring-orm/src/main/java/org/springframework/orm/hibernate4/support/OpenSessionInViewInterceptor.java @@ -29,9 +29,10 @@ import org.springframework.orm.hibernate4.SessionHolder; import org.springframework.transaction.support.TransactionSynchronizationManager; import org.springframework.ui.ModelMap; import org.springframework.web.context.request.WebRequest; -import org.springframework.web.context.request.async.AbstractDelegatingCallable; -import org.springframework.web.context.request.async.AsyncExecutionChain; import org.springframework.web.context.request.async.AsyncWebRequestInterceptor; +import org.springframework.web.context.request.async.AsyncWebUtils; +import org.springframework.web.context.request.async.WebAsyncManager; +import org.springframework.web.context.request.async.WebAsyncManager.AsyncThreadInitializer; /** * Spring web request interceptor that binds a Hibernate Session to the @@ -103,9 +104,18 @@ public class OpenSessionInViewInterceptor implements AsyncWebRequestInterceptor * {@link org.springframework.transaction.support.TransactionSynchronizationManager}. */ public void preHandle(WebRequest request) throws DataAccessException { + + WebAsyncManager asyncManager = AsyncWebUtils.getAsyncManager(request); + String participateAttributeName = getParticipateAttributeName(); + + if (asyncManager.hasConcurrentResult()) { + if (asyncManager.applyAsyncThreadInitializer(participateAttributeName)) { + return; + } + } + if (TransactionSynchronizationManager.hasResource(getSessionFactory())) { // Do not modify the Session: just mark the request accordingly. - String participateAttributeName = getParticipateAttributeName(); Integer count = (Integer) request.getAttribute(participateAttributeName, WebRequest.SCOPE_REQUEST); int newCount = (count != null ? count + 1 : 1); request.setAttribute(getParticipateAttributeName(), newCount, WebRequest.SCOPE_REQUEST); @@ -115,6 +125,9 @@ public class OpenSessionInViewInterceptor implements AsyncWebRequestInterceptor Session session = openSession(); SessionHolder sessionHolder = new SessionHolder(session); TransactionSynchronizationManager.bindResource(getSessionFactory(), sessionHolder); + + AsyncThreadInitializer asyncThreadInitializer = createThreadInitializer(sessionHolder); + asyncManager.registerAsyncThreadInitializer(participateAttributeName, asyncThreadInitializer); } } @@ -122,60 +135,39 @@ public class OpenSessionInViewInterceptor implements AsyncWebRequestInterceptor } /** - * Create a Callable to bind the Hibernate session - * to the async request thread. + * Unbind the Hibernate Session from the thread and close it). + * @see org.springframework.transaction.support.TransactionSynchronizationManager */ - public AbstractDelegatingCallable getAsyncCallable(WebRequest request) { - String attributeName = getParticipateAttributeName(); - if (request.getAttribute(attributeName, WebRequest.SCOPE_REQUEST) != null) { - return null; - } - - final SessionHolder sessionHolder = - (SessionHolder) TransactionSynchronizationManager.getResource(getSessionFactory()); + public void afterCompletion(WebRequest request, Exception ex) throws DataAccessException { + if (!decrementParticipateCount(request)) { + SessionHolder sessionHolder = + (SessionHolder) TransactionSynchronizationManager.unbindResource(getSessionFactory()); + logger.debug("Closing Hibernate Session in OpenSessionInViewInterceptor"); + SessionFactoryUtils.closeSession(sessionHolder.getSession()); - return new AbstractDelegatingCallable() { - public Object call() throws Exception { - TransactionSynchronizationManager.bindResource(getSessionFactory(), sessionHolder); - getNext().call(); - return null; - } - }; + } } - /** - * Unbind the Hibernate Session from the main thread leaving - * it open for further use from an async thread. - */ - public void postHandleAsyncStarted(WebRequest request) { - String attributeName = getParticipateAttributeName(); - if (request.getAttribute(attributeName, WebRequest.SCOPE_REQUEST) == null) { + public void afterConcurrentHandlingStarted(WebRequest request) { + if (!decrementParticipateCount(request)) { TransactionSynchronizationManager.unbindResource(getSessionFactory()); } } - /** - * Unbind the Hibernate Session from the thread and close it). - * @see org.springframework.transaction.support.TransactionSynchronizationManager - */ - public void afterCompletion(WebRequest request, Exception ex) throws DataAccessException { + private boolean decrementParticipateCount(WebRequest request) { String participateAttributeName = getParticipateAttributeName(); Integer count = (Integer) request.getAttribute(participateAttributeName, WebRequest.SCOPE_REQUEST); - if (count != null) { - // Do not modify the Session: just clear the marker. - if (count > 1) { - request.setAttribute(participateAttributeName, count - 1, WebRequest.SCOPE_REQUEST); - } - else { - request.removeAttribute(participateAttributeName, WebRequest.SCOPE_REQUEST); - } + if (count == null) { + return false; + } + // Do not modify the Session: just clear the marker. + if (count > 1) { + request.setAttribute(participateAttributeName, count - 1, WebRequest.SCOPE_REQUEST); } else { - SessionHolder sessionHolder = - (SessionHolder) TransactionSynchronizationManager.unbindResource(getSessionFactory()); - logger.debug("Closing Hibernate Session in OpenSessionInViewInterceptor"); - SessionFactoryUtils.closeSession(sessionHolder.getSession()); + request.removeAttribute(participateAttributeName, WebRequest.SCOPE_REQUEST); } + return true; } /** @@ -208,4 +200,15 @@ public class OpenSessionInViewInterceptor implements AsyncWebRequestInterceptor return getSessionFactory().toString() + PARTICIPATE_SUFFIX; } + private AsyncThreadInitializer createThreadInitializer(final SessionHolder sessionHolder) { + return new AsyncThreadInitializer() { + public void initialize() { + TransactionSynchronizationManager.bindResource(getSessionFactory(), sessionHolder); + } + public void reset() { + TransactionSynchronizationManager.unbindResource(getSessionFactory()); + } + }; + } + } diff --git a/spring-orm/src/test/java/org/springframework/orm/hibernate3/support/OpenSessionInViewTests.java b/spring-orm/src/test/java/org/springframework/orm/hibernate3/support/OpenSessionInViewTests.java index b31747558d2..9a845214d5f 100644 --- a/spring-orm/src/test/java/org/springframework/orm/hibernate3/support/OpenSessionInViewTests.java +++ b/spring-orm/src/test/java/org/springframework/orm/hibernate3/support/OpenSessionInViewTests.java @@ -16,8 +16,10 @@ package org.springframework.orm.hibernate3.support; +import static org.easymock.EasyMock.anyObject; import static org.easymock.EasyMock.createStrictMock; import static org.easymock.EasyMock.expect; +import static org.easymock.EasyMock.expectLastCall; import static org.easymock.EasyMock.replay; import static org.easymock.EasyMock.reset; import static org.easymock.EasyMock.verify; @@ -29,6 +31,7 @@ import static org.junit.Assert.assertTrue; import java.io.IOException; import java.sql.Connection; import java.util.concurrent.Callable; +import java.util.concurrent.atomic.AtomicInteger; import javax.servlet.FilterChain; import javax.servlet.ServletException; @@ -59,9 +62,9 @@ import org.springframework.transaction.support.DefaultTransactionDefinition; import org.springframework.transaction.support.TransactionSynchronizationManager; import org.springframework.web.context.WebApplicationContext; import org.springframework.web.context.request.ServletWebRequest; -import org.springframework.web.context.request.async.AbstractDelegatingCallable; -import org.springframework.web.context.request.async.AsyncExecutionChain; import org.springframework.web.context.request.async.AsyncWebRequest; +import org.springframework.web.context.request.async.AsyncWebUtils; +import org.springframework.web.context.request.async.WebAsyncManager; import org.springframework.web.context.support.StaticWebApplicationContext; @@ -152,6 +155,8 @@ public class OpenSessionInViewTests { @Test public void testOpenSessionInViewInterceptorAsyncScenario() throws Exception { + // Initial request thread + final SessionFactory sf = createStrictMock(SessionFactory.class); Session session = createStrictMock(Session.class); @@ -167,39 +172,52 @@ public class OpenSessionInViewTests { interceptor.preHandle(this.webRequest); assertTrue(TransactionSynchronizationManager.hasResource(sf)); - AbstractDelegatingCallable asyncCallable = interceptor.getAsyncCallable(this.webRequest); - assertNotNull(asyncCallable); - - interceptor.postHandleAsyncStarted(this.webRequest); - assertFalse(TransactionSynchronizationManager.hasResource(sf)); - verify(sf); verify(session); - asyncCallable.setNext(new Callable() { - public Object call() { - return null; + AsyncWebRequest asyncWebRequest = createStrictMock(AsyncWebRequest.class); + asyncWebRequest.addCompletionHandler((Runnable) anyObject()); + asyncWebRequest.startAsync(); + replay(asyncWebRequest); + + WebAsyncManager asyncManager = AsyncWebUtils.getAsyncManager(this.request); + asyncManager.setAsyncWebRequest(asyncWebRequest); + + asyncManager.startCallableProcessing(new Callable() { + public String call() throws Exception { + return "anything"; } }); - asyncCallable.call(); + verify(asyncWebRequest); + + interceptor.afterConcurrentHandlingStarted(this.webRequest); + assertFalse(TransactionSynchronizationManager.hasResource(sf)); + + // Async dispatch thread + + interceptor.preHandle(this.webRequest); assertTrue("Session not bound to async thread", TransactionSynchronizationManager.hasResource(sf)); verify(sf); - verify(session); reset(sf); - reset(session); replay(sf); + + verify(session); + reset(session); replay(session); interceptor.postHandle(this.webRequest, null); assertTrue(TransactionSynchronizationManager.hasResource(sf)); verify(sf); - verify(session); reset(sf); + + verify(session); reset(session); + expect(session.close()).andReturn(null); + replay(sf); replay(session); @@ -438,10 +456,12 @@ public class OpenSessionInViewTests { } @Test - public void testOpenSessionInViewFilterWithSingleSessionAsyncScenario() throws Exception { + public void testOpenSessionInViewFilterAsyncScenario() throws Exception { final SessionFactory sf = createStrictMock(SessionFactory.class); Session session = createStrictMock(Session.class); + // Initial request during which concurrent handler execution starts.. + expect(sf.openSession()).andReturn(session); expect(session.getSessionFactory()).andReturn(sf); session.setFlushMode(FlushMode.MANUAL); @@ -456,55 +476,60 @@ public class OpenSessionInViewTests { MockFilterConfig filterConfig = new MockFilterConfig(wac.getServletContext(), "filter"); + final AtomicInteger count = new AtomicInteger(0); + final OpenSessionInViewFilter filter = new OpenSessionInViewFilter(); filter.init(filterConfig); final FilterChain filterChain = new FilterChain() { public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse) { assertTrue(TransactionSynchronizationManager.hasResource(sf)); - servletRequest.setAttribute("invoked", Boolean.TRUE); + count.incrementAndGet(); } }; AsyncWebRequest asyncWebRequest = createStrictMock(AsyncWebRequest.class); + asyncWebRequest.addCompletionHandler((Runnable) anyObject()); + asyncWebRequest.startAsync(); expect(asyncWebRequest.isAsyncStarted()).andReturn(true); - expect(asyncWebRequest.isAsyncStarted()).andReturn(true); + expectLastCall().anyTimes(); replay(asyncWebRequest); - AsyncExecutionChain chain = AsyncExecutionChain.getForCurrentRequest(this.request); - chain.setAsyncWebRequest(asyncWebRequest); + WebAsyncManager asyncManager = AsyncWebUtils.getAsyncManager(this.request); + asyncManager.setAsyncWebRequest(asyncWebRequest); + asyncManager.startCallableProcessing(new Callable() { + public String call() throws Exception { + return "anything"; + } + }); assertFalse(TransactionSynchronizationManager.hasResource(sf)); filter.doFilter(this.request, this.response, filterChain); assertFalse(TransactionSynchronizationManager.hasResource(sf)); - assertNotNull(this.request.getAttribute("invoked")); + assertEquals(1, count.get()); verify(sf); verify(session); verify(asyncWebRequest); - chain.setTaskExecutor(new SyncTaskExecutor()); - chain.setLastCallable(new Callable() { - public Object call() { - assertTrue(TransactionSynchronizationManager.hasResource(sf)); - return null; - } - }); - - reset(asyncWebRequest); - asyncWebRequest.startAsync(); - expect(asyncWebRequest.isAsyncCompleted()).andReturn(false); - asyncWebRequest.complete(); - replay(asyncWebRequest); - reset(sf); reset(session); + reset(asyncWebRequest); + + // Async dispatch after concurrent handler execution results ready.. + expect(session.close()).andReturn(null); + expect(asyncWebRequest.isAsyncStarted()).andReturn(false); + expectLastCall().anyTimes(); + replay(sf); replay(session); + replay(asyncWebRequest); - chain.startCallableProcessing(); assertFalse(TransactionSynchronizationManager.hasResource(sf)); + filter.doFilter(this.request, this.response, filterChain); + assertFalse(TransactionSynchronizationManager.hasResource(sf)); + assertEquals(2, count.get()); verify(sf); verify(session); diff --git a/spring-web/src/main/java/org/springframework/web/accept/PathExtensionContentNegotiationStrategy.java b/spring-web/src/main/java/org/springframework/web/accept/PathExtensionContentNegotiationStrategy.java index eb76db51610..4636183eee5 100644 --- a/spring-web/src/main/java/org/springframework/web/accept/PathExtensionContentNegotiationStrategy.java +++ b/spring-web/src/main/java/org/springframework/web/accept/PathExtensionContentNegotiationStrategy.java @@ -108,9 +108,6 @@ public class PathExtensionContentNegotiationStrategy extends AbstractMappingCont @Override protected void handleMatch(String extension, MediaType mediaType) { - if (logger.isDebugEnabled()) { - logger.debug("Requested media type is '" + mediaType + "' (based on file extension '" + extension + "')"); - } } @Override diff --git a/spring-web/src/main/java/org/springframework/web/context/request/WebRequestInterceptor.java b/spring-web/src/main/java/org/springframework/web/context/request/WebRequestInterceptor.java index d5d9910c70e..bad61d5dac9 100644 --- a/spring-web/src/main/java/org/springframework/web/context/request/WebRequestInterceptor.java +++ b/spring-web/src/main/java/org/springframework/web/context/request/WebRequestInterceptor.java @@ -28,6 +28,14 @@ import org.springframework.ui.ModelMap; * Alternatively, a handler may also process the request completely, with no * view to be rendered. * + *

In an async processing scenario, the handler may be executed in a separate + * thread while the main thread exits without rendering or invoking the + * {@code postHandle} and {@code afterCompletion} callbacks. When concurrent + * handler execution completes, the request is dispatched back in order to + * proceed with rendering the model and all methods of this contract are invoked + * again. For further options and comments see + * {@code org.springframework.web.context.request.async.AsyncWebRequestInterceptor} + * *

This interface is deliberately minimalistic to keep the dependencies of * generic request interceptors as minimal as feasible. * diff --git a/spring-web/src/main/java/org/springframework/web/context/request/async/AbstractDelegatingCallable.java b/spring-web/src/main/java/org/springframework/web/context/request/async/AbstractDelegatingCallable.java deleted file mode 100644 index 265f722c0f0..00000000000 --- a/spring-web/src/main/java/org/springframework/web/context/request/async/AbstractDelegatingCallable.java +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Copyright 2002-2012 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.web.context.request.async; - -import java.util.concurrent.Callable; - -/** - * A base class for a Callable used to form a chain of Callable instances. - * Instances of this class are typically registered via - * {@link AsyncExecutionChain#push(AbstractDelegatingCallable)} in which case - * there is no need to set the next Callable. Implementations can simply use - * {@link #getNext()} to delegate to the next Callable and assume it will be set. - * - * @author Rossen Stoyanchev - * @since 3.2 - * - * @see AsyncExecutionChain - */ -public abstract class AbstractDelegatingCallable implements Callable { - - private Callable next; - - protected Callable getNext() { - return this.next; - } - - public void setNext(Callable callable) { - this.next = callable; - } - -} diff --git a/spring-web/src/main/java/org/springframework/web/context/request/async/AsyncExecutionChain.java b/spring-web/src/main/java/org/springframework/web/context/request/async/AsyncExecutionChain.java deleted file mode 100644 index 48d0a165cc6..00000000000 --- a/spring-web/src/main/java/org/springframework/web/context/request/async/AsyncExecutionChain.java +++ /dev/null @@ -1,216 +0,0 @@ -/* - * Copyright 2002-2012 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.web.context.request.async; - -import java.util.ArrayDeque; -import java.util.Deque; -import java.util.concurrent.Callable; - -import javax.servlet.ServletRequest; - -import org.springframework.core.task.AsyncTaskExecutor; -import org.springframework.core.task.SimpleAsyncTaskExecutor; -import org.springframework.util.Assert; -import org.springframework.web.context.request.RequestAttributes; -import org.springframework.web.context.request.WebRequest; -import org.springframework.web.context.request.async.DeferredResult.DeferredResultHandler; - -/** - * The central class for managing async request processing, mainly intended as - * an SPI and not typically used directly by application classes. - * - *

An async execution chain consists of a sequence of Callable instances that - * represent the work required to complete request processing in a separate thread. - * To construct the chain, each level of the call stack pushes an - * {@link AbstractDelegatingCallable} during the course of a normal request and - * pops (removes) it on the way out. If async processing has not started, the pop - * operation succeeds and the processing continues as normal, or otherwise if async - * processing has begun, the main processing thread must be exited. - * - *

For example the DispatcherServlet might contribute a Callable that completes - * view resolution or the HandlerAdapter might contribute a Callable that prepares a - * ModelAndView while the last Callable in the chain is usually associated with the - * application, e.g. the return value of an {@code @RequestMapping} method. - * - * @author Rossen Stoyanchev - * @since 3.2 - */ -public final class AsyncExecutionChain { - - public static final String CALLABLE_CHAIN_ATTRIBUTE = AsyncExecutionChain.class.getName() + ".CALLABLE_CHAIN"; - - private final Deque callables = new ArrayDeque(); - - private Callable lastCallable; - - private AsyncWebRequest asyncWebRequest; - - private AsyncTaskExecutor taskExecutor = new SimpleAsyncTaskExecutor("MvcAsync"); - - /** - * Private constructor - * @see #getForCurrentRequest() - */ - private AsyncExecutionChain() { - } - - /** - * Obtain the AsyncExecutionChain for the current request. - * Or if not found, create it and associate it with the request. - */ - public static AsyncExecutionChain getForCurrentRequest(ServletRequest request) { - AsyncExecutionChain chain = (AsyncExecutionChain) request.getAttribute(CALLABLE_CHAIN_ATTRIBUTE); - if (chain == null) { - chain = new AsyncExecutionChain(); - request.setAttribute(CALLABLE_CHAIN_ATTRIBUTE, chain); - } - return chain; - } - - /** - * Obtain the AsyncExecutionChain for the current request. - * Or if not found, create it and associate it with the request. - */ - public static AsyncExecutionChain getForCurrentRequest(WebRequest request) { - int scope = RequestAttributes.SCOPE_REQUEST; - AsyncExecutionChain chain = (AsyncExecutionChain) request.getAttribute(CALLABLE_CHAIN_ATTRIBUTE, scope); - if (chain == null) { - chain = new AsyncExecutionChain(); - request.setAttribute(CALLABLE_CHAIN_ATTRIBUTE, chain, scope); - } - return chain; - } - - /** - * Provide an instance of an AsyncWebRequest -- required for async processing. - */ - public void setAsyncWebRequest(AsyncWebRequest asyncRequest) { - Assert.state(!isAsyncStarted(), "Cannot set AsyncWebRequest after the start of async processing."); - this.asyncWebRequest = asyncRequest; - } - - /** - * Provide an AsyncTaskExecutor for use with {@link #startCallableProcessing()}. - *

By default a {@link SimpleAsyncTaskExecutor} instance is used. Applications are - * advised to provide a TaskExecutor configured for production use. - * @see org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandlerAdapter#setAsyncTaskExecutor - */ - public void setTaskExecutor(AsyncTaskExecutor taskExecutor) { - this.taskExecutor = taskExecutor; - } - - /** - * Push an async Callable for the current stack level. This method should be - * invoked before delegating to the next level of the stack where async - * processing may start. - */ - public void push(AbstractDelegatingCallable callable) { - Assert.notNull(callable, "Async Callable is required"); - this.callables.addFirst(callable); - } - - /** - * Pop the Callable of the current stack level. Ensure this method is invoked - * after delegation to the next level of the stack where async processing may - * start. The pop operation succeeds if async processing did not start. - * @return {@code true} if the Callable was removed, or {@code false} - * otherwise (i.e. async started). - */ - public boolean pop() { - if (isAsyncStarted()) { - return false; - } - else { - this.callables.removeFirst(); - return true; - } - } - - /** - * Whether async request processing has started. - */ - public boolean isAsyncStarted() { - return ((this.asyncWebRequest != null) && this.asyncWebRequest.isAsyncStarted()); - } - - /** - * Set the last Callable, e.g. the one returned by the controller. - */ - public AsyncExecutionChain setLastCallable(Callable callable) { - Assert.notNull(callable, "Callable required"); - this.lastCallable = callable; - return this; - } - - /** - * Start async processing and execute the async chain with an AsyncTaskExecutor. - * This method returns immediately. - */ - public void startCallableProcessing() { - Assert.state(this.asyncWebRequest != null, "AsyncWebRequest was not set"); - this.asyncWebRequest.startAsync(); - this.taskExecutor.execute(new AsyncExecutionChainRunnable(this.asyncWebRequest, buildChain())); - } - - private Callable buildChain() { - Assert.state(this.lastCallable != null, "The last Callable was not set"); - AbstractDelegatingCallable head = new StaleAsyncRequestCheckingCallable(this.asyncWebRequest); - head.setNext(this.lastCallable); - for (AbstractDelegatingCallable callable : this.callables) { - callable.setNext(head); - head = callable; - } - return head; - } - - /** - * Start async processing and initialize the given DeferredResult so when - * its value is set, the async chain is executed with an AsyncTaskExecutor. - */ - public void startDeferredResultProcessing(final DeferredResult deferredResult) { - Assert.notNull(deferredResult, "DeferredResult is required"); - Assert.state(this.asyncWebRequest != null, "AsyncWebRequest was not set"); - this.asyncWebRequest.startAsync(); - - deferredResult.init(new DeferredResultHandler() { - public void handle(Object result) { - if (asyncWebRequest.isAsyncCompleted()) { - throw new StaleAsyncWebRequestException("Too late to set DeferredResult: " + result); - } - setLastCallable(new PassThroughCallable(result)); - taskExecutor.execute(new AsyncExecutionChainRunnable(asyncWebRequest, buildChain())); - } - }); - - this.asyncWebRequest.setTimeoutHandler(deferredResult.getTimeoutHandler()); - } - - - private static class PassThroughCallable implements Callable { - - private final Object value; - - public PassThroughCallable(Object value) { - this.value = value; - } - - public Object call() throws Exception { - return this.value; - } - } - -} diff --git a/spring-web/src/main/java/org/springframework/web/context/request/async/AsyncExecutionChainRunnable.java b/spring-web/src/main/java/org/springframework/web/context/request/async/AsyncExecutionChainRunnable.java deleted file mode 100644 index 14d4d4b1b35..00000000000 --- a/spring-web/src/main/java/org/springframework/web/context/request/async/AsyncExecutionChainRunnable.java +++ /dev/null @@ -1,79 +0,0 @@ -/* - * Copyright 2002-2012 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.web.context.request.async; - -import java.util.concurrent.Callable; - -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; -import org.springframework.http.HttpStatus; -import org.springframework.util.Assert; - -/** - * A Runnable for invoking a chain of Callable instances and completing async - * request processing while also dealing with any unhandled exceptions. - * - * @author Rossen Stoyanchev - * @since 3.2 - * - * @see AsyncExecutionChain#startCallableProcessing() - * @see AsyncExecutionChain#startDeferredResultProcessing(DeferredResult) - */ -public class AsyncExecutionChainRunnable implements Runnable { - - private static final Log logger = LogFactory.getLog(AsyncExecutionChainRunnable.class); - - private final AsyncWebRequest asyncWebRequest; - - private final Callable callable; - - /** - * Class constructor. - * @param asyncWebRequest the async request - * @param callable the async execution chain - */ - public AsyncExecutionChainRunnable(AsyncWebRequest asyncWebRequest, Callable callable) { - Assert.notNull(asyncWebRequest, "An AsyncWebRequest is required"); - Assert.notNull(callable, "A Callable is required"); - this.asyncWebRequest = asyncWebRequest; - this.callable = callable; - } - - /** - * Run the async execution chain and complete the async request. - *

A {@link StaleAsyncWebRequestException} is logged at debug level and - * absorbed while any other unhandled {@link Exception} results in a 500 - * response code. - */ - public void run() { - try { - this.callable.call(); - } - catch (StaleAsyncWebRequestException ex) { - logger.trace("Could not complete async request", ex); - } - catch (Exception ex) { - logger.trace("Could not complete async request", ex); - this.asyncWebRequest.sendError(HttpStatus.INTERNAL_SERVER_ERROR, ex.getMessage()); - } - finally { - logger.debug("Completing async request processing"); - this.asyncWebRequest.complete(); - } - } - -} diff --git a/spring-web/src/main/java/org/springframework/web/context/request/async/AsyncWebRequest.java b/spring-web/src/main/java/org/springframework/web/context/request/async/AsyncWebRequest.java index 51e6a84a7c8..256a5518343 100644 --- a/spring-web/src/main/java/org/springframework/web/context/request/async/AsyncWebRequest.java +++ b/spring-web/src/main/java/org/springframework/web/context/request/async/AsyncWebRequest.java @@ -16,12 +16,11 @@ package org.springframework.web.context.request.async; -import org.springframework.http.HttpStatus; import org.springframework.web.context.request.NativeWebRequest; /** - * Extend {@link NativeWebRequest} with methods for async request processing. + * Extends {@link NativeWebRequest} with methods for asynchronous request processing. * * @author Rossen Stoyanchev * @since 3.2 @@ -29,50 +28,51 @@ import org.springframework.web.context.request.NativeWebRequest; public interface AsyncWebRequest extends NativeWebRequest { /** - * Set the timeout for asynchronous request processing in milliseconds. - * In Servlet 3 async request processing, the timeout begins when the - * main processing thread has exited. + * Set the time required for concurrent handling to complete. + * @param timeout amount of time in milliseconds */ void setTimeout(Long timeout); /** - * Invoked on a timeout to complete the response instead of the default - * behavior that sets the status to 503 (SERVICE_UNAVAILABLE). + * Provide a Runnable to invoke on timeout. */ void setTimeoutHandler(Runnable runnable); /** - * Mark the start of async request processing for example ensuring the - * request remains open in order to be completed in a separate thread. - * @throws IllegalStateException if async processing has started, if it is - * not supported, or if it has completed. + * Provide a Runnable to invoke at the end of asynchronous request processing. + */ + void addCompletionHandler(Runnable runnable); + + /** + * Mark the start of asynchronous request processing so that when the main + * processing thread exits, the response remains open for further processing + * in another thread. + * @throws IllegalStateException if async processing has completed or is not supported */ void startAsync(); /** - * Whether async processing is in progress and has not yet completed. + * Whether the request is in asynchronous mode after a call to {@link #startAsync()}. + * Returns "false" if asynchronous processing never started, has completed, or the + * request was dispatched for further processing. */ boolean isAsyncStarted(); /** - * Complete async request processing making a best effort but without any - * effect if async request processing has already completed for any reason - * including a timeout. + * Dispatch the request to the container in order to resume processing after + * concurrent execution in an application thread. */ - void complete(); + void dispatch(); /** - * Whether async processing has completed either normally via calls to - * {@link #complete()} or for other reasons such as a timeout likely - * detected in a separate thread during async request processing. + * Whether the request was dispatched to the container. */ - boolean isAsyncCompleted(); + boolean isDispatched(); /** - * Send an error to the client making a best effort to do so but without any - * effect if async request processing has already completed, for example due - * to a timeout. + * Whether asynchronous processing has completed in which case the request + * response should no longer be used. */ - void sendError(HttpStatus status, String message); + boolean isAsyncComplete(); } diff --git a/spring-web/src/main/java/org/springframework/web/context/request/async/AsyncWebRequestInterceptor.java b/spring-web/src/main/java/org/springframework/web/context/request/async/AsyncWebRequestInterceptor.java index 017e0bd7a4f..93f85f6082a 100644 --- a/spring-web/src/main/java/org/springframework/web/context/request/async/AsyncWebRequestInterceptor.java +++ b/spring-web/src/main/java/org/springframework/web/context/request/async/AsyncWebRequestInterceptor.java @@ -13,61 +13,43 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.springframework.web.context.request.async; -import java.util.concurrent.Callable; - import org.springframework.web.context.request.WebRequest; import org.springframework.web.context.request.WebRequestInterceptor; /** - * Extends {@link WebRequestInterceptor} with lifecycle methods specific to async - * request processing. + * Extends the WebRequestInterceptor contract for scenarios where a handler may be + * executed asynchronously. Since the handler will complete execution in another + * thread, the results are not available in the current thread, and therefore the + * DispatcherServlet exits quickly and on its way out invokes + * {@link #afterConcurrentHandlingStarted(WebRequest)} instead of + * {@code postHandle} and {@code afterCompletion}. + * When the async handler execution completes, and the request is dispatched back + * for further processing, the DispatcherServlet will invoke {@code preHandle} + * again, as well as {@code postHandle} and {@code afterCompletion}. * - *

This is the sequence of events on the main thread in an async scenario: - *

    - *
  1. {@link #preHandle(WebRequest)} - *
  2. {@link #getAsyncCallable(WebRequest)} - *
  3. ... handler execution - *
  4. {@link #postHandleAsyncStarted(WebRequest)} - *
- * - *

This is the sequence of events on the async thread: - *

    - *
  1. Async {@link Callable#call()} (the {@code Callable} returned by {@code getAsyncCallable}) - *
  2. ... async handler execution - *
  3. {@link #postHandle(WebRequest, org.springframework.ui.ModelMap)} - *
  4. {@link #afterCompletion(WebRequest, Exception)} - *
+ *

Existing implementations should consider the fact that {@code preHandle} may + * be invoked twice before {@code postHandle} and {@code afterCompletion} are + * called if they don't implement this contract. Once before the start of concurrent + * handling and a second time as part of an asynchronous dispatch after concurrent + * handling is done. This may be not important in most cases but when some work + * needs to be done after concurrent handling starts (e.g. clearing thread locals) + * then this contract can be implemented. * * @author Rossen Stoyanchev * @since 3.2 + * + * @see WebAsyncManager */ -public interface AsyncWebRequestInterceptor extends WebRequestInterceptor { - - /** - * Invoked after {@link #preHandle(WebRequest)} and before - * the handler is executed. The returned {@link Callable} is used only if - * handler execution leads to teh start of async processing. It is invoked - * the async thread before the request is handled fro. - *

Implementations can use this Callable to initialize - * ThreadLocal attributes on the async thread. - * @return a {@link Callable} instance or null - */ - AbstractDelegatingCallable getAsyncCallable(WebRequest request); +public interface AsyncWebRequestInterceptor extends WebRequestInterceptor{ /** - * Invoked after the execution of a handler if the handler started - * async processing instead of handling the request. Effectively this method - * is invoked on the way out of the main processing thread instead of - * {@link #postHandle(WebRequest, org.springframework.ui.ModelMap)}. The - * postHandle method is invoked after the request is handled - * in the async thread. - *

Implementations of this method can ensure ThreadLocal attributes bound - * to the main thread are cleared and also prepare for binding them to the - * async thread. + * Called instead of {@code postHandle} and {@code afterCompletion}, when the + * the handler started handling the request concurrently. + * + * @param request the current request */ - void postHandleAsyncStarted(WebRequest request); + void afterConcurrentHandlingStarted(WebRequest request); } diff --git a/spring-web/src/main/java/org/springframework/web/context/request/async/AsyncWebUtils.java b/spring-web/src/main/java/org/springframework/web/context/request/async/AsyncWebUtils.java new file mode 100644 index 00000000000..02f138c8ccb --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/context/request/async/AsyncWebUtils.java @@ -0,0 +1,60 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.web.context.request.async; + +import javax.servlet.ServletRequest; + +import org.springframework.web.context.request.RequestAttributes; +import org.springframework.web.context.request.WebRequest; + +/** + * Utility methods related to processing asynchronous web requests. + * + * @author Rossen Stoyanchev + * @since 3.2 + */ +public abstract class AsyncWebUtils { + + public static final String WEB_ASYNC_MANAGER_ATTRIBUTE = WebAsyncManager.class.getName() + ".WEB_ASYNC_MANAGER"; + + /** + * Obtain the {@link WebAsyncManager} for the current request, or if not + * found, create and associate it with the request. + */ + public static WebAsyncManager getAsyncManager(ServletRequest servletRequest) { + WebAsyncManager asyncManager = (WebAsyncManager) servletRequest.getAttribute(WEB_ASYNC_MANAGER_ATTRIBUTE); + if (asyncManager == null) { + asyncManager = new WebAsyncManager(); + servletRequest.setAttribute(WEB_ASYNC_MANAGER_ATTRIBUTE, asyncManager); + } + return asyncManager; + } + + /** + * Obtain the {@link WebAsyncManager} for the current request, or if not + * found, create and associate it with the request. + */ + public static WebAsyncManager getAsyncManager(WebRequest webRequest) { + int scope = RequestAttributes.SCOPE_REQUEST; + WebAsyncManager asyncManager = (WebAsyncManager) webRequest.getAttribute(WEB_ASYNC_MANAGER_ATTRIBUTE, scope); + if (asyncManager == null) { + asyncManager = new WebAsyncManager(); + webRequest.setAttribute(WEB_ASYNC_MANAGER_ATTRIBUTE, asyncManager, scope); + } + return asyncManager; + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/context/request/async/DeferredResult.java b/spring-web/src/main/java/org/springframework/web/context/request/async/DeferredResult.java index 8a5eadccc2c..3c3cd2f0da4 100644 --- a/spring-web/src/main/java/org/springframework/web/context/request/async/DeferredResult.java +++ b/spring-web/src/main/java/org/springframework/web/context/request/async/DeferredResult.java @@ -20,6 +20,8 @@ import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.locks.ReentrantLock; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; import org.springframework.core.task.AsyncTaskExecutor; import org.springframework.util.Assert; @@ -49,6 +51,8 @@ import org.springframework.util.Assert; */ public final class DeferredResult { + private static final Log logger = LogFactory.getLog(DeferredResult.class); + private V result; private DeferredResultHandler resultHandler; @@ -123,19 +127,28 @@ public final class DeferredResult { Assert.isNull(this.result, "A deferred result can be set once only"); this.result = result; this.timeoutValueUsed = (this.timeoutValueSet && (this.result == this.timeoutValue)); - try { - this.initializationLatch.await(10, TimeUnit.SECONDS); - } - catch (InterruptedException e) { + if (!await()) { throw new IllegalStateException( "Gave up on waiting for DeferredResult to be initialized. " + "Are you perhaps creating and setting a DeferredResult in the same thread? " + "The DeferredResult must be fully initialized before you can set it. " + "See the class javadoc for more details"); } + if (this.timeoutValueUsed) { + logger.debug("Using default timeout value"); + } this.resultHandler.handle(result); } + private boolean await() { + try { + return this.initializationLatch.await(10, TimeUnit.SECONDS); + } + catch (InterruptedException e) { + return false; + } + } + /** * Return a handler to use to complete processing using the default timeout value * provided via {@link #DeferredResult(Object)} or {@code null} if no timeout diff --git a/spring-web/src/main/java/org/springframework/web/context/request/async/NoOpAsyncWebRequest.java b/spring-web/src/main/java/org/springframework/web/context/request/async/NoSupportAsyncWebRequest.java similarity index 68% rename from spring-web/src/main/java/org/springframework/web/context/request/async/NoOpAsyncWebRequest.java rename to spring-web/src/main/java/org/springframework/web/context/request/async/NoSupportAsyncWebRequest.java index fd425bc1389..a4f0eb07024 100644 --- a/spring-web/src/main/java/org/springframework/web/context/request/async/NoOpAsyncWebRequest.java +++ b/spring-web/src/main/java/org/springframework/web/context/request/async/NoSupportAsyncWebRequest.java @@ -19,47 +19,51 @@ package org.springframework.web.context.request.async; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; -import org.springframework.http.HttpStatus; import org.springframework.web.context.request.ServletWebRequest; /** - * An implementation of {@link AsyncWebRequest} used when no underlying support - * for async request processing is available in which case {@link #startAsync()} - * results in an {@link UnsupportedOperationException}. + * An implementation of {@link AsyncWebRequest} to use when no underlying support is available. + * The methods {@link #startAsync()} and {@link #dispatch()} raise {@link UnsupportedOperationException}. * * @author Rossen Stoyanchev * @since 3.2 */ -public class NoOpAsyncWebRequest extends ServletWebRequest implements AsyncWebRequest { +public class NoSupportAsyncWebRequest extends ServletWebRequest implements AsyncWebRequest { - public NoOpAsyncWebRequest(HttpServletRequest request, HttpServletResponse response) { + public NoSupportAsyncWebRequest(HttpServletRequest request, HttpServletResponse response) { super(request, response); } + public void addCompletionHandler(Runnable runnable) { + // ignored + } + public void setTimeout(Long timeout) { + // ignored } public void setTimeoutHandler(Runnable runnable) { + // ignored } public boolean isAsyncStarted() { return false; } - public boolean isAsyncCompleted() { + public boolean isDispatched() { return false; } - public void startAsync() { + public boolean isAsyncComplete() { throw new UnsupportedOperationException("No async support in a pre-Servlet 3.0 runtime"); } - public void complete() { - throw new UnsupportedOperationException("No async support in a pre-Servlet 3.0 environment"); + public void startAsync() { + throw new UnsupportedOperationException("No async support in a pre-Servlet 3.0 runtime"); } - public void sendError(HttpStatus status, String message) { - throw new UnsupportedOperationException("No async support in a pre-Servlet 3.0 environment"); + public void dispatch() { + throw new UnsupportedOperationException("No async support in a pre-Servlet 3.0 runtime"); } } diff --git a/spring-web/src/main/java/org/springframework/web/context/request/async/StaleAsyncRequestCheckingCallable.java b/spring-web/src/main/java/org/springframework/web/context/request/async/StaleAsyncRequestCheckingCallable.java deleted file mode 100644 index 42a6a4d5eca..00000000000 --- a/spring-web/src/main/java/org/springframework/web/context/request/async/StaleAsyncRequestCheckingCallable.java +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Copyright 2002-2012 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.web.context.request.async; - - -/** - * Invokes the next Callable in a chain and then checks if the AsyncWebRequest - * provided to the constructor has ended before returning. Since a timeout or a - * (client) error may occur in a separate thread while async request processing - * is still in progress in its own thread, inserting this Callable in the chain - * protects against use of stale async requests. - * - *

If an async request was terminated while the next Callable was still - * processing, a {@link StaleAsyncWebRequestException} is raised. - * - * @author Rossen Stoyanchev - * @since 3.2 - */ -public class StaleAsyncRequestCheckingCallable extends AbstractDelegatingCallable { - - private final AsyncWebRequest asyncWebRequest; - - public StaleAsyncRequestCheckingCallable(AsyncWebRequest asyncWebRequest) { - this.asyncWebRequest = asyncWebRequest; - } - - public Object call() throws Exception { - Object result = getNext().call(); - if (this.asyncWebRequest.isAsyncCompleted()) { - throw new StaleAsyncWebRequestException( - "Async request no longer available due to a timeout or a (client) error"); - } - return result; - } - -} \ No newline at end of file diff --git a/spring-web/src/main/java/org/springframework/web/context/request/async/StaleAsyncWebRequestException.java b/spring-web/src/main/java/org/springframework/web/context/request/async/StaleAsyncWebRequestException.java index fbbb2f46cf9..4f80fc681aa 100644 --- a/spring-web/src/main/java/org/springframework/web/context/request/async/StaleAsyncWebRequestException.java +++ b/spring-web/src/main/java/org/springframework/web/context/request/async/StaleAsyncWebRequestException.java @@ -21,9 +21,6 @@ package org.springframework.web.context.request.async; * * @author Rossen Stoyanchev * @since 3.2 - * - * @see DeferredResult#set(Object) - * @see AsyncExecutionChainRunnable#run() */ @SuppressWarnings("serial") public class StaleAsyncWebRequestException extends RuntimeException { diff --git a/spring-web/src/main/java/org/springframework/web/context/request/async/StandardServletAsyncWebRequest.java b/spring-web/src/main/java/org/springframework/web/context/request/async/StandardServletAsyncWebRequest.java index 3b2c9de7c2a..7ed148c9c72 100644 --- a/spring-web/src/main/java/org/springframework/web/context/request/async/StandardServletAsyncWebRequest.java +++ b/spring-web/src/main/java/org/springframework/web/context/request/async/StandardServletAsyncWebRequest.java @@ -17,11 +17,14 @@ package org.springframework.web.context.request.async; import java.io.IOException; +import java.util.ArrayList; +import java.util.List; import java.util.concurrent.atomic.AtomicBoolean; import javax.servlet.AsyncContext; import javax.servlet.AsyncEvent; import javax.servlet.AsyncListener; +import javax.servlet.DispatcherType; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; @@ -48,36 +51,60 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements private AtomicBoolean asyncCompleted = new AtomicBoolean(false); - private Runnable timeoutHandler; + private Runnable timeoutHandler = new DefaultTimeoutHandler(); + + private final List completionHandlers = new ArrayList(); public StandardServletAsyncWebRequest(HttpServletRequest request, HttpServletResponse response) { super(request, response); } + /** + * {@inheritDoc} + *

The timeout period begins when the main processing thread has exited. + */ public void setTimeout(Long timeout) { this.timeout = timeout; } + public void setTimeoutHandler(Runnable timeoutHandler) { + if (timeoutHandler != null) { + this.timeoutHandler = timeoutHandler; + } + } + + public void addCompletionHandler(Runnable runnable) { + this.completionHandlers.add(runnable); + } + public boolean isAsyncStarted() { return ((this.asyncContext != null) && getRequest().isAsyncStarted()); } - public boolean isAsyncCompleted() { - return this.asyncCompleted.get(); + public boolean isDispatched() { + return (DispatcherType.ASYNC.equals(getRequest().getDispatcherType())); } - public void setTimeoutHandler(Runnable timeoutHandler) { - this.timeoutHandler = timeoutHandler; + /** + * Whether async request processing has completed. + *

It is important to avoid use of request and response objects after async + * processing has completed. Servlet containers often re-use them. + */ + public boolean isAsyncComplete() { + return this.asyncCompleted.get(); } + public void startAsync() { Assert.state(getRequest().isAsyncSupported(), "Async support must be enabled on a servlet and for all filters involved " + "in async request processing. This is done in Java code using the Servlet API " + "or by adding \"true\" to servlet and " + "filter declarations in web.xml."); - Assert.state(!isAsyncStarted(), "Async processing already started"); - Assert.state(!isAsyncCompleted(), "Cannot use async request that has completed"); + Assert.state(!isAsyncComplete(), "Async processing has already completed"); + if (isAsyncStarted()) { + return; + } this.asyncContext = getRequest().startAsync(getRequest(), getResponse()); this.asyncContext.addListener(this); if (this.timeout != null) { @@ -85,52 +112,44 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements } } - public void complete() { - if (!isAsyncCompleted()) { - this.asyncContext.complete(); - completeInternal(); - } - } - - private void completeInternal() { - this.asyncContext = null; - this.asyncCompleted.set(true); - } - - public void sendError(HttpStatus status, String message) { - try { - if (!isAsyncCompleted()) { - getResponse().sendError(500, message); - } - } - catch (IOException ioEx) { - // absorb - } + public void dispatch() { + Assert.notNull(this.asyncContext, "Cannot dispatch without an AsyncContext"); + this.asyncContext.dispatch(); } // --------------------------------------------------------------------- // Implementation of AsyncListener methods // --------------------------------------------------------------------- - public void onTimeout(AsyncEvent event) throws IOException { - if (this.timeoutHandler == null) { - getResponse().sendError(HttpStatus.SERVICE_UNAVAILABLE.value()); - } - else { - this.timeoutHandler.run(); - } - completeInternal(); + public void onStartAsync(AsyncEvent event) throws IOException { } public void onError(AsyncEvent event) throws IOException { - completeInternal(); } - public void onStartAsync(AsyncEvent event) throws IOException { + public void onTimeout(AsyncEvent event) throws IOException { + this.timeoutHandler.run(); } public void onComplete(AsyncEvent event) throws IOException { - completeInternal(); + for (Runnable runnable : this.completionHandlers) { + runnable.run(); + } + this.asyncContext = null; + this.asyncCompleted.set(true); + } + + + private class DefaultTimeoutHandler implements Runnable { + + public void run() { + try { + getResponse().sendError(HttpStatus.SERVICE_UNAVAILABLE.value()); + } + catch (IOException ex) { + // ignore + } + } } } diff --git a/spring-web/src/main/java/org/springframework/web/context/request/async/WebAsyncManager.java b/spring-web/src/main/java/org/springframework/web/context/request/async/WebAsyncManager.java new file mode 100644 index 00000000000..0d7a78c2abc --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/context/request/async/WebAsyncManager.java @@ -0,0 +1,319 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.web.context.request.async; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.Callable; + +import javax.servlet.http.HttpServletRequest; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.springframework.core.task.AsyncTaskExecutor; +import org.springframework.core.task.SimpleAsyncTaskExecutor; +import org.springframework.util.Assert; +import org.springframework.web.context.request.RequestAttributes; +import org.springframework.web.context.request.async.DeferredResult.DeferredResultHandler; +import org.springframework.web.util.UrlPathHelper; + +/** + * The central class for managing asynchronous request processing, mainly intended + * as an SPI and not typically used directly by application classes. + * + *

An async scenario starts with request processing as usual in a thread (T1). + * When a handler decides to handle the request concurrently, it calls + * {@linkplain #startCallableProcessing(Callable, Object...) startCallableProcessing} or + * {@linkplain #startDeferredResultProcessing(DeferredResult, Object...) startDeferredResultProcessing} + * both of which will process in a separate thread (T2). + * After the start of concurrent handling {@link #isConcurrentHandlingStarted()} + * returns "true" and this can be used by classes involved in processing on the + * main thread (T1) quickly and with very minimal processing. + * + *

When the concurrent handling completes in a separate thread (T2), both + * {@code startCallableProcessing} and {@code startDeferredResultProcessing} + * save the results and dispatched to the container, essentially to the + * same request URI as the one that started concurrent handling. This allows for + * further processing of the concurrent results. Classes in the dispatched + * thread (T3), can access the results via {@link #getConcurrentResult()} or + * detect their presence via {@link #hasConcurrentResult()}. Also in the + * dispatched thread {@link #isConcurrentHandlingStarted()} will return "false" + * unless concurrent handling is started once again. + * + * TODO .. mention Servlet 3 configuration + * + * @author Rossen Stoyanchev + * @since 3.2 + * + * @see org.springframework.web.context.request.async.AsyncWebRequestInterceptor + * @see org.springframework.web.servlet.AsyncHandlerInterceptor + * @see org.springframework.web.filter.OncePerRequestFilter#shouldFilterAsyncDispatches + * @see org.springframework.web.filter.OncePerRequestFilter#isAsyncDispatch + * @see org.springframework.web.filter.OncePerRequestFilter#isLastRequestThread + */ +public final class WebAsyncManager { + + private static final Object RESULT_NONE = new Object(); + + private static final Log logger = LogFactory.getLog(WebAsyncManager.class); + + private AsyncWebRequest asyncWebRequest; + + private AsyncTaskExecutor taskExecutor = new SimpleAsyncTaskExecutor(this.getClass().getSimpleName()); + + private final Map threadInitializers = new LinkedHashMap(); + + private Object concurrentResult = RESULT_NONE; + + private Object[] concurrentResultContext; + + private static final UrlPathHelper urlPathHelper = new UrlPathHelper(); + + /** + * Package private constructor + * @see AsyncWebUtils + */ + WebAsyncManager() { + } + + /** + * Configure an AsyncTaskExecutor for use with {@link #startCallableProcessing(Callable)}. + *

By default a {@link SimpleAsyncTaskExecutor} instance is used. Applications + * are advised to provide a TaskExecutor configured for production use. + * @see org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandlerAdapter#setAsyncTaskExecutor + */ + public void setTaskExecutor(AsyncTaskExecutor taskExecutor) { + this.taskExecutor = taskExecutor; + } + + /** + * Provide an {@link AsyncWebRequest} to use to start and to dispatch request. + * This property must be set before the start of concurrent handling. + * @param asyncWebRequest the request to use + */ + public void setAsyncWebRequest(final AsyncWebRequest asyncWebRequest) { + Assert.notNull(asyncWebRequest, "Expected AsyncWebRequest"); + Assert.state(!isConcurrentHandlingStarted(), "Can't set AsyncWebRequest with concurrent handling in progress"); + this.asyncWebRequest = asyncWebRequest; + this.asyncWebRequest.addCompletionHandler(new Runnable() { + public void run() { + asyncWebRequest.removeAttribute(AsyncWebUtils.WEB_ASYNC_MANAGER_ATTRIBUTE, RequestAttributes.SCOPE_REQUEST); + } + }); + } + + /** + * Whether the handler for the current request is executed concurrently. + * Once concurrent handling is done, the result is saved, and the request + * dispatched again to resume processing where the result of concurrent + * handling is available via {@link #getConcurrentResult()}. + */ + public boolean isConcurrentHandlingStarted() { + return ((this.asyncWebRequest != null) && (this.asyncWebRequest.isAsyncStarted())); + } + + /** + * Whether the current thread was dispatched to continue processing the result + * of concurrent handler execution. + */ + public boolean hasConcurrentResult() { + + // TODO: + // Add check for asyncWebRequest.isDispatched() once Apache id=53632 is fixed. + // That ensure "true" is returned in the dispatched thread only. + + return this.concurrentResult != RESULT_NONE; + } + + /** + * Return the result of concurrent handler execution. This may be an Object + * value on successful return or an {@code Exception} or {@code Throwable}. + */ + public Object getConcurrentResult() { + return this.concurrentResult; + } + + /** + * Return the processing context saved at the start of concurrent handling. + */ + public Object[] getConcurrentResultContext() { + return this.concurrentResultContext; + } + + /** + * Reset the {@linkplain #getConcurrentResult() concurrentResult} and the + * {@linkplain #getConcurrentResultContext() concurrentResultContext}. + */ + public void resetConcurrentResult() { + this.concurrentResult = RESULT_NONE; + this.concurrentResultContext = null; + } + + /** + * Register an {@link AsyncThreadInitializer} with the WebAsyncManager instance + * for the current request. It may later be accessed and applied via + * {@link #applyAsyncThreadInitializer(String)} and will also be used to + * initialize and reset threads for concurrent handler execution. + * @param key a unique the key under which to keep the initializer + * @param initializer the initializer instance + */ + public void registerAsyncThreadInitializer(Object key, AsyncThreadInitializer initializer) { + Assert.notNull(initializer, "An AsyncThreadInitializer instance is required"); + this.threadInitializers.put(key, initializer); + } + + /** + * Invoke the {@linkplain AsyncThreadInitializer#initialize() initialize()} + * method of the named {@link AsyncThreadInitializer}. + * @param key the key under which the initializer was registered + * @return whether an initializer was found and applied + */ + public boolean applyAsyncThreadInitializer(Object key) { + AsyncThreadInitializer initializer = this.threadInitializers.get(key); + if (initializer != null) { + initializer.initialize(); + return true; + } + return false; + } + + /** + * Submit a request handling task for concurrent execution. Returns immediately + * and subsequent calls to {@link #isConcurrentHandlingStarted()} return "true". + *

When concurrent handling is done, the resulting value, which may be an + * Object or a raised {@code Exception} or {@code Throwable}, is saved and the + * request is dispatched for further processing of that result. In the dispatched + * thread, the result can be accessed via {@link #getConcurrentResult()} while + * {@link #hasConcurrentResult()} returns "true" and + * {@link #isConcurrentHandlingStarted()} is back to returning "false". + * + * @param callable a unit of work to be executed asynchronously + * @param processingContext additional context to save for later access via + * {@link #getConcurrentResultContext()} + */ + public void startCallableProcessing(final Callable callable, Object... processingContext) { + Assert.notNull(callable, "Callable is required"); + + startAsyncProcessing(processingContext); + + this.taskExecutor.submit(new Runnable() { + + public void run() { + List initializers = + new ArrayList(threadInitializers.values()); + + try { + for (AsyncThreadInitializer initializer : initializers) { + initializer.initialize(); + } + concurrentResult = callable.call(); + } + catch (Throwable t) { + concurrentResult = t; + } + finally { + Collections.reverse(initializers); + for (AsyncThreadInitializer initializer : initializers) { + initializer.reset(); + } + } + + if (logger.isDebugEnabled()) { + logger.debug("Concurrent result value [" + concurrentResult + "]"); + } + + if (asyncWebRequest.isAsyncComplete()) { + logger.error("Could not complete processing due to a timeout or network error"); + return; + } + + logger.debug("Dispatching request to continue processing"); + asyncWebRequest.dispatch(); + } + }); + } + + /** + * Initialize the given given {@link DeferredResult} so that whenever the + * DeferredResult is set, the resulting value, which may be an Object or a + * raised {@code Exception} or {@code Throwable}, is saved and the request + * is dispatched for further processing of the result. In the dispatch + * thread, the result value can be accessed via {@link #getConcurrentResult()}. + *

The method returns immediately and it's up to the caller to set the + * DeferredResult. Subsequent calls to {@link #isConcurrentHandlingStarted()} + * return "true" until after the dispatch when {@link #hasConcurrentResult()} + * returns "true" and {@link #isConcurrentHandlingStarted()} is back to "false". + * + * @param deferredResult the DeferredResult instance to initialize + * @param processingContext additional context to save for later access via + * {@link #getConcurrentResultContext()} + */ + public void startDeferredResultProcessing(final DeferredResult deferredResult, Object... processingContext) { + Assert.notNull(deferredResult, "DeferredResult is required"); + + startAsyncProcessing(processingContext); + + deferredResult.init(new DeferredResultHandler() { + + public void handle(Object result) { + concurrentResult = result; + if (logger.isDebugEnabled()) { + logger.debug("Deferred result value [" + concurrentResult + "]"); + } + + if (asyncWebRequest.isAsyncComplete()) { + throw new StaleAsyncWebRequestException("Could not complete processing due to a timeout or network error"); + } + + logger.debug("Dispatching request to complete processing"); + asyncWebRequest.dispatch(); + } + }); + + this.asyncWebRequest.setTimeoutHandler(deferredResult.getTimeoutHandler()); + } + + private void startAsyncProcessing(Object... context) { + + Assert.state(this.asyncWebRequest != null, "AsyncWebRequest was not set"); + this.asyncWebRequest.startAsync(); + + this.concurrentResult = null; + this.concurrentResultContext = context; + + if (logger.isDebugEnabled()) { + HttpServletRequest request = asyncWebRequest.getNativeRequest(HttpServletRequest.class); + String requestUri = urlPathHelper.getRequestUri(request); + logger.debug("Concurrent handling starting for " + request.getMethod() + " [" + requestUri + "]"); + } + } + + + /** + * A contract for initializing and resetting a thread. + */ + public interface AsyncThreadInitializer { + + void initialize(); + + void reset(); + + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/filter/AbstractRequestLoggingFilter.java b/spring-web/src/main/java/org/springframework/web/filter/AbstractRequestLoggingFilter.java index b1ad7a9e096..7bdef199ab1 100644 --- a/spring-web/src/main/java/org/springframework/web/filter/AbstractRequestLoggingFilter.java +++ b/spring-web/src/main/java/org/springframework/web/filter/AbstractRequestLoggingFilter.java @@ -32,8 +32,6 @@ import javax.servlet.http.HttpSession; import org.springframework.util.Assert; import org.springframework.util.StringUtils; -import org.springframework.web.context.request.async.AbstractDelegatingCallable; -import org.springframework.web.context.request.async.AsyncExecutionChain; import org.springframework.web.util.WebUtils; /** @@ -179,6 +177,16 @@ public abstract class AbstractRequestLoggingFilter extends OncePerRequestFilter this.afterMessageSuffix = afterMessageSuffix; } + /** + * The default value is "true" so that the filter may log a "before" message + * at the start of request processing and an "after" message at the end from + * when the last asynchronously dispatched thread is exiting. + */ + @Override + protected boolean shouldFilterAsyncDispatches() { + return true; + } + /** * Forwards the request to the next filter in the chain and delegates down to the subclasses to perform the actual * request logging both before and after the request is processed. @@ -190,22 +198,28 @@ public abstract class AbstractRequestLoggingFilter extends OncePerRequestFilter protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { + boolean isAsyncDispatch = isAsyncDispatch(request); + if (isIncludePayload()) { - request = new RequestCachingRequestWrapper(request); + if (isAsyncDispatch) { + request = WebUtils.getNativeRequest(request, RequestCachingRequestWrapper.class); + Assert.notNull(request, "Expected wrapped request"); + } + else { + request = new RequestCachingRequestWrapper(request); + } } - beforeRequest(request, getBeforeMessage(request)); - - AsyncExecutionChain chain = AsyncExecutionChain.getForCurrentRequest(request); - chain.push(getAsyncCallable(request)); + if (!isAsyncDispatch) { + beforeRequest(request, getBeforeMessage(request)); + } try { filterChain.doFilter(request, response); } finally { - if (!chain.pop()) { - return; + if (isLastRequestThread(request)) { + afterRequest(request, getAfterMessage(request)); } - afterRequest(request, getAfterMessage(request)); } } @@ -290,19 +304,6 @@ public abstract class AbstractRequestLoggingFilter extends OncePerRequestFilter */ protected abstract void afterRequest(HttpServletRequest request, String message); - /** - * Create a Callable to use to complete processing in an async execution chain. - */ - private AbstractDelegatingCallable getAsyncCallable(final HttpServletRequest request) { - return new AbstractDelegatingCallable() { - public Object call() throws Exception { - getNext().call(); - afterRequest(request, getAfterMessage(request)); - return null; - } - }; - } - private static class RequestCachingRequestWrapper extends HttpServletRequestWrapper { diff --git a/spring-web/src/main/java/org/springframework/web/filter/OncePerRequestFilter.java b/spring-web/src/main/java/org/springframework/web/filter/OncePerRequestFilter.java index 29cc896f131..6b2f6c84c25 100644 --- a/spring-web/src/main/java/org/springframework/web/filter/OncePerRequestFilter.java +++ b/spring-web/src/main/java/org/springframework/web/filter/OncePerRequestFilter.java @@ -25,14 +25,19 @@ import javax.servlet.ServletResponse; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; -import org.springframework.web.context.request.async.AbstractDelegatingCallable; -import org.springframework.web.context.request.async.AsyncExecutionChain; +import org.springframework.web.context.request.async.AsyncWebUtils; /** * Filter base class that guarantees to be just executed once per request, * on any servlet container. It provides a {@link #doFilterInternal} * method with HttpServletRequest and HttpServletResponse arguments. * + *

In an async scenario a filter may be invoked again in additional threads + * as part of an {@linkplain javax.servlet.DispatcherType.ASYNC ASYNC} dispatch. + * Sub-classes may decide whether to be invoked once per request or once per + * request thread for as long as the same request is being processed. + * See {@link #shouldFilterAsyncDispatches()}. + * *

The {@link #getAlreadyFilteredAttributeName} method determines how * to identify that a request is already filtered. The default implementation * is based on the configured name of the concrete filter instance. @@ -68,26 +73,27 @@ public abstract class OncePerRequestFilter extends GenericFilterBean { HttpServletRequest httpRequest = (HttpServletRequest) request; HttpServletResponse httpResponse = (HttpServletResponse) response; + boolean processAsyncRequestThread = isAsyncDispatch(httpRequest) && shouldFilterAsyncDispatches(); + String alreadyFilteredAttributeName = getAlreadyFilteredAttributeName(); - if (request.getAttribute(alreadyFilteredAttributeName) != null || shouldNotFilter(httpRequest)) { + boolean hasAlreadyFilteredAttribute = request.getAttribute(alreadyFilteredAttributeName) != null; + + if ((hasAlreadyFilteredAttribute && (!processAsyncRequestThread)) || shouldNotFilter(httpRequest)) { + // Proceed without invoking this filter... filterChain.doFilter(request, response); } else { - AsyncExecutionChain chain = AsyncExecutionChain.getForCurrentRequest(request); - chain.push(getAsyncCallable(request, alreadyFilteredAttributeName)); - // Do invoke this filter... request.setAttribute(alreadyFilteredAttributeName, Boolean.TRUE); try { doFilterInternal(httpRequest, httpResponse, filterChain); } finally { - if (!chain.pop()) { - return; + if (isLastRequestThread(httpRequest)) { + // Remove the "already filtered" request attribute for this request. + request.removeAttribute(alreadyFilteredAttributeName); } - // Remove the "already filtered" request attribute for this request. - request.removeAttribute(alreadyFilteredAttributeName); } } } @@ -122,25 +128,62 @@ public abstract class OncePerRequestFilter extends GenericFilterBean { } /** - * Create a Callable to use to complete processing in an async execution chain. + * Whether to filter once per request or once per request thread. The dispatcher + * type {@code javax.servlet.DispatcherType.ASYNC} introduced in Servlet 3.0 + * means a filter can be invoked in more than one thread (and exited) over the + * course of a single request. Some filters only need to filter the initial + * thread (e.g. request wrapping) while others may need to be invoked at least + * once in each additional thread for example for setting up thread locals or + * to perform final processing at the very end. + *

Note that although a filter can be mapped to handle specific dispatcher + * types via {@code web.xml} or in Java through the {@code ServletContext}, + * servlet containers may enforce different defaults with regards to dispatcher + * types. This flag enforces the design intent of the filter. + *

The default setting is "false", which means the filter will be invoked + * once only per request and only on the initial request thread. If "true", the + * filter will also be invoked once only on each additional thread. + * + * @see org.springframework.web.context.request.async.WebAsyncManager */ - private AbstractDelegatingCallable getAsyncCallable(final ServletRequest request, - final String alreadyFilteredAttributeName) { - - return new AbstractDelegatingCallable() { - public Object call() throws Exception { - getNext().call(); - request.removeAttribute(alreadyFilteredAttributeName); - return null; - } - }; + protected boolean shouldFilterAsyncDispatches() { + return false; + } + + /** + * Whether the request was dispatched to complete processing of results produced + * in another thread. This aligns with the Servlet 3.0 dispatcher type + * {@code javax.servlet.DispatcherType.ASYNC} and can be used by filters that + * return "true" from {@link #shouldFilterAsyncDispatches()} to detect when + * the filter is being invoked subsequently in additional thread(s). + * + * @see org.springframework.web.context.request.async.WebAsyncManager + */ + protected final boolean isAsyncDispatch(HttpServletRequest request) { + return AsyncWebUtils.getAsyncManager(request).hasConcurrentResult(); + } + + /** + * Whether this is the last thread processing the request. Note the returned + * value may change from {@code true} to {@code false} if the method is + * invoked before and after delegating to the next filter, since the next filter + * or servlet may begin concurrent processing. Therefore this method is most + * useful after delegation for final, end-of-request type processing. + * @param request the current request + * @return {@code true} if the response will be committed when the current + * thread exits; {@code false} if the response will remain open. + * + * @see org.springframework.web.context.request.async.WebAsyncManager + */ + protected final boolean isLastRequestThread(HttpServletRequest request) { + return (!AsyncWebUtils.getAsyncManager(request).isConcurrentHandlingStarted()); } /** * Same contract as for doFilter, but guaranteed to be - * just invoked once per request. Provides HttpServletRequest and - * HttpServletResponse arguments instead of the default ServletRequest - * and ServletResponse ones. + * just invoked once per request or once per request thread. + * See {@link #shouldFilterAsyncDispatches()} for details. + *

Provides HttpServletRequest and HttpServletResponse arguments instead of the + * default ServletRequest and ServletResponse ones. */ protected abstract void doFilterInternal( HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) diff --git a/spring-web/src/main/java/org/springframework/web/filter/RequestContextFilter.java b/spring-web/src/main/java/org/springframework/web/filter/RequestContextFilter.java index 3c640fa4e19..9fe7cfed8a5 100644 --- a/spring-web/src/main/java/org/springframework/web/filter/RequestContextFilter.java +++ b/spring-web/src/main/java/org/springframework/web/filter/RequestContextFilter.java @@ -26,8 +26,6 @@ import javax.servlet.http.HttpServletResponse; import org.springframework.context.i18n.LocaleContextHolder; import org.springframework.web.context.request.RequestContextHolder; import org.springframework.web.context.request.ServletRequestAttributes; -import org.springframework.web.context.request.async.AbstractDelegatingCallable; -import org.springframework.web.context.request.async.AsyncExecutionChain; /** * Servlet 2.3 Filter that exposes the request to the current thread, @@ -72,6 +70,15 @@ public class RequestContextFilter extends OncePerRequestFilter { } + /** + * The default value is "true" in which case the filter will set up the request + * context in each asynchronously dispatched thread. + */ + @Override + protected boolean shouldFilterAsyncDispatches() { + return true; + } + @Override protected void doFilterInternal( HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) @@ -80,21 +87,15 @@ public class RequestContextFilter extends OncePerRequestFilter { ServletRequestAttributes attributes = new ServletRequestAttributes(request); initContextHolders(request, attributes); - AsyncExecutionChain chain = AsyncExecutionChain.getForCurrentRequest(request); - chain.push(getChainedCallable(request, attributes)); - try { filterChain.doFilter(request, response); } finally { resetContextHolders(); - if (!chain.pop()) { - return; - } - attributes.requestCompleted(); if (logger.isDebugEnabled()) { logger.debug("Cleared thread-bound request context: " + request); } + attributes.requestCompleted(); } } @@ -111,28 +112,4 @@ public class RequestContextFilter extends OncePerRequestFilter { RequestContextHolder.resetRequestAttributes(); } - /** - * Create a Callable to use to complete processing in an async execution chain. - */ - private AbstractDelegatingCallable getChainedCallable(final HttpServletRequest request, - final ServletRequestAttributes requestAttributes) { - - return new AbstractDelegatingCallable() { - public Object call() throws Exception { - initContextHolders(request, requestAttributes); - try { - getNext().call(); - } - finally { - resetContextHolders(); - requestAttributes.requestCompleted(); - if (logger.isDebugEnabled()) { - logger.debug("Cleared thread-bound request context: " + request); - } - } - return null; - } - }; - } - } diff --git a/spring-web/src/main/java/org/springframework/web/filter/ShallowEtagHeaderFilter.java b/spring-web/src/main/java/org/springframework/web/filter/ShallowEtagHeaderFilter.java index 79d704b7d27..6f0e33c6a94 100644 --- a/spring-web/src/main/java/org/springframework/web/filter/ShallowEtagHeaderFilter.java +++ b/spring-web/src/main/java/org/springframework/web/filter/ShallowEtagHeaderFilter.java @@ -29,10 +29,9 @@ import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponseWrapper; +import org.springframework.util.Assert; import org.springframework.util.DigestUtils; import org.springframework.util.FileCopyUtils; -import org.springframework.web.context.request.async.AbstractDelegatingCallable; -import org.springframework.web.context.request.async.AsyncExecutionChain; import org.springframework.web.util.WebUtils; /** @@ -54,37 +53,34 @@ public class ShallowEtagHeaderFilter extends OncePerRequestFilter { private static String HEADER_IF_NONE_MATCH = "If-None-Match"; + /** + * The default value is "true" so that the filter may delay the generation of + * an ETag until the last asynchronously dispatched thread. + */ + @Override + protected boolean shouldFilterAsyncDispatches() { + return true; + } + @Override protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { - ShallowEtagResponseWrapper responseWrapper = new ShallowEtagResponseWrapper(response); + ShallowEtagResponseWrapper responseWrapper; - AsyncExecutionChain chain = AsyncExecutionChain.getForCurrentRequest(request); - chain.push(getAsyncCallable(request, response, responseWrapper)); + if (isAsyncDispatch(request)) { + responseWrapper = WebUtils.getNativeResponse(response, ShallowEtagResponseWrapper.class); + Assert.notNull(responseWrapper, "Expected wrapped response"); + } + else { + responseWrapper = new ShallowEtagResponseWrapper(response); + } filterChain.doFilter(request, responseWrapper); - if (!chain.pop()) { - return; + if (isLastRequestThread(request)) { + updateResponse(request, response, responseWrapper); } - - updateResponse(request, response, responseWrapper); - } - - /** - * Create a Callable to use to complete processing in an async execution chain. - */ - private AbstractDelegatingCallable getAsyncCallable(final HttpServletRequest request, - final HttpServletResponse response, final ShallowEtagResponseWrapper responseWrapper) { - - return new AbstractDelegatingCallable() { - public Object call() throws Exception { - getNext().call(); - updateResponse(request, response, responseWrapper); - return null; - } - }; } private void updateResponse(HttpServletRequest request, HttpServletResponse response, diff --git a/spring-web/src/test/java/org/springframework/http/client/AbstractHttpRequestFactoryTestCase.java b/spring-web/src/test/java/org/springframework/http/client/AbstractHttpRequestFactoryTestCase.java index 58ed4ebbe07..0ce0dc19e4e 100644 --- a/spring-web/src/test/java/org/springframework/http/client/AbstractHttpRequestFactoryTestCase.java +++ b/spring-web/src/test/java/org/springframework/http/client/AbstractHttpRequestFactoryTestCase.java @@ -16,12 +16,16 @@ package org.springframework.http.client; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + import java.io.IOException; import java.io.InputStream; import java.net.URI; import java.util.Arrays; import java.util.Enumeration; import java.util.Locale; + import javax.servlet.GenericServlet; import javax.servlet.ServletException; import javax.servlet.ServletRequest; @@ -30,20 +34,16 @@ import javax.servlet.http.HttpServlet; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; -import org.springframework.http.HttpMethod; -import org.springframework.http.HttpStatus; -import org.springframework.util.FileCopyUtils; - +import org.eclipse.jetty.server.Server; +import org.eclipse.jetty.servlet.ServletContextHandler; +import org.eclipse.jetty.servlet.ServletHolder; import org.junit.AfterClass; import org.junit.Before; import org.junit.BeforeClass; import org.junit.Test; -import org.mortbay.jetty.Server; -import org.mortbay.jetty.servlet.Context; -import org.mortbay.jetty.servlet.ServletHolder; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import org.springframework.http.HttpMethod; +import org.springframework.http.HttpStatus; +import org.springframework.util.FileCopyUtils; public abstract class AbstractHttpRequestFactoryTestCase { @@ -58,17 +58,23 @@ public abstract class AbstractHttpRequestFactoryTestCase { int port = FreePortScanner.getFreePort(); jettyServer = new Server(port); baseUrl = "http://localhost:" + port; - Context jettyContext = new Context(jettyServer, "/"); - jettyContext.addServlet(new ServletHolder(new EchoServlet()), "/echo"); - jettyContext.addServlet(new ServletHolder(new StatusServlet(200)), "/status/ok"); - jettyContext.addServlet(new ServletHolder(new StatusServlet(404)), "/status/notfound"); - jettyContext.addServlet(new ServletHolder(new MethodServlet("DELETE")), "/methods/delete"); - jettyContext.addServlet(new ServletHolder(new MethodServlet("GET")), "/methods/get"); - jettyContext.addServlet(new ServletHolder(new MethodServlet("HEAD")), "/methods/head"); - jettyContext.addServlet(new ServletHolder(new MethodServlet("OPTIONS")), "/methods/options"); - jettyContext.addServlet(new ServletHolder(new PostServlet()), "/methods/post"); - jettyContext.addServlet(new ServletHolder(new MethodServlet("PUT")), "/methods/put"); - jettyContext.addServlet(new ServletHolder(new MethodServlet("PATCH")), "/methods/patch"); + + ServletContextHandler handler = new ServletContextHandler(); + handler.setContextPath("/"); + + handler.addServlet(new ServletHolder(new EchoServlet()), "/echo"); + handler.addServlet(new ServletHolder(new EchoServlet()), "/echo"); + handler.addServlet(new ServletHolder(new StatusServlet(200)), "/status/ok"); + handler.addServlet(new ServletHolder(new StatusServlet(404)), "/status/notfound"); + handler.addServlet(new ServletHolder(new MethodServlet("DELETE")), "/methods/delete"); + handler.addServlet(new ServletHolder(new MethodServlet("GET")), "/methods/get"); + handler.addServlet(new ServletHolder(new MethodServlet("HEAD")), "/methods/head"); + handler.addServlet(new ServletHolder(new MethodServlet("OPTIONS")), "/methods/options"); + handler.addServlet(new ServletHolder(new PostServlet()), "/methods/post"); + handler.addServlet(new ServletHolder(new MethodServlet("PUT")), "/methods/put"); + handler.addServlet(new ServletHolder(new MethodServlet("PATCH")), "/methods/patch"); + + jettyServer.setHandler(handler); jettyServer.start(); } @@ -179,6 +185,7 @@ public abstract class AbstractHttpRequestFactoryTestCase { /** * Servlet that sets a given status code. */ + @SuppressWarnings("serial") private static class StatusServlet extends GenericServlet { private final int sc; @@ -193,6 +200,7 @@ public abstract class AbstractHttpRequestFactoryTestCase { } } + @SuppressWarnings("serial") private static class MethodServlet extends GenericServlet { private final String method; @@ -210,6 +218,7 @@ public abstract class AbstractHttpRequestFactoryTestCase { } } + @SuppressWarnings("serial") private static class PostServlet extends MethodServlet { private PostServlet() { @@ -233,6 +242,7 @@ public abstract class AbstractHttpRequestFactoryTestCase { } } + @SuppressWarnings("serial") private static class EchoServlet extends HttpServlet { @Override diff --git a/spring-web/src/test/java/org/springframework/mock/web/MockAsyncContext.java b/spring-web/src/test/java/org/springframework/mock/web/MockAsyncContext.java new file mode 100644 index 00000000000..9607213e456 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/mock/web/MockAsyncContext.java @@ -0,0 +1,133 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.mock.web; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import javax.servlet.AsyncContext; +import javax.servlet.AsyncEvent; +import javax.servlet.AsyncListener; +import javax.servlet.DispatcherType; +import javax.servlet.ServletContext; +import javax.servlet.ServletException; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; + +import org.springframework.beans.BeanUtils; +import org.springframework.web.util.WebUtils; + +/** + * Mock implementation of the {@link AsyncContext} interface. + * + * @author Rossen Stoyanchev + * @since 3.2 + */ +public class MockAsyncContext implements AsyncContext { + + private final ServletRequest request; + + private final ServletResponse response; + + private final MockHttpServletRequest mockRequest; + + private final List listeners = new ArrayList(); + + private String dispatchPath; + + private long timeout = 10 * 60 * 1000L; + + public MockAsyncContext(ServletRequest request, ServletResponse response) { + this.request = request; + this.response = response; + this.mockRequest = WebUtils.getNativeRequest(request, MockHttpServletRequest.class); + } + + public ServletRequest getRequest() { + return this.request; + } + + public ServletResponse getResponse() { + return this.response; + } + + public boolean hasOriginalRequestAndResponse() { + return false; + } + + public String getDispatchPath() { + return this.dispatchPath; + } + + public void dispatch() { + dispatch(null); + } + + public void dispatch(String path) { + dispatch(null, path); + } + + public void dispatch(ServletContext context, String path) { + this.dispatchPath = path; + if (this.mockRequest != null) { + this.mockRequest.setDispatcherType(DispatcherType.ASYNC); + this.mockRequest.setAsyncStarted(false); + } + } + + public void complete() { + if (this.mockRequest != null) { + this.mockRequest.setAsyncStarted(false); + } + for (AsyncListener listener : this.listeners) { + try { + listener.onComplete(new AsyncEvent(this, this.request, this.response)); + } + catch (IOException e) { + throw new IllegalStateException("AsyncListener failure", e); + } + } + } + + public void start(Runnable run) { + } + + public List getListeners() { + return this.listeners; + } + + public void addListener(AsyncListener listener) { + this.listeners.add(listener); + } + + public void addListener(AsyncListener listener, ServletRequest request, ServletResponse response) { + this.listeners.add(listener); + } + + public T createListener(Class clazz) throws ServletException { + return BeanUtils.instantiateClass(clazz); + } + + public long getTimeout() { + return this.timeout; + } + + public void setTimeout(long timeout) { + this.timeout = timeout; + } + +} diff --git a/spring-web/src/test/java/org/springframework/mock/web/MockHttpServletRequest.java b/spring-web/src/test/java/org/springframework/mock/web/MockHttpServletRequest.java index 5d23662a3fc..d08fa5ce8ec 100644 --- a/spring-web/src/test/java/org/springframework/mock/web/MockHttpServletRequest.java +++ b/spring-web/src/test/java/org/springframework/mock/web/MockHttpServletRequest.java @@ -56,8 +56,8 @@ import org.springframework.util.LinkedCaseInsensitiveMap; /** * Mock implementation of the {@link javax.servlet.http.HttpServletRequest} - * interface. Supports the Servlet 2.5 API level; throws - * {@link UnsupportedOperationException} for all methods introduced in Servlet 3.0. + * interface. Supports the Servlet 2.5 API leve. Throws + * {@link UnsupportedOperationException} for some methods introduced in Servlet 3.0. * *

Used for testing the web framework; also useful for testing * application controllers. @@ -102,7 +102,7 @@ public class MockHttpServletRequest implements HttpServletRequest { public static final String DEFAULT_REMOTE_HOST = "localhost"; private static final String CONTENT_TYPE_HEADER = "Content-Type"; - + private static final String CHARSET_PREFIX = "charset="; @@ -190,6 +190,13 @@ public class MockHttpServletRequest implements HttpServletRequest { private boolean requestedSessionIdFromURL = false; + private boolean asyncSupported = false; + + private boolean asyncStarted = false; + + private MockAsyncContext asyncContext; + + private DispatcherType dispatcherType = DispatcherType.REQUEST; //--------------------------------------------------------------------- // Constructors @@ -312,7 +319,7 @@ public class MockHttpServletRequest implements HttpServletRequest { this.characterEncoding = characterEncoding; updateContentTypeHeader(); } - + private void updateContentTypeHeader() { if (this.contentType != null) { StringBuilder sb = new StringBuilder(this.contentType); @@ -679,7 +686,7 @@ public class MockHttpServletRequest implements HttpServletRequest { } doAddHeaderValue(name, value, false); } - + @SuppressWarnings("rawtypes") private void doAddHeaderValue(String name, Object value, boolean replace) { HeaderValueHolder header = HeaderValueHolder.getByName(this.headers, name); @@ -898,33 +905,54 @@ public class MockHttpServletRequest implements HttpServletRequest { //--------------------------------------------------------------------- public AsyncContext getAsyncContext() { - throw new UnsupportedOperationException(); + return this.asyncContext; + } + + public void setAsyncContext(MockAsyncContext asyncContext) { + this.asyncContext = asyncContext; } public DispatcherType getDispatcherType() { - throw new UnsupportedOperationException(); + return this.dispatcherType; + } + + public void setDispatcherType(DispatcherType dispatcherType) { + this.dispatcherType = dispatcherType; + } + + public void setAsyncSupported(boolean asyncSupported) { + this.asyncSupported = asyncSupported; } public boolean isAsyncSupported() { - throw new UnsupportedOperationException(); + return this.asyncSupported; } public AsyncContext startAsync() { - throw new UnsupportedOperationException(); + return startAsync(this, null); } - public AsyncContext startAsync(ServletRequest arg0, ServletResponse arg1) { - throw new UnsupportedOperationException(); + public AsyncContext startAsync(ServletRequest request, ServletResponse response) { + if (!this.asyncSupported) { + throw new IllegalStateException("Async not supported"); + } + this.asyncStarted = true; + this.asyncContext = new MockAsyncContext(request, response); + return this.asyncContext; + } + + public void setAsyncStarted(boolean asyncStarted) { + this.asyncStarted = asyncStarted; } public boolean isAsyncStarted() { - throw new UnsupportedOperationException(); + return this.asyncStarted; } public boolean authenticate(HttpServletResponse arg0) throws IOException, ServletException { throw new UnsupportedOperationException(); } - + public void addPart(Part part) { parts.put(part.getName(), part); } diff --git a/spring-web/src/test/java/org/springframework/web/client/RestTemplateIntegrationTests.java b/spring-web/src/test/java/org/springframework/web/client/RestTemplateIntegrationTests.java index edae6e4afc4..00306c05316 100644 --- a/spring-web/src/test/java/org/springframework/web/client/RestTemplateIntegrationTests.java +++ b/spring-web/src/test/java/org/springframework/web/client/RestTemplateIntegrationTests.java @@ -16,6 +16,14 @@ package org.springframework.web.client; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + import java.io.IOException; import java.io.UnsupportedEncodingException; import java.net.URI; @@ -25,6 +33,7 @@ import java.util.Collections; import java.util.EnumSet; import java.util.List; import java.util.Set; + import javax.servlet.GenericServlet; import javax.servlet.ServletException; import javax.servlet.ServletRequest; @@ -38,14 +47,13 @@ import org.apache.commons.fileupload.FileItemFactory; import org.apache.commons.fileupload.FileUploadException; import org.apache.commons.fileupload.disk.DiskFileItemFactory; import org.apache.commons.fileupload.servlet.ServletFileUpload; +import org.eclipse.jetty.server.Server; +import org.eclipse.jetty.servlet.ServletContextHandler; +import org.eclipse.jetty.servlet.ServletHolder; import org.junit.AfterClass; import org.junit.Before; import org.junit.BeforeClass; import org.junit.Test; -import org.mortbay.jetty.Server; -import org.mortbay.jetty.servlet.Context; -import org.mortbay.jetty.servlet.ServletHolder; - import org.springframework.core.io.ClassPathResource; import org.springframework.core.io.Resource; import org.springframework.http.HttpEntity; @@ -60,8 +68,6 @@ import org.springframework.util.FileCopyUtils; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; -import static org.junit.Assert.*; - /** @author Arjen Poutsma */ public class RestTemplateIntegrationTests { @@ -80,21 +86,22 @@ public class RestTemplateIntegrationTests { int port = FreePortScanner.getFreePort(); jettyServer = new Server(port); baseUrl = "http://localhost:" + port; - Context jettyContext = new Context(jettyServer, "/"); + ServletContextHandler handler = new ServletContextHandler(); byte[] bytes = helloWorld.getBytes("UTF-8"); - contentType = new MediaType("text", "plain", Collections.singletonMap("charset", "utf-8")); - jettyContext.addServlet(new ServletHolder(new GetServlet(bytes, contentType)), "/get"); - jettyContext.addServlet(new ServletHolder(new GetServlet(new byte[0], contentType)), "/get/nothing"); - jettyContext.addServlet(new ServletHolder(new GetServlet(bytes, null)), "/get/nocontenttype"); - jettyContext.addServlet( + contentType = new MediaType("text", "plain", Collections.singletonMap("charset", "UTF-8")); + handler.addServlet(new ServletHolder(new GetServlet(bytes, contentType)), "/get"); + handler.addServlet(new ServletHolder(new GetServlet(new byte[0], contentType)), "/get/nothing"); + handler.addServlet(new ServletHolder(new GetServlet(bytes, null)), "/get/nocontenttype"); + handler.addServlet( new ServletHolder(new PostServlet(helloWorld, baseUrl + "/post/1", bytes, contentType)), "/post"); - jettyContext.addServlet(new ServletHolder(new StatusCodeServlet(204)), "/status/nocontent"); - jettyContext.addServlet(new ServletHolder(new StatusCodeServlet(304)), "/status/notmodified"); - jettyContext.addServlet(new ServletHolder(new ErrorServlet(404)), "/status/notfound"); - jettyContext.addServlet(new ServletHolder(new ErrorServlet(500)), "/status/server"); - jettyContext.addServlet(new ServletHolder(new UriServlet()), "/uri/*"); - jettyContext.addServlet(new ServletHolder(new MultipartServlet()), "/multipart"); + handler.addServlet(new ServletHolder(new StatusCodeServlet(204)), "/status/nocontent"); + handler.addServlet(new ServletHolder(new StatusCodeServlet(304)), "/status/notmodified"); + handler.addServlet(new ServletHolder(new ErrorServlet(404)), "/status/notfound"); + handler.addServlet(new ServletHolder(new ErrorServlet(500)), "/status/server"); + handler.addServlet(new ServletHolder(new UriServlet()), "/uri/*"); + handler.addServlet(new ServletHolder(new MultipartServlet()), "/multipart"); + jettyServer.setHandler(handler); jettyServer.start(); } @@ -130,7 +137,7 @@ public class RestTemplateIntegrationTests { String s = template.getForObject(baseUrl + "/get/nothing", String.class); assertNull("Invalid content", s); } - + @Test public void getNoContentTypeHeader() throws UnsupportedEncodingException { byte[] bytes = template.getForObject(baseUrl + "/get/nocontenttype", byte[].class); @@ -141,7 +148,7 @@ public class RestTemplateIntegrationTests { public void getNoContent() { String s = template.getForObject(baseUrl + "/status/nocontent", String.class); assertNull("Invalid content", s); - + ResponseEntity entity = template.getForEntity(baseUrl + "/status/nocontent", String.class); assertEquals("Invalid response code", HttpStatus.NO_CONTENT, entity.getStatusCode()); assertNull("Invalid content", entity.getBody()); diff --git a/spring-web/src/test/java/org/springframework/web/context/request/async/AsyncExecutionChainTests.java b/spring-web/src/test/java/org/springframework/web/context/request/async/AsyncExecutionChainTests.java deleted file mode 100644 index ebbb4048108..00000000000 --- a/spring-web/src/test/java/org/springframework/web/context/request/async/AsyncExecutionChainTests.java +++ /dev/null @@ -1,251 +0,0 @@ -/* - * Copyright 2002-2012 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.web.context.request.async; - -import static org.hamcrest.Matchers.containsString; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertSame; -import static org.junit.Assert.assertThat; -import static org.junit.Assert.assertTrue; -import static org.junit.Assert.fail; - -import java.util.concurrent.Callable; - -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; - -import org.junit.Before; -import org.junit.Test; -import org.springframework.core.task.SimpleAsyncTaskExecutor; -import org.springframework.http.HttpStatus; -import org.springframework.mock.web.MockHttpServletRequest; -import org.springframework.mock.web.MockHttpServletResponse; -import org.springframework.web.context.request.ServletWebRequest; - - -/** - * Test fixture with an AsyncExecutionChain. - * - * @author Rossen Stoyanchev - */ -public class AsyncExecutionChainTests { - - private AsyncExecutionChain chain; - - private MockHttpServletRequest request; - - private SimpleAsyncWebRequest asyncWebRequest; - - private ResultSavingCallable resultSavingCallable; - - @Before - public void setUp() { - this.request = new MockHttpServletRequest(); - this.asyncWebRequest = new SimpleAsyncWebRequest(this.request, new MockHttpServletResponse()); - this.resultSavingCallable = new ResultSavingCallable(); - - this.chain = AsyncExecutionChain.getForCurrentRequest(this.request); - this.chain.setTaskExecutor(new SyncTaskExecutor()); - this.chain.setAsyncWebRequest(this.asyncWebRequest); - this.chain.push(this.resultSavingCallable); - } - - @Test - public void getForCurrentRequest() throws Exception { - assertNotNull(this.chain); - assertSame(this.chain, AsyncExecutionChain.getForCurrentRequest(this.request)); - assertSame(this.chain, this.request.getAttribute(AsyncExecutionChain.CALLABLE_CHAIN_ATTRIBUTE)); - } - - @Test - public void isAsyncStarted() { - assertFalse(this.chain.isAsyncStarted()); - - this.asyncWebRequest.startAsync(); - assertTrue(this.chain.isAsyncStarted()); - } - - @Test(expected=IllegalStateException.class) - public void setAsyncWebRequestAfterAsyncStarted() { - this.asyncWebRequest.startAsync(); - this.chain.setAsyncWebRequest(null); - } - - @Test - public void startCallableChainProcessing() throws Exception { - this.chain.push(new IntegerIncrementingCallable()); - this.chain.push(new IntegerIncrementingCallable()); - this.chain.setLastCallable(new Callable() { - public Object call() throws Exception { - return 1; - } - }); - - this.chain.startCallableProcessing(); - - assertEquals(3, this.resultSavingCallable.result); - } - - @Test - public void startCallableChainProcessing_staleRequest() { - this.chain.setLastCallable(new Callable() { - public Object call() throws Exception { - return 1; - } - }); - - this.asyncWebRequest.startAsync(); - this.asyncWebRequest.complete(); - this.chain.startCallableProcessing(); - Exception ex = this.resultSavingCallable.exception; - - assertNotNull(ex); - assertEquals(StaleAsyncWebRequestException.class, ex.getClass()); - } - - @Test - public void startCallableChainProcessing_requiredCallable() { - try { - this.chain.startCallableProcessing(); - fail("Expected exception"); - } - catch (IllegalStateException ex) { - assertEquals(ex.getMessage(), "The last Callable was not set"); - } - } - - @Test - public void startCallableChainProcessing_requiredAsyncWebRequest() { - this.chain.setAsyncWebRequest(null); - try { - this.chain.startCallableProcessing(); - fail("Expected exception"); - } - catch (IllegalStateException ex) { - assertEquals(ex.getMessage(), "AsyncWebRequest was not set"); - } - } - - @Test - public void startDeferredResultProcessing() throws Exception { - this.chain.push(new IntegerIncrementingCallable()); - this.chain.push(new IntegerIncrementingCallable()); - - DeferredResult deferredResult = new DeferredResult(); - this.chain.startDeferredResultProcessing(deferredResult); - - assertTrue(this.asyncWebRequest.isAsyncStarted()); - - deferredResult.set(1); - - assertEquals(3, this.resultSavingCallable.result); - } - - @Test(expected=StaleAsyncWebRequestException.class) - public void startDeferredResultProcessing_staleRequest() throws Exception { - this.asyncWebRequest.startAsync(); - this.asyncWebRequest.complete(); - - DeferredResult deferredResult = new DeferredResult(); - this.chain.startDeferredResultProcessing(deferredResult); - deferredResult.set(1); - } - - @Test - public void startDeferredResultProcessing_requiredDeferredResult() { - try { - this.chain.startDeferredResultProcessing(null); - fail("Expected exception"); - } - catch (IllegalArgumentException ex) { - assertThat(ex.getMessage(), containsString("DeferredResult is required")); - } - } - - - private static class SimpleAsyncWebRequest extends ServletWebRequest implements AsyncWebRequest { - - private boolean asyncStarted; - - private boolean asyncCompleted; - - public SimpleAsyncWebRequest(HttpServletRequest request, HttpServletResponse response) { - super(request, response); - } - - public void setTimeout(Long timeout) { } - - public void setTimeoutHandler(Runnable runnable) { } - - public void startAsync() { - this.asyncStarted = true; - } - - public boolean isAsyncStarted() { - return this.asyncStarted; - } - - public void complete() { - this.asyncStarted = false; - this.asyncCompleted = true; - } - - public boolean isAsyncCompleted() { - return this.asyncCompleted; - } - - public void sendError(HttpStatus status, String message) { - } - } - - @SuppressWarnings("serial") - private static class SyncTaskExecutor extends SimpleAsyncTaskExecutor { - - @Override - public void execute(Runnable task, long startTimeout) { - task.run(); - } - } - - private static class ResultSavingCallable extends AbstractDelegatingCallable { - - Object result; - - Exception exception; - - public Object call() throws Exception { - try { - this.result = getNext().call(); - } - catch (Exception ex) { - this.exception = ex; - throw ex; - } - return this.result; - } - } - - private static class IntegerIncrementingCallable extends AbstractDelegatingCallable { - - public Object call() throws Exception { - return ((Integer) getNext().call() + 1); - } - } - -} diff --git a/spring-web/src/test/java/org/springframework/web/context/request/async/DeferredResultTests.java b/spring-web/src/test/java/org/springframework/web/context/request/async/DeferredResultTests.java index 667e084de66..30f4a4749b6 100644 --- a/spring-web/src/test/java/org/springframework/web/context/request/async/DeferredResultTests.java +++ b/spring-web/src/test/java/org/springframework/web/context/request/async/DeferredResultTests.java @@ -25,6 +25,8 @@ import static org.junit.Assert.assertNull; import static org.junit.Assert.fail; import org.junit.Test; +import org.springframework.web.context.request.async.DeferredResult; +import org.springframework.web.context.request.async.StaleAsyncWebRequestException; import org.springframework.web.context.request.async.DeferredResult.DeferredResultHandler; /** diff --git a/spring-web/src/test/java/org/springframework/web/context/request/async/StaleAsyncRequestCheckingCallableTests.java b/spring-web/src/test/java/org/springframework/web/context/request/async/StaleAsyncRequestCheckingCallableTests.java deleted file mode 100644 index abd5c68bd27..00000000000 --- a/spring-web/src/test/java/org/springframework/web/context/request/async/StaleAsyncRequestCheckingCallableTests.java +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Copyright 2002-2012 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.web.context.request.async; - -import static org.easymock.EasyMock.*; - -import java.util.concurrent.Callable; - -import org.easymock.EasyMock; -import org.junit.Before; -import org.junit.Test; - -/** - * A test fixture with a {@link StaleAsyncRequestCheckingCallable}. - * - * @author Rossen Stoyanchev - */ -public class StaleAsyncRequestCheckingCallableTests { - - private StaleAsyncRequestCheckingCallable callable; - - private AsyncWebRequest asyncWebRequest; - - @Before - public void setUp() { - this.asyncWebRequest = EasyMock.createMock(AsyncWebRequest.class); - this.callable = new StaleAsyncRequestCheckingCallable(asyncWebRequest); - this.callable.setNext(new Callable() { - public Object call() throws Exception { - return 1; - } - }); - } - - @Test - public void call_notStale() throws Exception { - expect(this.asyncWebRequest.isAsyncCompleted()).andReturn(false); - replay(this.asyncWebRequest); - - this.callable.call(); - - verify(this.asyncWebRequest); - } - - @Test(expected=StaleAsyncWebRequestException.class) - public void call_stale() throws Exception { - expect(this.asyncWebRequest.isAsyncCompleted()).andReturn(true); - replay(this.asyncWebRequest); - - try { - this.callable.call(); - } - finally { - verify(this.asyncWebRequest); - } - } -} diff --git a/spring-web/src/test/java/org/springframework/web/context/request/async/StandardServletAsyncWebRequestTests.java b/spring-web/src/test/java/org/springframework/web/context/request/async/StandardServletAsyncWebRequestTests.java index 75263147e9a..2c80e6f8b73 100644 --- a/spring-web/src/test/java/org/springframework/web/context/request/async/StandardServletAsyncWebRequestTests.java +++ b/spring-web/src/test/java/org/springframework/web/context/request/async/StandardServletAsyncWebRequestTests.java @@ -17,25 +17,25 @@ package org.springframework.web.context.request.async; -import static org.easymock.EasyMock.expect; import static org.easymock.EasyMock.replay; -import static org.easymock.EasyMock.reset; import static org.easymock.EasyMock.verify; import static org.hamcrest.Matchers.containsString; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertSame; import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; -import javax.servlet.AsyncContext; import javax.servlet.AsyncEvent; -import javax.servlet.http.HttpServletRequest; import org.easymock.EasyMock; import org.junit.Before; import org.junit.Test; import org.springframework.http.HttpStatus; +import org.springframework.mock.web.MockAsyncContext; +import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; /** @@ -47,124 +47,92 @@ public class StandardServletAsyncWebRequestTests { private StandardServletAsyncWebRequest asyncRequest; - private HttpServletRequest request; + private MockHttpServletRequest request; private MockHttpServletResponse response; @Before public void setup() { - this.request = EasyMock.createMock(HttpServletRequest.class); + this.request = new MockHttpServletRequest(); + this.request.setAsyncSupported(true); this.response = new MockHttpServletResponse(); this.asyncRequest = new StandardServletAsyncWebRequest(this.request, this.response); - this.asyncRequest.setTimeout(60*1000L); + this.asyncRequest.setTimeout(44*1000L); } @Test public void isAsyncStarted() throws Exception { - replay(this.request); - assertEquals("Should be \"false\" before startAsync()", false, this.asyncRequest.isAsyncStarted()); - verify(this.request); - - startAsync(); - - reset(this.request); - expect(this.request.isAsyncStarted()).andReturn(true); - replay(this.request); - - assertTrue("Should be \"true\" true startAsync()", this.asyncRequest.isAsyncStarted()); - verify(this.request); - - this.asyncRequest.onComplete(new AsyncEvent(null)); + assertFalse(this.asyncRequest.isAsyncStarted()); - assertFalse("Should be \"false\" after complete()", this.asyncRequest.isAsyncStarted()); + this.asyncRequest.startAsync(); + assertTrue(this.asyncRequest.isAsyncStarted()); } @Test public void startAsync() throws Exception { - AsyncContext asyncContext = EasyMock.createMock(AsyncContext.class); - - reset(this.request); - expect(this.request.isAsyncSupported()).andReturn(true); - expect(this.request.startAsync(this.request, this.response)).andStubReturn(asyncContext); - replay(this.request); - - asyncContext.addListener(this.asyncRequest); - asyncContext.setTimeout(60*1000); - replay(asyncContext); - this.asyncRequest.startAsync(); - verify(this.request); - } + MockAsyncContext asyncContext = (MockAsyncContext) this.request.getAsyncContext(); - @Test - public void startAsync_notSupported() throws Exception { - expect(this.request.isAsyncSupported()).andReturn(false); - replay(this.request); - try { - this.asyncRequest.startAsync(); - fail("expected exception"); - } - catch (IllegalStateException ex) { - assertThat(ex.getMessage(), containsString("Async support must be enabled")); - } + assertNotNull(asyncContext); + assertEquals("Timeout value not set", 44 * 1000, asyncContext.getTimeout()); + assertEquals(1, asyncContext.getListeners().size()); + assertSame(this.asyncRequest, asyncContext.getListeners().get(0)); } @Test - public void startAsync_alreadyStarted() throws Exception { - startAsync(); + public void startAsyncMultipleTimes() throws Exception { + this.asyncRequest.startAsync(); + this.asyncRequest.startAsync(); + this.asyncRequest.startAsync(); + this.asyncRequest.startAsync(); // idempotent - reset(this.request); + MockAsyncContext asyncContext = (MockAsyncContext) this.request.getAsyncContext(); - expect(this.request.isAsyncSupported()).andReturn(true); - expect(this.request.isAsyncStarted()).andReturn(true); - replay(this.request); + assertNotNull(asyncContext); + assertEquals(1, asyncContext.getListeners().size()); + } + @Test + public void startAsyncNotSupported() throws Exception { + this.request.setAsyncSupported(false); try { this.asyncRequest.startAsync(); fail("expected exception"); } catch (IllegalStateException ex) { - assertEquals("Async processing already started", ex.getMessage()); + assertThat(ex.getMessage(), containsString("Async support must be enabled")); } - - verify(this.request); } @Test - public void startAsync_stale() throws Exception { - expect(this.request.isAsyncSupported()).andReturn(true); - replay(this.request); + public void startAsyncAfterCompleted() throws Exception { this.asyncRequest.onComplete(new AsyncEvent(null)); try { this.asyncRequest.startAsync(); fail("expected exception"); } catch (IllegalStateException ex) { - assertEquals("Cannot use async request that has completed", ex.getMessage()); + assertEquals("Async processing has already completed", ex.getMessage()); } } @Test - public void complete_stale() throws Exception { - this.asyncRequest.onComplete(new AsyncEvent(null)); - this.asyncRequest.complete(); - - assertFalse(this.asyncRequest.isAsyncStarted()); - assertTrue(this.asyncRequest.isAsyncCompleted()); + public void onTimeoutDefaultBehavior() throws Exception { + this.asyncRequest.onTimeout(new AsyncEvent(null)); + assertEquals(HttpStatus.SERVICE_UNAVAILABLE.value(), this.response.getStatus()); } @Test - public void sendError() throws Exception { - this.asyncRequest.sendError(HttpStatus.INTERNAL_SERVER_ERROR, "error"); - assertEquals(500, this.response.getStatus()); - } + public void onTimeoutTimeoutHandler() throws Exception { + Runnable timeoutHandler = EasyMock.createMock(Runnable.class); + timeoutHandler.run(); + replay(timeoutHandler); - @Test - public void sendError_stale() throws Exception { - this.asyncRequest.onComplete(new AsyncEvent(null)); - this.asyncRequest.sendError(HttpStatus.INTERNAL_SERVER_ERROR, "error"); - assertEquals(200, this.response.getStatus()); + this.asyncRequest.setTimeoutHandler(timeoutHandler); + this.asyncRequest.onTimeout(new AsyncEvent(null)); + + verify(timeoutHandler); } } diff --git a/spring-web/src/test/java/org/springframework/web/context/request/async/WebAsyncManagerTests.java b/spring-web/src/test/java/org/springframework/web/context/request/async/WebAsyncManagerTests.java new file mode 100644 index 00000000000..060b2ccc014 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/context/request/async/WebAsyncManagerTests.java @@ -0,0 +1,222 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.request.async; + +import static org.hamcrest.Matchers.containsString; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import java.util.concurrent.Callable; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +import org.junit.Before; +import org.junit.Test; +import org.springframework.core.task.SimpleAsyncTaskExecutor; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.web.context.request.ServletWebRequest; + + +/** + * Test fixture with an {@link WebAsyncManager}. + * + * @author Rossen Stoyanchev + */ +public class WebAsyncManagerTests { + + private WebAsyncManager asyncManager; + + private MockHttpServletRequest request; + + private StubAsyncWebRequest stubAsyncWebRequest; + + @Before + public void setUp() { + this.request = new MockHttpServletRequest(); + this.stubAsyncWebRequest = new StubAsyncWebRequest(this.request, new MockHttpServletResponse()); + + this.asyncManager = AsyncWebUtils.getAsyncManager(this.request); + this.asyncManager.setTaskExecutor(new SyncTaskExecutor()); + this.asyncManager.setAsyncWebRequest(this.stubAsyncWebRequest); + } + + @Test + public void getForCurrentRequest() throws Exception { + assertNotNull(this.asyncManager); + assertSame(this.asyncManager, AsyncWebUtils.getAsyncManager(this.request)); + assertSame(this.asyncManager, this.request.getAttribute(AsyncWebUtils.WEB_ASYNC_MANAGER_ATTRIBUTE)); + } + + @Test + public void isConcurrentHandlingStarted() { + assertFalse(this.asyncManager.isConcurrentHandlingStarted()); + + this.stubAsyncWebRequest.startAsync(); + assertTrue(this.asyncManager.isConcurrentHandlingStarted()); + } + + @Test(expected=IllegalArgumentException.class) + public void setAsyncWebRequestAfterAsyncStarted() { + this.stubAsyncWebRequest.startAsync(); + this.asyncManager.setAsyncWebRequest(null); + } + + @Test + public void startCallableChainProcessing() throws Exception { + this.asyncManager.startCallableProcessing(new Callable() { + public Object call() throws Exception { + return 1; + } + }); + + assertTrue(this.asyncManager.isConcurrentHandlingStarted()); + assertTrue(this.stubAsyncWebRequest.isDispatched()); + } + + @Test + public void startCallableChainProcessingStaleRequest() { + this.stubAsyncWebRequest.setAsyncComplete(true); + this.asyncManager.startCallableProcessing(new Callable() { + public Object call() throws Exception { + return 1; + } + }); + + assertFalse(this.stubAsyncWebRequest.isDispatched()); + } + + @Test + public void startCallableChainProcessingCallableRequired() { + try { + this.asyncManager.startCallableProcessing(null); + fail("Expected exception"); + } + catch (IllegalArgumentException ex) { + assertEquals(ex.getMessage(), "Callable is required"); + } + } + + @Test + public void startCallableChainProcessingAsyncWebRequestRequired() { + this.request.removeAttribute(AsyncWebUtils.WEB_ASYNC_MANAGER_ATTRIBUTE); + this.asyncManager = AsyncWebUtils.getAsyncManager(this.request); + try { + this.asyncManager.startCallableProcessing(new Callable() { + public Object call() throws Exception { + return null; + } + }); + fail("Expected exception"); + } + catch (IllegalStateException ex) { + assertEquals(ex.getMessage(), "AsyncWebRequest was not set"); + } + } + + @Test + public void startDeferredResultProcessing() throws Exception { + DeferredResult deferredResult = new DeferredResult(); + this.asyncManager.startDeferredResultProcessing(deferredResult); + + assertTrue(this.asyncManager.isConcurrentHandlingStarted()); + + deferredResult.set(25); + assertEquals(25, this.asyncManager.getConcurrentResult()); + } + + @Test(expected=StaleAsyncWebRequestException.class) + public void startDeferredResultProcessing_staleRequest() throws Exception { + DeferredResult deferredResult = new DeferredResult(); + this.asyncManager.startDeferredResultProcessing(deferredResult); + + this.stubAsyncWebRequest.setAsyncComplete(true); + deferredResult.set(1); + } + + @Test + public void startDeferredResultProcessingDeferredResultRequired() { + try { + this.asyncManager.startDeferredResultProcessing(null); + fail("Expected exception"); + } + catch (IllegalArgumentException ex) { + assertThat(ex.getMessage(), containsString("DeferredResult is required")); + } + } + + + private static class StubAsyncWebRequest extends ServletWebRequest implements AsyncWebRequest { + + private boolean asyncStarted; + + private boolean dispatched; + + private boolean asyncComplete; + + public StubAsyncWebRequest(HttpServletRequest request, HttpServletResponse response) { + super(request, response); + } + + public void setTimeout(Long timeout) { } + + public void setTimeoutHandler(Runnable runnable) { } + + public void startAsync() { + this.asyncStarted = true; + } + + public boolean isAsyncStarted() { + return this.asyncStarted; + } + + public void dispatch() { + this.dispatched = true; + } + + public boolean isDispatched() { + return dispatched; + } + + public void setAsyncComplete(boolean asyncComplete) { + this.asyncComplete = asyncComplete; + } + + public boolean isAsyncComplete() { + return this.asyncComplete; + } + + public void addCompletionHandler(Runnable runnable) { + } + } + + @SuppressWarnings("serial") + private static class SyncTaskExecutor extends SimpleAsyncTaskExecutor { + + @Override + public void execute(Runnable task, long startTimeout) { + task.run(); + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/filter/CharacterEncodingFilterTests.java b/spring-web/src/test/java/org/springframework/web/filter/CharacterEncodingFilterTests.java index 8e30c50109d..415d0c3b307 100644 --- a/spring-web/src/test/java/org/springframework/web/filter/CharacterEncodingFilterTests.java +++ b/spring-web/src/test/java/org/springframework/web/filter/CharacterEncodingFilterTests.java @@ -18,6 +18,7 @@ package org.springframework.web.filter; import static org.easymock.EasyMock.createMock; import static org.easymock.EasyMock.expect; +import static org.easymock.EasyMock.expectLastCall; import static org.easymock.EasyMock.notNull; import static org.easymock.EasyMock.replay; import static org.easymock.EasyMock.same; @@ -32,7 +33,7 @@ import junit.framework.TestCase; import org.springframework.mock.web.MockFilterConfig; import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.mock.web.MockServletContext; -import org.springframework.web.context.request.async.AsyncExecutionChain; +import org.springframework.web.context.request.async.AsyncWebUtils; /** * @author Rick Evans @@ -47,8 +48,7 @@ public class CharacterEncodingFilterTests extends TestCase { public void testForceAlwaysSetsEncoding() throws Exception { HttpServletRequest request = createMock(HttpServletRequest.class); - expect(request.getAttribute(AsyncExecutionChain.CALLABLE_CHAIN_ATTRIBUTE)).andReturn(null); - request.setAttribute(same(AsyncExecutionChain.CALLABLE_CHAIN_ATTRIBUTE), notNull()); + addAsyncManagerExpectations(request); request.setCharacterEncoding(ENCODING); expect(request.getAttribute(FILTER_NAME + OncePerRequestFilter.ALREADY_FILTERED_SUFFIX)).andReturn(null); request.setAttribute(FILTER_NAME + OncePerRequestFilter.ALREADY_FILTERED_SUFFIX, Boolean.TRUE); @@ -76,8 +76,7 @@ public class CharacterEncodingFilterTests extends TestCase { public void testEncodingIfEmptyAndNotForced() throws Exception { HttpServletRequest request = createMock(HttpServletRequest.class); - expect(request.getAttribute(AsyncExecutionChain.CALLABLE_CHAIN_ATTRIBUTE)).andReturn(null); - request.setAttribute(same(AsyncExecutionChain.CALLABLE_CHAIN_ATTRIBUTE), notNull()); + addAsyncManagerExpectations(request); expect(request.getCharacterEncoding()).andReturn(null); request.setCharacterEncoding(ENCODING); expect(request.getAttribute(FILTER_NAME + OncePerRequestFilter.ALREADY_FILTERED_SUFFIX)).andReturn(null); @@ -103,8 +102,7 @@ public class CharacterEncodingFilterTests extends TestCase { public void testDoesNowtIfEncodingIsNotEmptyAndNotForced() throws Exception { HttpServletRequest request = createMock(HttpServletRequest.class); - expect(request.getAttribute(AsyncExecutionChain.CALLABLE_CHAIN_ATTRIBUTE)).andReturn(null); - request.setAttribute(same(AsyncExecutionChain.CALLABLE_CHAIN_ATTRIBUTE), notNull()); + addAsyncManagerExpectations(request); expect(request.getCharacterEncoding()).andReturn(ENCODING); expect(request.getAttribute(FILTER_NAME + OncePerRequestFilter.ALREADY_FILTERED_SUFFIX)).andReturn(null); request.setAttribute(FILTER_NAME + OncePerRequestFilter.ALREADY_FILTERED_SUFFIX, Boolean.TRUE); @@ -128,8 +126,7 @@ public class CharacterEncodingFilterTests extends TestCase { public void testWithBeanInitialization() throws Exception { HttpServletRequest request = createMock(HttpServletRequest.class); - expect(request.getAttribute(AsyncExecutionChain.CALLABLE_CHAIN_ATTRIBUTE)).andReturn(null); - request.setAttribute(same(AsyncExecutionChain.CALLABLE_CHAIN_ATTRIBUTE), notNull()); + addAsyncManagerExpectations(request); expect(request.getCharacterEncoding()).andReturn(null); request.setCharacterEncoding(ENCODING); expect(request.getAttribute(FILTER_NAME + OncePerRequestFilter.ALREADY_FILTERED_SUFFIX)).andReturn(null); @@ -155,8 +152,7 @@ public class CharacterEncodingFilterTests extends TestCase { public void testWithIncompleteInitialization() throws Exception { HttpServletRequest request = createMock(HttpServletRequest.class); - expect(request.getAttribute(AsyncExecutionChain.CALLABLE_CHAIN_ATTRIBUTE)).andReturn(null); - request.setAttribute(same(AsyncExecutionChain.CALLABLE_CHAIN_ATTRIBUTE), notNull()); + addAsyncManagerExpectations(request); expect(request.getCharacterEncoding()).andReturn(null); request.setCharacterEncoding(ENCODING); expect(request.getAttribute(CharacterEncodingFilter.class.getName() + OncePerRequestFilter.ALREADY_FILTERED_SUFFIX)).andReturn(null); @@ -178,4 +174,11 @@ public class CharacterEncodingFilterTests extends TestCase { verify(filterChain); } + + private void addAsyncManagerExpectations(HttpServletRequest request) { + expect(request.getAttribute(AsyncWebUtils.WEB_ASYNC_MANAGER_ATTRIBUTE)).andReturn(null); + expectLastCall().anyTimes(); + request.setAttribute(same(AsyncWebUtils.WEB_ASYNC_MANAGER_ATTRIBUTE), notNull()); + expectLastCall().anyTimes(); + } } diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/AsyncHandlerInterceptor.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/AsyncHandlerInterceptor.java index 167059e5005..3ce34718351 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/AsyncHandlerInterceptor.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/AsyncHandlerInterceptor.java @@ -16,65 +16,45 @@ package org.springframework.web.servlet; -import java.util.concurrent.Callable; - import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; -import org.springframework.web.context.request.WebRequest; -import org.springframework.web.context.request.async.AbstractDelegatingCallable; - /** - * Extends {@link HanderInterceptor} with lifecycle methods specific to async - * request processing. + * Extends the HandlerInterceptor contract for scenarios where a handler may be + * executed asynchronously. Since the handler will complete execution in another + * thread, the results are not available in the current thread, and therefore the + * DispatcherServlet exits quickly and on its way out invokes + * {@link #afterConcurrentHandlingStarted(HttpServletRequest, HttpServletResponse)} + * instead of {@code postHandle} and {@code afterCompletion}. + * When the async handler execution completes, and the request is dispatched back + * for further processing, the DispatcherServlet will invoke {@code preHandle} + * again, as well as {@code postHandle} and {@code afterCompletion}. * - *

This is the sequence of events on the main thread in an async scenario: - *

    - *
  1. {@link #preHandle(WebRequest)} - *
  2. {@link #getAsyncCallable(WebRequest)} - *
  3. ... handler execution - *
  4. {@link #postHandleAsyncStarted(WebRequest)} - *
- * - *

This is the sequence of events on the async thread: - *

    - *
  1. Async {@link Callable#call()} (the {@code Callable} returned by {@code getAsyncCallable}) - *
  2. ... async handler execution - *
  3. {@link #postHandle(WebRequest, org.springframework.ui.ModelMap)} - *
  4. {@link #afterCompletion(WebRequest, Exception)} - *
+ *

Existing implementations should consider the fact that {@code preHandle} may + * be invoked twice before {@code postHandle} and {@code afterCompletion} are + * called if they don't implement this contract. Once before the start of concurrent + * handling and a second time as part of an asynchronous dispatch after concurrent + * handling is done. This may be not important in most cases but when some work + * needs to be done after concurrent handling starts (e.g. clearing thread locals) + * then this contract can be implemented. * * @author Rossen Stoyanchev * @since 3.2 + * + * @see org.springframework.web.context.request.async.WebAsyncManager */ public interface AsyncHandlerInterceptor extends HandlerInterceptor { /** - * Invoked after {@link #preHandle(WebRequest)} and before - * the handler is executed. The returned {@link Callable} is used only if - * handler execution leads to teh start of async processing. It is invoked - * the async thread before the request is handled fro. - *

Implementations can use this Callable to initialize - * ThreadLocal attributes on the async thread. - * @param request current HTTP request - * @param response current HTTP response - * @param handler chosen handler to execute, for type and/or instance examination - * @return a {@link Callable} instance or null - */ - AbstractDelegatingCallable getAsyncCallable(HttpServletRequest request, HttpServletResponse response, Object handler); - - /** - * Invoked after the execution of a handler but only if the handler started - * async processing instead of handling the request. Effectively this method - * is invoked instead of {@link #postHandle(WebRequest, org.springframework.ui.ModelMap)} - * on the way out of the main processing thread allowing implementations - * to ensure ThreadLocal attributes are cleared. The postHandle - * invocation is effectively delayed until after async processing when the - * request has actually been handled. - * @param request current HTTP request - * @param response current HTTP response - * @param handler chosen handler to execute, for type and/or instance examination + * Called instead of {@code postHandle} and {@code afterCompletion}, when the + * a handler is being executed concurrently. Implementations may use the provided + * request and response but should avoid modifying them in ways that would + * conflict with the concurrent execution of the handler. A typical use of + * this method would be to clean thread local variables. + * + * @param request the current request + * @param response the current response */ - void postHandleAfterAsyncStarted(HttpServletRequest request, HttpServletResponse response, Object handler); + void afterConcurrentHandlingStarted(HttpServletRequest request, HttpServletResponse response); } diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/DispatcherServlet.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/DispatcherServlet.java index 7119add6437..6e31065868d 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/DispatcherServlet.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/DispatcherServlet.java @@ -50,8 +50,8 @@ import org.springframework.util.ClassUtils; import org.springframework.util.StringUtils; import org.springframework.web.context.WebApplicationContext; import org.springframework.web.context.request.ServletWebRequest; -import org.springframework.web.context.request.async.AbstractDelegatingCallable; -import org.springframework.web.context.request.async.AsyncExecutionChain; +import org.springframework.web.context.request.async.WebAsyncManager; +import org.springframework.web.context.request.async.AsyncWebUtils; import org.springframework.web.multipart.MultipartException; import org.springframework.web.multipart.MultipartHttpServletRequest; import org.springframework.web.multipart.MultipartResolver; @@ -819,8 +819,9 @@ public class DispatcherServlet extends FrameworkServlet { if (logger.isDebugEnabled()) { String requestUri = urlPathHelper.getRequestUri(request); - logger.debug("DispatcherServlet with name '" + getServletName() + "' processing " + request.getMethod() + - " request for [" + requestUri + "]"); + String resumed = AsyncWebUtils.getAsyncManager(request).hasConcurrentResult() ? " resumed" : ""; + logger.debug("DispatcherServlet with name '" + getServletName() + "'" + resumed + + " processing " + request.getMethod() + " request for [" + requestUri + "]"); } // Keep a snapshot of the request attributes in case of an include, @@ -851,14 +852,11 @@ public class DispatcherServlet extends FrameworkServlet { request.setAttribute(OUTPUT_FLASH_MAP_ATTRIBUTE, new FlashMap()); request.setAttribute(FLASH_MAP_MANAGER_ATTRIBUTE, this.flashMapManager); - AsyncExecutionChain asyncChain = AsyncExecutionChain.getForCurrentRequest(request); - asyncChain.push(getServiceAsyncCallable(request, attributesSnapshot)); - try { doDispatch(request, response); } finally { - if (!asyncChain.pop()) { + if (AsyncWebUtils.getAsyncManager(request).isConcurrentHandlingStarted()) { return; } // Restore the original attribute snapshot, in case of an include. @@ -868,27 +866,6 @@ public class DispatcherServlet extends FrameworkServlet { } } - /** - * Create a Callable to complete doService() processing asynchronously. - */ - private AbstractDelegatingCallable getServiceAsyncCallable( - final HttpServletRequest request, final Map attributesSnapshot) { - - return new AbstractDelegatingCallable() { - public Object call() throws Exception { - if (logger.isDebugEnabled()) { - logger.debug("Resuming asynchronous processing of " + request.getMethod() + - " request for [" + urlPathHelper.getRequestUri(request) + "]"); - } - getNext().call(); - if (attributesSnapshot != null) { - restoreAttributesAfterInclude(request, attributesSnapshot); - } - return null; - } - }; - } - /** * Process the actual dispatching to the handler. *

The handler will be obtained by applying the servlet's HandlerMappings in order. @@ -903,9 +880,9 @@ public class DispatcherServlet extends FrameworkServlet { protected void doDispatch(HttpServletRequest request, HttpServletResponse response) throws Exception { HttpServletRequest processedRequest = request; HandlerExecutionChain mappedHandler = null; + boolean multipartRequestParsed = false; - AsyncExecutionChain asyncChain = AsyncExecutionChain.getForCurrentRequest(request); - boolean asyncStarted = false; + WebAsyncManager asyncManager = AsyncWebUtils.getAsyncManager(request); try { ModelAndView mv = null; @@ -913,6 +890,7 @@ public class DispatcherServlet extends FrameworkServlet { try { processedRequest = checkMultipart(request); + multipartRequestParsed = processedRequest != request; // Determine handler for the current request. mappedHandler = getHandler(processedRequest, false); @@ -942,18 +920,12 @@ public class DispatcherServlet extends FrameworkServlet { return; } - mappedHandler.pushInterceptorCallables(processedRequest, response); - asyncChain.push(getDispatchAsyncCallable(mappedHandler, request, response, processedRequest)); - try { // Actually invoke the handler. mv = ha.handle(processedRequest, response, mappedHandler.getHandler()); } finally { - asyncStarted = !asyncChain.pop(); - mappedHandler.popInterceptorCallables(processedRequest, response, asyncStarted); - if (asyncStarted) { - logger.debug("Exiting request thread and leaving the response open"); + if (asyncManager.isConcurrentHandlingStarted()) { return; } } @@ -973,11 +945,13 @@ public class DispatcherServlet extends FrameworkServlet { triggerAfterCompletionWithError(processedRequest, response, mappedHandler, err); } finally { - if (asyncStarted) { + if (asyncManager.isConcurrentHandlingStarted()) { + // Instead of postHandle and afterCompletion + mappedHandler.applyAfterConcurrentHandlingStarted(processedRequest, response); return; } // Clean up any resources used by a multipart request. - if (processedRequest != request) { + if (multipartRequestParsed) { cleanupMultipart(processedRequest); } } @@ -1027,50 +1001,16 @@ public class DispatcherServlet extends FrameworkServlet { } } + if (AsyncWebUtils.getAsyncManager(request).isConcurrentHandlingStarted()) { + // Concurrent handling started during a forward + return; + } + if (mappedHandler != null) { mappedHandler.triggerAfterCompletion(request, response, null); } } - /** - * Create a Callable to complete doDispatch processing asynchronously. - */ - private AbstractDelegatingCallable getDispatchAsyncCallable( - final HandlerExecutionChain mappedHandler, - final HttpServletRequest request, final HttpServletResponse response, - final HttpServletRequest processedRequest) throws Exception { - - return new AbstractDelegatingCallable() { - public Object call() throws Exception { - try { - ModelAndView mv = null; - Exception dispatchException = null; - try { - mv = (ModelAndView) getNext().call(); - applyDefaultViewName(processedRequest, mv); - mappedHandler.applyPostHandle(request, response, mv); - } - catch (Exception ex) { - dispatchException = ex; - } - processDispatchResult(processedRequest, response, mappedHandler, mv, dispatchException); - } - catch (Exception ex) { - triggerAfterCompletion(processedRequest, response, mappedHandler, ex); - } - catch (Error err) { - triggerAfterCompletionWithError(processedRequest, response, mappedHandler, err); - } - finally { - if (processedRequest != request) { - cleanupMultipart(processedRequest); - } - } - return null; - } - }; - } - /** * Build a LocaleContext for the given request, exposing the request's primary locale as current locale. *

The default implementation uses the dispatcher's LocaleResolver to obtain the current locale, @@ -1113,12 +1053,13 @@ public class DispatcherServlet extends FrameworkServlet { /** * Clean up any resources used by the given multipart request (if any). - * @param request current HTTP request + * @param servletRequest current HTTP request * @see MultipartResolver#cleanupMultipart */ - protected void cleanupMultipart(HttpServletRequest request) { - if (request instanceof MultipartHttpServletRequest) { - this.multipartResolver.cleanupMultipart((MultipartHttpServletRequest) request); + protected void cleanupMultipart(HttpServletRequest servletRequest) { + MultipartHttpServletRequest req = WebUtils.getNativeRequest(servletRequest, MultipartHttpServletRequest.class); + if (req != null) { + this.multipartResolver.cleanupMultipart(req); } } diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/FrameworkServlet.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/FrameworkServlet.java index b9258a056d4..0336487e46e 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/FrameworkServlet.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/FrameworkServlet.java @@ -47,8 +47,9 @@ import org.springframework.web.context.WebApplicationContext; import org.springframework.web.context.request.RequestAttributes; import org.springframework.web.context.request.RequestContextHolder; import org.springframework.web.context.request.ServletRequestAttributes; -import org.springframework.web.context.request.async.AbstractDelegatingCallable; -import org.springframework.web.context.request.async.AsyncExecutionChain; +import org.springframework.web.context.request.async.AsyncWebUtils; +import org.springframework.web.context.request.async.WebAsyncManager; +import org.springframework.web.context.request.async.WebAsyncManager.AsyncThreadInitializer; import org.springframework.web.context.support.ServletRequestHandledEvent; import org.springframework.web.context.support.WebApplicationContextUtils; import org.springframework.web.context.support.XmlWebApplicationContext; @@ -195,6 +196,7 @@ public abstract class FrameworkServlet extends HttpServletBean { private ArrayList> contextInitializers = new ArrayList>(); + /** * Create a new {@code FrameworkServlet} that will create its own internal web * application context based on defaults and values provided through servlet @@ -905,22 +907,55 @@ public abstract class FrameworkServlet extends HttpServletBean { initContextHolders(request, localeContext, requestAttributes); - AsyncExecutionChain chain = AsyncExecutionChain.getForCurrentRequest(request); - chain.push(getAsyncCallable(startTime, request, response, - previousLocaleContext, previousAttributes, localeContext, requestAttributes)); + WebAsyncManager asyncManager = AsyncWebUtils.getAsyncManager(request); + asyncManager.registerAsyncThreadInitializer(this.getClass().getName(), createAsyncThreadInitializer(request)); try { doService(request, response); } - catch (Throwable t) { - failureCause = t; + catch (ServletException ex) { + failureCause = ex; + throw ex; + } + catch (IOException ex) { + failureCause = ex; + throw ex; } + catch (Throwable ex) { + failureCause = ex; + throw new NestedServletException("Request processing failed", ex); + } + finally { resetContextHolders(request, previousLocaleContext, previousAttributes); - if (!chain.pop()) { - return; + if (requestAttributes != null) { + requestAttributes.requestCompleted(); + } + + if (logger.isDebugEnabled()) { + if (failureCause != null) { + this.logger.debug("Could not complete request", failureCause); + } else { + if (asyncManager.isConcurrentHandlingStarted()) { + if (logger.isDebugEnabled()) { + logger.debug("Leaving response open for concurrent processing"); + } + } + else { + this.logger.debug("Successfully completed request"); + } + } + } + if (this.publishEvents) { + // Whether or not we succeeded, publish an event. + long processingTime = System.currentTimeMillis() - startTime; + this.webApplicationContext.publishEvent( + new ServletRequestHandledEvent(this, + request.getRequestURI(), request.getRemoteAddr(), + request.getMethod(), getServletConfig().getServletName(), + WebUtils.getSessionId(request), getUsernameForRequest(request), + processingTime, failureCause)); } - finalizeProcessing(startTime, request, response, requestAttributes, failureCause); } } @@ -956,78 +991,14 @@ public abstract class FrameworkServlet extends HttpServletBean { } } - /** - * Log and re-throw unhandled exceptions, publish a ServletRequestHandledEvent, etc. - */ - private void finalizeProcessing(long startTime, HttpServletRequest request, HttpServletResponse response, - ServletRequestAttributes requestAttributes, Throwable t) throws ServletException, IOException { + private AsyncThreadInitializer createAsyncThreadInitializer(final HttpServletRequest request) { - Throwable failureCause = null; - try { - if (t != null) { - if (t instanceof ServletException) { - failureCause = t; - throw (ServletException) t; - } - else if (t instanceof IOException) { - failureCause = t; - throw (IOException) t; - } - else { - NestedServletException ex = new NestedServletException("Request processing failed", t); - failureCause = ex; - throw ex; - } + return new AsyncThreadInitializer() { + public void initialize() { + initContextHolders(request, buildLocaleContext(request), new ServletRequestAttributes(request)); } - } - finally { - if (requestAttributes != null) { - requestAttributes.requestCompleted(); - } - if (logger.isDebugEnabled()) { - if (failureCause != null) { - this.logger.debug("Could not complete request", failureCause); - } - else { - this.logger.debug("Successfully completed request"); - } - } - if (this.publishEvents) { - // Whether or not we succeeded, publish an event. - long processingTime = System.currentTimeMillis() - startTime; - this.webApplicationContext.publishEvent( - new ServletRequestHandledEvent(this, - request.getRequestURI(), request.getRemoteAddr(), - request.getMethod(), getServletConfig().getServletName(), - WebUtils.getSessionId(request), getUsernameForRequest(request), - processingTime, failureCause)); - } - } - } - - /** - * Create a Callable to use to complete processing in an async execution chain. - */ - private AbstractDelegatingCallable getAsyncCallable(final long startTime, - final HttpServletRequest request, final HttpServletResponse response, - final LocaleContext previousLocaleContext, final RequestAttributes previousAttributes, - final LocaleContext localeContext, final ServletRequestAttributes requestAttributes) { - - return new AbstractDelegatingCallable() { - public Object call() throws Exception { - initContextHolders(request, localeContext, requestAttributes); - Throwable unhandledFailure = null; - try { - getNext().call(); - } - catch (Throwable t) { - unhandledFailure = t; - } - finally { - resetContextHolders(request, previousLocaleContext, previousAttributes); - finalizeProcessing(startTime, request, response, requestAttributes, unhandledFailure); - } - return null; + public void reset() { + resetContextHolders(request, null, null); } }; } @@ -1061,6 +1032,7 @@ public abstract class FrameworkServlet extends HttpServletBean { protected abstract void doService(HttpServletRequest request, HttpServletResponse response) throws Exception; + /** * Close the WebApplicationContext of this servlet. * @see org.springframework.context.ConfigurableApplicationContext#close() @@ -1073,6 +1045,7 @@ public abstract class FrameworkServlet extends HttpServletBean { } } + /** * ApplicationListener endpoint that receives events from this servlet's WebApplicationContext * only, delegating to onApplicationEvent on the FrameworkServlet instance. diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/HandlerExecutionChain.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/HandlerExecutionChain.java index 6509a477a33..1e7323a2771 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/HandlerExecutionChain.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/HandlerExecutionChain.java @@ -26,8 +26,6 @@ import javax.servlet.http.HttpServletResponse; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.springframework.util.CollectionUtils; -import org.springframework.web.context.request.async.AbstractDelegatingCallable; -import org.springframework.web.context.request.async.AsyncExecutionChain; /** * Handler execution chain, consisting of handler object and any handler interceptors. @@ -49,7 +47,6 @@ public class HandlerExecutionChain { private int interceptorIndex = -1; - private int pushedCallableCount; /** * Create a new HandlerExecutionChain. @@ -140,27 +137,6 @@ public class HandlerExecutionChain { return true; } - void pushInterceptorCallables(HttpServletRequest request, HttpServletResponse response) { - if (getInterceptors() == null) { - return; - } - for (HandlerInterceptor interceptor : getInterceptors()) { - if (interceptor instanceof AsyncHandlerInterceptor) { - try { - AsyncHandlerInterceptor asyncInterceptor = (AsyncHandlerInterceptor) interceptor; - AbstractDelegatingCallable callable = asyncInterceptor.getAsyncCallable(request, response, this.handler); - if (callable != null) { - AsyncExecutionChain.getForCurrentRequest(request).push(callable); - this.pushedCallableCount++; - } - } - catch (Throwable ex) { - logger.error("HandlerInterceptor failed to return an async Callable", ex); - } - } - } - } - /** * Apply postHandle methods of registered interceptors. */ @@ -174,34 +150,6 @@ public class HandlerExecutionChain { } } - /** - * Remove pushed callables and apply postHandleAsyncStarted callbacks. - */ - void popInterceptorCallables(HttpServletRequest request, HttpServletResponse response, - boolean asyncStarted) throws Exception { - - if (getInterceptors() == null) { - return; - } - AsyncExecutionChain chain = AsyncExecutionChain.getForCurrentRequest(request); - for ( ; this.pushedCallableCount > 0; this.pushedCallableCount--) { - chain.pop(); - } - if (asyncStarted) { - for (int i = getInterceptors().length - 1; i >= 0; i--) { - HandlerInterceptor interceptor = getInterceptors()[i]; - if (interceptor instanceof AsyncHandlerInterceptor) { - try { - ((AsyncHandlerInterceptor) interceptor).postHandleAfterAsyncStarted(request, response, this.handler); - } - catch (Throwable ex) { - logger.error("HandlerInterceptor.postHandleAsyncStarted(..) failed", ex); - } - } - } - } - } - /** * Trigger afterCompletion callbacks on the mapped HandlerInterceptors. * Will just invoke afterCompletion for all interceptors whose preHandle invocation @@ -224,6 +172,25 @@ public class HandlerExecutionChain { } } + /** + * Apply afterConcurrentHandlerStarted callback on mapped AsyncHandlerInterceptors. + */ + void applyAfterConcurrentHandlingStarted(HttpServletRequest request, HttpServletResponse response) { + if (getInterceptors() == null) { + return; + } + for (int i = getInterceptors().length - 1; i >= 0; i--) { + if (interceptors[i] instanceof AsyncHandlerInterceptor) { + try { + AsyncHandlerInterceptor asyncInterceptor = (AsyncHandlerInterceptor) interceptors[i]; + asyncInterceptor.afterConcurrentHandlingStarted(request, response); + } + catch (Throwable ex) { + logger.error("Interceptor [" + interceptors[i] + "] failed in afterConcurrentHandlingStarted", ex); + } + } + } + } /** * Delegates to the handler's toString(). diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/HandlerInterceptor.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/HandlerInterceptor.java index c2c8deb729a..b93a3862b9a 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/HandlerInterceptor.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/HandlerInterceptor.java @@ -31,6 +31,14 @@ import javax.servlet.http.HttpServletResponse; * or common handler behavior like locale or theme changes. Its main purpose * is to allow for factoring out repetitive handler code. * + *

In an async processing scenario, the handler may be executed in a separate + * thread while the main thread exits without rendering or invoking the + * {@code postHandle} and {@code afterCompletion} callbacks. When concurrent + * handler execution completes, the request is dispatched back in order to + * proceed with rendering the model and all methods of this contract are invoked + * again. For further options and comments see + * {@code org.springframework.web.servlet.HandlerInterceptor} + * *

Typically an interceptor chain is defined per HandlerMapping bean, * sharing its granularity. To be able to apply a certain interceptor chain * to a group of handlers, one needs to map the desired handlers via one diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/handler/WebRequestHandlerInterceptorAdapter.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/handler/WebRequestHandlerInterceptorAdapter.java index 250baedcf09..d50885ab0f2 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/handler/WebRequestHandlerInterceptorAdapter.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/handler/WebRequestHandlerInterceptorAdapter.java @@ -21,7 +21,6 @@ import javax.servlet.http.HttpServletResponse; import org.springframework.util.Assert; import org.springframework.web.context.request.WebRequestInterceptor; -import org.springframework.web.context.request.async.AbstractDelegatingCallable; import org.springframework.web.context.request.async.AsyncWebRequestInterceptor; import org.springframework.web.servlet.AsyncHandlerInterceptor; import org.springframework.web.servlet.ModelAndView; @@ -57,25 +56,6 @@ public class WebRequestHandlerInterceptorAdapter implements AsyncHandlerIntercep return true; } - public AbstractDelegatingCallable getAsyncCallable(HttpServletRequest request, - HttpServletResponse response, Object handler) { - - if (this.requestInterceptor instanceof AsyncWebRequestInterceptor) { - AsyncWebRequestInterceptor asyncInterceptor = (AsyncWebRequestInterceptor) this.requestInterceptor; - DispatcherServletWebRequest webRequest = new DispatcherServletWebRequest(request, response); - return asyncInterceptor.getAsyncCallable(webRequest); - } - return null; - } - - public void postHandleAfterAsyncStarted(HttpServletRequest request, HttpServletResponse response, Object handler) { - if (this.requestInterceptor instanceof AsyncWebRequestInterceptor) { - AsyncWebRequestInterceptor asyncInterceptor = (AsyncWebRequestInterceptor) this.requestInterceptor; - DispatcherServletWebRequest webRequest = new DispatcherServletWebRequest(request, response); - asyncInterceptor.postHandleAsyncStarted(webRequest); - } - } - public void postHandle(HttpServletRequest request, HttpServletResponse response, Object handler, ModelAndView modelAndView) throws Exception { @@ -89,4 +69,12 @@ public class WebRequestHandlerInterceptorAdapter implements AsyncHandlerIntercep this.requestInterceptor.afterCompletion(new DispatcherServletWebRequest(request, response), ex); } + public void afterConcurrentHandlingStarted(HttpServletRequest request, HttpServletResponse response) { + if (this.requestInterceptor instanceof AsyncWebRequestInterceptor) { + AsyncWebRequestInterceptor asyncInterceptor = (AsyncWebRequestInterceptor) this.requestInterceptor; + DispatcherServletWebRequest webRequest = new DispatcherServletWebRequest(request, response); + asyncInterceptor.afterConcurrentHandlingStarted(webRequest); + } + } + } diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/AsyncMethodReturnValueHandler.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/AsyncMethodReturnValueHandler.java index d07ff2e3f69..31d432af2c3 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/AsyncMethodReturnValueHandler.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/AsyncMethodReturnValueHandler.java @@ -24,8 +24,9 @@ import javax.servlet.ServletRequest; import org.springframework.core.MethodParameter; import org.springframework.util.Assert; import org.springframework.web.context.request.NativeWebRequest; -import org.springframework.web.context.request.async.AsyncExecutionChain; import org.springframework.web.context.request.async.DeferredResult; +import org.springframework.web.context.request.async.WebAsyncManager; +import org.springframework.web.context.request.async.AsyncWebUtils; import org.springframework.web.method.support.HandlerMethodReturnValueHandler; import org.springframework.web.method.support.ModelAndViewContainer; @@ -52,18 +53,15 @@ public class AsyncMethodReturnValueHandler implements HandlerMethodReturnValueHa Assert.notNull(returnValue, "A Callable or a DeferredValue is required"); - mavContainer.setRequestHandled(true); - Class paramType = returnType.getParameterType(); ServletRequest servletRequest = webRequest.getNativeRequest(ServletRequest.class); - AsyncExecutionChain chain = AsyncExecutionChain.getForCurrentRequest(servletRequest); + WebAsyncManager asyncManager = AsyncWebUtils.getAsyncManager(servletRequest); if (Callable.class.isAssignableFrom(paramType)) { - chain.setLastCallable((Callable) returnValue); - chain.startCallableProcessing(); + asyncManager.startCallableProcessing((Callable) returnValue, mavContainer); } else if (DeferredResult.class.isAssignableFrom(paramType)) { - chain.startDeferredResultProcessing((DeferredResult) returnValue); + asyncManager.startDeferredResultProcessing((DeferredResult) returnValue, mavContainer); } else { // should never happen.. diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/RequestMappingHandlerAdapter.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/RequestMappingHandlerAdapter.java index e113f147e05..139e874a806 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/RequestMappingHandlerAdapter.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/RequestMappingHandlerAdapter.java @@ -61,10 +61,10 @@ import org.springframework.web.bind.support.WebDataBinderFactory; import org.springframework.web.context.request.NativeWebRequest; import org.springframework.web.context.request.ServletWebRequest; import org.springframework.web.context.request.WebRequest; -import org.springframework.web.context.request.async.AbstractDelegatingCallable; -import org.springframework.web.context.request.async.AsyncExecutionChain; import org.springframework.web.context.request.async.AsyncWebRequest; -import org.springframework.web.context.request.async.NoOpAsyncWebRequest; +import org.springframework.web.context.request.async.NoSupportAsyncWebRequest; +import org.springframework.web.context.request.async.WebAsyncManager; +import org.springframework.web.context.request.async.AsyncWebUtils; import org.springframework.web.method.HandlerMethod; import org.springframework.web.method.HandlerMethodSelector; import org.springframework.web.method.annotation.ErrorsMethodArgumentResolver; @@ -146,7 +146,7 @@ public class RequestMappingHandlerAdapter extends AbstractHandlerMethodAdapter i private final Map, Set> modelFactoryCache = new ConcurrentHashMap, Set>(); - private AsyncTaskExecutor taskExecutor = new SimpleAsyncTaskExecutor(); + private AsyncTaskExecutor taskExecutor = new SimpleAsyncTaskExecutor("MvcAsync"); private Long asyncRequestTimeout; @@ -652,48 +652,36 @@ public class RequestMappingHandlerAdapter extends AbstractHandlerMethodAdapter i modelFactory.initModel(webRequest, mavContainer, requestMappingMethod); mavContainer.setIgnoreDefaultModelOnRedirect(this.ignoreDefaultModelOnRedirect); - AsyncExecutionChain chain = AsyncExecutionChain.getForCurrentRequest(request); - chain.setAsyncWebRequest(createAsyncWebRequest(request, response)); - chain.setTaskExecutor(this.taskExecutor); - chain.push(getAsyncCallable(mavContainer, modelFactory, webRequest)); + AsyncWebRequest asyncWebRequest = createAsyncWebRequest(request, response); + asyncWebRequest.setTimeout(this.asyncRequestTimeout); - try { - requestMappingMethod.invokeAndHandle(webRequest, mavContainer); - } - finally { - if (!chain.pop()) { - return null; + final WebAsyncManager asyncManager = AsyncWebUtils.getAsyncManager(request); + asyncManager.setTaskExecutor(this.taskExecutor); + asyncManager.setAsyncWebRequest(asyncWebRequest); + + if (asyncManager.hasConcurrentResult()) { + Object result = asyncManager.getConcurrentResult(); + mavContainer = (ModelAndViewContainer) asyncManager.getConcurrentResultContext()[0]; + asyncManager.resetConcurrentResult(); + + if (logger.isDebugEnabled()) { + logger.debug("Found concurrent result value [" + result + "]"); } + requestMappingMethod = requestMappingMethod.wrapConcurrentProcessingResult(result); } - return getModelAndView(mavContainer, modelFactory, webRequest); - } + requestMappingMethod.invokeAndHandle(webRequest, mavContainer); - private AsyncWebRequest createAsyncWebRequest(HttpServletRequest request, HttpServletResponse response) { - if (ClassUtils.hasMethod(ServletRequest.class, "startAsync")) { - AsyncWebRequest asyncRequest = instantiateStandardServletAsyncWebRequest(request, response); - asyncRequest.setTimeout(this.asyncRequestTimeout); - return asyncRequest; - } - else { - return new NoOpAsyncWebRequest(request, response); + if (asyncManager.isConcurrentHandlingStarted()) { + return null; } - } - private AsyncWebRequest instantiateStandardServletAsyncWebRequest(HttpServletRequest request, HttpServletResponse response) { - String className = "org.springframework.web.context.request.async.StandardServletAsyncWebRequest"; - try { - Class clazz = ClassUtils.forName(className, this.getClass().getClassLoader()); - Constructor constructor = clazz.getConstructor(HttpServletRequest.class, HttpServletResponse.class); - return (AsyncWebRequest) BeanUtils.instantiateClass(constructor , request, response); - } - catch (Throwable t) { - throw new IllegalStateException("Failed to instantiate StandardServletAsyncWebRequest", t); - } + return getModelAndView(mavContainer, modelFactory, webRequest); } - private ServletInvocableHandlerMethod createRequestMappingMethod(HandlerMethod handlerMethod, - WebDataBinderFactory binderFactory) { + private ServletInvocableHandlerMethod createRequestMappingMethod( + HandlerMethod handlerMethod, WebDataBinderFactory binderFactory) { + ServletInvocableHandlerMethod requestMethod; requestMethod = new ServletInvocableHandlerMethod(handlerMethod.getBean(), handlerMethod.getMethod()); requestMethod.setHandlerMethodArgumentResolvers(this.argumentResolvers); @@ -753,18 +741,21 @@ public class RequestMappingHandlerAdapter extends AbstractHandlerMethodAdapter i return new ServletRequestDataBinderFactory(binderMethods, getWebBindingInitializer()); } - /** - * Create a Callable to produce a ModelAndView asynchronously. - */ - private AbstractDelegatingCallable getAsyncCallable(final ModelAndViewContainer mavContainer, - final ModelFactory modelFactory, final NativeWebRequest webRequest) { + private AsyncWebRequest createAsyncWebRequest(HttpServletRequest request, HttpServletResponse response) { + return ClassUtils.hasMethod(ServletRequest.class, "startAsync") ? + createStandardServletAsyncWebRequest(request, response) : new NoSupportAsyncWebRequest(request, response); + } - return new AbstractDelegatingCallable() { - public Object call() throws Exception { - getNext().call(); - return getModelAndView(mavContainer, modelFactory, webRequest); - } - }; + private AsyncWebRequest createStandardServletAsyncWebRequest(HttpServletRequest request, HttpServletResponse response) { + try { + String className = "org.springframework.web.context.request.async.StandardServletAsyncWebRequest"; + Class clazz = ClassUtils.forName(className, this.getClass().getClassLoader()); + Constructor constructor = clazz.getConstructor(HttpServletRequest.class, HttpServletResponse.class); + return (AsyncWebRequest) BeanUtils.instantiateClass(constructor, request, response); + } + catch (Throwable t) { + throw new IllegalStateException("Failed to instantiate StandardServletAsyncWebRequest", t); + } } private ModelAndView getModelAndView(ModelAndViewContainer mavContainer, diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ServletInvocableHandlerMethod.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ServletInvocableHandlerMethod.java index def48a6ea10..657537967bc 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ServletInvocableHandlerMethod.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ServletInvocableHandlerMethod.java @@ -26,13 +26,12 @@ import org.springframework.util.ClassUtils; import org.springframework.util.StringUtils; import org.springframework.web.bind.annotation.ResponseStatus; import org.springframework.web.context.request.ServletWebRequest; -import org.springframework.web.context.request.async.AbstractDelegatingCallable; -import org.springframework.web.context.request.async.AsyncExecutionChain; import org.springframework.web.method.support.HandlerMethodReturnValueHandler; import org.springframework.web.method.support.HandlerMethodReturnValueHandlerComposite; import org.springframework.web.method.support.InvocableHandlerMethod; import org.springframework.web.method.support.ModelAndViewContainer; import org.springframework.web.servlet.View; +import org.springframework.web.util.NestedServletException; /** * Extends {@link InvocableHandlerMethod} with the ability to handle return @@ -108,9 +107,6 @@ public class ServletInvocableHandlerMethod extends InvocableHandlerMethod { mavContainer.setRequestHandled(false); - AsyncExecutionChain chain = AsyncExecutionChain.getForCurrentRequest(webRequest.getRequest()); - chain.push(geAsyncCallable(webRequest, mavContainer, providedArgs)); - try { this.returnValueHandlers.handleReturnValue(returnValue, getReturnValueType(returnValue), mavContainer, webRequest); } @@ -120,33 +116,6 @@ public class ServletInvocableHandlerMethod extends InvocableHandlerMethod { } throw ex; } - finally { - chain.pop(); - } - } - - /** - * Create a Callable to populate the ModelAndViewContainer asynchronously. - */ - private AbstractDelegatingCallable geAsyncCallable(final ServletWebRequest webRequest, - final ModelAndViewContainer mavContainer, final Object... providedArgs) { - - return new AbstractDelegatingCallable() { - public Object call() throws Exception { - mavContainer.setRequestHandled(false); - new CallableHandlerMethod(getNext()).invokeAndHandle(webRequest, mavContainer, providedArgs); - return null; - } - }; - } - - private String getReturnValueHandlingErrorMessage(String message, Object returnValue) { - StringBuilder sb = new StringBuilder(message); - if (returnValue != null) { - sb.append(" [type=" + returnValue.getClass().getName() + "] "); - } - sb.append("[value=" + returnValue + "]"); - return getDetailedErrorMessage(sb.toString()); } /** @@ -184,11 +153,39 @@ public class ServletInvocableHandlerMethod extends InvocableHandlerMethod { return responseStatus != null; } + private String getReturnValueHandlingErrorMessage(String message, Object returnValue) { + StringBuilder sb = new StringBuilder(message); + if (returnValue != null) { + sb.append(" [type=" + returnValue.getClass().getName() + "] "); + } + sb.append("[value=" + returnValue + "]"); + return getDetailedErrorMessage(sb.toString()); + } + + /** + * Return a ServletInvocableHandlerMethod that will process the value returned + * from an async operation essentially either applying return value handling or + * raising an exception if the end result is an Exception. + */ + ServletInvocableHandlerMethod wrapConcurrentProcessingResult(final Object result) { + + return new CallableHandlerMethod(new Callable() { + + public Object call() throws Exception { + if (result instanceof Exception) { + throw (Exception) result; + } + else if (result instanceof Throwable) { + throw new NestedServletException("Async processing failed", (Throwable) result); + } + return result; + } + }); + } + /** - * Wraps the Callable returned from a HandlerMethod so may be invoked just - * like the HandlerMethod with the same return value handling guarantees. - * Method-level annotations must be on the HandlerMethod, not the Callable. + * Wrap a Callable as a ServletInvocableHandlerMethod inheriting method-level annotations. */ private class CallableHandlerMethod extends ServletInvocableHandlerMethod { diff --git a/spring-webmvc/src/test/java/org/springframework/mock/web/MockAsyncContext.java b/spring-webmvc/src/test/java/org/springframework/mock/web/MockAsyncContext.java new file mode 100644 index 00000000000..9607213e456 --- /dev/null +++ b/spring-webmvc/src/test/java/org/springframework/mock/web/MockAsyncContext.java @@ -0,0 +1,133 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.mock.web; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import javax.servlet.AsyncContext; +import javax.servlet.AsyncEvent; +import javax.servlet.AsyncListener; +import javax.servlet.DispatcherType; +import javax.servlet.ServletContext; +import javax.servlet.ServletException; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; + +import org.springframework.beans.BeanUtils; +import org.springframework.web.util.WebUtils; + +/** + * Mock implementation of the {@link AsyncContext} interface. + * + * @author Rossen Stoyanchev + * @since 3.2 + */ +public class MockAsyncContext implements AsyncContext { + + private final ServletRequest request; + + private final ServletResponse response; + + private final MockHttpServletRequest mockRequest; + + private final List listeners = new ArrayList(); + + private String dispatchPath; + + private long timeout = 10 * 60 * 1000L; + + public MockAsyncContext(ServletRequest request, ServletResponse response) { + this.request = request; + this.response = response; + this.mockRequest = WebUtils.getNativeRequest(request, MockHttpServletRequest.class); + } + + public ServletRequest getRequest() { + return this.request; + } + + public ServletResponse getResponse() { + return this.response; + } + + public boolean hasOriginalRequestAndResponse() { + return false; + } + + public String getDispatchPath() { + return this.dispatchPath; + } + + public void dispatch() { + dispatch(null); + } + + public void dispatch(String path) { + dispatch(null, path); + } + + public void dispatch(ServletContext context, String path) { + this.dispatchPath = path; + if (this.mockRequest != null) { + this.mockRequest.setDispatcherType(DispatcherType.ASYNC); + this.mockRequest.setAsyncStarted(false); + } + } + + public void complete() { + if (this.mockRequest != null) { + this.mockRequest.setAsyncStarted(false); + } + for (AsyncListener listener : this.listeners) { + try { + listener.onComplete(new AsyncEvent(this, this.request, this.response)); + } + catch (IOException e) { + throw new IllegalStateException("AsyncListener failure", e); + } + } + } + + public void start(Runnable run) { + } + + public List getListeners() { + return this.listeners; + } + + public void addListener(AsyncListener listener) { + this.listeners.add(listener); + } + + public void addListener(AsyncListener listener, ServletRequest request, ServletResponse response) { + this.listeners.add(listener); + } + + public T createListener(Class clazz) throws ServletException { + return BeanUtils.instantiateClass(clazz); + } + + public long getTimeout() { + return this.timeout; + } + + public void setTimeout(long timeout) { + this.timeout = timeout; + } + +} diff --git a/spring-webmvc/src/test/java/org/springframework/mock/web/MockHttpServletRequest.java b/spring-webmvc/src/test/java/org/springframework/mock/web/MockHttpServletRequest.java index 5d23662a3fc..debe0165e82 100644 --- a/spring-webmvc/src/test/java/org/springframework/mock/web/MockHttpServletRequest.java +++ b/spring-webmvc/src/test/java/org/springframework/mock/web/MockHttpServletRequest.java @@ -102,7 +102,7 @@ public class MockHttpServletRequest implements HttpServletRequest { public static final String DEFAULT_REMOTE_HOST = "localhost"; private static final String CONTENT_TYPE_HEADER = "Content-Type"; - + private static final String CHARSET_PREFIX = "charset="; @@ -190,6 +190,14 @@ public class MockHttpServletRequest implements HttpServletRequest { private boolean requestedSessionIdFromURL = false; + private boolean asyncSupported = false; + + private boolean asyncStarted = false; + + private MockAsyncContext asyncContext; + + private DispatcherType dispatcherType = DispatcherType.REQUEST; + //--------------------------------------------------------------------- // Constructors @@ -312,7 +320,7 @@ public class MockHttpServletRequest implements HttpServletRequest { this.characterEncoding = characterEncoding; updateContentTypeHeader(); } - + private void updateContentTypeHeader() { if (this.contentType != null) { StringBuilder sb = new StringBuilder(this.contentType); @@ -679,7 +687,7 @@ public class MockHttpServletRequest implements HttpServletRequest { } doAddHeaderValue(name, value, false); } - + @SuppressWarnings("rawtypes") private void doAddHeaderValue(String name, Object value, boolean replace) { HeaderValueHolder header = HeaderValueHolder.getByName(this.headers, name); @@ -898,33 +906,54 @@ public class MockHttpServletRequest implements HttpServletRequest { //--------------------------------------------------------------------- public AsyncContext getAsyncContext() { - throw new UnsupportedOperationException(); + return this.asyncContext; + } + + public void setAsyncContext(MockAsyncContext asyncContext) { + this.asyncContext = asyncContext; } public DispatcherType getDispatcherType() { - throw new UnsupportedOperationException(); + return this.dispatcherType; + } + + public void setDispatcherType(DispatcherType dispatcherType) { + this.dispatcherType = dispatcherType; + } + + public void setAsyncSupported(boolean asyncSupported) { + this.asyncSupported = asyncSupported; } public boolean isAsyncSupported() { - throw new UnsupportedOperationException(); + return this.asyncSupported; } public AsyncContext startAsync() { - throw new UnsupportedOperationException(); + return startAsync(this, null); } - public AsyncContext startAsync(ServletRequest arg0, ServletResponse arg1) { - throw new UnsupportedOperationException(); + public AsyncContext startAsync(ServletRequest request, ServletResponse response) { + if (!this.asyncSupported) { + throw new IllegalStateException("Async not supported"); + } + this.asyncStarted = true; + this.asyncContext = new MockAsyncContext(request, response); + return this.asyncContext; + } + + public void setAsyncStarted(boolean asyncStarted) { + this.asyncStarted = asyncStarted; } public boolean isAsyncStarted() { - throw new UnsupportedOperationException(); + return this.asyncStarted; } public boolean authenticate(HttpServletResponse arg0) throws IOException, ServletException { throw new UnsupportedOperationException(); } - + public void addPart(Part part) { parts.put(part.getName(), part); } diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/HandlerExecutionChainTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/HandlerExecutionChainTests.java index 33463473bcf..d537ec71587 100644 --- a/spring-webmvc/src/test/java/org/springframework/web/servlet/HandlerExecutionChainTests.java +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/HandlerExecutionChainTests.java @@ -26,7 +26,6 @@ import org.junit.Before; import org.junit.Test; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; -import org.springframework.web.context.request.async.AbstractDelegatingCallable; /** * A test fixture with HandlerExecutionChain and mock handler interceptors. @@ -74,10 +73,6 @@ public class HandlerExecutionChainTests { expect(this.interceptor2.preHandle(this.request, this.response, this.handler)).andReturn(true); expect(this.interceptor3.preHandle(this.request, this.response, this.handler)).andReturn(true); - expect(this.interceptor1.getAsyncCallable(request, response, this.handler)).andReturn(new TestAsyncCallable()); - expect(this.interceptor2.getAsyncCallable(request, response, this.handler)).andReturn(new TestAsyncCallable()); - expect(this.interceptor3.getAsyncCallable(request, response, this.handler)).andReturn(new TestAsyncCallable()); - this.interceptor1.postHandle(this.request, this.response, this.handler, mav); this.interceptor2.postHandle(this.request, this.response, this.handler, mav); this.interceptor3.postHandle(this.request, this.response, this.handler, mav); @@ -89,7 +84,6 @@ public class HandlerExecutionChainTests { replay(this.interceptor1, this.interceptor2, this.interceptor3); this.chain.applyPreHandle(request, response); - this.chain.pushInterceptorCallables(request, response); this.chain.applyPostHandle(request, response, mav); this.chain.triggerAfterCompletion(this.request, this.response, null); @@ -104,28 +98,14 @@ public class HandlerExecutionChainTests { expect(this.interceptor2.preHandle(this.request, this.response, this.handler)).andReturn(true); expect(this.interceptor3.preHandle(this.request, this.response, this.handler)).andReturn(true); - expect(this.interceptor1.getAsyncCallable(request, response, this.handler)).andReturn(new TestAsyncCallable()); - expect(this.interceptor2.getAsyncCallable(request, response, this.handler)).andReturn(new TestAsyncCallable()); - expect(this.interceptor3.getAsyncCallable(request, response, this.handler)).andReturn(new TestAsyncCallable()); - - this.interceptor1.postHandleAfterAsyncStarted(request, response, this.handler); - this.interceptor2.postHandleAfterAsyncStarted(request, response, this.handler); - this.interceptor3.postHandleAfterAsyncStarted(request, response, this.handler); - - this.interceptor1.postHandle(this.request, this.response, this.handler, mav); - this.interceptor2.postHandle(this.request, this.response, this.handler, mav); - this.interceptor3.postHandle(this.request, this.response, this.handler, mav); - - this.interceptor3.afterCompletion(this.request, this.response, this.handler, null); - this.interceptor2.afterCompletion(this.request, this.response, this.handler, null); - this.interceptor1.afterCompletion(this.request, this.response, this.handler, null); + this.interceptor1.afterConcurrentHandlingStarted(request, response); + this.interceptor2.afterConcurrentHandlingStarted(request, response); + this.interceptor3.afterConcurrentHandlingStarted(request, response); replay(this.interceptor1, this.interceptor2, this.interceptor3); this.chain.applyPreHandle(request, response); - this.chain.pushInterceptorCallables(request, response); - this.chain.popInterceptorCallables(request, response, true); - this.chain.applyPostHandle(request, response, mav); + this.chain.applyAfterConcurrentHandlingStarted(request, response); this.chain.triggerAfterCompletion(this.request, this.response, null); verify(this.interceptor1, this.interceptor2, this.interceptor3); @@ -196,12 +176,4 @@ public class HandlerExecutionChainTests { verify(this.interceptor1, this.interceptor2, this.interceptor3); } - - private static class TestAsyncCallable extends AbstractDelegatingCallable { - - public Object call() throws Exception { - return null; - } - } - } diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/RequestPartIntegrationTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/RequestPartIntegrationTests.java index fe203d38202..f5730fcb690 100644 --- a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/RequestPartIntegrationTests.java +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/RequestPartIntegrationTests.java @@ -21,14 +21,14 @@ import static org.junit.Assert.assertEquals; import java.net.URI; import java.util.Arrays; +import org.eclipse.jetty.server.Server; +import org.eclipse.jetty.servlet.ServletContextHandler; +import org.eclipse.jetty.servlet.ServletHolder; import org.junit.AfterClass; import org.junit.Before; import org.junit.BeforeClass; import org.junit.Ignore; import org.junit.Test; -import org.mortbay.jetty.Server; -import org.mortbay.jetty.servlet.Context; -import org.mortbay.jetty.servlet.ServletHolder; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.core.io.ClassPathResource; @@ -59,8 +59,8 @@ import org.springframework.web.servlet.config.annotation.EnableWebMvc; import org.springframework.web.servlet.config.annotation.WebMvcConfigurerAdapter; /** - * Test access to parts of a multipart request with {@link RequestPart}. - * + * Test access to parts of a multipart request with {@link RequestPart}. + * * @author Rossen Stoyanchev */ public class RequestPartIntegrationTests { @@ -71,39 +71,41 @@ public class RequestPartIntegrationTests { private static String baseUrl; - + @BeforeClass public static void startServer() throws Exception { - + int port = FreePortScanner.getFreePort(); baseUrl = "http://localhost:" + port; server = new Server(port); - Context context = new Context(server, "/"); + ServletContextHandler handler = new ServletContextHandler(); + handler.setContextPath("/"); Class config = CommonsMultipartResolverTestConfig.class; ServletHolder commonsResolverServlet = new ServletHolder(DispatcherServlet.class); commonsResolverServlet.setInitParameter("contextConfigLocation", config.getName()); commonsResolverServlet.setInitParameter("contextClass", AnnotationConfigWebApplicationContext.class.getName()); - context.addServlet(commonsResolverServlet, "/commons-resolver/*"); + handler.addServlet(commonsResolverServlet, "/commons-resolver/*"); config = StandardMultipartResolverTestConfig.class; ServletHolder standardResolverServlet = new ServletHolder(DispatcherServlet.class); standardResolverServlet.setInitParameter("contextConfigLocation", config.getName()); standardResolverServlet.setInitParameter("contextClass", AnnotationConfigWebApplicationContext.class.getName()); - context.addServlet(standardResolverServlet, "/standard-resolver/*"); + handler.addServlet(standardResolverServlet, "/standard-resolver/*"); - // TODO: add Servlet 3.0 test case without MultipartResolver + // TODO: add Servlet 3.0 test case without MultipartResolver + server.setHandler(handler); server.start(); } - + @Before public void setUp() { XmlAwareFormHttpMessageConverter converter = new XmlAwareFormHttpMessageConverter(); converter.setPartConverters(Arrays.>asList( new ResourceHttpMessageConverter(), new MappingJacksonHttpMessageConverter())); - + restTemplate = new RestTemplate(new HttpComponentsClientHttpRequestFactory()); restTemplate.setMessageConverters(Arrays.>asList(converter)); } @@ -115,7 +117,7 @@ public class RequestPartIntegrationTests { } } - + @Test public void commonsMultipartResolver() throws Exception { testCreate(baseUrl + "/commons-resolver/test"); @@ -147,7 +149,7 @@ public class RequestPartIntegrationTests { return new RequestPartTestController(); } } - + @Configuration static class CommonsMultipartResolverTestConfig extends RequestPartTestConfig { @@ -166,7 +168,6 @@ public class RequestPartIntegrationTests { } } - @SuppressWarnings("unused") @Controller private static class RequestPartTestController { @@ -178,10 +179,10 @@ public class RequestPartIntegrationTests { return new ResponseEntity(headers, HttpStatus.CREATED); } } - + @SuppressWarnings("unused") private static class TestData { - + private String name; public TestData() { @@ -200,5 +201,5 @@ public class RequestPartIntegrationTests { this.name = name; } } - + }