Browse Source

DATAMONGO-586 - First round of polish.

Fixed or added copyright headers where necessary. Added Tobias as author where necessary. Added @since tags to newly introduced classes and methods. Documented non-nullability of parameters. Polished test cases a bit.
pull/58/merge
Oliver Gierke 13 years ago committed by Oliver Gierke
parent
commit
4d65aa7207
  1. 11
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoOperations.java
  2. 25
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java
  3. 50
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationPipeline.java
  4. 56
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationResults.java
  5. 82
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/Projection.java
  6. 148
      spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationPipelineTests.java
  7. 93
      spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java
  8. 111
      spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/ProjectionTests.java

11
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoOperations.java

@ -47,6 +47,7 @@ import com.mongodb.WriteResult;
* @author Thomas Risberg * @author Thomas Risberg
* @author Mark Pollack * @author Mark Pollack
* @author Oliver Gierke * @author Oliver Gierke
* @author Tobias Trelle
*/ */
public interface MongoOperations { public interface MongoOperations {
@ -306,13 +307,15 @@ public interface MongoOperations {
/** /**
* Execute an aggregation operation. The raw results will be mapped to the given entity class. * Execute an aggregation operation. The raw results will be mapped to the given entity class.
* *
* @param inputCollectionName the collection there the aggregation operation will read from. * @param inputCollectionName the collection there the aggregation operation will read from, must not be
* @param pipeline The pipeline holding the aggregation operations. * {@literal null} or empty.
* @param entityClass The parameterized type of the returned list. * @param pipeline The pipeline holding the aggregation operations, must not be {@literal null}.
* @param entityClass The parameterized type of the returned list, must not be {@literal null}.
* @return The results of the aggregation operation. * @return The results of the aggregation operation.
* @since 1.3
*/ */
<T> AggregationResults<T> aggregate(String inputCollectionName, AggregationPipeline pipeline, Class<T> entityClass); <T> AggregationResults<T> aggregate(String inputCollectionName, AggregationPipeline pipeline, Class<T> entityClass);
/** /**
* Execute a map-reduce operation. The map-reduce operation will be formed with an output type of INLINE * Execute a map-reduce operation. The map-reduce operation will be formed with an output type of INLINE
* *

25
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java

@ -15,8 +15,8 @@
*/ */
package org.springframework.data.mongodb.core; package org.springframework.data.mongodb.core;
import static org.springframework.data.mongodb.core.query.Criteria.where; import static org.springframework.data.mongodb.core.query.Criteria.*;
import static org.springframework.data.mongodb.core.query.SerializationUtils.serializeToJsonSafely; import static org.springframework.data.mongodb.core.query.SerializationUtils.*;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
@ -116,6 +116,7 @@ import com.mongodb.util.JSONParseException;
* @author Oliver Gierke * @author Oliver Gierke
* @author Amol Nayak * @author Amol Nayak
* @author Patryk Wasik * @author Patryk Wasik
* @author Tobias Trelle
*/ */
public class MongoTemplate implements MongoOperations, ApplicationContextAware { public class MongoTemplate implements MongoOperations, ApplicationContextAware {
@ -1210,19 +1211,21 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware {
} }
public <T> AggregationResults<T> aggregate(String inputCollectionName, AggregationPipeline pipeline, Class<T> entityClass) { public <T> AggregationResults<T> aggregate(String inputCollectionName, AggregationPipeline pipeline,
Class<T> entityClass) {
Assert.notNull(inputCollectionName, "Collection name is missing"); Assert.notNull(inputCollectionName, "Collection name is missing");
Assert.notNull(pipeline, "Aggregation pipeline is missing"); Assert.notNull(pipeline, "Aggregation pipeline is missing");
Assert.notNull(entityClass, "Entity class is missing"); Assert.notNull(entityClass, "Entity class is missing");
// prepare command // prepare command
DBObject command = new BasicDBObject("aggregate", inputCollectionName ); DBObject command = new BasicDBObject("aggregate", inputCollectionName);
command.put( "pipeline", pipeline.getOperations() ); command.put("pipeline", pipeline.getOperations());
// execute command // execute command
CommandResult commandResult = executeCommand(command); CommandResult commandResult = executeCommand(command);
handleCommandError(commandResult, command); handleCommandError(commandResult, command);
// map results // map results
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
Iterable<DBObject> resultSet = (Iterable<DBObject>) commandResult.get("result"); Iterable<DBObject> resultSet = (Iterable<DBObject>) commandResult.get("result");
@ -1230,11 +1233,11 @@ public class MongoTemplate implements MongoOperations, ApplicationContextAware {
DbObjectCallback<T> callback = new ReadDbObjectCallback<T>(mongoConverter, entityClass); DbObjectCallback<T> callback = new ReadDbObjectCallback<T>(mongoConverter, entityClass);
for (DBObject dbObject : resultSet) { for (DBObject dbObject : resultSet) {
mappedResults.add(callback.doWith(dbObject)); mappedResults.add(callback.doWith(dbObject));
} }
return new AggregationResults<T>(mappedResults, commandResult); return new AggregationResults<T>(mappedResults, commandResult);
} }
protected String replaceWithResourceIfNecessary(String function) { protected String replaceWithResourceIfNecessary(String function) {
String func = function; String func = function;

50
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationPipeline.java

@ -1,5 +1,5 @@
/* /*
* Copyright 2011-2012 the original author or authors. * Copyright 2013 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -31,17 +31,18 @@ import com.mongodb.util.JSONParseException;
* Holds the operations of an aggregation pipeline. * Holds the operations of an aggregation pipeline.
* *
* @author Tobias Trelle * @author Tobias Trelle
* @since 1.3
*/ */
public class AggregationPipeline { public class AggregationPipeline {
private static final String OPERATOR_PREFIX = "$"; private static final String OPERATOR_PREFIX = "$";
private List<DBObject> operations = new ArrayList<DBObject>(); private final List<DBObject> operations = new ArrayList<DBObject>();
/** /**
* Adds a projection operation to the pipeline. * Adds a projection operation to the pipeline.
* *
* @param projection JSON string holding the projection. * @param projection JSON string holding the projection, must not be {@literal null} or empty.
* @return The pipeline. * @return The pipeline.
*/ */
public AggregationPipeline project(String projection) { public AggregationPipeline project(String projection) {
@ -51,33 +52,36 @@ public class AggregationPipeline {
/** /**
* Adds a projection operation to the pipeline. * Adds a projection operation to the pipeline.
* *
* @param projection Type safe projection object. * @param projection Type safe projection object, must not be {@literal null}.
* @return The pipeline. * @return The pipeline.
*/ */
public AggregationPipeline project(Projection projection) { public AggregationPipeline project(Projection projection) {
return addOperation("project", projection.toDBObject() );
Assert.notNull(projection, "Projection must not be null!");
return addOperation("project", projection.toDBObject());
} }
/** /**
* Adds an unwind operation to the pipeline. * Adds an unwind operation to the pipeline.
* *
* @param field Name of the field to unwind (should be an array). * @param field Name of the field to unwind (should be an array), must not be {@literal null} or empty.
* @return The pipeline. * @return The pipeline.
*/ */
public AggregationPipeline unwind(String field) { public AggregationPipeline unwind(String field) {
Assert.notNull(field, "Missing field name");
Assert.hasText(field, "Missing field name");
if (!field.startsWith(OPERATOR_PREFIX)) { if (!field.startsWith(OPERATOR_PREFIX)) {
field = OPERATOR_PREFIX + field; field = OPERATOR_PREFIX + field;
} }
return addOperation("unwind", field); return addOperation("unwind", field);
} }
/** /**
* Adds a group operation to the pipeline. * Adds a group operation to the pipeline.
* *
* @param projection JSON string holding the group. * @param projection JSON string holding the group, must not be {@literal null} or empty.
* @return The pipeline. * @return The pipeline.
*/ */
public AggregationPipeline group(String group) { public AggregationPipeline group(String group) {
@ -87,7 +91,7 @@ public class AggregationPipeline {
/** /**
* Adds a sort operation to the pipeline. * Adds a sort operation to the pipeline.
* *
* @param sort JSON string holding the sorting. * @param sort JSON string holding the sorting, must not be {@literal null} or empty.
* @return The pipeline. * @return The pipeline.
*/ */
public AggregationPipeline sort(String sort) { public AggregationPipeline sort(String sort) {
@ -97,12 +101,12 @@ public class AggregationPipeline {
/** /**
* Adds a sort operation to the pipeline. * Adds a sort operation to the pipeline.
* *
* @param sort Type safe sort operation. * @param sort Type safe sort operation, must not be {@literal null}.
* @return The pipeline. * @return The pipeline.
*/ */
public AggregationPipeline sort(Sort sort) { public AggregationPipeline sort(Sort sort) {
Assert.notNull(sort); Assert.notNull(sort);
DBObject dbo = new BasicDBObject(); DBObject dbo = new BasicDBObject();
for (org.springframework.data.domain.Sort.Order order : sort) { for (org.springframework.data.domain.Sort.Order order : sort) {
@ -112,9 +116,9 @@ public class AggregationPipeline {
} }
/** /**
* Adds a match operation to the pipeline that is basically a query on the collection.s * Adds a match operation to the pipeline that is basically a query on the collections.
* *
* @param projection JSON string holding the criteria. * @param projection JSON string holding the criteria, must not be {@literal null} or empty.
* @return The pipeline. * @return The pipeline.
*/ */
public AggregationPipeline match(String match) { public AggregationPipeline match(String match) {
@ -124,12 +128,12 @@ public class AggregationPipeline {
/** /**
* Adds a match operation to the pipeline that is basically a query on the collection.s * Adds a match operation to the pipeline that is basically a query on the collection.s
* *
* @param criteria Type safe criteria to filter documents from the collection. * @param criteria Type safe criteria to filter documents from the collection, must not be {@literal null}.
* @return The pipeline. * @return The pipeline.
*/ */
public AggregationPipeline match(Criteria criteria) { public AggregationPipeline match(Criteria criteria) {
Assert.notNull(criteria); Assert.notNull(criteria);
return addOperation("match", criteria.getCriteriaObject()); return addOperation("match", criteria.getCriteriaObject());
} }
@ -158,7 +162,8 @@ public class AggregationPipeline {
} }
private AggregationPipeline addDocumentOperation(String opName, String operation) { private AggregationPipeline addDocumentOperation(String opName, String operation) {
Assert.notNull(operation, "Missing " + opName);
Assert.hasText(operation, "Missing operation name!");
return addOperation(opName, parseJson(operation)); return addOperation(opName, parseJson(operation));
} }
@ -174,5 +179,4 @@ public class AggregationPipeline {
throw new IllegalArgumentException("Not a valid JSON document: " + json, e); throw new IllegalArgumentException("Not a valid JSON document: " + json, e);
} }
} }
}
}

56
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationResults.java

@ -1,5 +1,5 @@
/* /*
* Copyright 2011-2012 the original author or authors. * Copyright 2013 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -15,7 +15,7 @@
*/ */
package org.springframework.data.mongodb.core.aggregation; package org.springframework.data.mongodb.core.aggregation;
import java.util.ArrayList; import java.util.Collections;
import java.util.Iterator; import java.util.Iterator;
import java.util.List; import java.util.List;
@ -27,49 +27,61 @@ import com.mongodb.DBObject;
* Collects the results of executing an aggregation operation. * Collects the results of executing an aggregation operation.
* *
* @author Tobias Trelle * @author Tobias Trelle
* * @author Oliver Gierke
* @param <T> The class in which the results are mapped onto. * @param <T> The class in which the results are mapped onto.
* @since 1.3
*/ */
public class AggregationResults<T> implements Iterable<T> { public class AggregationResults<T> implements Iterable<T> {
private final List<T> mappedResults; private final List<T> mappedResults;
private final DBObject rawResults; private final DBObject rawResults;
private final String serverUsed;
private String serverUsed; /**
* Creates a new {@link AggregationResults} instance from the given mapped and raw results.
*
* @param mappedResults must not be {@literal null}.
* @param rawResults must not be {@literal null}.
*/
public AggregationResults(List<T> mappedResults, DBObject rawResults) { public AggregationResults(List<T> mappedResults, DBObject rawResults) {
Assert.notNull(mappedResults); Assert.notNull(mappedResults);
Assert.notNull(rawResults); Assert.notNull(rawResults);
this.mappedResults = mappedResults;
this.mappedResults = Collections.unmodifiableList(mappedResults);
this.rawResults = rawResults; this.rawResults = rawResults;
parseServerUsed(); this.serverUsed = parseServerUsed();
} }
/**
* Returns the aggregation results.
*
* @return
*/
public List<T> getAggregationResult() { public List<T> getAggregationResult() {
List<T> result = new ArrayList<T>(); return mappedResults;
Iterator<T> it = iterator();
while (it.hasNext()) {
result.add(it.next());
}
return result;
} }
@Override /*
* (non-Javadoc)
* @see java.lang.Iterable#iterator()
*/
public Iterator<T> iterator() { public Iterator<T> iterator() {
return mappedResults.iterator(); return mappedResults.iterator();
} }
/**
* Returns the server that has been used to perform the aggregation.
*
* @return
*/
public String getServerUsed() { public String getServerUsed() {
return serverUsed; return serverUsed;
} }
private void parseServerUsed() { private String parseServerUsed() {
Object object = rawResults.get("serverUsed"); Object object = rawResults.get("serverUsed");
if (object instanceof String) { return object instanceof String ? (String) object : null;
serverUsed = (String) object;
}
} }
} }

82
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/Projection.java

@ -1,5 +1,5 @@
/* /*
* Copyright 2011-2012 the original author or authors. * Copyright 2013 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -32,33 +32,37 @@ import com.mongodb.DBObject;
* Projection of field to be used in an {@link AggregationPipeline}. * Projection of field to be used in an {@link AggregationPipeline}.
* <p/> * <p/>
* A projection is similar to a {@link Field} inclusion/exclusion but more powerful. It can generate new fields, change * A projection is similar to a {@link Field} inclusion/exclusion but more powerful. It can generate new fields, change
* values of given field etc. * values of given field etc.
* *
* @author Tobias Trelle * @author Tobias Trelle
* @since 1.3
*/ */
public class Projection { public class Projection {
private static final String REFERENCE_PREFIX = "$"; private static final String REFERENCE_PREFIX = "$";
private DBObject document = new BasicDBObject();
private DBObject rightHandExpression;
/** Stack of key names. Size is 0 or 1. */ /** Stack of key names. Size is 0 or 1. */
private Stack<String> reference = new Stack<String>(); private final Stack<String> reference = new Stack<String>();
private final DBObject document = new BasicDBObject();
private DBObject rightHandExpression;
/** Create an empty projection. */ /**
* Create an empty projection.
*/
public Projection() { public Projection() {
} }
/** /**
* This convenience constructor excludes the field <code>_id</code> and includes the given fields. * This convenience constructor excludes the field {@code _id} and includes the given fields.
* *
* @param includes Keys of field to include. * @param includes Keys of field to include, must not be {@literal null} or empty.
*/ */
public Projection(String... includes) { public Projection(String... includes) {
Assert.notEmpty(includes); Assert.notEmpty(includes);
exclude("_id"); exclude("_id");
for (String key : includes) { for (String key : includes) {
include(key); include(key);
} }
@ -69,73 +73,82 @@ public class Projection {
* *
* @param key The key of the field. * @param key The key of the field.
*/ */
public final void exclude(String key) { public final Projection exclude(String key) {
Assert.notNull(key, "Missing key");
Assert.hasText(key, "Missing key");
document.put(key, 0); document.put(key, 0);
return this;
} }
/** /**
* Includes a given field. * Includes a given field.
* *
* @param key The key of the field. * @param key The key of the field, must not be {@literal null} or empty.
*/ */
public final Projection include(String key) { public final Projection include(String key) {
Assert.notNull(key, "Missing key");
Assert.hasText(key, "Missing key");
safePop(); safePop();
reference.push(key); reference.push(key);
return this; return this;
} }
/** /**
* Sets the key for a computed field. * Sets the key for a computed field.
* *
* @param key must not be {@literal null} or empty.
*/ */
public final Projection as(String key) { public final Projection as(String key) {
Assert.notNull(key, "Missing key");
Assert.hasText(key, "Missing key");
try { try {
document.put(key, rightHandSide(safeReference(reference.pop())) ); document.put(key, rightHandSide(safeReference(reference.pop())));
} catch (EmptyStackException e) { } catch (EmptyStackException e) {
throw new InvalidDataAccessApiUsageException("Invalid use of as()", e); throw new InvalidDataAccessApiUsageException("Invalid use of as()", e);
} }
return this; return this;
} }
public final Projection plus(Number n) { public final Projection plus(Number n) {
return arithmeticOperation("add", n); return arithmeticOperation("add", n);
} }
public final Projection minus(Number n) { public final Projection minus(Number n) {
return arithmeticOperation("substract", n); return arithmeticOperation("substract", n);
} }
private Projection arithmeticOperation(String op, Number n) { private Projection arithmeticOperation(String op, Number n) {
Assert.notNull(n, "Missing number"); Assert.notNull(n, "Missing number");
rightHandExpression = createArrayObject(op, safeReference(reference.peek()), n); rightHandExpression = createArrayObject(op, safeReference(reference.peek()), n);
return this;
return this;
} }
private DBObject createArrayObject(String op, Object... items) { private DBObject createArrayObject(String op, Object... items) {
List<Object> list = new ArrayList<Object>(); List<Object> list = new ArrayList<Object>();
Collections.addAll(list, items); Collections.addAll(list, items);
return new BasicDBObject( safeReference(op), list ); return new BasicDBObject(safeReference(op), list);
} }
private void safePop() { private void safePop() {
if ( !reference.empty() ) {
document.put( reference.pop(), rightHandSide(1) ); if (!reference.empty()) {
document.put(reference.pop(), rightHandSide(1));
} }
} }
private String safeReference(String key) { private String safeReference(String key) {
Assert.notNull(key);
Assert.hasText(key);
if ( !key.startsWith(REFERENCE_PREFIX) ) {
if (!key.startsWith(REFERENCE_PREFIX)) {
return REFERENCE_PREFIX + key; return REFERENCE_PREFIX + key;
} else { } else {
return key; return key;
@ -147,10 +160,9 @@ public class Projection {
rightHandExpression = null; rightHandExpression = null;
return value; return value;
} }
DBObject toDBObject() { DBObject toDBObject() {
safePop(); safePop();
return document; return document;
} }
} }

148
spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationPipelineTests.java

@ -1,8 +1,24 @@
/*
* Copyright 2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.data.mongodb.core.aggregation; package org.springframework.data.mongodb.core.aggregation;
import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.CoreMatchers.*;
import static org.hamcrest.CoreMatchers.notNullValue; import static org.junit.Assert.*;
import static org.junit.Assert.assertThat; import static org.springframework.data.mongodb.core.DBObjectUtils.*;
import java.util.List; import java.util.List;
import org.junit.Before; import org.junit.Before;
@ -16,122 +32,112 @@ import com.mongodb.DBObject;
/** /**
* Tests of the {@link AggregationPipeline}. * Tests of the {@link AggregationPipeline}.
* *
* @see DATAMONGO-586
* @author Tobias Trelle * @author Tobias Trelle
*/ */
public class AggregationPipelineTests { public class AggregationPipelineTests {
/** Unit under test. */ AggregationPipeline pipeline;
private AggregationPipeline pipeline;
@Before
@Before public void setUp() { public void setUp() {
pipeline = new AggregationPipeline(); pipeline = new AggregationPipeline();
} }
@Test public void limitOperation() { @Test
// given public void limitOperation() {
pipeline.limit(42); pipeline.limit(42);
// when
List<DBObject> rawPipeline = pipeline.getOperations(); List<DBObject> rawPipeline = pipeline.getOperations();
// then
assertDBObject("$limit", 42L, rawPipeline); assertDBObject("$limit", 42L, rawPipeline);
} }
@Test public void skipOperation() { @Test
// given public void skipOperation() {
pipeline.skip(5); pipeline.skip(5);
// when
List<DBObject> rawPipeline = pipeline.getOperations(); List<DBObject> rawPipeline = pipeline.getOperations();
// then
assertDBObject("$skip", 5L, rawPipeline); assertDBObject("$skip", 5L, rawPipeline);
} }
@Test public void unwindOperation() { @Test
// given public void unwindOperation() {
pipeline.unwind("$field"); pipeline.unwind("$field");
// when
List<DBObject> rawPipeline = pipeline.getOperations(); List<DBObject> rawPipeline = pipeline.getOperations();
// then
assertDBObject("$unwind", "$field", rawPipeline); assertDBObject("$unwind", "$field", rawPipeline);
} }
@Test public void unwindOperationWithAddedPrefix() { @Test
// given public void unwindOperationWithAddedPrefix() {
pipeline.unwind("field"); pipeline.unwind("field");
// when
List<DBObject> rawPipeline = pipeline.getOperations(); List<DBObject> rawPipeline = pipeline.getOperations();
// then
assertDBObject("$unwind", "$field", rawPipeline); assertDBObject("$unwind", "$field", rawPipeline);
} }
@Test
@Test public void matchOperation() { public void matchOperation() {
// given
Criteria criteria = new Criteria("title").is("Doc 1"); Criteria criteria = new Criteria("title").is("Doc 1");
pipeline.match( criteria ); pipeline.match(criteria);
// when
List<DBObject> rawPipeline = pipeline.getOperations(); List<DBObject> rawPipeline = pipeline.getOperations();
// then
assertOneDocument(rawPipeline); assertOneDocument(rawPipeline);
DBObject match = rawPipeline.get(0); DBObject match = rawPipeline.get(0);
DBObject criteriaDoc = (DBObject)match.get("$match"); DBObject criteriaDoc = getAsDBObject(match, "$match");
assertThat( criteriaDoc, notNullValue() ); assertThat(criteriaDoc, is(notNullValue()));
assertSingleDBObject( "title" , "Doc 1", criteriaDoc ); assertSingleDBObject("title", "Doc 1", criteriaDoc);
} }
@Test public void sortOperation() { @Test
// given public void sortOperation() {
Sort sort = new Sort(new Sort.Order(Direction.ASC, "n")); Sort sort = new Sort(new Sort.Order(Direction.ASC, "n"));
pipeline.sort( sort ); pipeline.sort(sort);
// when
List<DBObject> rawPipeline = pipeline.getOperations(); List<DBObject> rawPipeline = pipeline.getOperations();
// then
assertOneDocument(rawPipeline); assertOneDocument(rawPipeline);
DBObject sortDoc = rawPipeline.get(0); DBObject sortDoc = rawPipeline.get(0);
DBObject orderDoc = (DBObject)sortDoc.get("$sort"); DBObject orderDoc = getAsDBObject(sortDoc, "$sort");
assertThat( orderDoc, notNullValue() ); assertThat(orderDoc, is(notNullValue()));
assertSingleDBObject( "n" , 1, orderDoc ); assertSingleDBObject("n", 1, orderDoc);
} }
@Test public void projectOperation() { @Test
// given public void projectOperation() {
Projection projection = new Projection("a"); Projection projection = new Projection("a");
pipeline.project(projection); pipeline.project(projection);
// when
List<DBObject> rawPipeline = pipeline.getOperations(); List<DBObject> rawPipeline = pipeline.getOperations();
// then
assertOneDocument(rawPipeline); assertOneDocument(rawPipeline);
DBObject projectionDoc = rawPipeline.get(0); DBObject projectionDoc = rawPipeline.get(0);
DBObject fields = (DBObject)projectionDoc.get("$project"); DBObject fields = getAsDBObject(projectionDoc, "$project");
assertThat( fields, notNullValue() ); assertThat(fields, is(notNullValue()));
assertSingleDBObject( "a" , 1, fields ); assertSingleDBObject("a", 1, fields);
} }
private static void assertOneDocument(List<DBObject> result) { private static void assertOneDocument(List<DBObject> result) {
assertThat( result, notNullValue() );
assertThat( result.size(), is(1) ); assertThat(result, is(notNullValue()));
assertThat(result.size(), is(1));
} }
private static void assertDBObject(String key, Object value, List<DBObject> result) { private static void assertDBObject(String key, Object value, List<DBObject> result) {
assertOneDocument(result); assertOneDocument(result);
assertSingleDBObject( key, value, result.get(0) ); assertSingleDBObject(key, value, result.get(0));
} }
private static void assertSingleDBObject(String key, Object value, DBObject doc) { private static void assertSingleDBObject(String key, Object value, DBObject doc) {
assertThat( doc.get(key), is(value) ); assertThat(doc.get(key), is(value));
} }
} }

93
spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java

@ -1,5 +1,5 @@
/* /*
* Copyright 2011-2012 the original author or authors. * Copyright 2013 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -15,9 +15,8 @@
*/ */
package org.springframework.data.mongodb.core.aggregation; package org.springframework.data.mongodb.core.aggregation;
import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.CoreMatchers.*;
import static org.hamcrest.CoreMatchers.notNullValue; import static org.junit.Assert.*;
import static org.junit.Assert.assertThat;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
@ -41,6 +40,7 @@ import com.mongodb.DBObject;
/** /**
* Tests for {@link MongoTemplate#aggregate(String, AggregationPipeline, Class)}. * Tests for {@link MongoTemplate#aggregate(String, AggregationPipeline, Class)}.
* *
* @see DATAMONGO-586
* @author Tobias Trelle * @author Tobias Trelle
*/ */
@RunWith(SpringJUnit4ClassRunner.class) @RunWith(SpringJUnit4ClassRunner.class)
@ -79,35 +79,26 @@ public class AggregationTests {
@Test(expected = IllegalArgumentException.class) @Test(expected = IllegalArgumentException.class)
public void shouldDetectIllegalJsonInOperation() { public void shouldDetectIllegalJsonInOperation() {
// given
AggregationPipeline pipeline = new AggregationPipeline().project("{ foo bar");
// when AggregationPipeline pipeline = new AggregationPipeline().project("{ foo bar");
mongoTemplate.aggregate(INPUT_COLLECTION, pipeline, TagCount.class); mongoTemplate.aggregate(INPUT_COLLECTION, pipeline, TagCount.class);
// then: throw expected exception
} }
@Test @Test
public void shouldAggregate() { public void shouldAggregate() {
// given
createDocuments(); createDocuments();
AggregationPipeline pipeline = new AggregationPipeline()
.project("{_id:0,tags:1}}")
.unwind("tags")
.group("{_id:\"$tags\", n:{$sum:1}}")
.project("{tag: \"$_id\", n:1, _id:0}")
.sort( new Sort(new Sort.Order(Direction.DESC, "n")) );
// when
AggregationResults<TagCount> results = mongoTemplate.aggregate(INPUT_COLLECTION, pipeline, TagCount.class);
// then AggregationPipeline pipeline = new AggregationPipeline().project("{_id:0,tags:1}}").unwind("tags")
assertThat(results, notNullValue()); .group("{_id:\"$tags\", n:{$sum:1}}").project("{tag: \"$_id\", n:1, _id:0}")
.sort(new Sort(new Sort.Order(Direction.DESC, "n")));
AggregationResults<TagCount> results = mongoTemplate.aggregate(INPUT_COLLECTION, pipeline, TagCount.class);
assertThat(results, is(notNullValue()));
assertThat(results.getServerUsed(), is("/127.0.0.1:27017")); assertThat(results.getServerUsed(), is("/127.0.0.1:27017"));
List<TagCount> tagCount = results.getAggregationResult(); List<TagCount> tagCount = results.getAggregationResult();
assertThat(tagCount, notNullValue()); assertThat(tagCount, is(notNullValue()));
assertThat(tagCount.size(), is(3)); assertThat(tagCount.size(), is(3));
assertTagCount("spring", 3, tagCount.get(0)); assertTagCount("spring", 3, tagCount.get(0));
assertTagCount("mongodb", 2, tagCount.get(1)); assertTagCount("mongodb", 2, tagCount.get(1));
@ -116,87 +107,73 @@ public class AggregationTests {
@Test(expected = InvalidDataAccessApiUsageException.class) @Test(expected = InvalidDataAccessApiUsageException.class)
public void shouldDetectIllegalAggregationOperation() { public void shouldDetectIllegalAggregationOperation() {
// given
createDocuments(); createDocuments();
AggregationPipeline pipeline = new AggregationPipeline().project("{$foobar:{_id:0,tags:1}}"); AggregationPipeline pipeline = new AggregationPipeline().project("{$foobar:{_id:0,tags:1}}");
// when
mongoTemplate.aggregate(INPUT_COLLECTION, pipeline, TagCount.class); mongoTemplate.aggregate(INPUT_COLLECTION, pipeline, TagCount.class);
// then: throw expected exception
} }
@Test @Test
public void shouldAggregateEmptyCollection() { public void shouldAggregateEmptyCollection() {
// given
AggregationPipeline pipeline = new AggregationPipeline()
.project("{_id:0,tags:1}}")
.unwind("$tags")
.group("{_id:\"$tags\", n:{$sum:1}}")
.project("{tag: \"$_id\", n:1, _id:0}")
.sort("{n:-1}");
// when
AggregationResults<TagCount> results = mongoTemplate.aggregate(INPUT_COLLECTION, pipeline, TagCount.class);
// then AggregationPipeline pipeline = new AggregationPipeline().project("{_id:0,tags:1}}").unwind("$tags")
assertThat(results, notNullValue()); .group("{_id:\"$tags\", n:{$sum:1}}").project("{tag: \"$_id\", n:1, _id:0}").sort("{n:-1}");
AggregationResults<TagCount> results = mongoTemplate.aggregate(INPUT_COLLECTION, pipeline, TagCount.class);
assertThat(results, is(notNullValue()));
assertThat(results.getServerUsed(), is("/127.0.0.1:27017")); assertThat(results.getServerUsed(), is("/127.0.0.1:27017"));
List<TagCount> tagCount = results.getAggregationResult(); List<TagCount> tagCount = results.getAggregationResult();
assertThat(tagCount, notNullValue()); assertThat(tagCount, is(notNullValue()));
assertThat(tagCount.size(), is(0)); assertThat(tagCount.size(), is(0));
} }
@Test @Test
public void shouldDetectResultMismatch() { public void shouldDetectResultMismatch() {
// given
createDocuments();
AggregationPipeline pipeline = new AggregationPipeline()
.project("{_id:0,tags:1}}")
.unwind("$tags")
.group("{_id:\"$tags\", count:{$sum:1}}")
.limit(2);
// when
AggregationResults<TagCount> results = mongoTemplate.aggregate(INPUT_COLLECTION, pipeline, TagCount.class);
// then createDocuments();
assertThat(results, notNullValue()); AggregationPipeline pipeline = new AggregationPipeline().project("{_id:0,tags:1}}").unwind("$tags")
.group("{_id:\"$tags\", count:{$sum:1}}").limit(2);
AggregationResults<TagCount> results = mongoTemplate.aggregate(INPUT_COLLECTION, pipeline, TagCount.class);
assertThat(results, is(notNullValue()));
assertThat(results.getServerUsed(), is("/127.0.0.1:27017")); assertThat(results.getServerUsed(), is("/127.0.0.1:27017"));
List<TagCount> tagCount = results.getAggregationResult(); List<TagCount> tagCount = results.getAggregationResult();
assertThat(tagCount, notNullValue()); assertThat(tagCount, is(notNullValue()));
assertThat(tagCount.size(), is(2)); assertThat(tagCount.size(), is(2));
assertTagCount(null, 0, tagCount.get(0)); assertTagCount(null, 0, tagCount.get(0));
assertTagCount(null, 0, tagCount.get(1)); assertTagCount(null, 0, tagCount.get(1));
} }
protected void cleanDb() { protected void cleanDb() {
mongoTemplate.dropCollection(INPUT_COLLECTION); mongoTemplate.dropCollection(INPUT_COLLECTION);
} }
private void createDocuments() { private void createDocuments() {
DBCollection coll = mongoTemplate.getCollection(INPUT_COLLECTION); DBCollection coll = mongoTemplate.getCollection(INPUT_COLLECTION);
coll.insert(createDocument("Doc1", "spring", "mongodb", "nosql")); coll.insert(createDocument("Doc1", "spring", "mongodb", "nosql"));
coll.insert(createDocument("Doc2", "spring", "mongodb")); coll.insert(createDocument("Doc2", "spring", "mongodb"));
coll.insert(createDocument("Doc3", "spring")); coll.insert(createDocument("Doc3", "spring"));
} }
private DBObject createDocument(String title, String... tags) { private static DBObject createDocument(String title, String... tags) {
DBObject doc = new BasicDBObject("title", title); DBObject doc = new BasicDBObject("title", title);
List<String> tagList = new ArrayList<String>(); List<String> tagList = new ArrayList<String>();
for (String tag : tags) { for (String tag : tags) {
tagList.add(tag); tagList.add(tag);
} }
doc.put("tags", tagList);
doc.put("tags", tagList);
return doc; return doc;
} }
private void assertTagCount(String tag, int n, TagCount tagCount) { private static void assertTagCount(String tag, int n, TagCount tagCount) {
assertThat(tagCount.getTag(), is(tag)); assertThat(tagCount.getTag(), is(tag));
assertThat(tagCount.getN(), is(n)); assertThat(tagCount.getN(), is(n));
} }
} }

111
spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/ProjectionTests.java

@ -1,8 +1,22 @@
/*
* Copyright 2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.data.mongodb.core.aggregation; package org.springframework.data.mongodb.core.aggregation;
import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.CoreMatchers.*;
import static org.hamcrest.CoreMatchers.notNullValue; import static org.junit.Assert.*;
import static org.junit.Assert.assertThat;
import java.util.List; import java.util.List;
@ -15,12 +29,12 @@ import com.mongodb.DBObject;
/** /**
* Tests of {@link Projection}. * Tests of {@link Projection}.
* *
* @see DATAMONGO-586
* @author Tobias Trelle * @author Tobias Trelle
*/ */
public class ProjectionTests { public class ProjectionTests {
/** Unit under test. */ Projection projection;
private Projection projection;
@Before @Before
public void setUp() { public void setUp() {
@ -29,31 +43,24 @@ public class ProjectionTests {
@Test @Test
public void emptyProjection() { public void emptyProjection() {
// when
DBObject raw = projection.toDBObject();
// then DBObject raw = projection.toDBObject();
assertThat(raw, notNullValue()); assertThat(raw, is(notNullValue()));
assertThat(raw.toMap().isEmpty(), is(true)); assertThat(raw.toMap().isEmpty(), is(true));
} }
@Test(expected = IllegalArgumentException.class) @Test(expected = IllegalArgumentException.class)
public void shouldDetectNullIncludesInConstructor() { public void shouldDetectNullIncludesInConstructor() {
// when
new Projection((String[]) null); new Projection((String[]) null);
// then: throw expected exception
} }
@Test @Test
public void includesWithConstructor() { public void includesWithConstructor() {
// given
projection = new Projection("a", "b"); projection = new Projection("a", "b");
// when
DBObject raw = projection.toDBObject(); DBObject raw = projection.toDBObject();
assertThat(raw, is(notNullValue()));
// then
assertThat(raw, notNullValue());
assertThat(raw.toMap().size(), is(3)); assertThat(raw.toMap().size(), is(3));
assertThat((Integer) raw.get("_id"), is(0)); assertThat((Integer) raw.get("_id"), is(0));
assertThat((Integer) raw.get("a"), is(1)); assertThat((Integer) raw.get("a"), is(1));
@ -62,102 +69,88 @@ public class ProjectionTests {
@Test @Test
public void include() { public void include() {
// given
projection.include("a"); projection.include("a");
// when
DBObject raw = projection.toDBObject(); DBObject raw = projection.toDBObject();
// then
assertSingleDBObject("a", 1, raw); assertSingleDBObject("a", 1, raw);
} }
@Test @Test
public void exclude() { public void exclude() {
// given
projection.exclude("a"); projection.exclude("a");
// when
DBObject raw = projection.toDBObject(); DBObject raw = projection.toDBObject();
// then
assertSingleDBObject("a", 0, raw); assertSingleDBObject("a", 0, raw);
} }
@Test @Test
public void includeAlias() { public void includeAlias() {
// given
projection.include("a").as("b"); projection.include("a").as("b");
// when
DBObject raw = projection.toDBObject(); DBObject raw = projection.toDBObject();
// then
assertSingleDBObject("b", "$a", raw); assertSingleDBObject("b", "$a", raw);
} }
@Test(expected = InvalidDataAccessApiUsageException.class) @Test(expected = InvalidDataAccessApiUsageException.class)
public void shouldDetectAliasWithoutInclude() { public void shouldDetectAliasWithoutInclude() {
// when
projection.as("b"); projection.as("b");
// then: throw expected exception
} }
@Test(expected = InvalidDataAccessApiUsageException.class) @Test(expected = InvalidDataAccessApiUsageException.class)
public void shouldDetectDuplicateAlias() { public void shouldDetectDuplicateAlias() {
// when
projection.include("a").as("b").as("c"); projection.include("a").as("b").as("c");
// then: throw expected exception
} }
@Test @Test
@SuppressWarnings("unchecked")
public void plus() { public void plus() {
// given
projection.include("a").plus(10); projection.include("a").plus(10);
// when
DBObject raw = projection.toDBObject(); DBObject raw = projection.toDBObject();
// then
assertNotNullDBObject(raw); assertNotNullDBObject(raw);
DBObject addition = (DBObject)raw.get("a");
assertNotNullDBObject(addition); DBObject addition = (DBObject) raw.get("a");
@SuppressWarnings("unchecked") assertNotNullDBObject(addition);
List<Object> summands = (List<Object>)addition.get("$add");
assertThat( summands, notNullValue() ); List<Object> summands = (List<Object>) addition.get("$add");
assertThat( summands.size(), is(2) ); assertThat(summands, is(notNullValue()));
assertThat( (String)summands.get(0), is("$a") ); assertThat(summands.size(), is(2));
assertThat( (Integer)summands.get(1), is (10) ); assertThat((String) summands.get(0), is("$a"));
assertThat((Integer) summands.get(1), is(10));
} }
@Test @Test
@SuppressWarnings("unchecked")
public void plusWithAlias() { public void plusWithAlias() {
// given
projection.include("a").plus(10).as("b"); projection.include("a").plus(10).as("b");
// when
DBObject raw = projection.toDBObject(); DBObject raw = projection.toDBObject();
// then
assertNotNullDBObject(raw); assertNotNullDBObject(raw);
DBObject addition = (DBObject)raw.get("b");
assertNotNullDBObject(addition); DBObject addition = (DBObject) raw.get("b");
@SuppressWarnings("unchecked") assertNotNullDBObject(addition);
List<Object> summands = (List<Object>)addition.get("$add");
assertThat( summands, notNullValue() ); List<Object> summands = (List<Object>) addition.get("$add");
assertThat( summands.size(), is(2) ); assertThat(summands, is(notNullValue()));
assertThat( (String)summands.get(0), is("$a") ); assertThat(summands.size(), is(2));
assertThat( (Integer)summands.get(1), is (10) ); assertThat((String) summands.get(0), is("$a"));
assertThat((Integer) summands.get(1), is(10));
} }
private static void assertSingleDBObject(String key, Object value, DBObject doc) { private static void assertSingleDBObject(String key, Object value, DBObject doc) {
assertNotNullDBObject(doc); assertNotNullDBObject(doc);
assertThat(doc.get(key), is(value)); assertThat(doc.get(key), is(value));
} }
private static void assertNotNullDBObject(DBObject doc) { private static void assertNotNullDBObject(DBObject doc) {
assertThat(doc, notNullValue());
assertThat(doc, is(notNullValue()));
assertThat(doc.toMap().size(), is(1)); assertThat(doc.toMap().size(), is(1));
} }
} }

Loading…
Cancel
Save