diff --git a/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/ServletHttpHandlerAdapter.java b/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/ServletHttpHandlerAdapter.java index 8726036ab2b..ea64176d0ad 100644 --- a/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/ServletHttpHandlerAdapter.java +++ b/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/ServletHttpHandlerAdapter.java @@ -17,13 +17,8 @@ package org.springframework.http.server.reactive; import java.io.IOException; -import java.io.InputStream; import javax.servlet.AsyncContext; -import javax.servlet.ReadListener; import javax.servlet.ServletException; -import javax.servlet.ServletInputStream; -import javax.servlet.ServletOutputStream; -import javax.servlet.WriteListener; import javax.servlet.annotation.WebServlet; import javax.servlet.http.HttpServlet; import javax.servlet.http.HttpServletRequest; @@ -33,9 +28,7 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; -import reactor.core.publisher.Mono; -import org.springframework.core.io.buffer.DataBuffer; import org.springframework.core.io.buffer.DataBufferFactory; import org.springframework.core.io.buffer.DefaultDataBufferFactory; import org.springframework.util.Assert; @@ -80,23 +73,15 @@ public class ServletHttpHandlerAdapter extends HttpServlet { AsyncContext asyncContext = servletRequest.startAsync(); - RequestBodyPublisher requestBody = - new RequestBodyPublisher(servletRequest.getInputStream(), - this.dataBufferFactory, this.bufferSize); - requestBody.registerListener(); ServletServerHttpRequest request = - new ServletServerHttpRequest(servletRequest, requestBody); - - ResponseBodyProcessor responseBody = - new ResponseBodyProcessor(servletResponse.getOutputStream(), + new ServletServerHttpRequest(servletRequest, this.dataBufferFactory, this.bufferSize); - responseBody.registerListener(); + request.registerListener(); + ServletServerHttpResponse response = new ServletServerHttpResponse(servletResponse, this.dataBufferFactory, - publisher -> Mono.from(subscriber -> { - publisher.subscribe(responseBody); - responseBody.subscribe(subscriber); - })); + this.bufferSize); + response.registerListener(); HandlerResultSubscriber resultSubscriber = new HandlerResultSubscriber(asyncContext); @@ -137,173 +122,5 @@ public class ServletHttpHandlerAdapter extends HttpServlet { } } - private static class RequestBodyPublisher extends AbstractRequestBodyPublisher { - - private final RequestBodyPublisher.RequestBodyReadListener readListener = - new RequestBodyPublisher.RequestBodyReadListener(); - - private final ServletInputStream inputStream; - - private final DataBufferFactory dataBufferFactory; - - private final byte[] buffer; - - public RequestBodyPublisher(ServletInputStream inputStream, - DataBufferFactory dataBufferFactory, int bufferSize) { - this.inputStream = inputStream; - this.dataBufferFactory = dataBufferFactory; - this.buffer = new byte[bufferSize]; - } - - public void registerListener() throws IOException { - this.inputStream.setReadListener(this.readListener); - } - - @Override - protected void checkOnDataAvailable() { - if (!this.inputStream.isFinished() && this.inputStream.isReady()) { - onDataAvailable(); - } - } - - @Override - protected DataBuffer read() throws IOException { - if (this.inputStream.isReady()) { - int read = this.inputStream.read(this.buffer); - if (logger.isTraceEnabled()) { - logger.trace("read:" + read); - } - - if (read > 0) { - DataBuffer dataBuffer = this.dataBufferFactory.allocateBuffer(read); - dataBuffer.write(this.buffer, 0, read); - return dataBuffer; - } - } - return null; - } - - private class RequestBodyReadListener implements ReadListener { - - @Override - public void onDataAvailable() throws IOException { - RequestBodyPublisher.this.onDataAvailable(); - } - - @Override - public void onAllDataRead() throws IOException { - RequestBodyPublisher.this.onAllDataRead(); - } - - @Override - public void onError(Throwable throwable) { - RequestBodyPublisher.this.onError(throwable); - - } - } - } - - private static class ResponseBodyProcessor extends AbstractResponseBodyProcessor { - - private final ResponseBodyWriteListener writeListener = - new ResponseBodyWriteListener(); - - private final ServletOutputStream outputStream; - - private final int bufferSize; - - private volatile boolean flushOnNext; - - public ResponseBodyProcessor(ServletOutputStream outputStream, int bufferSize) { - this.outputStream = outputStream; - this.bufferSize = bufferSize; - } - - public void registerListener() throws IOException { - this.outputStream.setWriteListener(this.writeListener); - } - - @Override - protected boolean isWritePossible() { - return this.outputStream.isReady(); - } - - @Override - protected boolean write(DataBuffer dataBuffer) throws IOException { - if (this.flushOnNext) { - flush(); - } - - boolean ready = this.outputStream.isReady(); - - if (this.logger.isTraceEnabled()) { - this.logger.trace("write: " + dataBuffer + " ready: " + ready); - } - - if (ready) { - int total = dataBuffer.readableByteCount(); - int written = writeDataBuffer(dataBuffer); - - if (this.logger.isTraceEnabled()) { - this.logger.trace("written: " + written + " total: " + total); - } - return written == total; - } - else { - return false; - } - } - - @Override - protected void flush() throws IOException { - if (this.outputStream.isReady()) { - if (logger.isTraceEnabled()) { - logger.trace("flush"); - } - try { - this.outputStream.flush(); - this.flushOnNext = false; - return; - } - catch (IOException ignored) { - } - } - this.flushOnNext = true; - - } - - private int writeDataBuffer(DataBuffer dataBuffer) throws IOException { - InputStream input = dataBuffer.asInputStream(); - - int bytesWritten = 0; - byte[] buffer = new byte[this.bufferSize]; - int bytesRead = -1; - - while (this.outputStream.isReady() && - (bytesRead = input.read(buffer)) != -1) { - this.outputStream.write(buffer, 0, bytesRead); - bytesWritten += bytesRead; - } - - return bytesWritten; - } - - private class ResponseBodyWriteListener implements WriteListener { - - @Override - public void onWritePossible() throws IOException { - ResponseBodyProcessor.this.onWritePossible(); - } - - @Override - public void onError(Throwable ex) { - // Error on writing to the HTTP stream, so any further writes will probably - // fail. Let's log instead of calling {@link #writeError}. - ResponseBodyProcessor.this.logger - .error("ResponseBodyWriteListener error", ex); - } - } - } - } \ No newline at end of file diff --git a/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/ServletServerHttpRequest.java b/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/ServletServerHttpRequest.java index 568942bde90..3027dea3109 100644 --- a/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/ServletServerHttpRequest.java +++ b/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/ServletServerHttpRequest.java @@ -16,18 +16,21 @@ package org.springframework.http.server.reactive; +import java.io.IOException; import java.net.URI; import java.net.URISyntaxException; import java.nio.charset.Charset; import java.util.Enumeration; import java.util.Map; +import javax.servlet.ReadListener; +import javax.servlet.ServletInputStream; import javax.servlet.http.Cookie; import javax.servlet.http.HttpServletRequest; -import org.reactivestreams.Publisher; import reactor.core.publisher.Flux; import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferFactory; import org.springframework.http.HttpCookie; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; @@ -40,24 +43,24 @@ import org.springframework.util.StringUtils; /** * Adapt {@link ServerHttpRequest} to the Servlet {@link HttpServletRequest}. - * * @author Rossen Stoyanchev */ public class ServletServerHttpRequest extends AbstractServerHttpRequest { private final HttpServletRequest request; - private final Flux requestBodyPublisher; + private final RequestBodyPublisher bodyPublisher; public ServletServerHttpRequest(HttpServletRequest request, - Publisher body) { + DataBufferFactory dataBufferFactory, int bufferSize) throws IOException { Assert.notNull(request, "'request' must not be null."); - Assert.notNull(body, "'body' must not be null."); + Assert.notNull(dataBufferFactory, "'dataBufferFactory' must not be null"); this.request = request; - this.requestBodyPublisher = Flux.from(body); + this.bodyPublisher = + new RequestBodyPublisher(request.getInputStream(), dataBufferFactory, + bufferSize); } - public HttpServletRequest getServletRequest() { return this.request; } @@ -80,9 +83,11 @@ public class ServletServerHttpRequest extends AbstractServerHttpRequest { @Override protected HttpHeaders initHeaders() { HttpHeaders headers = new HttpHeaders(); - for (Enumeration names = getServletRequest().getHeaderNames(); names.hasMoreElements(); ) { + for (Enumeration names = getServletRequest().getHeaderNames(); + names.hasMoreElements(); ) { String name = (String) names.nextElement(); - for (Enumeration values = getServletRequest().getHeaders(name); values.hasMoreElements(); ) { + for (Enumeration values = getServletRequest().getHeaders(name); + values.hasMoreElements(); ) { headers.add(name, (String) values.nextElement()); } } @@ -101,7 +106,9 @@ public class ServletServerHttpRequest extends AbstractServerHttpRequest { Map params = new LinkedCaseInsensitiveMap<>(); params.putAll(contentType.getParameters()); params.put("charset", charset.toString()); - headers.setContentType(new MediaType(contentType.getType(), contentType.getSubtype(), params)); + headers.setContentType( + new MediaType(contentType.getType(), contentType.getSubtype(), + params)); } } if (headers.getContentLength() == -1) { @@ -129,7 +136,76 @@ public class ServletServerHttpRequest extends AbstractServerHttpRequest { @Override public Flux getBody() { - return this.requestBodyPublisher; + return Flux.from(this.bodyPublisher); + } + + public void registerListener() throws IOException { + this.bodyPublisher.registerListener(); } + private static class RequestBodyPublisher extends AbstractRequestBodyPublisher { + + private final RequestBodyPublisher.RequestBodyReadListener readListener = + new RequestBodyPublisher.RequestBodyReadListener(); + + private final ServletInputStream inputStream; + + private final DataBufferFactory dataBufferFactory; + + private final byte[] buffer; + + public RequestBodyPublisher(ServletInputStream inputStream, + DataBufferFactory dataBufferFactory, int bufferSize) { + this.inputStream = inputStream; + this.dataBufferFactory = dataBufferFactory; + this.buffer = new byte[bufferSize]; + } + + public void registerListener() throws IOException { + this.inputStream.setReadListener(this.readListener); + } + + @Override + protected void checkOnDataAvailable() { + if (!this.inputStream.isFinished() && this.inputStream.isReady()) { + onDataAvailable(); + } + } + + @Override + protected DataBuffer read() throws IOException { + if (this.inputStream.isReady()) { + int read = this.inputStream.read(this.buffer); + if (logger.isTraceEnabled()) { + logger.trace("read:" + read); + } + + if (read > 0) { + DataBuffer dataBuffer = this.dataBufferFactory.allocateBuffer(read); + dataBuffer.write(this.buffer, 0, read); + return dataBuffer; + } + } + return null; + } + + private class RequestBodyReadListener implements ReadListener { + + @Override + public void onDataAvailable() throws IOException { + RequestBodyPublisher.this.onDataAvailable(); + } + + @Override + public void onAllDataRead() throws IOException { + RequestBodyPublisher.this.onAllDataRead(); + } + + @Override + public void onError(Throwable throwable) { + RequestBodyPublisher.this.onError(throwable); + + } + } + } } diff --git a/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/ServletServerHttpResponse.java b/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/ServletServerHttpResponse.java index 7cdad1512ae..a54f2b08089 100644 --- a/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/ServletServerHttpResponse.java +++ b/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/ServletServerHttpResponse.java @@ -16,10 +16,13 @@ package org.springframework.http.server.reactive; +import java.io.IOException; +import java.io.InputStream; import java.nio.charset.Charset; import java.util.List; import java.util.Map; -import java.util.function.Function; +import javax.servlet.ServletOutputStream; +import javax.servlet.WriteListener; import javax.servlet.http.Cookie; import javax.servlet.http.HttpServletResponse; @@ -35,32 +38,27 @@ import org.springframework.util.Assert; /** * Adapt {@link ServerHttpResponse} to the Servlet {@link HttpServletResponse}. - * * @author Rossen Stoyanchev */ public class ServletServerHttpResponse extends AbstractServerHttpResponse { private final HttpServletResponse response; - private final Function, Mono> responseBodyWriter; - + private ResponseBodyProcessor bodyProcessor; public ServletServerHttpResponse(HttpServletResponse response, - DataBufferFactory dataBufferFactory, - Function, Mono> responseBodyWriter) { + DataBufferFactory dataBufferFactory, int bufferSize) throws IOException { super(dataBufferFactory); Assert.notNull(response, "'response' must not be null"); - Assert.notNull(responseBodyWriter, "'responseBodyWriter' must not be null"); this.response = response; - this.responseBodyWriter = responseBodyWriter; + this.bodyProcessor = + new ResponseBodyProcessor(response.getOutputStream(), bufferSize); } - public HttpServletResponse getServletResponse() { return this.response; } - @Override protected void writeStatusCode() { HttpStatus statusCode = this.getStatusCode(); @@ -71,7 +69,10 @@ public class ServletServerHttpResponse extends AbstractServerHttpResponse { @Override protected Mono writeWithInternal(Publisher publisher) { - return this.responseBodyWriter.apply(publisher); + return Mono.from(subscriber -> { + publisher.subscribe(this.bodyProcessor); + this.bodyProcessor.subscribe(subscriber); + }); } @Override @@ -109,4 +110,109 @@ public class ServletServerHttpResponse extends AbstractServerHttpResponse { } } + public void registerListener() throws IOException { + this.bodyProcessor.registerListener(); + } + + private static class ResponseBodyProcessor extends AbstractResponseBodyProcessor { + + private final ResponseBodyWriteListener writeListener = + new ResponseBodyWriteListener(); + + private final ServletOutputStream outputStream; + + private final int bufferSize; + + private volatile boolean flushOnNext; + + public ResponseBodyProcessor(ServletOutputStream outputStream, int bufferSize) { + this.outputStream = outputStream; + this.bufferSize = bufferSize; + } + + public void registerListener() throws IOException { + this.outputStream.setWriteListener(this.writeListener); + } + + @Override + protected boolean isWritePossible() { + return this.outputStream.isReady(); + } + + @Override + protected boolean write(DataBuffer dataBuffer) throws IOException { + if (this.flushOnNext) { + flush(); + } + + boolean ready = this.outputStream.isReady(); + + if (this.logger.isTraceEnabled()) { + this.logger.trace("write: " + dataBuffer + " ready: " + ready); + } + + if (ready) { + int total = dataBuffer.readableByteCount(); + int written = writeDataBuffer(dataBuffer); + + if (this.logger.isTraceEnabled()) { + this.logger.trace("written: " + written + " total: " + total); + } + return written == total; + } + else { + return false; + } + } + + @Override + protected void flush() throws IOException { + if (this.outputStream.isReady()) { + if (logger.isTraceEnabled()) { + logger.trace("flush"); + } + try { + this.outputStream.flush(); + this.flushOnNext = false; + return; + } + catch (IOException ignored) { + } + } + this.flushOnNext = true; + + } + + private int writeDataBuffer(DataBuffer dataBuffer) throws IOException { + InputStream input = dataBuffer.asInputStream(); + + int bytesWritten = 0; + byte[] buffer = new byte[this.bufferSize]; + int bytesRead = -1; + + while (this.outputStream.isReady() && + (bytesRead = input.read(buffer)) != -1) { + this.outputStream.write(buffer, 0, bytesRead); + bytesWritten += bytesRead; + } + + return bytesWritten; + } + + private class ResponseBodyWriteListener implements WriteListener { + + @Override + public void onWritePossible() throws IOException { + ResponseBodyProcessor.this.onWritePossible(); + } + + @Override + public void onError(Throwable ex) { + // Error on writing to the HTTP stream, so any further writes will probably + // fail. Let's log instead of calling {@link #writeError}. + ResponseBodyProcessor.this.logger + .error("ResponseBodyWriteListener error", ex); + } + } + } } \ No newline at end of file diff --git a/spring-web-reactive/src/test/java/org/springframework/http/server/reactive/ServerHttpRequestTests.java b/spring-web-reactive/src/test/java/org/springframework/http/server/reactive/ServerHttpRequestTests.java index 6de90e0cffa..b6743cd77e0 100644 --- a/spring-web-reactive/src/test/java/org/springframework/http/server/reactive/ServerHttpRequestTests.java +++ b/spring-web-reactive/src/test/java/org/springframework/http/server/reactive/ServerHttpRequestTests.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.http.server.reactive; import java.util.Arrays; @@ -20,8 +21,8 @@ import java.util.Collections; import javax.servlet.http.HttpServletRequest; import org.junit.Test; -import reactor.core.publisher.Flux; +import org.springframework.core.io.buffer.DefaultDataBufferFactory; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.util.MultiValueMap; @@ -70,10 +71,10 @@ public class ServerHttpRequestTests { assertEquals(Collections.singletonList(null), params.get("a")); } - - private ServerHttpRequest createHttpRequest(String path) { + private ServerHttpRequest createHttpRequest(String path) throws Exception { HttpServletRequest servletRequest = new MockHttpServletRequest("GET", path); - return new ServletServerHttpRequest(servletRequest, Flux.empty()); + return new ServletServerHttpRequest(servletRequest, + new DefaultDataBufferFactory(), 1024); } }