diff --git a/spring-web/src/main/java/org/springframework/web/context/request/async/StandardServletAsyncWebRequest.java b/spring-web/src/main/java/org/springframework/web/context/request/async/StandardServletAsyncWebRequest.java index aa48427b2ad..67e137a55d5 100644 --- a/spring-web/src/main/java/org/springframework/web/context/request/async/StandardServletAsyncWebRequest.java +++ b/spring-web/src/main/java/org/springframework/web/context/request/async/StandardServletAsyncWebRequest.java @@ -21,7 +21,6 @@ import java.io.PrintWriter; import java.util.ArrayList; import java.util.List; import java.util.Locale; -import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; import java.util.function.Consumer; @@ -189,7 +188,7 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements public void onError(AsyncEvent event) throws IOException { this.stateLock.lock(); try { - transitionToErrorState(); + this.state = State.ERROR; Throwable ex = event.getThrowable(); this.exceptionHandlers.forEach(consumer -> consumer.accept(ex)); } @@ -198,12 +197,6 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements } } - private void transitionToErrorState() { - if (!isAsyncComplete()) { - this.state = State.ERROR; - } - } - @Override public void onComplete(AsyncEvent event) throws IOException { this.stateLock.lock(); @@ -218,8 +211,17 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements } + /** + * Package private access for testing only. + */ + ReentrantLock stateLock() { + return this.stateLock; + } + + /** * Response wrapper to wrap the output stream with {@link LifecycleServletOutputStream}. + * @since 5.3.33 */ private static final class LifecycleHttpServletResponse extends HttpServletResponseWrapper { @@ -241,145 +243,180 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements } @Override - public ServletOutputStream getOutputStream() { - if (this.outputStream == null) { - Assert.notNull(this.asyncWebRequest, "Not initialized"); - this.outputStream = new LifecycleServletOutputStream( - (HttpServletResponse) getResponse(), this.asyncWebRequest); + public ServletOutputStream getOutputStream() throws IOException { + int level = obtainLockAndCheckState(); + try { + if (this.outputStream == null) { + Assert.notNull(this.asyncWebRequest, "Not initialized"); + ServletOutputStream delegate = getResponse().getOutputStream(); + this.outputStream = new LifecycleServletOutputStream(delegate, this); + } + } + catch (IOException ex) { + handleIOException(ex, "Failed to get ServletResponseOutput"); + } + finally { + releaseLock(level); } return this.outputStream; } @Override public PrintWriter getWriter() throws IOException { - if (this.writer == null) { - Assert.notNull(this.asyncWebRequest, "Not initialized"); - this.writer = new LifecyclePrintWriter(getResponse().getWriter(), this.asyncWebRequest); + int level = obtainLockAndCheckState(); + try { + if (this.writer == null) { + Assert.notNull(this.asyncWebRequest, "Not initialized"); + this.writer = new LifecyclePrintWriter(getResponse().getWriter(), this.asyncWebRequest); + } + } + catch (IOException ex) { + handleIOException(ex, "Failed to get PrintWriter"); + } + finally { + releaseLock(level); } return this.writer; } + + @Override + public void flushBuffer() throws IOException { + int level = obtainLockAndCheckState(); + try { + getResponse().flushBuffer(); + } + catch (IOException ex) { + handleIOException(ex, "ServletResponse failed to flushBuffer"); + } + finally { + releaseLock(level); + } + } + + /** + * Return 0 if checks passed and lock is not needed, 1 if checks passed + * and lock is held, or raise AsyncRequestNotUsableException. + */ + private int obtainLockAndCheckState() throws AsyncRequestNotUsableException { + Assert.notNull(this.asyncWebRequest, "Not initialized"); + if (this.asyncWebRequest.state == State.NEW) { + return 0; + } + + this.asyncWebRequest.stateLock.lock(); + if (this.asyncWebRequest.state == State.ASYNC) { + return 1; + } + + this.asyncWebRequest.stateLock.unlock(); + throw new AsyncRequestNotUsableException("Response not usable after " + + (this.asyncWebRequest.state == State.COMPLETED ? + "async request completion" : "response errors") + "."); + } + + void handleIOException(IOException ex, String msg) throws AsyncRequestNotUsableException { + Assert.notNull(this.asyncWebRequest, "Not initialized"); + this.asyncWebRequest.state = State.ERROR; + throw new AsyncRequestNotUsableException(msg + ": " + ex.getMessage(), ex); + } + + void releaseLock(int level) { + Assert.notNull(this.asyncWebRequest, "Not initialized"); + if (level > 0) { + this.asyncWebRequest.stateLock.unlock(); + } + } } /** * Wraps a ServletOutputStream to prevent use after Servlet container onError * notifications, and after async request completion. + * @since 5.3.33 */ private static final class LifecycleServletOutputStream extends ServletOutputStream { - private final HttpServletResponse delegate; - - private final StandardServletAsyncWebRequest asyncWebRequest; + private final ServletOutputStream delegate; - private LifecycleServletOutputStream( - HttpServletResponse delegate, StandardServletAsyncWebRequest asyncWebRequest) { + private final LifecycleHttpServletResponse response; + private LifecycleServletOutputStream(ServletOutputStream delegate, LifecycleHttpServletResponse response) { this.delegate = delegate; - this.asyncWebRequest = asyncWebRequest; + this.response = response; } @Override public boolean isReady() { - return false; + return this.delegate.isReady(); } @Override public void setWriteListener(WriteListener writeListener) { - throw new UnsupportedOperationException(); + this.delegate.setWriteListener(writeListener); } @Override public void write(int b) throws IOException { - obtainLockAndCheckState(); + int level = this.response.obtainLockAndCheckState(); try { - this.delegate.getOutputStream().write(b); + this.delegate.write(b); } catch (IOException ex) { - handleIOException(ex, "ServletOutputStream failed to write"); + this.response.handleIOException(ex, "ServletOutputStream failed to write"); } finally { - releaseLock(); + this.response.releaseLock(level); } } public void write(byte[] buf, int offset, int len) throws IOException { - obtainLockAndCheckState(); + int level = this.response.obtainLockAndCheckState(); try { - this.delegate.getOutputStream().write(buf, offset, len); + this.delegate.write(buf, offset, len); } catch (IOException ex) { - handleIOException(ex, "ServletOutputStream failed to write"); + this.response.handleIOException(ex, "ServletOutputStream failed to write"); } finally { - releaseLock(); + this.response.releaseLock(level); } } @Override public void flush() throws IOException { - obtainLockAndCheckState(); + int level = this.response.obtainLockAndCheckState(); try { - this.delegate.getOutputStream().flush(); + this.delegate.flush(); } catch (IOException ex) { - handleIOException(ex, "ServletOutputStream failed to flush"); + this.response.handleIOException(ex, "ServletOutputStream failed to flush"); } finally { - releaseLock(); + this.response.releaseLock(level); } } @Override public void close() throws IOException { - obtainLockAndCheckState(); + int level = this.response.obtainLockAndCheckState(); try { - this.delegate.getOutputStream().close(); + this.delegate.close(); } catch (IOException ex) { - handleIOException(ex, "ServletOutputStream failed to close"); + this.response.handleIOException(ex, "ServletOutputStream failed to close"); } finally { - releaseLock(); + this.response.releaseLock(level); } } - private void obtainLockAndCheckState() throws AsyncRequestNotUsableException { - if (state() != State.NEW) { - stateLock().lock(); - if (state() != State.ASYNC) { - stateLock().unlock(); - throw new AsyncRequestNotUsableException("Response not usable after " + - (state() == State.COMPLETED ? - "async request completion" : "onError notification") + "."); - } - } - } - - private void handleIOException(IOException ex, String msg) throws AsyncRequestNotUsableException { - this.asyncWebRequest.transitionToErrorState(); - throw new AsyncRequestNotUsableException(msg, ex); - } - - private void releaseLock() { - if (state() != State.NEW) { - stateLock().unlock(); - } - } - - private State state() { - return this.asyncWebRequest.state; - } - - private Lock stateLock() { - return this.asyncWebRequest.stateLock; - } - } /** * Wraps a PrintWriter to prevent use after Servlet container onError * notifications, and after async request completion. + * @since 5.3.33 */ private static final class LifecyclePrintWriter extends PrintWriter { @@ -395,24 +432,26 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements @Override public void flush() { - if (tryObtainLockAndCheckState()) { + int level = tryObtainLockAndCheckState(); + if (level > -1) { try { this.delegate.flush(); } finally { - releaseLock(); + releaseLock(level); } } } @Override public void close() { - if (tryObtainLockAndCheckState()) { + int level = tryObtainLockAndCheckState(); + if (level > -1) { try { this.delegate.close(); } finally { - releaseLock(); + releaseLock(level); } } } @@ -424,24 +463,26 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements @Override public void write(int c) { - if (tryObtainLockAndCheckState()) { + int level = tryObtainLockAndCheckState(); + if (level > -1) { try { this.delegate.write(c); } finally { - releaseLock(); + releaseLock(level); } } } @Override public void write(char[] buf, int off, int len) { - if (tryObtainLockAndCheckState()) { + int level = tryObtainLockAndCheckState(); + if (level > -1) { try { this.delegate.write(buf, off, len); } finally { - releaseLock(); + releaseLock(level); } } } @@ -453,12 +494,13 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements @Override public void write(String s, int off, int len) { - if (tryObtainLockAndCheckState()) { + int level = tryObtainLockAndCheckState(); + if (level > -1) { try { this.delegate.write(s, off, len); } finally { - releaseLock(); + releaseLock(level); } } } @@ -468,33 +510,28 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements this.delegate.write(s); } - private boolean tryObtainLockAndCheckState() { - if (state() == State.NEW) { - return true; + /** + * Return 0 if checks passed and lock is not needed, 1 if checks passed + * and lock is held, and -1 if checks did not pass. + */ + private int tryObtainLockAndCheckState() { + if (this.asyncWebRequest.state == State.NEW) { + return 0; } - if (stateLock().tryLock()) { - if (state() == State.ASYNC) { - return true; - } - stateLock().unlock(); + this.asyncWebRequest.stateLock.lock(); + if (this.asyncWebRequest.state == State.ASYNC) { + return 1; } - return false; + this.asyncWebRequest.stateLock.unlock(); + return -1; } - private void releaseLock() { - if (state() != State.NEW) { - stateLock().unlock(); + private void releaseLock(int level) { + if (level > 0) { + this.asyncWebRequest.stateLock.unlock(); } } - private State state() { - return this.asyncWebRequest.state; - } - - private Lock stateLock() { - return this.asyncWebRequest.stateLock; - } - // Plain delegates @Override @@ -632,28 +669,28 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements /** * Represents a state for {@link StandardServletAsyncWebRequest} to be in. *

-	 *        NEW
-	 *         |
-	 *         v
-	 *       ASYNC----> +
-	 *         |        |
-	 *         v        |
-	 *       ERROR      |
-	 *         |        |
-	 *         v        |
-	 *     COMPLETED <--+
+	 *    +------ NEW
+	 *    |        |
+	 *    |        v
+	 *    |      ASYNC ----> +
+	 *    |        |         |
+	 *    |        v         |
+	 *    +----> ERROR       |
+	 *             |         |
+	 *             v         |
+	 *         COMPLETED <---+
 	 * 
* @since 5.3.33 */ private enum State { - /** New request (thas may not do async handling). */ + /** New request (may not start async handling). */ NEW, /** Async handling has started. */ ASYNC, - /** onError notification received, or ServletOutputStream failed. */ + /** ServletOutputStream failed, or onError notification received. */ ERROR, /** onComplete notification received. */ diff --git a/spring-web/src/main/java/org/springframework/web/context/request/async/WebAsyncManager.java b/spring-web/src/main/java/org/springframework/web/context/request/async/WebAsyncManager.java index 5ff3485a670..56b3d84e5e5 100644 --- a/spring-web/src/main/java/org/springframework/web/context/request/async/WebAsyncManager.java +++ b/spring-web/src/main/java/org/springframework/web/context/request/async/WebAsyncManager.java @@ -514,14 +514,11 @@ public final class WebAsyncManager { } this.asyncWebRequest.startAsync(); - if (logger.isDebugEnabled()) { - logger.debug("Started async request"); - } } private static String formatUri(AsyncWebRequest asyncWebRequest) { HttpServletRequest request = asyncWebRequest.getNativeRequest(HttpServletRequest.class); - return (request != null ? request.getRequestURI() : "servlet container"); + return (request != null ? "\"" + request.getRequestURI() + "\"" : "servlet container"); } diff --git a/spring-web/src/test/java/org/springframework/web/context/request/async/AsyncRequestNotUsableTests.java b/spring-web/src/test/java/org/springframework/web/context/request/async/AsyncRequestNotUsableTests.java new file mode 100644 index 00000000000..0033627dfb9 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/context/request/async/AsyncRequestNotUsableTests.java @@ -0,0 +1,357 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.request.async; + +import java.io.IOException; +import java.io.PrintWriter; +import java.util.concurrent.atomic.AtomicInteger; + +import jakarta.servlet.AsyncEvent; +import jakarta.servlet.ServletOutputStream; +import jakarta.servlet.http.HttpServletResponse; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.mockito.stubbing.Answer; + +import org.springframework.web.testfixture.servlet.MockAsyncContext; +import org.springframework.web.testfixture.servlet.MockHttpServletRequest; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.BDDMockito.doAnswer; +import static org.mockito.BDDMockito.doThrow; +import static org.mockito.BDDMockito.given; +import static org.mockito.BDDMockito.mock; +import static org.mockito.BDDMockito.verify; +import static org.mockito.BDDMockito.verifyNoInteractions; + +/** + * {@link StandardServletAsyncWebRequest} tests related to response wrapping in + * order to enforce thread safety and prevent use after errors. + * + * @author Rossen Stoyanchev + */ +public class AsyncRequestNotUsableTests { + + private final MockHttpServletRequest request = new MockHttpServletRequest(); + + private final HttpServletResponse response = mock(); + + private final ServletOutputStream outputStream = mock(); + + private final PrintWriter writer = mock(); + + private StandardServletAsyncWebRequest asyncRequest; + + + @BeforeEach + void setup() throws IOException { + this.request.setAsyncSupported(true); + given(this.response.getOutputStream()).willReturn(this.outputStream); + given(this.response.getWriter()).willReturn(this.writer); + + this.asyncRequest = new StandardServletAsyncWebRequest(this.request, this.response); + } + + @AfterEach + void tearDown() { + assertThat(this.asyncRequest.stateLock().isLocked()).isFalse(); + } + + + @SuppressWarnings("DataFlowIssue") + private ServletOutputStream getWrappedOutputStream() throws IOException { + return this.asyncRequest.getResponse().getOutputStream(); + } + + @SuppressWarnings("DataFlowIssue") + private PrintWriter getWrappedWriter() throws IOException { + return this.asyncRequest.getResponse().getWriter(); + } + + + @Nested + class ResponseTests { + + @Test + void notUsableAfterError() throws IOException { + asyncRequest.startAsync(); + asyncRequest.onError(new AsyncEvent(new MockAsyncContext(request, response), new Exception())); + + HttpServletResponse wrapped = asyncRequest.getResponse(); + assertThat(wrapped).isNotNull(); + assertThatThrownBy(wrapped::getOutputStream).hasMessage("Response not usable after response errors."); + assertThatThrownBy(wrapped::getWriter).hasMessage("Response not usable after response errors."); + assertThatThrownBy(wrapped::flushBuffer).hasMessage("Response not usable after response errors."); + } + + @Test + void notUsableAfterCompletion() throws IOException { + asyncRequest.startAsync(); + asyncRequest.onComplete(new AsyncEvent(new MockAsyncContext(request, response))); + + HttpServletResponse wrapped = asyncRequest.getResponse(); + assertThat(wrapped).isNotNull(); + assertThatThrownBy(wrapped::getOutputStream).hasMessage("Response not usable after async request completion."); + assertThatThrownBy(wrapped::getWriter).hasMessage("Response not usable after async request completion."); + assertThatThrownBy(wrapped::flushBuffer).hasMessage("Response not usable after async request completion."); + } + + @Test + void notUsableWhenRecreatedAfterCompletion() throws IOException { + asyncRequest.startAsync(); + asyncRequest.onComplete(new AsyncEvent(new MockAsyncContext(request, response))); + + StandardServletAsyncWebRequest newWebRequest = + new StandardServletAsyncWebRequest(request, response, asyncRequest); + + HttpServletResponse wrapped = newWebRequest.getResponse(); + assertThat(wrapped).isNotNull(); + assertThatThrownBy(wrapped::getOutputStream).hasMessage("Response not usable after async request completion."); + assertThatThrownBy(wrapped::getWriter).hasMessage("Response not usable after async request completion."); + assertThatThrownBy(wrapped::flushBuffer).hasMessage("Response not usable after async request completion."); + } + } + + + @Nested + class OutputStreamTests { + + @Test + void use() throws IOException { + testUseOutputStream(); + } + + @Test + void useInAsyncState() throws IOException { + asyncRequest.startAsync(); + testUseOutputStream(); + } + + private void testUseOutputStream() throws IOException { + ServletOutputStream wrapped = getWrappedOutputStream(); + + wrapped.write('a'); + wrapped.write(new byte[0], 1, 2); + wrapped.flush(); + wrapped.close(); + + verify(outputStream).write('a'); + verify(outputStream).write(new byte[0], 1, 2); + verify(outputStream).flush(); + verify(outputStream).close(); + } + + @Test + void notUsableAfterCompletion() throws IOException { + asyncRequest.startAsync(); + ServletOutputStream wrapped = getWrappedOutputStream(); + + asyncRequest.onComplete(new AsyncEvent(new MockAsyncContext(request, response))); + + assertThatThrownBy(() -> wrapped.write('a')).hasMessage("Response not usable after async request completion."); + assertThatThrownBy(() -> wrapped.write(new byte[0])).hasMessage("Response not usable after async request completion."); + assertThatThrownBy(() -> wrapped.write(new byte[0], 0, 0)).hasMessage("Response not usable after async request completion."); + assertThatThrownBy(wrapped::flush).hasMessage("Response not usable after async request completion."); + assertThatThrownBy(wrapped::close).hasMessage("Response not usable after async request completion."); + } + + @Test + void lockingNotUsed() throws IOException { + AtomicInteger count = new AtomicInteger(-1); + doAnswer((Answer) invocation -> { + count.set(asyncRequest.stateLock().getHoldCount()); + return null; + }).when(outputStream).write('a'); + + // Access ServletOutputStream in NEW state (no async handling) without locking + getWrappedOutputStream().write('a'); + + assertThat(count.get()).isEqualTo(0); + } + + @Test + void lockingUsedInAsyncState() throws IOException { + AtomicInteger count = new AtomicInteger(-1); + doAnswer((Answer) invocation -> { + count.set(asyncRequest.stateLock().getHoldCount()); + return null; + }).when(outputStream).write('a'); + + // Access ServletOutputStream in ASYNC state with locking + asyncRequest.startAsync(); + getWrappedOutputStream().write('a'); + + assertThat(count.get()).isEqualTo(1); + } + } + + + @Nested + class OutputStreamErrorTests { + + @Test + void writeInt() throws IOException { + asyncRequest.startAsync(); + ServletOutputStream wrapped = getWrappedOutputStream(); + + doThrow(new IOException("Broken pipe")).when(outputStream).write('a'); + assertThatThrownBy(() -> wrapped.write('a')).hasMessage("ServletOutputStream failed to write: Broken pipe"); + } + + @Test + void writeBytesFull() throws IOException { + asyncRequest.startAsync(); + ServletOutputStream wrapped = getWrappedOutputStream(); + + byte[] bytes = new byte[0]; + doThrow(new IOException("Broken pipe")).when(outputStream).write(bytes, 0, 0); + assertThatThrownBy(() -> wrapped.write(bytes)).hasMessage("ServletOutputStream failed to write: Broken pipe"); + } + + @Test + void writeBytes() throws IOException { + asyncRequest.startAsync(); + ServletOutputStream wrapped = getWrappedOutputStream(); + + byte[] bytes = new byte[0]; + doThrow(new IOException("Broken pipe")).when(outputStream).write(bytes, 0, 0); + assertThatThrownBy(() -> wrapped.write(bytes, 0, 0)).hasMessage("ServletOutputStream failed to write: Broken pipe"); + } + + @Test + void flush() throws IOException { + asyncRequest.startAsync(); + ServletOutputStream wrapped = getWrappedOutputStream(); + + doThrow(new IOException("Broken pipe")).when(outputStream).flush(); + assertThatThrownBy(wrapped::flush).hasMessage("ServletOutputStream failed to flush: Broken pipe"); + } + + @Test + void close() throws IOException { + asyncRequest.startAsync(); + ServletOutputStream wrapped = getWrappedOutputStream(); + + doThrow(new IOException("Broken pipe")).when(outputStream).close(); + assertThatThrownBy(wrapped::close).hasMessage("ServletOutputStream failed to close: Broken pipe"); + } + + @Test + void writeErrorPreventsFurtherWriting() throws IOException { + ServletOutputStream wrapped = getWrappedOutputStream(); + + doThrow(new IOException("Broken pipe")).when(outputStream).write('a'); + assertThatThrownBy(() -> wrapped.write('a')).hasMessage("ServletOutputStream failed to write: Broken pipe"); + assertThatThrownBy(() -> wrapped.write('a')).hasMessage("Response not usable after response errors."); + } + + @Test + void writeErrorInAsyncStatePreventsFurtherWriting() throws IOException { + asyncRequest.startAsync(); + ServletOutputStream wrapped = getWrappedOutputStream(); + + doThrow(new IOException("Broken pipe")).when(outputStream).write('a'); + assertThatThrownBy(() -> wrapped.write('a')).hasMessage("ServletOutputStream failed to write: Broken pipe"); + assertThatThrownBy(() -> wrapped.write('a')).hasMessage("Response not usable after response errors."); + } + } + + + @Nested + class WriterTests { + + @Test + void useWriter() throws IOException { + testUseWriter(); + } + + @Test + void useWriterInAsyncState() throws IOException { + asyncRequest.startAsync(); + testUseWriter(); + } + + private void testUseWriter() throws IOException { + PrintWriter wrapped = getWrappedWriter(); + + wrapped.write('a'); + wrapped.write(new char[0], 1, 2); + wrapped.write("abc", 1, 2); + wrapped.flush(); + wrapped.close(); + + verify(writer).write('a'); + verify(writer).write(new char[0], 1, 2); + verify(writer).write("abc", 1, 2); + verify(writer).flush(); + verify(writer).close(); + } + + @Test + void writerNotUsableAfterCompletion() throws IOException { + asyncRequest.startAsync(); + PrintWriter wrapped = getWrappedWriter(); + + asyncRequest.onComplete(new AsyncEvent(new MockAsyncContext(request, response))); + + char[] chars = new char[0]; + wrapped.write('a'); + wrapped.write(chars, 1, 2); + wrapped.flush(); + wrapped.close(); + + verifyNoInteractions(writer); + } + + @Test + void lockingNotUsed() throws IOException { + AtomicInteger count = new AtomicInteger(-1); + + doAnswer((Answer) invocation -> { + count.set(asyncRequest.stateLock().getHoldCount()); + return null; + }).when(writer).write('a'); + + // Use Writer in NEW state (no async handling) without locking + PrintWriter wrapped = getWrappedWriter(); + wrapped.write('a'); + + assertThat(count.get()).isEqualTo(0); + } + + @Test + void lockingUsedInAsyncState() throws IOException { + AtomicInteger count = new AtomicInteger(-1); + + doAnswer((Answer) invocation -> { + count.set(asyncRequest.stateLock().getHoldCount()); + return null; + }).when(writer).write('a'); + + // Use Writer in ASYNC state with locking + asyncRequest.startAsync(); + PrintWriter wrapped = getWrappedWriter(); + wrapped.write('a'); + + assertThat(count.get()).isEqualTo(1); + } + } + +} diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/RequestMappingHandlerAdapterTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/RequestMappingHandlerAdapterTests.java index cdf6f16b283..ea5b5a00ac2 100644 --- a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/RequestMappingHandlerAdapterTests.java +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/RequestMappingHandlerAdapterTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,8 +16,11 @@ package org.springframework.web.servlet.mvc.method.annotation; +import java.io.IOException; +import java.io.OutputStream; import java.lang.reflect.Method; import java.lang.reflect.Type; +import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -25,6 +28,7 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import jakarta.servlet.AsyncEvent; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -46,6 +50,10 @@ import org.springframework.ui.Model; import org.springframework.web.bind.annotation.ControllerAdvice; import org.springframework.web.bind.annotation.ModelAttribute; import org.springframework.web.bind.annotation.SessionAttributes; +import org.springframework.web.context.request.async.AsyncRequestNotUsableException; +import org.springframework.web.context.request.async.StandardServletAsyncWebRequest; +import org.springframework.web.context.request.async.WebAsyncManager; +import org.springframework.web.context.request.async.WebAsyncUtils; import org.springframework.web.context.support.StaticWebApplicationContext; import org.springframework.web.method.HandlerMethod; import org.springframework.web.method.annotation.ModelMethodProcessor; @@ -55,13 +63,15 @@ import org.springframework.web.method.support.InvocableHandlerMethod; import org.springframework.web.servlet.DispatcherServlet; import org.springframework.web.servlet.FlashMap; import org.springframework.web.servlet.ModelAndView; +import org.springframework.web.testfixture.servlet.MockAsyncContext; import org.springframework.web.testfixture.servlet.MockHttpServletRequest; import org.springframework.web.testfixture.servlet.MockHttpServletResponse; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; /** - * Unit tests for {@link RequestMappingHandlerAdapter}. + * Tests for {@link RequestMappingHandlerAdapter}. * * @author Rossen Stoyanchev * @author Sam Brannen @@ -249,9 +259,7 @@ public class RequestMappingHandlerAdapterTests { assertThat(mav.getModel().get("attr3")).isNull(); } - // SPR-10859 - - @Test + @Test // gh-15486 public void responseBodyAdvice() throws Exception { List> converters = new ArrayList<>(); converters.add(new MappingJackson2HttpMessageConverter()); @@ -271,6 +279,26 @@ public class RequestMappingHandlerAdapterTests { assertThat(this.response.getContentAsString()).isEqualTo("{\"status\":400,\"message\":\"body\"}"); } + @Test + void asyncRequestNotUsable() throws Exception { + + // Put AsyncWebRequest in ERROR state + StandardServletAsyncWebRequest asyncRequest = new StandardServletAsyncWebRequest(this.request, this.response); + asyncRequest.onError(new AsyncEvent(new MockAsyncContext(this.request, this.response), new Exception())); + + // Set it as the current AsyncWebRequest, from the initial REQUEST dispatch + WebAsyncManager asyncManager = WebAsyncUtils.getAsyncManager(this.request); + asyncManager.setAsyncWebRequest(asyncRequest); + + // AsyncWebRequest created for current dispatch should inherit state + HandlerMethod handlerMethod = handlerMethod(new SimpleController(), "handleOutputStream", OutputStream.class); + this.handlerAdapter.afterPropertiesSet(); + + // Use of response should be rejected + assertThatThrownBy(() -> this.handlerAdapter.handle(this.request, this.response, handlerMethod)) + .isInstanceOf(AsyncRequestNotUsableException.class); + } + private HandlerMethod handlerMethod(Object handler, String methodName, Class... paramTypes) throws Exception { Method method = handler.getClass().getDeclaredMethod(methodName, paramTypes); return new InvocableHandlerMethod(handler, method); @@ -296,14 +324,16 @@ public class RequestMappingHandlerAdapterTests { } public ResponseEntity> handleWithResponseEntity() { - return new ResponseEntity<>(Collections.singletonMap( - "foo", "bar"), HttpStatus.OK); + return new ResponseEntity<>(Collections.singletonMap("foo", "bar"), HttpStatus.OK); } public ResponseEntity handleBadRequest() { return new ResponseEntity<>("body", HttpStatus.BAD_REQUEST); } + public void handleOutputStream(OutputStream outputStream) throws IOException { + outputStream.write("body".getBytes(StandardCharsets.UTF_8)); + } }