diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/support/CrudMethodMetadata.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/support/CrudMethodMetadata.java new file mode 100644 index 000000000..44c6c97ce --- /dev/null +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/support/CrudMethodMetadata.java @@ -0,0 +1,45 @@ +/* + * Copyright 2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.repository.support; + +import java.lang.reflect.Method; +import java.util.Optional; + +import com.mongodb.ReadPreference; + +/** + * Interface to abstract {@link CrudMethodMetadata} that provide the {@link ReadPreference} to be used for query + * execution. + * + * @author Mark Paluch + * @since 4.2 + */ +public interface CrudMethodMetadata { + + /** + * Returns the {@link ReadPreference} to be used. + * + * @return the {@link ReadPreference} to be used. + */ + Optional getReadPreference(); + + /** + * Returns the {@link Method} that this metadata applies to. + * + * @return the {@link Method} that this metadata applies to. + */ + Method getMethod(); +} diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/support/CrudMethodMetadataPostProcessor.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/support/CrudMethodMetadataPostProcessor.java new file mode 100644 index 000000000..1c4cb75bc --- /dev/null +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/support/CrudMethodMetadataPostProcessor.java @@ -0,0 +1,233 @@ +/* + * Copyright 2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.repository.support; + +import java.lang.reflect.Method; +import java.util.HashSet; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; + +import org.aopalliance.intercept.MethodInterceptor; +import org.aopalliance.intercept.MethodInvocation; +import org.springframework.aop.TargetSource; +import org.springframework.aop.framework.ProxyFactory; +import org.springframework.beans.factory.BeanClassLoaderAware; +import org.springframework.core.NamedThreadLocal; +import org.springframework.core.annotation.AnnotatedElementUtils; +import org.springframework.data.repository.core.RepositoryInformation; +import org.springframework.data.repository.core.support.RepositoryProxyPostProcessor; +import org.springframework.lang.Nullable; +import org.springframework.transaction.support.TransactionSynchronizationManager; +import org.springframework.util.Assert; +import org.springframework.util.ClassUtils; +import org.springframework.util.ReflectionUtils; + +import com.mongodb.ReadPreference; + +/** + * {@link RepositoryProxyPostProcessor} that sets up interceptors to read metadata information from the invoked method. + * This is necessary to allow redeclaration of CRUD methods in repository interfaces and configure read preference + * information or query hints on them. + * + * @author Mark Paluch + */ +class CrudMethodMetadataPostProcessor implements RepositoryProxyPostProcessor, BeanClassLoaderAware { + + private @Nullable ClassLoader classLoader = ClassUtils.getDefaultClassLoader(); + + @Override + public void setBeanClassLoader(ClassLoader classLoader) { + this.classLoader = classLoader; + } + + @Override + public void postProcess(ProxyFactory factory, RepositoryInformation repositoryInformation) { + factory.addAdvice(new CrudMethodMetadataPopulatingMethodInterceptor(repositoryInformation)); + } + + /** + * Returns a {@link CrudMethodMetadata} proxy that will lookup the actual target object by obtaining a thread bound + * instance from the {@link TransactionSynchronizationManager} later. + */ + CrudMethodMetadata getCrudMethodMetadata() { + + ProxyFactory factory = new ProxyFactory(); + + factory.addInterface(CrudMethodMetadata.class); + factory.setTargetSource(new ThreadBoundTargetSource()); + + return (CrudMethodMetadata) factory.getProxy(this.classLoader); + } + + /** + * {@link MethodInterceptor} to build and cache {@link DefaultCrudMethodMetadata} instances for the invoked methods. + * Will bind the found information to a {@link TransactionSynchronizationManager} for later lookup. + * + * @see DefaultCrudMethodMetadata + */ + static class CrudMethodMetadataPopulatingMethodInterceptor implements MethodInterceptor { + + private static final ThreadLocal currentInvocation = new NamedThreadLocal<>( + "Current AOP method invocation"); + + private final ConcurrentMap metadataCache = new ConcurrentHashMap<>(); + private final Set implementations = new HashSet<>(); + + CrudMethodMetadataPopulatingMethodInterceptor(RepositoryInformation repositoryInformation) { + + ReflectionUtils.doWithMethods(repositoryInformation.getRepositoryInterface(), implementations::add, + method -> !repositoryInformation.isQueryMethod(method)); + } + + /** + * Return the AOP Alliance {@link MethodInvocation} object associated with the current invocation. + * + * @return the invocation object associated with the current invocation. + * @throws IllegalStateException if there is no AOP invocation in progress, or if the + * {@link CrudMethodMetadataPopulatingMethodInterceptor} was not added to this interceptor chain. + */ + static MethodInvocation currentInvocation() throws IllegalStateException { + + MethodInvocation mi = currentInvocation.get(); + + if (mi == null) + throw new IllegalStateException( + "No MethodInvocation found: Check that an AOP invocation is in progress, and that the " + + "CrudMethodMetadataPopulatingMethodInterceptor is upfront in the interceptor chain."); + return mi; + } + + @Override + public Object invoke(MethodInvocation invocation) throws Throwable { + + Method method = invocation.getMethod(); + + if (!implementations.contains(method)) { + return invocation.proceed(); + } + + MethodInvocation oldInvocation = currentInvocation.get(); + currentInvocation.set(invocation); + + try { + + CrudMethodMetadata metadata = (CrudMethodMetadata) TransactionSynchronizationManager.getResource(method); + + if (metadata != null) { + return invocation.proceed(); + } + + CrudMethodMetadata methodMetadata = metadataCache.get(method); + + if (methodMetadata == null) { + + methodMetadata = new DefaultCrudMethodMetadata(method); + CrudMethodMetadata tmp = metadataCache.putIfAbsent(method, methodMetadata); + + if (tmp != null) { + methodMetadata = tmp; + } + } + + TransactionSynchronizationManager.bindResource(method, methodMetadata); + + try { + return invocation.proceed(); + } finally { + TransactionSynchronizationManager.unbindResource(method); + } + } finally { + currentInvocation.set(oldInvocation); + } + } + } + + /** + * Default implementation of {@link CrudMethodMetadata} that will inspect the backing method for annotations. + */ + static class DefaultCrudMethodMetadata implements CrudMethodMetadata { + + private final Optional readPreference; + private final Method method; + + /** + * Creates a new {@link DefaultCrudMethodMetadata} for the given {@link Method}. + * + * @param method must not be {@literal null}. + */ + DefaultCrudMethodMetadata(Method method) { + + Assert.notNull(method, "Method must not be null"); + + this.readPreference = findReadPreference(method); + this.method = method; + } + + private Optional findReadPreference(Method method) { + + org.springframework.data.mongodb.repository.ReadPreference preference = AnnotatedElementUtils + .findMergedAnnotation(method, org.springframework.data.mongodb.repository.ReadPreference.class); + + if (preference == null) { + + preference = AnnotatedElementUtils.findMergedAnnotation(method.getDeclaringClass(), + org.springframework.data.mongodb.repository.ReadPreference.class); + } + + if (preference == null) { + return Optional.empty(); + } + + return Optional.of(com.mongodb.ReadPreference.valueOf(preference.value())); + + } + + @Override + public Optional getReadPreference() { + return readPreference; + } + + @Override + public Method getMethod() { + return method; + } + } + + private static class ThreadBoundTargetSource implements TargetSource { + + @Override + public Class getTargetClass() { + return CrudMethodMetadata.class; + } + + @Override + public boolean isStatic() { + return false; + } + + @Override + public Object getTarget() { + + MethodInvocation invocation = CrudMethodMetadataPopulatingMethodInterceptor.currentInvocation(); + return TransactionSynchronizationManager.getResource(invocation.getMethod()); + } + + @Override + public void releaseTarget(Object target) {} + } +} diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/support/MappingMongoEntityInformation.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/support/MappingMongoEntityInformation.java index ce5cda800..261b8274a 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/support/MappingMongoEntityInformation.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/support/MappingMongoEntityInformation.java @@ -94,7 +94,7 @@ public class MappingMongoEntityInformation extends PersistentEntityInform } public String getIdAttribute() { - return entityMetadata.getRequiredIdProperty().getName(); + return entityMetadata.hasIdProperty() ? entityMetadata.getRequiredIdProperty().getName() : "_id"; } @Override diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/support/MongoRepositoryFactory.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/support/MongoRepositoryFactory.java index f768d2873..967eb7666 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/support/MongoRepositoryFactory.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/support/MongoRepositoryFactory.java @@ -61,6 +61,7 @@ public class MongoRepositoryFactory extends RepositoryFactorySupport { private static final SpelExpressionParser EXPRESSION_PARSER = new SpelExpressionParser(); + private final CrudMethodMetadataPostProcessor crudMethodMetadataPostProcessor = new CrudMethodMetadataPostProcessor(); private final MongoOperations operations; private final MappingContext, MongoPersistentProperty> mappingContext; @@ -75,6 +76,15 @@ public class MongoRepositoryFactory extends RepositoryFactorySupport { this.operations = mongoOperations; this.mappingContext = mongoOperations.getConverter().getMappingContext(); + + addRepositoryProxyPostProcessor(crudMethodMetadataPostProcessor); + } + + @Override + public void setBeanClassLoader(ClassLoader classLoader) { + + super.setBeanClassLoader(classLoader); + crudMethodMetadataPostProcessor.setBeanClassLoader(classLoader); } @Override @@ -127,7 +137,13 @@ public class MongoRepositoryFactory extends RepositoryFactorySupport { MongoEntityInformation entityInformation = getEntityInformation(information.getDomainType(), information); - return getTargetRepositoryViaReflection(information, information, entityInformation, operations); + Object targetRepository = getTargetRepositoryViaReflection(information, entityInformation, operations); + + if (targetRepository instanceof SimpleMongoRepository repository) { + repository.setRepositoryMethodMetadata(crudMethodMetadataPostProcessor.getCrudMethodMetadata()); + } + + return targetRepository; } @Override diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/support/ReactiveMongoRepositoryFactory.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/support/ReactiveMongoRepositoryFactory.java index b8dd2cc99..a0526d034 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/support/ReactiveMongoRepositoryFactory.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/support/ReactiveMongoRepositoryFactory.java @@ -62,6 +62,7 @@ public class ReactiveMongoRepositoryFactory extends ReactiveRepositoryFactorySup private static final SpelExpressionParser EXPRESSION_PARSER = new SpelExpressionParser(); + private final CrudMethodMetadataPostProcessor crudMethodMetadataPostProcessor = new CrudMethodMetadataPostProcessor(); private final ReactiveMongoOperations operations; private final MappingContext, MongoPersistentProperty> mappingContext; @@ -76,7 +77,16 @@ public class ReactiveMongoRepositoryFactory extends ReactiveRepositoryFactorySup this.operations = mongoOperations; this.mappingContext = mongoOperations.getConverter().getMappingContext(); + setEvaluationContextProvider(ReactiveQueryMethodEvaluationContextProvider.DEFAULT); + addRepositoryProxyPostProcessor(crudMethodMetadataPostProcessor); + } + + @Override + public void setBeanClassLoader(ClassLoader classLoader) { + + super.setBeanClassLoader(classLoader); + crudMethodMetadataPostProcessor.setBeanClassLoader(classLoader); } @Override @@ -114,7 +124,13 @@ public class ReactiveMongoRepositoryFactory extends ReactiveRepositoryFactorySup MongoEntityInformation entityInformation = getEntityInformation(information.getDomainType(), information); - return getTargetRepositoryViaReflection(information, information, entityInformation, operations); + Object targetRepository = getTargetRepositoryViaReflection(information, entityInformation, operations); + + if (targetRepository instanceof SimpleReactiveMongoRepository repository) { + repository.setRepositoryMethodMetadata(crudMethodMetadataPostProcessor.getCrudMethodMetadata()); + } + + return targetRepository; } @Override diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/support/SimpleMongoRepository.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/support/SimpleMongoRepository.java index 4b236a6c9..5671801d7 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/support/SimpleMongoRepository.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/support/SimpleMongoRepository.java @@ -27,7 +27,6 @@ import java.util.function.UnaryOperator; import java.util.stream.Collectors; import java.util.stream.Stream; -import org.springframework.core.annotation.AnnotatedElementUtils; import org.springframework.dao.OptimisticLockingFailureException; import org.springframework.data.domain.Example; import org.springframework.data.domain.Page; @@ -42,16 +41,14 @@ import org.springframework.data.mongodb.core.MongoTemplate; import org.springframework.data.mongodb.core.query.Criteria; import org.springframework.data.mongodb.core.query.Query; import org.springframework.data.mongodb.repository.MongoRepository; -import org.springframework.data.mongodb.repository.ReadPreference; import org.springframework.data.mongodb.repository.query.MongoEntityInformation; -import org.springframework.data.repository.core.RepositoryMetadata; import org.springframework.data.support.PageableExecutionUtils; -import org.springframework.data.util.Lazy; import org.springframework.data.util.StreamUtils; import org.springframework.data.util.Streamable; import org.springframework.lang.Nullable; import org.springframework.util.Assert; +import com.mongodb.ReadPreference; import com.mongodb.client.result.DeleteResult; /** @@ -66,9 +63,9 @@ import com.mongodb.client.result.DeleteResult; */ public class SimpleMongoRepository implements MongoRepository { + private @Nullable CrudMethodMetadata crudMethodMetadata; private final MongoEntityInformation entityInformation; private final MongoOperations mongoOperations; - private final Lazy readPreference; /** * Creates a new {@link SimpleMongoRepository} for the given {@link MongoEntityInformation} and {@link MongoTemplate}. @@ -77,40 +74,12 @@ public class SimpleMongoRepository implements MongoRepository { * @param mongoOperations must not be {@literal null}. */ public SimpleMongoRepository(MongoEntityInformation metadata, MongoOperations mongoOperations) { - this(null, metadata, mongoOperations, null); - } - - /** - * Creates a new {@link SimpleMongoRepository} for the given {@link MongoEntityInformation} and {@link MongoTemplate}. - * - * @param repositoryMetadata must not be {@literal null}. - * @param metadata must not be {@literal null}. - * @param mongoOperations must not be {@literal null}. - * @since 4.2 - */ - public SimpleMongoRepository(RepositoryMetadata repositoryMetadata, MongoEntityInformation metadata, - MongoOperations mongoOperations) { - this(repositoryMetadata, metadata, mongoOperations, null); - } - - private SimpleMongoRepository(@Nullable RepositoryMetadata repositoryMetadata, MongoEntityInformation metadata, - MongoOperations mongoOperations, @Nullable Object marker) { Assert.notNull(metadata, "MongoEntityInformation must not be null"); Assert.notNull(mongoOperations, "MongoOperations must not be null"); this.entityInformation = metadata; this.mongoOperations = mongoOperations; - - this.readPreference = repositoryMetadata == null ? Lazy.empty() : Lazy.of(() -> { - ReadPreference preference = AnnotatedElementUtils - .findMergedAnnotation(repositoryMetadata.getRepositoryInterface(), ReadPreference.class); - - if (preference == null) { - return null; - } - return com.mongodb.ReadPreference.valueOf(preference.value()); - }); } // ------------------------------------------------------------------------- @@ -151,8 +120,11 @@ public class SimpleMongoRepository implements MongoRepository { Assert.notNull(id, "The given id must not be null"); + Query query = getIdQuery(id); + getReadPreference().ifPresent(query::withReadPreference); + return Optional.ofNullable( - mongoOperations.findById(id, entityInformation.getJavaType(), entityInformation.getCollectionName())); + mongoOperations.findOne(query, entityInformation.getJavaType(), entityInformation.getCollectionName())); } @Override @@ -160,7 +132,10 @@ public class SimpleMongoRepository implements MongoRepository { Assert.notNull(id, "The given id must not be null"); - return mongoOperations.exists(getIdQuery(id), entityInformation.getJavaType(), + Query query = getIdQuery(id); + getReadPreference().ifPresent(query::withReadPreference); + + return mongoOperations.exists(query, entityInformation.getJavaType(), entityInformation.getCollectionName()); } @@ -179,7 +154,10 @@ public class SimpleMongoRepository implements MongoRepository { @Override public long count() { - return mongoOperations.count(new Query(), entityInformation.getCollectionName()); + + Query query = new Query(); + getReadPreference().ifPresent(query::withReadPreference); + return mongoOperations.count(query, entityInformation.getCollectionName()); } @Override @@ -187,7 +165,9 @@ public class SimpleMongoRepository implements MongoRepository { Assert.notNull(id, "The given id must not be null"); - mongoOperations.remove(getIdQuery(id), entityInformation.getJavaType(), entityInformation.getCollectionName()); + Query query = getIdQuery(id); + getReadPreference().ifPresent(query::withReadPreference); + mongoOperations.remove(query, entityInformation.getJavaType(), entityInformation.getCollectionName()); } @Override @@ -210,7 +190,9 @@ public class SimpleMongoRepository implements MongoRepository { Assert.notNull(ids, "The given Iterable of ids must not be null"); - mongoOperations.remove(getIdQuery(ids), entityInformation.getJavaType(), entityInformation.getCollectionName()); + Query query = getIdQuery(ids); + getReadPreference().ifPresent(query::withReadPreference); + mongoOperations.remove(query, entityInformation.getJavaType(), entityInformation.getCollectionName()); } @Override @@ -223,7 +205,11 @@ public class SimpleMongoRepository implements MongoRepository { @Override public void deleteAll() { - mongoOperations.remove(new Query(), entityInformation.getCollectionName()); + + Query query = new Query(); + getReadPreference().ifPresent(query::withReadPreference); + + mongoOperations.remove(query, entityInformation.getCollectionName()); } // ------------------------------------------------------------------------- @@ -287,7 +273,7 @@ public class SimpleMongoRepository implements MongoRepository { Query query = new Query(new Criteria().alike(example)) // .collation(entityInformation.getCollation()); - readPreference.getOptional().ifPresent(query::withReadPreference); + getReadPreference().ifPresent(query::withReadPreference); return Optional .ofNullable(mongoOperations.findOne(query, example.getProbeType(), entityInformation.getCollectionName())); @@ -307,7 +293,7 @@ public class SimpleMongoRepository implements MongoRepository { Query query = new Query(new Criteria().alike(example)) // .collation(entityInformation.getCollation()) // .with(sort); - readPreference.getOptional().ifPresent(query::withReadPreference); + getReadPreference().ifPresent(query::withReadPreference); return mongoOperations.find(query, example.getProbeType(), entityInformation.getCollectionName()); } @@ -320,7 +306,7 @@ public class SimpleMongoRepository implements MongoRepository { Query query = new Query(new Criteria().alike(example)) // .collation(entityInformation.getCollation()).with(pageable); // - readPreference.getOptional().ifPresent(query::withReadPreference); + getReadPreference().ifPresent(query::withReadPreference); List list = mongoOperations.find(query, example.getProbeType(), entityInformation.getCollectionName()); @@ -335,6 +321,7 @@ public class SimpleMongoRepository implements MongoRepository { Query query = new Query(new Criteria().alike(example)) // .collation(entityInformation.getCollation()); + getReadPreference().ifPresent(query::withReadPreference); return mongoOperations.count(query, example.getProbeType(), entityInformation.getCollectionName()); } @@ -346,6 +333,7 @@ public class SimpleMongoRepository implements MongoRepository { Query query = new Query(new Criteria().alike(example)) // .collation(entityInformation.getCollation()); + getReadPreference().ifPresent(query::withReadPreference); return mongoOperations.exists(query, example.getProbeType(), entityInformation.getCollectionName()); } @@ -364,6 +352,25 @@ public class SimpleMongoRepository implements MongoRepository { // Utility methods // ------------------------------------------------------------------------- + /** + * Configures a custom {@link CrudMethodMetadata} to be used to detect {@link ReadPreference}s and query hints to be + * applied to queries. + * + * @param crudMethodMetadata + */ + public void setRepositoryMethodMetadata(CrudMethodMetadata crudMethodMetadata) { + this.crudMethodMetadata = crudMethodMetadata; + } + + private Optional getReadPreference() { + + if (crudMethodMetadata == null) { + return Optional.empty(); + } + + return crudMethodMetadata.getReadPreference(); + } + private Query getIdQuery(Object id) { return new Query(getIdCriteria(id)); } @@ -375,7 +382,7 @@ public class SimpleMongoRepository implements MongoRepository { private Query getIdQuery(Iterable ids) { Query query = new Query(new Criteria(entityInformation.getIdAttribute()).in(toCollection(ids))); - readPreference.getOptional().ifPresent(query::withReadPreference); + getReadPreference().ifPresent(query::withReadPreference); return query; } @@ -390,7 +397,7 @@ public class SimpleMongoRepository implements MongoRepository { return Collections.emptyList(); } - readPreference.getOptional().ifPresent(query::withReadPreference); + getReadPreference().ifPresent(query::withReadPreference); return mongoOperations.find(query, entityInformation.getJavaType(), entityInformation.getCollectionName()); } @@ -480,7 +487,7 @@ public class SimpleMongoRepository implements MongoRepository { query.fields().include(getFieldsToInclude().toArray(new String[0])); } - readPreference.getOptional().ifPresent(query::withReadPreference); + getReadPreference().ifPresent(query::withReadPreference); query = queryCustomizer.apply(query); diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/support/SimpleReactiveMongoRepository.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/support/SimpleReactiveMongoRepository.java index 698aec3e1..e5361a34d 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/support/SimpleReactiveMongoRepository.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/support/SimpleReactiveMongoRepository.java @@ -24,12 +24,12 @@ import java.io.Serializable; import java.util.Collection; import java.util.Collections; import java.util.List; +import java.util.Optional; import java.util.function.Function; import java.util.function.UnaryOperator; import java.util.stream.Collectors; import org.reactivestreams.Publisher; -import org.springframework.core.annotation.AnnotatedElementUtils; import org.springframework.dao.IncorrectResultSizeDataAccessException; import org.springframework.dao.OptimisticLockingFailureException; import org.springframework.data.domain.Example; @@ -44,16 +44,14 @@ import org.springframework.data.mongodb.core.ReactiveMongoOperations; import org.springframework.data.mongodb.core.query.Criteria; import org.springframework.data.mongodb.core.query.Query; import org.springframework.data.mongodb.repository.ReactiveMongoRepository; -import org.springframework.data.mongodb.repository.ReadPreference; import org.springframework.data.mongodb.repository.query.MongoEntityInformation; -import org.springframework.data.repository.core.RepositoryMetadata; import org.springframework.data.repository.query.FluentQuery; -import org.springframework.data.util.Lazy; import org.springframework.data.util.StreamUtils; import org.springframework.data.util.Streamable; import org.springframework.lang.Nullable; import org.springframework.util.Assert; +import com.mongodb.ReadPreference; import com.mongodb.client.result.DeleteResult; /** @@ -69,48 +67,25 @@ import com.mongodb.client.result.DeleteResult; */ public class SimpleReactiveMongoRepository implements ReactiveMongoRepository { + private @Nullable CrudMethodMetadata crudMethodMetadata; private final MongoEntityInformation entityInformation; private final ReactiveMongoOperations mongoOperations; - private final Lazy readPreference; - - public SimpleReactiveMongoRepository(MongoEntityInformation entityInformation, - ReactiveMongoOperations mongoOperations) { - this(null, entityInformation, mongoOperations, null); - } /** * Creates a new {@link SimpleReactiveMongoRepository} for the given {@link MongoEntityInformation} and * {@link MongoTemplate}. * - * @param repositoryMetadata must not be {@literal null}. * @param entityInformation must not be {@literal null}. * @param mongoOperations must not be {@literal null}. - * @since 4.2 */ - public SimpleReactiveMongoRepository(RepositoryMetadata repositoryMetadata, - MongoEntityInformation entityInformation, ReactiveMongoOperations mongoOperations) { - this(repositoryMetadata, entityInformation, mongoOperations, null); - } - - private SimpleReactiveMongoRepository(@Nullable RepositoryMetadata repositoryMetadata, - MongoEntityInformation entityInformation, ReactiveMongoOperations mongoOperations, - @Nullable Object marker) { + public SimpleReactiveMongoRepository(MongoEntityInformation entityInformation, + ReactiveMongoOperations mongoOperations) { Assert.notNull(entityInformation, "EntityInformation must not be null"); Assert.notNull(mongoOperations, "MongoOperations must not be null"); this.entityInformation = entityInformation; this.mongoOperations = mongoOperations; - - this.readPreference = repositoryMetadata == null ? Lazy.empty() : Lazy.of(() -> { - - ReadPreference preference = AnnotatedElementUtils - .findMergedAnnotation(repositoryMetadata.getRepositoryInterface(), ReadPreference.class); - if (preference == null) { - return null; - } - return com.mongodb.ReadPreference.valueOf(preference.value()); - }); } // ------------------------------------------------------------------------- @@ -156,16 +131,22 @@ public class SimpleReactiveMongoRepository implement Assert.notNull(id, "The given id must not be null"); - return mongoOperations.findById(id, entityInformation.getJavaType(), entityInformation.getCollectionName()); + Query query = getIdQuery(id); + getReadPreference().ifPresent(query::withReadPreference); + return mongoOperations.findOne(query, entityInformation.getJavaType(), entityInformation.getCollectionName()); } @Override public Mono findById(Publisher publisher) { Assert.notNull(publisher, "The given id must not be null"); + Optional readPreference = getReadPreference(); - return Mono.from(publisher).flatMap( - id -> mongoOperations.findById(id, entityInformation.getJavaType(), entityInformation.getCollectionName())); + return Mono.from(publisher).flatMap(id -> { + Query query = getIdQuery(id); + readPreference.ifPresent(query::withReadPreference); + return mongoOperations.findOne(query, entityInformation.getJavaType(), entityInformation.getCollectionName()); + }); } @Override @@ -173,17 +154,22 @@ public class SimpleReactiveMongoRepository implement Assert.notNull(id, "The given id must not be null"); - return mongoOperations.exists(getIdQuery(id), entityInformation.getJavaType(), - entityInformation.getCollectionName()); + Query query = getIdQuery(id); + getReadPreference().ifPresent(query::withReadPreference); + return mongoOperations.exists(query, entityInformation.getJavaType(), entityInformation.getCollectionName()); } @Override public Mono existsById(Publisher publisher) { Assert.notNull(publisher, "The given id must not be null"); + Optional readPreference = getReadPreference(); - return Mono.from(publisher).flatMap(id -> mongoOperations.exists(getIdQuery(id), entityInformation.getJavaType(), - entityInformation.getCollectionName())); + return Mono.from(publisher).flatMap(id -> { + Query query = getIdQuery(id); + readPreference.ifPresent(query::withReadPreference); + return mongoOperations.exists(query, entityInformation.getJavaType(), entityInformation.getCollectionName()); + }); } @Override @@ -204,12 +190,20 @@ public class SimpleReactiveMongoRepository implement Assert.notNull(ids, "The given Publisher of Id's must not be null"); - return Flux.from(ids).buffer().flatMap(this::findAllById); + Optional readPreference = getReadPreference(); + return Flux.from(ids).buffer().flatMap(listOfIds -> { + Query query = getIdQuery(listOfIds); + readPreference.ifPresent(query::withReadPreference); + return mongoOperations.find(query, entityInformation.getJavaType(), entityInformation.getCollectionName()); + }); } @Override public Mono count() { - return mongoOperations.count(new Query(), entityInformation.getCollectionName()); + + Query query = new Query(); + getReadPreference().ifPresent(query::withReadPreference); + return mongoOperations.count(query, entityInformation.getCollectionName()); } @Override @@ -217,8 +211,16 @@ public class SimpleReactiveMongoRepository implement Assert.notNull(id, "The given id must not be null"); - return mongoOperations - .remove(getIdQuery(id), entityInformation.getJavaType(), entityInformation.getCollectionName()).then(); + return deleteById(id, getReadPreference()); + } + + private Mono deleteById(ID id, Optional readPreference) { + + Assert.notNull(id, "The given id must not be null"); + + Query query = getIdQuery(id); + readPreference.ifPresent(query::withReadPreference); + return mongoOperations.remove(query, entityInformation.getJavaType(), entityInformation.getCollectionName()).then(); } @Override @@ -226,8 +228,13 @@ public class SimpleReactiveMongoRepository implement Assert.notNull(publisher, "Id must not be null"); - return Mono.from(publisher).flatMap(id -> mongoOperations.remove(getIdQuery(id), entityInformation.getJavaType(), - entityInformation.getCollectionName())).then(); + Optional readPreference = getReadPreference(); + + return Mono.from(publisher).flatMap(id -> { + Query query = getIdQuery(id); + readPreference.ifPresent(query::withReadPreference); + return mongoOperations.remove(query, entityInformation.getJavaType(), entityInformation.getCollectionName()); + }).then(); } @Override @@ -260,8 +267,9 @@ public class SimpleReactiveMongoRepository implement Assert.notNull(ids, "The given Iterable of Id's must not be null"); - return mongoOperations - .remove(getIdQuery(ids), entityInformation.getJavaType(), entityInformation.getCollectionName()).then(); + Query query = getIdQuery(ids); + getReadPreference().ifPresent(query::withReadPreference); + return mongoOperations.remove(query, entityInformation.getJavaType(), entityInformation.getCollectionName()).then(); } @Override @@ -274,9 +282,9 @@ public class SimpleReactiveMongoRepository implement Criteria idsInCriteria = where(entityInformation.getIdAttribute()).in(idCollection); - return mongoOperations - .remove(new Query(idsInCriteria), entityInformation.getJavaType(), entityInformation.getCollectionName()) - .then(); + Query query = new Query(idsInCriteria); + getReadPreference().ifPresent(query::withReadPreference); + return mongoOperations.remove(query, entityInformation.getJavaType(), entityInformation.getCollectionName()).then(); } @Override @@ -284,15 +292,18 @@ public class SimpleReactiveMongoRepository implement Assert.notNull(entityStream, "The given Publisher of entities must not be null"); + Optional readPreference = getReadPreference(); return Flux.from(entityStream)// .map(entityInformation::getRequiredId)// - .flatMap(this::deleteById)// + .flatMap(id -> deleteById(id, readPreference))// .then(); } @Override public Mono deleteAll() { - return mongoOperations.remove(new Query(), entityInformation.getCollectionName()).then(Mono.empty()); + Query query = new Query(); + getReadPreference().ifPresent(query::withReadPreference); + return mongoOperations.remove(query, entityInformation.getCollectionName()).then(Mono.empty()); } // ------------------------------------------------------------------------- @@ -349,7 +360,7 @@ public class SimpleReactiveMongoRepository implement Query query = new Query(new Criteria().alike(example)) // .collation(entityInformation.getCollation()) // .limit(2); - readPreference.getOptional().ifPresent(query::withReadPreference); + getReadPreference().ifPresent(query::withReadPreference); return mongoOperations.find(query, example.getProbeType(), entityInformation.getCollectionName()).buffer(2) .map(vals -> { @@ -378,7 +389,7 @@ public class SimpleReactiveMongoRepository implement Query query = new Query(new Criteria().alike(example)) // .collation(entityInformation.getCollation()) // .with(sort); - readPreference.getOptional().ifPresent(query::withReadPreference); + getReadPreference().ifPresent(query::withReadPreference); return mongoOperations.find(query, example.getProbeType(), entityInformation.getCollectionName()); } @@ -390,6 +401,7 @@ public class SimpleReactiveMongoRepository implement Query query = new Query(new Criteria().alike(example)) // .collation(entityInformation.getCollation()); + getReadPreference().ifPresent(query::withReadPreference); return mongoOperations.count(query, example.getProbeType(), entityInformation.getCollectionName()); } @@ -401,6 +413,7 @@ public class SimpleReactiveMongoRepository implement Query query = new Query(new Criteria().alike(example)) // .collation(entityInformation.getCollation()); + getReadPreference().ifPresent(query::withReadPreference); return mongoOperations.exists(query, example.getProbeType(), entityInformation.getCollectionName()); } @@ -412,7 +425,27 @@ public class SimpleReactiveMongoRepository implement Assert.notNull(example, "Sample must not be null"); Assert.notNull(queryFunction, "Query function must not be null"); - return queryFunction.apply(new ReactiveFluentQueryByExample<>(example, example.getProbeType())); + return queryFunction + .apply(new ReactiveFluentQueryByExample<>(example, example.getProbeType(), getReadPreference())); + } + + /** + * Configures a custom {@link CrudMethodMetadata} to be used to detect {@link ReadPreference}s and query hints to be + * applied to queries. + * + * @param crudMethodMetadata + */ + public void setRepositoryMethodMetadata(CrudMethodMetadata crudMethodMetadata) { + this.crudMethodMetadata = crudMethodMetadata; + } + + private Optional getReadPreference() { + + if (crudMethodMetadata == null) { + return Optional.empty(); + } + + return crudMethodMetadata.getReadPreference(); } private Query getIdQuery(Object id) { @@ -434,7 +467,7 @@ public class SimpleReactiveMongoRepository implement private Flux findAll(Query query) { - readPreference.getOptional().ifPresent(query::withReadPreference); + getReadPreference().ifPresent(query::withReadPreference); return mongoOperations.find(query, entityInformation.getJavaType(), entityInformation.getCollectionName()); } @@ -446,19 +479,22 @@ public class SimpleReactiveMongoRepository implement */ class ReactiveFluentQueryByExample extends ReactiveFluentQuerySupport, T> { - ReactiveFluentQueryByExample(Example example, Class resultType) { - this(example, Sort.unsorted(), 0, resultType, Collections.emptyList()); + private final Optional readPreference; + + ReactiveFluentQueryByExample(Example example, Class resultType, Optional readPreference) { + this(example, Sort.unsorted(), 0, resultType, Collections.emptyList(), readPreference); } ReactiveFluentQueryByExample(Example example, Sort sort, int limit, Class resultType, - List fieldsToInclude) { + List fieldsToInclude, Optional readPreference) { super(example, sort, limit, resultType, fieldsToInclude); + this.readPreference = readPreference; } @Override protected ReactiveFluentQueryByExample create(Example predicate, Sort sort, int limit, Class resultType, List fieldsToInclude) { - return new ReactiveFluentQueryByExample<>(predicate, sort, limit, resultType, fieldsToInclude); + return new ReactiveFluentQueryByExample<>(predicate, sort, limit, resultType, fieldsToInclude, readPreference); } @Override @@ -520,7 +556,7 @@ public class SimpleReactiveMongoRepository implement query.fields().include(getFieldsToInclude().toArray(new String[0])); } - readPreference.getOptional().ifPresent(query::withReadPreference); + readPreference.ifPresent(query::withReadPreference); query = queryCustomizer.apply(query); diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/support/MongoRepositoryFactoryUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/support/MongoRepositoryFactoryUnitTests.java index 1a95379fe..c61930ec2 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/support/MongoRepositoryFactoryUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/support/MongoRepositoryFactoryUnitTests.java @@ -19,22 +19,25 @@ import static org.assertj.core.api.Assertions.*; import static org.mockito.Mockito.*; import java.io.Serializable; +import java.util.Optional; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import org.mockito.junit.jupiter.MockitoSettings; import org.mockito.quality.Strictness; - -import org.springframework.data.mapping.context.MappingContext; import org.springframework.data.mongodb.core.MongoTemplate; +import org.springframework.data.mongodb.core.convert.MappingMongoConverter; import org.springframework.data.mongodb.core.convert.MongoConverter; -import org.springframework.data.mongodb.core.mapping.MongoPersistentEntity; +import org.springframework.data.mongodb.core.convert.NoOpDbRefResolver; +import org.springframework.data.mongodb.core.mapping.MongoMappingContext; +import org.springframework.data.mongodb.core.query.Query; import org.springframework.data.mongodb.repository.Person; +import org.springframework.data.mongodb.repository.ReadPreference; import org.springframework.data.mongodb.repository.query.MongoEntityInformation; -import org.springframework.data.projection.SpelAwareProxyProjectionFactory; import org.springframework.data.repository.Repository; /** @@ -48,43 +51,46 @@ public class MongoRepositoryFactoryUnitTests { @Mock MongoTemplate template; - @Mock MongoConverter converter; - - @Mock @SuppressWarnings("rawtypes") MappingContext mappingContext; - - @Mock @SuppressWarnings("rawtypes") MongoPersistentEntity entity; + MongoConverter converter = new MappingMongoConverter(NoOpDbRefResolver.INSTANCE, new MongoMappingContext()); @BeforeEach - @SuppressWarnings("unchecked") public void setUp() { when(template.getConverter()).thenReturn(converter); - when(converter.getMappingContext()).thenReturn(mappingContext); - when(converter.getProjectionFactory()).thenReturn(new SpelAwareProxyProjectionFactory()); } @Test - @SuppressWarnings("unchecked") public void usesMappingMongoEntityInformationIfMappingContextSet() { - when(mappingContext.getRequiredPersistentEntity(Person.class)).thenReturn(entity); - MongoRepositoryFactory factory = new MongoRepositoryFactory(template); MongoEntityInformation entityInformation = factory.getEntityInformation(Person.class); assertThat(entityInformation instanceof MappingMongoEntityInformation).isTrue(); } @Test // DATAMONGO-385 - @SuppressWarnings("unchecked") public void createsRepositoryWithIdTypeLong() { - when(mappingContext.getRequiredPersistentEntity(Person.class)).thenReturn(entity); - MongoRepositoryFactory factory = new MongoRepositoryFactory(template); MyPersonRepository repository = factory.getRepository(MyPersonRepository.class); assertThat(repository).isNotNull(); } + @Test // GH-2971 + void considersCrudMethodMetadata() { + + MongoRepositoryFactory factory = new MongoRepositoryFactory(template); + MyPersonRepository repository = factory.getRepository(MyPersonRepository.class); + repository.findById(42L); + + ArgumentCaptor captor = ArgumentCaptor.forClass(Query.class); + verify(template).findOne(captor.capture(), eq(Person.class), eq("person")); + + Query value = captor.getValue(); + assertThat(value.getReadPreference()).isEqualTo(com.mongodb.ReadPreference.secondary()); + } + interface MyPersonRepository extends Repository { + @ReadPreference("secondary") + Optional findById(Long id); } } diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/support/ReactiveMongoRepositoryFactoryUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/support/ReactiveMongoRepositoryFactoryUnitTests.java new file mode 100644 index 000000000..49a551324 --- /dev/null +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/support/ReactiveMongoRepositoryFactoryUnitTests.java @@ -0,0 +1,80 @@ +/* + * Copyright 2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.repository.support; + +import static org.assertj.core.api.Assertions.*; +import static org.mockito.Mockito.*; + +import reactor.core.publisher.Mono; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.mockito.junit.jupiter.MockitoSettings; +import org.mockito.quality.Strictness; +import org.springframework.data.mongodb.core.ReactiveMongoTemplate; +import org.springframework.data.mongodb.core.convert.MappingMongoConverter; +import org.springframework.data.mongodb.core.convert.MongoConverter; +import org.springframework.data.mongodb.core.convert.NoOpDbRefResolver; +import org.springframework.data.mongodb.core.mapping.MongoMappingContext; +import org.springframework.data.mongodb.core.query.Query; +import org.springframework.data.mongodb.repository.Person; +import org.springframework.data.mongodb.repository.ReadPreference; +import org.springframework.data.repository.Repository; + +/** + * Unit test for {@link ReactiveMongoRepositoryFactory}. + * + * @author Mark Paluch + */ +@ExtendWith(MockitoExtension.class) +@MockitoSettings(strictness = Strictness.LENIENT) +public class ReactiveMongoRepositoryFactoryUnitTests { + + @Mock ReactiveMongoTemplate template; + + MongoConverter converter = new MappingMongoConverter(NoOpDbRefResolver.INSTANCE, new MongoMappingContext()); + + @BeforeEach + public void setUp() { + when(template.getConverter()).thenReturn(converter); + } + + @Test // GH-2971 + void considersCrudMethodMetadata() { + + when(template.findOne(any(), any(), anyString())).thenReturn(Mono.empty()); + + ReactiveMongoRepositoryFactory factory = new ReactiveMongoRepositoryFactory(template); + MyPersonRepository repository = factory.getRepository(MyPersonRepository.class); + repository.findById(42L); + + ArgumentCaptor captor = ArgumentCaptor.forClass(Query.class); + verify(template).findOne(captor.capture(), eq(Person.class), eq("person")); + + Query value = captor.getValue(); + assertThat(value.getReadPreference()).isEqualTo(com.mongodb.ReadPreference.secondary()); + } + + interface MyPersonRepository extends Repository { + + @ReadPreference("secondary") + Mono findById(Long id); + } +} diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/support/SimpleMongoRepositoryUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/support/SimpleMongoRepositoryUnitTests.java index a8ac1881a..6e2d87b50 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/support/SimpleMongoRepositoryUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/support/SimpleMongoRepositoryUnitTests.java @@ -40,11 +40,14 @@ import org.springframework.data.mongodb.core.query.Collation; import org.springframework.data.mongodb.core.query.Query; import org.springframework.data.mongodb.repository.ReadPreference; import org.springframework.data.mongodb.repository.query.MongoEntityInformation; -import org.springframework.data.repository.core.RepositoryMetadata; +import org.springframework.data.mongodb.repository.support.CrudMethodMetadataPostProcessor.DefaultCrudMethodMetadata; import org.springframework.data.repository.query.FluentQuery.FetchableFluentQuery; /** + * Unit tests for {@link SimpleMongoRepository}. + * * @author Christoph Strobl + * @author Mark Paluch */ @ExtendWith(MockitoExtension.class) public class SimpleMongoRepositoryUnitTests { @@ -144,11 +147,12 @@ public class SimpleMongoRepositoryUnitTests { @ParameterizedTest // GH-2971 @MethodSource("findAllCalls") - void shouldAddReadPreferenceToFindAllMethods(Consumer> findCall) { + void shouldAddReadPreferenceToFindAllMethods(Consumer> findCall) + throws NoSuchMethodException { - RepositoryMetadata repositoryMetadata = mock(RepositoryMetadata.class); - doReturn(TestRepositoryWithReadPreference.class).when(repositoryMetadata).getRepositoryInterface(); - repository = new SimpleMongoRepository<>(repositoryMetadata, entityInformation, mongoOperations); + repository = new SimpleMongoRepository<>(entityInformation, mongoOperations); + repository.setRepositoryMethodMetadata( + new DefaultCrudMethodMetadata(TestRepositoryWithReadPreference.class.getMethod("dummy"))); findCall.accept(repository); @@ -159,11 +163,11 @@ public class SimpleMongoRepositoryUnitTests { } @Test // GH-2971 - void shouldAddReadPreferenceToFindOne() { + void shouldAddReadPreferenceToFindOne() throws NoSuchMethodException { - RepositoryMetadata repositoryMetadata = mock(RepositoryMetadata.class); - doReturn(TestRepositoryWithReadPreference.class).when(repositoryMetadata).getRepositoryInterface(); - repository = new SimpleMongoRepository<>(repositoryMetadata, entityInformation, mongoOperations); + repository = new SimpleMongoRepository<>(entityInformation, mongoOperations); + repository.setRepositoryMethodMetadata( + new DefaultCrudMethodMetadata(TestRepositoryWithReadPreference.class.getMethod("dummy"))); repository.findOne(Example.of(new TestDummy())); @@ -174,10 +178,7 @@ public class SimpleMongoRepositoryUnitTests { } @Test // GH-2971 - void shouldAddReadPreferenceToFluentFetchable() { - - RepositoryMetadata repositoryMetadata = mock(RepositoryMetadata.class); - doReturn(TestRepositoryWithReadPreference.class).when(repositoryMetadata).getRepositoryInterface(); + void shouldAddReadPreferenceToFluentFetchable() throws NoSuchMethodException { ExecutableFind finder = mock(ExecutableFind.class); when(mongoOperations.query(any())).thenReturn(finder); @@ -185,7 +186,9 @@ public class SimpleMongoRepositoryUnitTests { when(finder.matching(any(Query.class))).thenReturn(finder); when(finder.as(any())).thenReturn(finder); - repository = new SimpleMongoRepository<>(repositoryMetadata, entityInformation, mongoOperations); + repository = new SimpleMongoRepository<>(entityInformation, mongoOperations); + repository.setRepositoryMethodMetadata( + new DefaultCrudMethodMetadata(TestRepositoryWithReadPreferenceMethod.class.getMethod("dummy"))); repository.findBy(Example.of(new TestDummy()), FetchableFluentQuery::all); @@ -227,6 +230,13 @@ public class SimpleMongoRepositoryUnitTests { @ReadPreference("secondaryPreferred") interface TestRepositoryWithReadPreference { + void dummy(); + } + + interface TestRepositoryWithReadPreferenceMethod { + + @ReadPreference("secondaryPreferred") + void dummy(); } } diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/support/SimpleReactiveMongoRepositoryUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/support/SimpleReactiveMongoRepositoryUnitTests.java index 206a025ca..32462a62c 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/support/SimpleReactiveMongoRepositoryUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/support/SimpleReactiveMongoRepositoryUnitTests.java @@ -22,6 +22,8 @@ import static org.mockito.Mockito.*; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import java.lang.reflect.Method; +import java.util.Optional; import java.util.function.Function; import java.util.stream.Stream; @@ -42,11 +44,13 @@ import org.springframework.data.mongodb.core.query.Collation; import org.springframework.data.mongodb.core.query.Query; import org.springframework.data.mongodb.repository.ReadPreference; import org.springframework.data.mongodb.repository.query.MongoEntityInformation; -import org.springframework.data.repository.core.RepositoryMetadata; import org.springframework.data.repository.query.FluentQuery; /** + * Unit tests for {@link SimpleReactiveMongoRepository}. + * * @author Christoph Strobl + * @author Mark Paluch */ @ExtendWith(MockitoExtension.class) class SimpleReactiveMongoRepositoryUnitTests { @@ -148,11 +152,21 @@ class SimpleReactiveMongoRepositoryUnitTests { @ParameterizedTest // GH-2971 @MethodSource("findAllCalls") - void shouldAddReadPreferenceToFindAllMethods(Function, Flux> findCall) { + void shouldAddReadPreferenceToFindAllMethods( + Function, Flux> findCall) { - RepositoryMetadata repositoryMetadata = mock(RepositoryMetadata.class); - doReturn(TestRepositoryWithReadPreference.class).when(repositoryMetadata).getRepositoryInterface(); - repository = new SimpleReactiveMongoRepository<>(repositoryMetadata, entityInformation, mongoOperations); + repository = new SimpleReactiveMongoRepository<>(entityInformation, mongoOperations); + repository.setRepositoryMethodMetadata(new CrudMethodMetadata() { + @Override + public Optional getReadPreference() { + return Optional.of(com.mongodb.ReadPreference.secondaryPreferred()); + } + + @Override + public Method getMethod() { + return null; + } + }); when(mongoOperations.find(any(), any(), any())).thenReturn(Flux.just("ok")); findCall.apply(repository).subscribe(); @@ -166,9 +180,18 @@ class SimpleReactiveMongoRepositoryUnitTests { @Test // GH-2971 void shouldAddReadPreferenceToFindOne() { - RepositoryMetadata repositoryMetadata = mock(RepositoryMetadata.class); - doReturn(TestRepositoryWithReadPreference.class).when(repositoryMetadata).getRepositoryInterface(); - repository = new SimpleReactiveMongoRepository<>(repositoryMetadata, entityInformation, mongoOperations); + repository = new SimpleReactiveMongoRepository<>(entityInformation, mongoOperations); + repository.setRepositoryMethodMetadata(new CrudMethodMetadata() { + @Override + public Optional getReadPreference() { + return Optional.of(com.mongodb.ReadPreference.secondaryPreferred()); + } + + @Override + public Method getMethod() { + return null; + } + }); when(mongoOperations.find(any(), any(), any())).thenReturn(Flux.just("ok")); repository.findOne(Example.of(new SimpleMongoRepositoryUnitTests.TestDummy())).subscribe(); @@ -182,10 +205,6 @@ class SimpleReactiveMongoRepositoryUnitTests { @Test // GH-2971 void shouldAddReadPreferenceToFluentFetchable() { - RepositoryMetadata repositoryMetadata = mock(RepositoryMetadata.class); - doReturn(SimpleMongoRepositoryUnitTests.TestRepositoryWithReadPreference.class).when(repositoryMetadata) - .getRepositoryInterface(); - ReactiveFind finder = mock(ReactiveFind.class); when(mongoOperations.query(any())).thenReturn(finder); when(finder.inCollection(any())).thenReturn(finder); @@ -193,7 +212,18 @@ class SimpleReactiveMongoRepositoryUnitTests { when(finder.as(any())).thenReturn(finder); when(finder.all()).thenReturn(Flux.just("ok")); - repository = new SimpleReactiveMongoRepository<>(repositoryMetadata, entityInformation, mongoOperations); + repository = new SimpleReactiveMongoRepository<>(entityInformation, mongoOperations); + repository.setRepositoryMethodMetadata(new CrudMethodMetadata() { + @Override + public Optional getReadPreference() { + return Optional.of(com.mongodb.ReadPreference.secondaryPreferred()); + } + + @Override + public Method getMethod() { + return null; + } + }); repository.findBy(Example.of(new TestDummy()), FluentQuery.ReactiveFluentQuery::all).subscribe();