From 361707c448f41612b018f2e50452dfb185fbbd5e Mon Sep 17 00:00:00 2001 From: Arjen Poutsma Date: Tue, 23 Feb 2016 12:17:14 +0100 Subject: [PATCH] Servlet 3.1 support cleanup Claneup of the Servlet 3.1 support: - moved RequestBodyPublisher to ServletServerHttpRequest - moved ResponseBodySubscribera to ServletServerHttpResponse - response body is now copied to ServletOutputStream in chunks, rather than one big byte[] --- .../ServletAsyncContextSynchronizer.java | 18 +- .../reactive/ServletHttpHandlerAdapter.java | 321 ++---------------- .../reactive/ServletServerHttpRequest.java | 216 +++++++++++- .../reactive/ServletServerHttpResponse.java | 132 ++++++- .../reactive/AsyncIntegrationTests.java | 50 +-- 5 files changed, 376 insertions(+), 361 deletions(-) diff --git a/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/ServletAsyncContextSynchronizer.java b/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/ServletAsyncContextSynchronizer.java index dc1e015b5ec..283b598e9ed 100644 --- a/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/ServletAsyncContextSynchronizer.java +++ b/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/ServletAsyncContextSynchronizer.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2015 the original author or authors. + * Copyright 2002-2016 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,6 +21,8 @@ import java.util.concurrent.atomic.AtomicInteger; import javax.servlet.AsyncContext; import javax.servlet.ServletInputStream; import javax.servlet.ServletOutputStream; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; /** * Utility class for synchronizing between the reading and writing side of an @@ -40,10 +42,12 @@ final class ServletAsyncContextSynchronizer { private static final int COMPLETE = READ_COMPLETE | WRITE_COMPLETE; + private final AsyncContext asyncContext; private final AtomicInteger complete = new AtomicInteger(NONE_COMPLETE); + /** * Creates a new {@code AsyncContextSynchronizer} based on the given context. * @param asyncContext the context to base this synchronizer on @@ -52,13 +56,21 @@ final class ServletAsyncContextSynchronizer { this.asyncContext = asyncContext; } + public ServletRequest getRequest() { + return this.asyncContext.getRequest(); + } + + public ServletResponse getResponse() { + return this.asyncContext.getResponse(); + } + /** * Returns the input stream of this synchronizer. * @return the input stream * @throws IOException if an input or output exception occurred */ public ServletInputStream getInputStream() throws IOException { - return this.asyncContext.getRequest().getInputStream(); + return getRequest().getInputStream(); } /** @@ -67,7 +79,7 @@ final class ServletAsyncContextSynchronizer { * @throws IOException if an input or output exception occurred */ public ServletOutputStream getOutputStream() throws IOException { - return this.asyncContext.getResponse().getOutputStream(); + return getResponse().getOutputStream(); } /** diff --git a/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/ServletHttpHandlerAdapter.java b/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/ServletHttpHandlerAdapter.java index 048fa38014b..80b447c43ca 100644 --- a/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/ServletHttpHandlerAdapter.java +++ b/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/ServletHttpHandlerAdapter.java @@ -17,14 +17,8 @@ package org.springframework.http.server.reactive; import java.io.IOException; -import java.io.InputStream; -import java.util.concurrent.atomic.AtomicLong; import javax.servlet.AsyncContext; -import javax.servlet.ReadListener; import javax.servlet.ServletException; -import javax.servlet.ServletInputStream; -import javax.servlet.ServletOutputStream; -import javax.servlet.WriteListener; import javax.servlet.annotation.WebServlet; import javax.servlet.http.HttpServlet; import javax.servlet.http.HttpServletRequest; @@ -32,17 +26,13 @@ import javax.servlet.http.HttpServletResponse; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; -import org.reactivestreams.Publisher; import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; -import reactor.core.publisher.Mono; -import org.springframework.core.io.buffer.DataBuffer; import org.springframework.core.io.buffer.DataBufferAllocator; import org.springframework.core.io.buffer.DefaultDataBufferAllocator; import org.springframework.http.HttpStatus; import org.springframework.util.Assert; -import org.springframework.util.StreamUtils; /** * @author Arjen Poutsma @@ -51,24 +41,35 @@ import org.springframework.util.StreamUtils; @WebServlet(asyncSupported = true) public class ServletHttpHandlerAdapter extends HttpServlet { - private static final int BUFFER_SIZE = 8192; + private static final int DEFAULT_BUFFER_SIZE = 8192; private static Log logger = LogFactory.getLog(ServletHttpHandlerAdapter.class); private HttpHandler handler; - private DataBufferAllocator allocator = new DefaultDataBufferAllocator(); + // Servlet is based on blocking I/O, hence the usage of non-direct, heap-based buffers + // (i.e. 'false' as constructor argument) + private DataBufferAllocator allocator = new DefaultDataBufferAllocator(false); + + private int bufferSize = DEFAULT_BUFFER_SIZE; public void setHandler(HttpHandler handler) { + Assert.notNull(handler, "'handler' must not be null"); this.handler = handler; } public void setAllocator(DataBufferAllocator allocator) { + Assert.notNull(allocator, "'allocator' must not be null"); this.allocator = allocator; } + public void setBufferSize(int bufferSize) { + Assert.isTrue(bufferSize > 0); + this.bufferSize = bufferSize; + } + @Override protected void service(HttpServletRequest servletRequest, HttpServletResponse servletResponse) throws ServletException, IOException { @@ -76,299 +77,25 @@ public class ServletHttpHandlerAdapter extends HttpServlet { AsyncContext context = servletRequest.startAsync(); ServletAsyncContextSynchronizer synchronizer = new ServletAsyncContextSynchronizer(context); - RequestBodyPublisher requestBody = - new RequestBodyPublisher(synchronizer, allocator, BUFFER_SIZE); - ServletServerHttpRequest request = new ServletServerHttpRequest(servletRequest, requestBody); - servletRequest.getInputStream().setReadListener(requestBody); - - ResponseBodySubscriber responseBodySubscriber = - new ResponseBodySubscriber(synchronizer); - ServletServerHttpResponse response = new ServletServerHttpResponse(servletResponse, - publisher -> Mono.from(subscriber -> publisher.subscribe(responseBodySubscriber))); - servletResponse.getOutputStream().setWriteListener(responseBodySubscriber); - - HandlerResultSubscriber resultSubscriber = new HandlerResultSubscriber(synchronizer, response); - this.handler.handle(request, response).subscribe(resultSubscriber); - } - - private static class RequestBodyPublisher - implements ReadListener, Publisher { - - private final ServletAsyncContextSynchronizer synchronizer; - - private final DataBufferAllocator allocator; - - private final byte[] buffer; - - private final DemandCounter demand = new DemandCounter(); - - private Subscriber subscriber; - - private boolean stalled; - - private boolean cancelled; - - public RequestBodyPublisher(ServletAsyncContextSynchronizer synchronizer, - DataBufferAllocator allocator, int bufferSize) { - this.synchronizer = synchronizer; - this.allocator = allocator; - this.buffer = new byte[bufferSize]; - } - - @Override - public void subscribe(Subscriber subscriber) { - if (subscriber == null) { - throw new NullPointerException(); - } - else if (this.subscriber != null) { - subscriber.onError(new IllegalStateException("Only one subscriber allowed")); - } - this.subscriber = subscriber; - this.subscriber.onSubscribe(new RequestBodySubscription()); - } - - @Override - public void onDataAvailable() throws IOException { - if (cancelled) { - return; - } - ServletInputStream input = this.synchronizer.getInputStream(); - logger.debug("onDataAvailable: " + input); - - while (true) { - logger.debug("Demand: " + this.demand); - - if (!demand.hasDemand()) { - stalled = true; - break; - } - - boolean ready = input.isReady(); - logger.debug("Input ready: " + ready + " finished: " + input.isFinished()); - - if (!ready) { - break; - } - - int read = input.read(buffer); - logger.debug("Input read:" + read); - - if (read == -1) { - break; - } - else if (read > 0) { - this.demand.decrement(); - - DataBuffer dataBuffer = allocator.allocateBuffer(read); - dataBuffer.write(this.buffer, 0, read); - - this.subscriber.onNext(dataBuffer); - - } - } - } - - @Override - public void onAllDataRead() throws IOException { - if (cancelled) { - return; - } - logger.debug("All data read"); - this.synchronizer.readComplete(); - if (this.subscriber != null) { - this.subscriber.onComplete(); - } - } - - @Override - public void onError(Throwable t) { - if (cancelled) { - return; - } - logger.error("RequestBodyPublisher Error", t); - this.synchronizer.readComplete(); - if (this.subscriber != null) { - this.subscriber.onError(t); - } - } - - private class RequestBodySubscription implements Subscription { - - @Override - public void request(long n) { - if (cancelled) { - return; - } - logger.debug("Updating demand " + demand + " by " + n); - - demand.increase(n); - - logger.debug("Stalled: " + stalled); - - if (stalled) { - stalled = false; - try { - onDataAvailable(); - } - catch (IOException ex) { - onError(ex); - } - } - } + ServletServerHttpRequest request = + new ServletServerHttpRequest(synchronizer, this.allocator, + this.bufferSize); - @Override - public void cancel() { - if (cancelled) { - return; - } - cancelled = true; - synchronizer.readComplete(); - demand.reset(); - } - } - - - /** - * Small utility class for keeping track of Reactive Streams demand. - */ - private static final class DemandCounter { - - private final AtomicLong demand = new AtomicLong(); - - /** - * Increases the demand by the given number - * @param n the positive number to increase demand by - * @return the increased demand - * @see org.reactivestreams.Subscription#request(long) - */ - public long increase(long n) { - Assert.isTrue(n > 0, "'n' must be higher than 0"); - return demand.updateAndGet(d -> d != Long.MAX_VALUE ? d + n : Long.MAX_VALUE); - } - - /** - * Decreases the demand by one. - * @return the decremented demand - */ - public long decrement() { - return demand.updateAndGet(d -> d != Long.MAX_VALUE ? d - 1 : Long.MAX_VALUE); - } + ServletServerHttpResponse response = + new ServletServerHttpResponse(synchronizer, this.bufferSize); - /** - * Indicates whether this counter has demand, i.e. whether it is higher than 0. - * @return {@code true} if this counter has demand; {@code false} otherwise - */ - public boolean hasDemand() { - return this.demand.get() > 0; - } + HandlerResultSubscriber resultSubscriber = + new HandlerResultSubscriber(synchronizer); - /** - * Resets this counter to 0. - * @see org.reactivestreams.Subscription#cancel() - */ - public void reset() { - this.demand.set(0); - } - - @Override - public String toString() { - return demand.toString(); - } - } - } - - private static class ResponseBodySubscriber - implements WriteListener, Subscriber { - - private final ServletAsyncContextSynchronizer synchronizer; - - private Subscription subscription; - - private DataBuffer dataBuffer; - - private volatile boolean subscriberComplete = false; - - public ResponseBodySubscriber(ServletAsyncContextSynchronizer synchronizer) { - this.synchronizer = synchronizer; - } - - - @Override - public void onSubscribe(Subscription subscription) { - this.subscription = subscription; - this.subscription.request(1); - } - - @Override - public void onNext(DataBuffer bytes) { - Assert.isNull(dataBuffer); - - this.dataBuffer = bytes; - try { - onWritePossible(); - } - catch (IOException e) { - onError(e); - } - } - - @Override - public void onComplete() { - logger.debug("Complete buffer: " + (dataBuffer == null)); - - this.subscriberComplete = true; - - if (dataBuffer == null) { - this.synchronizer.writeComplete(); - } - } - - @Override - public void onWritePossible() throws IOException { - ServletOutputStream output = this.synchronizer.getOutputStream(); - - boolean ready = output.isReady(); - logger.debug("Output: " + ready + " buffer: " + (dataBuffer == null)); - - if (ready) { - if (this.dataBuffer != null) { - InputStream in = this.dataBuffer.asInputStream(); - byte[] buffer = new byte[BUFFER_SIZE]; - int bytesRead; - while ((bytesRead = in.read(buffer)) != -1) { - output.write(buffer, 0, bytesRead); - } - if (!subscriberComplete) { - this.subscription.request(1); - } - else { - this.synchronizer.writeComplete(); - } - } - else { - this.subscription.request(1); - } - } - } - - @Override - public void onError(Throwable t) { - logger.error("ResponseBodySubscriber error", t); - } + this.handler.handle(request, response).subscribe(resultSubscriber); } private static class HandlerResultSubscriber implements Subscriber { private final ServletAsyncContextSynchronizer synchronizer; - private final ServletServerHttpResponse response; - - - public HandlerResultSubscriber(ServletAsyncContextSynchronizer synchronizer, - ServletServerHttpResponse response) { - + public HandlerResultSubscriber(ServletAsyncContextSynchronizer synchronizer) { this.synchronizer = synchronizer; - this.response = response; } @@ -385,7 +112,9 @@ public class ServletHttpHandlerAdapter extends HttpServlet { @Override public void onError(Throwable ex) { logger.error("Error from request handling. Completing the request.", ex); - this.response.setStatusCode(HttpStatus.INTERNAL_SERVER_ERROR); + HttpServletResponse response = + (HttpServletResponse) this.synchronizer.getResponse(); + response.setStatus(HttpStatus.INTERNAL_SERVER_ERROR.value()); this.synchronizer.complete(); } diff --git a/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/ServletServerHttpRequest.java b/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/ServletServerHttpRequest.java index cbfb6e73426..2374d862313 100644 --- a/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/ServletServerHttpRequest.java +++ b/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/ServletServerHttpRequest.java @@ -16,6 +16,7 @@ package org.springframework.http.server.reactive; +import java.io.IOException; import java.net.URI; import java.net.URISyntaxException; import java.nio.charset.Charset; @@ -23,13 +24,21 @@ import java.util.ArrayList; import java.util.Enumeration; import java.util.List; import java.util.Map; +import java.util.concurrent.atomic.AtomicLong; +import javax.servlet.ReadListener; +import javax.servlet.ServletInputStream; import javax.servlet.http.Cookie; import javax.servlet.http.HttpServletRequest; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; import org.reactivestreams.Publisher; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; import reactor.core.publisher.Flux; import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferAllocator; import org.springframework.http.HttpCookie; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; @@ -45,16 +54,22 @@ import org.springframework.util.StringUtils; */ public class ServletServerHttpRequest extends AbstractServerHttpRequest { + private static final Log logger = LogFactory.getLog(ServletServerHttpRequest.class); + private final HttpServletRequest request; private final Flux requestBodyPublisher; - public ServletServerHttpRequest(HttpServletRequest request, - Publisher body) { - Assert.notNull(request, "'request' must not be null."); - Assert.notNull(body, "'body' must not be null."); - this.request = request; - this.requestBodyPublisher = Flux.from(body); + public ServletServerHttpRequest(ServletAsyncContextSynchronizer synchronizer, + DataBufferAllocator allocator, int bufferSize) throws IOException { + Assert.notNull(synchronizer, "'synchronizer' must not be null"); + Assert.notNull(allocator, "'allocator' must not be null"); + + this.request = (HttpServletRequest) synchronizer.getRequest(); + RequestBodyPublisher bodyPublisher = + new RequestBodyPublisher(synchronizer, allocator, bufferSize); + this.requestBodyPublisher = Flux.from(bodyPublisher); + this.request.getInputStream().setReadListener(bodyPublisher); } @@ -132,4 +147,193 @@ public class ServletServerHttpRequest extends AbstractServerHttpRequest { return this.requestBodyPublisher; } + private static class RequestBodyPublisher + implements ReadListener, Publisher { + + private final ServletAsyncContextSynchronizer synchronizer; + + private final DataBufferAllocator allocator; + + private final byte[] buffer; + + private final DemandCounter demand = new DemandCounter(); + + private Subscriber subscriber; + + private boolean stalled; + + private boolean cancelled; + + public RequestBodyPublisher(ServletAsyncContextSynchronizer synchronizer, + DataBufferAllocator allocator, int bufferSize) { + this.synchronizer = synchronizer; + this.allocator = allocator; + this.buffer = new byte[bufferSize]; + } + + @Override + public void subscribe(Subscriber subscriber) { + if (subscriber == null) { + throw new NullPointerException(); + } + else if (this.subscriber != null) { + subscriber.onError( + new IllegalStateException("Only one subscriber allowed")); + } + this.subscriber = subscriber; + this.subscriber.onSubscribe(new RequestBodySubscription()); + } + + @Override + public void onDataAvailable() throws IOException { + if (cancelled) { + return; + } + ServletInputStream input = this.synchronizer.getInputStream(); + logger.trace("onDataAvailable: " + input); + + while (true) { + logger.trace("Demand: " + this.demand); + + if (!demand.hasDemand()) { + stalled = true; + break; + } + + boolean ready = input.isReady(); + logger.trace( + "Input ready: " + ready + " finished: " + input.isFinished()); + + if (!ready) { + break; + } + + int read = input.read(buffer); + logger.trace("Input read:" + read); + + if (read == -1) { + break; + } + else if (read > 0) { + this.demand.decrement(); + + DataBuffer dataBuffer = allocator.allocateBuffer(read); + dataBuffer.write(this.buffer, 0, read); + + this.subscriber.onNext(dataBuffer); + + } + } + } + + @Override + public void onAllDataRead() throws IOException { + if (cancelled) { + return; + } + logger.trace("All data read"); + this.synchronizer.readComplete(); + if (this.subscriber != null) { + this.subscriber.onComplete(); + } + } + + @Override + public void onError(Throwable t) { + if (cancelled) { + return; + } + logger.trace("RequestBodyPublisher Error", t); + this.synchronizer.readComplete(); + if (this.subscriber != null) { + this.subscriber.onError(t); + } + } + + private class RequestBodySubscription implements Subscription { + + @Override + public void request(long n) { + if (cancelled) { + return; + } + logger.trace("Updating demand " + demand + " by " + n); + + demand.increase(n); + + logger.trace("Stalled: " + stalled); + + if (stalled) { + stalled = false; + try { + onDataAvailable(); + } + catch (IOException ex) { + onError(ex); + } + } + } + + @Override + public void cancel() { + if (cancelled) { + return; + } + cancelled = true; + synchronizer.readComplete(); + demand.reset(); + } + } + + /** + * Small utility class for keeping track of Reactive Streams demand. + */ + private static final class DemandCounter { + + private final AtomicLong demand = new AtomicLong(); + + /** + * Increases the demand by the given number + * @param n the positive number to increase demand by + * @return the increased demand + * @see Subscription#request(long) + */ + public long increase(long n) { + Assert.isTrue(n > 0, "'n' must be higher than 0"); + return demand + .updateAndGet(d -> d != Long.MAX_VALUE ? d + n : Long.MAX_VALUE); + } + + /** + * Decreases the demand by one. + * @return the decremented demand + */ + public long decrement() { + return demand + .updateAndGet(d -> d != Long.MAX_VALUE ? d - 1 : Long.MAX_VALUE); + } + + /** + * Indicates whether this counter has demand, i.e. whether it is higher than + * 0. + * @return {@code true} if this counter has demand; {@code false} otherwise + */ + public boolean hasDemand() { + return this.demand.get() > 0; + } + + /** + * Resets this counter to 0. + * @see Subscription#cancel() + */ + public void reset() { + this.demand.set(0); + } + + @Override + public String toString() { + return demand.toString(); + } + } + } } diff --git a/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/ServletServerHttpResponse.java b/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/ServletServerHttpResponse.java index eeb50d96c81..2ee70f42a9a 100644 --- a/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/ServletServerHttpResponse.java +++ b/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/ServletServerHttpResponse.java @@ -16,14 +16,21 @@ package org.springframework.http.server.reactive; +import java.io.IOException; +import java.io.InputStream; import java.nio.charset.Charset; import java.util.List; import java.util.Map; -import java.util.function.Function; +import javax.servlet.ServletOutputStream; +import javax.servlet.WriteListener; import javax.servlet.http.Cookie; import javax.servlet.http.HttpServletResponse; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; import org.reactivestreams.Publisher; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; import reactor.core.publisher.Mono; import org.springframework.core.io.buffer.DataBuffer; @@ -39,18 +46,20 @@ import org.springframework.util.Assert; */ public class ServletServerHttpResponse extends AbstractServerHttpResponse { - private final HttpServletResponse response; + private static final Log logger = LogFactory.getLog(ServletServerHttpResponse.class); - private final Function, Mono> responseBodyWriter; + private final HttpServletResponse response; + private final ResponseBodySubscriber responseBodySubscriber; - public ServletServerHttpResponse(HttpServletResponse response, - Function, Mono> responseBodyWriter) { + public ServletServerHttpResponse(ServletAsyncContextSynchronizer synchronizer, + int bufferSize) throws IOException { + Assert.notNull(synchronizer, "'synchronizer' must not be null"); - Assert.notNull(response, "'response' must not be null"); - Assert.notNull(responseBodyWriter, "'responseBodyWriter' must not be null"); - this.response = response; - this.responseBodyWriter = responseBodyWriter; + this.response = (HttpServletResponse) synchronizer.getResponse(); + this.responseBodySubscriber = + new ResponseBodySubscriber(synchronizer, bufferSize); + this.response.getOutputStream().setWriteListener(responseBodySubscriber); } @@ -65,7 +74,8 @@ public class ServletServerHttpResponse extends AbstractServerHttpResponse { @Override protected Mono setBodyInternal(Publisher publisher) { - return this.responseBodyWriter.apply(publisher); + return Mono.from((Publisher) subscriber -> publisher + .subscribe(this.responseBodySubscriber)); } @Override @@ -107,4 +117,106 @@ public class ServletServerHttpResponse extends AbstractServerHttpResponse { } } + private static class ResponseBodySubscriber + implements WriteListener, Subscriber { + + private final ServletAsyncContextSynchronizer synchronizer; + + private final int bufferSize; + + private Subscription subscription; + + private DataBuffer dataBuffer; + + private volatile boolean subscriberComplete = false; + + public ResponseBodySubscriber(ServletAsyncContextSynchronizer synchronizer, + int bufferSize) { + this.synchronizer = synchronizer; + this.bufferSize = bufferSize; + } + + @Override + public void onSubscribe(Subscription subscription) { + this.subscription = subscription; + this.subscription.request(1); + } + + @Override + public void onNext(DataBuffer dataBuffer) { + Assert.isNull(this.dataBuffer); + logger.trace("onNext. buffer: " + dataBuffer); + + this.dataBuffer = dataBuffer; + try { + onWritePossible(); + } + catch (IOException e) { + onError(e); + } + } + + @Override + public void onComplete() { + logger.trace("onComplete. buffer: " + dataBuffer); + + this.subscriberComplete = true; + + if (dataBuffer == null) { + this.synchronizer.writeComplete(); + } + } + + @Override + public void onWritePossible() throws IOException { + ServletOutputStream output = this.synchronizer.getOutputStream(); + + boolean ready = output.isReady(); + logger.trace("onWritePossible. ready: " + ready + " buffer: " + dataBuffer); + + if (ready) { + if (this.dataBuffer != null) { + int toBeWritten = this.dataBuffer.readableByteCount(); + InputStream input = this.dataBuffer.asInputStream(); + int writeCount = write(input, output); + logger.trace("written: " + writeCount + " total: " + toBeWritten); + if (writeCount == toBeWritten) { + this.dataBuffer = null; + if (!this.subscriberComplete) { + this.subscription.request(1); + } + else { + this.synchronizer.writeComplete(); + } + } + } + else if (this.subscription != null) { + this.subscription.request(1); + } + } + } + + private int write(InputStream in, ServletOutputStream output) throws IOException { + int byteCount = 0; + byte[] buffer = new byte[bufferSize]; + int bytesRead = -1; + while (output.isReady() && (bytesRead = in.read(buffer)) != -1) { + output.write(buffer, 0, bytesRead); + byteCount += bytesRead; + } + return byteCount; + } + + @Override + public void onError(Throwable ex) { + if (this.subscription != null) { + this.subscription.cancel(); + } + logger.error("ResponseBodySubscriber error", ex); + HttpServletResponse response = + (HttpServletResponse) this.synchronizer.getResponse(); + response.setStatus(HttpStatus.INTERNAL_SERVER_ERROR.value()); + this.synchronizer.complete(); + } + } } \ No newline at end of file diff --git a/spring-web-reactive/src/test/java/org/springframework/http/server/reactive/AsyncIntegrationTests.java b/spring-web-reactive/src/test/java/org/springframework/http/server/reactive/AsyncIntegrationTests.java index e06acc147a2..b11739f554f 100644 --- a/spring-web-reactive/src/test/java/org/springframework/http/server/reactive/AsyncIntegrationTests.java +++ b/spring-web-reactive/src/test/java/org/springframework/http/server/reactive/AsyncIntegrationTests.java @@ -19,11 +19,7 @@ package org.springframework.http.server.reactive; import java.net.URI; import org.hamcrest.Matchers; -import org.junit.After; -import org.junit.Before; import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; import reactor.core.publisher.Mono; import reactor.core.publisher.SchedulerGroup; import reactor.core.timer.Timer; @@ -33,11 +29,6 @@ import org.springframework.core.io.buffer.DataBufferAllocator; import org.springframework.core.io.buffer.DefaultDataBufferAllocator; import org.springframework.http.RequestEntity; import org.springframework.http.ResponseEntity; -import org.springframework.http.server.reactive.boot.HttpServer; -import org.springframework.http.server.reactive.boot.ReactorHttpServer; -import org.springframework.http.server.reactive.boot.RxNettyHttpServer; -import org.springframework.http.server.reactive.boot.UndertowHttpServer; -import org.springframework.util.SocketUtils; import org.springframework.web.client.RestTemplate; import static org.junit.Assert.assertThat; @@ -47,48 +38,15 @@ import static org.junit.Assert.assertThat; * * @author Stephane Maldini */ -@RunWith(Parameterized.class) -public class AsyncIntegrationTests { +public class AsyncIntegrationTests extends AbstractHttpHandlerIntegrationTests { private final SchedulerGroup asyncGroup = SchedulerGroup.async(); private final DataBufferAllocator allocator = new DefaultDataBufferAllocator(); - protected int port; - - @Parameterized.Parameter(0) - public HttpServer server; - - private AsyncHandler asyncHandler; - - @Parameterized.Parameters(name = "server [{0}]") - public static Object[][] arguments() { - return new Object[][]{ - //{new JettyHttpServer()}, - {new RxNettyHttpServer()}, - {new ReactorHttpServer()}, - //{new TomcatHttpServer()}, - {new UndertowHttpServer()} - }; - } - - @Before - public void setup() throws Exception { - this.port = SocketUtils.findAvailableTcpPort(); - this.server.setPort(this.port); - this.server.setHandler(createHttpHandler()); - this.server.afterPropertiesSet(); - this.server.start(); - } - - protected HttpHandler createHttpHandler() { - this.asyncHandler = new AsyncHandler(); - return this.asyncHandler; - } - - @After - public void tearDown() throws Exception { - this.server.stop(); + @Override + protected AsyncHandler createHttpHandler() { + return new AsyncHandler(); } @SuppressWarnings("unchecked")