diff --git a/spring-web/src/main/java/org/springframework/web/context/request/async/StandardServletAsyncWebRequest.java b/spring-web/src/main/java/org/springframework/web/context/request/async/StandardServletAsyncWebRequest.java index 9bd4ac5c509..69ed94fe684 100644 --- a/spring-web/src/main/java/org/springframework/web/context/request/async/StandardServletAsyncWebRequest.java +++ b/spring-web/src/main/java/org/springframework/web/context/request/async/StandardServletAsyncWebRequest.java @@ -50,6 +50,8 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements private final List timeoutHandlers = new ArrayList(); + private ErrorHandler errorHandler; + private final List completionHandlers = new ArrayList(); @@ -78,6 +80,10 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements this.timeoutHandlers.add(timeoutHandler); } + void setErrorHandler(ErrorHandler errorHandler) { + this.errorHandler = errorHandler; + } + @Override public void addCompletionHandler(Runnable runnable) { this.completionHandlers.add(runnable); @@ -134,7 +140,9 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements @Override public void onError(AsyncEvent event) throws IOException { - onComplete(event); + if (this.errorHandler != null) { + this.errorHandler.handle(event.getThrowable()); + } } @Override @@ -153,4 +161,11 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements this.asyncCompleted.set(true); } + + interface ErrorHandler { + + void handle(Throwable ex); + + } + } diff --git a/spring-web/src/main/java/org/springframework/web/context/request/async/WebAsyncManager.java b/spring-web/src/main/java/org/springframework/web/context/request/async/WebAsyncManager.java index 597f9871802..92fcb48a633 100644 --- a/spring-web/src/main/java/org/springframework/web/context/request/async/WebAsyncManager.java +++ b/spring-web/src/main/java/org/springframework/web/context/request/async/WebAsyncManager.java @@ -298,6 +298,16 @@ public final class WebAsyncManager { } }); + if (this.asyncWebRequest instanceof StandardServletAsyncWebRequest) { + ((StandardServletAsyncWebRequest) this.asyncWebRequest).setErrorHandler( + new StandardServletAsyncWebRequest.ErrorHandler() { + @Override + public void handle(Throwable ex) { + setConcurrentResultAndDispatch(ex); + } + }); + } + this.asyncWebRequest.addCompletionHandler(new Runnable() { @Override public void run() { @@ -399,6 +409,16 @@ public final class WebAsyncManager { } }); + if (this.asyncWebRequest instanceof StandardServletAsyncWebRequest) { + ((StandardServletAsyncWebRequest) this.asyncWebRequest).setErrorHandler( + new StandardServletAsyncWebRequest.ErrorHandler() { + @Override + public void handle(Throwable ex) { + deferredResult.setErrorResult(ex); + } + }); + } + this.asyncWebRequest.addCompletionHandler(new Runnable() { @Override public void run() { diff --git a/spring-web/src/test/java/org/springframework/web/context/request/async/StandardServletAsyncWebRequestTests.java b/spring-web/src/test/java/org/springframework/web/context/request/async/StandardServletAsyncWebRequestTests.java index 2596796e865..964b1c6fddb 100644 --- a/spring-web/src/test/java/org/springframework/web/context/request/async/StandardServletAsyncWebRequestTests.java +++ b/spring-web/src/test/java/org/springframework/web/context/request/async/StandardServletAsyncWebRequestTests.java @@ -25,6 +25,7 @@ import org.junit.Test; import org.springframework.mock.web.test.MockAsyncContext; import org.springframework.mock.web.test.MockHttpServletRequest; import org.springframework.mock.web.test.MockHttpServletResponse; +import org.springframework.web.context.request.async.StandardServletAsyncWebRequest.ErrorHandler; import static org.hamcrest.Matchers.containsString; import static org.junit.Assert.assertEquals; @@ -148,13 +149,26 @@ public class StandardServletAsyncWebRequestTests { // SPR-13292 + @SuppressWarnings("unchecked") @Test - public void onCompletionHandlerAfterOnErrorEvent() throws Exception { + public void onErrorHandlerAfterOnErrorEvent() throws Exception { + ErrorHandler handler = mock(ErrorHandler.class); + this.asyncRequest.setErrorHandler(handler); + + this.asyncRequest.startAsync(); + Exception e = new Exception(); + this.asyncRequest.onError(new AsyncEvent(this.request.getAsyncContext(), e)); + + verify(handler).handle(e); + } + + @Test + public void onCompletionHandlerAfterOnCompleteEvent() throws Exception { Runnable handler = mock(Runnable.class); this.asyncRequest.addCompletionHandler(handler); this.asyncRequest.startAsync(); - this.asyncRequest.onError(new AsyncEvent(null)); + this.asyncRequest.onComplete(new AsyncEvent(this.request.getAsyncContext())); verify(handler).run(); assertTrue(this.asyncRequest.isAsyncComplete()); diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitter.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitter.java index 18d96b07710..209cd522a5c 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitter.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitter.java @@ -166,11 +166,9 @@ public class ResponseBodyEmitter { this.handler.send(object, mediaType); } catch (IOException ex) { - completeWithError(ex); throw ex; } catch (Throwable ex) { - completeWithError(ex); throw new IllegalStateException("Failed to send " + object, ex); } } diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitterTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitterTests.java index 3709a5e8ed5..2e211ad433a 100644 --- a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitterTests.java +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitterTests.java @@ -150,7 +150,6 @@ public class ResponseBodyEmitterTests { // expected } verify(this.handler).send("foo", MediaType.TEXT_PLAIN); - verify(this.handler).completeWithError(failure); verifyNoMoreInteractions(this.handler); }