diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/PartTreeMongoQuery.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/PartTreeMongoQuery.java index ebeb5350e..b8a34f118 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/PartTreeMongoQuery.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/PartTreeMongoQuery.java @@ -19,10 +19,14 @@ import org.springframework.data.mapping.context.MappingContext; import org.springframework.data.mongodb.core.MongoOperations; import org.springframework.data.mongodb.core.MongoTemplate; import org.springframework.data.mongodb.core.mapping.MongoPersistentProperty; +import org.springframework.data.mongodb.core.query.BasicQuery; import org.springframework.data.mongodb.core.query.Query; import org.springframework.data.repository.query.QueryMethod; import org.springframework.data.repository.query.RepositoryQuery; import org.springframework.data.repository.query.parser.PartTree; +import org.springframework.util.StringUtils; + +import com.mongodb.util.JSONParseException; /** * {@link RepositoryQuery} implementation for Mongo. @@ -67,7 +71,24 @@ public class PartTreeMongoQuery extends AbstractMongoQuery { protected Query createQuery(ConvertingParameterAccessor accessor) { MongoQueryCreator creator = new MongoQueryCreator(tree, accessor, context, isGeoNearQuery); - return creator.createQuery(); + Query query = creator.createQuery(); + + String fieldSpec = this.getQueryMethod().getFieldSpecification(); + + if (!StringUtils.hasText(fieldSpec)) { + return query; + } + + try { + + BasicQuery result = new BasicQuery(query.getQueryObject().toString(), fieldSpec); + result.setSortObject(query.getSortObject()); + return result; + + } catch (JSONParseException o_O) { + throw new IllegalStateException(String.format("Invalid query or field specification in %s!", getQueryMethod(), + o_O)); + } } /* diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/PartTreeMongoQueryUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/PartTreeMongoQueryUnitTests.java new file mode 100644 index 000000000..edeecbd5f --- /dev/null +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/PartTreeMongoQueryUnitTests.java @@ -0,0 +1,166 @@ +/* + * Copyright 2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.repository.query; + +import static org.hamcrest.CoreMatchers.*; +import static org.junit.Assert.*; +import static org.mockito.Mockito.*; + +import java.lang.reflect.Method; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.mockito.Matchers; +import org.mockito.Mock; +import org.mockito.runners.MockitoJUnitRunner; +import org.springframework.data.mongodb.MongoDbFactory; +import org.springframework.data.mongodb.core.MongoOperations; +import org.springframework.data.mongodb.core.convert.DbRefResolver; +import org.springframework.data.mongodb.core.convert.DefaultDbRefResolver; +import org.springframework.data.mongodb.core.convert.MappingMongoConverter; +import org.springframework.data.mongodb.core.convert.MongoConverter; +import org.springframework.data.mongodb.core.mapping.MongoMappingContext; +import org.springframework.data.mongodb.repository.MongoRepository; +import org.springframework.data.mongodb.repository.Person; +import org.springframework.data.mongodb.repository.Query; +import org.springframework.data.repository.core.RepositoryMetadata; + +import com.mongodb.BasicDBObjectBuilder; + +/** + * Unit tests for {@link PartTreeMongoQuery}. + * + * @author Christoph Strobl + * @author Oliver Gierke + */ +@RunWith(MockitoJUnitRunner.class) +public class PartTreeMongoQueryUnitTests { + + @Mock RepositoryMetadata metadataMock; + @Mock MongoOperations mongoOperationsMock; + + MongoMappingContext mappingContext; + + public @Rule ExpectedException exception = ExpectedException.none(); + + @Before + @SuppressWarnings({ "unchecked", "rawtypes" }) + public void setUp() { + + when(metadataMock.getDomainType()).thenReturn((Class) Person.class); + when(metadataMock.getReturnedDomainClass(Matchers.any(Method.class))).thenReturn((Class) Person.class); + mappingContext = new MongoMappingContext(); + DbRefResolver dbRefResolver = new DefaultDbRefResolver(mock(MongoDbFactory.class)); + MongoConverter converter = new MappingMongoConverter(dbRefResolver, mappingContext); + + when(mongoOperationsMock.getConverter()).thenReturn(converter); + } + + /** + * @see DATAMOGO-952 + */ + @Test + public void rejectsInvalidFieldSpecification() { + + exception.expect(IllegalStateException.class); + exception.expectMessage("findByLastname"); + + deriveQueryFromMethod("findByLastname", new Object[] { "foo" }); + } + + /** + * @see DATAMOGO-952 + */ + @Test + public void singleFieldJsonIncludeRestrictionShouldBeConsidered() { + + org.springframework.data.mongodb.core.query.Query query = deriveQueryFromMethod("findByFirstname", + new Object[] { "foo" }); + + assertThat(query.getFieldsObject(), is(new BasicDBObjectBuilder().add("firstname", 1).get())); + } + + /** + * @see DATAMOGO-952 + */ + @Test + public void multiFieldJsonIncludeRestrictionShouldBeConsidered() { + + org.springframework.data.mongodb.core.query.Query query = deriveQueryFromMethod("findByFirstnameAndLastname", + new Object[] { "foo", "bar" }); + + assertThat(query.getFieldsObject(), is(new BasicDBObjectBuilder().add("firstname", 1).add("lastname", 1).get())); + } + + /** + * @see DATAMOGO-952 + */ + @Test + public void multiFieldJsonExcludeRestrictionShouldBeConsidered() { + + org.springframework.data.mongodb.core.query.Query query = deriveQueryFromMethod("findPersonByFirstnameAndLastname", + new Object[] { "foo", "bar" }); + + assertThat(query.getFieldsObject(), is(new BasicDBObjectBuilder().add("firstname", 0).add("lastname", 0).get())); + } + + private org.springframework.data.mongodb.core.query.Query deriveQueryFromMethod(String method, Object[] args) { + + Class[] types = new Class[args.length]; + + for (int i = 0; i < args.length; i++) { + types[i] = args[i].getClass(); + } + + PartTreeMongoQuery partTreeQuery = createQueryForMethod(method, types); + + MongoParameterAccessor accessor = new MongoParametersParameterAccessor(partTreeQuery.getQueryMethod(), args); + return partTreeQuery.createQuery(new ConvertingParameterAccessor(mongoOperationsMock.getConverter(), accessor)); + } + + private PartTreeMongoQuery createQueryForMethod(String methodName, Class... paramTypes) { + + try { + + Method method = Repo.class.getMethod(methodName, paramTypes); + MongoQueryMethod queryMethod = new MongoQueryMethod(method, metadataMock, mappingContext); + + return new PartTreeMongoQuery(queryMethod, mongoOperationsMock); + } catch (NoSuchMethodException e) { + throw new IllegalArgumentException(e.getMessage(), e); + } catch (SecurityException e) { + throw new IllegalArgumentException(e.getMessage(), e); + } + } + + interface Repo extends MongoRepository { + + @Query(fields = "firstname") + Person findByLastname(String lastname); + + @Query(fields = "{ 'firstname' : 1 }") + Person findByFirstname(String lastname); + + @Query(fields = "{ 'firstname' : 1, 'lastname' : 1 }") + Person findByFirstnameAndLastname(String firstname, String lastname); + + @Query(fields = "{ 'firstname' : 0, 'lastname' : 0 }") + Person findPersonByFirstnameAndLastname(String firstname, String lastname); + } +}