diff --git a/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/ServletHttpHandlerAdapter.java b/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/ServletHttpHandlerAdapter.java index 1ea7e4cd0c2..97bf0d7f4b1 100644 --- a/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/ServletHttpHandlerAdapter.java +++ b/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/ServletHttpHandlerAdapter.java @@ -17,8 +17,15 @@ package org.springframework.http.server.reactive; import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.concurrent.atomic.AtomicLong; import javax.servlet.AsyncContext; +import javax.servlet.ReadListener; import javax.servlet.ServletException; +import javax.servlet.ServletInputStream; +import javax.servlet.ServletOutputStream; +import javax.servlet.WriteListener; import javax.servlet.annotation.WebServlet; import javax.servlet.http.HttpServlet; import javax.servlet.http.HttpServletRequest; @@ -26,10 +33,12 @@ import javax.servlet.http.HttpServletResponse; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.reactivestreams.Publisher; import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; import org.springframework.http.HttpStatus; +import org.springframework.util.Assert; /** * @author Arjen Poutsma @@ -38,6 +47,8 @@ import org.springframework.http.HttpStatus; @WebServlet(asyncSupported = true) public class ServletHttpHandlerAdapter extends HttpServlet { + private static final int BUFFER_SIZE = 8192; + private static Log logger = LogFactory.getLog(ServletHttpHandlerAdapter.class); @@ -50,23 +61,288 @@ public class ServletHttpHandlerAdapter extends HttpServlet { @Override - protected void service(HttpServletRequest request, HttpServletResponse response) + protected void service(HttpServletRequest servletRequest, HttpServletResponse servletResponse) throws ServletException, IOException { - AsyncContext context = request.startAsync(); + AsyncContext context = servletRequest.startAsync(); ServletAsyncContextSynchronizer synchronizer = new ServletAsyncContextSynchronizer(context); - ServletServerHttpRequest httpRequest = new ServletServerHttpRequest(request, synchronizer); - request.getInputStream().setReadListener(httpRequest.getReadListener()); + RequestBodyPublisher requestBody = new RequestBodyPublisher(synchronizer, BUFFER_SIZE); + ServletServerHttpRequest request = new ServletServerHttpRequest(servletRequest, requestBody); + servletRequest.getInputStream().setReadListener(requestBody); - ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response, synchronizer); - response.getOutputStream().setWriteListener(httpResponse.getWriteListener()); + ResponseBodySubscriber responseBodySubscriber = new ResponseBodySubscriber(synchronizer); + ServletServerHttpResponse response = new ServletServerHttpResponse(servletResponse, + publisher -> subscriber -> publisher.subscribe(responseBodySubscriber)); + servletResponse.getOutputStream().setWriteListener(responseBodySubscriber); - HandlerResultSubscriber resultSubscriber = new HandlerResultSubscriber(synchronizer, httpResponse); - this.handler.handle(httpRequest, httpResponse).subscribe(resultSubscriber); + HandlerResultSubscriber resultSubscriber = new HandlerResultSubscriber(synchronizer, response); + this.handler.handle(request, response).subscribe(resultSubscriber); } + private static class RequestBodyPublisher implements ReadListener, Publisher { + + private final ServletAsyncContextSynchronizer synchronizer; + + private final byte[] buffer; + + private final DemandCounter demand = new DemandCounter(); + + private Subscriber subscriber; + + private boolean stalled; + + private boolean cancelled; + + + public RequestBodyPublisher(ServletAsyncContextSynchronizer synchronizer, int bufferSize) { + this.synchronizer = synchronizer; + this.buffer = new byte[bufferSize]; + } + + + @Override + public void subscribe(Subscriber subscriber) { + if (subscriber == null) { + throw new NullPointerException(); + } + else if (this.subscriber != null) { + subscriber.onError(new IllegalStateException("Only one subscriber allowed")); + } + this.subscriber = subscriber; + this.subscriber.onSubscribe(new RequestBodySubscription()); + } + + @Override + public void onDataAvailable() throws IOException { + if (cancelled) { + return; + } + ServletInputStream input = this.synchronizer.getInputStream(); + logger.debug("onDataAvailable: " + input); + + while (true) { + logger.debug("Demand: " + this.demand); + + if (!demand.hasDemand()) { + stalled = true; + break; + } + + boolean ready = input.isReady(); + logger.debug("Input ready: " + ready + " finished: " + input.isFinished()); + + if (!ready) { + break; + } + + int read = input.read(buffer); + logger.debug("Input read:" + read); + + if (read == -1) { + break; + } + else if (read > 0) { + this.demand.decrement(); + byte[] copy = Arrays.copyOf(this.buffer, read); + +// logger.debug("Next: " + new String(copy, UTF_8)); + + this.subscriber.onNext(ByteBuffer.wrap(copy)); + + } + } + } + + @Override + public void onAllDataRead() throws IOException { + if (cancelled) { + return; + } + logger.debug("All data read"); + this.synchronizer.readComplete(); + if (this.subscriber != null) { + this.subscriber.onComplete(); + } + } + + @Override + public void onError(Throwable t) { + if (cancelled) { + return; + } + logger.error("RequestBodyPublisher Error", t); + this.synchronizer.readComplete(); + if (this.subscriber != null) { + this.subscriber.onError(t); + } + } + + private class RequestBodySubscription implements Subscription { + + @Override + public void request(long n) { + if (cancelled) { + return; + } + logger.debug("Updating demand " + demand + " by " + n); + + demand.increase(n); + + logger.debug("Stalled: " + stalled); + + if (stalled) { + stalled = false; + try { + onDataAvailable(); + } + catch (IOException ex) { + onError(ex); + } + } + } + + @Override + public void cancel() { + if (cancelled) { + return; + } + cancelled = true; + synchronizer.readComplete(); + demand.reset(); + } + } + + + /** + * Small utility class for keeping track of Reactive Streams demand. + */ + private static final class DemandCounter { + + private final AtomicLong demand = new AtomicLong(); + + /** + * Increases the demand by the given number + * @param n the positive number to increase demand by + * @return the increased demand + * @see org.reactivestreams.Subscription#request(long) + */ + public long increase(long n) { + Assert.isTrue(n > 0, "'n' must be higher than 0"); + return demand.updateAndGet(d -> d != Long.MAX_VALUE ? d + n : Long.MAX_VALUE); + } + + /** + * Decreases the demand by one. + * @return the decremented demand + */ + public long decrement() { + return demand.updateAndGet(d -> d != Long.MAX_VALUE ? d - 1 : Long.MAX_VALUE); + } + + /** + * Indicates whether this counter has demand, i.e. whether it is higher than 0. + * @return {@code true} if this counter has demand; {@code false} otherwise + */ + public boolean hasDemand() { + return this.demand.get() > 0; + } + + /** + * Resets this counter to 0. + * @see org.reactivestreams.Subscription#cancel() + */ + public void reset() { + this.demand.set(0); + } + + @Override + public String toString() { + return demand.toString(); + } + } + } + + private static class ResponseBodySubscriber implements WriteListener, Subscriber { + + private final ServletAsyncContextSynchronizer synchronizer; + + private Subscription subscription; + + private ByteBuffer buffer; + + private volatile boolean subscriberComplete = false; + + + public ResponseBodySubscriber(ServletAsyncContextSynchronizer synchronizer) { + this.synchronizer = synchronizer; + } + + + @Override + public void onSubscribe(Subscription subscription) { + this.subscription = subscription; + this.subscription.request(1); + } + + @Override + public void onNext(ByteBuffer bytes) { + + Assert.isNull(buffer); + + this.buffer = bytes; + try { + onWritePossible(); + } + catch (IOException e) { + onError(e); + } + } + + @Override + public void onComplete() { + logger.debug("Complete buffer: " + (buffer == null)); + + this.subscriberComplete = true; + + if (buffer == null) { + this.synchronizer.writeComplete(); + } + } + + @Override + public void onWritePossible() throws IOException { + ServletOutputStream output = this.synchronizer.getOutputStream(); + + boolean ready = output.isReady(); + logger.debug("Output: " + ready + " buffer: " + (buffer == null)); + + if (ready) { + if (this.buffer != null) { + byte[] bytes = new byte[this.buffer.remaining()]; + this.buffer.get(bytes); + this.buffer = null; + output.write(bytes); + if (!subscriberComplete) { + this.subscription.request(1); + } + else { + this.synchronizer.writeComplete(); + } + } + else { + this.subscription.request(1); + } + } + } + + @Override + public void onError(Throwable t) { + logger.error("ResponseBodySubscriber error", t); + } + } + private static class HandlerResultSubscriber implements Subscriber { private final ServletAsyncContextSynchronizer synchronizer; diff --git a/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/ServletServerHttpRequest.java b/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/ServletServerHttpRequest.java index 477f2b6a860..3f1eb1c69ca 100644 --- a/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/ServletServerHttpRequest.java +++ b/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/ServletServerHttpRequest.java @@ -16,24 +16,15 @@ package org.springframework.http.server.reactive; -import java.io.IOException; import java.net.URI; import java.net.URISyntaxException; import java.nio.ByteBuffer; import java.nio.charset.Charset; -import java.util.Arrays; import java.util.Enumeration; import java.util.Map; -import java.util.concurrent.atomic.AtomicLong; -import javax.servlet.ReadListener; -import javax.servlet.ServletInputStream; import javax.servlet.http.HttpServletRequest; -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; import org.reactivestreams.Publisher; -import org.reactivestreams.Subscriber; -import org.reactivestreams.Subscription; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; @@ -49,24 +40,20 @@ import org.springframework.util.StringUtils; */ public class ServletServerHttpRequest implements ServerHttpRequest { - private static final int BUFFER_SIZE = 8192; - - private static final Log logger = LogFactory.getLog(ServletServerHttpRequest.class); - - private final HttpServletRequest request; private URI uri; private HttpHeaders headers; - private final RequestBodyPublisher requestBodyPublisher; + private final Publisher requestBodyPublisher; - public ServletServerHttpRequest(HttpServletRequest request, ServletAsyncContextSynchronizer synchronizer) { + public ServletServerHttpRequest(HttpServletRequest request, Publisher body) { Assert.notNull(request, "'request' must not be null."); + Assert.notNull(body, "'body' must not be null."); this.request = request; - this.requestBodyPublisher = new RequestBodyPublisher(synchronizer, BUFFER_SIZE); + this.requestBodyPublisher = body; } @@ -143,192 +130,4 @@ public class ServletServerHttpRequest implements ServerHttpRequest { return this.requestBodyPublisher; } - ReadListener getReadListener() { - return this.requestBodyPublisher; - } - - - private static class RequestBodyPublisher implements ReadListener, Publisher { - - private final ServletAsyncContextSynchronizer synchronizer; - - private final byte[] buffer; - - private final DemandCounter demand = new DemandCounter(); - - private Subscriber subscriber; - - private boolean stalled; - - private boolean cancelled; - - - public RequestBodyPublisher(ServletAsyncContextSynchronizer synchronizer, int bufferSize) { - this.synchronizer = synchronizer; - this.buffer = new byte[bufferSize]; - } - - - @Override - public void subscribe(Subscriber subscriber) { - if (subscriber == null) { - throw new NullPointerException(); - } - else if (this.subscriber != null) { - subscriber.onError(new IllegalStateException("Only one subscriber allowed")); - } - this.subscriber = subscriber; - this.subscriber.onSubscribe(new RequestBodySubscription()); - } - - @Override - public void onDataAvailable() throws IOException { - if (cancelled) { - return; - } - ServletInputStream input = this.synchronizer.getInputStream(); - logger.debug("onDataAvailable: " + input); - - while (true) { - logger.debug("Demand: " + this.demand); - - if (!demand.hasDemand()) { - stalled = true; - break; - } - - boolean ready = input.isReady(); - logger.debug("Input ready: " + ready + " finished: " + input.isFinished()); - - if (!ready) { - break; - } - - int read = input.read(buffer); - logger.debug("Input read:" + read); - - if (read == -1) { - break; - } - else if (read > 0) { - this.demand.decrement(); - byte[] copy = Arrays.copyOf(this.buffer, read); - -// logger.debug("Next: " + new String(copy, UTF_8)); - - this.subscriber.onNext(ByteBuffer.wrap(copy)); - - } - } - } - - @Override - public void onAllDataRead() throws IOException { - if (cancelled) { - return; - } - logger.debug("All data read"); - this.synchronizer.readComplete(); - if (this.subscriber != null) { - this.subscriber.onComplete(); - } - } - - @Override - public void onError(Throwable t) { - if (cancelled) { - return; - } - logger.error("RequestBodyPublisher Error", t); - this.synchronizer.readComplete(); - if (this.subscriber != null) { - this.subscriber.onError(t); - } - } - - private class RequestBodySubscription implements Subscription { - - @Override - public void request(long n) { - if (cancelled) { - return; - } - logger.debug("Updating demand " + demand + " by " + n); - - demand.increase(n); - - logger.debug("Stalled: " + stalled); - - if (stalled) { - stalled = false; - try { - onDataAvailable(); - } - catch (IOException ex) { - onError(ex); - } - } - } - - @Override - public void cancel() { - if (cancelled) { - return; - } - cancelled = true; - synchronizer.readComplete(); - demand.reset(); - } - } - - - /** - * Small utility class for keeping track of Reactive Streams demand. - */ - private static final class DemandCounter { - - private final AtomicLong demand = new AtomicLong(); - - /** - * Increases the demand by the given number - * @param n the positive number to increase demand by - * @return the increased demand - * @see org.reactivestreams.Subscription#request(long) - */ - public long increase(long n) { - Assert.isTrue(n > 0, "'n' must be higher than 0"); - return demand.updateAndGet(d -> d != Long.MAX_VALUE ? d + n : Long.MAX_VALUE); - } - - /** - * Decreases the demand by one. - * @return the decremented demand - */ - public long decrement() { - return demand.updateAndGet(d -> d != Long.MAX_VALUE ? d - 1 : Long.MAX_VALUE); - } - - /** - * Indicates whether this counter has demand, i.e. whether it is higher than 0. - * @return {@code true} if this counter has demand; {@code false} otherwise - */ - public boolean hasDemand() { - return this.demand.get() > 0; - } - - /** - * Resets this counter to 0. - * @see org.reactivestreams.Subscription#cancel() - */ - public void reset() { - this.demand.set(0); - } - - @Override - public String toString() { - return demand.toString(); - } - } - } - } diff --git a/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/ServletServerHttpResponse.java b/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/ServletServerHttpResponse.java index 55dfee31a88..19e10d31289 100644 --- a/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/ServletServerHttpResponse.java +++ b/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/ServletServerHttpResponse.java @@ -16,18 +16,12 @@ package org.springframework.http.server.reactive; -import java.io.IOException; import java.nio.ByteBuffer; import java.util.List; -import javax.servlet.ServletOutputStream; -import javax.servlet.WriteListener; +import java.util.function.Function; import javax.servlet.http.HttpServletResponse; -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; import org.reactivestreams.Publisher; -import org.reactivestreams.Subscriber; -import org.reactivestreams.Subscription; import reactor.Publishers; import org.springframework.http.ExtendedHttpHeaders; @@ -42,21 +36,21 @@ import org.springframework.util.Assert; */ public class ServletServerHttpResponse implements ServerHttpResponse { - private static final Log logger = LogFactory.getLog(ServletServerHttpResponse.class); - - private final HttpServletResponse response; + private final Function, Publisher> responseBodyWriter; + private final HttpHeaders headers; - private final ResponseBodySubscriber subscriber; + public ServletServerHttpResponse(HttpServletResponse response, + Function, Publisher> responseBodyWriter) { - public ServletServerHttpResponse(HttpServletResponse response, ServletAsyncContextSynchronizer synchronizer) { Assert.notNull(response, "'response' must not be null"); + Assert.notNull(responseBodyWriter, "'responseBodyWriter' must not be null"); this.response = response; + this.responseBodyWriter = responseBodyWriter; this.headers = new ExtendedHttpHeaders(new ServletHeaderChangeListener()); - this.subscriber = new ResponseBodySubscriber(synchronizer); } @@ -74,17 +68,13 @@ public class ServletServerHttpResponse implements ServerHttpResponse { return this.headers; } - WriteListener getWriteListener() { - return this.subscriber; - } - @Override public Publisher setBody(final Publisher publisher) { return Publishers.lift(publisher, new WriteWithOperator<>(this::setBodyInternal)); } protected Publisher setBodyInternal(Publisher publisher) { - return s -> publisher.subscribe(subscriber); + return this.responseBodyWriter.apply(publisher); } @@ -109,84 +99,4 @@ public class ServletServerHttpResponse implements ServerHttpResponse { } } - - private static class ResponseBodySubscriber implements WriteListener, Subscriber { - - private final ServletAsyncContextSynchronizer synchronizer; - - private Subscription subscription; - - private ByteBuffer buffer; - - private volatile boolean subscriberComplete = false; - - - public ResponseBodySubscriber(ServletAsyncContextSynchronizer synchronizer) { - this.synchronizer = synchronizer; - } - - - @Override - public void onSubscribe(Subscription subscription) { - this.subscription = subscription; - this.subscription.request(1); - } - - @Override - public void onNext(ByteBuffer bytes) { - - Assert.isNull(buffer); - - this.buffer = bytes; - try { - onWritePossible(); - } - catch (IOException e) { - onError(e); - } - } - - @Override - public void onComplete() { - logger.debug("Complete buffer: " + (buffer == null)); - - this.subscriberComplete = true; - - if (buffer == null) { - this.synchronizer.writeComplete(); - } - } - - @Override - public void onWritePossible() throws IOException { - ServletOutputStream output = this.synchronizer.getOutputStream(); - - boolean ready = output.isReady(); - logger.debug("Output: " + ready + " buffer: " + (buffer == null)); - - if (ready) { - if (this.buffer != null) { - byte[] bytes = new byte[this.buffer.remaining()]; - this.buffer.get(bytes); - this.buffer = null; - output.write(bytes); - if (!subscriberComplete) { - this.subscription.request(1); - } - else { - this.synchronizer.writeComplete(); - } - } - else { - this.subscription.request(1); - } - } - } - - @Override - public void onError(Throwable t) { - logger.error("ResponseBodySubscriber error", t); - } - } - } diff --git a/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/UndertowHttpHandlerAdapter.java b/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/UndertowHttpHandlerAdapter.java index dce2694e756..db61960ef66 100644 --- a/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/UndertowHttpHandlerAdapter.java +++ b/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/UndertowHttpHandlerAdapter.java @@ -16,13 +16,34 @@ package org.springframework.http.server.reactive; -import org.springframework.util.Assert; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Queue; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import io.undertow.connector.PooledByteBuffer; import io.undertow.server.HttpServerExchange; +import io.undertow.util.SameThreadExecutor; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.reactivestreams.Publisher; import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; +import org.xnio.ChannelListener; +import org.xnio.channels.StreamSinkChannel; +import org.xnio.channels.StreamSourceChannel; +import reactor.core.error.SpecificationExceptions; +import reactor.core.subscriber.BaseSubscriber; +import reactor.core.support.BackpressureUtils; + +import org.springframework.util.Assert; + +import static org.xnio.ChannelListeners.closingChannelExceptionHandler; +import static org.xnio.ChannelListeners.flushingChannelListener; +import static org.xnio.IoUtils.safeClose; /** @@ -46,8 +67,12 @@ public class UndertowHttpHandlerAdapter implements io.undertow.server.HttpHandle @Override public void handleRequest(HttpServerExchange exchange) throws Exception { - ServerHttpRequest request = new UndertowServerHttpRequest(exchange); - ServerHttpResponse response = new UndertowServerHttpResponse(exchange); + RequestBodyPublisher requestBody = new RequestBodyPublisher(exchange); + ServerHttpRequest request = new UndertowServerHttpRequest(exchange, requestBody); + + ResponseBodySubscriber responseBodySubscriber = new ResponseBodySubscriber(exchange); + ServerHttpResponse response = new UndertowServerHttpResponse(exchange, + publisher -> subscriber -> publisher.subscribe(responseBodySubscriber)); exchange.dispatch(); @@ -81,4 +106,370 @@ public class UndertowHttpHandlerAdapter implements io.undertow.server.HttpHandle }); } + + private static class RequestBodyPublisher implements Publisher { + + private static final AtomicLongFieldUpdater DEMAND = + AtomicLongFieldUpdater.newUpdater(RequestBodySubscription.class, "demand"); + + + private final HttpServerExchange exchange; + + private Subscriber subscriber; + + + public RequestBodyPublisher(HttpServerExchange exchange) { + this.exchange = exchange; + } + + + @Override + public void subscribe(Subscriber subscriber) { + if (subscriber == null) { + throw SpecificationExceptions.spec_2_13_exception(); + } + if (this.subscriber != null) { + subscriber.onError(new IllegalStateException("Only one subscriber allowed")); + } + + this.subscriber = subscriber; + this.subscriber.onSubscribe(new RequestBodySubscription()); + } + + + private class RequestBodySubscription implements Subscription, Runnable, + ChannelListener { + + volatile long demand; + + private PooledByteBuffer pooledBuffer; + + private StreamSourceChannel channel; + + private boolean subscriptionClosed; + + private boolean draining; + + + @Override + public void request(long n) { + BackpressureUtils.checkRequest(n, subscriber); + if (this.subscriptionClosed) { + return; + } + BackpressureUtils.getAndAdd(DEMAND, this, n); + scheduleNextMessage(); + } + + private void scheduleNextMessage() { + exchange.dispatch(exchange.isInIoThread() ? SameThreadExecutor.INSTANCE : + exchange.getIoThread(), this); + } + + @Override + public void cancel() { + this.subscriptionClosed = true; + close(); + } + + private void close() { + if (this.pooledBuffer != null) { + safeClose(this.pooledBuffer); + this.pooledBuffer = null; + } + if (this.channel != null) { + safeClose(this.channel); + this.channel = null; + } + } + + @Override + public void run() { + if (this.subscriptionClosed || this.draining) { + return; + } + if (0 == BackpressureUtils.getAndSub(DEMAND, this, 1)) { + return; + } + + this.draining = true; + + if (this.channel == null) { + this.channel = exchange.getRequestChannel(); + + if (this.channel == null) { + if (exchange.isRequestComplete()) { + return; + } + else { + throw new IllegalStateException("Failed to acquire channel!"); + } + } + } + if (this.pooledBuffer == null) { + this.pooledBuffer = exchange.getConnection().getByteBufferPool().allocate(); + } + else { + this.pooledBuffer.getBuffer().clear(); + } + + try { + ByteBuffer buffer = this.pooledBuffer.getBuffer(); + int count; + do { + count = this.channel.read(buffer); + if (count == 0) { + this.channel.getReadSetter().set(this); + this.channel.resumeReads(); + } + else if (count == -1) { + if (buffer.position() > 0) { + doOnNext(buffer); + } + doOnComplete(); + } + else { + if (buffer.remaining() == 0) { + if (this.demand == 0) { + this.channel.suspendReads(); + } + doOnNext(buffer); + if (this.demand > 0) { + scheduleNextMessage(); + } + break; + } + } + } while (count > 0); + } + catch (IOException e) { + doOnError(e); + } + } + + private void doOnNext(ByteBuffer buffer) { + this.draining = false; + buffer.flip(); + subscriber.onNext(buffer); + } + + private void doOnComplete() { + this.subscriptionClosed = true; + try { + subscriber.onComplete(); + } + finally { + close(); + } + } + + private void doOnError(Throwable t) { + this.subscriptionClosed = true; + try { + subscriber.onError(t); + } + finally { + close(); + } + } + + @Override + public void handleEvent(StreamSourceChannel channel) { + if (this.subscriptionClosed) { + return; + } + + try { + ByteBuffer buffer = this.pooledBuffer.getBuffer(); + int count; + do { + count = channel.read(buffer); + if (count == 0) { + return; + } + else if (count == -1) { + if (buffer.position() > 0) { + doOnNext(buffer); + } + doOnComplete(); + } + else { + if (buffer.remaining() == 0) { + if (this.demand == 0) { + channel.suspendReads(); + } + doOnNext(buffer); + if (this.demand > 0) { + scheduleNextMessage(); + } + break; + } + } + } while (count > 0); + } + catch (IOException e) { + doOnError(e); + } + } + } + } + + private static class ResponseBodySubscriber extends BaseSubscriber + implements ChannelListener { + + private final HttpServerExchange exchange; + + private Subscription subscription; + + private final Queue buffers = new ConcurrentLinkedQueue<>(); + + private final AtomicInteger writing = new AtomicInteger(); + + private final AtomicBoolean closing = new AtomicBoolean(); + + private StreamSinkChannel responseChannel; + + + public ResponseBodySubscriber(HttpServerExchange exchange) { + this.exchange = exchange; + } + + @Override + public void onSubscribe(Subscription subscription) { + super.onSubscribe(subscription); + this.subscription = subscription; + this.subscription.request(1); + } + + @Override + public void onNext(ByteBuffer buffer) { + super.onNext(buffer); + + if (this.responseChannel == null) { + this.responseChannel = exchange.getResponseChannel(); + } + + this.writing.incrementAndGet(); + try { + int c; + do { + c = this.responseChannel.write(buffer); + } while (buffer.hasRemaining() && c > 0); + + if (buffer.hasRemaining()) { + this.writing.incrementAndGet(); + enqueue(buffer); + this.responseChannel.getWriteSetter().set(this); + this.responseChannel.resumeWrites(); + } + else { + this.subscription.request(1); + } + + } + catch (IOException ex) { + onError(ex); + } + finally { + this.writing.decrementAndGet(); + if (this.closing.get()) { + closeIfDone(); + } + } + } + + private void enqueue(ByteBuffer src) { + do { + PooledByteBuffer buffer = exchange.getConnection().getByteBufferPool().allocate(); + ByteBuffer dst = buffer.getBuffer(); + copy(dst, src); + dst.flip(); + this.buffers.add(buffer); + } while (src.remaining() > 0); + } + + private void copy(ByteBuffer dst, ByteBuffer src) { + int n = Math.min(dst.capacity(), src.remaining()); + for (int i = 0; i < n; i++) { + dst.put(src.get()); + } + } + + @Override + public void handleEvent(StreamSinkChannel channel) { + try { + int c; + do { + ByteBuffer buffer = this.buffers.peek().getBuffer(); + do { + c = channel.write(buffer); + } while (buffer.hasRemaining() && c > 0); + + if (!buffer.hasRemaining()) { + safeClose(this.buffers.remove()); + } + } while (!this.buffers.isEmpty() && c > 0); + + if (!this.buffers.isEmpty()) { + channel.resumeWrites(); + } + else { + this.writing.decrementAndGet(); + + if (this.closing.get()) { + closeIfDone(); + } + else { + this.subscription.request(1); + } + } + } + catch (IOException ex) { + onError(ex); + } + } + + @Override + public void onError(Throwable ex) { + super.onError(ex); + logger.error("ResponseBodySubscriber error", ex); + if (!exchange.isResponseStarted() && exchange.getStatusCode() < 500) { + exchange.setStatusCode(500); + } + } + + @Override + public void onComplete() { + super.onComplete(); + if (this.responseChannel != null) { + this.closing.set(true); + closeIfDone(); + } + } + + private void closeIfDone() { + if (this.writing.get() == 0) { + if (this.closing.compareAndSet(true, false)) { + closeChannel(); + } + } + } + + private void closeChannel() { + try { + this.responseChannel.shutdownWrites(); + + if (!this.responseChannel.flush()) { + this.responseChannel.getWriteSetter().set(flushingChannelListener( + o -> safeClose(this.responseChannel), closingChannelExceptionHandler())); + this.responseChannel.resumeWrites(); + } + this.responseChannel = null; + } + catch (IOException ex) { + onError(ex); + } + } + } + } diff --git a/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpRequest.java b/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpRequest.java index e8133f8bc87..9b710945fd3 100644 --- a/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpRequest.java +++ b/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpRequest.java @@ -16,30 +16,18 @@ package org.springframework.http.server.reactive; -import java.io.IOException; import java.net.URI; import java.net.URISyntaxException; import java.nio.ByteBuffer; -import java.util.concurrent.atomic.AtomicLongFieldUpdater; -import io.undertow.connector.PooledByteBuffer; import io.undertow.server.HttpServerExchange; import io.undertow.util.HeaderValues; -import io.undertow.util.SameThreadExecutor; import org.reactivestreams.Publisher; -import org.reactivestreams.Subscriber; -import org.reactivestreams.Subscription; -import org.xnio.ChannelListener; -import org.xnio.channels.StreamSourceChannel; -import reactor.core.error.SpecificationExceptions; -import reactor.core.support.BackpressureUtils; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; import org.springframework.util.Assert; -import static org.xnio.IoUtils.safeClose; - /** * Adapt {@link ServerHttpRequest} to the Underow {@link HttpServerExchange}. * @@ -54,12 +42,14 @@ public class UndertowServerHttpRequest implements ServerHttpRequest { private HttpHeaders headers; - private final Publisher body = new RequestBodyPublisher(); + private final Publisher body; - public UndertowServerHttpRequest(HttpServerExchange exchange) { + public UndertowServerHttpRequest(HttpServerExchange exchange, Publisher body) { Assert.notNull(exchange, "'exchange' is required."); + Assert.notNull(exchange, "'body' is required."); this.exchange = exchange; + this.body = body; } @@ -105,204 +95,4 @@ public class UndertowServerHttpRequest implements ServerHttpRequest { return this.body; } - - private static final AtomicLongFieldUpdater DEMAND = - AtomicLongFieldUpdater.newUpdater(RequestBodyPublisher.RequestBodySubscription.class, "demand"); - - private class RequestBodyPublisher implements Publisher { - - private Subscriber subscriber; - - - @Override - public void subscribe(Subscriber subscriber) { - if (subscriber == null) { - throw SpecificationExceptions.spec_2_13_exception(); - } - if (this.subscriber != null) { - subscriber.onError(new IllegalStateException("Only one subscriber allowed")); - } - - this.subscriber = subscriber; - this.subscriber.onSubscribe(new RequestBodySubscription()); - } - - - private class RequestBodySubscription implements Subscription, Runnable, - ChannelListener { - - volatile long demand; - - private PooledByteBuffer pooledBuffer; - - private StreamSourceChannel channel; - - private boolean subscriptionClosed; - - private boolean draining; - - - @Override - public void request(long n) { - BackpressureUtils.checkRequest(n, subscriber); - if (this.subscriptionClosed) { - return; - } - BackpressureUtils.getAndAdd(DEMAND, this, n); - scheduleNextMessage(); - } - - private void scheduleNextMessage() { - exchange.dispatch(exchange.isInIoThread() ? SameThreadExecutor.INSTANCE : - exchange.getIoThread(), this); - } - - @Override - public void cancel() { - this.subscriptionClosed = true; - close(); - } - - private void close() { - if (this.pooledBuffer != null) { - safeClose(this.pooledBuffer); - this.pooledBuffer = null; - } - if (this.channel != null) { - safeClose(this.channel); - this.channel = null; - } - } - - @Override - public void run() { - if (this.subscriptionClosed || this.draining) { - return; - } - if (0 == BackpressureUtils.getAndSub(DEMAND, this, 1)) { - return; - } - - this.draining = true; - - if (this.channel == null) { - this.channel = exchange.getRequestChannel(); - - if (this.channel == null) { - if (exchange.isRequestComplete()) { - return; - } - else { - throw new IllegalStateException("Failed to acquire channel!"); - } - } - } - if (this.pooledBuffer == null) { - this.pooledBuffer = exchange.getConnection().getByteBufferPool().allocate(); - } - else { - this.pooledBuffer.getBuffer().clear(); - } - - try { - ByteBuffer buffer = this.pooledBuffer.getBuffer(); - int count; - do { - count = this.channel.read(buffer); - if (count == 0) { - this.channel.getReadSetter().set(this); - this.channel.resumeReads(); - } - else if (count == -1) { - if (buffer.position() > 0) { - doOnNext(buffer); - } - doOnComplete(); - } - else { - if (buffer.remaining() == 0) { - if (this.demand == 0) { - this.channel.suspendReads(); - } - doOnNext(buffer); - if (this.demand > 0) { - scheduleNextMessage(); - } - break; - } - } - } while (count > 0); - } - catch (IOException e) { - doOnError(e); - } - } - - private void doOnNext(ByteBuffer buffer) { - this.draining = false; - buffer.flip(); - subscriber.onNext(buffer); - } - - private void doOnComplete() { - this.subscriptionClosed = true; - try { - subscriber.onComplete(); - } - finally { - close(); - } - } - - private void doOnError(Throwable t) { - this.subscriptionClosed = true; - try { - subscriber.onError(t); - } - finally { - close(); - } - } - - @Override - public void handleEvent(StreamSourceChannel channel) { - if (this.subscriptionClosed) { - return; - } - - try { - ByteBuffer buffer = this.pooledBuffer.getBuffer(); - int count; - do { - count = channel.read(buffer); - if (count == 0) { - return; - } - else if (count == -1) { - if (buffer.position() > 0) { - doOnNext(buffer); - } - doOnComplete(); - } - else { - if (buffer.remaining() == 0) { - if (this.demand == 0) { - channel.suspendReads(); - } - doOnNext(buffer); - if (this.demand > 0) { - scheduleNextMessage(); - } - break; - } - } - } while (count > 0); - } - catch (IOException e) { - doOnError(e); - } - } - } - } - } diff --git a/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpResponse.java b/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpResponse.java index 316257bf8b5..d13b80eda2a 100644 --- a/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpResponse.java +++ b/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpResponse.java @@ -16,35 +16,20 @@ package org.springframework.http.server.reactive; -import java.io.IOException; import java.nio.ByteBuffer; import java.util.List; -import java.util.Queue; -import java.util.concurrent.ConcurrentLinkedQueue; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Function; -import io.undertow.connector.PooledByteBuffer; import io.undertow.server.HttpServerExchange; import io.undertow.util.HttpString; -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; import org.reactivestreams.Publisher; -import org.reactivestreams.Subscription; -import org.xnio.ChannelListener; -import org.xnio.channels.StreamSinkChannel; import reactor.Publishers; -import reactor.core.subscriber.BaseSubscriber; import org.springframework.http.ExtendedHttpHeaders; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpStatus; import org.springframework.util.Assert; -import static org.xnio.ChannelListeners.closingChannelExceptionHandler; -import static org.xnio.ChannelListeners.flushingChannelListener; -import static org.xnio.IoUtils.safeClose; - /** * Adapt {@link ServerHttpResponse} to the Undertow {@link HttpServerExchange}. * @@ -53,19 +38,20 @@ import static org.xnio.IoUtils.safeClose; */ public class UndertowServerHttpResponse implements ServerHttpResponse { - private static final Log logger = LogFactory.getLog(UndertowServerHttpResponse.class); - - private final HttpServerExchange exchange; + private final Function, Publisher> responseBodyWriter; + private final HttpHeaders headers; - private final ResponseBodySubscriber bodySubscriber = new ResponseBodySubscriber(); + public UndertowServerHttpResponse(HttpServerExchange exchange, + Function, Publisher> responseBodyWriter) { - public UndertowServerHttpResponse(HttpServerExchange exchange) { Assert.notNull(exchange, "'exchange' is required."); + Assert.notNull(responseBodyWriter, "'responseBodyWriter' must not be null"); this.exchange = exchange; + this.responseBodyWriter = responseBodyWriter; this.headers = new ExtendedHttpHeaders(new UndertowHeaderChangeListener()); } @@ -90,8 +76,8 @@ public class UndertowServerHttpResponse implements ServerHttpResponse { return Publishers.lift(publisher, new WriteWithOperator<>(this::setBodyInternal)); } - protected Publisher setBodyInternal(Publisher writePublisher) { - return subscriber -> writePublisher.subscribe(bodySubscriber); + protected Publisher setBodyInternal(Publisher publisher) { + return this.responseBodyWriter.apply(publisher); } @@ -115,156 +101,4 @@ public class UndertowServerHttpResponse implements ServerHttpResponse { } } - private class ResponseBodySubscriber extends BaseSubscriber - implements ChannelListener { - - private Subscription subscription; - - private final Queue buffers = new ConcurrentLinkedQueue<>(); - - private final AtomicInteger writing = new AtomicInteger(); - - private final AtomicBoolean closing = new AtomicBoolean(); - - private StreamSinkChannel responseChannel; - - - @Override - public void onSubscribe(Subscription subscription) { - super.onSubscribe(subscription); - this.subscription = subscription; - this.subscription.request(1); - } - - @Override - public void onNext(ByteBuffer buffer) { - super.onNext(buffer); - - if (this.responseChannel == null) { - this.responseChannel = exchange.getResponseChannel(); - } - - this.writing.incrementAndGet(); - try { - int c; - do { - c = this.responseChannel.write(buffer); - } while (buffer.hasRemaining() && c > 0); - - if (buffer.hasRemaining()) { - this.writing.incrementAndGet(); - enqueue(buffer); - this.responseChannel.getWriteSetter().set(this); - this.responseChannel.resumeWrites(); - } - else { - this.subscription.request(1); - } - - } - catch (IOException ex) { - onError(ex); - } - finally { - this.writing.decrementAndGet(); - if (this.closing.get()) { - closeIfDone(); - } - } - } - - private void enqueue(ByteBuffer src) { - do { - PooledByteBuffer buffer = exchange.getConnection().getByteBufferPool().allocate(); - ByteBuffer dst = buffer.getBuffer(); - copy(dst, src); - dst.flip(); - this.buffers.add(buffer); - } while (src.remaining() > 0); - } - - private void copy(ByteBuffer dst, ByteBuffer src) { - int n = Math.min(dst.capacity(), src.remaining()); - for (int i = 0; i < n; i++) { - dst.put(src.get()); - } - } - - @Override - public void handleEvent(StreamSinkChannel channel) { - try { - int c; - do { - ByteBuffer buffer = this.buffers.peek().getBuffer(); - do { - c = channel.write(buffer); - } while (buffer.hasRemaining() && c > 0); - - if (!buffer.hasRemaining()) { - safeClose(this.buffers.remove()); - } - } while (!this.buffers.isEmpty() && c > 0); - - if (!this.buffers.isEmpty()) { - channel.resumeWrites(); - } - else { - this.writing.decrementAndGet(); - - if (this.closing.get()) { - closeIfDone(); - } - else { - this.subscription.request(1); - } - } - } - catch (IOException ex) { - onError(ex); - } - } - - @Override - public void onError(Throwable ex) { - super.onError(ex); - logger.error("ResponseBodySubscriber error", ex); - if (!exchange.isResponseStarted() && exchange.getStatusCode() < 500) { - exchange.setStatusCode(500); - } - } - - @Override - public void onComplete() { - super.onComplete(); - if (this.responseChannel != null) { - this.closing.set(true); - closeIfDone(); - } - } - - private void closeIfDone() { - if (this.writing.get() == 0) { - if (this.closing.compareAndSet(true, false)) { - closeChannel(); - } - } - } - - private void closeChannel() { - try { - this.responseChannel.shutdownWrites(); - - if (!this.responseChannel.flush()) { - this.responseChannel.getWriteSetter().set(flushingChannelListener( - o -> safeClose(this.responseChannel), closingChannelExceptionHandler())); - this.responseChannel.resumeWrites(); - } - this.responseChannel = null; - } - catch (IOException ex) { - onError(ex); - } - } - } - }