From 52325a21ffb76d30d8f9253934fec621f4e3f121 Mon Sep 17 00:00:00 2001 From: Arjen Poutsma Date: Fri, 24 Jun 2016 13:42:46 +0200 Subject: [PATCH] Fixed Undertow flush support Reactored Servlet 3.1 and Undertow response support into an AbstractResponseBodySubscriber that uses an internal state machine, making thread-safity a lot easier. --- .../AbstractResponseBodySubscriber.java | 284 ++++++++++++++++++ .../reactive/ServletHttpHandlerAdapter.java | 220 ++++++-------- .../reactive/UndertowHttpHandlerAdapter.java | 159 ++++------ .../annotation/SseIntegrationTests.java | 21 +- 4 files changed, 442 insertions(+), 242 deletions(-) create mode 100644 spring-web-reactive/src/main/java/org/springframework/http/server/reactive/AbstractResponseBodySubscriber.java diff --git a/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/AbstractResponseBodySubscriber.java b/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/AbstractResponseBodySubscriber.java new file mode 100644 index 00000000000..7c174edd122 --- /dev/null +++ b/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/AbstractResponseBodySubscriber.java @@ -0,0 +1,284 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server.reactive; + +import java.io.IOException; +import java.nio.channels.Channel; +import java.util.concurrent.atomic.AtomicReference; +import javax.servlet.WriteListener; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import reactor.core.util.BackpressureUtils; + +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.FlushingDataBuffer; +import org.springframework.core.io.buffer.support.DataBufferUtils; +import org.springframework.util.Assert; + +/** + * Abstract base class for {@code Subscriber} implementations that bridge between + * event-listener APIs and Reactive Streams. Specifically, base class for the Servlet 3.1 + * and Undertow support. + * @author Arjen Poutsma + * @see ServletServerHttpRequest + * @see UndertowHttpHandlerAdapter + */ +abstract class AbstractResponseBodySubscriber implements Subscriber { + + protected final Log logger = LogFactory.getLog(getClass()); + + private final AtomicReference state = + new AtomicReference<>(State.UNSUBSCRIBED); + + private volatile DataBuffer currentBuffer; + + private volatile boolean subscriptionCompleted; + + private Subscription subscription; + + @Override + public final void onSubscribe(Subscription subscription) { + if (logger.isTraceEnabled()) { + logger.trace(this.state + " onSubscribe: " + subscription); + } + this.state.get().onSubscribe(this, subscription); + } + + @Override + public final void onNext(DataBuffer dataBuffer) { + if (logger.isTraceEnabled()) { + logger.trace(this.state + " onNext: " + dataBuffer); + } + this.state.get().onNext(this, dataBuffer); + } + + @Override + public final void onError(Throwable t) { + if (logger.isErrorEnabled()) { + logger.error(this.state + " onError: " + t, t); + } + this.state.get().onError(this, t); + } + + @Override + public final void onComplete() { + if (logger.isTraceEnabled()) { + logger.trace(this.state + " onComplete"); + } + this.state.get().onComplete(this); + } + + /** + * Called via a listener interface to indicate that writing is possible. + * @see WriteListener#onWritePossible() + * @see org.xnio.ChannelListener#handleEvent(Channel) + */ + protected final void onWritePossible() { + this.state.get().onWritePossible(this); + } + + /** + * Called when a {@link DataBuffer} is received via {@link Subscriber#onNext(Object)} + * @param dataBuffer the buffer that was received. + */ + protected void receiveBuffer(DataBuffer dataBuffer) { + Assert.state(this.currentBuffer == null); + this.currentBuffer = dataBuffer; + } + + /** + * Called when the current buffer should be + * {@linkplain DataBufferUtils#release(DataBuffer) released}. + */ + protected void releaseBuffer() { + if (logger.isTraceEnabled()) { + logger.trace("releaseBuffer: " + this.currentBuffer); + } + DataBufferUtils.release(this.currentBuffer); + this.currentBuffer = null; + } + + /** + * Writes the given data buffer to the output, indicating if the entire buffer was + * written. + * @param dataBuffer the data buffer to write + * @return {@code true} if {@code dataBuffer} was fully written and a new buffer + * can be requested; {@code false} otherwise + */ + protected abstract boolean write(DataBuffer dataBuffer) throws IOException; + + /** + * Writes the given exception to the output. + */ + protected abstract void writeError(Throwable t); + + /** + * Flushes the output. + */ + protected abstract void flush() throws IOException; + + /** + * Closes the output. + */ + protected abstract void close(); + + private void changeState(State oldState, State newState) { + this.state.compareAndSet(oldState, newState); + } + + /** + * Represents a state for the {@link Subscriber} to be in. The following figure + * indicate the four different states that exist, and the relationships between them. + * + *
+	 *       UNSUBSCRIBED
+	 *        |
+	 *        v
+	 * REQUESTED <---> RECEIVED
+	 *         |       |
+	 *         v       v
+	 *         COMPLETED
+	 * 
+ * Refer to the individual states for more information. + */ + private enum State { + + /** + * The initial unsubscribed state. Will respond to {@code onSubscribe} by + * requesting 1 buffer from the subscription, and change state to {@link + * #REQUESTED}. + */ + UNSUBSCRIBED { + @Override + void onSubscribe(AbstractResponseBodySubscriber subscriber, + Subscription subscription) { + if (BackpressureUtils.validate(subscriber.subscription, subscription)) { + subscriber.subscription = subscription; + subscriber.changeState(this, REQUESTED); + subscription.request(1); + } + } + }, + /** + * State that gets entered after a buffer has been + * {@linkplain Subscription#request(long) requested}. Responds to {@code onNext} + * by changing state to {@link #RECEIVED}, and responds to {@code onComplete} by + * changing state to {@link #COMPLETED}. + */ + REQUESTED { + @Override + void onNext(AbstractResponseBodySubscriber subscriber, + DataBuffer dataBuffer) { + subscriber.changeState(this, RECEIVED); + subscriber.receiveBuffer(dataBuffer); + } + + @Override + void onComplete(AbstractResponseBodySubscriber subscriber) { + subscriber.subscriptionCompleted = true; + subscriber.changeState(this, COMPLETED); + subscriber.close(); + } + }, + /** + * State that gets entered after a buffer has been + * {@linkplain Subscriber#onNext(Object) received}. Responds to + * {@code onWritePossible} by writing the current buffer, and if it can be + * written completely, changes state to either {@link #REQUESTED} if the + * subscription has not been completed; or {@link #COMPLETED} if it has. + */ + RECEIVED { + @Override + void onWritePossible(AbstractResponseBodySubscriber subscriber) { + DataBuffer dataBuffer = subscriber.currentBuffer; + try { + boolean writeCompleted = subscriber.write(dataBuffer); + if (writeCompleted) { + if (dataBuffer instanceof FlushingDataBuffer) { + subscriber.flush(); + } + subscriber.releaseBuffer(); + boolean subscriptionCompleted = subscriber.subscriptionCompleted; + if (!subscriptionCompleted) { + subscriber.changeState(this, REQUESTED); + subscriber.subscription.request(1); + } + else { + subscriber.changeState(this, COMPLETED); + subscriber.close(); + } + } + } + catch (IOException ex) { + subscriber.onError(ex); + } + } + + @Override + void onComplete(AbstractResponseBodySubscriber subscriber) { + subscriber.subscriptionCompleted = true; + } + }, + /** + * The terminal completed state. Does not respond to any events. + */ + COMPLETED { + @Override + void onNext(AbstractResponseBodySubscriber subscriber, + DataBuffer dataBuffer) { + // ignore + } + + @Override + void onError(AbstractResponseBodySubscriber subscriber, Throwable t) { + // ignore + } + + @Override + void onComplete(AbstractResponseBodySubscriber subscriber) { + // ignore + } + }; + + void onSubscribe(AbstractResponseBodySubscriber subscriber, Subscription s) { + throw new IllegalStateException(toString()); + } + + void onNext(AbstractResponseBodySubscriber subscriber, DataBuffer dataBuffer) { + throw new IllegalStateException(toString()); + } + + void onError(AbstractResponseBodySubscriber subscriber, Throwable t) { + subscriber.changeState(this, COMPLETED); + subscriber.writeError(t); + subscriber.close(); + } + + void onComplete(AbstractResponseBodySubscriber subscriber) { + throw new IllegalStateException(toString()); + } + + void onWritePossible(AbstractResponseBodySubscriber subscriber) { + // ignore + } + + } + +} diff --git a/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/ServletHttpHandlerAdapter.java b/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/ServletHttpHandlerAdapter.java index 6f8bcf8ce51..aa1edf35cd0 100644 --- a/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/ServletHttpHandlerAdapter.java +++ b/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/ServletHttpHandlerAdapter.java @@ -34,13 +34,10 @@ import org.apache.commons.logging.LogFactory; import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; import reactor.core.publisher.Mono; -import reactor.core.util.BackpressureUtils; import org.springframework.core.io.buffer.DataBuffer; import org.springframework.core.io.buffer.DataBufferFactory; import org.springframework.core.io.buffer.DefaultDataBufferFactory; -import org.springframework.core.io.buffer.FlushingDataBuffer; -import org.springframework.core.io.buffer.support.DataBufferUtils; import org.springframework.http.HttpStatus; import org.springframework.util.Assert; @@ -88,16 +85,17 @@ public class ServletHttpHandlerAdapter extends HttpServlet { ServletAsyncContextSynchronizer synchronizer = new ServletAsyncContextSynchronizer(context); RequestBodyPublisher requestBody = - new RequestBodyPublisher(synchronizer, dataBufferFactory, bufferSize); + new RequestBodyPublisher(synchronizer, this.dataBufferFactory, + this.bufferSize); requestBody.registerListener(); ServletServerHttpRequest request = new ServletServerHttpRequest(servletRequest, requestBody); ResponseBodySubscriber responseBody = - new ResponseBodySubscriber(synchronizer, bufferSize); + new ResponseBodySubscriber(synchronizer, this.bufferSize); responseBody.registerListener(); ServletServerHttpResponse response = - new ServletServerHttpResponse(servletResponse, dataBufferFactory, + new ServletServerHttpResponse(servletResponse, this.dataBufferFactory, publisher -> Mono .from(subscriber -> publisher.subscribe(responseBody))); @@ -162,16 +160,17 @@ public class ServletHttpHandlerAdapter extends HttpServlet { } public void registerListener() throws IOException { - this.synchronizer.getRequest().getInputStream().setReadListener(readListener); + this.synchronizer.getRequest().getInputStream() + .setReadListener(this.readListener); } @Override protected void noLongerStalled() { try { - readListener.onDataAvailable(); + this.readListener.onDataAvailable(); } catch (IOException ex) { - readListener.onError(ex); + this.readListener.onError(ex); } } @@ -183,7 +182,9 @@ public class ServletHttpHandlerAdapter extends HttpServlet { return; } logger.trace("onDataAvailable"); - ServletInputStream input = synchronizer.getRequest().getInputStream(); + ServletInputStream input = + RequestBodyPublisher.this.synchronizer.getRequest() + .getInputStream(); while (true) { if (!checkSubscriptionForDemand()) { @@ -198,15 +199,17 @@ public class ServletHttpHandlerAdapter extends HttpServlet { break; } - int read = input.read(buffer); + int read = input.read(RequestBodyPublisher.this.buffer); logger.trace("Input read:" + read); if (read == -1) { break; } else if (read > 0) { - DataBuffer dataBuffer = dataBufferFactory.allocateBuffer(read); - dataBuffer.write(buffer, 0, read); + DataBuffer dataBuffer = + RequestBodyPublisher.this.dataBufferFactory + .allocateBuffer(read); + dataBuffer.write(RequestBodyPublisher.this.buffer, 0, read); publishOnNext(dataBuffer); } @@ -216,7 +219,7 @@ public class ServletHttpHandlerAdapter extends HttpServlet { @Override public void onAllDataRead() throws IOException { logger.trace("All data read"); - synchronizer.readComplete(); + RequestBodyPublisher.this.synchronizer.readComplete(); publishOnComplete(); } @@ -224,7 +227,7 @@ public class ServletHttpHandlerAdapter extends HttpServlet { @Override public void onError(Throwable t) { logger.trace("RequestBodyReadListener Error", t); - synchronizer.readComplete(); + RequestBodyPublisher.this.synchronizer.readComplete(); publishOnError(t); } @@ -232,9 +235,7 @@ public class ServletHttpHandlerAdapter extends HttpServlet { } - private static class ResponseBodySubscriber implements Subscriber { - - private static final Log logger = LogFactory.getLog(ResponseBodySubscriber.class); + private static class ResponseBodySubscriber extends AbstractResponseBodySubscriber { private final ResponseBodyWriteListener writeListener = new ResponseBodyWriteListener(); @@ -243,14 +244,7 @@ public class ServletHttpHandlerAdapter extends HttpServlet { private final int bufferSize; - private volatile DataBuffer dataBuffer; - - private volatile boolean completed = false; - - private volatile boolean flushOnNext = false; - - private Subscription subscription; - + private volatile boolean flushOnNext; public ResponseBodySubscriber(ServletAsyncContextSynchronizer synchronizer, int bufferSize) { @@ -259,145 +253,119 @@ public class ServletHttpHandlerAdapter extends HttpServlet { } public void registerListener() throws IOException { - synchronizer.getResponse().getOutputStream().setWriteListener(writeListener); + outputStream().setWriteListener(this.writeListener); + } + + private ServletOutputStream outputStream() throws IOException { + return this.synchronizer.getResponse().getOutputStream(); } @Override - public void onSubscribe(Subscription subscription) { - logger.trace("onSubscribe. Subscription: " + subscription); - if (BackpressureUtils.validate(this.subscription, subscription)) { - this.subscription = subscription; - this.subscription.request(1); + protected void receiveBuffer(DataBuffer dataBuffer) { + super.receiveBuffer(dataBuffer); + + try { + if (outputStream().isReady()) { + onWritePossible(); + } + } + catch (IOException ignored) { } } @Override - public void onNext(DataBuffer dataBuffer) { - Assert.state(this.dataBuffer == null); + protected boolean write(DataBuffer dataBuffer) throws IOException { + ServletOutputStream output = outputStream(); - logger.trace("onNext. buffer: " + dataBuffer); + boolean ready = output.isReady(); - this.dataBuffer = dataBuffer; - try { - this.writeListener.onWritePossible(); + if (this.flushOnNext) { + flush(); + ready = output.isReady(); } - catch (IOException e) { - onError(e); + + if (this.logger.isTraceEnabled()) { + this.logger.trace("write: " + dataBuffer + " ready: " + ready); + } + + if (ready) { + int total = dataBuffer.readableByteCount(); + int written = writeDataBuffer(dataBuffer); + + if (this.logger.isTraceEnabled()) { + this.logger.trace("written: " + written + " total: " + total); + } + return written == total; + } + else { + return false; } } @Override - public void onError(Throwable t) { - logger.error("onError", t); + protected void writeError(Throwable t) { HttpServletResponse response = (HttpServletResponse) this.synchronizer.getResponse(); - response.setStatus(HttpStatus.INTERNAL_SERVER_ERROR.value()); - this.synchronizer.complete(); - + response.setStatus(HttpServletResponse.SC_INTERNAL_SERVER_ERROR); } @Override - public void onComplete() { - logger.trace("onComplete. buffer: " + this.dataBuffer); - - this.completed = true; - - if (this.dataBuffer != null) { + protected void flush() throws IOException { + ServletOutputStream output = outputStream(); + if (output.isReady()) { + if (logger.isTraceEnabled()) { + this.logger.trace("flush"); + } try { - this.writeListener.onWritePossible(); + output.flush(); + this.flushOnNext = false; } - catch (IOException ex) { - onError(ex); + catch (IOException ignored) { } } - - if (this.dataBuffer == null) { - this.synchronizer.writeComplete(); + else { + this.flushOnNext = true; } - } - private class ResponseBodyWriteListener implements WriteListener { + } - @Override - public void onWritePossible() throws IOException { - logger.trace("onWritePossible"); - ServletOutputStream output = synchronizer.getResponse().getOutputStream(); + @Override + protected void close() { + this.synchronizer.writeComplete(); + } - boolean ready = output.isReady(); + private int writeDataBuffer(DataBuffer dataBuffer) throws IOException { + InputStream input = dataBuffer.asInputStream(); + ServletOutputStream output = outputStream(); - if (flushOnNext) { - flush(output); - ready = output.isReady(); - } + int bytesWritten = 0; + byte[] buffer = new byte[this.bufferSize]; + int bytesRead = -1; - logger.trace("ready: " + ready + " buffer: " + dataBuffer); - - if (ready) { - if (dataBuffer != null) { - - int total = dataBuffer.readableByteCount(); - int written = writeDataBuffer(); - - logger.trace("written: " + written + " total: " + total); - if (written == total) { - if (dataBuffer instanceof FlushingDataBuffer) { - flush(output); - } - releaseBuffer(); - if (!completed) { - subscription.request(1); - } - else { - synchronizer.writeComplete(); - } - } - } - else if (subscription != null) { - subscription.request(1); - } - } + while (output.isReady() && (bytesRead = input.read(buffer)) != -1) { + output.write(buffer, 0, bytesRead); + bytesWritten += bytesRead; } - private int writeDataBuffer() throws IOException { - InputStream input = dataBuffer.asInputStream(); - ServletOutputStream output = synchronizer.getResponse().getOutputStream(); - - int bytesWritten = 0; - byte[] buffer = new byte[bufferSize]; - int bytesRead = -1; - - while (output.isReady() && (bytesRead = input.read(buffer)) != -1) { - output.write(buffer, 0, bytesRead); - bytesWritten += bytesRead; - } - - return bytesWritten; - } + return bytesWritten; + } - private void flush(ServletOutputStream output) { - if (output.isReady()) { - logger.trace("Flushing"); - try { - output.flush(); - flushOnNext = false; - } - catch (IOException ignored) { - } - } else { - flushOnNext = true; - } - } + private class ResponseBodyWriteListener implements WriteListener { - private void releaseBuffer() { - DataBufferUtils.release(dataBuffer); - dataBuffer = null; + @Override + public void onWritePossible() throws IOException { + ResponseBodySubscriber.this.onWritePossible(); } @Override public void onError(Throwable ex) { - logger.error("ResponseBodyWriteListener error", ex); + // Error on writing to the HTTP stream, so any further writes will probably + // fail. Let's log instead of calling {@link #writeError}. + ResponseBodySubscriber.this.logger + .error("ResponseBodyWriteListener error", ex); } } } + } \ No newline at end of file diff --git a/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/UndertowHttpHandlerAdapter.java b/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/UndertowHttpHandlerAdapter.java index bc2c89ab1a4..4d3f4ac7aed 100644 --- a/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/UndertowHttpHandlerAdapter.java +++ b/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/UndertowHttpHandlerAdapter.java @@ -31,12 +31,9 @@ import org.xnio.IoUtils; import org.xnio.channels.StreamSinkChannel; import org.xnio.channels.StreamSourceChannel; import reactor.core.publisher.Mono; -import reactor.core.util.BackpressureUtils; import org.springframework.core.io.buffer.DataBuffer; import org.springframework.core.io.buffer.DataBufferFactory; -import org.springframework.core.io.buffer.FlushingDataBuffer; -import org.springframework.core.io.buffer.support.DataBufferUtils; import org.springframework.util.Assert; /** @@ -65,7 +62,7 @@ public class UndertowHttpHandlerAdapter implements io.undertow.server.HttpHandle public void handleRequest(HttpServerExchange exchange) throws Exception { RequestBodyPublisher requestBody = - new RequestBodyPublisher(exchange, dataBufferFactory); + new RequestBodyPublisher(exchange, this.dataBufferFactory); requestBody.registerListener(); ServerHttpRequest request = new UndertowServerHttpRequest(exchange, requestBody); @@ -76,7 +73,7 @@ public class UndertowHttpHandlerAdapter implements io.undertow.server.HttpHandle ServerHttpResponse response = new UndertowServerHttpResponse(exchange, responseChannel, publisher -> Mono.from(subscriber -> publisher.subscribe(responseBody)), - dataBufferFactory); + this.dataBufferFactory); this.delegate.handle(request, response).subscribe(new Subscriber() { @@ -130,7 +127,7 @@ public class UndertowHttpHandlerAdapter implements io.undertow.server.HttpHandle } public void registerListener() { - this.requestChannel.getReadSetter().set(listener); + this.requestChannel.getReadSetter().set(this.listener); this.requestChannel.resumeReads(); } @@ -145,7 +142,7 @@ public class UndertowHttpHandlerAdapter implements io.undertow.server.HttpHandle @Override protected void noLongerStalled() { - listener.handleEvent(requestChannel); + this.listener.handleEvent(this.requestChannel); } private class RequestBodyListener @@ -157,7 +154,8 @@ public class UndertowHttpHandlerAdapter implements io.undertow.server.HttpHandle return; } logger.trace("handleEvent"); - ByteBuffer byteBuffer = pooledByteBuffer.getBuffer(); + ByteBuffer byteBuffer = + RequestBodyPublisher.this.pooledByteBuffer.getBuffer(); try { while (true) { if (!checkSubscriptionForDemand()) { @@ -177,7 +175,9 @@ public class UndertowHttpHandlerAdapter implements io.undertow.server.HttpHandle } else { byteBuffer.flip(); - DataBuffer dataBuffer = dataBufferFactory.wrap(byteBuffer); + DataBuffer dataBuffer = + RequestBodyPublisher.this.dataBufferFactory + .wrap(byteBuffer); publishOnNext(dataBuffer); } } @@ -190,9 +190,7 @@ public class UndertowHttpHandlerAdapter implements io.undertow.server.HttpHandle } - private static class ResponseBodySubscriber implements Subscriber { - - private static final Log logger = LogFactory.getLog(ResponseBodySubscriber.class); + private static class ResponseBodySubscriber extends AbstractResponseBodySubscriber { private final ChannelListener listener = new ResponseBodyListener(); @@ -203,12 +201,6 @@ public class UndertowHttpHandlerAdapter implements io.undertow.server.HttpHandle private volatile ByteBuffer byteBuffer; - private volatile DataBuffer dataBuffer; - - private volatile boolean completed = false; - - private Subscription subscription; - public ResponseBodySubscriber(HttpServerExchange exchange, StreamSinkChannel responseChannel) { this.exchange = exchange; @@ -216,58 +208,77 @@ public class UndertowHttpHandlerAdapter implements io.undertow.server.HttpHandle } public void registerListener() { - this.responseChannel.getWriteSetter().set(listener); + this.responseChannel.getWriteSetter().set(this.listener); this.responseChannel.resumeWrites(); } + @Override + protected void writeError(Throwable t) { + if (!this.exchange.isResponseStarted() && + this.exchange.getStatusCode() < 500) { + this.exchange.setStatusCode(500); + } + } @Override - public void onSubscribe(Subscription subscription) { - logger.trace("onSubscribe. Subscription: " + subscription); - if (BackpressureUtils.validate(this.subscription, subscription)) { - this.subscription = subscription; - this.subscription.request(1); + protected void flush() throws IOException { + if (logger.isTraceEnabled()) { + logger.trace("flush"); } + this.responseChannel.flush(); } @Override - public void onNext(DataBuffer dataBuffer) { - Assert.state(this.byteBuffer == null); - logger.trace("onNext. buffer: " + dataBuffer); + protected boolean write(DataBuffer dataBuffer) throws IOException { + if (this.byteBuffer == null) { + return false; + } + if (logger.isTraceEnabled()) { + logger.trace("write: " + dataBuffer); + } + int total = this.byteBuffer.remaining(); + int written = writeByteBuffer(this.byteBuffer); - this.byteBuffer = dataBuffer.asByteBuffer(); - this.dataBuffer = dataBuffer; + if (logger.isTraceEnabled()) { + logger.trace("written: " + written + " total: " + total); + } + return written == total; } - @Override - public void onError(Throwable t) { - logger.error("onError", t); - if (!exchange.isResponseStarted() && exchange.getStatusCode() < 500) { - exchange.setStatusCode(500); + private int writeByteBuffer(ByteBuffer byteBuffer) throws IOException { + int written; + int totalWritten = 0; + do { + written = this.responseChannel.write(byteBuffer); + totalWritten += written; } - closeChannel(responseChannel); + while (byteBuffer.hasRemaining() && written > 0); + return totalWritten; } @Override - public void onComplete() { - logger.trace("onComplete. buffer: " + this.byteBuffer); - - this.completed = true; + protected void receiveBuffer(DataBuffer dataBuffer) { + super.receiveBuffer(dataBuffer); + this.byteBuffer = dataBuffer.asByteBuffer(); + } - if (this.byteBuffer == null) { - closeChannel(responseChannel); - } + @Override + protected void releaseBuffer() { + super.releaseBuffer(); + this.byteBuffer = null; } - private void closeChannel(StreamSinkChannel channel) { + @Override + protected void close() { try { - channel.shutdownWrites(); + this.responseChannel.shutdownWrites(); - if (!channel.flush()) { - channel.getWriteSetter().set(ChannelListeners - .flushingChannelListener(o -> IoUtils.safeClose(channel), + if (!this.responseChannel.flush()) { + this.responseChannel.getWriteSetter().set(ChannelListeners + .flushingChannelListener( + o -> IoUtils.safeClose(this.responseChannel), ChannelListeners.closingChannelExceptionHandler())); - channel.resumeWrites(); + this.responseChannel.resumeWrites(); } } catch (IOException ignored) { @@ -278,60 +289,12 @@ public class UndertowHttpHandlerAdapter implements io.undertow.server.HttpHandle @Override public void handleEvent(StreamSinkChannel channel) { - if (byteBuffer != null) { - try { - int total = byteBuffer.remaining(); - int written = writeByteBuffer(channel); - - logger.trace("written: " + written + " total: " + total); - - if (written == total) { - if (dataBuffer instanceof FlushingDataBuffer) { - flush(channel); - } - releaseBuffer(); - if (!completed) { - subscription.request(1); - } - else { - closeChannel(channel); - } - } - } - catch (IOException ex) { - onError(ex); - } - } - else if (subscription != null) { - subscription.request(1); - } - - } - - private int writeByteBuffer(StreamSinkChannel channel) throws IOException { - int written; - int totalWritten = 0; - do { - written = channel.write(byteBuffer); - totalWritten += written; - } - while (byteBuffer.hasRemaining() && written > 0); - return totalWritten; - } - - private void flush(StreamSinkChannel channel) throws IOException { - logger.trace("Flushing"); - channel.flush(); - } - - private void releaseBuffer() { - DataBufferUtils.release(dataBuffer); - dataBuffer = null; - byteBuffer = null; + onWritePossible(); } } } + } \ No newline at end of file diff --git a/spring-web-reactive/src/test/java/org/springframework/web/reactive/result/method/annotation/SseIntegrationTests.java b/spring-web-reactive/src/test/java/org/springframework/web/reactive/result/method/annotation/SseIntegrationTests.java index 2c98530d8ce..1c9b514a29f 100644 --- a/spring-web-reactive/src/test/java/org/springframework/web/reactive/result/method/annotation/SseIntegrationTests.java +++ b/spring-web-reactive/src/test/java/org/springframework/web/reactive/result/method/annotation/SseIntegrationTests.java @@ -22,9 +22,6 @@ import java.util.List; import org.junit.Before; import org.junit.Test; -import org.junit.runners.Parameterized; -import static org.springframework.web.client.reactive.HttpRequestBuilders.get; -import static org.springframework.web.client.reactive.WebResponseExtractors.bodyStream; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.core.test.TestSubscriber; @@ -43,10 +40,6 @@ import org.springframework.http.converter.reactive.HttpMessageConverter; import org.springframework.http.converter.reactive.SseHttpMessageConverter; import org.springframework.http.server.reactive.AbstractHttpHandlerIntegrationTests; import org.springframework.http.server.reactive.HttpHandler; -import org.springframework.http.server.reactive.boot.JettyHttpServer; -import org.springframework.http.server.reactive.boot.ReactorHttpServer; -import org.springframework.http.server.reactive.boot.RxNettyHttpServer; -import org.springframework.http.server.reactive.boot.TomcatHttpServer; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RestController; import org.springframework.web.client.reactive.WebClient; @@ -55,22 +48,14 @@ import org.springframework.web.reactive.config.WebReactiveConfiguration; import org.springframework.web.reactive.sse.SseEvent; import org.springframework.web.server.adapter.WebHttpHandlerBuilder; +import static org.springframework.web.client.reactive.HttpRequestBuilders.get; +import static org.springframework.web.client.reactive.WebResponseExtractors.bodyStream; + /** * @author Sebastien Deleuze */ public class SseIntegrationTests extends AbstractHttpHandlerIntegrationTests { - // TODO Fix Undertow support and remove this method - @Parameterized.Parameters(name = "server [{0}]") - public static Object[][] arguments() { - return new Object[][] { - {new JettyHttpServer()}, - {new RxNettyHttpServer()}, - {new ReactorHttpServer()}, - {new TomcatHttpServer()}, - }; - } - private AnnotationConfigApplicationContext wac; private WebClient webClient;