@ -17,14 +17,8 @@
@@ -17,14 +17,8 @@
package org.springframework.http.server.reactive ;
import java.io.IOException ;
import java.io.InputStream ;
import java.util.concurrent.atomic.AtomicLong ;
import javax.servlet.AsyncContext ;
import javax.servlet.ReadListener ;
import javax.servlet.ServletException ;
import javax.servlet.ServletInputStream ;
import javax.servlet.ServletOutputStream ;
import javax.servlet.WriteListener ;
import javax.servlet.annotation.WebServlet ;
import javax.servlet.http.HttpServlet ;
import javax.servlet.http.HttpServletRequest ;
@ -32,17 +26,13 @@ import javax.servlet.http.HttpServletResponse;
@@ -32,17 +26,13 @@ import javax.servlet.http.HttpServletResponse;
import org.apache.commons.logging.Log ;
import org.apache.commons.logging.LogFactory ;
import org.reactivestreams.Publisher ;
import org.reactivestreams.Subscriber ;
import org.reactivestreams.Subscription ;
import reactor.core.publisher.Mono ;
import org.springframework.core.io.buffer.DataBuffer ;
import org.springframework.core.io.buffer.DataBufferAllocator ;
import org.springframework.core.io.buffer.DefaultDataBufferAllocator ;
import org.springframework.http.HttpStatus ;
import org.springframework.util.Assert ;
import org.springframework.util.StreamUtils ;
/ * *
* @author Arjen Poutsma
@ -51,24 +41,35 @@ import org.springframework.util.StreamUtils;
@@ -51,24 +41,35 @@ import org.springframework.util.StreamUtils;
@WebServlet ( asyncSupported = true )
public class ServletHttpHandlerAdapter extends HttpServlet {
private static final int BUFFER_SIZE = 8192 ;
private static final int DEFAULT_ BUFFER_SIZE = 8192 ;
private static Log logger = LogFactory . getLog ( ServletHttpHandlerAdapter . class ) ;
private HttpHandler handler ;
private DataBufferAllocator allocator = new DefaultDataBufferAllocator ( ) ;
// Servlet is based on blocking I/O, hence the usage of non-direct, heap-based buffers
// (i.e. 'false' as constructor argument)
private DataBufferAllocator allocator = new DefaultDataBufferAllocator ( false ) ;
private int bufferSize = DEFAULT_BUFFER_SIZE ;
public void setHandler ( HttpHandler handler ) {
Assert . notNull ( handler , "'handler' must not be null" ) ;
this . handler = handler ;
}
public void setAllocator ( DataBufferAllocator allocator ) {
Assert . notNull ( allocator , "'allocator' must not be null" ) ;
this . allocator = allocator ;
}
public void setBufferSize ( int bufferSize ) {
Assert . isTrue ( bufferSize > 0 ) ;
this . bufferSize = bufferSize ;
}
@Override
protected void service ( HttpServletRequest servletRequest , HttpServletResponse servletResponse )
throws ServletException , IOException {
@ -76,299 +77,25 @@ public class ServletHttpHandlerAdapter extends HttpServlet {
@@ -76,299 +77,25 @@ public class ServletHttpHandlerAdapter extends HttpServlet {
AsyncContext context = servletRequest . startAsync ( ) ;
ServletAsyncContextSynchronizer synchronizer = new ServletAsyncContextSynchronizer ( context ) ;
RequestBodyPublisher requestBody =
new RequestBodyPublisher ( synchronizer , allocator , BUFFER_SIZE ) ;
ServletServerHttpRequest request = new ServletServerHttpRequest ( servletRequest , requestBody ) ;
servletRequest . getInputStream ( ) . setReadListener ( requestBody ) ;
ResponseBodySubscriber responseBodySubscriber =
new ResponseBodySubscriber ( synchronizer ) ;
ServletServerHttpResponse response = new ServletServerHttpResponse ( servletResponse ,
publisher - > Mono . from ( subscriber - > publisher . subscribe ( responseBodySubscriber ) ) ) ;
servletResponse . getOutputStream ( ) . setWriteListener ( responseBodySubscriber ) ;
HandlerResultSubscriber resultSubscriber = new HandlerResultSubscriber ( synchronizer , response ) ;
this . handler . handle ( request , response ) . subscribe ( resultSubscriber ) ;
}
private static class RequestBodyPublisher
implements ReadListener , Publisher < DataBuffer > {
private final ServletAsyncContextSynchronizer synchronizer ;
private final DataBufferAllocator allocator ;
private final byte [ ] buffer ;
private final DemandCounter demand = new DemandCounter ( ) ;
private Subscriber < ? super DataBuffer > subscriber ;
private boolean stalled ;
private boolean cancelled ;
public RequestBodyPublisher ( ServletAsyncContextSynchronizer synchronizer ,
DataBufferAllocator allocator , int bufferSize ) {
this . synchronizer = synchronizer ;
this . allocator = allocator ;
this . buffer = new byte [ bufferSize ] ;
}
@Override
public void subscribe ( Subscriber < ? super DataBuffer > subscriber ) {
if ( subscriber = = null ) {
throw new NullPointerException ( ) ;
}
else if ( this . subscriber ! = null ) {
subscriber . onError ( new IllegalStateException ( "Only one subscriber allowed" ) ) ;
}
this . subscriber = subscriber ;
this . subscriber . onSubscribe ( new RequestBodySubscription ( ) ) ;
}
@Override
public void onDataAvailable ( ) throws IOException {
if ( cancelled ) {
return ;
}
ServletInputStream input = this . synchronizer . getInputStream ( ) ;
logger . debug ( "onDataAvailable: " + input ) ;
while ( true ) {
logger . debug ( "Demand: " + this . demand ) ;
if ( ! demand . hasDemand ( ) ) {
stalled = true ;
break ;
}
boolean ready = input . isReady ( ) ;
logger . debug ( "Input ready: " + ready + " finished: " + input . isFinished ( ) ) ;
if ( ! ready ) {
break ;
}
int read = input . read ( buffer ) ;
logger . debug ( "Input read:" + read ) ;
if ( read = = - 1 ) {
break ;
}
else if ( read > 0 ) {
this . demand . decrement ( ) ;
DataBuffer dataBuffer = allocator . allocateBuffer ( read ) ;
dataBuffer . write ( this . buffer , 0 , read ) ;
this . subscriber . onNext ( dataBuffer ) ;
}
}
}
@Override
public void onAllDataRead ( ) throws IOException {
if ( cancelled ) {
return ;
}
logger . debug ( "All data read" ) ;
this . synchronizer . readComplete ( ) ;
if ( this . subscriber ! = null ) {
this . subscriber . onComplete ( ) ;
}
}
@Override
public void onError ( Throwable t ) {
if ( cancelled ) {
return ;
}
logger . error ( "RequestBodyPublisher Error" , t ) ;
this . synchronizer . readComplete ( ) ;
if ( this . subscriber ! = null ) {
this . subscriber . onError ( t ) ;
}
}
private class RequestBodySubscription implements Subscription {
@Override
public void request ( long n ) {
if ( cancelled ) {
return ;
}
logger . debug ( "Updating demand " + demand + " by " + n ) ;
demand . increase ( n ) ;
logger . debug ( "Stalled: " + stalled ) ;
if ( stalled ) {
stalled = false ;
try {
onDataAvailable ( ) ;
}
catch ( IOException ex ) {
onError ( ex ) ;
}
}
}
ServletServerHttpRequest request =
new ServletServerHttpRequest ( synchronizer , this . allocator ,
this . bufferSize ) ;
@Override
public void cancel ( ) {
if ( cancelled ) {
return ;
}
cancelled = true ;
synchronizer . readComplete ( ) ;
demand . reset ( ) ;
}
}
/ * *
* Small utility class for keeping track of Reactive Streams demand .
* /
private static final class DemandCounter {
private final AtomicLong demand = new AtomicLong ( ) ;
/ * *
* Increases the demand by the given number
* @param n the positive number to increase demand by
* @return the increased demand
* @see org . reactivestreams . Subscription # request ( long )
* /
public long increase ( long n ) {
Assert . isTrue ( n > 0 , "'n' must be higher than 0" ) ;
return demand . updateAndGet ( d - > d ! = Long . MAX_VALUE ? d + n : Long . MAX_VALUE ) ;
}
/ * *
* Decreases the demand by one .
* @return the decremented demand
* /
public long decrement ( ) {
return demand . updateAndGet ( d - > d ! = Long . MAX_VALUE ? d - 1 : Long . MAX_VALUE ) ;
}
ServletServerHttpResponse response =
new ServletServerHttpResponse ( synchronizer , this . bufferSize ) ;
/ * *
* Indicates whether this counter has demand , i . e . whether it is higher than 0 .
* @return { @code true } if this counter has demand ; { @code false } otherwise
* /
public boolean hasDemand ( ) {
return this . demand . get ( ) > 0 ;
}
HandlerResultSubscriber resultSubscriber =
new HandlerResultSubscriber ( synchronizer ) ;
/ * *
* Resets this counter to 0 .
* @see org . reactivestreams . Subscription # cancel ( )
* /
public void reset ( ) {
this . demand . set ( 0 ) ;
}
@Override
public String toString ( ) {
return demand . toString ( ) ;
}
}
}
private static class ResponseBodySubscriber
implements WriteListener , Subscriber < DataBuffer > {
private final ServletAsyncContextSynchronizer synchronizer ;
private Subscription subscription ;
private DataBuffer dataBuffer ;
private volatile boolean subscriberComplete = false ;
public ResponseBodySubscriber ( ServletAsyncContextSynchronizer synchronizer ) {
this . synchronizer = synchronizer ;
}
@Override
public void onSubscribe ( Subscription subscription ) {
this . subscription = subscription ;
this . subscription . request ( 1 ) ;
}
@Override
public void onNext ( DataBuffer bytes ) {
Assert . isNull ( dataBuffer ) ;
this . dataBuffer = bytes ;
try {
onWritePossible ( ) ;
}
catch ( IOException e ) {
onError ( e ) ;
}
}
@Override
public void onComplete ( ) {
logger . debug ( "Complete buffer: " + ( dataBuffer = = null ) ) ;
this . subscriberComplete = true ;
if ( dataBuffer = = null ) {
this . synchronizer . writeComplete ( ) ;
}
}
@Override
public void onWritePossible ( ) throws IOException {
ServletOutputStream output = this . synchronizer . getOutputStream ( ) ;
boolean ready = output . isReady ( ) ;
logger . debug ( "Output: " + ready + " buffer: " + ( dataBuffer = = null ) ) ;
if ( ready ) {
if ( this . dataBuffer ! = null ) {
InputStream in = this . dataBuffer . asInputStream ( ) ;
byte [ ] buffer = new byte [ BUFFER_SIZE ] ;
int bytesRead ;
while ( ( bytesRead = in . read ( buffer ) ) ! = - 1 ) {
output . write ( buffer , 0 , bytesRead ) ;
}
if ( ! subscriberComplete ) {
this . subscription . request ( 1 ) ;
}
else {
this . synchronizer . writeComplete ( ) ;
}
}
else {
this . subscription . request ( 1 ) ;
}
}
}
@Override
public void onError ( Throwable t ) {
logger . error ( "ResponseBodySubscriber error" , t ) ;
}
this . handler . handle ( request , response ) . subscribe ( resultSubscriber ) ;
}
private static class HandlerResultSubscriber implements Subscriber < Void > {
private final ServletAsyncContextSynchronizer synchronizer ;
private final ServletServerHttpResponse response ;
public HandlerResultSubscriber ( ServletAsyncContextSynchronizer synchronizer ,
ServletServerHttpResponse response ) {
public HandlerResultSubscriber ( ServletAsyncContextSynchronizer synchronizer ) {
this . synchronizer = synchronizer ;
this . response = response ;
}
@ -385,7 +112,9 @@ public class ServletHttpHandlerAdapter extends HttpServlet {
@@ -385,7 +112,9 @@ public class ServletHttpHandlerAdapter extends HttpServlet {
@Override
public void onError ( Throwable ex ) {
logger . error ( "Error from request handling. Completing the request." , ex ) ;
this . response . setStatusCode ( HttpStatus . INTERNAL_SERVER_ERROR ) ;
HttpServletResponse response =
( HttpServletResponse ) this . synchronizer . getResponse ( ) ;
response . setStatus ( HttpStatus . INTERNAL_SERVER_ERROR . value ( ) ) ;
this . synchronizer . complete ( ) ;
}