diff --git a/spring-test/src/main/java/org/springframework/mock/web/MockAsyncContext.java b/spring-test/src/main/java/org/springframework/mock/web/MockAsyncContext.java index 12a58a77d22..588ae02c730 100644 --- a/spring-test/src/main/java/org/springframework/mock/web/MockAsyncContext.java +++ b/spring-test/src/main/java/org/springframework/mock/web/MockAsyncContext.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2017 the original author or authors. + * Copyright 2002-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -62,7 +62,14 @@ public class MockAsyncContext implements AsyncContext { public void addDispatchHandler(Runnable handler) { Assert.notNull(handler, "Dispatch handler must not be null"); - this.dispatchHandlers.add(handler); + synchronized (this) { + if (this.dispatchedPath == null) { + this.dispatchHandlers.add(handler); + } + else { + handler.run(); + } + } } @Override @@ -92,9 +99,11 @@ public class MockAsyncContext implements AsyncContext { @Override public void dispatch(ServletContext context, String path) { - this.dispatchedPath = path; - for (Runnable r : this.dispatchHandlers) { - r.run(); + synchronized (this) { + this.dispatchedPath = path; + for (Runnable r : this.dispatchHandlers) { + r.run(); + } } } diff --git a/spring-test/src/main/java/org/springframework/test/web/servlet/DefaultMvcResult.java b/spring-test/src/main/java/org/springframework/test/web/servlet/DefaultMvcResult.java index 2273388fbf1..a38a86bf3db 100644 --- a/spring-test/src/main/java/org/springframework/test/web/servlet/DefaultMvcResult.java +++ b/spring-test/src/main/java/org/springframework/test/web/servlet/DefaultMvcResult.java @@ -16,10 +16,13 @@ package org.springframework.test.web.servlet; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.util.Assert; import org.springframework.web.servlet.FlashMap; import org.springframework.web.servlet.HandlerInterceptor; import org.springframework.web.servlet.ModelAndView; @@ -51,6 +54,8 @@ class DefaultMvcResult implements MvcResult { private final AtomicReference asyncResult = new AtomicReference(RESULT_NONE); + private CountDownLatch asyncDispatchLatch; + /** * Create a new instance with the given request and response. @@ -126,27 +131,31 @@ class DefaultMvcResult implements MvcResult { if (this.mockRequest.getAsyncContext() != null) { timeToWait = (timeToWait == -1 ? this.mockRequest.getAsyncContext().getTimeout() : timeToWait); } - - if (timeToWait > 0) { - long endTime = System.currentTimeMillis() + timeToWait; - while (System.currentTimeMillis() < endTime && this.asyncResult.get() == RESULT_NONE) { - try { - Thread.sleep(100); - } - catch (InterruptedException ex) { - Thread.currentThread().interrupt(); - throw new IllegalStateException("Interrupted while waiting for " + - "async result to be set for handler [" + this.handler + "]", ex); - } - } + if (!awaitAsyncDispatch(timeToWait)) { + throw new IllegalStateException("Async result for handler [" + this.handler + "]" + + " was not set during the specified timeToWait=" + timeToWait); } - Object result = this.asyncResult.get(); - if (result == RESULT_NONE) { - throw new IllegalStateException("Async result for handler [" + this.handler + "] " + - "was not set during the specified timeToWait=" + timeToWait); + Assert.state(result != RESULT_NONE, "Async result for handler [" + this.handler + "] was not set"); + return this.asyncResult.get(); + } + + /** + * True if is there a latch was not set, or the latch count reached 0. + */ + private boolean awaitAsyncDispatch(long timeout) { + Assert.state(this.asyncDispatchLatch != null, + "The asynDispatch CountDownLatch was not set by the TestDispatcherServlet.\n"); + try { + return this.asyncDispatchLatch.await(timeout, TimeUnit.MILLISECONDS); + } + catch (InterruptedException e) { + return false; } - return result; + } + + void setAsyncDispatchLatch(CountDownLatch asyncDispatchLatch) { + this.asyncDispatchLatch = asyncDispatchLatch; } } diff --git a/spring-test/src/main/java/org/springframework/test/web/servlet/TestDispatcherServlet.java b/spring-test/src/main/java/org/springframework/test/web/servlet/TestDispatcherServlet.java index f073f4deee8..d0a3a2599af 100644 --- a/spring-test/src/main/java/org/springframework/test/web/servlet/TestDispatcherServlet.java +++ b/spring-test/src/main/java/org/springframework/test/web/servlet/TestDispatcherServlet.java @@ -18,11 +18,13 @@ package org.springframework.test.web.servlet; import java.io.IOException; import java.util.concurrent.Callable; +import java.util.concurrent.CountDownLatch; import javax.servlet.ServletException; import javax.servlet.ServletRequest; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; +import org.springframework.mock.web.MockAsyncContext; import org.springframework.web.context.WebApplicationContext; import org.springframework.web.context.request.NativeWebRequest; import org.springframework.web.context.request.async.CallableProcessingInterceptorAdapter; @@ -63,6 +65,7 @@ final class TestDispatcherServlet extends DispatcherServlet { registerAsyncResultInterceptors(request); super.service(request, response); + initAsyncDispatchLatch(request); } private void registerAsyncResultInterceptors(final HttpServletRequest request) { @@ -81,6 +84,19 @@ final class TestDispatcherServlet extends DispatcherServlet { }); } + private void initAsyncDispatchLatch(HttpServletRequest request) { + if (request.getAsyncContext() != null) { + final CountDownLatch dispatchLatch = new CountDownLatch(1); + ((MockAsyncContext) request.getAsyncContext()).addDispatchHandler(new Runnable() { + @Override + public void run() { + dispatchLatch.countDown(); + } + }); + getMvcResult(request).setAsyncDispatchLatch(dispatchLatch); + } + } + protected DefaultMvcResult getMvcResult(ServletRequest request) { return (DefaultMvcResult) request.getAttribute(MockMvc.MVC_RESULT_ATTRIBUTE); } diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/DefaultMvcResultTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/DefaultMvcResultTests.java index cf17c503298..ad149847c30 100644 --- a/spring-test/src/test/java/org/springframework/test/web/servlet/DefaultMvcResultTests.java +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/DefaultMvcResultTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2014 the original author or authors. + * Copyright 2002-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,6 +15,8 @@ */ package org.springframework.test.web.servlet; +import java.util.concurrent.CountDownLatch; + import org.junit.Before; import org.junit.Test; @@ -40,6 +42,7 @@ public class DefaultMvcResultTests { @Test public void getAsyncResultSuccess() throws Exception { this.mvcResult.setAsyncResult("Foo"); + this.mvcResult.setAsyncDispatchLatch(new CountDownLatch(0)); assertEquals("Foo", this.mvcResult.getAsyncResult()); } diff --git a/spring-web/src/test/java/org/springframework/mock/web/test/MockAsyncContext.java b/spring-web/src/test/java/org/springframework/mock/web/test/MockAsyncContext.java index a23f44e8abc..8fc030f68dd 100644 --- a/spring-web/src/test/java/org/springframework/mock/web/test/MockAsyncContext.java +++ b/spring-web/src/test/java/org/springframework/mock/web/test/MockAsyncContext.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2017 the original author or authors. + * Copyright 2002-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -62,7 +62,14 @@ public class MockAsyncContext implements AsyncContext { public void addDispatchHandler(Runnable handler) { Assert.notNull(handler, "Dispatch handler must not be null"); - this.dispatchHandlers.add(handler); + synchronized (this) { + if (this.dispatchedPath == null) { + this.dispatchHandlers.add(handler); + } + else { + handler.run(); + } + } } @Override @@ -92,9 +99,11 @@ public class MockAsyncContext implements AsyncContext { @Override public void dispatch(ServletContext context, String path) { - this.dispatchedPath = path; - for (Runnable r : this.dispatchHandlers) { - r.run(); + synchronized (this) { + this.dispatchedPath = path; + for (Runnable r : this.dispatchHandlers) { + r.run(); + } } }