@ -1,5 +1,5 @@
@@ -1,5 +1,5 @@
/ *
* Copyright 2019 the original author or authors .
* Copyright 2019 - 2021 the original author or authors .
*
* Licensed under the Apache License , Version 2 . 0 ( the "License" ) ;
* you may not use this file except in compliance with the License .
@ -19,10 +19,14 @@ package org.springframework.security.rsocket.core;
@@ -19,10 +19,14 @@ package org.springframework.security.rsocket.core;
import java.util.Arrays ;
import java.util.Collections ;
import java.util.List ;
import java.util.concurrent.ExecutorService ;
import java.util.concurrent.Executors ;
import io.rsocket.Payload ;
import io.rsocket.RSocket ;
import io.rsocket.metadata.WellKnownMimeType ;
import io.rsocket.util.ByteBufPayload ;
import io.rsocket.util.DefaultPayload ;
import io.rsocket.util.RSocketProxy ;
import org.junit.Test ;
import org.junit.runner.RunWith ;
@ -32,13 +36,17 @@ import org.mockito.Mock;
@@ -32,13 +36,17 @@ import org.mockito.Mock;
import org.mockito.runners.MockitoJUnitRunner ;
import org.mockito.stubbing.Answer ;
import org.reactivestreams.Publisher ;
import org.reactivestreams.Subscription ;
import reactor.core.CoreSubscriber ;
import reactor.core.publisher.Flux ;
import reactor.core.publisher.Mono ;
import reactor.test.StepVerifier ;
import reactor.test.publisher.PublisherProbe ;
import reactor.test.publisher.TestPublisher ;
import reactor.util.context.Context ;
import org.springframework.http.MediaType ;
import org.springframework.security.access.AccessDeniedException ;
import org.springframework.security.authentication.TestingAuthenticationToken ;
import org.springframework.security.core.Authentication ;
import org.springframework.security.core.context.ReactiveSecurityContextHolder ;
@ -56,6 +64,7 @@ import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException
@@ -56,6 +64,7 @@ import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException
import static org.mockito.ArgumentMatchers.any ;
import static org.mockito.ArgumentMatchers.eq ;
import static org.mockito.BDDMockito.given ;
import static org.mockito.Mockito.times ;
import static org.mockito.Mockito.verify ;
import static org.mockito.Mockito.verifyZeroInteractions ;
@ -265,6 +274,57 @@ public class PayloadInterceptorRSocketTests {
@@ -265,6 +274,57 @@ public class PayloadInterceptorRSocketTests {
verify ( this . delegate ) . requestChannel ( any ( ) ) ;
}
// gh-9345
@Test
public void requestChannelWhenInterceptorCompletesThenAllPayloadsRetained ( ) {
ExecutorService executors = Executors . newSingleThreadExecutor ( ) ;
Payload payload = ByteBufPayload . create ( "data" ) ;
Payload payloadTwo = ByteBufPayload . create ( "moredata" ) ;
Payload payloadThree = ByteBufPayload . create ( "stillmoredata" ) ;
Context ctx = Context . empty ( ) ;
Flux < Payload > payloads = this . payloadResult . flux ( ) ;
given ( this . interceptor . intercept ( any ( ) , any ( ) ) ) . willReturn ( Mono . empty ( ) )
. willReturn ( Mono . error ( ( ) - > new AccessDeniedException ( "Access Denied" ) ) ) ;
given ( this . delegate . requestChannel ( any ( ) ) ) . willAnswer ( ( invocation ) - > {
Flux < Payload > input = invocation . getArgument ( 0 ) ;
return Flux . from ( input ) . switchOnFirst ( ( signal , innerFlux ) - > innerFlux . map ( Payload : : getDataUtf8 )
. transform ( ( data ) - > Flux . < String > create ( ( emitter ) - > {
Runnable run = ( ) - > data . subscribe ( new CoreSubscriber < String > ( ) {
@Override
public void onSubscribe ( Subscription s ) {
s . request ( 3 ) ;
}
@Override
public void onNext ( String s ) {
emitter . next ( s ) ;
}
@Override
public void onError ( Throwable t ) {
emitter . error ( t ) ;
}
@Override
public void onComplete ( ) {
emitter . complete ( ) ;
}
} ) ;
executors . execute ( run ) ;
} ) ) . map ( DefaultPayload : : create ) ) ;
} ) ;
PayloadInterceptorRSocket interceptor = new PayloadInterceptorRSocket ( this . delegate ,
Arrays . asList ( this . interceptor ) , this . metadataMimeType , this . dataMimeType , ctx ) ;
StepVerifier . create ( interceptor . requestChannel ( payloads ) . doOnDiscard ( Payload . class , Payload : : release ) )
. then ( ( ) - > this . payloadResult . assertSubscribers ( ) )
. then ( ( ) - > this . payloadResult . emit ( payload , payloadTwo , payloadThree ) )
. assertNext ( ( next ) - > assertThat ( next . getDataUtf8 ( ) ) . isEqualTo ( payload . getDataUtf8 ( ) ) )
. verifyError ( AccessDeniedException . class ) ;
verify ( this . interceptor , times ( 2 ) ) . intercept ( this . exchange . capture ( ) , any ( ) ) ;
assertThat ( this . exchange . getValue ( ) . getPayload ( ) ) . isEqualTo ( payloadTwo ) ;
verify ( this . delegate ) . requestChannel ( any ( ) ) ;
}
@Test
public void requestChannelWhenInterceptorErrorsThenDelegateNotSubscribed ( ) {
RuntimeException expected = new RuntimeException ( "Oops" ) ;