From fc2f3ecf44ae3e43ebbdcbce4b446a07102de086 Mon Sep 17 00:00:00 2001 From: Rossen Stoyanchev Date: Tue, 9 Oct 2018 11:51:21 -0400 Subject: [PATCH] More defensive check for MockAsyncContext Avoid automatically unwrapping the request in TestDispatcherServlet, if we find the MockAsyncContext. Issue: SPR-17353 --- .../web/servlet/TestDispatcherServlet.java | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/spring-test/src/main/java/org/springframework/test/web/servlet/TestDispatcherServlet.java b/spring-test/src/main/java/org/springframework/test/web/servlet/TestDispatcherServlet.java index 4b6e4e6bccf..8d87cf48d77 100644 --- a/spring-test/src/main/java/org/springframework/test/web/servlet/TestDispatcherServlet.java +++ b/spring-test/src/main/java/org/springframework/test/web/servlet/TestDispatcherServlet.java @@ -71,13 +71,22 @@ final class TestDispatcherServlet extends DispatcherServlet { super.service(request, response); if (request.getAsyncContext() != null) { - MockHttpServletRequest mockRequest = WebUtils.getNativeRequest(request, MockHttpServletRequest.class); - Assert.notNull(mockRequest, "Expected MockHttpServletRequest"); - MockAsyncContext mockAsyncContext = ((MockAsyncContext) mockRequest.getAsyncContext()); - Assert.notNull(mockAsyncContext, "MockAsyncContext not found. Did request wrapper not delegate startAsync?"); + MockAsyncContext asyncContext; + if (request.getAsyncContext() instanceof MockAsyncContext) { + asyncContext = (MockAsyncContext) request.getAsyncContext(); + } + else { + MockHttpServletRequest mockRequest = WebUtils.getNativeRequest(request, MockHttpServletRequest.class); + Assert.notNull(mockRequest, "Expected MockHttpServletRequest"); + asyncContext = (MockAsyncContext) mockRequest.getAsyncContext(); + Assert.notNull(asyncContext, () -> + "Outer request wrapper " + request.getClass().getName() + " has an AsyncContext," + + "but it is not a MockAsyncContext, while the nested " + + mockRequest.getClass().getName() + " does not have an AsyncContext at all."); + } CountDownLatch dispatchLatch = new CountDownLatch(1); - mockAsyncContext.addDispatchHandler(dispatchLatch::countDown); + asyncContext.addDispatchHandler(dispatchLatch::countDown); getMvcResult(request).setAsyncDispatchLatch(dispatchLatch); } }