diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/ReactiveMongoTemplate.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/ReactiveMongoTemplate.java index 4fec00cf6..0d74a6d0a 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/ReactiveMongoTemplate.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/ReactiveMongoTemplate.java @@ -20,11 +20,14 @@ import static org.springframework.data.mongodb.core.query.SerializationUtils.*; import lombok.AccessLevel; import lombok.NonNull; import lombok.RequiredArgsConstructor; +import reactor.core.CoreSubscriber; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import reactor.util.context.Context; import reactor.util.function.Tuple2; import reactor.util.function.Tuples; +import java.lang.reflect.Field; import java.util.*; import java.util.concurrent.TimeUnit; import java.util.function.Consumer; @@ -70,7 +73,16 @@ import org.springframework.data.mongodb.core.aggregation.AggregationOptions; import org.springframework.data.mongodb.core.aggregation.PrefixingDelegatingAggregationOperationContext; import org.springframework.data.mongodb.core.aggregation.TypeBasedAggregationOperationContext; import org.springframework.data.mongodb.core.aggregation.TypedAggregation; -import org.springframework.data.mongodb.core.convert.*; +import org.springframework.data.mongodb.core.convert.DbRefResolver; +import org.springframework.data.mongodb.core.convert.JsonSchemaMapper; +import org.springframework.data.mongodb.core.convert.MappingMongoConverter; +import org.springframework.data.mongodb.core.convert.MongoConverter; +import org.springframework.data.mongodb.core.convert.MongoCustomConversions; +import org.springframework.data.mongodb.core.convert.MongoJsonSchemaMapper; +import org.springframework.data.mongodb.core.convert.MongoWriter; +import org.springframework.data.mongodb.core.convert.NoOpDbRefResolver; +import org.springframework.data.mongodb.core.convert.QueryMapper; +import org.springframework.data.mongodb.core.convert.UpdateMapper; import org.springframework.data.mongodb.core.index.MongoMappingEventPublisher; import org.springframework.data.mongodb.core.index.ReactiveIndexOperations; import org.springframework.data.mongodb.core.index.ReactiveMongoPersistentEntityIndexCreator; @@ -102,6 +114,7 @@ import org.springframework.util.Assert; import org.springframework.util.ClassUtils; import org.springframework.util.CollectionUtils; import org.springframework.util.ObjectUtils; +import org.springframework.util.ReflectionUtils; import org.springframework.util.ResourceUtils; import org.springframework.util.StringUtils; @@ -113,11 +126,29 @@ import com.mongodb.Mongo; import com.mongodb.MongoException; import com.mongodb.ReadPreference; import com.mongodb.WriteConcern; -import com.mongodb.client.model.*; +import com.mongodb.client.model.CountOptions; +import com.mongodb.client.model.CreateCollectionOptions; +import com.mongodb.client.model.DeleteOptions; +import com.mongodb.client.model.FindOneAndDeleteOptions; +import com.mongodb.client.model.FindOneAndReplaceOptions; +import com.mongodb.client.model.FindOneAndUpdateOptions; +import com.mongodb.client.model.ReplaceOptions; +import com.mongodb.client.model.ReturnDocument; +import com.mongodb.client.model.UpdateOptions; +import com.mongodb.client.model.ValidationOptions; import com.mongodb.client.model.changestream.FullDocument; import com.mongodb.client.result.DeleteResult; import com.mongodb.client.result.UpdateResult; -import com.mongodb.reactivestreams.client.*; +import com.mongodb.reactivestreams.client.AggregatePublisher; +import com.mongodb.reactivestreams.client.ChangeStreamPublisher; +import com.mongodb.reactivestreams.client.ClientSession; +import com.mongodb.reactivestreams.client.DistinctPublisher; +import com.mongodb.reactivestreams.client.FindPublisher; +import com.mongodb.reactivestreams.client.MapReducePublisher; +import com.mongodb.reactivestreams.client.MongoClient; +import com.mongodb.reactivestreams.client.MongoCollection; +import com.mongodb.reactivestreams.client.MongoDatabase; +import com.mongodb.reactivestreams.client.Success; /** * Primary implementation of {@link ReactiveMongoOperations}. It simplifies the use of Reactive MongoDB usage and helps @@ -581,9 +612,98 @@ public class ReactiveMongoTemplate implements ReactiveMongoOperations, Applicati Mono> collectionPublisher = Mono .fromCallable(() -> getAndPrepareCollection(doGetDatabase(), collectionName)); - return collectionPublisher.flatMapMany(callback::doInCollection).onErrorMap(translateException()); + Flux source = collectionPublisher.flatMapMany(callback::doInCollection).onErrorMap(translateException()); + + + return new Flux() { + + @Override + public void subscribe(CoreSubscriber actual) { + + Long skip = extractSkip(actual); + Long take = extractLimit(actual); + + System.out.println(String.format("Setting offset %s and limit: %s", skip, take)); + + Context context = Context.empty(); + + // and here we use the original Flux and evaluate skip / take in the template + if (skip != null && skip > 0L) { + context = context.put("skip", skip); + } + if (take != null && take > 0L) { + context = context.put("take", take); + } + + + source.subscriberContext(context).subscribe(actual); + } + }; + } + + // --> HACKING + + @Nullable + static Long extractSkip(Subscriber subscriber) { + + if (subscriber == null || !ClassUtils.getShortName(subscriber.getClass()).endsWith("SkipSubscriber")) { + return null; + } + + java.lang.reflect.Field field = ReflectionUtils.findField(subscriber.getClass(), "remaining"); + if (field == null) { + return null; + } + + ReflectionUtils.makeAccessible(field); + Long skip = (Long) ReflectionUtils.getField(field, subscriber); + if (skip != null && skip > 0L) { + + // reset the field, otherwise we'd skip stuff in the code. + ReflectionUtils.setField(field, subscriber, 0L); + } + + return skip; + } + + @Nullable + static Long extractLimit(Subscriber subscriber) { + + if (subscriber == null) { + return null; + } + + if (!ClassUtils.getShortName(subscriber.getClass()).endsWith("TakeSubscriber")) { + return extractLimit(extractPotentialTakeSubscriber(subscriber)); + } + + java.lang.reflect.Field field = ReflectionUtils.findField(subscriber.getClass(), "n"); + if (field == null) { + return null; + } + + ReflectionUtils.makeAccessible(field); + return (Long) ReflectionUtils.getField(field, subscriber); + } + + @Nullable + static Subscriber extractPotentialTakeSubscriber(Subscriber subscriber) { + + if (!ClassUtils.getShortName(subscriber.getClass()).endsWith("SkipSubscriber")) { + return null; + } + + Field field = ReflectionUtils.findField(subscriber.getClass(), "actual"); + if (field == null) { + return null; + } + + ReflectionUtils.makeAccessible(field); + return (Subscriber) ReflectionUtils.getField(field, subscriber); } + // <--- HACKING + /** * Create a reusable {@link Mono} for the {@code collectionName} and {@link ReactiveCollectionCallback}. * @@ -2539,12 +2659,33 @@ public class ReactiveMongoTemplate implements ReactiveMongoOperations, Applicati return createFlux(collectionName, collection -> { - FindPublisher findPublisher = collectionCallback.doInCollection(collection); + return Mono.subscriberContext().flatMapMany(context -> { + + FindPublisher findPublisher = collectionCallback.doInCollection(collection); + + if (preparer != null) { + findPublisher = preparer.prepare(findPublisher); + } + + Long skip = context.getOrDefault("skip", null); + Long take = context.getOrDefault("take", null); + + System.out.println(String.format("Using offset: %s and limit: %s", skip, take)); + + if(skip != null && skip > 0L) { + findPublisher = findPublisher.skip(skip.intValue()); + } + + if(take != null && take > 0L) { + findPublisher = findPublisher.limit(take.intValue()); + } + + return Flux.from(findPublisher).doOnNext(System.out::println).map(objectCallback::doWith); + + }); + + - if (preparer != null) { - findPublisher = preparer.prepare(findPublisher); - } - return Flux.from(findPublisher).map(objectCallback::doWith); }); } diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/Fluxperiment.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/Fluxperiment.java new file mode 100644 index 000000000..8544b8111 --- /dev/null +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/Fluxperiment.java @@ -0,0 +1,240 @@ +/* + * Copyright 2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.core; + +import static org.assertj.core.api.Assertions.*; + +import reactor.core.CoreSubscriber; +import reactor.core.publisher.Flux; +import reactor.test.StepVerifier; +import reactor.util.context.Context; + +import java.lang.reflect.Field; +import java.util.stream.Stream; + +import org.junit.Test; +import org.reactivestreams.Subscriber; +import org.springframework.lang.Nullable; +import org.springframework.util.ClassUtils; +import org.springframework.util.ReflectionUtils; + +/** + * @author Christoph Strobl + */ +public class Fluxperiment { + + @Test + public void applySkipFromFlux() { + + hackedFlux().skip(3) // + .as(StepVerifier::create) // + .expectAccessibleContext().assertThat(ctx -> { + + assertThat(ctx.getOrEmpty("skip")).contains(3L); + assertThat(ctx.getOrEmpty("take")).isEmpty(); + }).then() // + .expectNext("4") // + .expectNext("5") // + .verifyComplete(); + } + + @Test + public void applyTakeFromFlux() { + + hackedFlux().limitRequest(3) // + .as(StepVerifier::create) // + .expectAccessibleContext().assertThat(ctx -> { + + assertThat(ctx.getOrEmpty("skip")).isEmpty(); + assertThat(ctx.getOrEmpty("take")).contains(3L); + }).then() // + .expectNext("1") // + .expectNext("2") // + .expectNext("3") // + .verifyComplete(); + } + + @Test + public void applySkipAndLimitFromFlux/* in that order */() { + + hackedFlux().skip(1) /* in DB */.limitRequest(2) /* in DB */ // + .as(StepVerifier::create) // + .expectAccessibleContext().assertThat(ctx -> { + + assertThat(ctx.getOrEmpty("skip")).contains(1L); + assertThat(ctx.getOrEmpty("take")).contains(2L); + }).then() // + .expectNext("2") // + .expectNext("3") // + .verifyComplete(); + } + + @Test + public void applyTakeButNotSkipFromFlux/* cause order matters */() { + + hackedFlux().limitRequest(3)/* in DB */.skip(1) /* in memory */ // + .as(StepVerifier::create) // + .expectAccessibleContext().assertThat(ctx -> { + + assertThat(ctx.getOrEmpty("skip")).isEmpty(); + assertThat(ctx.getOrEmpty("take")).contains(3L); + }).then() // + .expectNext("2") // + .expectNext("3") // + .verifyComplete(); + } + + @Test + public void justApplySkipButNotTakeIfTheyDoNotFollowOneAnother() { + + hackedFlux().skip(1)/* in DB */.map(v -> v).limitRequest(2) /* in memory */ // + .as(StepVerifier::create) // + .expectAccessibleContext().assertThat(ctx -> { + + assertThat(ctx.getOrEmpty("skip")).contains(1L); + assertThat(ctx.getOrEmpty("take")).isEmpty(); + }).then() // + .expectNext("2") // + .expectNext("3") // + .verifyComplete(); + } + + @Test + public void applyNeitherSkipNorTakeIfPrecededWithOtherOperator() { + + hackedFlux().map(v -> v).skip(1).limitRequest(2) // + .as(StepVerifier::create) // + .expectAccessibleContext().assertThat(ctx -> { + + assertThat(ctx.getOrEmpty("skip")).isEmpty(); + assertThat(ctx.getOrEmpty("take")).isEmpty(); + }).then() // + .expectNext("2") // + .expectNext("3") // + .verifyComplete(); + } + + @Test + public void applyOnlyFirstSkip/* toDatabase */() { + + hackedFlux().skip(3)/* in DB */.skip(1)/* in memory */ // + .as(StepVerifier::create) // + .expectAccessibleContext().assertThat(ctx -> { + + assertThat(ctx.getOrEmpty("skip")).contains(3L); + assertThat(ctx.getOrEmpty("take")).isEmpty(); + }).then() // + .expectNext("5") // + .verifyComplete(); + } + + Flux hackedFlux() { + + return new Flux() { + + @Override + public void subscribe(CoreSubscriber actual) { + + Long skip = extractSkip(actual); + Long take = extractLimit(actual); + + System.out.println(String.format("Using offset: %s and limit: %s", skip, take)); + + // and here we use the original Flux and evaluate skip / take in the template + Stream source = Stream.of("1", "2", "3", "4", "5"); + Context context = Context.empty(); + + // and here we use the original Flux and evaluate skip / take in the template + if (skip != null && skip > 0L) { + context = context.put("skip", skip); + source = source.skip(skip); + } + if (take != null && take > 0L) { + + context = context.put("take", take); + source = source.limit(take); + } + + Flux.fromStream(source).subscriberContext(context).subscribe(actual); + + } + }; + } + + @Nullable + static Long extractSkip(Subscriber subscriber) { + + if (subscriber == null || !ClassUtils.getShortName(subscriber.getClass()).endsWith("SkipSubscriber")) { + return null; + } + + java.lang.reflect.Field field = ReflectionUtils.findField(subscriber.getClass(), "remaining"); + if (field == null) { + return null; + } + + ReflectionUtils.makeAccessible(field); + Long skip = (Long) ReflectionUtils.getField(field, subscriber); + if (skip != null && skip > 0L) { + + // reset the field, otherwise we'd skip stuff in the code. + ReflectionUtils.setField(field, subscriber, 0L); + } + + return skip; + } + + @Nullable + static Long extractLimit(Subscriber subscriber) { + + if (subscriber == null) { + return null; + } + + if (!ClassUtils.getShortName(subscriber.getClass()).endsWith("TakeSubscriber") + && !ClassUtils.getShortName(subscriber.getClass()).endsWith("FluxLimitRequestSubscriber")) { + return extractLimit(extractPotentialTakeSubscriber(subscriber)); + } + + java.lang.reflect.Field field = ReflectionUtils.findField(subscriber.getClass(), "n"); // from TakeSubscriber + if (field == null) { + + field = ReflectionUtils.findField(subscriber.getClass(), "toProduce"); // from FluxLimitRequestSubscriber + if (field == null) { + return null; + } + } + + ReflectionUtils.makeAccessible(field); + return (Long) ReflectionUtils.getField(field, subscriber); + } + + @Nullable + static Subscriber extractPotentialTakeSubscriber(Subscriber subscriber) { + + if (!ClassUtils.getShortName(subscriber.getClass()).endsWith("SkipSubscriber")) { + return null; + } + + Field field = ReflectionUtils.findField(subscriber.getClass(), "actual"); + if (field == null) { + return null; + } + + ReflectionUtils.makeAccessible(field); + return (Subscriber) ReflectionUtils.getField(field, subscriber); + } +} diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/ReactiveMongoTemplateTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/ReactiveMongoTemplateTests.java index 40bba381e..21172638b 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/ReactiveMongoTemplateTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/ReactiveMongoTemplateTests.java @@ -1751,6 +1751,22 @@ public class ReactiveMongoTemplateTests { .verify(); } + @Test + public void fluxperiment() { + + List people = Arrays.asList(new Person("Dick", 22), new Person("Harry", 23), new Person("Tom", 21)); + + StepVerifier.create(template.insertAll(people)).expectNextCount(3).verifyComplete(); + + template.find(new Query().skip(2).limit(1), Person.class); + + template.findAll(Person.class).skip(2).take(1) // + .as(StepVerifier::create) // + .consumeNextWith(System.out::println) // + .verifyComplete(); + + } + private PersonWithAList createPersonWithAList(String firstname, int age) { PersonWithAList p = new PersonWithAList();