From b189e0370a799f948d1a26f990a3b5b6fc4ec1fc Mon Sep 17 00:00:00 2001 From: Josh Cummings Date: Fri, 28 May 2021 12:18:15 -0600 Subject: [PATCH] PayloadInterceptorRSocket retains all payloads Flux#skip discards its corresponding elements, meaning that they aren't intended for reuse. When using RSocket's ByteBufPayloads, this means that the bytes are releaseed back into RSocket's pool. Since the downstream request may still need the skipped payload, we should construct the publisher in a different way so as to avoid the preemptive release. Deferring Spring JavaFormat to clarify what changed. Closes gh-9345 --- .../core/PayloadInterceptorRSocket.java | 11 ++-- .../core/PayloadInterceptorRSocketTests.java | 62 ++++++++++++++++++- 2 files changed, 68 insertions(+), 5 deletions(-) diff --git a/rsocket/src/main/java/org/springframework/security/rsocket/core/PayloadInterceptorRSocket.java b/rsocket/src/main/java/org/springframework/security/rsocket/core/PayloadInterceptorRSocket.java index 418fb67121..0c146ea9d2 100644 --- a/rsocket/src/main/java/org/springframework/security/rsocket/core/PayloadInterceptorRSocket.java +++ b/rsocket/src/main/java/org/springframework/security/rsocket/core/PayloadInterceptorRSocket.java @@ -1,5 +1,5 @@ /* - * Copyright 2019 the original author or authors. + * Copyright 2019-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -104,15 +104,18 @@ class PayloadInterceptorRSocket extends RSocketProxy implements ResponderRSocket return intercept(PayloadExchangeType.REQUEST_CHANNEL, firstPayload) .flatMapMany(context -> innerFlux - .skip(1) - .flatMap(p -> intercept(PayloadExchangeType.PAYLOAD, p).thenReturn(p)) - .transform(securedPayloads -> Flux.concat(Flux.just(firstPayload), securedPayloads)) + .index() + .concatMap(tuple -> justOrIntercept(tuple.getT1(), tuple.getT2())) .transform(securedPayloads -> this.source.requestChannel(securedPayloads)) .subscriberContext(context) ); }); } + private Mono justOrIntercept(Long index, Payload payload) { + return (index == 0) ? Mono.just(payload) : intercept(PayloadExchangeType.PAYLOAD, payload).thenReturn(payload); + } + @Override public Mono metadataPush(Payload payload) { return intercept(PayloadExchangeType.METADATA_PUSH, payload) diff --git a/rsocket/src/test/java/org/springframework/security/rsocket/core/PayloadInterceptorRSocketTests.java b/rsocket/src/test/java/org/springframework/security/rsocket/core/PayloadInterceptorRSocketTests.java index a925ac676e..fa149e453a 100644 --- a/rsocket/src/test/java/org/springframework/security/rsocket/core/PayloadInterceptorRSocketTests.java +++ b/rsocket/src/test/java/org/springframework/security/rsocket/core/PayloadInterceptorRSocketTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2019 the original author or authors. + * Copyright 2019-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,6 +19,8 @@ package org.springframework.security.rsocket.core; import io.rsocket.Payload; import io.rsocket.RSocket; import io.rsocket.metadata.WellKnownMimeType; +import io.rsocket.util.ByteBufPayload; +import io.rsocket.util.DefaultPayload; import io.rsocket.util.RSocketProxy; import org.junit.Test; import org.junit.runner.RunWith; @@ -28,7 +30,9 @@ import org.mockito.Mock; import org.mockito.runners.MockitoJUnitRunner; import org.mockito.stubbing.Answer; import org.reactivestreams.Publisher; +import org.reactivestreams.Subscription; import org.springframework.http.MediaType; +import org.springframework.security.access.AccessDeniedException; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.core.context.ReactiveSecurityContextHolder; @@ -41,6 +45,8 @@ import org.springframework.security.rsocket.core.DefaultPayloadExchange; import org.springframework.security.rsocket.core.PayloadInterceptorRSocket; import org.springframework.util.MimeType; import org.springframework.util.MimeTypeUtils; +import reactor.util.context.Context; +import reactor.core.CoreSubscriber; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; @@ -50,10 +56,13 @@ import reactor.test.publisher.TestPublisher; import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.concurrent.Executors; +import java.util.concurrent.ExecutorService; import static org.assertj.core.api.Assertions.*; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyZeroInteractions; import static org.mockito.Mockito.when; @@ -315,6 +324,57 @@ public class PayloadInterceptorRSocketTests { verify(this.delegate).requestChannel(any()); } + // gh-9345 + @Test + public void requestChannelWhenInterceptorCompletesThenAllPayloadsRetained() { + ExecutorService executors = Executors.newSingleThreadExecutor(); + Payload payload = ByteBufPayload.create("data"); + Payload payloadTwo = ByteBufPayload.create("moredata"); + Payload payloadThree = ByteBufPayload.create("stillmoredata"); + Context ctx = Context.empty(); + Flux payloads = this.payloadResult.flux(); + when(this.interceptor.intercept(any(), any())).thenReturn(Mono.empty()) + .thenReturn(Mono.error(() -> new AccessDeniedException("Access Denied"))); + when(this.delegate.requestChannel(any())).thenAnswer((invocation) -> { + Flux input = invocation.getArgument(0); + return Flux.from(input).switchOnFirst((signal, innerFlux) -> innerFlux.map(Payload::getDataUtf8) + .transform((data) -> Flux.create((emitter) -> { + Runnable run = () -> data.subscribe(new CoreSubscriber() { + @Override + public void onSubscribe(Subscription s) { + s.request(3); + } + + @Override + public void onNext(String s) { + emitter.next(s); + } + + @Override + public void onError(Throwable t) { + emitter.error(t); + } + + @Override + public void onComplete() { + emitter.complete(); + } + }); + executors.execute(run); + })).map(DefaultPayload::create)); + }); + PayloadInterceptorRSocket interceptor = new PayloadInterceptorRSocket(this.delegate, + Arrays.asList(this.interceptor), this.metadataMimeType, this.dataMimeType, ctx); + StepVerifier.create(interceptor.requestChannel(payloads).doOnDiscard(Payload.class, Payload::release)) + .then(() -> this.payloadResult.assertSubscribers()) + .then(() -> this.payloadResult.emit(payload, payloadTwo, payloadThree)) + .assertNext((next) -> assertThat(next.getDataUtf8()).isEqualTo(payload.getDataUtf8())) + .verifyError(AccessDeniedException.class); + verify(this.interceptor, times(2)).intercept(this.exchange.capture(), any()); + assertThat(this.exchange.getValue().getPayload()).isEqualTo(payloadTwo); + verify(this.delegate).requestChannel(any()); + } + @Test public void requestChannelWhenInterceptorErrorsThenDelegateNotSubscribed() { RuntimeException expected = new RuntimeException("Oops");