Browse Source

Backport tests for wrapping of response for async requests

This is a backport of commits 4b96cd and ef0717.

Closes gh-32341
pull/33048/head
rstoyanchev 2 years ago
parent
commit
1a7a6f421f
  1. 263
      spring-web/src/main/java/org/springframework/web/context/request/async/StandardServletAsyncWebRequest.java
  2. 5
      spring-web/src/main/java/org/springframework/web/context/request/async/WebAsyncManager.java
  3. 357
      spring-web/src/test/java/org/springframework/web/context/request/async/AsyncRequestNotUsableTests.java
  4. 44
      spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/RequestMappingHandlerAdapterTests.java

263
spring-web/src/main/java/org/springframework/web/context/request/async/StandardServletAsyncWebRequest.java

@ -21,7 +21,6 @@ import java.io.PrintWriter; @@ -21,7 +21,6 @@ import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.Consumer;
@ -189,7 +188,7 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements @@ -189,7 +188,7 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements
public void onError(AsyncEvent event) throws IOException {
this.stateLock.lock();
try {
transitionToErrorState();
this.state = State.ERROR;
Throwable ex = event.getThrowable();
this.exceptionHandlers.forEach(consumer -> consumer.accept(ex));
}
@ -198,12 +197,6 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements @@ -198,12 +197,6 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements
}
}
private void transitionToErrorState() {
if (!isAsyncComplete()) {
this.state = State.ERROR;
}
}
@Override
public void onComplete(AsyncEvent event) throws IOException {
this.stateLock.lock();
@ -218,8 +211,17 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements @@ -218,8 +211,17 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements
}
/**
* Package private access for testing only.
*/
ReentrantLock stateLock() {
return this.stateLock;
}
/**
* Response wrapper to wrap the output stream with {@link LifecycleServletOutputStream}.
* @since 5.3.33
*/
private static final class LifecycleHttpServletResponse extends HttpServletResponseWrapper {
@ -241,145 +243,180 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements @@ -241,145 +243,180 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements
}
@Override
public ServletOutputStream getOutputStream() {
if (this.outputStream == null) {
Assert.notNull(this.asyncWebRequest, "Not initialized");
this.outputStream = new LifecycleServletOutputStream(
(HttpServletResponse) getResponse(), this.asyncWebRequest);
public ServletOutputStream getOutputStream() throws IOException {
int level = obtainLockAndCheckState();
try {
if (this.outputStream == null) {
Assert.notNull(this.asyncWebRequest, "Not initialized");
ServletOutputStream delegate = getResponse().getOutputStream();
this.outputStream = new LifecycleServletOutputStream(delegate, this);
}
}
catch (IOException ex) {
handleIOException(ex, "Failed to get ServletResponseOutput");
}
finally {
releaseLock(level);
}
return this.outputStream;
}
@Override
public PrintWriter getWriter() throws IOException {
if (this.writer == null) {
Assert.notNull(this.asyncWebRequest, "Not initialized");
this.writer = new LifecyclePrintWriter(getResponse().getWriter(), this.asyncWebRequest);
int level = obtainLockAndCheckState();
try {
if (this.writer == null) {
Assert.notNull(this.asyncWebRequest, "Not initialized");
this.writer = new LifecyclePrintWriter(getResponse().getWriter(), this.asyncWebRequest);
}
}
catch (IOException ex) {
handleIOException(ex, "Failed to get PrintWriter");
}
finally {
releaseLock(level);
}
return this.writer;
}
@Override
public void flushBuffer() throws IOException {
int level = obtainLockAndCheckState();
try {
getResponse().flushBuffer();
}
catch (IOException ex) {
handleIOException(ex, "ServletResponse failed to flushBuffer");
}
finally {
releaseLock(level);
}
}
/**
* Return 0 if checks passed and lock is not needed, 1 if checks passed
* and lock is held, or raise AsyncRequestNotUsableException.
*/
private int obtainLockAndCheckState() throws AsyncRequestNotUsableException {
Assert.notNull(this.asyncWebRequest, "Not initialized");
if (this.asyncWebRequest.state == State.NEW) {
return 0;
}
this.asyncWebRequest.stateLock.lock();
if (this.asyncWebRequest.state == State.ASYNC) {
return 1;
}
this.asyncWebRequest.stateLock.unlock();
throw new AsyncRequestNotUsableException("Response not usable after " +
(this.asyncWebRequest.state == State.COMPLETED ?
"async request completion" : "response errors") + ".");
}
void handleIOException(IOException ex, String msg) throws AsyncRequestNotUsableException {
Assert.notNull(this.asyncWebRequest, "Not initialized");
this.asyncWebRequest.state = State.ERROR;
throw new AsyncRequestNotUsableException(msg + ": " + ex.getMessage(), ex);
}
void releaseLock(int level) {
Assert.notNull(this.asyncWebRequest, "Not initialized");
if (level > 0) {
this.asyncWebRequest.stateLock.unlock();
}
}
}
/**
* Wraps a ServletOutputStream to prevent use after Servlet container onError
* notifications, and after async request completion.
* @since 5.3.33
*/
private static final class LifecycleServletOutputStream extends ServletOutputStream {
private final HttpServletResponse delegate;
private final StandardServletAsyncWebRequest asyncWebRequest;
private final ServletOutputStream delegate;
private LifecycleServletOutputStream(
HttpServletResponse delegate, StandardServletAsyncWebRequest asyncWebRequest) {
private final LifecycleHttpServletResponse response;
private LifecycleServletOutputStream(ServletOutputStream delegate, LifecycleHttpServletResponse response) {
this.delegate = delegate;
this.asyncWebRequest = asyncWebRequest;
this.response = response;
}
@Override
public boolean isReady() {
return false;
return this.delegate.isReady();
}
@Override
public void setWriteListener(WriteListener writeListener) {
throw new UnsupportedOperationException();
this.delegate.setWriteListener(writeListener);
}
@Override
public void write(int b) throws IOException {
obtainLockAndCheckState();
int level = this.response.obtainLockAndCheckState();
try {
this.delegate.getOutputStream().write(b);
this.delegate.write(b);
}
catch (IOException ex) {
handleIOException(ex, "ServletOutputStream failed to write");
this.response.handleIOException(ex, "ServletOutputStream failed to write");
}
finally {
releaseLock();
this.response.releaseLock(level);
}
}
public void write(byte[] buf, int offset, int len) throws IOException {
obtainLockAndCheckState();
int level = this.response.obtainLockAndCheckState();
try {
this.delegate.getOutputStream().write(buf, offset, len);
this.delegate.write(buf, offset, len);
}
catch (IOException ex) {
handleIOException(ex, "ServletOutputStream failed to write");
this.response.handleIOException(ex, "ServletOutputStream failed to write");
}
finally {
releaseLock();
this.response.releaseLock(level);
}
}
@Override
public void flush() throws IOException {
obtainLockAndCheckState();
int level = this.response.obtainLockAndCheckState();
try {
this.delegate.getOutputStream().flush();
this.delegate.flush();
}
catch (IOException ex) {
handleIOException(ex, "ServletOutputStream failed to flush");
this.response.handleIOException(ex, "ServletOutputStream failed to flush");
}
finally {
releaseLock();
this.response.releaseLock(level);
}
}
@Override
public void close() throws IOException {
obtainLockAndCheckState();
int level = this.response.obtainLockAndCheckState();
try {
this.delegate.getOutputStream().close();
this.delegate.close();
}
catch (IOException ex) {
handleIOException(ex, "ServletOutputStream failed to close");
this.response.handleIOException(ex, "ServletOutputStream failed to close");
}
finally {
releaseLock();
this.response.releaseLock(level);
}
}
private void obtainLockAndCheckState() throws AsyncRequestNotUsableException {
if (state() != State.NEW) {
stateLock().lock();
if (state() != State.ASYNC) {
stateLock().unlock();
throw new AsyncRequestNotUsableException("Response not usable after " +
(state() == State.COMPLETED ?
"async request completion" : "onError notification") + ".");
}
}
}
private void handleIOException(IOException ex, String msg) throws AsyncRequestNotUsableException {
this.asyncWebRequest.transitionToErrorState();
throw new AsyncRequestNotUsableException(msg, ex);
}
private void releaseLock() {
if (state() != State.NEW) {
stateLock().unlock();
}
}
private State state() {
return this.asyncWebRequest.state;
}
private Lock stateLock() {
return this.asyncWebRequest.stateLock;
}
}
/**
* Wraps a PrintWriter to prevent use after Servlet container onError
* notifications, and after async request completion.
* @since 5.3.33
*/
private static final class LifecyclePrintWriter extends PrintWriter {
@ -395,24 +432,26 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements @@ -395,24 +432,26 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements
@Override
public void flush() {
if (tryObtainLockAndCheckState()) {
int level = tryObtainLockAndCheckState();
if (level > -1) {
try {
this.delegate.flush();
}
finally {
releaseLock();
releaseLock(level);
}
}
}
@Override
public void close() {
if (tryObtainLockAndCheckState()) {
int level = tryObtainLockAndCheckState();
if (level > -1) {
try {
this.delegate.close();
}
finally {
releaseLock();
releaseLock(level);
}
}
}
@ -424,24 +463,26 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements @@ -424,24 +463,26 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements
@Override
public void write(int c) {
if (tryObtainLockAndCheckState()) {
int level = tryObtainLockAndCheckState();
if (level > -1) {
try {
this.delegate.write(c);
}
finally {
releaseLock();
releaseLock(level);
}
}
}
@Override
public void write(char[] buf, int off, int len) {
if (tryObtainLockAndCheckState()) {
int level = tryObtainLockAndCheckState();
if (level > -1) {
try {
this.delegate.write(buf, off, len);
}
finally {
releaseLock();
releaseLock(level);
}
}
}
@ -453,12 +494,13 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements @@ -453,12 +494,13 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements
@Override
public void write(String s, int off, int len) {
if (tryObtainLockAndCheckState()) {
int level = tryObtainLockAndCheckState();
if (level > -1) {
try {
this.delegate.write(s, off, len);
}
finally {
releaseLock();
releaseLock(level);
}
}
}
@ -468,33 +510,28 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements @@ -468,33 +510,28 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements
this.delegate.write(s);
}
private boolean tryObtainLockAndCheckState() {
if (state() == State.NEW) {
return true;
/**
* Return 0 if checks passed and lock is not needed, 1 if checks passed
* and lock is held, and -1 if checks did not pass.
*/
private int tryObtainLockAndCheckState() {
if (this.asyncWebRequest.state == State.NEW) {
return 0;
}
if (stateLock().tryLock()) {
if (state() == State.ASYNC) {
return true;
}
stateLock().unlock();
this.asyncWebRequest.stateLock.lock();
if (this.asyncWebRequest.state == State.ASYNC) {
return 1;
}
return false;
this.asyncWebRequest.stateLock.unlock();
return -1;
}
private void releaseLock() {
if (state() != State.NEW) {
stateLock().unlock();
private void releaseLock(int level) {
if (level > 0) {
this.asyncWebRequest.stateLock.unlock();
}
}
private State state() {
return this.asyncWebRequest.state;
}
private Lock stateLock() {
return this.asyncWebRequest.stateLock;
}
// Plain delegates
@Override
@ -632,28 +669,28 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements @@ -632,28 +669,28 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements
/**
* Represents a state for {@link StandardServletAsyncWebRequest} to be in.
* <p><pre>
* NEW
* |
* v
* ASYNC----> +
* | |
* v |
* ERROR |
* | |
* v |
* COMPLETED <--+
* +------ NEW
* | |
* | v
* | ASYNC ----> +
* | | |
* | v |
* +----> ERROR |
* | |
* v |
* COMPLETED <---+
* </pre>
* @since 5.3.33
*/
private enum State {
/** New request (thas may not do async handling). */
/** New request (may not start async handling). */
NEW,
/** Async handling has started. */
ASYNC,
/** onError notification received, or ServletOutputStream failed. */
/** ServletOutputStream failed, or onError notification received. */
ERROR,
/** onComplete notification received. */

5
spring-web/src/main/java/org/springframework/web/context/request/async/WebAsyncManager.java

@ -514,14 +514,11 @@ public final class WebAsyncManager { @@ -514,14 +514,11 @@ public final class WebAsyncManager {
}
this.asyncWebRequest.startAsync();
if (logger.isDebugEnabled()) {
logger.debug("Started async request");
}
}
private static String formatUri(AsyncWebRequest asyncWebRequest) {
HttpServletRequest request = asyncWebRequest.getNativeRequest(HttpServletRequest.class);
return (request != null ? request.getRequestURI() : "servlet container");
return (request != null ? "\"" + request.getRequestURI() + "\"" : "servlet container");
}

357
spring-web/src/test/java/org/springframework/web/context/request/async/AsyncRequestNotUsableTests.java

@ -0,0 +1,357 @@ @@ -0,0 +1,357 @@
/*
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.context.request.async;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.concurrent.atomic.AtomicInteger;
import jakarta.servlet.AsyncEvent;
import jakarta.servlet.ServletOutputStream;
import jakarta.servlet.http.HttpServletResponse;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Nested;
import org.junit.jupiter.api.Test;
import org.mockito.stubbing.Answer;
import org.springframework.web.testfixture.servlet.MockAsyncContext;
import org.springframework.web.testfixture.servlet.MockHttpServletRequest;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.BDDMockito.doAnswer;
import static org.mockito.BDDMockito.doThrow;
import static org.mockito.BDDMockito.given;
import static org.mockito.BDDMockito.mock;
import static org.mockito.BDDMockito.verify;
import static org.mockito.BDDMockito.verifyNoInteractions;
/**
* {@link StandardServletAsyncWebRequest} tests related to response wrapping in
* order to enforce thread safety and prevent use after errors.
*
* @author Rossen Stoyanchev
*/
public class AsyncRequestNotUsableTests {
private final MockHttpServletRequest request = new MockHttpServletRequest();
private final HttpServletResponse response = mock();
private final ServletOutputStream outputStream = mock();
private final PrintWriter writer = mock();
private StandardServletAsyncWebRequest asyncRequest;
@BeforeEach
void setup() throws IOException {
this.request.setAsyncSupported(true);
given(this.response.getOutputStream()).willReturn(this.outputStream);
given(this.response.getWriter()).willReturn(this.writer);
this.asyncRequest = new StandardServletAsyncWebRequest(this.request, this.response);
}
@AfterEach
void tearDown() {
assertThat(this.asyncRequest.stateLock().isLocked()).isFalse();
}
@SuppressWarnings("DataFlowIssue")
private ServletOutputStream getWrappedOutputStream() throws IOException {
return this.asyncRequest.getResponse().getOutputStream();
}
@SuppressWarnings("DataFlowIssue")
private PrintWriter getWrappedWriter() throws IOException {
return this.asyncRequest.getResponse().getWriter();
}
@Nested
class ResponseTests {
@Test
void notUsableAfterError() throws IOException {
asyncRequest.startAsync();
asyncRequest.onError(new AsyncEvent(new MockAsyncContext(request, response), new Exception()));
HttpServletResponse wrapped = asyncRequest.getResponse();
assertThat(wrapped).isNotNull();
assertThatThrownBy(wrapped::getOutputStream).hasMessage("Response not usable after response errors.");
assertThatThrownBy(wrapped::getWriter).hasMessage("Response not usable after response errors.");
assertThatThrownBy(wrapped::flushBuffer).hasMessage("Response not usable after response errors.");
}
@Test
void notUsableAfterCompletion() throws IOException {
asyncRequest.startAsync();
asyncRequest.onComplete(new AsyncEvent(new MockAsyncContext(request, response)));
HttpServletResponse wrapped = asyncRequest.getResponse();
assertThat(wrapped).isNotNull();
assertThatThrownBy(wrapped::getOutputStream).hasMessage("Response not usable after async request completion.");
assertThatThrownBy(wrapped::getWriter).hasMessage("Response not usable after async request completion.");
assertThatThrownBy(wrapped::flushBuffer).hasMessage("Response not usable after async request completion.");
}
@Test
void notUsableWhenRecreatedAfterCompletion() throws IOException {
asyncRequest.startAsync();
asyncRequest.onComplete(new AsyncEvent(new MockAsyncContext(request, response)));
StandardServletAsyncWebRequest newWebRequest =
new StandardServletAsyncWebRequest(request, response, asyncRequest);
HttpServletResponse wrapped = newWebRequest.getResponse();
assertThat(wrapped).isNotNull();
assertThatThrownBy(wrapped::getOutputStream).hasMessage("Response not usable after async request completion.");
assertThatThrownBy(wrapped::getWriter).hasMessage("Response not usable after async request completion.");
assertThatThrownBy(wrapped::flushBuffer).hasMessage("Response not usable after async request completion.");
}
}
@Nested
class OutputStreamTests {
@Test
void use() throws IOException {
testUseOutputStream();
}
@Test
void useInAsyncState() throws IOException {
asyncRequest.startAsync();
testUseOutputStream();
}
private void testUseOutputStream() throws IOException {
ServletOutputStream wrapped = getWrappedOutputStream();
wrapped.write('a');
wrapped.write(new byte[0], 1, 2);
wrapped.flush();
wrapped.close();
verify(outputStream).write('a');
verify(outputStream).write(new byte[0], 1, 2);
verify(outputStream).flush();
verify(outputStream).close();
}
@Test
void notUsableAfterCompletion() throws IOException {
asyncRequest.startAsync();
ServletOutputStream wrapped = getWrappedOutputStream();
asyncRequest.onComplete(new AsyncEvent(new MockAsyncContext(request, response)));
assertThatThrownBy(() -> wrapped.write('a')).hasMessage("Response not usable after async request completion.");
assertThatThrownBy(() -> wrapped.write(new byte[0])).hasMessage("Response not usable after async request completion.");
assertThatThrownBy(() -> wrapped.write(new byte[0], 0, 0)).hasMessage("Response not usable after async request completion.");
assertThatThrownBy(wrapped::flush).hasMessage("Response not usable after async request completion.");
assertThatThrownBy(wrapped::close).hasMessage("Response not usable after async request completion.");
}
@Test
void lockingNotUsed() throws IOException {
AtomicInteger count = new AtomicInteger(-1);
doAnswer((Answer<Void>) invocation -> {
count.set(asyncRequest.stateLock().getHoldCount());
return null;
}).when(outputStream).write('a');
// Access ServletOutputStream in NEW state (no async handling) without locking
getWrappedOutputStream().write('a');
assertThat(count.get()).isEqualTo(0);
}
@Test
void lockingUsedInAsyncState() throws IOException {
AtomicInteger count = new AtomicInteger(-1);
doAnswer((Answer<Void>) invocation -> {
count.set(asyncRequest.stateLock().getHoldCount());
return null;
}).when(outputStream).write('a');
// Access ServletOutputStream in ASYNC state with locking
asyncRequest.startAsync();
getWrappedOutputStream().write('a');
assertThat(count.get()).isEqualTo(1);
}
}
@Nested
class OutputStreamErrorTests {
@Test
void writeInt() throws IOException {
asyncRequest.startAsync();
ServletOutputStream wrapped = getWrappedOutputStream();
doThrow(new IOException("Broken pipe")).when(outputStream).write('a');
assertThatThrownBy(() -> wrapped.write('a')).hasMessage("ServletOutputStream failed to write: Broken pipe");
}
@Test
void writeBytesFull() throws IOException {
asyncRequest.startAsync();
ServletOutputStream wrapped = getWrappedOutputStream();
byte[] bytes = new byte[0];
doThrow(new IOException("Broken pipe")).when(outputStream).write(bytes, 0, 0);
assertThatThrownBy(() -> wrapped.write(bytes)).hasMessage("ServletOutputStream failed to write: Broken pipe");
}
@Test
void writeBytes() throws IOException {
asyncRequest.startAsync();
ServletOutputStream wrapped = getWrappedOutputStream();
byte[] bytes = new byte[0];
doThrow(new IOException("Broken pipe")).when(outputStream).write(bytes, 0, 0);
assertThatThrownBy(() -> wrapped.write(bytes, 0, 0)).hasMessage("ServletOutputStream failed to write: Broken pipe");
}
@Test
void flush() throws IOException {
asyncRequest.startAsync();
ServletOutputStream wrapped = getWrappedOutputStream();
doThrow(new IOException("Broken pipe")).when(outputStream).flush();
assertThatThrownBy(wrapped::flush).hasMessage("ServletOutputStream failed to flush: Broken pipe");
}
@Test
void close() throws IOException {
asyncRequest.startAsync();
ServletOutputStream wrapped = getWrappedOutputStream();
doThrow(new IOException("Broken pipe")).when(outputStream).close();
assertThatThrownBy(wrapped::close).hasMessage("ServletOutputStream failed to close: Broken pipe");
}
@Test
void writeErrorPreventsFurtherWriting() throws IOException {
ServletOutputStream wrapped = getWrappedOutputStream();
doThrow(new IOException("Broken pipe")).when(outputStream).write('a');
assertThatThrownBy(() -> wrapped.write('a')).hasMessage("ServletOutputStream failed to write: Broken pipe");
assertThatThrownBy(() -> wrapped.write('a')).hasMessage("Response not usable after response errors.");
}
@Test
void writeErrorInAsyncStatePreventsFurtherWriting() throws IOException {
asyncRequest.startAsync();
ServletOutputStream wrapped = getWrappedOutputStream();
doThrow(new IOException("Broken pipe")).when(outputStream).write('a');
assertThatThrownBy(() -> wrapped.write('a')).hasMessage("ServletOutputStream failed to write: Broken pipe");
assertThatThrownBy(() -> wrapped.write('a')).hasMessage("Response not usable after response errors.");
}
}
@Nested
class WriterTests {
@Test
void useWriter() throws IOException {
testUseWriter();
}
@Test
void useWriterInAsyncState() throws IOException {
asyncRequest.startAsync();
testUseWriter();
}
private void testUseWriter() throws IOException {
PrintWriter wrapped = getWrappedWriter();
wrapped.write('a');
wrapped.write(new char[0], 1, 2);
wrapped.write("abc", 1, 2);
wrapped.flush();
wrapped.close();
verify(writer).write('a');
verify(writer).write(new char[0], 1, 2);
verify(writer).write("abc", 1, 2);
verify(writer).flush();
verify(writer).close();
}
@Test
void writerNotUsableAfterCompletion() throws IOException {
asyncRequest.startAsync();
PrintWriter wrapped = getWrappedWriter();
asyncRequest.onComplete(new AsyncEvent(new MockAsyncContext(request, response)));
char[] chars = new char[0];
wrapped.write('a');
wrapped.write(chars, 1, 2);
wrapped.flush();
wrapped.close();
verifyNoInteractions(writer);
}
@Test
void lockingNotUsed() throws IOException {
AtomicInteger count = new AtomicInteger(-1);
doAnswer((Answer<Void>) invocation -> {
count.set(asyncRequest.stateLock().getHoldCount());
return null;
}).when(writer).write('a');
// Use Writer in NEW state (no async handling) without locking
PrintWriter wrapped = getWrappedWriter();
wrapped.write('a');
assertThat(count.get()).isEqualTo(0);
}
@Test
void lockingUsedInAsyncState() throws IOException {
AtomicInteger count = new AtomicInteger(-1);
doAnswer((Answer<Void>) invocation -> {
count.set(asyncRequest.stateLock().getHoldCount());
return null;
}).when(writer).write('a');
// Use Writer in ASYNC state with locking
asyncRequest.startAsync();
PrintWriter wrapped = getWrappedWriter();
wrapped.write('a');
assertThat(count.get()).isEqualTo(1);
}
}
}

44
spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/RequestMappingHandlerAdapterTests.java

@ -1,5 +1,5 @@ @@ -1,5 +1,5 @@
/*
* Copyright 2002-2023 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,8 +16,11 @@ @@ -16,8 +16,11 @@
package org.springframework.web.servlet.mvc.method.annotation;
import java.io.IOException;
import java.io.OutputStream;
import java.lang.reflect.Method;
import java.lang.reflect.Type;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
@ -25,6 +28,7 @@ import java.util.LinkedHashMap; @@ -25,6 +28,7 @@ import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import jakarta.servlet.AsyncEvent;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
@ -46,6 +50,10 @@ import org.springframework.ui.Model; @@ -46,6 +50,10 @@ import org.springframework.ui.Model;
import org.springframework.web.bind.annotation.ControllerAdvice;
import org.springframework.web.bind.annotation.ModelAttribute;
import org.springframework.web.bind.annotation.SessionAttributes;
import org.springframework.web.context.request.async.AsyncRequestNotUsableException;
import org.springframework.web.context.request.async.StandardServletAsyncWebRequest;
import org.springframework.web.context.request.async.WebAsyncManager;
import org.springframework.web.context.request.async.WebAsyncUtils;
import org.springframework.web.context.support.StaticWebApplicationContext;
import org.springframework.web.method.HandlerMethod;
import org.springframework.web.method.annotation.ModelMethodProcessor;
@ -55,13 +63,15 @@ import org.springframework.web.method.support.InvocableHandlerMethod; @@ -55,13 +63,15 @@ import org.springframework.web.method.support.InvocableHandlerMethod;
import org.springframework.web.servlet.DispatcherServlet;
import org.springframework.web.servlet.FlashMap;
import org.springframework.web.servlet.ModelAndView;
import org.springframework.web.testfixture.servlet.MockAsyncContext;
import org.springframework.web.testfixture.servlet.MockHttpServletRequest;
import org.springframework.web.testfixture.servlet.MockHttpServletResponse;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
/**
* Unit tests for {@link RequestMappingHandlerAdapter}.
* Tests for {@link RequestMappingHandlerAdapter}.
*
* @author Rossen Stoyanchev
* @author Sam Brannen
@ -249,9 +259,7 @@ public class RequestMappingHandlerAdapterTests { @@ -249,9 +259,7 @@ public class RequestMappingHandlerAdapterTests {
assertThat(mav.getModel().get("attr3")).isNull();
}
// SPR-10859
@Test
@Test // gh-15486
public void responseBodyAdvice() throws Exception {
List<HttpMessageConverter<?>> converters = new ArrayList<>();
converters.add(new MappingJackson2HttpMessageConverter());
@ -271,6 +279,26 @@ public class RequestMappingHandlerAdapterTests { @@ -271,6 +279,26 @@ public class RequestMappingHandlerAdapterTests {
assertThat(this.response.getContentAsString()).isEqualTo("{\"status\":400,\"message\":\"body\"}");
}
@Test
void asyncRequestNotUsable() throws Exception {
// Put AsyncWebRequest in ERROR state
StandardServletAsyncWebRequest asyncRequest = new StandardServletAsyncWebRequest(this.request, this.response);
asyncRequest.onError(new AsyncEvent(new MockAsyncContext(this.request, this.response), new Exception()));
// Set it as the current AsyncWebRequest, from the initial REQUEST dispatch
WebAsyncManager asyncManager = WebAsyncUtils.getAsyncManager(this.request);
asyncManager.setAsyncWebRequest(asyncRequest);
// AsyncWebRequest created for current dispatch should inherit state
HandlerMethod handlerMethod = handlerMethod(new SimpleController(), "handleOutputStream", OutputStream.class);
this.handlerAdapter.afterPropertiesSet();
// Use of response should be rejected
assertThatThrownBy(() -> this.handlerAdapter.handle(this.request, this.response, handlerMethod))
.isInstanceOf(AsyncRequestNotUsableException.class);
}
private HandlerMethod handlerMethod(Object handler, String methodName, Class<?>... paramTypes) throws Exception {
Method method = handler.getClass().getDeclaredMethod(methodName, paramTypes);
return new InvocableHandlerMethod(handler, method);
@ -296,14 +324,16 @@ public class RequestMappingHandlerAdapterTests { @@ -296,14 +324,16 @@ public class RequestMappingHandlerAdapterTests {
}
public ResponseEntity<Map<String, String>> handleWithResponseEntity() {
return new ResponseEntity<>(Collections.singletonMap(
"foo", "bar"), HttpStatus.OK);
return new ResponseEntity<>(Collections.singletonMap("foo", "bar"), HttpStatus.OK);
}
public ResponseEntity<String> handleBadRequest() {
return new ResponseEntity<>("body", HttpStatus.BAD_REQUEST);
}
public void handleOutputStream(OutputStream outputStream) throws IOException {
outputStream.write("body".getBytes(StandardCharsets.UTF_8));
}
}

Loading…
Cancel
Save