diff --git a/spring-test/src/main/java/org/springframework/test/web/servlet/TestDispatcherServlet.java b/spring-test/src/main/java/org/springframework/test/web/servlet/TestDispatcherServlet.java index 4b6e4e6bccf..8d87cf48d77 100644 --- a/spring-test/src/main/java/org/springframework/test/web/servlet/TestDispatcherServlet.java +++ b/spring-test/src/main/java/org/springframework/test/web/servlet/TestDispatcherServlet.java @@ -71,13 +71,22 @@ final class TestDispatcherServlet extends DispatcherServlet { super.service(request, response); if (request.getAsyncContext() != null) { - MockHttpServletRequest mockRequest = WebUtils.getNativeRequest(request, MockHttpServletRequest.class); - Assert.notNull(mockRequest, "Expected MockHttpServletRequest"); - MockAsyncContext mockAsyncContext = ((MockAsyncContext) mockRequest.getAsyncContext()); - Assert.notNull(mockAsyncContext, "MockAsyncContext not found. Did request wrapper not delegate startAsync?"); + MockAsyncContext asyncContext; + if (request.getAsyncContext() instanceof MockAsyncContext) { + asyncContext = (MockAsyncContext) request.getAsyncContext(); + } + else { + MockHttpServletRequest mockRequest = WebUtils.getNativeRequest(request, MockHttpServletRequest.class); + Assert.notNull(mockRequest, "Expected MockHttpServletRequest"); + asyncContext = (MockAsyncContext) mockRequest.getAsyncContext(); + Assert.notNull(asyncContext, () -> + "Outer request wrapper " + request.getClass().getName() + " has an AsyncContext," + + "but it is not a MockAsyncContext, while the nested " + + mockRequest.getClass().getName() + " does not have an AsyncContext at all."); + } CountDownLatch dispatchLatch = new CountDownLatch(1); - mockAsyncContext.addDispatchHandler(dispatchLatch::countDown); + asyncContext.addDispatchHandler(dispatchLatch::countDown); getMvcResult(request).setAsyncDispatchLatch(dispatchLatch); } }