diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/query/Field.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/query/Field.java index 4940a857e..eeaed6114 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/query/Field.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/query/Field.java @@ -15,10 +15,13 @@ */ package org.springframework.data.mongodb.core.query; +import java.util.Arrays; import java.util.HashMap; import java.util.Map; import java.util.Map.Entry; +import lombok.EqualsAndHashCode; + import org.bson.Document; import org.springframework.lang.Nullable; import org.springframework.util.Assert; @@ -30,6 +33,7 @@ import org.springframework.util.ObjectUtils; * @author Patryk Wasik * @author Christoph Strobl * @author Mark Paluch + * @author Owen Q */ public class Field { @@ -44,11 +48,27 @@ public class Field { return this; } + public Field includes(String... keys) { + Assert.notNull(keys, "Keys must not be null!"); + Assert.notEmpty(keys, "Keys must not be empty!"); + + Arrays.asList(keys).stream().forEach(this::include); + return this; + } + public Field exclude(String key) { criteria.put(key, Integer.valueOf(0)); return this; } + public Field excludes(String... keys) { + Assert.notNull(keys, "Keys must not be null!"); + Assert.notEmpty(keys, "Keys must not be empty!"); + + Arrays.asList(keys).stream().forEach(this::exclude); + return this; + } + public Field slice(String key, int size) { slices.put(key, Integer.valueOf(size)); return this; diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/MongoTemplateTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/MongoTemplateTests.java index f7c308a2f..96b6e1001 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/MongoTemplateTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/MongoTemplateTests.java @@ -15,18 +15,18 @@ */ package org.springframework.data.mongodb.core; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.fail; +import static org.hamcrest.Matchers.not; +import static org.hamcrest.Matchers.*; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.*; +import static org.junit.Assume.*; import static org.assertj.core.api.Assertions.*; import static org.springframework.data.mongodb.core.query.Criteria.*; import static org.springframework.data.mongodb.core.query.Query.*; import static org.springframework.data.mongodb.core.query.Update.*; -import lombok.AllArgsConstructor; -import lombok.Data; -import lombok.EqualsAndHashCode; -import lombok.NoArgsConstructor; -import lombok.Value; -import lombok.experimental.Wither; - import java.lang.reflect.InvocationTargetException; import java.math.BigDecimal; import java.math.BigInteger; @@ -38,6 +38,13 @@ import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Collectors; import java.util.stream.IntStream; +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.NoArgsConstructor; +import lombok.Value; +import lombok.experimental.Wither; + import org.bson.types.ObjectId; import org.joda.time.DateTime; import org.junit.jupiter.api.AfterEach; @@ -91,6 +98,7 @@ import org.springframework.util.StringUtils; import com.mongodb.BasicDBObject; import com.mongodb.DBObject; import com.mongodb.DBRef; +import com.mongodb.MongoClient; import com.mongodb.MongoException; import com.mongodb.ReadPreference; import com.mongodb.WriteConcern; @@ -3669,6 +3677,27 @@ public class MongoTemplateTests { assertThat(target.inner.id).isEqualTo(innerId); } + @Test // DATAMONGO-2294 + public void shouldProjectWithCollections() { + + MyPerson person = new MyPerson("Walter"); + person.address = new Address("TX", "Austin"); + template.save(person); + + Query queryByChainedInclude = query(where("name").is("Walter")); + queryByChainedInclude.fields().include("id").include("name"); + + Query queryByCollectionInclude = query(where("name").is("Walter")); + queryByCollectionInclude.fields().includes("id", "name"); + + MyPerson first = template.findAndReplace(queryByChainedInclude, new MyPerson("Walter")); + MyPerson second = template.findAndReplace(queryByCollectionInclude, new MyPerson("Walter")); + + assertThat(first).isEqualTo(second); + assertThat(first.address).isNull(); + assertThat(second.address).isNull(); + } + @Test // DATAMONGO-2451 public void sortOnIdFieldWithExplicitTypeShouldWork() { diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/query/FieldUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/query/FieldUnitTests.java index ee312f784..e48dae245 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/query/FieldUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/query/FieldUnitTests.java @@ -23,6 +23,7 @@ import org.junit.jupiter.api.Test; * Unit tests for {@link DocumentField}. * * @author Oliver Gierke + * @author Owen Q */ public class FieldUnitTests { @@ -36,6 +37,16 @@ public class FieldUnitTests { assertThat(right).isEqualTo(left); } + @Test // DATAMONGO-2294 + public void sameObjectSetupCreatesEqualFieldByCollections() { + + Field left = new Field().includes("foo", "bar"); + Field right = new Field().include("foo").include("bar"); + + assertThat(left, is(right)); + assertThat(right, is(left)); + } + @Test public void differentObjectSetupCreatesEqualField() { @@ -45,4 +56,14 @@ public class FieldUnitTests { assertThat(left).isNotEqualTo(right); assertThat(right).isNotEqualTo(left); } + + @Test // DATAMONGO-2294 + public void differentObjectSetupCreatesEqualFieldByCollections() { + + Field left = new Field().includes("foo", "bar"); + Field right = new Field().include("foo").include("zoo"); + + assertThat(left, is(not(right))); + assertThat(right, is(not(left))); + } }