diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/MongoDatabaseUtils.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/MongoDatabaseUtils.java index 713fc73dd..80dcd802d 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/MongoDatabaseUtils.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/MongoDatabaseUtils.java @@ -110,7 +110,7 @@ public class MongoDatabaseUtils { ClientSession session = doGetSession(factory, sessionSynchronization); - if(session == null) { + if (session == null) { return StringUtils.hasText(dbName) ? factory.getDb(dbName) : factory.getDb(); } @@ -118,6 +118,26 @@ public class MongoDatabaseUtils { return StringUtils.hasText(dbName) ? factoryToUse.getDb(dbName) : factoryToUse.getDb(); } + /** + * Check if the {@link MongoDbFactory} is actually bound to a {@link ClientSession} that has an active transaction, or + * if a {@link TransactionSynchronization} has been registered for the {@link MongoDbFactory resource} and if the + * associated {@link ClientSession} has an {@link ClientSession#hasActiveTransaction() active transaction}. + * + * @param dbFactory the resource to check transactions for. Must not be {@literal null}. + * @return {@literal true} if the factory has an ongoing transaction. + * @since 2.1.3 + */ + public static boolean isTransactionActive(MongoDbFactory dbFactory) { + + if (dbFactory.isTransactionActive()) { + return true; + } + + MongoResourceHolder resourceHolder = (MongoResourceHolder) TransactionSynchronizationManager.getResource(dbFactory); + return resourceHolder != null + && (resourceHolder.hasSession() && resourceHolder.getSession().hasActiveTransaction()); + } + @Nullable private static ClientSession doGetSession(MongoDbFactory dbFactory, SessionSynchronization sessionSynchronization) { diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/MongoDbFactory.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/MongoDbFactory.java index 15a9ae691..1a0cc19ac 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/MongoDbFactory.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/MongoDbFactory.java @@ -108,4 +108,15 @@ public interface MongoDbFactory extends CodecRegistryProvider, MongoSessionProvi * @since 2.1 */ MongoDbFactory withSession(ClientSession session); + + /** + * Returns if the given {@link MongoDbFactory} is bound to a {@link ClientSession} that has an + * {@link ClientSession#hasActiveTransaction() active transaction}. + * + * @return {@literal true} if there's an active transaction, {@literal false} otherwise. + * @since 2.1.3 + */ + default boolean isTransactionActive() { + return false; + } } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoDbFactorySupport.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoDbFactorySupport.java index 1aa8254f8..b6daf64c0 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoDbFactorySupport.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoDbFactorySupport.java @@ -233,6 +233,15 @@ public abstract class MongoDbFactorySupport implements MongoDbFactory { return delegate.withSession(session); } + /* + * (non-Javadoc) + * @see org.springframework.data.mongodb.MongoDbFactory#isTransactionActive() + */ + @Override + public boolean isTransactionActive() { + return session != null && session.hasActiveTransaction(); + } + private MongoDatabase proxyMongoDatabase(MongoDatabase database) { return createProxyInstance(session, database, MongoDatabase.class); } @@ -241,7 +250,8 @@ public abstract class MongoDbFactorySupport implements MongoDbFactory { return createProxyInstance(session, database, MongoDatabase.class); } - private MongoCollection proxyCollection(com.mongodb.session.ClientSession session, MongoCollection collection) { + private MongoCollection proxyCollection(com.mongodb.session.ClientSession session, + MongoCollection collection) { return createProxyInstance(session, collection, MongoCollection.class); } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java index 408d617ae..55dfdfb56 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java @@ -1119,7 +1119,16 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware, Document document = queryMapper.getMappedObject(query.getQueryObject(), Optional.ofNullable(entityClass).map(it -> mappingContext.getPersistentEntity(entityClass))); - return execute(collectionName, collection -> collection.count(document, options)); + return doCount(collectionName, document, options); + } + + protected long doCount(String collectionName, Document filter, CountOptions options) { + + if (!MongoDatabaseUtils.isTransactionActive(getMongoDbFactory())) { + return execute(collectionName, collection -> collection.count(filter, options)); + } + + return execute(collectionName, collection -> collection.countDocuments(filter, options)); } /* @@ -2820,21 +2829,23 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware, } /** - * Optimized {@link CollectionCallback} that takes an already mappend query and a nullable + * Optimized {@link CollectionCallback} that takes an already mapped query and a nullable * {@link com.mongodb.client.model.Collation} to execute a count query limited to one element. * * @author Christoph Strobl * @since 2.0 */ @RequiredArgsConstructor - private static class ExistsCallback implements CollectionCallback { + private class ExistsCallback implements CollectionCallback { private final Document mappedQuery; private final com.mongodb.client.model.Collation collation; @Override public Boolean doInCollection(MongoCollection collection) throws MongoException, DataAccessException { - return collection.count(mappedQuery, new CountOptions().limit(1).collation(collation)) > 0; + + return doCount(collection.getNamespace().getCollectionName(), mappedQuery, + new CountOptions().limit(1).collation(collation)) > 0; } } @@ -3343,23 +3354,16 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware, /* * (non-Javadoc) - * @see org.springframework.data.mongodb.core.MongoTemplate#count(org.springframework.data.mongodb.core.query.Query, java.lang.Class, java.lang.String) + * @see org.springframework.data.mongodb.core.MongoTemplate#doCount(java.lang.String, org.bson.Document, com.mongodb.client.model.CountOptions) */ @Override - @SuppressWarnings("unchecked") - public long count(Query query, @Nullable Class entityClass, String collectionName) { + protected long doCount(String collectionName, Document filter, CountOptions options) { if (!session.hasActiveTransaction()) { - return super.count(query, entityClass, collectionName); + return super.doCount(collectionName, filter, options); } - CountOptions options = new CountOptions(); - query.getCollation().map(Collation::toMongoCollation).ifPresent(options::collation); - - Document document = delegate.queryMapper.getMappedObject(query.getQueryObject(), - Optional.ofNullable(entityClass).map(it -> delegate.mappingContext.getPersistentEntity(entityClass))); - - return execute(collectionName, collection -> collection.countDocuments(document, options)); + return execute(collectionName, collection -> collection.countDocuments(filter, options)); } } } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/support/SimpleMongoRepository.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/support/SimpleMongoRepository.java index 39099a18b..9c4d7b46c 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/support/SimpleMongoRepository.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/support/SimpleMongoRepository.java @@ -137,7 +137,7 @@ public class SimpleMongoRepository implements MongoRepository { */ @Override public long count() { - return mongoOperations.getCollection(entityInformation.getCollectionName()).count(); + return mongoOperations.count(new Query(), entityInformation.getCollectionName()); } /* diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/MongoDatabaseUtilsUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/MongoDatabaseUtilsUnitTests.java index 96e9a61be..a85697811 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/MongoDatabaseUtilsUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/MongoDatabaseUtilsUnitTests.java @@ -78,6 +78,39 @@ public class MongoDatabaseUtilsUnitTests { assertFalse(TransactionSynchronizationManager.isActualTransactionActive()); } + @Test // DATAMONGO-2130 + public void isTransactionActiveShouldDetectTxViaFactory() { + + when(dbFactory.isTransactionActive()).thenReturn(true); + + assertThat(MongoDatabaseUtils.isTransactionActive(dbFactory)).isTrue(); + } + + @Test // DATAMONGO-2130 + public void isTransactionActiveShouldReturnFalseIfNoTxActive() { + + when(dbFactory.isTransactionActive()).thenReturn(false); + + assertThat(MongoDatabaseUtils.isTransactionActive(dbFactory)).isFalse(); + } + + @Test // DATAMONGO-2130 + public void isTransactionActiveShouldLookupTxForActiveTransactionSynchronizationViaTxManager() { + + when(dbFactory.isTransactionActive()).thenReturn(false); + + MongoTransactionManager txManager = new MongoTransactionManager(dbFactory); + TransactionTemplate txTemplate = new TransactionTemplate(txManager); + + txTemplate.execute(new TransactionCallbackWithoutResult() { + + @Override + protected void doInTransactionWithoutResult(TransactionStatus transactionStatus) { + assertThat(MongoDatabaseUtils.isTransactionActive(dbFactory)).isTrue(); + } + }); + } + @Test // DATAMONGO-1920 public void shouldNotStartSessionWhenNoTransactionOngoing() { diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/MongoTemplateUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/MongoTemplateUnitTests.java index 307e80bb0..701f31dc4 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/MongoTemplateUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/MongoTemplateUnitTests.java @@ -22,6 +22,7 @@ import static org.mockito.Mockito.any; import static org.springframework.data.mongodb.core.aggregation.Aggregation.*; import static org.springframework.data.mongodb.test.util.IsBsonObject.*; +import com.mongodb.MongoNamespace; import lombok.Data; import java.math.BigInteger; @@ -135,6 +136,7 @@ public class MongoTemplateUnitTests extends MongoOperationsUnitTests { when(collection.find(any(org.bson.Document.class), any(Class.class))).thenReturn(findIterable); when(collection.mapReduce(any(), any(), eq(Document.class))).thenReturn(mapReduceIterable); when(collection.count(any(Bson.class), any(CountOptions.class))).thenReturn(1L); + when(collection.getNamespace()).thenReturn(new MongoNamespace("db.mock-collection")); when(collection.aggregate(any(List.class), any())).thenReturn(aggregateIterable); when(collection.withReadPreference(any())).thenReturn(collection); when(findIterable.projection(any())).thenReturn(findIterable); diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/support/SimpleMongoRepositoryTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/support/SimpleMongoRepositoryTests.java index 8f2b94c3d..04df9b5bd 100755 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/support/SimpleMongoRepositoryTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/support/SimpleMongoRepositoryTests.java @@ -16,6 +16,7 @@ package org.springframework.data.mongodb.repository.support; import static org.assertj.core.api.Assertions.*; +import static org.assertj.core.api.Assumptions.*; import static org.springframework.data.domain.ExampleMatcher.*; import java.util.ArrayList; @@ -28,14 +29,16 @@ import java.util.Set; import java.util.UUID; import org.junit.Before; +import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.data.domain.Example; -import org.springframework.data.domain.ExampleMatcher.StringMatcher; import org.springframework.data.domain.Page; import org.springframework.data.domain.PageRequest; +import org.springframework.data.domain.ExampleMatcher.*; import org.springframework.data.geo.Point; +import org.springframework.data.mongodb.MongoTransactionManager; import org.springframework.data.mongodb.core.MongoTemplate; import org.springframework.data.mongodb.core.geo.GeoJsonPoint; import org.springframework.data.mongodb.core.mapping.Document; @@ -44,9 +47,13 @@ import org.springframework.data.mongodb.repository.Person; import org.springframework.data.mongodb.repository.Person.Sex; import org.springframework.data.mongodb.repository.User; import org.springframework.data.mongodb.repository.query.MongoEntityInformation; +import org.springframework.data.mongodb.test.util.MongoVersion; +import org.springframework.data.mongodb.test.util.MongoVersionRule; +import org.springframework.data.mongodb.test.util.ReplicaSet; import org.springframework.test.context.ContextConfiguration; import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; import org.springframework.test.util.ReflectionTestUtils; +import org.springframework.transaction.support.TransactionTemplate; /** * @author A. B. M. Kowser @@ -59,6 +66,7 @@ import org.springframework.test.util.ReflectionTestUtils; public class SimpleMongoRepositoryTests { @Autowired private MongoTemplate template; + public @Rule MongoVersionRule mongoVersion = MongoVersionRule.any(); private Person oliver, dave, carter, boyd, stefan, leroi, alicia; private List all; @@ -383,6 +391,54 @@ public class SimpleMongoRepositoryTests { assertThat(repository.findAll()).containsExactlyInAnyOrder(first, second); } + @Test // DATAMONGO-2130 + @MongoVersion(asOf = "4.0") + public void countShouldBePossibleInTransaction() { + + assumeThat(ReplicaSet.required().runsAsReplicaSet()).isTrue(); + + MongoTransactionManager txmgr = new MongoTransactionManager(template.getMongoDbFactory()); + TransactionTemplate tt = new TransactionTemplate(txmgr); + tt.afterPropertiesSet(); + + long countPreTx = repository.count(); + + long count = tt.execute(status -> { + + Person sample = new Person(); + sample.setLastname("Matthews"); + + repository.save(sample); + + return repository.count(); + }); + + assertThat(count).isEqualTo(countPreTx + 1); + } + + @Test // DATAMONGO-2130 + @MongoVersion(asOf = "4.0") + public void existsShouldBePossibleInTransaction() { + + assumeThat(ReplicaSet.required().runsAsReplicaSet()).isTrue(); + + MongoTransactionManager txmgr = new MongoTransactionManager(template.getMongoDbFactory()); + TransactionTemplate tt = new TransactionTemplate(txmgr); + tt.afterPropertiesSet(); + + boolean exists = tt.execute(status -> { + + Person sample = new Person(); + sample.setLastname("Matthews"); + + repository.save(sample); + + return repository.existsById(sample.getId()); + }); + + assertThat(exists).isTrue(); + } + private void assertThatAllReferencePersonsWereStoredCorrectly(Map references, List saved) { for (Person person : saved) {