diff --git a/spring-test-mvc/src/main/java/org/springframework/test/web/servlet/request/MockMvcRequestBuilders.java b/spring-test-mvc/src/main/java/org/springframework/test/web/servlet/request/MockMvcRequestBuilders.java index b812ad2d9d1..acdc8ae15f7 100644 --- a/spring-test-mvc/src/main/java/org/springframework/test/web/servlet/request/MockMvcRequestBuilders.java +++ b/spring-test-mvc/src/main/java/org/springframework/test/web/servlet/request/MockMvcRequestBuilders.java @@ -15,8 +15,15 @@ */ package org.springframework.test.web.servlet.request; +import java.lang.reflect.Method; + +import javax.servlet.ServletContext; + import org.springframework.http.HttpMethod; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.test.web.servlet.MvcResult; import org.springframework.test.web.servlet.RequestBuilder; +import org.springframework.util.ReflectionUtils; /** * Static factory methods for {@link RequestBuilder}s. @@ -83,4 +90,37 @@ public abstract class MockMvcRequestBuilders { return new MockMultipartHttpServletRequestBuilder(urlTemplate, urlVariables); } + /** + * Create a {@link RequestBuilder} for an async dispatch from the + * {@link MvcResult} of the request that started async processing. + * + *

Usage involves performing one request first that starts async processing: + *

+	 * MvcResult mvcResult = this.mockMvc.perform(get("/1"))
+	 *	.andExpect(request().asyncStarted())
+	 *	.andReturn();
+	 *  
+ * + *

And then performing the async dispatch re-using the {@code MvcResult}: + *

+	 * this.mockMvc.perform(asyncDispatch(mvcResult))
+	 * 	.andExpect(status().isOk())
+	 * 	.andExpect(content().contentType(MediaType.APPLICATION_JSON))
+	 * 	.andExpect(content().string("{\"name\":\"Joe\",\"someDouble\":0.0,\"someBoolean\":false}"));
+	 * 
+ * + * @param mvcResult the result from the request that started async processing + */ + public static RequestBuilder asyncDispatch(final MvcResult mvcResult) { + return new RequestBuilder() { + public MockHttpServletRequest buildRequest(ServletContext servletContext) { + MockHttpServletRequest request = mvcResult.getRequest(); + Method method = ReflectionUtils.findMethod(request.getClass(), "setAsyncStarted", boolean.class); + method.setAccessible(true); + ReflectionUtils.invokeMethod(method, request, false); + return request; + } + }; + } + } diff --git a/spring-test-mvc/src/test/java/org/springframework/test/web/servlet/samples/standalone/AsyncTests.java b/spring-test-mvc/src/test/java/org/springframework/test/web/servlet/samples/standalone/AsyncTests.java index 73bdffa7c5b..4ef7ce93355 100644 --- a/spring-test-mvc/src/test/java/org/springframework/test/web/servlet/samples/standalone/AsyncTests.java +++ b/spring-test-mvc/src/test/java/org/springframework/test/web/servlet/samples/standalone/AsyncTests.java @@ -15,19 +15,27 @@ */ package org.springframework.test.web.servlet.samples.standalone; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.asyncDispatch; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.request; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; import static org.springframework.test.web.servlet.setup.MockMvcBuilders.standaloneSetup; +import java.util.Collection; import java.util.concurrent.Callable; +import java.util.concurrent.CopyOnWriteArrayList; import org.junit.Before; import org.junit.Test; +import org.springframework.http.MediaType; import org.springframework.stereotype.Controller; import org.springframework.test.web.Person; import org.springframework.test.web.servlet.MockMvc; +import org.springframework.test.web.servlet.MvcResult; +import org.springframework.ui.Model; import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.ResponseBody; import org.springframework.web.context.request.async.DeferredResult; /** @@ -39,37 +47,52 @@ public class AsyncTests { private MockMvc mockMvc; + private AsyncController asyncController; + + @Before public void setup() { - this.mockMvc = standaloneSetup(new AsyncController()).build(); + this.asyncController = new AsyncController(); + this.mockMvc = standaloneSetup(this.asyncController).build(); } @Test - public void testDeferredResult() throws Exception { - this.mockMvc.perform(get("/1").param("deferredResult", "true")) + public void testCallable() throws Exception { + MvcResult mvcResult = this.mockMvc.perform(get("/1").param("callable", "true")) + .andExpect(request().asyncStarted()) + .andExpect(request().asyncResult(new Person("Joe"))) + .andReturn(); + + this.mockMvc.perform(asyncDispatch(mvcResult)) .andExpect(status().isOk()) - .andExpect(request().asyncStarted()); + .andExpect(content().contentType(MediaType.APPLICATION_JSON)) + .andExpect(content().string("{\"name\":\"Joe\",\"someDouble\":0.0,\"someBoolean\":false}")); } @Test - public void testCallable() throws Exception { - this.mockMvc.perform(get("/1").param("callable", "true")) - .andExpect(status().isOk()) + public void testDeferredResult() throws Exception { + MvcResult mvcResult = this.mockMvc.perform(get("/1").param("deferredResult", "true")) .andExpect(request().asyncStarted()) - .andExpect(request().asyncResult(new Person("Joe"))); + .andReturn(); + + this.asyncController.onMessage("Joe"); + + this.mockMvc.perform(asyncDispatch(mvcResult)) + .andExpect(status().isOk()) + .andExpect(content().contentType(MediaType.APPLICATION_JSON)) + .andExpect(content().string("{\"name\":\"Joe\",\"someDouble\":0.0,\"someBoolean\":false}")); } @Controller private static class AsyncController { - @RequestMapping(value="/{id}", params="deferredResult", produces="application/json") - public DeferredResult getDeferredResult() { - return new DeferredResult(); - } + private Collection> deferredResults = new CopyOnWriteArrayList>(); + @RequestMapping(value="/{id}", params="callable", produces="application/json") - public Callable getCallable() { + @ResponseBody + public Callable getCallable(final Model model) { return new Callable() { public Person call() throws Exception { return new Person("Joe"); @@ -77,6 +100,20 @@ public class AsyncTests { }; } + @RequestMapping(value="/{id}", params="deferredResult", produces="application/json") + @ResponseBody + public DeferredResult getDeferredResult() { + DeferredResult deferredResult = new DeferredResult(); + this.deferredResults.add(deferredResult); + return deferredResult; + } + + public void onMessage(String name) { + for (DeferredResult deferredResult : this.deferredResults) { + deferredResult.setResult(new Person(name)); + this.deferredResults.remove(deferredResult); + } + } } }