diff --git a/spring-web/src/main/java/org/springframework/web/context/request/async/StandardServletAsyncWebRequest.java b/spring-web/src/main/java/org/springframework/web/context/request/async/StandardServletAsyncWebRequest.java index e83f5f1d9fa..67e137a55d5 100644 --- a/spring-web/src/main/java/org/springframework/web/context/request/async/StandardServletAsyncWebRequest.java +++ b/spring-web/src/main/java/org/springframework/web/context/request/async/StandardServletAsyncWebRequest.java @@ -21,7 +21,6 @@ import java.io.PrintWriter; import java.util.ArrayList; import java.util.List; import java.util.Locale; -import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; import java.util.function.Consumer; @@ -189,7 +188,7 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements public void onError(AsyncEvent event) throws IOException { this.stateLock.lock(); try { - transitionToErrorState(); + this.state = State.ERROR; Throwable ex = event.getThrowable(); this.exceptionHandlers.forEach(consumer -> consumer.accept(ex)); } @@ -198,12 +197,6 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements } } - private void transitionToErrorState() { - if (!isAsyncComplete()) { - this.state = State.ERROR; - } - } - @Override public void onComplete(AsyncEvent event) throws IOException { this.stateLock.lock(); @@ -218,8 +211,17 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements } + /** + * Package private access for testing only. + */ + ReentrantLock stateLock() { + return this.stateLock; + } + + /** * Response wrapper to wrap the output stream with {@link LifecycleServletOutputStream}. + * @since 5.3.33 */ private static final class LifecycleHttpServletResponse extends HttpServletResponseWrapper { @@ -242,26 +244,44 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements @Override public ServletOutputStream getOutputStream() throws IOException { - if (this.outputStream == null) { - Assert.notNull(this.asyncWebRequest, "Not initialized"); - ServletOutputStream delegate = getResponse().getOutputStream(); - this.outputStream = new LifecycleServletOutputStream(delegate, this); + int level = obtainLockAndCheckState(); + try { + if (this.outputStream == null) { + Assert.notNull(this.asyncWebRequest, "Not initialized"); + ServletOutputStream delegate = getResponse().getOutputStream(); + this.outputStream = new LifecycleServletOutputStream(delegate, this); + } + } + catch (IOException ex) { + handleIOException(ex, "Failed to get ServletResponseOutput"); + } + finally { + releaseLock(level); } return this.outputStream; } @Override public PrintWriter getWriter() throws IOException { - if (this.writer == null) { - Assert.notNull(this.asyncWebRequest, "Not initialized"); - this.writer = new LifecyclePrintWriter(getResponse().getWriter(), this.asyncWebRequest); + int level = obtainLockAndCheckState(); + try { + if (this.writer == null) { + Assert.notNull(this.asyncWebRequest, "Not initialized"); + this.writer = new LifecyclePrintWriter(getResponse().getWriter(), this.asyncWebRequest); + } + } + catch (IOException ex) { + handleIOException(ex, "Failed to get PrintWriter"); + } + finally { + releaseLock(level); } return this.writer; } @Override public void flushBuffer() throws IOException { - obtainLockAndCheckState(); + int level = obtainLockAndCheckState(); try { getResponse().flushBuffer(); } @@ -269,32 +289,40 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements handleIOException(ex, "ServletResponse failed to flushBuffer"); } finally { - releaseLock(); + releaseLock(level); } } - private void obtainLockAndCheckState() throws AsyncRequestNotUsableException { + /** + * Return 0 if checks passed and lock is not needed, 1 if checks passed + * and lock is held, or raise AsyncRequestNotUsableException. + */ + private int obtainLockAndCheckState() throws AsyncRequestNotUsableException { Assert.notNull(this.asyncWebRequest, "Not initialized"); - if (this.asyncWebRequest.state != State.NEW) { - this.asyncWebRequest.stateLock.lock(); - if (this.asyncWebRequest.state != State.ASYNC) { - this.asyncWebRequest.stateLock.unlock(); - throw new AsyncRequestNotUsableException("Response not usable after " + - (this.asyncWebRequest.state == State.COMPLETED ? - "async request completion" : "onError notification") + "."); - } + if (this.asyncWebRequest.state == State.NEW) { + return 0; + } + + this.asyncWebRequest.stateLock.lock(); + if (this.asyncWebRequest.state == State.ASYNC) { + return 1; } + + this.asyncWebRequest.stateLock.unlock(); + throw new AsyncRequestNotUsableException("Response not usable after " + + (this.asyncWebRequest.state == State.COMPLETED ? + "async request completion" : "response errors") + "."); } void handleIOException(IOException ex, String msg) throws AsyncRequestNotUsableException { Assert.notNull(this.asyncWebRequest, "Not initialized"); - this.asyncWebRequest.transitionToErrorState(); - throw new AsyncRequestNotUsableException(msg, ex); + this.asyncWebRequest.state = State.ERROR; + throw new AsyncRequestNotUsableException(msg + ": " + ex.getMessage(), ex); } - void releaseLock() { + void releaseLock(int level) { Assert.notNull(this.asyncWebRequest, "Not initialized"); - if (this.asyncWebRequest.state != State.NEW) { + if (level > 0) { this.asyncWebRequest.stateLock.unlock(); } } @@ -304,6 +332,7 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements /** * Wraps a ServletOutputStream to prevent use after Servlet container onError * notifications, and after async request completion. + * @since 5.3.33 */ private static final class LifecycleServletOutputStream extends ServletOutputStream { @@ -328,7 +357,7 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements @Override public void write(int b) throws IOException { - this.response.obtainLockAndCheckState(); + int level = this.response.obtainLockAndCheckState(); try { this.delegate.write(b); } @@ -336,12 +365,12 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements this.response.handleIOException(ex, "ServletOutputStream failed to write"); } finally { - this.response.releaseLock(); + this.response.releaseLock(level); } } public void write(byte[] buf, int offset, int len) throws IOException { - this.response.obtainLockAndCheckState(); + int level = this.response.obtainLockAndCheckState(); try { this.delegate.write(buf, offset, len); } @@ -349,13 +378,13 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements this.response.handleIOException(ex, "ServletOutputStream failed to write"); } finally { - this.response.releaseLock(); + this.response.releaseLock(level); } } @Override public void flush() throws IOException { - this.response.obtainLockAndCheckState(); + int level = this.response.obtainLockAndCheckState(); try { this.delegate.flush(); } @@ -363,13 +392,13 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements this.response.handleIOException(ex, "ServletOutputStream failed to flush"); } finally { - this.response.releaseLock(); + this.response.releaseLock(level); } } @Override public void close() throws IOException { - this.response.obtainLockAndCheckState(); + int level = this.response.obtainLockAndCheckState(); try { this.delegate.close(); } @@ -377,7 +406,7 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements this.response.handleIOException(ex, "ServletOutputStream failed to close"); } finally { - this.response.releaseLock(); + this.response.releaseLock(level); } } @@ -387,6 +416,7 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements /** * Wraps a PrintWriter to prevent use after Servlet container onError * notifications, and after async request completion. + * @since 5.3.33 */ private static final class LifecyclePrintWriter extends PrintWriter { @@ -402,24 +432,26 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements @Override public void flush() { - if (tryObtainLockAndCheckState()) { + int level = tryObtainLockAndCheckState(); + if (level > -1) { try { this.delegate.flush(); } finally { - releaseLock(); + releaseLock(level); } } } @Override public void close() { - if (tryObtainLockAndCheckState()) { + int level = tryObtainLockAndCheckState(); + if (level > -1) { try { this.delegate.close(); } finally { - releaseLock(); + releaseLock(level); } } } @@ -431,24 +463,26 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements @Override public void write(int c) { - if (tryObtainLockAndCheckState()) { + int level = tryObtainLockAndCheckState(); + if (level > -1) { try { this.delegate.write(c); } finally { - releaseLock(); + releaseLock(level); } } } @Override public void write(char[] buf, int off, int len) { - if (tryObtainLockAndCheckState()) { + int level = tryObtainLockAndCheckState(); + if (level > -1) { try { this.delegate.write(buf, off, len); } finally { - releaseLock(); + releaseLock(level); } } } @@ -460,12 +494,13 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements @Override public void write(String s, int off, int len) { - if (tryObtainLockAndCheckState()) { + int level = tryObtainLockAndCheckState(); + if (level > -1) { try { this.delegate.write(s, off, len); } finally { - releaseLock(); + releaseLock(level); } } } @@ -475,33 +510,28 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements this.delegate.write(s); } - private boolean tryObtainLockAndCheckState() { - if (state() == State.NEW) { - return true; + /** + * Return 0 if checks passed and lock is not needed, 1 if checks passed + * and lock is held, and -1 if checks did not pass. + */ + private int tryObtainLockAndCheckState() { + if (this.asyncWebRequest.state == State.NEW) { + return 0; } - if (stateLock().tryLock()) { - if (state() == State.ASYNC) { - return true; - } - stateLock().unlock(); + this.asyncWebRequest.stateLock.lock(); + if (this.asyncWebRequest.state == State.ASYNC) { + return 1; } - return false; + this.asyncWebRequest.stateLock.unlock(); + return -1; } - private void releaseLock() { - if (state() != State.NEW) { - stateLock().unlock(); + private void releaseLock(int level) { + if (level > 0) { + this.asyncWebRequest.stateLock.unlock(); } } - private State state() { - return this.asyncWebRequest.state; - } - - private Lock stateLock() { - return this.asyncWebRequest.stateLock; - } - // Plain delegates @Override @@ -639,28 +669,28 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements /** * Represents a state for {@link StandardServletAsyncWebRequest} to be in. *

-	 *        NEW
-	 *         |
-	 *         v
-	 *       ASYNC----> +
-	 *         |        |
-	 *         v        |
-	 *       ERROR      |
-	 *         |        |
-	 *         v        |
-	 *     COMPLETED <--+
+	 *    +------ NEW
+	 *    |        |
+	 *    |        v
+	 *    |      ASYNC ----> +
+	 *    |        |         |
+	 *    |        v         |
+	 *    +----> ERROR       |
+	 *             |         |
+	 *             v         |
+	 *         COMPLETED <---+
 	 * 
* @since 5.3.33 */ private enum State { - /** New request (thas may not do async handling). */ + /** New request (may not start async handling). */ NEW, /** Async handling has started. */ ASYNC, - /** onError notification received, or ServletOutputStream failed. */ + /** ServletOutputStream failed, or onError notification received. */ ERROR, /** onComplete notification received. */ diff --git a/spring-web/src/test/java/org/springframework/web/context/request/async/AsyncRequestNotUsableTests.java b/spring-web/src/test/java/org/springframework/web/context/request/async/AsyncRequestNotUsableTests.java new file mode 100644 index 00000000000..0033627dfb9 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/context/request/async/AsyncRequestNotUsableTests.java @@ -0,0 +1,357 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.request.async; + +import java.io.IOException; +import java.io.PrintWriter; +import java.util.concurrent.atomic.AtomicInteger; + +import jakarta.servlet.AsyncEvent; +import jakarta.servlet.ServletOutputStream; +import jakarta.servlet.http.HttpServletResponse; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.mockito.stubbing.Answer; + +import org.springframework.web.testfixture.servlet.MockAsyncContext; +import org.springframework.web.testfixture.servlet.MockHttpServletRequest; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.BDDMockito.doAnswer; +import static org.mockito.BDDMockito.doThrow; +import static org.mockito.BDDMockito.given; +import static org.mockito.BDDMockito.mock; +import static org.mockito.BDDMockito.verify; +import static org.mockito.BDDMockito.verifyNoInteractions; + +/** + * {@link StandardServletAsyncWebRequest} tests related to response wrapping in + * order to enforce thread safety and prevent use after errors. + * + * @author Rossen Stoyanchev + */ +public class AsyncRequestNotUsableTests { + + private final MockHttpServletRequest request = new MockHttpServletRequest(); + + private final HttpServletResponse response = mock(); + + private final ServletOutputStream outputStream = mock(); + + private final PrintWriter writer = mock(); + + private StandardServletAsyncWebRequest asyncRequest; + + + @BeforeEach + void setup() throws IOException { + this.request.setAsyncSupported(true); + given(this.response.getOutputStream()).willReturn(this.outputStream); + given(this.response.getWriter()).willReturn(this.writer); + + this.asyncRequest = new StandardServletAsyncWebRequest(this.request, this.response); + } + + @AfterEach + void tearDown() { + assertThat(this.asyncRequest.stateLock().isLocked()).isFalse(); + } + + + @SuppressWarnings("DataFlowIssue") + private ServletOutputStream getWrappedOutputStream() throws IOException { + return this.asyncRequest.getResponse().getOutputStream(); + } + + @SuppressWarnings("DataFlowIssue") + private PrintWriter getWrappedWriter() throws IOException { + return this.asyncRequest.getResponse().getWriter(); + } + + + @Nested + class ResponseTests { + + @Test + void notUsableAfterError() throws IOException { + asyncRequest.startAsync(); + asyncRequest.onError(new AsyncEvent(new MockAsyncContext(request, response), new Exception())); + + HttpServletResponse wrapped = asyncRequest.getResponse(); + assertThat(wrapped).isNotNull(); + assertThatThrownBy(wrapped::getOutputStream).hasMessage("Response not usable after response errors."); + assertThatThrownBy(wrapped::getWriter).hasMessage("Response not usable after response errors."); + assertThatThrownBy(wrapped::flushBuffer).hasMessage("Response not usable after response errors."); + } + + @Test + void notUsableAfterCompletion() throws IOException { + asyncRequest.startAsync(); + asyncRequest.onComplete(new AsyncEvent(new MockAsyncContext(request, response))); + + HttpServletResponse wrapped = asyncRequest.getResponse(); + assertThat(wrapped).isNotNull(); + assertThatThrownBy(wrapped::getOutputStream).hasMessage("Response not usable after async request completion."); + assertThatThrownBy(wrapped::getWriter).hasMessage("Response not usable after async request completion."); + assertThatThrownBy(wrapped::flushBuffer).hasMessage("Response not usable after async request completion."); + } + + @Test + void notUsableWhenRecreatedAfterCompletion() throws IOException { + asyncRequest.startAsync(); + asyncRequest.onComplete(new AsyncEvent(new MockAsyncContext(request, response))); + + StandardServletAsyncWebRequest newWebRequest = + new StandardServletAsyncWebRequest(request, response, asyncRequest); + + HttpServletResponse wrapped = newWebRequest.getResponse(); + assertThat(wrapped).isNotNull(); + assertThatThrownBy(wrapped::getOutputStream).hasMessage("Response not usable after async request completion."); + assertThatThrownBy(wrapped::getWriter).hasMessage("Response not usable after async request completion."); + assertThatThrownBy(wrapped::flushBuffer).hasMessage("Response not usable after async request completion."); + } + } + + + @Nested + class OutputStreamTests { + + @Test + void use() throws IOException { + testUseOutputStream(); + } + + @Test + void useInAsyncState() throws IOException { + asyncRequest.startAsync(); + testUseOutputStream(); + } + + private void testUseOutputStream() throws IOException { + ServletOutputStream wrapped = getWrappedOutputStream(); + + wrapped.write('a'); + wrapped.write(new byte[0], 1, 2); + wrapped.flush(); + wrapped.close(); + + verify(outputStream).write('a'); + verify(outputStream).write(new byte[0], 1, 2); + verify(outputStream).flush(); + verify(outputStream).close(); + } + + @Test + void notUsableAfterCompletion() throws IOException { + asyncRequest.startAsync(); + ServletOutputStream wrapped = getWrappedOutputStream(); + + asyncRequest.onComplete(new AsyncEvent(new MockAsyncContext(request, response))); + + assertThatThrownBy(() -> wrapped.write('a')).hasMessage("Response not usable after async request completion."); + assertThatThrownBy(() -> wrapped.write(new byte[0])).hasMessage("Response not usable after async request completion."); + assertThatThrownBy(() -> wrapped.write(new byte[0], 0, 0)).hasMessage("Response not usable after async request completion."); + assertThatThrownBy(wrapped::flush).hasMessage("Response not usable after async request completion."); + assertThatThrownBy(wrapped::close).hasMessage("Response not usable after async request completion."); + } + + @Test + void lockingNotUsed() throws IOException { + AtomicInteger count = new AtomicInteger(-1); + doAnswer((Answer) invocation -> { + count.set(asyncRequest.stateLock().getHoldCount()); + return null; + }).when(outputStream).write('a'); + + // Access ServletOutputStream in NEW state (no async handling) without locking + getWrappedOutputStream().write('a'); + + assertThat(count.get()).isEqualTo(0); + } + + @Test + void lockingUsedInAsyncState() throws IOException { + AtomicInteger count = new AtomicInteger(-1); + doAnswer((Answer) invocation -> { + count.set(asyncRequest.stateLock().getHoldCount()); + return null; + }).when(outputStream).write('a'); + + // Access ServletOutputStream in ASYNC state with locking + asyncRequest.startAsync(); + getWrappedOutputStream().write('a'); + + assertThat(count.get()).isEqualTo(1); + } + } + + + @Nested + class OutputStreamErrorTests { + + @Test + void writeInt() throws IOException { + asyncRequest.startAsync(); + ServletOutputStream wrapped = getWrappedOutputStream(); + + doThrow(new IOException("Broken pipe")).when(outputStream).write('a'); + assertThatThrownBy(() -> wrapped.write('a')).hasMessage("ServletOutputStream failed to write: Broken pipe"); + } + + @Test + void writeBytesFull() throws IOException { + asyncRequest.startAsync(); + ServletOutputStream wrapped = getWrappedOutputStream(); + + byte[] bytes = new byte[0]; + doThrow(new IOException("Broken pipe")).when(outputStream).write(bytes, 0, 0); + assertThatThrownBy(() -> wrapped.write(bytes)).hasMessage("ServletOutputStream failed to write: Broken pipe"); + } + + @Test + void writeBytes() throws IOException { + asyncRequest.startAsync(); + ServletOutputStream wrapped = getWrappedOutputStream(); + + byte[] bytes = new byte[0]; + doThrow(new IOException("Broken pipe")).when(outputStream).write(bytes, 0, 0); + assertThatThrownBy(() -> wrapped.write(bytes, 0, 0)).hasMessage("ServletOutputStream failed to write: Broken pipe"); + } + + @Test + void flush() throws IOException { + asyncRequest.startAsync(); + ServletOutputStream wrapped = getWrappedOutputStream(); + + doThrow(new IOException("Broken pipe")).when(outputStream).flush(); + assertThatThrownBy(wrapped::flush).hasMessage("ServletOutputStream failed to flush: Broken pipe"); + } + + @Test + void close() throws IOException { + asyncRequest.startAsync(); + ServletOutputStream wrapped = getWrappedOutputStream(); + + doThrow(new IOException("Broken pipe")).when(outputStream).close(); + assertThatThrownBy(wrapped::close).hasMessage("ServletOutputStream failed to close: Broken pipe"); + } + + @Test + void writeErrorPreventsFurtherWriting() throws IOException { + ServletOutputStream wrapped = getWrappedOutputStream(); + + doThrow(new IOException("Broken pipe")).when(outputStream).write('a'); + assertThatThrownBy(() -> wrapped.write('a')).hasMessage("ServletOutputStream failed to write: Broken pipe"); + assertThatThrownBy(() -> wrapped.write('a')).hasMessage("Response not usable after response errors."); + } + + @Test + void writeErrorInAsyncStatePreventsFurtherWriting() throws IOException { + asyncRequest.startAsync(); + ServletOutputStream wrapped = getWrappedOutputStream(); + + doThrow(new IOException("Broken pipe")).when(outputStream).write('a'); + assertThatThrownBy(() -> wrapped.write('a')).hasMessage("ServletOutputStream failed to write: Broken pipe"); + assertThatThrownBy(() -> wrapped.write('a')).hasMessage("Response not usable after response errors."); + } + } + + + @Nested + class WriterTests { + + @Test + void useWriter() throws IOException { + testUseWriter(); + } + + @Test + void useWriterInAsyncState() throws IOException { + asyncRequest.startAsync(); + testUseWriter(); + } + + private void testUseWriter() throws IOException { + PrintWriter wrapped = getWrappedWriter(); + + wrapped.write('a'); + wrapped.write(new char[0], 1, 2); + wrapped.write("abc", 1, 2); + wrapped.flush(); + wrapped.close(); + + verify(writer).write('a'); + verify(writer).write(new char[0], 1, 2); + verify(writer).write("abc", 1, 2); + verify(writer).flush(); + verify(writer).close(); + } + + @Test + void writerNotUsableAfterCompletion() throws IOException { + asyncRequest.startAsync(); + PrintWriter wrapped = getWrappedWriter(); + + asyncRequest.onComplete(new AsyncEvent(new MockAsyncContext(request, response))); + + char[] chars = new char[0]; + wrapped.write('a'); + wrapped.write(chars, 1, 2); + wrapped.flush(); + wrapped.close(); + + verifyNoInteractions(writer); + } + + @Test + void lockingNotUsed() throws IOException { + AtomicInteger count = new AtomicInteger(-1); + + doAnswer((Answer) invocation -> { + count.set(asyncRequest.stateLock().getHoldCount()); + return null; + }).when(writer).write('a'); + + // Use Writer in NEW state (no async handling) without locking + PrintWriter wrapped = getWrappedWriter(); + wrapped.write('a'); + + assertThat(count.get()).isEqualTo(0); + } + + @Test + void lockingUsedInAsyncState() throws IOException { + AtomicInteger count = new AtomicInteger(-1); + + doAnswer((Answer) invocation -> { + count.set(asyncRequest.stateLock().getHoldCount()); + return null; + }).when(writer).write('a'); + + // Use Writer in ASYNC state with locking + asyncRequest.startAsync(); + PrintWriter wrapped = getWrappedWriter(); + wrapped.write('a'); + + assertThat(count.get()).isEqualTo(1); + } + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/context/request/async/StandardServletAsyncWebRequestTests.java b/spring-web/src/test/java/org/springframework/web/context/request/async/StandardServletAsyncWebRequestTests.java index 280024d944e..af274c47ca0 100644 --- a/spring-web/src/test/java/org/springframework/web/context/request/async/StandardServletAsyncWebRequestTests.java +++ b/spring-web/src/test/java/org/springframework/web/context/request/async/StandardServletAsyncWebRequestTests.java @@ -20,6 +20,7 @@ import java.util.function.Consumer; import jakarta.servlet.AsyncEvent; import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; import org.springframework.web.testfixture.servlet.MockAsyncContext; @@ -32,7 +33,8 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; /** - * A test fixture with a {@link StandardServletAsyncWebRequest}. + * Tests for {@link StandardServletAsyncWebRequest}. + * * @author Rossen Stoyanchev */ class StandardServletAsyncWebRequestTests { @@ -49,120 +51,109 @@ class StandardServletAsyncWebRequestTests { this.request = new MockHttpServletRequest(); this.request.setAsyncSupported(true); this.response = new MockHttpServletResponse(); - this.asyncRequest = new StandardServletAsyncWebRequest(this.request, this.response); - this.asyncRequest.setTimeout(44*1000L); - } - - - @Test - void isAsyncStarted() { - assertThat(this.asyncRequest.isAsyncStarted()).isFalse(); - this.asyncRequest.startAsync(); - assertThat(this.asyncRequest.isAsyncStarted()).isTrue(); - } - - @Test - void startAsync() { - this.asyncRequest.startAsync(); - - MockAsyncContext context = (MockAsyncContext) this.request.getAsyncContext(); - assertThat(context).isNotNull(); - assertThat(context.getTimeout()).as("Timeout value not set").isEqualTo((44 * 1000)); - assertThat(context.getListeners()).containsExactly(this.asyncRequest); - } - - @Test - void startAsyncMultipleTimes() { - this.asyncRequest.startAsync(); - this.asyncRequest.startAsync(); - this.asyncRequest.startAsync(); - this.asyncRequest.startAsync(); // idempotent - MockAsyncContext context = (MockAsyncContext) this.request.getAsyncContext(); - assertThat(context).isNotNull(); - assertThat(context.getListeners()).hasSize(1); + this.asyncRequest = new StandardServletAsyncWebRequest(this.request, this.response); + this.asyncRequest.setTimeout(44 * 1000L); } - @Test - void startAsyncNotSupported() { - this.request.setAsyncSupported(false); - assertThatIllegalStateException().isThrownBy( - this.asyncRequest::startAsync) - .withMessageContaining("Async support must be enabled"); - } - @Test - void startAsyncAfterCompleted() throws Exception { - this.asyncRequest.onComplete(new AsyncEvent(new MockAsyncContext(this.request, this.response))); - assertThatIllegalStateException().isThrownBy(this.asyncRequest::startAsync) - .withMessage("Cannot start async: [COMPLETED]"); + @Nested + class StartAsync { + + @Test + void isAsyncStarted() { + assertThat(asyncRequest.isAsyncStarted()).isFalse(); + + asyncRequest.startAsync(); + + assertThat(asyncRequest.isAsyncStarted()).isTrue(); + } + + @Test + void startAsync() { + asyncRequest.startAsync(); + + MockAsyncContext context = (MockAsyncContext) request.getAsyncContext(); + assertThat(context).isNotNull(); + assertThat(context.getTimeout()).as("Timeout value not set").isEqualTo((44 * 1000)); + assertThat(context.getListeners()).containsExactly(asyncRequest); + } + + @Test + void startAsyncMultipleTimes() { + asyncRequest.startAsync(); + asyncRequest.startAsync(); + asyncRequest.startAsync(); + asyncRequest.startAsync(); + + MockAsyncContext context = (MockAsyncContext) request.getAsyncContext(); + assertThat(context).isNotNull(); + assertThat(context.getListeners()).hasSize(1); + } + + @Test + void startAsyncNotSupported() { + request.setAsyncSupported(false); + assertThatIllegalStateException() + .isThrownBy(asyncRequest::startAsync) + .withMessageContaining("Async support must be enabled"); + } + + @Test + void startAsyncAfterCompleted() throws Exception { + asyncRequest.startAsync(); + asyncRequest.onComplete(new AsyncEvent(new MockAsyncContext(request, response))); + + assertThatIllegalStateException() + .isThrownBy(asyncRequest::startAsync) + .withMessage("Cannot start async: [COMPLETED]"); + } + + @Test + void startAsyncAndSetTimeout() { + asyncRequest.startAsync(); + assertThatIllegalStateException().isThrownBy(() -> asyncRequest.setTimeout(25L)); + } } - @Test - void onTimeoutDefaultBehavior() throws Exception { - this.asyncRequest.onTimeout(new AsyncEvent(new MockAsyncContext(this.request, this.response))); - assertThat(this.response.getStatus()).isEqualTo(200); - } - @Test - void onTimeoutHandler() throws Exception { - Runnable timeoutHandler = mock(); - this.asyncRequest.addTimeoutHandler(timeoutHandler); - this.asyncRequest.onTimeout(new AsyncEvent(new MockAsyncContext(this.request, this.response))); - verify(timeoutHandler).run(); - } + @Nested + class AsyncListenerHandling { - @Test - void onErrorHandler() throws Exception { - Consumer errorHandler = mock(); - this.asyncRequest.addErrorHandler(errorHandler); - Exception e = new Exception(); - this.asyncRequest.onError(new AsyncEvent(new MockAsyncContext(this.request, this.response), e)); - verify(errorHandler).accept(e); - } + @Test + void onTimeoutHandler() throws Exception { + Runnable handler = mock(); + asyncRequest.addTimeoutHandler(handler); - @Test - void setTimeoutDuringConcurrentHandling() { - this.asyncRequest.startAsync(); - assertThatIllegalStateException().isThrownBy(() -> - this.asyncRequest.setTimeout(25L)); - } + asyncRequest.startAsync(); + asyncRequest.onTimeout(new AsyncEvent(new MockAsyncContext(request, response))); - @Test - void onCompletionHandler() throws Exception { - Runnable handler = mock(); - this.asyncRequest.addCompletionHandler(handler); + verify(handler).run(); + } - this.asyncRequest.startAsync(); - this.asyncRequest.onComplete(new AsyncEvent(this.request.getAsyncContext())); + @Test + void onErrorHandler() throws Exception { + Exception ex = new Exception(); + Consumer handler = mock(); + asyncRequest.addErrorHandler(handler); - verify(handler).run(); - assertThat(this.asyncRequest.isAsyncComplete()).isTrue(); - } + asyncRequest.startAsync(); + asyncRequest.onError(new AsyncEvent(new MockAsyncContext(request, response), ex)); - // SPR-13292 + verify(handler).accept(ex); + } - @Test - void onErrorHandlerAfterOnErrorEvent() throws Exception { - Consumer handler = mock(); - this.asyncRequest.addErrorHandler(handler); + @Test + void onCompletionHandler() throws Exception { + Runnable handler = mock(); + asyncRequest.addCompletionHandler(handler); - this.asyncRequest.startAsync(); - Exception e = new Exception(); - this.asyncRequest.onError(new AsyncEvent(this.request.getAsyncContext(), e)); + asyncRequest.startAsync(); + asyncRequest.onComplete(new AsyncEvent(request.getAsyncContext())); - verify(handler).accept(e); + verify(handler).run(); + assertThat(asyncRequest.isAsyncComplete()).isTrue(); + } } - @Test - void onCompletionHandlerAfterOnCompleteEvent() throws Exception { - Runnable handler = mock(); - this.asyncRequest.addCompletionHandler(handler); - - this.asyncRequest.startAsync(); - this.asyncRequest.onComplete(new AsyncEvent(this.request.getAsyncContext())); - - verify(handler).run(); - assertThat(this.asyncRequest.isAsyncComplete()).isTrue(); - } } diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/RequestMappingHandlerAdapterTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/RequestMappingHandlerAdapterTests.java index 1fbd5307cdc..7943b17fed5 100644 --- a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/RequestMappingHandlerAdapterTests.java +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/RequestMappingHandlerAdapterTests.java @@ -16,8 +16,11 @@ package org.springframework.web.servlet.mvc.method.annotation; +import java.io.IOException; +import java.io.OutputStream; import java.lang.reflect.Method; import java.lang.reflect.Type; +import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -25,6 +28,7 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import jakarta.servlet.AsyncEvent; import org.apache.groovy.util.Maps; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeEach; @@ -49,6 +53,10 @@ import org.springframework.web.bind.annotation.ModelAttribute; import org.springframework.web.bind.annotation.RequestBody; import org.springframework.web.bind.annotation.ResponseBody; import org.springframework.web.bind.annotation.SessionAttributes; +import org.springframework.web.context.request.async.AsyncRequestNotUsableException; +import org.springframework.web.context.request.async.StandardServletAsyncWebRequest; +import org.springframework.web.context.request.async.WebAsyncManager; +import org.springframework.web.context.request.async.WebAsyncUtils; import org.springframework.web.context.support.StaticWebApplicationContext; import org.springframework.web.method.HandlerMethod; import org.springframework.web.method.annotation.ModelMethodProcessor; @@ -58,10 +66,12 @@ import org.springframework.web.method.support.InvocableHandlerMethod; import org.springframework.web.servlet.DispatcherServlet; import org.springframework.web.servlet.FlashMap; import org.springframework.web.servlet.ModelAndView; +import org.springframework.web.testfixture.servlet.MockAsyncContext; import org.springframework.web.testfixture.servlet.MockHttpServletRequest; import org.springframework.web.testfixture.servlet.MockHttpServletResponse; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; /** * Tests for {@link RequestMappingHandlerAdapter}. @@ -285,6 +295,26 @@ class RequestMappingHandlerAdapterTests { assertThat(this.response.getContentAsString()).isEqualTo("Body: {foo=bar}"); } + @Test + void asyncRequestNotUsable() throws Exception { + + // Put AsyncWebRequest in ERROR state + StandardServletAsyncWebRequest asyncRequest = new StandardServletAsyncWebRequest(this.request, this.response); + asyncRequest.onError(new AsyncEvent(new MockAsyncContext(this.request, this.response), new Exception())); + + // Set it as the current AsyncWebRequest, from the initial REQUEST dispatch + WebAsyncManager asyncManager = WebAsyncUtils.getAsyncManager(this.request); + asyncManager.setAsyncWebRequest(asyncRequest); + + // AsyncWebRequest created for current dispatch should inherit state + HandlerMethod handlerMethod = handlerMethod(new TestController(), "handleOutputStream", OutputStream.class); + this.handlerAdapter.afterPropertiesSet(); + + // Use of response should be rejected + assertThatThrownBy(() -> this.handlerAdapter.handle(this.request, this.response, handlerMethod)) + .isInstanceOf(AsyncRequestNotUsableException.class); + } + private HandlerMethod handlerMethod(Object handler, String methodName, Class... paramTypes) throws Exception { Method method = handler.getClass().getDeclaredMethod(methodName, paramTypes); return new InvocableHandlerMethod(handler, method); @@ -321,6 +351,10 @@ class RequestMappingHandlerAdapterTests { public String handleBody(@Nullable @RequestBody Map body) { return "Body: " + body; } + + public void handleOutputStream(OutputStream outputStream) throws IOException { + outputStream.write("body".getBytes(StandardCharsets.UTF_8)); + } }