Browse Source

Merge branch '6.1.x'

pull/32377/head
rstoyanchev 2 years ago
parent
commit
ddab971fca
  1. 190
      spring-web/src/main/java/org/springframework/web/context/request/async/StandardServletAsyncWebRequest.java
  2. 357
      spring-web/src/test/java/org/springframework/web/context/request/async/AsyncRequestNotUsableTests.java
  3. 189
      spring-web/src/test/java/org/springframework/web/context/request/async/StandardServletAsyncWebRequestTests.java
  4. 34
      spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/RequestMappingHandlerAdapterTests.java

190
spring-web/src/main/java/org/springframework/web/context/request/async/StandardServletAsyncWebRequest.java

@ -21,7 +21,6 @@ import java.io.PrintWriter; @@ -21,7 +21,6 @@ import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.Consumer;
@ -189,7 +188,7 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements @@ -189,7 +188,7 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements
public void onError(AsyncEvent event) throws IOException {
this.stateLock.lock();
try {
transitionToErrorState();
this.state = State.ERROR;
Throwable ex = event.getThrowable();
this.exceptionHandlers.forEach(consumer -> consumer.accept(ex));
}
@ -198,12 +197,6 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements @@ -198,12 +197,6 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements
}
}
private void transitionToErrorState() {
if (!isAsyncComplete()) {
this.state = State.ERROR;
}
}
@Override
public void onComplete(AsyncEvent event) throws IOException {
this.stateLock.lock();
@ -218,8 +211,17 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements @@ -218,8 +211,17 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements
}
/**
* Package private access for testing only.
*/
ReentrantLock stateLock() {
return this.stateLock;
}
/**
* Response wrapper to wrap the output stream with {@link LifecycleServletOutputStream}.
* @since 5.3.33
*/
private static final class LifecycleHttpServletResponse extends HttpServletResponseWrapper {
@ -242,26 +244,44 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements @@ -242,26 +244,44 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements
@Override
public ServletOutputStream getOutputStream() throws IOException {
if (this.outputStream == null) {
Assert.notNull(this.asyncWebRequest, "Not initialized");
ServletOutputStream delegate = getResponse().getOutputStream();
this.outputStream = new LifecycleServletOutputStream(delegate, this);
int level = obtainLockAndCheckState();
try {
if (this.outputStream == null) {
Assert.notNull(this.asyncWebRequest, "Not initialized");
ServletOutputStream delegate = getResponse().getOutputStream();
this.outputStream = new LifecycleServletOutputStream(delegate, this);
}
}
catch (IOException ex) {
handleIOException(ex, "Failed to get ServletResponseOutput");
}
finally {
releaseLock(level);
}
return this.outputStream;
}
@Override
public PrintWriter getWriter() throws IOException {
if (this.writer == null) {
Assert.notNull(this.asyncWebRequest, "Not initialized");
this.writer = new LifecyclePrintWriter(getResponse().getWriter(), this.asyncWebRequest);
int level = obtainLockAndCheckState();
try {
if (this.writer == null) {
Assert.notNull(this.asyncWebRequest, "Not initialized");
this.writer = new LifecyclePrintWriter(getResponse().getWriter(), this.asyncWebRequest);
}
}
catch (IOException ex) {
handleIOException(ex, "Failed to get PrintWriter");
}
finally {
releaseLock(level);
}
return this.writer;
}
@Override
public void flushBuffer() throws IOException {
obtainLockAndCheckState();
int level = obtainLockAndCheckState();
try {
getResponse().flushBuffer();
}
@ -269,32 +289,40 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements @@ -269,32 +289,40 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements
handleIOException(ex, "ServletResponse failed to flushBuffer");
}
finally {
releaseLock();
releaseLock(level);
}
}
private void obtainLockAndCheckState() throws AsyncRequestNotUsableException {
/**
* Return 0 if checks passed and lock is not needed, 1 if checks passed
* and lock is held, or raise AsyncRequestNotUsableException.
*/
private int obtainLockAndCheckState() throws AsyncRequestNotUsableException {
Assert.notNull(this.asyncWebRequest, "Not initialized");
if (this.asyncWebRequest.state != State.NEW) {
this.asyncWebRequest.stateLock.lock();
if (this.asyncWebRequest.state != State.ASYNC) {
this.asyncWebRequest.stateLock.unlock();
throw new AsyncRequestNotUsableException("Response not usable after " +
(this.asyncWebRequest.state == State.COMPLETED ?
"async request completion" : "onError notification") + ".");
}
if (this.asyncWebRequest.state == State.NEW) {
return 0;
}
this.asyncWebRequest.stateLock.lock();
if (this.asyncWebRequest.state == State.ASYNC) {
return 1;
}
this.asyncWebRequest.stateLock.unlock();
throw new AsyncRequestNotUsableException("Response not usable after " +
(this.asyncWebRequest.state == State.COMPLETED ?
"async request completion" : "response errors") + ".");
}
void handleIOException(IOException ex, String msg) throws AsyncRequestNotUsableException {
Assert.notNull(this.asyncWebRequest, "Not initialized");
this.asyncWebRequest.transitionToErrorState();
throw new AsyncRequestNotUsableException(msg, ex);
this.asyncWebRequest.state = State.ERROR;
throw new AsyncRequestNotUsableException(msg + ": " + ex.getMessage(), ex);
}
void releaseLock() {
void releaseLock(int level) {
Assert.notNull(this.asyncWebRequest, "Not initialized");
if (this.asyncWebRequest.state != State.NEW) {
if (level > 0) {
this.asyncWebRequest.stateLock.unlock();
}
}
@ -304,6 +332,7 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements @@ -304,6 +332,7 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements
/**
* Wraps a ServletOutputStream to prevent use after Servlet container onError
* notifications, and after async request completion.
* @since 5.3.33
*/
private static final class LifecycleServletOutputStream extends ServletOutputStream {
@ -328,7 +357,7 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements @@ -328,7 +357,7 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements
@Override
public void write(int b) throws IOException {
this.response.obtainLockAndCheckState();
int level = this.response.obtainLockAndCheckState();
try {
this.delegate.write(b);
}
@ -336,12 +365,12 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements @@ -336,12 +365,12 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements
this.response.handleIOException(ex, "ServletOutputStream failed to write");
}
finally {
this.response.releaseLock();
this.response.releaseLock(level);
}
}
public void write(byte[] buf, int offset, int len) throws IOException {
this.response.obtainLockAndCheckState();
int level = this.response.obtainLockAndCheckState();
try {
this.delegate.write(buf, offset, len);
}
@ -349,13 +378,13 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements @@ -349,13 +378,13 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements
this.response.handleIOException(ex, "ServletOutputStream failed to write");
}
finally {
this.response.releaseLock();
this.response.releaseLock(level);
}
}
@Override
public void flush() throws IOException {
this.response.obtainLockAndCheckState();
int level = this.response.obtainLockAndCheckState();
try {
this.delegate.flush();
}
@ -363,13 +392,13 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements @@ -363,13 +392,13 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements
this.response.handleIOException(ex, "ServletOutputStream failed to flush");
}
finally {
this.response.releaseLock();
this.response.releaseLock(level);
}
}
@Override
public void close() throws IOException {
this.response.obtainLockAndCheckState();
int level = this.response.obtainLockAndCheckState();
try {
this.delegate.close();
}
@ -377,7 +406,7 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements @@ -377,7 +406,7 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements
this.response.handleIOException(ex, "ServletOutputStream failed to close");
}
finally {
this.response.releaseLock();
this.response.releaseLock(level);
}
}
@ -387,6 +416,7 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements @@ -387,6 +416,7 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements
/**
* Wraps a PrintWriter to prevent use after Servlet container onError
* notifications, and after async request completion.
* @since 5.3.33
*/
private static final class LifecyclePrintWriter extends PrintWriter {
@ -402,24 +432,26 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements @@ -402,24 +432,26 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements
@Override
public void flush() {
if (tryObtainLockAndCheckState()) {
int level = tryObtainLockAndCheckState();
if (level > -1) {
try {
this.delegate.flush();
}
finally {
releaseLock();
releaseLock(level);
}
}
}
@Override
public void close() {
if (tryObtainLockAndCheckState()) {
int level = tryObtainLockAndCheckState();
if (level > -1) {
try {
this.delegate.close();
}
finally {
releaseLock();
releaseLock(level);
}
}
}
@ -431,24 +463,26 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements @@ -431,24 +463,26 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements
@Override
public void write(int c) {
if (tryObtainLockAndCheckState()) {
int level = tryObtainLockAndCheckState();
if (level > -1) {
try {
this.delegate.write(c);
}
finally {
releaseLock();
releaseLock(level);
}
}
}
@Override
public void write(char[] buf, int off, int len) {
if (tryObtainLockAndCheckState()) {
int level = tryObtainLockAndCheckState();
if (level > -1) {
try {
this.delegate.write(buf, off, len);
}
finally {
releaseLock();
releaseLock(level);
}
}
}
@ -460,12 +494,13 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements @@ -460,12 +494,13 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements
@Override
public void write(String s, int off, int len) {
if (tryObtainLockAndCheckState()) {
int level = tryObtainLockAndCheckState();
if (level > -1) {
try {
this.delegate.write(s, off, len);
}
finally {
releaseLock();
releaseLock(level);
}
}
}
@ -475,33 +510,28 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements @@ -475,33 +510,28 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements
this.delegate.write(s);
}
private boolean tryObtainLockAndCheckState() {
if (state() == State.NEW) {
return true;
/**
* Return 0 if checks passed and lock is not needed, 1 if checks passed
* and lock is held, and -1 if checks did not pass.
*/
private int tryObtainLockAndCheckState() {
if (this.asyncWebRequest.state == State.NEW) {
return 0;
}
if (stateLock().tryLock()) {
if (state() == State.ASYNC) {
return true;
}
stateLock().unlock();
this.asyncWebRequest.stateLock.lock();
if (this.asyncWebRequest.state == State.ASYNC) {
return 1;
}
return false;
this.asyncWebRequest.stateLock.unlock();
return -1;
}
private void releaseLock() {
if (state() != State.NEW) {
stateLock().unlock();
private void releaseLock(int level) {
if (level > 0) {
this.asyncWebRequest.stateLock.unlock();
}
}
private State state() {
return this.asyncWebRequest.state;
}
private Lock stateLock() {
return this.asyncWebRequest.stateLock;
}
// Plain delegates
@Override
@ -639,28 +669,28 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements @@ -639,28 +669,28 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements
/**
* Represents a state for {@link StandardServletAsyncWebRequest} to be in.
* <p><pre>
* NEW
* |
* v
* ASYNC----> +
* | |
* v |
* ERROR |
* | |
* v |
* COMPLETED <--+
* +------ NEW
* | |
* | v
* | ASYNC ----> +
* | | |
* | v |
* +----> ERROR |
* | |
* v |
* COMPLETED <---+
* </pre>
* @since 5.3.33
*/
private enum State {
/** New request (thas may not do async handling). */
/** New request (may not start async handling). */
NEW,
/** Async handling has started. */
ASYNC,
/** onError notification received, or ServletOutputStream failed. */
/** ServletOutputStream failed, or onError notification received. */
ERROR,
/** onComplete notification received. */

357
spring-web/src/test/java/org/springframework/web/context/request/async/AsyncRequestNotUsableTests.java

@ -0,0 +1,357 @@ @@ -0,0 +1,357 @@
/*
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.context.request.async;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.concurrent.atomic.AtomicInteger;
import jakarta.servlet.AsyncEvent;
import jakarta.servlet.ServletOutputStream;
import jakarta.servlet.http.HttpServletResponse;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Nested;
import org.junit.jupiter.api.Test;
import org.mockito.stubbing.Answer;
import org.springframework.web.testfixture.servlet.MockAsyncContext;
import org.springframework.web.testfixture.servlet.MockHttpServletRequest;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.BDDMockito.doAnswer;
import static org.mockito.BDDMockito.doThrow;
import static org.mockito.BDDMockito.given;
import static org.mockito.BDDMockito.mock;
import static org.mockito.BDDMockito.verify;
import static org.mockito.BDDMockito.verifyNoInteractions;
/**
* {@link StandardServletAsyncWebRequest} tests related to response wrapping in
* order to enforce thread safety and prevent use after errors.
*
* @author Rossen Stoyanchev
*/
public class AsyncRequestNotUsableTests {
private final MockHttpServletRequest request = new MockHttpServletRequest();
private final HttpServletResponse response = mock();
private final ServletOutputStream outputStream = mock();
private final PrintWriter writer = mock();
private StandardServletAsyncWebRequest asyncRequest;
@BeforeEach
void setup() throws IOException {
this.request.setAsyncSupported(true);
given(this.response.getOutputStream()).willReturn(this.outputStream);
given(this.response.getWriter()).willReturn(this.writer);
this.asyncRequest = new StandardServletAsyncWebRequest(this.request, this.response);
}
@AfterEach
void tearDown() {
assertThat(this.asyncRequest.stateLock().isLocked()).isFalse();
}
@SuppressWarnings("DataFlowIssue")
private ServletOutputStream getWrappedOutputStream() throws IOException {
return this.asyncRequest.getResponse().getOutputStream();
}
@SuppressWarnings("DataFlowIssue")
private PrintWriter getWrappedWriter() throws IOException {
return this.asyncRequest.getResponse().getWriter();
}
@Nested
class ResponseTests {
@Test
void notUsableAfterError() throws IOException {
asyncRequest.startAsync();
asyncRequest.onError(new AsyncEvent(new MockAsyncContext(request, response), new Exception()));
HttpServletResponse wrapped = asyncRequest.getResponse();
assertThat(wrapped).isNotNull();
assertThatThrownBy(wrapped::getOutputStream).hasMessage("Response not usable after response errors.");
assertThatThrownBy(wrapped::getWriter).hasMessage("Response not usable after response errors.");
assertThatThrownBy(wrapped::flushBuffer).hasMessage("Response not usable after response errors.");
}
@Test
void notUsableAfterCompletion() throws IOException {
asyncRequest.startAsync();
asyncRequest.onComplete(new AsyncEvent(new MockAsyncContext(request, response)));
HttpServletResponse wrapped = asyncRequest.getResponse();
assertThat(wrapped).isNotNull();
assertThatThrownBy(wrapped::getOutputStream).hasMessage("Response not usable after async request completion.");
assertThatThrownBy(wrapped::getWriter).hasMessage("Response not usable after async request completion.");
assertThatThrownBy(wrapped::flushBuffer).hasMessage("Response not usable after async request completion.");
}
@Test
void notUsableWhenRecreatedAfterCompletion() throws IOException {
asyncRequest.startAsync();
asyncRequest.onComplete(new AsyncEvent(new MockAsyncContext(request, response)));
StandardServletAsyncWebRequest newWebRequest =
new StandardServletAsyncWebRequest(request, response, asyncRequest);
HttpServletResponse wrapped = newWebRequest.getResponse();
assertThat(wrapped).isNotNull();
assertThatThrownBy(wrapped::getOutputStream).hasMessage("Response not usable after async request completion.");
assertThatThrownBy(wrapped::getWriter).hasMessage("Response not usable after async request completion.");
assertThatThrownBy(wrapped::flushBuffer).hasMessage("Response not usable after async request completion.");
}
}
@Nested
class OutputStreamTests {
@Test
void use() throws IOException {
testUseOutputStream();
}
@Test
void useInAsyncState() throws IOException {
asyncRequest.startAsync();
testUseOutputStream();
}
private void testUseOutputStream() throws IOException {
ServletOutputStream wrapped = getWrappedOutputStream();
wrapped.write('a');
wrapped.write(new byte[0], 1, 2);
wrapped.flush();
wrapped.close();
verify(outputStream).write('a');
verify(outputStream).write(new byte[0], 1, 2);
verify(outputStream).flush();
verify(outputStream).close();
}
@Test
void notUsableAfterCompletion() throws IOException {
asyncRequest.startAsync();
ServletOutputStream wrapped = getWrappedOutputStream();
asyncRequest.onComplete(new AsyncEvent(new MockAsyncContext(request, response)));
assertThatThrownBy(() -> wrapped.write('a')).hasMessage("Response not usable after async request completion.");
assertThatThrownBy(() -> wrapped.write(new byte[0])).hasMessage("Response not usable after async request completion.");
assertThatThrownBy(() -> wrapped.write(new byte[0], 0, 0)).hasMessage("Response not usable after async request completion.");
assertThatThrownBy(wrapped::flush).hasMessage("Response not usable after async request completion.");
assertThatThrownBy(wrapped::close).hasMessage("Response not usable after async request completion.");
}
@Test
void lockingNotUsed() throws IOException {
AtomicInteger count = new AtomicInteger(-1);
doAnswer((Answer<Void>) invocation -> {
count.set(asyncRequest.stateLock().getHoldCount());
return null;
}).when(outputStream).write('a');
// Access ServletOutputStream in NEW state (no async handling) without locking
getWrappedOutputStream().write('a');
assertThat(count.get()).isEqualTo(0);
}
@Test
void lockingUsedInAsyncState() throws IOException {
AtomicInteger count = new AtomicInteger(-1);
doAnswer((Answer<Void>) invocation -> {
count.set(asyncRequest.stateLock().getHoldCount());
return null;
}).when(outputStream).write('a');
// Access ServletOutputStream in ASYNC state with locking
asyncRequest.startAsync();
getWrappedOutputStream().write('a');
assertThat(count.get()).isEqualTo(1);
}
}
@Nested
class OutputStreamErrorTests {
@Test
void writeInt() throws IOException {
asyncRequest.startAsync();
ServletOutputStream wrapped = getWrappedOutputStream();
doThrow(new IOException("Broken pipe")).when(outputStream).write('a');
assertThatThrownBy(() -> wrapped.write('a')).hasMessage("ServletOutputStream failed to write: Broken pipe");
}
@Test
void writeBytesFull() throws IOException {
asyncRequest.startAsync();
ServletOutputStream wrapped = getWrappedOutputStream();
byte[] bytes = new byte[0];
doThrow(new IOException("Broken pipe")).when(outputStream).write(bytes, 0, 0);
assertThatThrownBy(() -> wrapped.write(bytes)).hasMessage("ServletOutputStream failed to write: Broken pipe");
}
@Test
void writeBytes() throws IOException {
asyncRequest.startAsync();
ServletOutputStream wrapped = getWrappedOutputStream();
byte[] bytes = new byte[0];
doThrow(new IOException("Broken pipe")).when(outputStream).write(bytes, 0, 0);
assertThatThrownBy(() -> wrapped.write(bytes, 0, 0)).hasMessage("ServletOutputStream failed to write: Broken pipe");
}
@Test
void flush() throws IOException {
asyncRequest.startAsync();
ServletOutputStream wrapped = getWrappedOutputStream();
doThrow(new IOException("Broken pipe")).when(outputStream).flush();
assertThatThrownBy(wrapped::flush).hasMessage("ServletOutputStream failed to flush: Broken pipe");
}
@Test
void close() throws IOException {
asyncRequest.startAsync();
ServletOutputStream wrapped = getWrappedOutputStream();
doThrow(new IOException("Broken pipe")).when(outputStream).close();
assertThatThrownBy(wrapped::close).hasMessage("ServletOutputStream failed to close: Broken pipe");
}
@Test
void writeErrorPreventsFurtherWriting() throws IOException {
ServletOutputStream wrapped = getWrappedOutputStream();
doThrow(new IOException("Broken pipe")).when(outputStream).write('a');
assertThatThrownBy(() -> wrapped.write('a')).hasMessage("ServletOutputStream failed to write: Broken pipe");
assertThatThrownBy(() -> wrapped.write('a')).hasMessage("Response not usable after response errors.");
}
@Test
void writeErrorInAsyncStatePreventsFurtherWriting() throws IOException {
asyncRequest.startAsync();
ServletOutputStream wrapped = getWrappedOutputStream();
doThrow(new IOException("Broken pipe")).when(outputStream).write('a');
assertThatThrownBy(() -> wrapped.write('a')).hasMessage("ServletOutputStream failed to write: Broken pipe");
assertThatThrownBy(() -> wrapped.write('a')).hasMessage("Response not usable after response errors.");
}
}
@Nested
class WriterTests {
@Test
void useWriter() throws IOException {
testUseWriter();
}
@Test
void useWriterInAsyncState() throws IOException {
asyncRequest.startAsync();
testUseWriter();
}
private void testUseWriter() throws IOException {
PrintWriter wrapped = getWrappedWriter();
wrapped.write('a');
wrapped.write(new char[0], 1, 2);
wrapped.write("abc", 1, 2);
wrapped.flush();
wrapped.close();
verify(writer).write('a');
verify(writer).write(new char[0], 1, 2);
verify(writer).write("abc", 1, 2);
verify(writer).flush();
verify(writer).close();
}
@Test
void writerNotUsableAfterCompletion() throws IOException {
asyncRequest.startAsync();
PrintWriter wrapped = getWrappedWriter();
asyncRequest.onComplete(new AsyncEvent(new MockAsyncContext(request, response)));
char[] chars = new char[0];
wrapped.write('a');
wrapped.write(chars, 1, 2);
wrapped.flush();
wrapped.close();
verifyNoInteractions(writer);
}
@Test
void lockingNotUsed() throws IOException {
AtomicInteger count = new AtomicInteger(-1);
doAnswer((Answer<Void>) invocation -> {
count.set(asyncRequest.stateLock().getHoldCount());
return null;
}).when(writer).write('a');
// Use Writer in NEW state (no async handling) without locking
PrintWriter wrapped = getWrappedWriter();
wrapped.write('a');
assertThat(count.get()).isEqualTo(0);
}
@Test
void lockingUsedInAsyncState() throws IOException {
AtomicInteger count = new AtomicInteger(-1);
doAnswer((Answer<Void>) invocation -> {
count.set(asyncRequest.stateLock().getHoldCount());
return null;
}).when(writer).write('a');
// Use Writer in ASYNC state with locking
asyncRequest.startAsync();
PrintWriter wrapped = getWrappedWriter();
wrapped.write('a');
assertThat(count.get()).isEqualTo(1);
}
}
}

189
spring-web/src/test/java/org/springframework/web/context/request/async/StandardServletAsyncWebRequestTests.java

@ -20,6 +20,7 @@ import java.util.function.Consumer; @@ -20,6 +20,7 @@ import java.util.function.Consumer;
import jakarta.servlet.AsyncEvent;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Nested;
import org.junit.jupiter.api.Test;
import org.springframework.web.testfixture.servlet.MockAsyncContext;
@ -32,7 +33,8 @@ import static org.mockito.Mockito.mock; @@ -32,7 +33,8 @@ import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
/**
* A test fixture with a {@link StandardServletAsyncWebRequest}.
* Tests for {@link StandardServletAsyncWebRequest}.
*
* @author Rossen Stoyanchev
*/
class StandardServletAsyncWebRequestTests {
@ -49,120 +51,109 @@ class StandardServletAsyncWebRequestTests { @@ -49,120 +51,109 @@ class StandardServletAsyncWebRequestTests {
this.request = new MockHttpServletRequest();
this.request.setAsyncSupported(true);
this.response = new MockHttpServletResponse();
this.asyncRequest = new StandardServletAsyncWebRequest(this.request, this.response);
this.asyncRequest.setTimeout(44*1000L);
}
@Test
void isAsyncStarted() {
assertThat(this.asyncRequest.isAsyncStarted()).isFalse();
this.asyncRequest.startAsync();
assertThat(this.asyncRequest.isAsyncStarted()).isTrue();
}
@Test
void startAsync() {
this.asyncRequest.startAsync();
MockAsyncContext context = (MockAsyncContext) this.request.getAsyncContext();
assertThat(context).isNotNull();
assertThat(context.getTimeout()).as("Timeout value not set").isEqualTo((44 * 1000));
assertThat(context.getListeners()).containsExactly(this.asyncRequest);
}
@Test
void startAsyncMultipleTimes() {
this.asyncRequest.startAsync();
this.asyncRequest.startAsync();
this.asyncRequest.startAsync();
this.asyncRequest.startAsync(); // idempotent
MockAsyncContext context = (MockAsyncContext) this.request.getAsyncContext();
assertThat(context).isNotNull();
assertThat(context.getListeners()).hasSize(1);
this.asyncRequest = new StandardServletAsyncWebRequest(this.request, this.response);
this.asyncRequest.setTimeout(44 * 1000L);
}
@Test
void startAsyncNotSupported() {
this.request.setAsyncSupported(false);
assertThatIllegalStateException().isThrownBy(
this.asyncRequest::startAsync)
.withMessageContaining("Async support must be enabled");
}
@Test
void startAsyncAfterCompleted() throws Exception {
this.asyncRequest.onComplete(new AsyncEvent(new MockAsyncContext(this.request, this.response)));
assertThatIllegalStateException().isThrownBy(this.asyncRequest::startAsync)
.withMessage("Cannot start async: [COMPLETED]");
@Nested
class StartAsync {
@Test
void isAsyncStarted() {
assertThat(asyncRequest.isAsyncStarted()).isFalse();
asyncRequest.startAsync();
assertThat(asyncRequest.isAsyncStarted()).isTrue();
}
@Test
void startAsync() {
asyncRequest.startAsync();
MockAsyncContext context = (MockAsyncContext) request.getAsyncContext();
assertThat(context).isNotNull();
assertThat(context.getTimeout()).as("Timeout value not set").isEqualTo((44 * 1000));
assertThat(context.getListeners()).containsExactly(asyncRequest);
}
@Test
void startAsyncMultipleTimes() {
asyncRequest.startAsync();
asyncRequest.startAsync();
asyncRequest.startAsync();
asyncRequest.startAsync();
MockAsyncContext context = (MockAsyncContext) request.getAsyncContext();
assertThat(context).isNotNull();
assertThat(context.getListeners()).hasSize(1);
}
@Test
void startAsyncNotSupported() {
request.setAsyncSupported(false);
assertThatIllegalStateException()
.isThrownBy(asyncRequest::startAsync)
.withMessageContaining("Async support must be enabled");
}
@Test
void startAsyncAfterCompleted() throws Exception {
asyncRequest.startAsync();
asyncRequest.onComplete(new AsyncEvent(new MockAsyncContext(request, response)));
assertThatIllegalStateException()
.isThrownBy(asyncRequest::startAsync)
.withMessage("Cannot start async: [COMPLETED]");
}
@Test
void startAsyncAndSetTimeout() {
asyncRequest.startAsync();
assertThatIllegalStateException().isThrownBy(() -> asyncRequest.setTimeout(25L));
}
}
@Test
void onTimeoutDefaultBehavior() throws Exception {
this.asyncRequest.onTimeout(new AsyncEvent(new MockAsyncContext(this.request, this.response)));
assertThat(this.response.getStatus()).isEqualTo(200);
}
@Test
void onTimeoutHandler() throws Exception {
Runnable timeoutHandler = mock();
this.asyncRequest.addTimeoutHandler(timeoutHandler);
this.asyncRequest.onTimeout(new AsyncEvent(new MockAsyncContext(this.request, this.response)));
verify(timeoutHandler).run();
}
@Nested
class AsyncListenerHandling {
@Test
void onErrorHandler() throws Exception {
Consumer<Throwable> errorHandler = mock();
this.asyncRequest.addErrorHandler(errorHandler);
Exception e = new Exception();
this.asyncRequest.onError(new AsyncEvent(new MockAsyncContext(this.request, this.response), e));
verify(errorHandler).accept(e);
}
@Test
void onTimeoutHandler() throws Exception {
Runnable handler = mock();
asyncRequest.addTimeoutHandler(handler);
@Test
void setTimeoutDuringConcurrentHandling() {
this.asyncRequest.startAsync();
assertThatIllegalStateException().isThrownBy(() ->
this.asyncRequest.setTimeout(25L));
}
asyncRequest.startAsync();
asyncRequest.onTimeout(new AsyncEvent(new MockAsyncContext(request, response)));
@Test
void onCompletionHandler() throws Exception {
Runnable handler = mock();
this.asyncRequest.addCompletionHandler(handler);
verify(handler).run();
}
this.asyncRequest.startAsync();
this.asyncRequest.onComplete(new AsyncEvent(this.request.getAsyncContext()));
@Test
void onErrorHandler() throws Exception {
Exception ex = new Exception();
Consumer<Throwable> handler = mock();
asyncRequest.addErrorHandler(handler);
verify(handler).run();
assertThat(this.asyncRequest.isAsyncComplete()).isTrue();
}
asyncRequest.startAsync();
asyncRequest.onError(new AsyncEvent(new MockAsyncContext(request, response), ex));
// SPR-13292
verify(handler).accept(ex);
}
@Test
void onErrorHandlerAfterOnErrorEvent() throws Exception {
Consumer<Throwable> handler = mock();
this.asyncRequest.addErrorHandler(handler);
@Test
void onCompletionHandler() throws Exception {
Runnable handler = mock();
asyncRequest.addCompletionHandler(handler);
this.asyncRequest.startAsync();
Exception e = new Exception();
this.asyncRequest.onError(new AsyncEvent(this.request.getAsyncContext(), e));
asyncRequest.startAsync();
asyncRequest.onComplete(new AsyncEvent(request.getAsyncContext()));
verify(handler).accept(e);
verify(handler).run();
assertThat(asyncRequest.isAsyncComplete()).isTrue();
}
}
@Test
void onCompletionHandlerAfterOnCompleteEvent() throws Exception {
Runnable handler = mock();
this.asyncRequest.addCompletionHandler(handler);
this.asyncRequest.startAsync();
this.asyncRequest.onComplete(new AsyncEvent(this.request.getAsyncContext()));
verify(handler).run();
assertThat(this.asyncRequest.isAsyncComplete()).isTrue();
}
}

34
spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/RequestMappingHandlerAdapterTests.java

@ -16,8 +16,11 @@ @@ -16,8 +16,11 @@
package org.springframework.web.servlet.mvc.method.annotation;
import java.io.IOException;
import java.io.OutputStream;
import java.lang.reflect.Method;
import java.lang.reflect.Type;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
@ -25,6 +28,7 @@ import java.util.LinkedHashMap; @@ -25,6 +28,7 @@ import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import jakarta.servlet.AsyncEvent;
import org.apache.groovy.util.Maps;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
@ -49,6 +53,10 @@ import org.springframework.web.bind.annotation.ModelAttribute; @@ -49,6 +53,10 @@ import org.springframework.web.bind.annotation.ModelAttribute;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.ResponseBody;
import org.springframework.web.bind.annotation.SessionAttributes;
import org.springframework.web.context.request.async.AsyncRequestNotUsableException;
import org.springframework.web.context.request.async.StandardServletAsyncWebRequest;
import org.springframework.web.context.request.async.WebAsyncManager;
import org.springframework.web.context.request.async.WebAsyncUtils;
import org.springframework.web.context.support.StaticWebApplicationContext;
import org.springframework.web.method.HandlerMethod;
import org.springframework.web.method.annotation.ModelMethodProcessor;
@ -58,10 +66,12 @@ import org.springframework.web.method.support.InvocableHandlerMethod; @@ -58,10 +66,12 @@ import org.springframework.web.method.support.InvocableHandlerMethod;
import org.springframework.web.servlet.DispatcherServlet;
import org.springframework.web.servlet.FlashMap;
import org.springframework.web.servlet.ModelAndView;
import org.springframework.web.testfixture.servlet.MockAsyncContext;
import org.springframework.web.testfixture.servlet.MockHttpServletRequest;
import org.springframework.web.testfixture.servlet.MockHttpServletResponse;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
/**
* Tests for {@link RequestMappingHandlerAdapter}.
@ -285,6 +295,26 @@ class RequestMappingHandlerAdapterTests { @@ -285,6 +295,26 @@ class RequestMappingHandlerAdapterTests {
assertThat(this.response.getContentAsString()).isEqualTo("Body: {foo=bar}");
}
@Test
void asyncRequestNotUsable() throws Exception {
// Put AsyncWebRequest in ERROR state
StandardServletAsyncWebRequest asyncRequest = new StandardServletAsyncWebRequest(this.request, this.response);
asyncRequest.onError(new AsyncEvent(new MockAsyncContext(this.request, this.response), new Exception()));
// Set it as the current AsyncWebRequest, from the initial REQUEST dispatch
WebAsyncManager asyncManager = WebAsyncUtils.getAsyncManager(this.request);
asyncManager.setAsyncWebRequest(asyncRequest);
// AsyncWebRequest created for current dispatch should inherit state
HandlerMethod handlerMethod = handlerMethod(new TestController(), "handleOutputStream", OutputStream.class);
this.handlerAdapter.afterPropertiesSet();
// Use of response should be rejected
assertThatThrownBy(() -> this.handlerAdapter.handle(this.request, this.response, handlerMethod))
.isInstanceOf(AsyncRequestNotUsableException.class);
}
private HandlerMethod handlerMethod(Object handler, String methodName, Class<?>... paramTypes) throws Exception {
Method method = handler.getClass().getDeclaredMethod(methodName, paramTypes);
return new InvocableHandlerMethod(handler, method);
@ -321,6 +351,10 @@ class RequestMappingHandlerAdapterTests { @@ -321,6 +351,10 @@ class RequestMappingHandlerAdapterTests {
public String handleBody(@Nullable @RequestBody Map<String, String> body) {
return "Body: " + body;
}
public void handleOutputStream(OutputStream outputStream) throws IOException {
outputStream.write("body".getBytes(StandardCharsets.UTF_8));
}
}

Loading…
Cancel
Save