Browse Source

Refactor async result handling in Spring MVC Test

This change removes the use of a CountDownLatch to wait for the
asynchronously computed controller method return value. Instead we
check in a loop every 200 milliseconds if the result has been set.
If the result is not set within the specified amount of time to wait
an IllegalStateException is raised.

Additional changes:
 - Use AtomicReference to hold the async result
 - Remove @Ignore annotations on AsyncTests methods
 - Remove checks for the presence of Servlet 3

Issue: SPR-11516
pull/484/merge
Rossen Stoyanchev 12 years ago
parent
commit
74de35df1e
  1. 50
      spring-test/src/main/java/org/springframework/test/web/servlet/DefaultMvcResult.java
  2. 26
      spring-test/src/main/java/org/springframework/test/web/servlet/MvcResult.java
  3. 23
      spring-test/src/main/java/org/springframework/test/web/servlet/TestDispatcherServlet.java
  4. 20
      spring-test/src/main/java/org/springframework/test/web/servlet/result/PrintingResultHandler.java
  5. 72
      spring-test/src/test/java/org/springframework/test/web/servlet/DefaultMvcResultTests.java
  6. 2
      spring-test/src/test/java/org/springframework/test/web/servlet/StubMvcResult.java
  7. 5
      spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/AsyncTests.java

50
spring-test/src/main/java/org/springframework/test/web/servlet/DefaultMvcResult.java

@ -16,12 +16,11 @@
package org.springframework.test.web.servlet; package org.springframework.test.web.servlet;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference;
import javax.servlet.http.HttpServletRequest;
import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.util.Assert;
import org.springframework.web.servlet.FlashMap; import org.springframework.web.servlet.FlashMap;
import org.springframework.web.servlet.HandlerInterceptor; import org.springframework.web.servlet.HandlerInterceptor;
import org.springframework.web.servlet.ModelAndView; import org.springframework.web.servlet.ModelAndView;
@ -51,7 +50,7 @@ class DefaultMvcResult implements MvcResult {
private Exception resolvedException; private Exception resolvedException;
private Object asyncResult = RESULT_NONE; private final AtomicReference<Object> asyncResult = new AtomicReference<Object>(RESULT_NONE);
private CountDownLatch asyncResultLatch; private CountDownLatch asyncResultLatch;
@ -116,7 +115,7 @@ class DefaultMvcResult implements MvcResult {
} }
public void setAsyncResult(Object asyncResult) { public void setAsyncResult(Object asyncResult) {
this.asyncResult = asyncResult; this.asyncResult.set(asyncResult);
} }
@Override @Override
@ -125,35 +124,30 @@ class DefaultMvcResult implements MvcResult {
} }
@Override @Override
public Object getAsyncResult(long timeout) { public Object getAsyncResult(long timeToWait) {
if (this.asyncResult == RESULT_NONE) {
if ((timeout != 0) && this.mockRequest.isAsyncStarted()) { if (this.mockRequest.getAsyncContext() != null) {
if (timeout == -1) { timeToWait = (timeToWait == -1 ? this.mockRequest.getAsyncContext().getTimeout() : timeToWait);
timeout = this.mockRequest.getAsyncContext().getTimeout(); }
if (timeToWait > 0) {
long endTime = System.currentTimeMillis() + timeToWait;
while (System.currentTimeMillis() < endTime && this.asyncResult.get() == RESULT_NONE) {
try {
Thread.sleep(200);
} }
if (!awaitAsyncResult(timeout) && this.asyncResult == RESULT_NONE) { catch (InterruptedException ex) {
throw new IllegalStateException( throw new IllegalStateException("Interrupted while waiting for " +
"Gave up waiting on async result from handler [" + this.handler + "] to complete"); "async result to be set for handler [" + this.handler + "]", ex);
} }
} }
} }
return (this.asyncResult == RESULT_NONE ? null : this.asyncResult);
}
private boolean awaitAsyncResult(long timeout) { Assert.state(this.asyncResult.get() != RESULT_NONE,
if (this.asyncResultLatch != null) { "Async result for handler [" + this.handler + "] " +
try { "was not set during the specified timeToWait=" + timeToWait);
return this.asyncResultLatch.await(timeout, TimeUnit.MILLISECONDS);
}
catch (InterruptedException e) {
return false;
}
}
return true;
}
public void setAsyncResultLatch(CountDownLatch asyncResultLatch) { return this.asyncResult.get();
this.asyncResultLatch = asyncResultLatch;
} }
} }

26
spring-test/src/main/java/org/springframework/test/web/servlet/MvcResult.java

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2012 the original author or authors. * Copyright 2002-2014 the original author or authors.
* *
* Licensed under the Apache License; Version 2.0 (the "License"); * Licensed under the Apache License; Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -76,24 +76,24 @@ public interface MvcResult {
FlashMap getFlashMap(); FlashMap getFlashMap();
/** /**
* Get the result of asynchronous execution or {@code null} if concurrent * Get the result of async execution. This method will wait for the async result
* handling did not start. This method will hold and await the completion * to be set for up to the amount of time configured on the async request,
* of concurrent handling. * i.e. {@link org.springframework.mock.web.MockAsyncContext#getTimeout()}.
* *
* @throws IllegalStateException if concurrent handling does not complete * @throws IllegalStateException if the async result was not set.
* within the allocated async timeout value.
*/ */
Object getAsyncResult(); Object getAsyncResult();
/** /**
* Get the result of asynchronous execution or {@code null} if concurrent * Get the result of async execution. This method will wait for the async result
* handling did not start. This method will wait for up to the given timeout * to be set for up to the specified amount of time.
* for the completion of concurrent handling.
* *
* @param timeout how long to wait for the async result to be set in * @param timeToWait how long to wait for the async result to be set, in
* milliseconds; if -1, the wait will be as long as the async timeout set * milliseconds; if -1, then the async request timeout value is used,
* on the Servlet request * i.e.{@link org.springframework.mock.web.MockAsyncContext#getTimeout()}.
*
* @throws IllegalStateException if the async result was not set.
*/ */
Object getAsyncResult(long timeout); Object getAsyncResult(long timeToWait);
} }

23
spring-test/src/main/java/org/springframework/test/web/servlet/TestDispatcherServlet.java

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2012 the original author or authors. * Copyright 2002-2014 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -18,14 +18,12 @@ package org.springframework.test.web.servlet;
import java.io.IOException; import java.io.IOException;
import java.util.concurrent.Callable; import java.util.concurrent.Callable;
import java.util.concurrent.CountDownLatch;
import javax.servlet.ServletException; import javax.servlet.ServletException;
import javax.servlet.ServletRequest; import javax.servlet.ServletRequest;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponse;
import org.springframework.mock.web.MockAsyncContext;
import org.springframework.web.context.WebApplicationContext; import org.springframework.web.context.WebApplicationContext;
import org.springframework.web.context.request.NativeWebRequest; import org.springframework.web.context.request.NativeWebRequest;
import org.springframework.web.context.request.async.*; import org.springframework.web.context.request.async.*;
@ -57,15 +55,11 @@ final class TestDispatcherServlet extends DispatcherServlet {
@Override @Override
protected void service(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException { protected void service(HttpServletRequest request, HttpServletResponse response)
throws ServletException, IOException {
registerAsyncResultInterceptors(request); registerAsyncResultInterceptors(request);
super.service(request, response); super.service(request, response);
if (request.isAsyncStarted()) {
addAsyncResultLatch(request);
}
} }
private void registerAsyncResultInterceptors(final HttpServletRequest request) { private void registerAsyncResultInterceptors(final HttpServletRequest request) {
@ -84,17 +78,6 @@ final class TestDispatcherServlet extends DispatcherServlet {
}); });
} }
private void addAsyncResultLatch(HttpServletRequest request) {
final CountDownLatch latch = new CountDownLatch(1);
((MockAsyncContext) request.getAsyncContext()).addDispatchHandler(new Runnable() {
@Override
public void run() {
latch.countDown();
}
});
getMvcResult(request).setAsyncResultLatch(latch);
}
protected DefaultMvcResult getMvcResult(ServletRequest request) { protected DefaultMvcResult getMvcResult(ServletRequest request) {
return (DefaultMvcResult) request.getAttribute(MockMvc.MVC_RESULT_ATTRIBUTE); return (DefaultMvcResult) request.getAttribute(MockMvc.MVC_RESULT_ATTRIBUTE);
} }

20
spring-test/src/main/java/org/springframework/test/web/servlet/result/PrintingResultHandler.java

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2012 the original author or authors. * Copyright 2002-2014 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -19,7 +19,6 @@ package org.springframework.test.web.servlet.result;
import java.util.Enumeration; import java.util.Enumeration;
import java.util.Map; import java.util.Map;
import javax.servlet.ServletRequest;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import org.springframework.http.HttpHeaders; import org.springframework.http.HttpHeaders;
@ -27,7 +26,6 @@ import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.test.web.servlet.MvcResult; import org.springframework.test.web.servlet.MvcResult;
import org.springframework.test.web.servlet.ResultHandler; import org.springframework.test.web.servlet.ResultHandler;
import org.springframework.util.ClassUtils;
import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap; import org.springframework.util.MultiValueMap;
import org.springframework.validation.BindingResult; import org.springframework.validation.BindingResult;
@ -48,8 +46,6 @@ import org.springframework.web.servlet.support.RequestContextUtils;
*/ */
public class PrintingResultHandler implements ResultHandler { public class PrintingResultHandler implements ResultHandler {
private static final boolean servlet3Present = ClassUtils.hasMethod(ServletRequest.class, "startAsync");
private final ResultValuePrinter printer; private final ResultValuePrinter printer;
@ -80,10 +76,8 @@ public class PrintingResultHandler implements ResultHandler {
this.printer.printHeading("Handler"); this.printer.printHeading("Handler");
printHandler(result.getHandler(), result.getInterceptors()); printHandler(result.getHandler(), result.getInterceptors());
if (servlet3Present) { this.printer.printHeading("Async");
this.printer.printHeading("Async"); printAsyncResult(result);
printAsyncResult(result);
}
this.printer.printHeading("Resolved Exception"); this.printer.printHeading("Resolved Exception");
printResolvedException(result.getResolvedException()); printResolvedException(result.getResolvedException());
@ -133,11 +127,9 @@ public class PrintingResultHandler implements ResultHandler {
} }
protected void printAsyncResult(MvcResult result) throws Exception { protected void printAsyncResult(MvcResult result) throws Exception {
if (servlet3Present) { HttpServletRequest request = result.getRequest();
HttpServletRequest request = result.getRequest(); this.printer.printValue("Was async started", request.isAsyncStarted());
this.printer.printValue("Was async started", request.isAsyncStarted()); this.printer.printValue("Async result", (request.isAsyncStarted() ? result.getAsyncResult(0) : null));
this.printer.printValue("Async result", result.getAsyncResult(0));
}
} }
/** Print the handler */ /** Print the handler */

72
spring-test/src/test/java/org/springframework/test/web/servlet/DefaultMvcResultTests.java

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2013 the original author or authors. * Copyright 2002-2014 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -24,6 +24,7 @@ import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletRequest;
import static org.junit.Assert.assertEquals;
import static org.mockito.BDDMockito.*; import static org.mockito.BDDMockito.*;
/** /**
@ -37,80 +38,23 @@ public class DefaultMvcResultTests {
private DefaultMvcResult mvcResult; private DefaultMvcResult mvcResult;
private CountDownLatch countDownLatch;
@Before @Before
public void setup() { public void setup() {
ExtendedMockHttpServletRequest request = new ExtendedMockHttpServletRequest(); MockHttpServletRequest request = new MockHttpServletRequest();
request.setAsyncStarted(true); request.setAsyncStarted(true);
this.countDownLatch = mock(CountDownLatch.class);
this.mvcResult = new DefaultMvcResult(request, null); this.mvcResult = new DefaultMvcResult(request, null);
this.mvcResult.setAsyncResultLatch(this.countDownLatch);
}
@Test
public void getAsyncResultWithTimeout() throws Exception {
long timeout = 1234L;
given(this.countDownLatch.await(timeout, TimeUnit.MILLISECONDS)).willReturn(true);
this.mvcResult.getAsyncResult(timeout);
verify(this.countDownLatch).await(timeout, TimeUnit.MILLISECONDS);
} }
@Test @Test
public void getAsyncResultWithTimeoutNegativeOne() throws Exception { public void getAsyncResultSuccess() throws Exception {
given(this.countDownLatch.await(DEFAULT_TIMEOUT, TimeUnit.MILLISECONDS)).willReturn(true); this.mvcResult.setAsyncResult("Foo");
this.mvcResult.getAsyncResult(-1); assertEquals("Foo", this.mvcResult.getAsyncResult(10 * 1000));
verify(this.countDownLatch).await(DEFAULT_TIMEOUT, TimeUnit.MILLISECONDS);
} }
@Test @Test(expected = IllegalStateException.class)
public void getAsyncResultWithoutTimeout() throws Exception { public void getAsyncResultFailure() throws Exception {
given(this.countDownLatch.await(DEFAULT_TIMEOUT, TimeUnit.MILLISECONDS)).willReturn(true);
this.mvcResult.getAsyncResult();
verify(this.countDownLatch).await(DEFAULT_TIMEOUT, TimeUnit.MILLISECONDS);
}
@Test
public void getAsyncResultWithTimeoutZero() throws Exception {
this.mvcResult.getAsyncResult(0); this.mvcResult.getAsyncResult(0);
verifyZeroInteractions(this.countDownLatch);
}
@Test(expected=IllegalStateException.class)
public void getAsyncResultAndTimeOut() throws Exception {
this.mvcResult.getAsyncResult(-1);
verify(this.countDownLatch).await(DEFAULT_TIMEOUT, TimeUnit.MILLISECONDS);
}
private static class ExtendedMockHttpServletRequest extends MockHttpServletRequest {
private boolean asyncStarted;
private AsyncContext asyncContext;
public ExtendedMockHttpServletRequest() {
super();
this.asyncContext = mock(AsyncContext.class);
given(this.asyncContext.getTimeout()).willReturn(new Long(DEFAULT_TIMEOUT));
}
@Override
public void setAsyncStarted(boolean asyncStarted) {
this.asyncStarted = asyncStarted;
}
@Override
public boolean isAsyncStarted() {
return this.asyncStarted;
}
@Override
public AsyncContext getAsyncContext() {
return asyncContext;
}
} }
} }

2
spring-test/src/test/java/org/springframework/test/web/servlet/StubMvcResult.java

@ -133,7 +133,7 @@ public class StubMvcResult implements MvcResult {
} }
@Override @Override
public Object getAsyncResult(long timeout) { public Object getAsyncResult(long timeToWait) {
return null; return null;
} }

5
spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/AsyncTests.java

@ -59,23 +59,19 @@ public class AsyncTests {
} }
@Test @Test
@Ignore
public void testCallable() throws Exception { public void testCallable() throws Exception {
MvcResult mvcResult = this.mockMvc.perform(get("/1").param("callable", "true")) MvcResult mvcResult = this.mockMvc.perform(get("/1").param("callable", "true"))
.andDo(print())
.andExpect(request().asyncStarted()) .andExpect(request().asyncStarted())
.andExpect(request().asyncResult(new Person("Joe"))) .andExpect(request().asyncResult(new Person("Joe")))
.andReturn(); .andReturn();
this.mockMvc.perform(asyncDispatch(mvcResult)) this.mockMvc.perform(asyncDispatch(mvcResult))
.andDo(print())
.andExpect(status().isOk()) .andExpect(status().isOk())
.andExpect(content().contentType(MediaType.APPLICATION_JSON)) .andExpect(content().contentType(MediaType.APPLICATION_JSON))
.andExpect(content().string("{\"name\":\"Joe\",\"someDouble\":0.0,\"someBoolean\":false}")); .andExpect(content().string("{\"name\":\"Joe\",\"someDouble\":0.0,\"someBoolean\":false}"));
} }
@Test @Test
@Ignore
public void testDeferredResult() throws Exception { public void testDeferredResult() throws Exception {
MvcResult mvcResult = this.mockMvc.perform(get("/1").param("deferredResult", "true")) MvcResult mvcResult = this.mockMvc.perform(get("/1").param("deferredResult", "true"))
.andExpect(request().asyncStarted()) .andExpect(request().asyncStarted())
@ -90,7 +86,6 @@ public class AsyncTests {
} }
@Test @Test
@Ignore
public void testDeferredResultWithSetValue() throws Exception { public void testDeferredResultWithSetValue() throws Exception {
MvcResult mvcResult = this.mockMvc.perform(get("/1").param("deferredResultWithSetValue", "true")) MvcResult mvcResult = this.mockMvc.perform(get("/1").param("deferredResultWithSetValue", "true"))
.andExpect(request().asyncStarted()) .andExpect(request().asyncStarted())

Loading…
Cancel
Save