diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractListenerReadPublisher.java b/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractListenerReadPublisher.java index e83dc84cfc1..edff8bcad8e 100644 --- a/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractListenerReadPublisher.java +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractListenerReadPublisher.java @@ -130,6 +130,14 @@ public abstract class AbstractListenerReadPublisher implements Publisher { */ protected abstract void readingPaused(); + /** + * Invoked after an I/O read error from the underlying server or after a + * cancellation signal from the downstream consumer to allow sub-classes + * to discard any current cached data they might have. + * @since 5.1.2 + */ + protected abstract void discardData(); + // Private methods for use in State... @@ -382,7 +390,10 @@ public abstract class AbstractListenerReadPublisher implements Publisher { } void cancel(AbstractListenerReadPublisher publisher) { - if (!publisher.changeState(this, COMPLETED)) { + if (publisher.changeState(this, COMPLETED)) { + publisher.discardData(); + } + else { publisher.state.get().cancel(publisher); } } @@ -405,6 +416,7 @@ public abstract class AbstractListenerReadPublisher implements Publisher { void onError(AbstractListenerReadPublisher publisher, Throwable t) { if (publisher.changeState(this, COMPLETED)) { + publisher.discardData(); Subscriber s = publisher.subscriber; if (s != null) { s.onError(t); diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractListenerWriteProcessor.java b/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractListenerWriteProcessor.java index 50a245dc30c..0a4d703a0ea 100644 --- a/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractListenerWriteProcessor.java +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractListenerWriteProcessor.java @@ -118,6 +118,9 @@ public abstract class AbstractListenerWriteProcessor implements Processor subscriber) { + // Technically, cancellation from the result subscriber should be propagated + // to the upstream subscription. In practice, HttpHandler server adapters + // don't have a reason to cancel the result subscription. this.resultPublisher.subscribe(subscriber); } @@ -136,8 +139,14 @@ public abstract class AbstractListenerWriteProcessor implements Processor implements Processor implements Processor implements Processor implements Processor void onNext(AbstractListenerWriteProcessor processor, T data) { - throw new IllegalStateException(toString()); + processor.discardData(data); + processor.cancel(); + processor.onError(new IllegalStateException("Illegal onNext without demand")); } public void onError(AbstractListenerWriteProcessor processor, Throwable ex) { if (processor.changeState(this, COMPLETED)) { + processor.discardCurrentData(); processor.writingComplete(); processor.resultPublisher.publishError(ex); } diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/ServletServerHttpRequest.java b/spring-web/src/main/java/org/springframework/http/server/reactive/ServletServerHttpRequest.java index 26235a83d6e..972b05e371a 100644 --- a/spring-web/src/main/java/org/springframework/http/server/reactive/ServletServerHttpRequest.java +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/ServletServerHttpRequest.java @@ -290,6 +290,11 @@ class ServletServerHttpRequest extends AbstractServerHttpRequest { // no-op } + @Override + protected void discardData() { + // Nothing to discard since we pass data buffers on immediately.. + } + private class RequestBodyPublisherReadListener implements ReadListener { diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/ServletServerHttpResponse.java b/spring-web/src/main/java/org/springframework/http/server/reactive/ServletServerHttpResponse.java index 143667a2dbc..ab28e2cd399 100644 --- a/spring-web/src/main/java/org/springframework/http/server/reactive/ServletServerHttpResponse.java +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/ServletServerHttpResponse.java @@ -318,6 +318,7 @@ class ServletServerHttpResponse extends AbstractListenerServerHttpResponse { } int remaining = dataBuffer.readableByteCount(); if (ready && remaining > 0) { + // In case of IOException, onError handling should call discardData(DataBuffer).. int written = writeToOutputStream(dataBuffer); if (this.logger.isTraceEnabled()) { this.logger.trace("written: " + written + " total: " + remaining); @@ -337,6 +338,11 @@ class ServletServerHttpResponse extends AbstractListenerServerHttpResponse { protected void writingComplete() { bodyProcessor = null; } + + @Override + protected void discardData(DataBuffer dataBuffer) { + DataBufferUtils.release(dataBuffer); + } } } diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpRequest.java b/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpRequest.java index 8ef9eb90c3f..662a8e4cabb 100644 --- a/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpRequest.java +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpRequest.java @@ -194,6 +194,10 @@ class UndertowServerHttpRequest extends AbstractServerHttpRequest { } } + @Override + protected void discardData() { + // Nothing to discard since we pass data buffers on immediately.. + } } private static class UndertowDataBuffer implements PooledDataBuffer { diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpResponse.java b/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpResponse.java index 58aed42f296..bf47e4cbf22 100644 --- a/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpResponse.java +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpResponse.java @@ -174,6 +174,7 @@ class UndertowServerHttpResponse extends AbstractListenerServerHttpResponse impl // Track write listener calls from here on.. this.writePossible = false; + // In case of IOException, onError handling should call discardData(DataBuffer).. int total = buffer.remaining(); int written = writeByteBuffer(buffer); @@ -228,6 +229,11 @@ class UndertowServerHttpResponse extends AbstractListenerServerHttpResponse impl cancel(); onError(ex); } + + @Override + protected void discardData(DataBuffer dataBuffer) { + DataBufferUtils.release(dataBuffer); + } } diff --git a/spring-web/src/test/java/org/springframework/http/server/reactive/ListenerReadPublisherTests.java b/spring-web/src/test/java/org/springframework/http/server/reactive/ListenerReadPublisherTests.java index 81260eb63f5..f34ed849764 100644 --- a/spring-web/src/test/java/org/springframework/http/server/reactive/ListenerReadPublisherTests.java +++ b/spring-web/src/test/java/org/springframework/http/server/reactive/ListenerReadPublisherTests.java @@ -16,54 +16,90 @@ package org.springframework.http.server.reactive; -import java.io.IOException; - +import org.junit.Before; import org.junit.Test; -import org.mockito.invocation.InvocationOnMock; -import org.mockito.stubbing.Answer; import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; import org.springframework.core.io.buffer.DataBuffer; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.isA; -import static org.mockito.Mockito.mock; -import static org.junit.Assert.assertTrue; +import static org.junit.Assert.*; +import static org.mockito.Mockito.*; /** - * Unit tests for {@link AbstractListenerReadPublisher} - * + * Unit tests for {@link AbstractListenerReadPublisher}. + * * @author Violeta Georgieva - * @since 5.0 + * @author Rossen Stoyanchev */ public class ListenerReadPublisherTests { + private final TestListenerReadPublisher publisher = new TestListenerReadPublisher(); + + private final TestSubscriber subscriber = new TestSubscriber(); + + + @Before + public void setup() { + this.publisher.subscribe(this.subscriber); + } + + @Test - @SuppressWarnings("unchecked") - public void testReceiveTwoRequestCallsWhenOnSubscribe() { - Subscriber subscriber = mock(Subscriber.class); - doAnswer(new SubscriptionAnswer()).when(subscriber).onSubscribe(isA(Subscription.class)); + public void twoReads() { + + this.subscriber.getSubscription().request(2); + this.publisher.onDataAvailable(); + + assertEquals(2, this.publisher.getReadCalls()); + } + + @Test // SPR-17410 + public void discardDataOnError() { - TestListenerReadPublisher publisher = new TestListenerReadPublisher(); - publisher.subscribe(subscriber); - publisher.onDataAvailable(); + this.subscriber.getSubscription().request(2); + this.publisher.onDataAvailable(); + this.publisher.onError(new IllegalStateException()); - assertTrue(publisher.getReadCalls() == 2); + assertEquals(2, this.publisher.getReadCalls()); + assertEquals(1, this.publisher.getDiscardCalls()); } - private static final class TestListenerReadPublisher extends AbstractListenerReadPublisher { + @Test // SPR-17410 + public void discardDataOnCancel() { + + this.subscriber.getSubscription().request(2); + this.subscriber.setCancelOnNext(true); + this.publisher.onDataAvailable(); + + assertEquals(1, this.publisher.getReadCalls()); + assertEquals(1, this.publisher.getDiscardCalls()); + } + + + private static final class TestListenerReadPublisher extends AbstractListenerReadPublisher { private int readCalls = 0; + private int discardCalls = 0; + + + public int getReadCalls() { + return this.readCalls; + } + + public int getDiscardCalls() { + return this.discardCalls; + } + @Override protected void checkOnDataAvailable() { // no-op } @Override - protected DataBuffer read() throws IOException { - readCalls++; + protected DataBuffer read() { + this.readCalls++; return mock(DataBuffer.class); } @@ -72,22 +108,48 @@ public class ListenerReadPublisherTests { // No-op } - public int getReadCalls() { - return this.readCalls; + @Override + protected void discardData() { + this.discardCalls++; } - } - private static final class SubscriptionAnswer implements Answer { + + private static final class TestSubscriber implements Subscriber { + + private Subscription subscription; + + private boolean cancelOnNext; + + + public Subscription getSubscription() { + return this.subscription; + } + + public void setCancelOnNext(boolean cancelOnNext) { + this.cancelOnNext = cancelOnNext; + } + + + @Override + public void onSubscribe(Subscription subscription) { + this.subscription = subscription; + } @Override - public Subscription answer(InvocationOnMock invocation) throws Throwable { - Subscription arg = (Subscription) invocation.getArguments()[0]; - arg.request(1); - arg.request(1); - return arg; + public void onNext(DataBuffer dataBuffer) { + if (this.cancelOnNext) { + this.subscription.cancel(); + } } + @Override + public void onError(Throwable t) { + } + + @Override + public void onComplete() { + } } } diff --git a/spring-web/src/test/java/org/springframework/http/server/reactive/ListenerWriteProcessorTests.java b/spring-web/src/test/java/org/springframework/http/server/reactive/ListenerWriteProcessorTests.java new file mode 100644 index 00000000000..80348355bfc --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/server/reactive/ListenerWriteProcessorTests.java @@ -0,0 +1,206 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.http.server.reactive; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import org.junit.Before; +import org.junit.Test; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; + +import org.springframework.core.io.buffer.DataBuffer; + +import static junit.framework.TestCase.*; +import static org.mockito.Mockito.*; + +/** + * Unit tests for {@link AbstractListenerWriteProcessor}. + * + * @author Rossen Stoyanchev + */ +public class ListenerWriteProcessorTests { + + private final TestListenerWriteProcessor processor = new TestListenerWriteProcessor(); + + private final TestResultSubscriber resultSubscriber = new TestResultSubscriber(); + + private final TestSubscription subscription = new TestSubscription(); + + + @Before + public void setup() { + this.processor.subscribe(this.resultSubscriber); + this.processor.onSubscribe(this.subscription); + assertEquals(1, subscription.getDemand()); + } + + + @Test // SPR-17410 + public void writePublisherError() { + + // Turn off writing so next item will be cached + this.processor.setWritePossible(false); + DataBuffer buffer = mock(DataBuffer.class); + this.processor.onNext(buffer); + + // Send error while item cached + this.processor.onError(new IllegalStateException()); + + assertNotNull("Error should flow to result publisher", this.resultSubscriber.getError()); + assertEquals(1, this.processor.getDiscardedBuffers().size()); + assertSame(buffer, this.processor.getDiscardedBuffers().get(0)); + } + + @Test // SPR-17410 + public void ioExceptionDuringWrite() { + + // Fail on next write + this.processor.setWritePossible(true); + this.processor.setFailOnWrite(true); + + // Write + DataBuffer buffer = mock(DataBuffer.class); + this.processor.onNext(buffer); + + assertNotNull("Error should flow to result publisher", this.resultSubscriber.getError()); + assertEquals(1, this.processor.getDiscardedBuffers().size()); + assertSame(buffer, this.processor.getDiscardedBuffers().get(0)); + } + + @Test // SPR-17410 + public void onNextWithoutDemand() { + + // Disable writing: next item will be cached.. + this.processor.setWritePossible(false); + DataBuffer buffer1 = mock(DataBuffer.class); + this.processor.onNext(buffer1); + + // Send more data illegally + DataBuffer buffer2 = mock(DataBuffer.class); + this.processor.onNext(buffer2); + + assertNotNull("Error should flow to result publisher", this.resultSubscriber.getError()); + assertEquals(2, this.processor.getDiscardedBuffers().size()); + assertSame(buffer2, this.processor.getDiscardedBuffers().get(0)); + assertSame(buffer1, this.processor.getDiscardedBuffers().get(1)); + } + + + private static final class TestListenerWriteProcessor extends AbstractListenerWriteProcessor { + + private final List discardedBuffers = new ArrayList<>(); + + private boolean writePossible; + + private boolean failOnWrite; + + + public List getDiscardedBuffers() { + return this.discardedBuffers; + } + + public void setWritePossible(boolean writePossible) { + this.writePossible = writePossible; + } + + public void setFailOnWrite(boolean failOnWrite) { + this.failOnWrite = failOnWrite; + } + + + @Override + protected boolean isDataEmpty(DataBuffer dataBuffer) { + return false; + } + + @Override + protected boolean isWritePossible() { + return this.writePossible; + } + + @Override + protected boolean write(DataBuffer dataBuffer) throws IOException { + if (this.failOnWrite) { + throw new IOException("write failed"); + } + return true; + } + + @Override + protected void writingFailed(Throwable ex) { + cancel(); + onError(ex); + } + + @Override + protected void discardData(DataBuffer dataBuffer) { + this.discardedBuffers.add(dataBuffer); + } + } + + + private static final class TestSubscription implements Subscription { + + private long demand; + + + public long getDemand() { + return this.demand; + } + + + @Override + public void request(long n) { + this.demand = (n == Long.MAX_VALUE ? n : this.demand + n); + } + + @Override + public void cancel() { + } + } + + private static final class TestResultSubscriber implements Subscriber { + + private Throwable error; + + + public Throwable getError() { + return this.error; + } + + + @Override + public void onSubscribe(Subscription subscription) { + } + + @Override + public void onNext(Void aVoid) { + } + + @Override + public void onError(Throwable ex) { + this.error = ex; + } + + @Override + public void onComplete() { + } + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/socket/adapter/AbstractListenerWebSocketSession.java b/spring-webflux/src/main/java/org/springframework/web/reactive/socket/adapter/AbstractListenerWebSocketSession.java index 48633db50c4..02ec6115c53 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/socket/adapter/AbstractListenerWebSocketSession.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/socket/adapter/AbstractListenerWebSocketSession.java @@ -250,11 +250,23 @@ public abstract class AbstractListenerWebSocketSession extends AbstractWebSoc logger.trace("Received message: " + message); } if (!this.pendingMessages.offer(message)) { - throw new IllegalStateException("Too many messages received. " + - "Please ensure WebSocketSession.receive() is subscribed to."); + discardData(); + throw new IllegalStateException( + "Too many messages. Please ensure WebSocketSession.receive() is subscribed to."); } onDataAvailable(); } + + @Override + protected void discardData() { + while (true) { + WebSocketMessage message = (WebSocketMessage) this.pendingMessages.poll(); + if (message == null) { + return; + } + message.release(); + } + } } @@ -267,6 +279,7 @@ public abstract class AbstractListenerWebSocketSession extends AbstractWebSoc if (logger.isTraceEnabled()) { logger.trace("Sending message " + message); } + // In case of IOException, onError handling should call discardData(WebSocketMessage).. return sendMessage(message); } @@ -291,6 +304,11 @@ public abstract class AbstractListenerWebSocketSession extends AbstractWebSoc } this.isReady = ready; } + + @Override + protected void discardData(WebSocketMessage message) { + message.release(); + } } }