diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoOperations.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoOperations.java index 6b481de4f..8f199e8ff 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoOperations.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoOperations.java @@ -47,6 +47,7 @@ import com.mongodb.WriteResult; * @author Thomas Risberg * @author Mark Pollack * @author Oliver Gierke + * @author Tobias Trelle */ public interface MongoOperations { @@ -306,13 +307,15 @@ public interface MongoOperations { /** * Execute an aggregation operation. The raw results will be mapped to the given entity class. * - * @param inputCollectionName the collection there the aggregation operation will read from. - * @param pipeline The pipeline holding the aggregation operations. - * @param entityClass The parameterized type of the returned list. + * @param inputCollectionName the collection there the aggregation operation will read from, must not be + * {@literal null} or empty. + * @param pipeline The pipeline holding the aggregation operations, must not be {@literal null}. + * @param entityClass The parameterized type of the returned list, must not be {@literal null}. * @return The results of the aggregation operation. + * @since 1.3 */ AggregationResults aggregate(String inputCollectionName, AggregationPipeline pipeline, Class entityClass); - + /** * Execute a map-reduce operation. The map-reduce operation will be formed with an output type of INLINE * diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java index 1782564da..7fed406f1 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java @@ -15,8 +15,8 @@ */ package org.springframework.data.mongodb.core; -import static org.springframework.data.mongodb.core.query.Criteria.where; -import static org.springframework.data.mongodb.core.query.SerializationUtils.serializeToJsonSafely; +import static org.springframework.data.mongodb.core.query.Criteria.*; +import static org.springframework.data.mongodb.core.query.SerializationUtils.*; import java.io.IOException; import java.util.ArrayList; @@ -116,6 +116,7 @@ import com.mongodb.util.JSONParseException; * @author Oliver Gierke * @author Amol Nayak * @author Patryk Wasik + * @author Tobias Trelle */ public class MongoTemplate implements MongoOperations, ApplicationContextAware { @@ -1210,19 +1211,21 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware { } - public AggregationResults aggregate(String inputCollectionName, AggregationPipeline pipeline, Class entityClass) { + public AggregationResults aggregate(String inputCollectionName, AggregationPipeline pipeline, + Class entityClass) { + Assert.notNull(inputCollectionName, "Collection name is missing"); Assert.notNull(pipeline, "Aggregation pipeline is missing"); Assert.notNull(entityClass, "Entity class is missing"); // prepare command - DBObject command = new BasicDBObject("aggregate", inputCollectionName ); - command.put( "pipeline", pipeline.getOperations() ); - + DBObject command = new BasicDBObject("aggregate", inputCollectionName); + command.put("pipeline", pipeline.getOperations()); + // execute command - CommandResult commandResult = executeCommand(command); + CommandResult commandResult = executeCommand(command); handleCommandError(commandResult, command); - + // map results @SuppressWarnings("unchecked") Iterable resultSet = (Iterable) commandResult.get("result"); @@ -1230,11 +1233,11 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware { DbObjectCallback callback = new ReadDbObjectCallback(mongoConverter, entityClass); for (DBObject dbObject : resultSet) { mappedResults.add(callback.doWith(dbObject)); - } - + } + return new AggregationResults(mappedResults, commandResult); } - + protected String replaceWithResourceIfNecessary(String function) { String func = function; diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationPipeline.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationPipeline.java index cdaf46165..81e229fae 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationPipeline.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationPipeline.java @@ -1,5 +1,5 @@ /* - * Copyright 2011-2012 the original author or authors. + * Copyright 2013 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -31,17 +31,18 @@ import com.mongodb.util.JSONParseException; * Holds the operations of an aggregation pipeline. * * @author Tobias Trelle + * @since 1.3 */ public class AggregationPipeline { - + private static final String OPERATOR_PREFIX = "$"; - - private List operations = new ArrayList(); + + private final List operations = new ArrayList(); /** * Adds a projection operation to the pipeline. * - * @param projection JSON string holding the projection. + * @param projection JSON string holding the projection, must not be {@literal null} or empty. * @return The pipeline. */ public AggregationPipeline project(String projection) { @@ -51,33 +52,36 @@ public class AggregationPipeline { /** * Adds a projection operation to the pipeline. * - * @param projection Type safe projection object. + * @param projection Type safe projection object, must not be {@literal null}. * @return The pipeline. */ public AggregationPipeline project(Projection projection) { - return addOperation("project", projection.toDBObject() ); + + Assert.notNull(projection, "Projection must not be null!"); + return addOperation("project", projection.toDBObject()); } - + /** * Adds an unwind operation to the pipeline. * - * @param field Name of the field to unwind (should be an array). + * @param field Name of the field to unwind (should be an array), must not be {@literal null} or empty. * @return The pipeline. */ public AggregationPipeline unwind(String field) { - Assert.notNull(field, "Missing field name"); - + + Assert.hasText(field, "Missing field name"); + if (!field.startsWith(OPERATOR_PREFIX)) { field = OPERATOR_PREFIX + field; } - + return addOperation("unwind", field); } /** * Adds a group operation to the pipeline. * - * @param projection JSON string holding the group. + * @param projection JSON string holding the group, must not be {@literal null} or empty. * @return The pipeline. */ public AggregationPipeline group(String group) { @@ -87,7 +91,7 @@ public class AggregationPipeline { /** * Adds a sort operation to the pipeline. * - * @param sort JSON string holding the sorting. + * @param sort JSON string holding the sorting, must not be {@literal null} or empty. * @return The pipeline. */ public AggregationPipeline sort(String sort) { @@ -97,12 +101,12 @@ public class AggregationPipeline { /** * Adds a sort operation to the pipeline. * - * @param sort Type safe sort operation. + * @param sort Type safe sort operation, must not be {@literal null}. * @return The pipeline. */ public AggregationPipeline sort(Sort sort) { + Assert.notNull(sort); - DBObject dbo = new BasicDBObject(); for (org.springframework.data.domain.Sort.Order order : sort) { @@ -112,9 +116,9 @@ public class AggregationPipeline { } /** - * Adds a match operation to the pipeline that is basically a query on the collection.s + * Adds a match operation to the pipeline that is basically a query on the collections. * - * @param projection JSON string holding the criteria. + * @param projection JSON string holding the criteria, must not be {@literal null} or empty. * @return The pipeline. */ public AggregationPipeline match(String match) { @@ -124,12 +128,12 @@ public class AggregationPipeline { /** * Adds a match operation to the pipeline that is basically a query on the collection.s * - * @param criteria Type safe criteria to filter documents from the collection. + * @param criteria Type safe criteria to filter documents from the collection, must not be {@literal null}. * @return The pipeline. */ public AggregationPipeline match(Criteria criteria) { + Assert.notNull(criteria); - return addOperation("match", criteria.getCriteriaObject()); } @@ -158,7 +162,8 @@ public class AggregationPipeline { } private AggregationPipeline addDocumentOperation(String opName, String operation) { - Assert.notNull(operation, "Missing " + opName); + + Assert.hasText(operation, "Missing operation name!"); return addOperation(opName, parseJson(operation)); } @@ -174,5 +179,4 @@ public class AggregationPipeline { throw new IllegalArgumentException("Not a valid JSON document: " + json, e); } } - -} \ No newline at end of file +} diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationResults.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationResults.java index c8b178584..ac48fe553 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationResults.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationResults.java @@ -1,5 +1,5 @@ /* - * Copyright 2011-2012 the original author or authors. + * Copyright 2013 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,7 +15,7 @@ */ package org.springframework.data.mongodb.core.aggregation; -import java.util.ArrayList; +import java.util.Collections; import java.util.Iterator; import java.util.List; @@ -27,49 +27,61 @@ import com.mongodb.DBObject; * Collects the results of executing an aggregation operation. * * @author Tobias Trelle - * + * @author Oliver Gierke * @param The class in which the results are mapped onto. + * @since 1.3 */ public class AggregationResults implements Iterable { private final List mappedResults; private final DBObject rawResults; + private final String serverUsed; - private String serverUsed; - + /** + * Creates a new {@link AggregationResults} instance from the given mapped and raw results. + * + * @param mappedResults must not be {@literal null}. + * @param rawResults must not be {@literal null}. + */ public AggregationResults(List mappedResults, DBObject rawResults) { + Assert.notNull(mappedResults); Assert.notNull(rawResults); - this.mappedResults = mappedResults; + + this.mappedResults = Collections.unmodifiableList(mappedResults); this.rawResults = rawResults; - parseServerUsed(); + this.serverUsed = parseServerUsed(); } + /** + * Returns the aggregation results. + * + * @return + */ public List getAggregationResult() { - List result = new ArrayList(); - Iterator it = iterator(); - - while (it.hasNext()) { - result.add(it.next()); - } - - return result; + return mappedResults; } - @Override + /* + * (non-Javadoc) + * @see java.lang.Iterable#iterator() + */ public Iterator iterator() { return mappedResults.iterator(); } - + + /** + * Returns the server that has been used to perform the aggregation. + * + * @return + */ public String getServerUsed() { return serverUsed; } - private void parseServerUsed() { + private String parseServerUsed() { + Object object = rawResults.get("serverUsed"); - if (object instanceof String) { - serverUsed = (String) object; - } + return object instanceof String ? (String) object : null; } - } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/Projection.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/Projection.java index a596199bc..c4ef22891 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/Projection.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/Projection.java @@ -1,5 +1,5 @@ /* - * Copyright 2011-2012 the original author or authors. + * Copyright 2013 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -32,33 +32,37 @@ import com.mongodb.DBObject; * Projection of field to be used in an {@link AggregationPipeline}. *

* A projection is similar to a {@link Field} inclusion/exclusion but more powerful. It can generate new fields, change - * values of given field etc. + * values of given field etc. * * @author Tobias Trelle + * @since 1.3 */ public class Projection { private static final String REFERENCE_PREFIX = "$"; - - private DBObject document = new BasicDBObject(); - private DBObject rightHandExpression; - /** Stack of key names. Size is 0 or 1. */ - private Stack reference = new Stack(); + private final Stack reference = new Stack(); + private final DBObject document = new BasicDBObject(); + + private DBObject rightHandExpression; - /** Create an empty projection. */ + /** + * Create an empty projection. + */ public Projection() { } /** - * This convenience constructor excludes the field _id and includes the given fields. + * This convenience constructor excludes the field {@code _id} and includes the given fields. * - * @param includes Keys of field to include. + * @param includes Keys of field to include, must not be {@literal null} or empty. */ public Projection(String... includes) { + Assert.notEmpty(includes); exclude("_id"); + for (String key : includes) { include(key); } @@ -69,73 +73,82 @@ public class Projection { * * @param key The key of the field. */ - public final void exclude(String key) { - Assert.notNull(key, "Missing key"); + public final Projection exclude(String key) { + + Assert.hasText(key, "Missing key"); document.put(key, 0); + return this; } /** * Includes a given field. * - * @param key The key of the field. + * @param key The key of the field, must not be {@literal null} or empty. */ public final Projection include(String key) { - Assert.notNull(key, "Missing key"); + + Assert.hasText(key, "Missing key"); safePop(); reference.push(key); - + return this; } /** * Sets the key for a computed field. * + * @param key must not be {@literal null} or empty. */ public final Projection as(String key) { - Assert.notNull(key, "Missing key"); + + Assert.hasText(key, "Missing key"); try { - document.put(key, rightHandSide(safeReference(reference.pop())) ); + document.put(key, rightHandSide(safeReference(reference.pop()))); } catch (EmptyStackException e) { throw new InvalidDataAccessApiUsageException("Invalid use of as()", e); } + return this; } public final Projection plus(Number n) { return arithmeticOperation("add", n); } - + public final Projection minus(Number n) { return arithmeticOperation("substract", n); - } - + } + private Projection arithmeticOperation(String op, Number n) { + Assert.notNull(n, "Missing number"); - + rightHandExpression = createArrayObject(op, safeReference(reference.peek()), n); - - return this; + return this; } - + private DBObject createArrayObject(String op, Object... items) { + List list = new ArrayList(); Collections.addAll(list, items); - - return new BasicDBObject( safeReference(op), list ); + + return new BasicDBObject(safeReference(op), list); } - + private void safePop() { - if ( !reference.empty() ) { - document.put( reference.pop(), rightHandSide(1) ); + + if (!reference.empty()) { + document.put(reference.pop(), rightHandSide(1)); } } - + private String safeReference(String key) { - Assert.notNull(key); - - if ( !key.startsWith(REFERENCE_PREFIX) ) { + + Assert.hasText(key); + + if (!key.startsWith(REFERENCE_PREFIX)) { return REFERENCE_PREFIX + key; } else { return key; @@ -147,10 +160,9 @@ public class Projection { rightHandExpression = null; return value; } - + DBObject toDBObject() { safePop(); return document; } - } diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationPipelineTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationPipelineTests.java index 756d193a6..01ee98b04 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationPipelineTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationPipelineTests.java @@ -1,8 +1,24 @@ +/* + * Copyright 2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.springframework.data.mongodb.core.aggregation; -import static org.hamcrest.CoreMatchers.is; -import static org.hamcrest.CoreMatchers.notNullValue; -import static org.junit.Assert.assertThat; +import static org.hamcrest.CoreMatchers.*; +import static org.junit.Assert.*; +import static org.springframework.data.mongodb.core.DBObjectUtils.*; + import java.util.List; import org.junit.Before; @@ -16,122 +32,112 @@ import com.mongodb.DBObject; /** * Tests of the {@link AggregationPipeline}. * + * @see DATAMONGO-586 * @author Tobias Trelle */ public class AggregationPipelineTests { - /** Unit under test. */ - private AggregationPipeline pipeline; - - @Before public void setUp() { + AggregationPipeline pipeline; + + @Before + public void setUp() { pipeline = new AggregationPipeline(); } - - @Test public void limitOperation() { - // given + + @Test + public void limitOperation() { + pipeline.limit(42); - - // when + List rawPipeline = pipeline.getOperations(); - - // then assertDBObject("$limit", 42L, rawPipeline); } - @Test public void skipOperation() { - // given + @Test + public void skipOperation() { + pipeline.skip(5); - - // when + List rawPipeline = pipeline.getOperations(); - - // then assertDBObject("$skip", 5L, rawPipeline); } - @Test public void unwindOperation() { - // given + @Test + public void unwindOperation() { + pipeline.unwind("$field"); - - // when + List rawPipeline = pipeline.getOperations(); - - // then assertDBObject("$unwind", "$field", rawPipeline); } - @Test public void unwindOperationWithAddedPrefix() { - // given + @Test + public void unwindOperationWithAddedPrefix() { + pipeline.unwind("field"); - - // when + List rawPipeline = pipeline.getOperations(); - - // then assertDBObject("$unwind", "$field", rawPipeline); } - - - @Test public void matchOperation() { - // given + + @Test + public void matchOperation() { + Criteria criteria = new Criteria("title").is("Doc 1"); - pipeline.match( criteria ); - - // when + pipeline.match(criteria); + List rawPipeline = pipeline.getOperations(); - - // then assertOneDocument(rawPipeline); + DBObject match = rawPipeline.get(0); - DBObject criteriaDoc = (DBObject)match.get("$match"); - assertThat( criteriaDoc, notNullValue() ); - assertSingleDBObject( "title" , "Doc 1", criteriaDoc ); + DBObject criteriaDoc = getAsDBObject(match, "$match"); + assertThat(criteriaDoc, is(notNullValue())); + assertSingleDBObject("title", "Doc 1", criteriaDoc); } - @Test public void sortOperation() { - // given + @Test + public void sortOperation() { + Sort sort = new Sort(new Sort.Order(Direction.ASC, "n")); - pipeline.sort( sort ); - - // when + pipeline.sort(sort); + List rawPipeline = pipeline.getOperations(); - - // then assertOneDocument(rawPipeline); + DBObject sortDoc = rawPipeline.get(0); - DBObject orderDoc = (DBObject)sortDoc.get("$sort"); - assertThat( orderDoc, notNullValue() ); - assertSingleDBObject( "n" , 1, orderDoc ); + DBObject orderDoc = getAsDBObject(sortDoc, "$sort"); + assertThat(orderDoc, is(notNullValue())); + assertSingleDBObject("n", 1, orderDoc); } - - @Test public void projectOperation() { - // given + + @Test + public void projectOperation() { + Projection projection = new Projection("a"); pipeline.project(projection); - - // when + List rawPipeline = pipeline.getOperations(); - - // then assertOneDocument(rawPipeline); + DBObject projectionDoc = rawPipeline.get(0); - DBObject fields = (DBObject)projectionDoc.get("$project"); - assertThat( fields, notNullValue() ); - assertSingleDBObject( "a" , 1, fields ); + DBObject fields = getAsDBObject(projectionDoc, "$project"); + assertThat(fields, is(notNullValue())); + assertSingleDBObject("a", 1, fields); } - + private static void assertOneDocument(List result) { - assertThat( result, notNullValue() ); - assertThat( result.size(), is(1) ); + + assertThat(result, is(notNullValue())); + assertThat(result.size(), is(1)); } - + private static void assertDBObject(String key, Object value, List result) { + assertOneDocument(result); - assertSingleDBObject( key, value, result.get(0) ); + assertSingleDBObject(key, value, result.get(0)); } private static void assertSingleDBObject(String key, Object value, DBObject doc) { - assertThat( doc.get(key), is(value) ); + assertThat(doc.get(key), is(value)); } - } diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java index f70ff7d98..b08f6e05f 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2011-2012 the original author or authors. + * Copyright 2013 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,9 +15,8 @@ */ package org.springframework.data.mongodb.core.aggregation; -import static org.hamcrest.CoreMatchers.is; -import static org.hamcrest.CoreMatchers.notNullValue; -import static org.junit.Assert.assertThat; +import static org.hamcrest.CoreMatchers.*; +import static org.junit.Assert.*; import java.util.ArrayList; import java.util.List; @@ -41,6 +40,7 @@ import com.mongodb.DBObject; /** * Tests for {@link MongoTemplate#aggregate(String, AggregationPipeline, Class)}. * + * @see DATAMONGO-586 * @author Tobias Trelle */ @RunWith(SpringJUnit4ClassRunner.class) @@ -79,35 +79,26 @@ public class AggregationTests { @Test(expected = IllegalArgumentException.class) public void shouldDetectIllegalJsonInOperation() { - // given - AggregationPipeline pipeline = new AggregationPipeline().project("{ foo bar"); - // when + AggregationPipeline pipeline = new AggregationPipeline().project("{ foo bar"); mongoTemplate.aggregate(INPUT_COLLECTION, pipeline, TagCount.class); - - // then: throw expected exception } @Test public void shouldAggregate() { - // given + createDocuments(); - AggregationPipeline pipeline = new AggregationPipeline() - .project("{_id:0,tags:1}}") - .unwind("tags") - .group("{_id:\"$tags\", n:{$sum:1}}") - .project("{tag: \"$_id\", n:1, _id:0}") - .sort( new Sort(new Sort.Order(Direction.DESC, "n")) ); - - // when - AggregationResults results = mongoTemplate.aggregate(INPUT_COLLECTION, pipeline, TagCount.class); - // then - assertThat(results, notNullValue()); + AggregationPipeline pipeline = new AggregationPipeline().project("{_id:0,tags:1}}").unwind("tags") + .group("{_id:\"$tags\", n:{$sum:1}}").project("{tag: \"$_id\", n:1, _id:0}") + .sort(new Sort(new Sort.Order(Direction.DESC, "n"))); + + AggregationResults results = mongoTemplate.aggregate(INPUT_COLLECTION, pipeline, TagCount.class); + assertThat(results, is(notNullValue())); assertThat(results.getServerUsed(), is("/127.0.0.1:27017")); List tagCount = results.getAggregationResult(); - assertThat(tagCount, notNullValue()); + assertThat(tagCount, is(notNullValue())); assertThat(tagCount.size(), is(3)); assertTagCount("spring", 3, tagCount.get(0)); assertTagCount("mongodb", 2, tagCount.get(1)); @@ -116,87 +107,73 @@ public class AggregationTests { @Test(expected = InvalidDataAccessApiUsageException.class) public void shouldDetectIllegalAggregationOperation() { - // given + createDocuments(); AggregationPipeline pipeline = new AggregationPipeline().project("{$foobar:{_id:0,tags:1}}"); - - // when mongoTemplate.aggregate(INPUT_COLLECTION, pipeline, TagCount.class); - - // then: throw expected exception } @Test public void shouldAggregateEmptyCollection() { - // given - AggregationPipeline pipeline = new AggregationPipeline() - .project("{_id:0,tags:1}}") - .unwind("$tags") - .group("{_id:\"$tags\", n:{$sum:1}}") - .project("{tag: \"$_id\", n:1, _id:0}") - .sort("{n:-1}"); - - // when - AggregationResults results = mongoTemplate.aggregate(INPUT_COLLECTION, pipeline, TagCount.class); - // then - assertThat(results, notNullValue()); + AggregationPipeline pipeline = new AggregationPipeline().project("{_id:0,tags:1}}").unwind("$tags") + .group("{_id:\"$tags\", n:{$sum:1}}").project("{tag: \"$_id\", n:1, _id:0}").sort("{n:-1}"); + + AggregationResults results = mongoTemplate.aggregate(INPUT_COLLECTION, pipeline, TagCount.class); + assertThat(results, is(notNullValue())); assertThat(results.getServerUsed(), is("/127.0.0.1:27017")); List tagCount = results.getAggregationResult(); - assertThat(tagCount, notNullValue()); + assertThat(tagCount, is(notNullValue())); assertThat(tagCount.size(), is(0)); } @Test public void shouldDetectResultMismatch() { - // given - createDocuments(); - AggregationPipeline pipeline = new AggregationPipeline() - .project("{_id:0,tags:1}}") - .unwind("$tags") - .group("{_id:\"$tags\", count:{$sum:1}}") - .limit(2); - - // when - AggregationResults results = mongoTemplate.aggregate(INPUT_COLLECTION, pipeline, TagCount.class); - // then - assertThat(results, notNullValue()); + createDocuments(); + AggregationPipeline pipeline = new AggregationPipeline().project("{_id:0,tags:1}}").unwind("$tags") + .group("{_id:\"$tags\", count:{$sum:1}}").limit(2); + + AggregationResults results = mongoTemplate.aggregate(INPUT_COLLECTION, pipeline, TagCount.class); + assertThat(results, is(notNullValue())); assertThat(results.getServerUsed(), is("/127.0.0.1:27017")); List tagCount = results.getAggregationResult(); - assertThat(tagCount, notNullValue()); + assertThat(tagCount, is(notNullValue())); assertThat(tagCount.size(), is(2)); assertTagCount(null, 0, tagCount.get(0)); assertTagCount(null, 0, tagCount.get(1)); } - + protected void cleanDb() { mongoTemplate.dropCollection(INPUT_COLLECTION); } private void createDocuments() { + DBCollection coll = mongoTemplate.getCollection(INPUT_COLLECTION); coll.insert(createDocument("Doc1", "spring", "mongodb", "nosql")); coll.insert(createDocument("Doc2", "spring", "mongodb")); coll.insert(createDocument("Doc3", "spring")); } - private DBObject createDocument(String title, String... tags) { + private static DBObject createDocument(String title, String... tags) { + DBObject doc = new BasicDBObject("title", title); List tagList = new ArrayList(); + for (String tag : tags) { tagList.add(tag); } - doc.put("tags", tagList); + doc.put("tags", tagList); return doc; } - private void assertTagCount(String tag, int n, TagCount tagCount) { + private static void assertTagCount(String tag, int n, TagCount tagCount) { + assertThat(tagCount.getTag(), is(tag)); assertThat(tagCount.getN(), is(n)); } - } diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/ProjectionTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/ProjectionTests.java index 25b537633..42982d868 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/ProjectionTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/ProjectionTests.java @@ -1,8 +1,22 @@ +/* + * Copyright 2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.springframework.data.mongodb.core.aggregation; -import static org.hamcrest.CoreMatchers.is; -import static org.hamcrest.CoreMatchers.notNullValue; -import static org.junit.Assert.assertThat; +import static org.hamcrest.CoreMatchers.*; +import static org.junit.Assert.*; import java.util.List; @@ -15,12 +29,12 @@ import com.mongodb.DBObject; /** * Tests of {@link Projection}. * + * @see DATAMONGO-586 * @author Tobias Trelle */ public class ProjectionTests { - /** Unit under test. */ - private Projection projection; + Projection projection; @Before public void setUp() { @@ -29,31 +43,24 @@ public class ProjectionTests { @Test public void emptyProjection() { - // when - DBObject raw = projection.toDBObject(); - // then - assertThat(raw, notNullValue()); + DBObject raw = projection.toDBObject(); + assertThat(raw, is(notNullValue())); assertThat(raw.toMap().isEmpty(), is(true)); } @Test(expected = IllegalArgumentException.class) public void shouldDetectNullIncludesInConstructor() { - // when new Projection((String[]) null); - // then: throw expected exception } @Test public void includesWithConstructor() { - // given + projection = new Projection("a", "b"); - // when DBObject raw = projection.toDBObject(); - - // then - assertThat(raw, notNullValue()); + assertThat(raw, is(notNullValue())); assertThat(raw.toMap().size(), is(3)); assertThat((Integer) raw.get("_id"), is(0)); assertThat((Integer) raw.get("a"), is(1)); @@ -62,102 +69,88 @@ public class ProjectionTests { @Test public void include() { - // given + projection.include("a"); - // when DBObject raw = projection.toDBObject(); - - // then assertSingleDBObject("a", 1, raw); } @Test public void exclude() { - // given + projection.exclude("a"); - // when DBObject raw = projection.toDBObject(); - - // then assertSingleDBObject("a", 0, raw); } @Test public void includeAlias() { - // given + projection.include("a").as("b"); - // when DBObject raw = projection.toDBObject(); - - // then assertSingleDBObject("b", "$a", raw); } @Test(expected = InvalidDataAccessApiUsageException.class) public void shouldDetectAliasWithoutInclude() { - // when projection.as("b"); - // then: throw expected exception } @Test(expected = InvalidDataAccessApiUsageException.class) public void shouldDetectDuplicateAlias() { - // when projection.include("a").as("b").as("c"); - // then: throw expected exception } @Test + @SuppressWarnings("unchecked") public void plus() { - // given + projection.include("a").plus(10); - // when DBObject raw = projection.toDBObject(); - - // then assertNotNullDBObject(raw); - DBObject addition = (DBObject)raw.get("a"); - assertNotNullDBObject(addition); - @SuppressWarnings("unchecked") - List summands = (List)addition.get("$add"); - assertThat( summands, notNullValue() ); - assertThat( summands.size(), is(2) ); - assertThat( (String)summands.get(0), is("$a") ); - assertThat( (Integer)summands.get(1), is (10) ); + + DBObject addition = (DBObject) raw.get("a"); + assertNotNullDBObject(addition); + + List summands = (List) addition.get("$add"); + assertThat(summands, is(notNullValue())); + assertThat(summands.size(), is(2)); + assertThat((String) summands.get(0), is("$a")); + assertThat((Integer) summands.get(1), is(10)); } @Test + @SuppressWarnings("unchecked") public void plusWithAlias() { - // given + projection.include("a").plus(10).as("b"); - // when DBObject raw = projection.toDBObject(); - - // then assertNotNullDBObject(raw); - DBObject addition = (DBObject)raw.get("b"); - assertNotNullDBObject(addition); - @SuppressWarnings("unchecked") - List summands = (List)addition.get("$add"); - assertThat( summands, notNullValue() ); - assertThat( summands.size(), is(2) ); - assertThat( (String)summands.get(0), is("$a") ); - assertThat( (Integer)summands.get(1), is (10) ); + + DBObject addition = (DBObject) raw.get("b"); + assertNotNullDBObject(addition); + + List summands = (List) addition.get("$add"); + assertThat(summands, is(notNullValue())); + assertThat(summands.size(), is(2)); + assertThat((String) summands.get(0), is("$a")); + assertThat((Integer) summands.get(1), is(10)); } - - + private static void assertSingleDBObject(String key, Object value, DBObject doc) { + assertNotNullDBObject(doc); assertThat(doc.get(key), is(value)); } private static void assertNotNullDBObject(DBObject doc) { - assertThat(doc, notNullValue()); + + assertThat(doc, is(notNullValue())); assertThat(doc.toMap().size(), is(1)); } }