diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractListenerServerHttpResponse.java b/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractListenerServerHttpResponse.java index 3d58304385e..b14f95b30e7 100644 --- a/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractListenerServerHttpResponse.java +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractListenerServerHttpResponse.java @@ -26,7 +26,9 @@ import org.springframework.core.io.buffer.DataBuffer; import org.springframework.core.io.buffer.DataBufferFactory; /** - * Abstract base class for listener-based server responses, i.e. Servlet 3.1 and Undertow. + * Abstract base class for listener-based server responses, e.g. Servlet 3.1 + * and Undertow. + * * @author Arjen Poutsma * @since 5.0 */ @@ -34,51 +36,37 @@ public abstract class AbstractListenerServerHttpResponse extends AbstractServerH private final AtomicBoolean writeCalled = new AtomicBoolean(); + public AbstractListenerServerHttpResponse(DataBufferFactory dataBufferFactory) { super(dataBufferFactory); } + @Override protected final Mono writeWithInternal(Publisher body) { - if (this.writeCalled.compareAndSet(false, true)) { - Processor bodyProcessor = createBodyProcessor(); - return Mono.from(subscriber -> { - body.subscribe(bodyProcessor); - bodyProcessor.subscribe(subscriber); - }); - - } else { - return Mono.error(new IllegalStateException( - "writeWith() or writeAndFlushWith() has already been called")); - } + return writeAndFlushWithInternal(Mono.just(body)); } @Override protected final Mono writeAndFlushWithInternal(Publisher> body) { if (this.writeCalled.compareAndSet(false, true)) { - Processor, Void> bodyProcessor = - createBodyFlushProcessor(); + Processor, Void> bodyProcessor = createBodyFlushProcessor(); return Mono.from(subscriber -> { body.subscribe(bodyProcessor); bodyProcessor.subscribe(subscriber); }); - } else { + } + else { return Mono.error(new IllegalStateException( "writeWith() or writeAndFlushWith() has already been called")); } } - /** - * Abstract template method to create a {@code Processor} that - * will write the response body to the underlying output. Called from - * {@link #writeWithInternal(Publisher)}. - */ - protected abstract Processor createBodyProcessor(); - /** * Abstract template method to create a {@code Processor, Void>} * that will write the response body with flushes to the underlying output. Called from * {@link #writeAndFlushWithInternal(Publisher)}. */ protected abstract Processor, Void> createBodyFlushProcessor(); + } diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractResponseBodyFlushProcessor.java b/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractResponseBodyFlushProcessor.java index 71adf24dd16..ef63b70377e 100644 --- a/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractResponseBodyFlushProcessor.java +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractResponseBodyFlushProcessor.java @@ -35,26 +35,25 @@ import org.springframework.core.io.buffer.DataBuffer; * Servlet 3.1 and Undertow support. * * @author Arjen Poutsma + * @author Violeta Georgieva * @since 5.0 * @see ServletServerHttpRequest * @see UndertowHttpHandlerAdapter * @see ServerHttpResponse#writeAndFlushWith(Publisher) */ -abstract class AbstractResponseBodyFlushProcessor - implements Processor, Void> { +abstract class AbstractResponseBodyFlushProcessor implements Processor, Void> { protected final Log logger = LogFactory.getLog(getClass()); - private final ResponseBodyWriteResultPublisher publisherDelegate = - new ResponseBodyWriteResultPublisher(); + private final ResponseBodyWriteResultPublisher resultPublisher = new ResponseBodyWriteResultPublisher(); - private final AtomicReference state = - new AtomicReference<>(State.UNSUBSCRIBED); + private final AtomicReference state = new AtomicReference<>(State.UNSUBSCRIBED); private volatile boolean subscriberCompleted; private Subscription subscription; + // Subscriber @Override @@ -89,13 +88,15 @@ abstract class AbstractResponseBodyFlushProcessor this.state.get().onComplete(this); } + // Publisher @Override public final void subscribe(Subscriber subscriber) { - this.publisherDelegate.subscribe(subscriber); + this.resultPublisher.subscribe(subscriber); } + /** * Creates a new processor for subscribing to a body chunk. */ @@ -106,6 +107,11 @@ abstract class AbstractResponseBodyFlushProcessor */ protected abstract void flush() throws IOException; + + private boolean changeState(State oldState, State newState) { + return this.state.compareAndSet(oldState, newState); + } + private void writeComplete() { if (logger.isTraceEnabled()) { logger.trace(this.state + " writeComplete"); @@ -114,17 +120,19 @@ abstract class AbstractResponseBodyFlushProcessor } - private boolean changeState(State oldState, State newState) { - return this.state.compareAndSet(oldState, newState); + private void cancel() { + this.subscription.cancel(); } + private enum State { + UNSUBSCRIBED { + @Override - public void onSubscribe(AbstractResponseBodyFlushProcessor processor, - Subscription subscription) { + public void onSubscribe(AbstractResponseBodyFlushProcessor processor, Subscription subscription) { Objects.requireNonNull(subscription, "Subscription cannot be null"); - if (processor.changeState(this, SUBSCRIBED)) { + if (processor.changeState(this, REQUESTED)) { processor.subscription = subscription; subscription.request(1); } @@ -132,86 +140,97 @@ abstract class AbstractResponseBodyFlushProcessor super.onSubscribe(processor, subscription); } } - }, SUBSCRIBED { + }, + REQUESTED { + @Override - public void onNext(AbstractResponseBodyFlushProcessor processor, - Publisher chunk) { - Processor chunkProcessor = - processor.createBodyProcessor(); - chunk.subscribe(chunkProcessor); - chunkProcessor.subscribe(new WriteSubscriber(processor)); + public void onNext(AbstractResponseBodyFlushProcessor processor, Publisher chunk) { + if (processor.changeState(this, RECEIVED)) { + Processor chunkProcessor = processor.createBodyProcessor(); + chunk.subscribe(chunkProcessor); + chunkProcessor.subscribe(new WriteSubscriber(processor)); + } } @Override - void onComplete(AbstractResponseBodyFlushProcessor processor) { - processor.subscriberCompleted = true; + public void onComplete(AbstractResponseBodyFlushProcessor processor) { + if (processor.changeState(this, COMPLETED)) { + processor.resultPublisher.publishComplete(); + } } + }, + RECEIVED { @Override public void writeComplete(AbstractResponseBodyFlushProcessor processor) { + try { + processor.flush(); + } + catch (IOException ex) { + processor.cancel(); + processor.onError(ex); + } + if (processor.subscriberCompleted) { if (processor.changeState(this, COMPLETED)) { - processor.publisherDelegate.publishComplete(); + processor.resultPublisher.publishComplete(); } } else { - try { - processor.flush(); - } - catch (IOException ex) { - processor.onError(ex); + if (processor.changeState(this, REQUESTED)) { + processor.subscription.request(1); } - processor.subscription.request(1); } } - }, COMPLETED { - @Override - public void onNext(AbstractResponseBodyFlushProcessor processor, - Publisher publisher) { - // ignore + @Override + public void onComplete(AbstractResponseBodyFlushProcessor processor) { + processor.subscriberCompleted = true; } + }, + COMPLETED { @Override - void onError(AbstractResponseBodyFlushProcessor processor, Throwable t) { + public void onNext(AbstractResponseBodyFlushProcessor processor, + Publisher publisher) { // ignore + } @Override - void onComplete(AbstractResponseBodyFlushProcessor processor) { + public void onError(AbstractResponseBodyFlushProcessor processor, Throwable t) { // ignore } @Override - public void writeComplete(AbstractResponseBodyFlushProcessor processor) { + public void onComplete(AbstractResponseBodyFlushProcessor processor) { // ignore } }; - public void onSubscribe(AbstractResponseBodyFlushProcessor processor, - Subscription subscription) { + public void onSubscribe(AbstractResponseBodyFlushProcessor processor, Subscription subscription) { subscription.cancel(); } - public void onNext(AbstractResponseBodyFlushProcessor processor, - Publisher publisher) { + public void onNext(AbstractResponseBodyFlushProcessor processor, Publisher publisher) { throw new IllegalStateException(toString()); } - void onError(AbstractResponseBodyFlushProcessor processor, Throwable t) { + public void onError(AbstractResponseBodyFlushProcessor processor, Throwable ex) { if (processor.changeState(this, COMPLETED)) { - processor.publisherDelegate.publishError(t); + processor.resultPublisher.publishError(ex); } } - void onComplete(AbstractResponseBodyFlushProcessor processor) { + public void onComplete(AbstractResponseBodyFlushProcessor processor) { throw new IllegalStateException(toString()); } public void writeComplete(AbstractResponseBodyFlushProcessor processor) { - throw new IllegalStateException(toString()); + // ignore } + private static class WriteSubscriber implements Subscriber { private final AbstractResponseBodyFlushProcessor processor; @@ -221,8 +240,8 @@ abstract class AbstractResponseBodyFlushProcessor } @Override - public void onSubscribe(Subscription s) { - s.request(Long.MAX_VALUE); + public void onSubscribe(Subscription subscription) { + subscription.request(Long.MAX_VALUE); } @Override @@ -230,13 +249,14 @@ abstract class AbstractResponseBodyFlushProcessor } @Override - public void onError(Throwable t) { - processor.onError(t); + public void onError(Throwable ex) { + this.processor.cancel(); + this.processor.onError(ex); } @Override public void onComplete() { - processor.writeComplete(); + this.processor.writeComplete(); } } } diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractResponseBodyProcessor.java b/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractResponseBodyProcessor.java index 1a0268cd07d..3e62264dd13 100644 --- a/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractResponseBodyProcessor.java +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractResponseBodyProcessor.java @@ -48,11 +48,9 @@ abstract class AbstractResponseBodyProcessor implements Processor state = - new AtomicReference<>(State.UNSUBSCRIBED); + private final AtomicReference state = new AtomicReference<>(State.UNSUBSCRIBED); private volatile DataBuffer currentBuffer; @@ -60,6 +58,7 @@ abstract class AbstractResponseBodyProcessor implements Processor subscriber) { - this.publisherDelegate.subscribe(subscriber); + this.resultPublisher.subscribe(subscriber); } + // listener methods /** @@ -159,10 +160,15 @@ abstract class AbstractResponseBodyProcessor implements Processor, Void> createBodyFlushProcessor() { + Processor, Void> processor = new ResponseBodyFlushProcessor(); + registerListener(); + return processor; + } + + private void registerListener() { + try { + outputStream().setWriteListener(writeListener); + } + catch (IOException e) { + throw new UncheckedIOException(e); } } + private ServletOutputStream outputStream() throws IOException { + return this.response.getOutputStream(); + } + private void flush() throws IOException { - ServletOutputStream outputStream = this.response.getOutputStream(); + ServletOutputStream outputStream = outputStream(); if (outputStream.isReady()) { try { outputStream.flush(); @@ -136,23 +152,6 @@ public class ServletServerHttpResponse extends AbstractListenerServerHttpRespons } } - @Override - protected ResponseBodyProcessor createBodyProcessor() { - try { - registerListener(); - this.bodyProcessor = new ResponseBodyProcessor(this.response.getOutputStream(), - this.bufferSize); - return this.bodyProcessor; - } - catch (IOException ex) { - throw new UncheckedIOException(ex); - } - } - - @Override - protected AbstractResponseBodyFlushProcessor createBodyFlushProcessor() { - return new ResponseBodyFlushProcessor(); - } private class ResponseBodyProcessor extends AbstractResponseBodyProcessor { @@ -160,11 +159,13 @@ public class ServletServerHttpResponse extends AbstractListenerServerHttpRespons private final int bufferSize; + public ResponseBodyProcessor(ServletOutputStream outputStream, int bufferSize) { this.outputStream = outputStream; this.bufferSize = bufferSize; } + @Override protected boolean isWritePossible() { return this.outputStream.isReady(); @@ -206,8 +207,7 @@ public class ServletServerHttpResponse extends AbstractListenerServerHttpRespons byte[] buffer = new byte[this.bufferSize]; int bytesRead = -1; - while (this.outputStream.isReady() && - (bytesRead = input.read(buffer)) != -1) { + while (this.outputStream.isReady() && (bytesRead = input.read(buffer)) != -1) { this.outputStream.write(buffer, 0, bytesRead); bytesWritten += bytesRead; } @@ -229,6 +229,7 @@ public class ServletServerHttpResponse extends AbstractListenerServerHttpRespons @Override public void onError(Throwable ex) { if (bodyProcessor != null) { + bodyProcessor.cancel(); bodyProcessor.onError(ex); } } @@ -238,7 +239,13 @@ public class ServletServerHttpResponse extends AbstractListenerServerHttpRespons @Override protected Processor createBodyProcessor() { - return ServletServerHttpResponse.this.createBodyProcessor(); + try { + bodyProcessor = new ResponseBodyProcessor(outputStream(), bufferSize); + return bodyProcessor; + } + catch (IOException ex) { + throw new UncheckedIOException(ex); + } } @Override diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpResponse.java b/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpResponse.java index b4483aea480..30102b6e2d5 100644 --- a/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpResponse.java +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpResponse.java @@ -55,13 +55,14 @@ public class UndertowServerHttpResponse extends AbstractListenerServerHttpRespon private StreamSinkChannel responseChannel; - public UndertowServerHttpResponse(HttpServerExchange exchange, - DataBufferFactory dataBufferFactory) { - super(dataBufferFactory); + + public UndertowServerHttpResponse(HttpServerExchange exchange, DataBufferFactory bufferFactory) { + super(bufferFactory); Assert.notNull(exchange, "'exchange' is required."); this.exchange = exchange; } + public HttpServerExchange getUndertowExchange() { return this.exchange; } @@ -78,10 +79,8 @@ public class UndertowServerHttpResponse extends AbstractListenerServerHttpRespon public Mono writeWith(File file, long position, long count) { writeHeaders(); writeCookies(); - try { - StreamSinkChannel responseChannel = - getUndertowExchange().getResponseChannel(); + StreamSinkChannel responseChannel = getUndertowExchange().getResponseChannel(); @SuppressWarnings("resource") FileChannel in = new FileInputStream(file).getChannel(); long result = responseChannel.transferFrom(in, position, count); @@ -124,20 +123,19 @@ public class UndertowServerHttpResponse extends AbstractListenerServerHttpRespon } @Override - protected ResponseBodyProcessor createBodyProcessor() { + protected AbstractResponseBodyFlushProcessor createBodyFlushProcessor() { + return new ResponseBodyFlushProcessor(); + } + + private ResponseBodyProcessor createBodyProcessor() { if (this.responseChannel == null) { this.responseChannel = this.exchange.getResponseChannel(); } - ResponseBodyProcessor bodyProcessor = - new ResponseBodyProcessor( this.responseChannel); + ResponseBodyProcessor bodyProcessor = new ResponseBodyProcessor( this.responseChannel); bodyProcessor.registerListener(); return bodyProcessor; } - @Override - protected AbstractResponseBodyFlushProcessor createBodyFlushProcessor() { - return new ResponseBodyFlushProcessor(); - } private static class ResponseBodyProcessor extends AbstractResponseBodyProcessor { @@ -147,11 +145,13 @@ public class UndertowServerHttpResponse extends AbstractListenerServerHttpRespon private volatile ByteBuffer byteBuffer; + public ResponseBodyProcessor(StreamSinkChannel responseChannel) { Assert.notNull(responseChannel, "'responseChannel' must not be null"); this.responseChannel = responseChannel; } + public void registerListener() { this.responseChannel.getWriteSetter().set(this.listener); this.responseChannel.resumeWrites(); @@ -203,9 +203,7 @@ public class UndertowServerHttpResponse extends AbstractListenerServerHttpRespon public void handleEvent(StreamSinkChannel channel) { onWritePossible(); } - } - } private class ResponseBodyFlushProcessor extends AbstractResponseBodyFlushProcessor { @@ -224,6 +222,6 @@ public class UndertowServerHttpResponse extends AbstractListenerServerHttpRespon UndertowServerHttpResponse.this.responseChannel.flush(); } } - } + } diff --git a/spring-web/src/test/java/org/springframework/http/server/reactive/WriteOnlyHandlerIntegrationTests.java b/spring-web/src/test/java/org/springframework/http/server/reactive/WriteOnlyHandlerIntegrationTests.java new file mode 100644 index 00000000000..b7bafdfe420 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/server/reactive/WriteOnlyHandlerIntegrationTests.java @@ -0,0 +1,86 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server.reactive; + +import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.util.Random; + +import org.junit.Test; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.http.RequestEntity; +import org.springframework.http.ResponseEntity; +import org.springframework.http.server.reactive.HttpHandler; +import org.springframework.http.server.reactive.ServerHttpRequest; +import org.springframework.http.server.reactive.ServerHttpResponse; +import org.springframework.web.client.RestTemplate; + +import static org.junit.Assert.*; + +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +/** + * @author Violeta Georgieva + * @since 5.0 + */ +public class WriteOnlyHandlerIntegrationTests extends AbstractHttpHandlerIntegrationTests { + + private static final int REQUEST_SIZE = 4096 * 3; + + private Random rnd = new Random(); + + private byte[] body; + + + @Override + protected WriteOnlyHandler createHttpHandler() { + return new WriteOnlyHandler(); + } + + + @Test + public void writeOnly() throws Exception { + RestTemplate restTemplate = new RestTemplate(); + + this.body = randomBytes(); + RequestEntity request = RequestEntity.post( + new URI("http://localhost:" + port)).body( + "".getBytes(StandardCharsets.UTF_8)); + ResponseEntity response = restTemplate.exchange(request, byte[].class); + + assertArrayEquals(body, response.getBody()); + } + + + private byte[] randomBytes() { + byte[] buffer = new byte[REQUEST_SIZE]; + rnd.nextBytes(buffer); + return buffer; + } + + + public class WriteOnlyHandler implements HttpHandler { + + @Override + public Mono handle(ServerHttpRequest request, ServerHttpResponse response) { + DataBuffer buffer = response.bufferFactory().allocateBuffer(body.length); + buffer.write(body); + return response.writeAndFlushWith(Flux.just(Flux.just(buffer))); + } + } +}