diff --git a/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/UndertowHttpHandlerAdapter.java b/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/UndertowHttpHandlerAdapter.java index dfc9d5837bd..a856a7a1019 100644 --- a/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/UndertowHttpHandlerAdapter.java +++ b/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/UndertowHttpHandlerAdapter.java @@ -16,29 +16,19 @@ package org.springframework.http.server.reactive; -import java.io.IOException; -import java.nio.ByteBuffer; - -import io.undertow.connector.PooledByteBuffer; import io.undertow.server.HttpServerExchange; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; -import org.xnio.ChannelListener; -import org.xnio.ChannelListeners; -import org.xnio.IoUtils; -import org.xnio.channels.StreamSinkChannel; -import org.xnio.channels.StreamSourceChannel; -import reactor.core.publisher.Mono; -import org.springframework.core.io.buffer.DataBuffer; import org.springframework.core.io.buffer.DataBufferFactory; import org.springframework.util.Assert; /** * @author Marek Hawrylczak * @author Rossen Stoyanchev + * @author Arjen Poutsma */ public class UndertowHttpHandlerAdapter implements io.undertow.server.HttpHandler { @@ -60,20 +50,11 @@ public class UndertowHttpHandlerAdapter implements io.undertow.server.HttpHandle @Override public void handleRequest(HttpServerExchange exchange) throws Exception { - RequestBodyPublisher requestBody = - new RequestBodyPublisher(exchange, this.dataBufferFactory); - requestBody.registerListener(); - ServerHttpRequest request = new UndertowServerHttpRequest(exchange, requestBody); + ServerHttpRequest request = + new UndertowServerHttpRequest(exchange, this.dataBufferFactory); - StreamSinkChannel responseChannel = exchange.getResponseChannel(); - ResponseBodySubscriber responseBody = - new ResponseBodySubscriber(exchange, responseChannel); - responseBody.registerListener(); ServerHttpResponse response = - new UndertowServerHttpResponse(exchange, responseChannel, - publisher -> Mono - .from(subscriber -> publisher.subscribe(responseBody)), - this.dataBufferFactory); + new UndertowServerHttpResponse(exchange, this.dataBufferFactory); this.delegate.handle(request, response).subscribe(new Subscriber() { @@ -106,183 +87,4 @@ public class UndertowHttpHandlerAdapter implements io.undertow.server.HttpHandle }); } - private static class RequestBodyPublisher extends AbstractRequestBodyPublisher { - - private final ChannelListener readListener = - new ReadListener(); - - private final ChannelListener closeListener = - new CloseListener(); - - private final StreamSourceChannel requestChannel; - - private final DataBufferFactory dataBufferFactory; - - private final PooledByteBuffer pooledByteBuffer; - - public RequestBodyPublisher(HttpServerExchange exchange, - DataBufferFactory dataBufferFactory) { - this.requestChannel = exchange.getRequestChannel(); - this.pooledByteBuffer = - exchange.getConnection().getByteBufferPool().allocate(); - this.dataBufferFactory = dataBufferFactory; - } - - public void registerListener() { - this.requestChannel.getReadSetter().set(this.readListener); - this.requestChannel.getCloseSetter().set(this.closeListener); - this.requestChannel.resumeReads(); - } - - @Override - protected DataBuffer read() throws IOException { - ByteBuffer byteBuffer = this.pooledByteBuffer.getBuffer(); - int read = this.requestChannel.read(byteBuffer); - if (logger.isTraceEnabled()) { - logger.trace("read:" + read); - } - - if (read > 0) { - byteBuffer.flip(); - return this.dataBufferFactory.wrap(byteBuffer); - } - else if (read == -1) { - onAllDataRead(); - } - return null; - } - - @Override - protected void close() { - if (this.pooledByteBuffer != null) { - IoUtils.safeClose(this.pooledByteBuffer); - } - if (this.requestChannel != null) { - IoUtils.safeClose(this.requestChannel); - } - } - - private class ReadListener implements ChannelListener { - - @Override - public void handleEvent(StreamSourceChannel channel) { - onDataAvailable(); - } - } - - private class CloseListener implements ChannelListener { - - @Override - public void handleEvent(StreamSourceChannel channel) { - onAllDataRead(); - } - } - } - - private static class ResponseBodySubscriber extends AbstractResponseBodySubscriber { - - private final ChannelListener listener = - new ResponseBodyListener(); - - private final HttpServerExchange exchange; - - private final StreamSinkChannel responseChannel; - - private volatile ByteBuffer byteBuffer; - - public ResponseBodySubscriber(HttpServerExchange exchange, - StreamSinkChannel responseChannel) { - this.exchange = exchange; - this.responseChannel = responseChannel; - } - - public void registerListener() { - this.responseChannel.getWriteSetter().set(this.listener); - this.responseChannel.resumeWrites(); - } - - @Override - protected void writeError(Throwable t) { - if (!this.exchange.isResponseStarted() && - this.exchange.getStatusCode() < 500) { - this.exchange.setStatusCode(500); - } - } - - @Override - protected void flush() throws IOException { - if (logger.isTraceEnabled()) { - logger.trace("flush"); - } - this.responseChannel.flush(); - } - - @Override - protected boolean write(DataBuffer dataBuffer) throws IOException { - if (this.byteBuffer == null) { - return false; - } - if (logger.isTraceEnabled()) { - logger.trace("write: " + dataBuffer); - } - int total = this.byteBuffer.remaining(); - int written = writeByteBuffer(this.byteBuffer); - - if (logger.isTraceEnabled()) { - logger.trace("written: " + written + " total: " + total); - } - return written == total; - } - - private int writeByteBuffer(ByteBuffer byteBuffer) throws IOException { - int written; - int totalWritten = 0; - do { - written = this.responseChannel.write(byteBuffer); - totalWritten += written; - } - while (byteBuffer.hasRemaining() && written > 0); - return totalWritten; - } - - @Override - protected void receiveBuffer(DataBuffer dataBuffer) { - super.receiveBuffer(dataBuffer); - this.byteBuffer = dataBuffer.asByteBuffer(); - } - - @Override - protected void releaseBuffer() { - super.releaseBuffer(); - this.byteBuffer = null; - } - - @Override - protected void close() { - try { - this.responseChannel.shutdownWrites(); - - if (!this.responseChannel.flush()) { - this.responseChannel.getWriteSetter().set(ChannelListeners - .flushingChannelListener( - o -> IoUtils.safeClose(this.responseChannel), - ChannelListeners.closingChannelExceptionHandler())); - this.responseChannel.resumeWrites(); - } - } - catch (IOException ignored) { - } - } - - private class ResponseBodyListener implements ChannelListener { - - @Override - public void handleEvent(StreamSinkChannel channel) { - onWritePossible(); - } - - } - - } - } \ No newline at end of file diff --git a/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpRequest.java b/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpRequest.java index 51ef11d4453..19f56d596dc 100644 --- a/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpRequest.java +++ b/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpRequest.java @@ -16,16 +16,22 @@ package org.springframework.http.server.reactive; +import java.io.IOException; import java.net.URI; import java.net.URISyntaxException; +import java.nio.ByteBuffer; +import io.undertow.connector.PooledByteBuffer; import io.undertow.server.HttpServerExchange; import io.undertow.server.handlers.Cookie; import io.undertow.util.HeaderValues; -import org.reactivestreams.Publisher; +import org.xnio.ChannelListener; +import org.xnio.IoUtils; +import org.xnio.channels.StreamSourceChannel; import reactor.core.publisher.Flux; import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferFactory; import org.springframework.http.HttpCookie; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; @@ -43,14 +49,14 @@ public class UndertowServerHttpRequest extends AbstractServerHttpRequest { private final HttpServerExchange exchange; - private final Flux body; + private final RequestBodyPublisher body; public UndertowServerHttpRequest(HttpServerExchange exchange, - Publisher body) { + DataBufferFactory dataBufferFactory) { Assert.notNull(exchange, "'exchange' is required."); - Assert.notNull(exchange, "'body' is required."); this.exchange = exchange; - this.body = Flux.from(body); + this.body = new RequestBodyPublisher(exchange, dataBufferFactory); + this.body.registerListener(); } @@ -92,7 +98,79 @@ public class UndertowServerHttpRequest extends AbstractServerHttpRequest { @Override public Flux getBody() { - return this.body; + return Flux.from(this.body); } + private static class RequestBodyPublisher extends AbstractRequestBodyPublisher { + + private final ChannelListener readListener = + new ReadListener(); + + private final ChannelListener closeListener = + new CloseListener(); + + private final StreamSourceChannel requestChannel; + + private final DataBufferFactory dataBufferFactory; + + private final PooledByteBuffer pooledByteBuffer; + + public RequestBodyPublisher(HttpServerExchange exchange, + DataBufferFactory dataBufferFactory) { + this.requestChannel = exchange.getRequestChannel(); + this.pooledByteBuffer = + exchange.getConnection().getByteBufferPool().allocate(); + this.dataBufferFactory = dataBufferFactory; + } + + private void registerListener() { + this.requestChannel.getReadSetter().set(this.readListener); + this.requestChannel.getCloseSetter().set(this.closeListener); + this.requestChannel.resumeReads(); + } + + @Override + protected DataBuffer read() throws IOException { + ByteBuffer byteBuffer = this.pooledByteBuffer.getBuffer(); + int read = this.requestChannel.read(byteBuffer); + if (logger.isTraceEnabled()) { + logger.trace("read:" + read); + } + + if (read > 0) { + byteBuffer.flip(); + return this.dataBufferFactory.wrap(byteBuffer); + } + else if (read == -1) { + onAllDataRead(); + } + return null; + } + + @Override + protected void close() { + if (this.pooledByteBuffer != null) { + IoUtils.safeClose(this.pooledByteBuffer); + } + if (this.requestChannel != null) { + IoUtils.safeClose(this.requestChannel); + } + } + + private class ReadListener implements ChannelListener { + + @Override + public void handleEvent(StreamSourceChannel channel) { + onDataAvailable(); + } + } + + private class CloseListener implements ChannelListener { + + @Override + public void handleEvent(StreamSourceChannel channel) { + onAllDataRead(); + } + } + } } diff --git a/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpResponse.java b/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpResponse.java index 7ce3a5d89aa..0dfb54f26b5 100644 --- a/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpResponse.java +++ b/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpResponse.java @@ -19,16 +19,19 @@ package org.springframework.http.server.reactive; import java.io.File; import java.io.FileInputStream; import java.io.IOException; +import java.nio.ByteBuffer; import java.nio.channels.FileChannel; import java.util.List; import java.util.Map; -import java.util.function.Function; import io.undertow.server.HttpServerExchange; import io.undertow.server.handlers.Cookie; import io.undertow.server.handlers.CookieImpl; import io.undertow.util.HttpString; import org.reactivestreams.Publisher; +import org.xnio.ChannelListener; +import org.xnio.ChannelListeners; +import org.xnio.IoUtils; import org.xnio.channels.StreamSinkChannel; import reactor.core.publisher.Mono; @@ -44,27 +47,18 @@ import org.springframework.util.Assert; * * @author Marek Hawrylczak * @author Rossen Stoyanchev + * @author Arjen Poutsma */ public class UndertowServerHttpResponse extends AbstractServerHttpResponse implements ZeroCopyHttpOutputMessage { private final HttpServerExchange exchange; - private final StreamSinkChannel responseChannel; - - private final Function, Mono> responseBodyWriter; - public UndertowServerHttpResponse(HttpServerExchange exchange, - StreamSinkChannel responseChannel, - Function, Mono> responseBodyWriter, DataBufferFactory dataBufferFactory) { super(dataBufferFactory); Assert.notNull(exchange, "'exchange' is required."); - Assert.notNull(responseChannel, "'responseChannel' must not be null"); - Assert.notNull(responseBodyWriter, "'responseBodyWriter' must not be null"); this.exchange = exchange; - this.responseChannel = responseChannel; - this.responseBodyWriter = responseBodyWriter; } @@ -80,16 +74,26 @@ public class UndertowServerHttpResponse extends AbstractServerHttpResponse @Override protected Mono writeWithInternal(Publisher publisher) { - return this.responseBodyWriter.apply(publisher); + return Mono.from(s -> { + // lazily create Subscriber, since calling + // {@link HttpServerExchange#getResponseChannel} as done in the + // ResponseBodySubscriber constructor commits the response status and headers + ResponseBodySubscriber subscriber = new ResponseBodySubscriber(this.exchange); + subscriber.registerListener(); + publisher.subscribe(subscriber); + }); } @Override public Mono writeWith(File file, long position, long count) { writeHeaders(); writeCookies(); + try { + StreamSinkChannel responseChannel = + getUndertowExchange().getResponseChannel(); FileChannel in = new FileInputStream(file).getChannel(); - long result = this.responseChannel.transferFrom(in, position, count); + long result = responseChannel.transferFrom(in, position, count); if (result < count) { return Mono.error(new IOException("Could only write " + result + " out of " + count + " bytes")); @@ -128,4 +132,107 @@ public class UndertowServerHttpResponse extends AbstractServerHttpResponse } } + private static class ResponseBodySubscriber extends AbstractResponseBodySubscriber { + + private final ChannelListener listener = new WriteListener(); + + private final HttpServerExchange exchange; + + private final StreamSinkChannel responseChannel; + + private volatile ByteBuffer byteBuffer; + + public ResponseBodySubscriber(HttpServerExchange exchange) { + this.exchange = exchange; + this.responseChannel = exchange.getResponseChannel(); + } + + public void registerListener() { + this.responseChannel.getWriteSetter().set(this.listener); + this.responseChannel.resumeWrites(); + } + + @Override + protected void writeError(Throwable t) { + if (!this.exchange.isResponseStarted() && + this.exchange.getStatusCode() < 500) { + this.exchange.setStatusCode(500); + } + } + + @Override + protected void flush() throws IOException { + if (logger.isTraceEnabled()) { + logger.trace("flush"); + } + this.responseChannel.flush(); + } + + @Override + protected boolean write(DataBuffer dataBuffer) throws IOException { + if (this.byteBuffer == null) { + return false; + } + if (logger.isTraceEnabled()) { + logger.trace("write: " + dataBuffer); + } + int total = this.byteBuffer.remaining(); + int written = writeByteBuffer(this.byteBuffer); + + if (logger.isTraceEnabled()) { + logger.trace("written: " + written + " total: " + total); + } + return written == total; + } + + private int writeByteBuffer(ByteBuffer byteBuffer) throws IOException { + int written; + int totalWritten = 0; + do { + written = this.responseChannel.write(byteBuffer); + totalWritten += written; + } + while (byteBuffer.hasRemaining() && written > 0); + return totalWritten; + } + + @Override + protected void receiveBuffer(DataBuffer dataBuffer) { + super.receiveBuffer(dataBuffer); + this.byteBuffer = dataBuffer.asByteBuffer(); + } + + @Override + protected void releaseBuffer() { + super.releaseBuffer(); + this.byteBuffer = null; + } + + @Override + protected void close() { + try { + this.responseChannel.shutdownWrites(); + + if (!this.responseChannel.flush()) { + this.responseChannel.getWriteSetter().set(ChannelListeners + .flushingChannelListener( + o -> IoUtils.safeClose(this.responseChannel), + ChannelListeners.closingChannelExceptionHandler())); + this.responseChannel.resumeWrites(); + } + } + catch (IOException ignored) { + } + } + + private class WriteListener implements ChannelListener { + + @Override + public void handleEvent(StreamSinkChannel channel) { + onWritePossible(); + } + + } + + } } diff --git a/spring-web-reactive/src/test/java/org/springframework/web/reactive/result/method/annotation/RequestMappingIntegrationTests.java b/spring-web-reactive/src/test/java/org/springframework/web/reactive/result/method/annotation/RequestMappingIntegrationTests.java index 5e5338a33f2..b04b0204ef3 100644 --- a/spring-web-reactive/src/test/java/org/springframework/web/reactive/result/method/annotation/RequestMappingIntegrationTests.java +++ b/spring-web-reactive/src/test/java/org/springframework/web/reactive/result/method/annotation/RequestMappingIntegrationTests.java @@ -67,9 +67,7 @@ import org.springframework.web.reactive.config.WebReactiveConfiguration; import org.springframework.web.reactive.result.view.freemarker.FreeMarkerConfigurer; import org.springframework.web.server.adapter.WebHttpHandlerBuilder; -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.Assert.*; /** @@ -173,7 +171,6 @@ public class RequestMappingIntegrationTests extends AbstractHttpHandlerIntegrati } @Test - @Ignore // Issue #119 public void serializeAsMonoResponseEntity() throws Exception { serializeAsPojo("http://localhost:" + port + "/monoResponseEntity"); }