diff --git a/spring-web/src/main/java/org/springframework/web/context/request/async/AsyncRequestNotUsableException.java b/spring-web/src/main/java/org/springframework/web/context/request/async/AsyncRequestNotUsableException.java
new file mode 100644
index 00000000000..45198fe728d
--- /dev/null
+++ b/spring-web/src/main/java/org/springframework/web/context/request/async/AsyncRequestNotUsableException.java
@@ -0,0 +1,44 @@
+/*
+ * Copyright 2002-2024 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.web.context.request.async;
+
+import java.io.IOException;
+
+/**
+ * Raised when the response for an asynchronous request becomes unusable as
+ * indicated by a write failure, or a Servlet container error notification, or
+ * after the async request has completed.
+ *
+ *
The exception relies on response wrapping, and on {@code AsyncListener}
+ * notifications, managed by {@link StandardServletAsyncWebRequest}.
+ *
+ * @author Rossen Stoyanchev
+ * @since 5.3.33
+ */
+@SuppressWarnings("serial")
+public class AsyncRequestNotUsableException extends IOException {
+
+
+ public AsyncRequestNotUsableException(String message) {
+ super(message);
+ }
+
+ public AsyncRequestNotUsableException(String message, Throwable cause) {
+ super(message, cause);
+ }
+
+}
diff --git a/spring-web/src/main/java/org/springframework/web/context/request/async/StandardServletAsyncWebRequest.java b/spring-web/src/main/java/org/springframework/web/context/request/async/StandardServletAsyncWebRequest.java
index ebae26d6785..aa48427b2ad 100644
--- a/spring-web/src/main/java/org/springframework/web/context/request/async/StandardServletAsyncWebRequest.java
+++ b/spring-web/src/main/java/org/springframework/web/context/request/async/StandardServletAsyncWebRequest.java
@@ -17,16 +17,22 @@
package org.springframework.web.context.request.async;
import java.io.IOException;
+import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.List;
-import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.Locale;
+import java.util.concurrent.locks.Lock;
+import java.util.concurrent.locks.ReentrantLock;
import java.util.function.Consumer;
import jakarta.servlet.AsyncContext;
import jakarta.servlet.AsyncEvent;
import jakarta.servlet.AsyncListener;
+import jakarta.servlet.ServletOutputStream;
+import jakarta.servlet.WriteListener;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
+import jakarta.servlet.http.HttpServletResponseWrapper;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
@@ -45,8 +51,6 @@ import org.springframework.web.context.request.ServletWebRequest;
*/
public class StandardServletAsyncWebRequest extends ServletWebRequest implements AsyncWebRequest, AsyncListener {
- private final AtomicBoolean asyncCompleted = new AtomicBoolean();
-
private final List timeoutHandlers = new ArrayList<>();
private final List> exceptionHandlers = new ArrayList<>();
@@ -59,6 +63,10 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements
@Nullable
private AsyncContext asyncContext;
+ private State state;
+
+ private final ReentrantLock stateLock = new ReentrantLock();
+
/**
* Create a new instance for the given request/response pair.
@@ -66,7 +74,26 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements
* @param response current HTTP response
*/
public StandardServletAsyncWebRequest(HttpServletRequest request, HttpServletResponse response) {
- super(request, response);
+ this(request, response, null);
+ }
+
+ /**
+ * Constructor to wrap the request and response for the current dispatch that
+ * also picks up the state of the last (probably the REQUEST) dispatch.
+ * @param request current HTTP request
+ * @param response current HTTP response
+ * @param previousRequest the existing request from the last dispatch
+ * @since 5.3.33
+ */
+ StandardServletAsyncWebRequest(HttpServletRequest request, HttpServletResponse response,
+ @Nullable StandardServletAsyncWebRequest previousRequest) {
+
+ super(request, new LifecycleHttpServletResponse(response));
+
+ this.state = (previousRequest != null ? previousRequest.state : State.NEW);
+
+ //noinspection DataFlowIssue
+ ((LifecycleHttpServletResponse) getResponse()).setAsyncWebRequest(this);
}
@@ -107,7 +134,7 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements
*/
@Override
public boolean isAsyncComplete() {
- return this.asyncCompleted.get();
+ return (this.state == State.COMPLETED);
}
@Override
@@ -117,11 +144,18 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements
"in async request processing. This is done in Java code using the Servlet API " +
"or by adding \"true\" to servlet and " +
"filter declarations in web.xml.");
- Assert.state(!isAsyncComplete(), "Async processing has already completed");
if (isAsyncStarted()) {
return;
}
+
+ if (this.state == State.NEW) {
+ this.state = State.ASYNC;
+ }
+ else {
+ Assert.state(this.state == State.ASYNC, "Cannot start async: [" + this.state + "]");
+ }
+
this.asyncContext = getRequest().startAsync(getRequest(), getResponse());
this.asyncContext.addListener(this);
if (this.timeout != null) {
@@ -131,8 +165,10 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements
@Override
public void dispatch() {
- Assert.state(this.asyncContext != null, "Cannot dispatch without an AsyncContext");
- this.asyncContext.dispatch();
+ Assert.state(this.asyncContext != null, "AsyncContext not yet initialized");
+ if (!this.isAsyncComplete()) {
+ this.asyncContext.dispatch();
+ }
}
@@ -151,14 +187,478 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements
@Override
public void onError(AsyncEvent event) throws IOException {
- this.exceptionHandlers.forEach(consumer -> consumer.accept(event.getThrowable()));
+ this.stateLock.lock();
+ try {
+ transitionToErrorState();
+ Throwable ex = event.getThrowable();
+ this.exceptionHandlers.forEach(consumer -> consumer.accept(ex));
+ }
+ finally {
+ this.stateLock.unlock();
+ }
+ }
+
+ private void transitionToErrorState() {
+ if (!isAsyncComplete()) {
+ this.state = State.ERROR;
+ }
}
@Override
public void onComplete(AsyncEvent event) throws IOException {
- this.completionHandlers.forEach(Runnable::run);
- this.asyncContext = null;
- this.asyncCompleted.set(true);
+ this.stateLock.lock();
+ try {
+ this.completionHandlers.forEach(Runnable::run);
+ this.asyncContext = null;
+ this.state = State.COMPLETED;
+ }
+ finally {
+ this.stateLock.unlock();
+ }
+ }
+
+
+ /**
+ * Response wrapper to wrap the output stream with {@link LifecycleServletOutputStream}.
+ */
+ private static final class LifecycleHttpServletResponse extends HttpServletResponseWrapper {
+
+ @Nullable
+ private StandardServletAsyncWebRequest asyncWebRequest;
+
+ @Nullable
+ private ServletOutputStream outputStream;
+
+ @Nullable
+ private PrintWriter writer;
+
+ public LifecycleHttpServletResponse(HttpServletResponse response) {
+ super(response);
+ }
+
+ public void setAsyncWebRequest(StandardServletAsyncWebRequest asyncWebRequest) {
+ this.asyncWebRequest = asyncWebRequest;
+ }
+
+ @Override
+ public ServletOutputStream getOutputStream() {
+ if (this.outputStream == null) {
+ Assert.notNull(this.asyncWebRequest, "Not initialized");
+ this.outputStream = new LifecycleServletOutputStream(
+ (HttpServletResponse) getResponse(), this.asyncWebRequest);
+ }
+ return this.outputStream;
+ }
+
+ @Override
+ public PrintWriter getWriter() throws IOException {
+ if (this.writer == null) {
+ Assert.notNull(this.asyncWebRequest, "Not initialized");
+ this.writer = new LifecyclePrintWriter(getResponse().getWriter(), this.asyncWebRequest);
+ }
+ return this.writer;
+ }
+ }
+
+
+ /**
+ * Wraps a ServletOutputStream to prevent use after Servlet container onError
+ * notifications, and after async request completion.
+ */
+ private static final class LifecycleServletOutputStream extends ServletOutputStream {
+
+ private final HttpServletResponse delegate;
+
+ private final StandardServletAsyncWebRequest asyncWebRequest;
+
+ private LifecycleServletOutputStream(
+ HttpServletResponse delegate, StandardServletAsyncWebRequest asyncWebRequest) {
+
+ this.delegate = delegate;
+ this.asyncWebRequest = asyncWebRequest;
+ }
+
+ @Override
+ public boolean isReady() {
+ return false;
+ }
+
+ @Override
+ public void setWriteListener(WriteListener writeListener) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void write(int b) throws IOException {
+ obtainLockAndCheckState();
+ try {
+ this.delegate.getOutputStream().write(b);
+ }
+ catch (IOException ex) {
+ handleIOException(ex, "ServletOutputStream failed to write");
+ }
+ finally {
+ releaseLock();
+ }
+ }
+
+ public void write(byte[] buf, int offset, int len) throws IOException {
+ obtainLockAndCheckState();
+ try {
+ this.delegate.getOutputStream().write(buf, offset, len);
+ }
+ catch (IOException ex) {
+ handleIOException(ex, "ServletOutputStream failed to write");
+ }
+ finally {
+ releaseLock();
+ }
+ }
+
+ @Override
+ public void flush() throws IOException {
+ obtainLockAndCheckState();
+ try {
+ this.delegate.getOutputStream().flush();
+ }
+ catch (IOException ex) {
+ handleIOException(ex, "ServletOutputStream failed to flush");
+ }
+ finally {
+ releaseLock();
+ }
+ }
+
+ @Override
+ public void close() throws IOException {
+ obtainLockAndCheckState();
+ try {
+ this.delegate.getOutputStream().close();
+ }
+ catch (IOException ex) {
+ handleIOException(ex, "ServletOutputStream failed to close");
+ }
+ finally {
+ releaseLock();
+ }
+ }
+
+ private void obtainLockAndCheckState() throws AsyncRequestNotUsableException {
+ if (state() != State.NEW) {
+ stateLock().lock();
+ if (state() != State.ASYNC) {
+ stateLock().unlock();
+ throw new AsyncRequestNotUsableException("Response not usable after " +
+ (state() == State.COMPLETED ?
+ "async request completion" : "onError notification") + ".");
+ }
+ }
+ }
+
+ private void handleIOException(IOException ex, String msg) throws AsyncRequestNotUsableException {
+ this.asyncWebRequest.transitionToErrorState();
+ throw new AsyncRequestNotUsableException(msg, ex);
+ }
+
+ private void releaseLock() {
+ if (state() != State.NEW) {
+ stateLock().unlock();
+ }
+ }
+
+ private State state() {
+ return this.asyncWebRequest.state;
+ }
+
+ private Lock stateLock() {
+ return this.asyncWebRequest.stateLock;
+ }
+
+ }
+
+
+ /**
+ * Wraps a PrintWriter to prevent use after Servlet container onError
+ * notifications, and after async request completion.
+ */
+ private static final class LifecyclePrintWriter extends PrintWriter {
+
+ private final PrintWriter delegate;
+
+ private final StandardServletAsyncWebRequest asyncWebRequest;
+
+ private LifecyclePrintWriter(PrintWriter delegate, StandardServletAsyncWebRequest asyncWebRequest) {
+ super(delegate);
+ this.delegate = delegate;
+ this.asyncWebRequest = asyncWebRequest;
+ }
+
+ @Override
+ public void flush() {
+ if (tryObtainLockAndCheckState()) {
+ try {
+ this.delegate.flush();
+ }
+ finally {
+ releaseLock();
+ }
+ }
+ }
+
+ @Override
+ public void close() {
+ if (tryObtainLockAndCheckState()) {
+ try {
+ this.delegate.close();
+ }
+ finally {
+ releaseLock();
+ }
+ }
+ }
+
+ @Override
+ public boolean checkError() {
+ return this.delegate.checkError();
+ }
+
+ @Override
+ public void write(int c) {
+ if (tryObtainLockAndCheckState()) {
+ try {
+ this.delegate.write(c);
+ }
+ finally {
+ releaseLock();
+ }
+ }
+ }
+
+ @Override
+ public void write(char[] buf, int off, int len) {
+ if (tryObtainLockAndCheckState()) {
+ try {
+ this.delegate.write(buf, off, len);
+ }
+ finally {
+ releaseLock();
+ }
+ }
+ }
+
+ @Override
+ public void write(char[] buf) {
+ this.delegate.write(buf);
+ }
+
+ @Override
+ public void write(String s, int off, int len) {
+ if (tryObtainLockAndCheckState()) {
+ try {
+ this.delegate.write(s, off, len);
+ }
+ finally {
+ releaseLock();
+ }
+ }
+ }
+
+ @Override
+ public void write(String s) {
+ this.delegate.write(s);
+ }
+
+ private boolean tryObtainLockAndCheckState() {
+ if (state() == State.NEW) {
+ return true;
+ }
+ if (stateLock().tryLock()) {
+ if (state() == State.ASYNC) {
+ return true;
+ }
+ stateLock().unlock();
+ }
+ return false;
+ }
+
+ private void releaseLock() {
+ if (state() != State.NEW) {
+ stateLock().unlock();
+ }
+ }
+
+ private State state() {
+ return this.asyncWebRequest.state;
+ }
+
+ private Lock stateLock() {
+ return this.asyncWebRequest.stateLock;
+ }
+
+ // Plain delegates
+
+ @Override
+ public void print(boolean b) {
+ this.delegate.print(b);
+ }
+
+ @Override
+ public void print(char c) {
+ this.delegate.print(c);
+ }
+
+ @Override
+ public void print(int i) {
+ this.delegate.print(i);
+ }
+
+ @Override
+ public void print(long l) {
+ this.delegate.print(l);
+ }
+
+ @Override
+ public void print(float f) {
+ this.delegate.print(f);
+ }
+
+ @Override
+ public void print(double d) {
+ this.delegate.print(d);
+ }
+
+ @Override
+ public void print(char[] s) {
+ this.delegate.print(s);
+ }
+
+ @Override
+ public void print(String s) {
+ this.delegate.print(s);
+ }
+
+ @Override
+ public void print(Object obj) {
+ this.delegate.print(obj);
+ }
+
+ @Override
+ public void println() {
+ this.delegate.println();
+ }
+
+ @Override
+ public void println(boolean x) {
+ this.delegate.println(x);
+ }
+
+ @Override
+ public void println(char x) {
+ this.delegate.println(x);
+ }
+
+ @Override
+ public void println(int x) {
+ this.delegate.println(x);
+ }
+
+ @Override
+ public void println(long x) {
+ this.delegate.println(x);
+ }
+
+ @Override
+ public void println(float x) {
+ this.delegate.println(x);
+ }
+
+ @Override
+ public void println(double x) {
+ this.delegate.println(x);
+ }
+
+ @Override
+ public void println(char[] x) {
+ this.delegate.println(x);
+ }
+
+ @Override
+ public void println(String x) {
+ this.delegate.println(x);
+ }
+
+ @Override
+ public void println(Object x) {
+ this.delegate.println(x);
+ }
+
+ @Override
+ public PrintWriter printf(String format, Object... args) {
+ return this.delegate.printf(format, args);
+ }
+
+ @Override
+ public PrintWriter printf(Locale l, String format, Object... args) {
+ return this.delegate.printf(l, format, args);
+ }
+
+ @Override
+ public PrintWriter format(String format, Object... args) {
+ return this.delegate.format(format, args);
+ }
+
+ @Override
+ public PrintWriter format(Locale l, String format, Object... args) {
+ return this.delegate.format(l, format, args);
+ }
+
+ @Override
+ public PrintWriter append(CharSequence csq) {
+ return this.delegate.append(csq);
+ }
+
+ @Override
+ public PrintWriter append(CharSequence csq, int start, int end) {
+ return this.delegate.append(csq, start, end);
+ }
+
+ @Override
+ public PrintWriter append(char c) {
+ return this.delegate.append(c);
+ }
+ }
+
+
+ /**
+ * Represents a state for {@link StandardServletAsyncWebRequest} to be in.
+ *
+ * NEW
+ * |
+ * v
+ * ASYNC----> +
+ * | |
+ * v |
+ * ERROR |
+ * | |
+ * v |
+ * COMPLETED <--+
+ *
+ * @since 5.3.33
+ */
+ private enum State {
+
+ /** New request (thas may not do async handling). */
+ NEW,
+
+ /** Async handling has started. */
+ ASYNC,
+
+ /** onError notification received, or ServletOutputStream failed. */
+ ERROR,
+
+ /** onComplete notification received. */
+ COMPLETED
+
}
}
diff --git a/spring-web/src/main/java/org/springframework/web/context/request/async/WebAsyncManager.java b/spring-web/src/main/java/org/springframework/web/context/request/async/WebAsyncManager.java
index 35e901b59ea..59661bde988 100644
--- a/spring-web/src/main/java/org/springframework/web/context/request/async/WebAsyncManager.java
+++ b/spring-web/src/main/java/org/springframework/web/context/request/async/WebAsyncManager.java
@@ -22,6 +22,7 @@ import java.util.List;
import java.util.Map;
import java.util.concurrent.Callable;
import java.util.concurrent.Future;
+import java.util.concurrent.atomic.AtomicReference;
import jakarta.servlet.http.HttpServletRequest;
import org.apache.commons.logging.Log;
@@ -33,7 +34,6 @@ import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.web.context.request.RequestAttributes;
import org.springframework.web.context.request.async.DeferredResult.DeferredResultHandler;
-import org.springframework.web.util.DisconnectedClientHelper;
/**
* The central class for managing asynchronous request processing, mainly intended
@@ -67,16 +67,6 @@ public final class WebAsyncManager {
private static final Log logger = LogFactory.getLog(WebAsyncManager.class);
- /**
- * Log category to use for network failure after a client has gone away.
- * @see DisconnectedClientHelper
- */
- private static final String DISCONNECTED_CLIENT_LOG_CATEGORY =
- "org.springframework.web.server.DisconnectedClient";
-
- private static final DisconnectedClientHelper disconnectedClientHelper =
- new DisconnectedClientHelper(DISCONNECTED_CLIENT_LOG_CATEGORY);
-
private static final CallableProcessingInterceptor timeoutCallableInterceptor =
new TimeoutCallableProcessingInterceptor();
@@ -95,12 +85,7 @@ public final class WebAsyncManager {
@Nullable
private volatile Object[] concurrentResultContext;
- /*
- * Whether the concurrentResult is an error. If such errors remain unhandled, some
- * Servlet containers will call AsyncListener#onError at the end, after the ASYNC
- * and/or the ERROR dispatch (Boot's case), and we need to ignore those.
- */
- private volatile boolean errorHandlingInProgress;
+ private final AtomicReference state = new AtomicReference<>(State.NOT_STARTED);
private final Map