diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowHttpHandlerAdapter.java b/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowHttpHandlerAdapter.java index 15eb1d4a68c..773c0677cd4 100644 --- a/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowHttpHandlerAdapter.java +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowHttpHandlerAdapter.java @@ -37,7 +37,7 @@ import org.springframework.util.Assert; public class UndertowHttpHandlerAdapter extends HttpHandlerAdapterSupport implements io.undertow.server.HttpHandler { - private DataBufferFactory dataBufferFactory = new DefaultDataBufferFactory(false); + private DataBufferFactory bufferFactory = new DefaultDataBufferFactory(false); public UndertowHttpHandlerAdapter(HttpHandler httpHandler) { @@ -49,41 +49,59 @@ public class UndertowHttpHandlerAdapter extends HttpHandlerAdapterSupport } - public void setDataBufferFactory(DataBufferFactory dataBufferFactory) { - Assert.notNull(dataBufferFactory, "DataBufferFactory must not be null"); - this.dataBufferFactory = dataBufferFactory; + public void setDataBufferFactory(DataBufferFactory bufferFactory) { + Assert.notNull(bufferFactory, "DataBufferFactory must not be null"); + this.bufferFactory = bufferFactory; + } + + public DataBufferFactory getDataBufferFactory() { + return this.bufferFactory; } @Override public void handleRequest(HttpServerExchange exchange) throws Exception { - UndertowServerHttpRequest request = new UndertowServerHttpRequest(exchange, this.dataBufferFactory); - ServerHttpResponse response = new UndertowServerHttpResponse(exchange, this.dataBufferFactory); + ServerHttpRequest request = new UndertowServerHttpRequest(exchange, getDataBufferFactory()); + ServerHttpResponse response = new UndertowServerHttpResponse(exchange, getDataBufferFactory()); - getHttpHandler().handle(request, response).subscribe(new Subscriber() { - @Override - public void onSubscribe(Subscription subscription) { - subscription.request(Long.MAX_VALUE); - } - @Override - public void onNext(Void aVoid) { - // no op - } - @Override - public void onError(Throwable ex) { - logger.error("Could not complete request", ex); - if (!exchange.isResponseStarted() && exchange.getStatusCode() <= 500) { - exchange.setStatusCode(500); - } - exchange.endExchange(); - } - @Override - public void onComplete() { - logger.debug("Successfully completed request"); - exchange.endExchange(); - } - }); + HandlerResultSubscriber resultSubscriber = new HandlerResultSubscriber(exchange); + getHttpHandler().handle(request, response).subscribe(resultSubscriber); } + + private class HandlerResultSubscriber implements Subscriber { + + private final HttpServerExchange exchange; + + + public HandlerResultSubscriber(HttpServerExchange exchange) { + this.exchange = exchange; + } + + @Override + public void onSubscribe(Subscription subscription) { + subscription.request(Long.MAX_VALUE); + } + + @Override + public void onNext(Void aVoid) { + // no op + } + + @Override + public void onError(Throwable ex) { + logger.error("Could not complete request", ex); + if (!this.exchange.isResponseStarted() && this.exchange.getStatusCode() < 500) { + this.exchange.setStatusCode(500); + } + this.exchange.endExchange(); + } + + @Override + public void onComplete() { + logger.debug("Successfully completed request"); + this.exchange.endExchange(); + } + } } diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpRequest.java b/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpRequest.java index 9ece475de54..25ac26fadbd 100644 --- a/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpRequest.java +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpRequest.java @@ -52,17 +52,16 @@ public class UndertowServerHttpRequest extends AbstractServerHttpRequest { private final RequestBodyPublisher body; - public UndertowServerHttpRequest(HttpServerExchange exchange, - DataBufferFactory dataBufferFactory) { + public UndertowServerHttpRequest(HttpServerExchange exchange, DataBufferFactory bufferFactory) { super(initUri(exchange), initHeaders(exchange)); this.exchange = exchange; - this.body = new RequestBodyPublisher(exchange, dataBufferFactory); - this.body.registerListener(exchange); + this.body = new RequestBodyPublisher(exchange, bufferFactory); + this.body.registerListeners(exchange); } private static URI initUri(HttpServerExchange exchange) { - Assert.notNull(exchange, "'exchange' is required."); + Assert.notNull(exchange, "HttpServerExchange is required."); try { return new URI(exchange.getRequestScheme(), null, exchange.getHostName(), exchange.getHostPort(), @@ -110,35 +109,29 @@ public class UndertowServerHttpRequest extends AbstractServerHttpRequest { private static class RequestBodyPublisher extends AbstractListenerReadPublisher { - private final ChannelListener readListener = - new ReadListener(); + private final StreamSourceChannel channel; - private final ChannelListener closeListener = - new CloseListener(); - - private final StreamSourceChannel requestChannel; - - private final DataBufferFactory dataBufferFactory; + private final DataBufferFactory bufferFactory; private final ByteBufferPool byteBufferPool; private PooledByteBuffer pooledByteBuffer; - public RequestBodyPublisher(HttpServerExchange exchange, - DataBufferFactory dataBufferFactory) { - this.requestChannel = exchange.getRequestChannel(); + + public RequestBodyPublisher(HttpServerExchange exchange, DataBufferFactory bufferFactory) { + this.channel = exchange.getRequestChannel(); + this.bufferFactory = bufferFactory; this.byteBufferPool = exchange.getConnection().getByteBufferPool(); - this.dataBufferFactory = dataBufferFactory; } - private void registerListener(HttpServerExchange exchange) { + private void registerListeners(HttpServerExchange exchange) { exchange.addExchangeCompleteListener((ex, next) -> { onAllDataRead(); next.proceed(); }); - this.requestChannel.getReadSetter().set(this.readListener); - this.requestChannel.getCloseSetter().set(this.closeListener); - this.requestChannel.resumeReads(); + this.channel.getReadSetter().set((ChannelListener) c -> onDataAvailable()); + this.channel.getCloseSetter().set((ChannelListener) c -> onAllDataRead()); + this.channel.resumeReads(); } @Override @@ -152,14 +145,14 @@ public class UndertowServerHttpRequest extends AbstractServerHttpRequest { this.pooledByteBuffer = this.byteBufferPool.allocate(); } ByteBuffer byteBuffer = this.pooledByteBuffer.getBuffer(); - int read = this.requestChannel.read(byteBuffer); + int read = this.channel.read(byteBuffer); if (logger.isTraceEnabled()) { logger.trace("read:" + read); } if (read > 0) { byteBuffer.flip(); - return this.dataBufferFactory.wrap(byteBuffer); + return this.bufferFactory.wrap(byteBuffer); } else if (read == -1) { onAllDataRead(); @@ -174,21 +167,5 @@ public class UndertowServerHttpRequest extends AbstractServerHttpRequest { } super.onAllDataRead(); } - - private class ReadListener implements ChannelListener { - - @Override - public void handleEvent(StreamSourceChannel channel) { - onDataAvailable(); - } - } - - private class CloseListener implements ChannelListener { - - @Override - public void handleEvent(StreamSourceChannel channel) { - onAllDataRead(); - } - } } } diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpResponse.java b/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpResponse.java index 0e1c950c21d..005adbe337e 100644 --- a/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpResponse.java +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpResponse.java @@ -141,22 +141,20 @@ public class UndertowServerHttpResponse extends AbstractListenerServerHttpRespon private static class ResponseBodyProcessor extends AbstractListenerWriteProcessor { - private final ChannelListener listener = new WriteListener(); - - private final StreamSinkChannel responseChannel; + private final StreamSinkChannel channel; private volatile ByteBuffer byteBuffer; - public ResponseBodyProcessor(StreamSinkChannel responseChannel) { - Assert.notNull(responseChannel, "'responseChannel' must not be null"); - this.responseChannel = responseChannel; + public ResponseBodyProcessor(StreamSinkChannel channel) { + Assert.notNull(channel, "StreamSinkChannel must not be null"); + this.channel = channel; } public void registerListener() { - this.responseChannel.getWriteSetter().set(this.listener); - this.responseChannel.resumeWrites(); + this.channel.getWriteSetter().set((ChannelListener) c -> onWritePossible()); + this.channel.resumeWrites(); } @Override @@ -180,7 +178,7 @@ public class UndertowServerHttpResponse extends AbstractListenerServerHttpRespon int written; int totalWritten = 0; do { - written = this.responseChannel.write(byteBuffer); + written = this.channel.write(byteBuffer); totalWritten += written; } while (byteBuffer.hasRemaining() && written > 0); @@ -208,14 +206,6 @@ public class UndertowServerHttpResponse extends AbstractListenerServerHttpRespon protected boolean isDataEmpty(DataBuffer dataBuffer) { return dataBuffer.readableByteCount() == 0; } - - private class WriteListener implements ChannelListener { - - @Override - public void handleEvent(StreamSinkChannel channel) { - onWritePossible(); - } - } } private class ResponseBodyFlushProcessor extends AbstractListenerFlushProcessor { diff --git a/spring-web/src/test/java/org/springframework/http/server/reactive/bootstrap/UndertowHttpServer.java b/spring-web/src/test/java/org/springframework/http/server/reactive/bootstrap/UndertowHttpServer.java index ee3ab6dafa4..4307ab5f4e2 100644 --- a/spring-web/src/test/java/org/springframework/http/server/reactive/bootstrap/UndertowHttpServer.java +++ b/spring-web/src/test/java/org/springframework/http/server/reactive/bootstrap/UndertowHttpServer.java @@ -17,7 +17,6 @@ package org.springframework.http.server.reactive.bootstrap; import io.undertow.Undertow; -import io.undertow.server.HttpHandler; import org.springframework.http.server.reactive.UndertowHttpHandlerAdapter; import org.springframework.util.Assert; @@ -34,10 +33,19 @@ public class UndertowHttpServer extends HttpServerSupport implements HttpServer @Override public void afterPropertiesSet() throws Exception { - Assert.notNull(getHttpHandler()); - HttpHandler handler = new UndertowHttpHandlerAdapter(getHttpHandler()); this.server = Undertow.builder().addHttpListener(getPort(), getHost()) - .setHandler(handler).build(); + .setHandler(initUndertowHttpHandlerAdapter()) + .build(); + } + + private UndertowHttpHandlerAdapter initUndertowHttpHandlerAdapter() { + if (getHttpHandlerMap() != null) { + return new UndertowHttpHandlerAdapter(getHttpHandlerMap()); + } + else { + Assert.notNull(getHttpHandler()); + return new UndertowHttpHandlerAdapter(getHttpHandler()); + } } @Override