@ -1,5 +1,5 @@
@@ -1,5 +1,5 @@
/ *
* Copyright 2002 - 2023 the original author or authors .
* Copyright 2002 - 2024 the original author or authors .
*
* Licensed under the Apache License , Version 2 . 0 ( the "License" ) ;
* you may not use this file except in compliance with the License .
@ -17,16 +17,22 @@
@@ -17,16 +17,22 @@
package org.springframework.web.context.request.async ;
import java.io.IOException ;
import java.io.PrintWriter ;
import java.util.ArrayList ;
import java.util.List ;
import java.util.concurrent.atomic.AtomicBoolean ;
import java.util.Locale ;
import java.util.concurrent.locks.Lock ;
import java.util.concurrent.locks.ReentrantLock ;
import java.util.function.Consumer ;
import jakarta.servlet.AsyncContext ;
import jakarta.servlet.AsyncEvent ;
import jakarta.servlet.AsyncListener ;
import jakarta.servlet.ServletOutputStream ;
import jakarta.servlet.WriteListener ;
import jakarta.servlet.http.HttpServletRequest ;
import jakarta.servlet.http.HttpServletResponse ;
import jakarta.servlet.http.HttpServletResponseWrapper ;
import org.springframework.lang.Nullable ;
import org.springframework.util.Assert ;
@ -45,8 +51,6 @@ import org.springframework.web.context.request.ServletWebRequest;
@@ -45,8 +51,6 @@ import org.springframework.web.context.request.ServletWebRequest;
* /
public class StandardServletAsyncWebRequest extends ServletWebRequest implements AsyncWebRequest , AsyncListener {
private final AtomicBoolean asyncCompleted = new AtomicBoolean ( ) ;
private final List < Runnable > timeoutHandlers = new ArrayList < > ( ) ;
private final List < Consumer < Throwable > > exceptionHandlers = new ArrayList < > ( ) ;
@ -59,6 +63,10 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements
@@ -59,6 +63,10 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements
@Nullable
private AsyncContext asyncContext ;
private State state ;
private final ReentrantLock stateLock = new ReentrantLock ( ) ;
/ * *
* Create a new instance for the given request / response pair .
@ -66,7 +74,26 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements
@@ -66,7 +74,26 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements
* @param response current HTTP response
* /
public StandardServletAsyncWebRequest ( HttpServletRequest request , HttpServletResponse response ) {
super ( request , response ) ;
this ( request , response , null ) ;
}
/ * *
* Constructor to wrap the request and response for the current dispatch that
* also picks up the state of the last ( probably the REQUEST ) dispatch .
* @param request current HTTP request
* @param response current HTTP response
* @param previousRequest the existing request from the last dispatch
* @since 5 . 3 . 33
* /
StandardServletAsyncWebRequest ( HttpServletRequest request , HttpServletResponse response ,
@Nullable StandardServletAsyncWebRequest previousRequest ) {
super ( request , new LifecycleHttpServletResponse ( response ) ) ;
this . state = ( previousRequest ! = null ? previousRequest . state : State . NEW ) ;
//noinspection DataFlowIssue
( ( LifecycleHttpServletResponse ) getResponse ( ) ) . setAsyncWebRequest ( this ) ;
}
@ -107,7 +134,7 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements
@@ -107,7 +134,7 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements
* /
@Override
public boolean isAsyncComplete ( ) {
return this . asyncCompleted . get ( ) ;
return ( this . state = = State . COMPLETED ) ;
}
@Override
@ -117,11 +144,18 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements
@@ -117,11 +144,18 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements
"in async request processing. This is done in Java code using the Servlet API " +
"or by adding \"<async-supported>true</async-supported>\" to servlet and " +
"filter declarations in web.xml." ) ;
Assert . state ( ! isAsyncComplete ( ) , "Async processing has already completed" ) ;
if ( isAsyncStarted ( ) ) {
return ;
}
if ( this . state = = State . NEW ) {
this . state = State . ASYNC ;
}
else {
Assert . state ( this . state = = State . ASYNC , "Cannot start async: [" + this . state + "]" ) ;
}
this . asyncContext = getRequest ( ) . startAsync ( getRequest ( ) , getResponse ( ) ) ;
this . asyncContext . addListener ( this ) ;
if ( this . timeout ! = null ) {
@ -131,9 +165,11 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements
@@ -131,9 +165,11 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements
@Override
public void dispatch ( ) {
Assert . state ( this . asyncContext ! = null , "Cannot dispatch without an AsyncContext" ) ;
Assert . state ( this . asyncContext ! = null , "AsyncContext not yet initialized" ) ;
if ( ! this . isAsyncComplete ( ) ) {
this . asyncContext . dispatch ( ) ;
}
}
// ---------------------------------------------------------------------
@ -151,14 +187,478 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements
@@ -151,14 +187,478 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements
@Override
public void onError ( AsyncEvent event ) throws IOException {
this . exceptionHandlers . forEach ( consumer - > consumer . accept ( event . getThrowable ( ) ) ) ;
this . stateLock . lock ( ) ;
try {
transitionToErrorState ( ) ;
Throwable ex = event . getThrowable ( ) ;
this . exceptionHandlers . forEach ( consumer - > consumer . accept ( ex ) ) ;
}
finally {
this . stateLock . unlock ( ) ;
}
}
private void transitionToErrorState ( ) {
if ( ! isAsyncComplete ( ) ) {
this . state = State . ERROR ;
}
}
@Override
public void onComplete ( AsyncEvent event ) throws IOException {
this . stateLock . lock ( ) ;
try {
this . completionHandlers . forEach ( Runnable : : run ) ;
this . asyncContext = null ;
this . asyncCompleted . set ( true ) ;
this . state = State . COMPLETED ;
}
finally {
this . stateLock . unlock ( ) ;
}
}
/ * *
* Response wrapper to wrap the output stream with { @link LifecycleServletOutputStream } .
* /
private static final class LifecycleHttpServletResponse extends HttpServletResponseWrapper {
@Nullable
private StandardServletAsyncWebRequest asyncWebRequest ;
@Nullable
private ServletOutputStream outputStream ;
@Nullable
private PrintWriter writer ;
public LifecycleHttpServletResponse ( HttpServletResponse response ) {
super ( response ) ;
}
public void setAsyncWebRequest ( StandardServletAsyncWebRequest asyncWebRequest ) {
this . asyncWebRequest = asyncWebRequest ;
}
@Override
public ServletOutputStream getOutputStream ( ) {
if ( this . outputStream = = null ) {
Assert . notNull ( this . asyncWebRequest , "Not initialized" ) ;
this . outputStream = new LifecycleServletOutputStream (
( HttpServletResponse ) getResponse ( ) , this . asyncWebRequest ) ;
}
return this . outputStream ;
}
@Override
public PrintWriter getWriter ( ) throws IOException {
if ( this . writer = = null ) {
Assert . notNull ( this . asyncWebRequest , "Not initialized" ) ;
this . writer = new LifecyclePrintWriter ( getResponse ( ) . getWriter ( ) , this . asyncWebRequest ) ;
}
return this . writer ;
}
}
/ * *
* Wraps a ServletOutputStream to prevent use after Servlet container onError
* notifications , and after async request completion .
* /
private static final class LifecycleServletOutputStream extends ServletOutputStream {
private final HttpServletResponse delegate ;
private final StandardServletAsyncWebRequest asyncWebRequest ;
private LifecycleServletOutputStream (
HttpServletResponse delegate , StandardServletAsyncWebRequest asyncWebRequest ) {
this . delegate = delegate ;
this . asyncWebRequest = asyncWebRequest ;
}
@Override
public boolean isReady ( ) {
return false ;
}
@Override
public void setWriteListener ( WriteListener writeListener ) {
throw new UnsupportedOperationException ( ) ;
}
@Override
public void write ( int b ) throws IOException {
obtainLockAndCheckState ( ) ;
try {
this . delegate . getOutputStream ( ) . write ( b ) ;
}
catch ( IOException ex ) {
handleIOException ( ex , "ServletOutputStream failed to write" ) ;
}
finally {
releaseLock ( ) ;
}
}
public void write ( byte [ ] buf , int offset , int len ) throws IOException {
obtainLockAndCheckState ( ) ;
try {
this . delegate . getOutputStream ( ) . write ( buf , offset , len ) ;
}
catch ( IOException ex ) {
handleIOException ( ex , "ServletOutputStream failed to write" ) ;
}
finally {
releaseLock ( ) ;
}
}
@Override
public void flush ( ) throws IOException {
obtainLockAndCheckState ( ) ;
try {
this . delegate . getOutputStream ( ) . flush ( ) ;
}
catch ( IOException ex ) {
handleIOException ( ex , "ServletOutputStream failed to flush" ) ;
}
finally {
releaseLock ( ) ;
}
}
@Override
public void close ( ) throws IOException {
obtainLockAndCheckState ( ) ;
try {
this . delegate . getOutputStream ( ) . close ( ) ;
}
catch ( IOException ex ) {
handleIOException ( ex , "ServletOutputStream failed to close" ) ;
}
finally {
releaseLock ( ) ;
}
}
private void obtainLockAndCheckState ( ) throws AsyncRequestNotUsableException {
if ( state ( ) ! = State . NEW ) {
stateLock ( ) . lock ( ) ;
if ( state ( ) ! = State . ASYNC ) {
stateLock ( ) . unlock ( ) ;
throw new AsyncRequestNotUsableException ( "Response not usable after " +
( state ( ) = = State . COMPLETED ?
"async request completion" : "onError notification" ) + "." ) ;
}
}
}
private void handleIOException ( IOException ex , String msg ) throws AsyncRequestNotUsableException {
this . asyncWebRequest . transitionToErrorState ( ) ;
throw new AsyncRequestNotUsableException ( msg , ex ) ;
}
private void releaseLock ( ) {
if ( state ( ) ! = State . NEW ) {
stateLock ( ) . unlock ( ) ;
}
}
private State state ( ) {
return this . asyncWebRequest . state ;
}
private Lock stateLock ( ) {
return this . asyncWebRequest . stateLock ;
}
}
/ * *
* Wraps a PrintWriter to prevent use after Servlet container onError
* notifications , and after async request completion .
* /
private static final class LifecyclePrintWriter extends PrintWriter {
private final PrintWriter delegate ;
private final StandardServletAsyncWebRequest asyncWebRequest ;
private LifecyclePrintWriter ( PrintWriter delegate , StandardServletAsyncWebRequest asyncWebRequest ) {
super ( delegate ) ;
this . delegate = delegate ;
this . asyncWebRequest = asyncWebRequest ;
}
@Override
public void flush ( ) {
if ( tryObtainLockAndCheckState ( ) ) {
try {
this . delegate . flush ( ) ;
}
finally {
releaseLock ( ) ;
}
}
}
@Override
public void close ( ) {
if ( tryObtainLockAndCheckState ( ) ) {
try {
this . delegate . close ( ) ;
}
finally {
releaseLock ( ) ;
}
}
}
@Override
public boolean checkError ( ) {
return this . delegate . checkError ( ) ;
}
@Override
public void write ( int c ) {
if ( tryObtainLockAndCheckState ( ) ) {
try {
this . delegate . write ( c ) ;
}
finally {
releaseLock ( ) ;
}
}
}
@Override
public void write ( char [ ] buf , int off , int len ) {
if ( tryObtainLockAndCheckState ( ) ) {
try {
this . delegate . write ( buf , off , len ) ;
}
finally {
releaseLock ( ) ;
}
}
}
@Override
public void write ( char [ ] buf ) {
this . delegate . write ( buf ) ;
}
@Override
public void write ( String s , int off , int len ) {
if ( tryObtainLockAndCheckState ( ) ) {
try {
this . delegate . write ( s , off , len ) ;
}
finally {
releaseLock ( ) ;
}
}
}
@Override
public void write ( String s ) {
this . delegate . write ( s ) ;
}
private boolean tryObtainLockAndCheckState ( ) {
if ( state ( ) = = State . NEW ) {
return true ;
}
if ( stateLock ( ) . tryLock ( ) ) {
if ( state ( ) = = State . ASYNC ) {
return true ;
}
stateLock ( ) . unlock ( ) ;
}
return false ;
}
private void releaseLock ( ) {
if ( state ( ) ! = State . NEW ) {
stateLock ( ) . unlock ( ) ;
}
}
private State state ( ) {
return this . asyncWebRequest . state ;
}
private Lock stateLock ( ) {
return this . asyncWebRequest . stateLock ;
}
// Plain delegates
@Override
public void print ( boolean b ) {
this . delegate . print ( b ) ;
}
@Override
public void print ( char c ) {
this . delegate . print ( c ) ;
}
@Override
public void print ( int i ) {
this . delegate . print ( i ) ;
}
@Override
public void print ( long l ) {
this . delegate . print ( l ) ;
}
@Override
public void print ( float f ) {
this . delegate . print ( f ) ;
}
@Override
public void print ( double d ) {
this . delegate . print ( d ) ;
}
@Override
public void print ( char [ ] s ) {
this . delegate . print ( s ) ;
}
@Override
public void print ( String s ) {
this . delegate . print ( s ) ;
}
@Override
public void print ( Object obj ) {
this . delegate . print ( obj ) ;
}
@Override
public void println ( ) {
this . delegate . println ( ) ;
}
@Override
public void println ( boolean x ) {
this . delegate . println ( x ) ;
}
@Override
public void println ( char x ) {
this . delegate . println ( x ) ;
}
@Override
public void println ( int x ) {
this . delegate . println ( x ) ;
}
@Override
public void println ( long x ) {
this . delegate . println ( x ) ;
}
@Override
public void println ( float x ) {
this . delegate . println ( x ) ;
}
@Override
public void println ( double x ) {
this . delegate . println ( x ) ;
}
@Override
public void println ( char [ ] x ) {
this . delegate . println ( x ) ;
}
@Override
public void println ( String x ) {
this . delegate . println ( x ) ;
}
@Override
public void println ( Object x ) {
this . delegate . println ( x ) ;
}
@Override
public PrintWriter printf ( String format , Object . . . args ) {
return this . delegate . printf ( format , args ) ;
}
@Override
public PrintWriter printf ( Locale l , String format , Object . . . args ) {
return this . delegate . printf ( l , format , args ) ;
}
@Override
public PrintWriter format ( String format , Object . . . args ) {
return this . delegate . format ( format , args ) ;
}
@Override
public PrintWriter format ( Locale l , String format , Object . . . args ) {
return this . delegate . format ( l , format , args ) ;
}
@Override
public PrintWriter append ( CharSequence csq ) {
return this . delegate . append ( csq ) ;
}
@Override
public PrintWriter append ( CharSequence csq , int start , int end ) {
return this . delegate . append ( csq , start , end ) ;
}
@Override
public PrintWriter append ( char c ) {
return this . delegate . append ( c ) ;
}
}
/ * *
* Represents a state for { @link StandardServletAsyncWebRequest } to be in .
* < p > < pre >
* NEW
* |
* v
* ASYNC - - - - > +
* | |
* v |
* ERROR |
* | |
* v |
* COMPLETED < - - +
* < / pre >
* @since 5 . 3 . 33
* /
private enum State {
/** New request (thas may not do async handling). */
NEW ,
/** Async handling has started. */
ASYNC ,
/** onError notification received, or ServletOutputStream failed. */
ERROR ,
/** onComplete notification received. */
COMPLETED
}
}