diff --git a/spring-web/src/main/java/org/springframework/web/context/request/async/StandardServletAsyncWebRequest.java b/spring-web/src/main/java/org/springframework/web/context/request/async/StandardServletAsyncWebRequest.java index e83f5f1d9fa..67e137a55d5 100644 --- a/spring-web/src/main/java/org/springframework/web/context/request/async/StandardServletAsyncWebRequest.java +++ b/spring-web/src/main/java/org/springframework/web/context/request/async/StandardServletAsyncWebRequest.java @@ -21,7 +21,6 @@ import java.io.PrintWriter; import java.util.ArrayList; import java.util.List; import java.util.Locale; -import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; import java.util.function.Consumer; @@ -189,7 +188,7 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements public void onError(AsyncEvent event) throws IOException { this.stateLock.lock(); try { - transitionToErrorState(); + this.state = State.ERROR; Throwable ex = event.getThrowable(); this.exceptionHandlers.forEach(consumer -> consumer.accept(ex)); } @@ -198,12 +197,6 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements } } - private void transitionToErrorState() { - if (!isAsyncComplete()) { - this.state = State.ERROR; - } - } - @Override public void onComplete(AsyncEvent event) throws IOException { this.stateLock.lock(); @@ -218,8 +211,17 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements } + /** + * Package private access for testing only. + */ + ReentrantLock stateLock() { + return this.stateLock; + } + + /** * Response wrapper to wrap the output stream with {@link LifecycleServletOutputStream}. + * @since 5.3.33 */ private static final class LifecycleHttpServletResponse extends HttpServletResponseWrapper { @@ -242,26 +244,44 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements @Override public ServletOutputStream getOutputStream() throws IOException { - if (this.outputStream == null) { - Assert.notNull(this.asyncWebRequest, "Not initialized"); - ServletOutputStream delegate = getResponse().getOutputStream(); - this.outputStream = new LifecycleServletOutputStream(delegate, this); + int level = obtainLockAndCheckState(); + try { + if (this.outputStream == null) { + Assert.notNull(this.asyncWebRequest, "Not initialized"); + ServletOutputStream delegate = getResponse().getOutputStream(); + this.outputStream = new LifecycleServletOutputStream(delegate, this); + } + } + catch (IOException ex) { + handleIOException(ex, "Failed to get ServletResponseOutput"); + } + finally { + releaseLock(level); } return this.outputStream; } @Override public PrintWriter getWriter() throws IOException { - if (this.writer == null) { - Assert.notNull(this.asyncWebRequest, "Not initialized"); - this.writer = new LifecyclePrintWriter(getResponse().getWriter(), this.asyncWebRequest); + int level = obtainLockAndCheckState(); + try { + if (this.writer == null) { + Assert.notNull(this.asyncWebRequest, "Not initialized"); + this.writer = new LifecyclePrintWriter(getResponse().getWriter(), this.asyncWebRequest); + } + } + catch (IOException ex) { + handleIOException(ex, "Failed to get PrintWriter"); + } + finally { + releaseLock(level); } return this.writer; } @Override public void flushBuffer() throws IOException { - obtainLockAndCheckState(); + int level = obtainLockAndCheckState(); try { getResponse().flushBuffer(); } @@ -269,32 +289,40 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements handleIOException(ex, "ServletResponse failed to flushBuffer"); } finally { - releaseLock(); + releaseLock(level); } } - private void obtainLockAndCheckState() throws AsyncRequestNotUsableException { + /** + * Return 0 if checks passed and lock is not needed, 1 if checks passed + * and lock is held, or raise AsyncRequestNotUsableException. + */ + private int obtainLockAndCheckState() throws AsyncRequestNotUsableException { Assert.notNull(this.asyncWebRequest, "Not initialized"); - if (this.asyncWebRequest.state != State.NEW) { - this.asyncWebRequest.stateLock.lock(); - if (this.asyncWebRequest.state != State.ASYNC) { - this.asyncWebRequest.stateLock.unlock(); - throw new AsyncRequestNotUsableException("Response not usable after " + - (this.asyncWebRequest.state == State.COMPLETED ? - "async request completion" : "onError notification") + "."); - } + if (this.asyncWebRequest.state == State.NEW) { + return 0; + } + + this.asyncWebRequest.stateLock.lock(); + if (this.asyncWebRequest.state == State.ASYNC) { + return 1; } + + this.asyncWebRequest.stateLock.unlock(); + throw new AsyncRequestNotUsableException("Response not usable after " + + (this.asyncWebRequest.state == State.COMPLETED ? + "async request completion" : "response errors") + "."); } void handleIOException(IOException ex, String msg) throws AsyncRequestNotUsableException { Assert.notNull(this.asyncWebRequest, "Not initialized"); - this.asyncWebRequest.transitionToErrorState(); - throw new AsyncRequestNotUsableException(msg, ex); + this.asyncWebRequest.state = State.ERROR; + throw new AsyncRequestNotUsableException(msg + ": " + ex.getMessage(), ex); } - void releaseLock() { + void releaseLock(int level) { Assert.notNull(this.asyncWebRequest, "Not initialized"); - if (this.asyncWebRequest.state != State.NEW) { + if (level > 0) { this.asyncWebRequest.stateLock.unlock(); } } @@ -304,6 +332,7 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements /** * Wraps a ServletOutputStream to prevent use after Servlet container onError * notifications, and after async request completion. + * @since 5.3.33 */ private static final class LifecycleServletOutputStream extends ServletOutputStream { @@ -328,7 +357,7 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements @Override public void write(int b) throws IOException { - this.response.obtainLockAndCheckState(); + int level = this.response.obtainLockAndCheckState(); try { this.delegate.write(b); } @@ -336,12 +365,12 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements this.response.handleIOException(ex, "ServletOutputStream failed to write"); } finally { - this.response.releaseLock(); + this.response.releaseLock(level); } } public void write(byte[] buf, int offset, int len) throws IOException { - this.response.obtainLockAndCheckState(); + int level = this.response.obtainLockAndCheckState(); try { this.delegate.write(buf, offset, len); } @@ -349,13 +378,13 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements this.response.handleIOException(ex, "ServletOutputStream failed to write"); } finally { - this.response.releaseLock(); + this.response.releaseLock(level); } } @Override public void flush() throws IOException { - this.response.obtainLockAndCheckState(); + int level = this.response.obtainLockAndCheckState(); try { this.delegate.flush(); } @@ -363,13 +392,13 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements this.response.handleIOException(ex, "ServletOutputStream failed to flush"); } finally { - this.response.releaseLock(); + this.response.releaseLock(level); } } @Override public void close() throws IOException { - this.response.obtainLockAndCheckState(); + int level = this.response.obtainLockAndCheckState(); try { this.delegate.close(); } @@ -377,7 +406,7 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements this.response.handleIOException(ex, "ServletOutputStream failed to close"); } finally { - this.response.releaseLock(); + this.response.releaseLock(level); } } @@ -387,6 +416,7 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements /** * Wraps a PrintWriter to prevent use after Servlet container onError * notifications, and after async request completion. + * @since 5.3.33 */ private static final class LifecyclePrintWriter extends PrintWriter { @@ -402,24 +432,26 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements @Override public void flush() { - if (tryObtainLockAndCheckState()) { + int level = tryObtainLockAndCheckState(); + if (level > -1) { try { this.delegate.flush(); } finally { - releaseLock(); + releaseLock(level); } } } @Override public void close() { - if (tryObtainLockAndCheckState()) { + int level = tryObtainLockAndCheckState(); + if (level > -1) { try { this.delegate.close(); } finally { - releaseLock(); + releaseLock(level); } } } @@ -431,24 +463,26 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements @Override public void write(int c) { - if (tryObtainLockAndCheckState()) { + int level = tryObtainLockAndCheckState(); + if (level > -1) { try { this.delegate.write(c); } finally { - releaseLock(); + releaseLock(level); } } } @Override public void write(char[] buf, int off, int len) { - if (tryObtainLockAndCheckState()) { + int level = tryObtainLockAndCheckState(); + if (level > -1) { try { this.delegate.write(buf, off, len); } finally { - releaseLock(); + releaseLock(level); } } } @@ -460,12 +494,13 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements @Override public void write(String s, int off, int len) { - if (tryObtainLockAndCheckState()) { + int level = tryObtainLockAndCheckState(); + if (level > -1) { try { this.delegate.write(s, off, len); } finally { - releaseLock(); + releaseLock(level); } } } @@ -475,33 +510,28 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements this.delegate.write(s); } - private boolean tryObtainLockAndCheckState() { - if (state() == State.NEW) { - return true; + /** + * Return 0 if checks passed and lock is not needed, 1 if checks passed + * and lock is held, and -1 if checks did not pass. + */ + private int tryObtainLockAndCheckState() { + if (this.asyncWebRequest.state == State.NEW) { + return 0; } - if (stateLock().tryLock()) { - if (state() == State.ASYNC) { - return true; - } - stateLock().unlock(); + this.asyncWebRequest.stateLock.lock(); + if (this.asyncWebRequest.state == State.ASYNC) { + return 1; } - return false; + this.asyncWebRequest.stateLock.unlock(); + return -1; } - private void releaseLock() { - if (state() != State.NEW) { - stateLock().unlock(); + private void releaseLock(int level) { + if (level > 0) { + this.asyncWebRequest.stateLock.unlock(); } } - private State state() { - return this.asyncWebRequest.state; - } - - private Lock stateLock() { - return this.asyncWebRequest.stateLock; - } - // Plain delegates @Override @@ -639,28 +669,28 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements /** * Represents a state for {@link StandardServletAsyncWebRequest} to be in. *
- * NEW - * | - * v - * ASYNC----> + - * | | - * v | - * ERROR | - * | | - * v | - * COMPLETED <--+ + * +------ NEW + * | | + * | v + * | ASYNC ----> + + * | | | + * | v | + * +----> ERROR | + * | | + * v | + * COMPLETED <---+ ** @since 5.3.33 */ private enum State { - /** New request (thas may not do async handling). */ + /** New request (may not start async handling). */ NEW, /** Async handling has started. */ ASYNC, - /** onError notification received, or ServletOutputStream failed. */ + /** ServletOutputStream failed, or onError notification received. */ ERROR, /** onComplete notification received. */ diff --git a/spring-web/src/test/java/org/springframework/web/context/request/async/AsyncRequestNotUsableTests.java b/spring-web/src/test/java/org/springframework/web/context/request/async/AsyncRequestNotUsableTests.java new file mode 100644 index 00000000000..0033627dfb9 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/context/request/async/AsyncRequestNotUsableTests.java @@ -0,0 +1,357 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.context.request.async; + +import java.io.IOException; +import java.io.PrintWriter; +import java.util.concurrent.atomic.AtomicInteger; + +import jakarta.servlet.AsyncEvent; +import jakarta.servlet.ServletOutputStream; +import jakarta.servlet.http.HttpServletResponse; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.mockito.stubbing.Answer; + +import org.springframework.web.testfixture.servlet.MockAsyncContext; +import org.springframework.web.testfixture.servlet.MockHttpServletRequest; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.BDDMockito.doAnswer; +import static org.mockito.BDDMockito.doThrow; +import static org.mockito.BDDMockito.given; +import static org.mockito.BDDMockito.mock; +import static org.mockito.BDDMockito.verify; +import static org.mockito.BDDMockito.verifyNoInteractions; + +/** + * {@link StandardServletAsyncWebRequest} tests related to response wrapping in + * order to enforce thread safety and prevent use after errors. + * + * @author Rossen Stoyanchev + */ +public class AsyncRequestNotUsableTests { + + private final MockHttpServletRequest request = new MockHttpServletRequest(); + + private final HttpServletResponse response = mock(); + + private final ServletOutputStream outputStream = mock(); + + private final PrintWriter writer = mock(); + + private StandardServletAsyncWebRequest asyncRequest; + + + @BeforeEach + void setup() throws IOException { + this.request.setAsyncSupported(true); + given(this.response.getOutputStream()).willReturn(this.outputStream); + given(this.response.getWriter()).willReturn(this.writer); + + this.asyncRequest = new StandardServletAsyncWebRequest(this.request, this.response); + } + + @AfterEach + void tearDown() { + assertThat(this.asyncRequest.stateLock().isLocked()).isFalse(); + } + + + @SuppressWarnings("DataFlowIssue") + private ServletOutputStream getWrappedOutputStream() throws IOException { + return this.asyncRequest.getResponse().getOutputStream(); + } + + @SuppressWarnings("DataFlowIssue") + private PrintWriter getWrappedWriter() throws IOException { + return this.asyncRequest.getResponse().getWriter(); + } + + + @Nested + class ResponseTests { + + @Test + void notUsableAfterError() throws IOException { + asyncRequest.startAsync(); + asyncRequest.onError(new AsyncEvent(new MockAsyncContext(request, response), new Exception())); + + HttpServletResponse wrapped = asyncRequest.getResponse(); + assertThat(wrapped).isNotNull(); + assertThatThrownBy(wrapped::getOutputStream).hasMessage("Response not usable after response errors."); + assertThatThrownBy(wrapped::getWriter).hasMessage("Response not usable after response errors."); + assertThatThrownBy(wrapped::flushBuffer).hasMessage("Response not usable after response errors."); + } + + @Test + void notUsableAfterCompletion() throws IOException { + asyncRequest.startAsync(); + asyncRequest.onComplete(new AsyncEvent(new MockAsyncContext(request, response))); + + HttpServletResponse wrapped = asyncRequest.getResponse(); + assertThat(wrapped).isNotNull(); + assertThatThrownBy(wrapped::getOutputStream).hasMessage("Response not usable after async request completion."); + assertThatThrownBy(wrapped::getWriter).hasMessage("Response not usable after async request completion."); + assertThatThrownBy(wrapped::flushBuffer).hasMessage("Response not usable after async request completion."); + } + + @Test + void notUsableWhenRecreatedAfterCompletion() throws IOException { + asyncRequest.startAsync(); + asyncRequest.onComplete(new AsyncEvent(new MockAsyncContext(request, response))); + + StandardServletAsyncWebRequest newWebRequest = + new StandardServletAsyncWebRequest(request, response, asyncRequest); + + HttpServletResponse wrapped = newWebRequest.getResponse(); + assertThat(wrapped).isNotNull(); + assertThatThrownBy(wrapped::getOutputStream).hasMessage("Response not usable after async request completion."); + assertThatThrownBy(wrapped::getWriter).hasMessage("Response not usable after async request completion."); + assertThatThrownBy(wrapped::flushBuffer).hasMessage("Response not usable after async request completion."); + } + } + + + @Nested + class OutputStreamTests { + + @Test + void use() throws IOException { + testUseOutputStream(); + } + + @Test + void useInAsyncState() throws IOException { + asyncRequest.startAsync(); + testUseOutputStream(); + } + + private void testUseOutputStream() throws IOException { + ServletOutputStream wrapped = getWrappedOutputStream(); + + wrapped.write('a'); + wrapped.write(new byte[0], 1, 2); + wrapped.flush(); + wrapped.close(); + + verify(outputStream).write('a'); + verify(outputStream).write(new byte[0], 1, 2); + verify(outputStream).flush(); + verify(outputStream).close(); + } + + @Test + void notUsableAfterCompletion() throws IOException { + asyncRequest.startAsync(); + ServletOutputStream wrapped = getWrappedOutputStream(); + + asyncRequest.onComplete(new AsyncEvent(new MockAsyncContext(request, response))); + + assertThatThrownBy(() -> wrapped.write('a')).hasMessage("Response not usable after async request completion."); + assertThatThrownBy(() -> wrapped.write(new byte[0])).hasMessage("Response not usable after async request completion."); + assertThatThrownBy(() -> wrapped.write(new byte[0], 0, 0)).hasMessage("Response not usable after async request completion."); + assertThatThrownBy(wrapped::flush).hasMessage("Response not usable after async request completion."); + assertThatThrownBy(wrapped::close).hasMessage("Response not usable after async request completion."); + } + + @Test + void lockingNotUsed() throws IOException { + AtomicInteger count = new AtomicInteger(-1); + doAnswer((Answer