From 1d48e7c5b9a45b3c8a37437b259c4a894d07bd07 Mon Sep 17 00:00:00 2001 From: Arjen Poutsma Date: Thu, 30 Jun 2016 15:26:17 +0200 Subject: [PATCH] Allow to set response status on Undertow Refactored Undertow support to register a response listener only when the body is written to, as opposed to registering it at startup. The reason for this is that getting the response channel from the HttpServerExchange commits the status and response, making it impossible to change them after the fact. Fixed issue #119. --- .../reactive/UndertowHttpHandlerAdapter.java | 206 +----------------- .../reactive/UndertowServerHttpRequest.java | 90 +++++++- .../reactive/UndertowServerHttpResponse.java | 133 +++++++++-- .../RequestMappingIntegrationTests.java | 5 +- 4 files changed, 209 insertions(+), 225 deletions(-) diff --git a/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/UndertowHttpHandlerAdapter.java b/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/UndertowHttpHandlerAdapter.java index dfc9d5837bd..a856a7a1019 100644 --- a/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/UndertowHttpHandlerAdapter.java +++ b/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/UndertowHttpHandlerAdapter.java @@ -16,29 +16,19 @@ package org.springframework.http.server.reactive; -import java.io.IOException; -import java.nio.ByteBuffer; - -import io.undertow.connector.PooledByteBuffer; import io.undertow.server.HttpServerExchange; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; -import org.xnio.ChannelListener; -import org.xnio.ChannelListeners; -import org.xnio.IoUtils; -import org.xnio.channels.StreamSinkChannel; -import org.xnio.channels.StreamSourceChannel; -import reactor.core.publisher.Mono; -import org.springframework.core.io.buffer.DataBuffer; import org.springframework.core.io.buffer.DataBufferFactory; import org.springframework.util.Assert; /** * @author Marek Hawrylczak * @author Rossen Stoyanchev + * @author Arjen Poutsma */ public class UndertowHttpHandlerAdapter implements io.undertow.server.HttpHandler { @@ -60,20 +50,11 @@ public class UndertowHttpHandlerAdapter implements io.undertow.server.HttpHandle @Override public void handleRequest(HttpServerExchange exchange) throws Exception { - RequestBodyPublisher requestBody = - new RequestBodyPublisher(exchange, this.dataBufferFactory); - requestBody.registerListener(); - ServerHttpRequest request = new UndertowServerHttpRequest(exchange, requestBody); + ServerHttpRequest request = + new UndertowServerHttpRequest(exchange, this.dataBufferFactory); - StreamSinkChannel responseChannel = exchange.getResponseChannel(); - ResponseBodySubscriber responseBody = - new ResponseBodySubscriber(exchange, responseChannel); - responseBody.registerListener(); ServerHttpResponse response = - new UndertowServerHttpResponse(exchange, responseChannel, - publisher -> Mono - .from(subscriber -> publisher.subscribe(responseBody)), - this.dataBufferFactory); + new UndertowServerHttpResponse(exchange, this.dataBufferFactory); this.delegate.handle(request, response).subscribe(new Subscriber() { @@ -106,183 +87,4 @@ public class UndertowHttpHandlerAdapter implements io.undertow.server.HttpHandle }); } - private static class RequestBodyPublisher extends AbstractRequestBodyPublisher { - - private final ChannelListener readListener = - new ReadListener(); - - private final ChannelListener closeListener = - new CloseListener(); - - private final StreamSourceChannel requestChannel; - - private final DataBufferFactory dataBufferFactory; - - private final PooledByteBuffer pooledByteBuffer; - - public RequestBodyPublisher(HttpServerExchange exchange, - DataBufferFactory dataBufferFactory) { - this.requestChannel = exchange.getRequestChannel(); - this.pooledByteBuffer = - exchange.getConnection().getByteBufferPool().allocate(); - this.dataBufferFactory = dataBufferFactory; - } - - public void registerListener() { - this.requestChannel.getReadSetter().set(this.readListener); - this.requestChannel.getCloseSetter().set(this.closeListener); - this.requestChannel.resumeReads(); - } - - @Override - protected DataBuffer read() throws IOException { - ByteBuffer byteBuffer = this.pooledByteBuffer.getBuffer(); - int read = this.requestChannel.read(byteBuffer); - if (logger.isTraceEnabled()) { - logger.trace("read:" + read); - } - - if (read > 0) { - byteBuffer.flip(); - return this.dataBufferFactory.wrap(byteBuffer); - } - else if (read == -1) { - onAllDataRead(); - } - return null; - } - - @Override - protected void close() { - if (this.pooledByteBuffer != null) { - IoUtils.safeClose(this.pooledByteBuffer); - } - if (this.requestChannel != null) { - IoUtils.safeClose(this.requestChannel); - } - } - - private class ReadListener implements ChannelListener { - - @Override - public void handleEvent(StreamSourceChannel channel) { - onDataAvailable(); - } - } - - private class CloseListener implements ChannelListener { - - @Override - public void handleEvent(StreamSourceChannel channel) { - onAllDataRead(); - } - } - } - - private static class ResponseBodySubscriber extends AbstractResponseBodySubscriber { - - private final ChannelListener listener = - new ResponseBodyListener(); - - private final HttpServerExchange exchange; - - private final StreamSinkChannel responseChannel; - - private volatile ByteBuffer byteBuffer; - - public ResponseBodySubscriber(HttpServerExchange exchange, - StreamSinkChannel responseChannel) { - this.exchange = exchange; - this.responseChannel = responseChannel; - } - - public void registerListener() { - this.responseChannel.getWriteSetter().set(this.listener); - this.responseChannel.resumeWrites(); - } - - @Override - protected void writeError(Throwable t) { - if (!this.exchange.isResponseStarted() && - this.exchange.getStatusCode() < 500) { - this.exchange.setStatusCode(500); - } - } - - @Override - protected void flush() throws IOException { - if (logger.isTraceEnabled()) { - logger.trace("flush"); - } - this.responseChannel.flush(); - } - - @Override - protected boolean write(DataBuffer dataBuffer) throws IOException { - if (this.byteBuffer == null) { - return false; - } - if (logger.isTraceEnabled()) { - logger.trace("write: " + dataBuffer); - } - int total = this.byteBuffer.remaining(); - int written = writeByteBuffer(this.byteBuffer); - - if (logger.isTraceEnabled()) { - logger.trace("written: " + written + " total: " + total); - } - return written == total; - } - - private int writeByteBuffer(ByteBuffer byteBuffer) throws IOException { - int written; - int totalWritten = 0; - do { - written = this.responseChannel.write(byteBuffer); - totalWritten += written; - } - while (byteBuffer.hasRemaining() && written > 0); - return totalWritten; - } - - @Override - protected void receiveBuffer(DataBuffer dataBuffer) { - super.receiveBuffer(dataBuffer); - this.byteBuffer = dataBuffer.asByteBuffer(); - } - - @Override - protected void releaseBuffer() { - super.releaseBuffer(); - this.byteBuffer = null; - } - - @Override - protected void close() { - try { - this.responseChannel.shutdownWrites(); - - if (!this.responseChannel.flush()) { - this.responseChannel.getWriteSetter().set(ChannelListeners - .flushingChannelListener( - o -> IoUtils.safeClose(this.responseChannel), - ChannelListeners.closingChannelExceptionHandler())); - this.responseChannel.resumeWrites(); - } - } - catch (IOException ignored) { - } - } - - private class ResponseBodyListener implements ChannelListener { - - @Override - public void handleEvent(StreamSinkChannel channel) { - onWritePossible(); - } - - } - - } - } \ No newline at end of file diff --git a/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpRequest.java b/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpRequest.java index 51ef11d4453..19f56d596dc 100644 --- a/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpRequest.java +++ b/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpRequest.java @@ -16,16 +16,22 @@ package org.springframework.http.server.reactive; +import java.io.IOException; import java.net.URI; import java.net.URISyntaxException; +import java.nio.ByteBuffer; +import io.undertow.connector.PooledByteBuffer; import io.undertow.server.HttpServerExchange; import io.undertow.server.handlers.Cookie; import io.undertow.util.HeaderValues; -import org.reactivestreams.Publisher; +import org.xnio.ChannelListener; +import org.xnio.IoUtils; +import org.xnio.channels.StreamSourceChannel; import reactor.core.publisher.Flux; import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferFactory; import org.springframework.http.HttpCookie; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; @@ -43,14 +49,14 @@ public class UndertowServerHttpRequest extends AbstractServerHttpRequest { private final HttpServerExchange exchange; - private final Flux body; + private final RequestBodyPublisher body; public UndertowServerHttpRequest(HttpServerExchange exchange, - Publisher body) { + DataBufferFactory dataBufferFactory) { Assert.notNull(exchange, "'exchange' is required."); - Assert.notNull(exchange, "'body' is required."); this.exchange = exchange; - this.body = Flux.from(body); + this.body = new RequestBodyPublisher(exchange, dataBufferFactory); + this.body.registerListener(); } @@ -92,7 +98,79 @@ public class UndertowServerHttpRequest extends AbstractServerHttpRequest { @Override public Flux getBody() { - return this.body; + return Flux.from(this.body); } + private static class RequestBodyPublisher extends AbstractRequestBodyPublisher { + + private final ChannelListener readListener = + new ReadListener(); + + private final ChannelListener closeListener = + new CloseListener(); + + private final StreamSourceChannel requestChannel; + + private final DataBufferFactory dataBufferFactory; + + private final PooledByteBuffer pooledByteBuffer; + + public RequestBodyPublisher(HttpServerExchange exchange, + DataBufferFactory dataBufferFactory) { + this.requestChannel = exchange.getRequestChannel(); + this.pooledByteBuffer = + exchange.getConnection().getByteBufferPool().allocate(); + this.dataBufferFactory = dataBufferFactory; + } + + private void registerListener() { + this.requestChannel.getReadSetter().set(this.readListener); + this.requestChannel.getCloseSetter().set(this.closeListener); + this.requestChannel.resumeReads(); + } + + @Override + protected DataBuffer read() throws IOException { + ByteBuffer byteBuffer = this.pooledByteBuffer.getBuffer(); + int read = this.requestChannel.read(byteBuffer); + if (logger.isTraceEnabled()) { + logger.trace("read:" + read); + } + + if (read > 0) { + byteBuffer.flip(); + return this.dataBufferFactory.wrap(byteBuffer); + } + else if (read == -1) { + onAllDataRead(); + } + return null; + } + + @Override + protected void close() { + if (this.pooledByteBuffer != null) { + IoUtils.safeClose(this.pooledByteBuffer); + } + if (this.requestChannel != null) { + IoUtils.safeClose(this.requestChannel); + } + } + + private class ReadListener implements ChannelListener { + + @Override + public void handleEvent(StreamSourceChannel channel) { + onDataAvailable(); + } + } + + private class CloseListener implements ChannelListener { + + @Override + public void handleEvent(StreamSourceChannel channel) { + onAllDataRead(); + } + } + } } diff --git a/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpResponse.java b/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpResponse.java index 7ce3a5d89aa..0dfb54f26b5 100644 --- a/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpResponse.java +++ b/spring-web-reactive/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpResponse.java @@ -19,16 +19,19 @@ package org.springframework.http.server.reactive; import java.io.File; import java.io.FileInputStream; import java.io.IOException; +import java.nio.ByteBuffer; import java.nio.channels.FileChannel; import java.util.List; import java.util.Map; -import java.util.function.Function; import io.undertow.server.HttpServerExchange; import io.undertow.server.handlers.Cookie; import io.undertow.server.handlers.CookieImpl; import io.undertow.util.HttpString; import org.reactivestreams.Publisher; +import org.xnio.ChannelListener; +import org.xnio.ChannelListeners; +import org.xnio.IoUtils; import org.xnio.channels.StreamSinkChannel; import reactor.core.publisher.Mono; @@ -44,27 +47,18 @@ import org.springframework.util.Assert; * * @author Marek Hawrylczak * @author Rossen Stoyanchev + * @author Arjen Poutsma */ public class UndertowServerHttpResponse extends AbstractServerHttpResponse implements ZeroCopyHttpOutputMessage { private final HttpServerExchange exchange; - private final StreamSinkChannel responseChannel; - - private final Function, Mono> responseBodyWriter; - public UndertowServerHttpResponse(HttpServerExchange exchange, - StreamSinkChannel responseChannel, - Function, Mono> responseBodyWriter, DataBufferFactory dataBufferFactory) { super(dataBufferFactory); Assert.notNull(exchange, "'exchange' is required."); - Assert.notNull(responseChannel, "'responseChannel' must not be null"); - Assert.notNull(responseBodyWriter, "'responseBodyWriter' must not be null"); this.exchange = exchange; - this.responseChannel = responseChannel; - this.responseBodyWriter = responseBodyWriter; } @@ -80,16 +74,26 @@ public class UndertowServerHttpResponse extends AbstractServerHttpResponse @Override protected Mono writeWithInternal(Publisher publisher) { - return this.responseBodyWriter.apply(publisher); + return Mono.from(s -> { + // lazily create Subscriber, since calling + // {@link HttpServerExchange#getResponseChannel} as done in the + // ResponseBodySubscriber constructor commits the response status and headers + ResponseBodySubscriber subscriber = new ResponseBodySubscriber(this.exchange); + subscriber.registerListener(); + publisher.subscribe(subscriber); + }); } @Override public Mono writeWith(File file, long position, long count) { writeHeaders(); writeCookies(); + try { + StreamSinkChannel responseChannel = + getUndertowExchange().getResponseChannel(); FileChannel in = new FileInputStream(file).getChannel(); - long result = this.responseChannel.transferFrom(in, position, count); + long result = responseChannel.transferFrom(in, position, count); if (result < count) { return Mono.error(new IOException("Could only write " + result + " out of " + count + " bytes")); @@ -128,4 +132,107 @@ public class UndertowServerHttpResponse extends AbstractServerHttpResponse } } + private static class ResponseBodySubscriber extends AbstractResponseBodySubscriber { + + private final ChannelListener listener = new WriteListener(); + + private final HttpServerExchange exchange; + + private final StreamSinkChannel responseChannel; + + private volatile ByteBuffer byteBuffer; + + public ResponseBodySubscriber(HttpServerExchange exchange) { + this.exchange = exchange; + this.responseChannel = exchange.getResponseChannel(); + } + + public void registerListener() { + this.responseChannel.getWriteSetter().set(this.listener); + this.responseChannel.resumeWrites(); + } + + @Override + protected void writeError(Throwable t) { + if (!this.exchange.isResponseStarted() && + this.exchange.getStatusCode() < 500) { + this.exchange.setStatusCode(500); + } + } + + @Override + protected void flush() throws IOException { + if (logger.isTraceEnabled()) { + logger.trace("flush"); + } + this.responseChannel.flush(); + } + + @Override + protected boolean write(DataBuffer dataBuffer) throws IOException { + if (this.byteBuffer == null) { + return false; + } + if (logger.isTraceEnabled()) { + logger.trace("write: " + dataBuffer); + } + int total = this.byteBuffer.remaining(); + int written = writeByteBuffer(this.byteBuffer); + + if (logger.isTraceEnabled()) { + logger.trace("written: " + written + " total: " + total); + } + return written == total; + } + + private int writeByteBuffer(ByteBuffer byteBuffer) throws IOException { + int written; + int totalWritten = 0; + do { + written = this.responseChannel.write(byteBuffer); + totalWritten += written; + } + while (byteBuffer.hasRemaining() && written > 0); + return totalWritten; + } + + @Override + protected void receiveBuffer(DataBuffer dataBuffer) { + super.receiveBuffer(dataBuffer); + this.byteBuffer = dataBuffer.asByteBuffer(); + } + + @Override + protected void releaseBuffer() { + super.releaseBuffer(); + this.byteBuffer = null; + } + + @Override + protected void close() { + try { + this.responseChannel.shutdownWrites(); + + if (!this.responseChannel.flush()) { + this.responseChannel.getWriteSetter().set(ChannelListeners + .flushingChannelListener( + o -> IoUtils.safeClose(this.responseChannel), + ChannelListeners.closingChannelExceptionHandler())); + this.responseChannel.resumeWrites(); + } + } + catch (IOException ignored) { + } + } + + private class WriteListener implements ChannelListener { + + @Override + public void handleEvent(StreamSinkChannel channel) { + onWritePossible(); + } + + } + + } } diff --git a/spring-web-reactive/src/test/java/org/springframework/web/reactive/result/method/annotation/RequestMappingIntegrationTests.java b/spring-web-reactive/src/test/java/org/springframework/web/reactive/result/method/annotation/RequestMappingIntegrationTests.java index 5e5338a33f2..b04b0204ef3 100644 --- a/spring-web-reactive/src/test/java/org/springframework/web/reactive/result/method/annotation/RequestMappingIntegrationTests.java +++ b/spring-web-reactive/src/test/java/org/springframework/web/reactive/result/method/annotation/RequestMappingIntegrationTests.java @@ -67,9 +67,7 @@ import org.springframework.web.reactive.config.WebReactiveConfiguration; import org.springframework.web.reactive.result.view.freemarker.FreeMarkerConfigurer; import org.springframework.web.server.adapter.WebHttpHandlerBuilder; -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.Assert.*; /** @@ -173,7 +171,6 @@ public class RequestMappingIntegrationTests extends AbstractHttpHandlerIntegrati } @Test - @Ignore // Issue #119 public void serializeAsMonoResponseEntity() throws Exception { serializeAsPojo("http://localhost:" + port + "/monoResponseEntity"); }