diff --git a/src/main/java/org/springframework/data/web/config/SpringDataWebConfiguration.java b/src/main/java/org/springframework/data/web/config/SpringDataWebConfiguration.java index c7c153046..2d3c6fb3b 100644 --- a/src/main/java/org/springframework/data/web/config/SpringDataWebConfiguration.java +++ b/src/main/java/org/springframework/data/web/config/SpringDataWebConfiguration.java @@ -23,6 +23,7 @@ import org.springframework.beans.factory.annotation.Qualifier; import org.springframework.context.ApplicationContext; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.Lazy; import org.springframework.core.convert.ConversionService; import org.springframework.data.geo.format.DistanceFormatter; import org.springframework.data.geo.format.PointFormatter; @@ -68,6 +69,17 @@ public class SpringDataWebConfiguration extends WebMvcConfigurerAdapter { return new SortHandlerMethodArgumentResolver(); } + /** + * Default QuerydslPredicateArgumentResolver. + * + * @return + */ + @Bean + @Lazy + public QuerydslPredicateArgumentResolver querydslPredicateArgumentResolver() { + return new QuerydslPredicateArgumentResolver(conversionService.getObject()); + } + /* * (non-Javadoc) * @see org.springframework.web.servlet.config.annotation.WebMvcConfigurerAdapter#addFormatters(org.springframework.format.FormatterRegistry) @@ -100,7 +112,7 @@ public class SpringDataWebConfiguration extends WebMvcConfigurerAdapter { argumentResolvers.add(pageableResolver()); if (QueryDslUtils.QUERY_DSL_PRESENT) { - argumentResolvers.add(new QuerydslPredicateArgumentResolver(conversionService.getObject())); + argumentResolvers.add(querydslPredicateArgumentResolver()); } ProxyingHandlerMethodArgumentResolver resolver = new ProxyingHandlerMethodArgumentResolver( @@ -110,4 +122,5 @@ public class SpringDataWebConfiguration extends WebMvcConfigurerAdapter { argumentResolvers.add(resolver); } + } diff --git a/src/main/java/org/springframework/data/web/querydsl/QuerydslPredicateArgumentResolver.java b/src/main/java/org/springframework/data/web/querydsl/QuerydslPredicateArgumentResolver.java index a3f8a4e64..d41c083ed 100644 --- a/src/main/java/org/springframework/data/web/querydsl/QuerydslPredicateArgumentResolver.java +++ b/src/main/java/org/springframework/data/web/querydsl/QuerydslPredicateArgumentResolver.java @@ -20,6 +20,7 @@ import java.util.Map.Entry; import org.springframework.beans.BeanUtils; import org.springframework.beans.BeansException; +import org.springframework.beans.factory.NoSuchBeanDefinitionException; import org.springframework.beans.factory.config.AutowireCapableBeanFactory; import org.springframework.context.ApplicationContext; import org.springframework.context.ApplicationContextAware; @@ -27,6 +28,7 @@ import org.springframework.core.MethodParameter; import org.springframework.core.convert.ConversionService; import org.springframework.core.convert.support.DefaultConversionService; import org.springframework.data.querydsl.SimpleEntityPathResolver; +import org.springframework.data.repository.support.Repositories; import org.springframework.data.util.ClassTypeInformation; import org.springframework.data.util.TypeInformation; import org.springframework.util.LinkedMultiValueMap; @@ -53,6 +55,8 @@ public class QuerydslPredicateArgumentResolver implements HandlerMethodArgumentR private AutowireCapableBeanFactory beanFactory; + private Repositories repositories; + /** * Creates a new {@link QuerydslPredicateArgumentResolver} using the given {@link ConversionService}. * @@ -69,7 +73,9 @@ public class QuerydslPredicateArgumentResolver implements HandlerMethodArgumentR */ @Override public void setApplicationContext(ApplicationContext applicationContext) throws BeansException { + this.beanFactory = applicationContext.getAutowireCapableBeanFactory(); + this.repositories = new Repositories(applicationContext); } /* @@ -127,17 +133,48 @@ public class QuerydslPredicateArgumentResolver implements HandlerMethodArgumentR EntityPath path = SimpleEntityPathResolver.INSTANCE.createPath(domainType); + QuerydslBinderCustomizer customizer = findCustomizerForDomainType(annotation, domainType); + QuerydslBindings bindings = new QuerydslBindings(); + if (customizer != null) { + customizer.customize(bindings, path); + } + return bindings; + } + + @SuppressWarnings("unchecked") + private QuerydslBinderCustomizer> findCustomizerForDomainType(QuerydslPredicate annotation, + Class domainType) { + + if (annotation == null || (annotation != null && annotation.bindings().equals(QuerydslBinderCustomizer.class))) { + if (repositories != null && repositories.hasRepositoryFor(domainType)) { + + Object repository = this.repositories.getRepositoryFor(domainType); + if (repository instanceof QuerydslBinderCustomizer) { + return (QuerydslBinderCustomizer>) repository; + } + } - if (annotation == null || annotation.bindings().equals(QuerydslBinderCustomizer.class)) { - return bindings; + return null; } - Class type = annotation.bindings(); - QuerydslBinderCustomizer> customizer = beanFactory != null ? beanFactory.createBean(type) - : BeanUtils.instantiateClass(type); - customizer.customize(bindings, path); + return createQuerydslBinderCustomizer(annotation.bindings()); + } - return bindings; + @SuppressWarnings({ "unchecked", "rawtypes" }) + private QuerydslBinderCustomizer> createQuerydslBinderCustomizer( + Class type) { + + if (beanFactory == null) { + return BeanUtils.instantiateClass(type); + } + + try { + return beanFactory.getBean(type); + } catch (NoSuchBeanDefinitionException e) { + + } + + return beanFactory.createBean(type); } } diff --git a/src/test/java/org/springframework/data/web/config/SpringDataWebConfigurationUnitTests.java b/src/test/java/org/springframework/data/web/config/SpringDataWebConfigurationUnitTests.java new file mode 100644 index 000000000..dd3140eb7 --- /dev/null +++ b/src/test/java/org/springframework/data/web/config/SpringDataWebConfigurationUnitTests.java @@ -0,0 +1,105 @@ +/* + * Copyright 2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.web.config; + +import static org.junit.Assert.*; +import static org.springframework.test.util.ReflectionTestUtils.*; + +import java.net.URLClassLoader; +import java.util.ArrayList; +import java.util.List; + +import org.junit.Test; +import org.springframework.beans.BeansException; +import org.springframework.beans.factory.ObjectFactory; +import org.springframework.core.convert.ConversionService; +import org.springframework.data.web.querydsl.QuerydslPredicateArgumentResolver; +import org.springframework.instrument.classloading.ShadowingClassLoader; +import org.springframework.web.method.support.HandlerMethodArgumentResolver; + +/** + * @author Christoph Strobl + */ +public class SpringDataWebConfigurationUnitTests { + + /** + * @see DATACMNS-669 + */ + @Test + public void shouldNotAddQuerydslPredicateArgumentResolverWhenQuerydslNotPresent() throws ClassNotFoundException, + InstantiationException, IllegalAccessException { + + ClassLoader classLoader = initClassLoader(); + + Object config = classLoader.loadClass("org.springframework.data.web.config.SpringDataWebConfiguration") + .newInstance(); + + setField(config, "context", + classLoader.loadClass("org.springframework.web.context.support.GenericWebApplicationContext").newInstance()); + setField( + config, + "conversionService", + classLoader.loadClass( + "org.springframework.data.web.config.SpringDataWebConfigurationUnitTests$ObjectFactoryImpl").newInstance()); + + List argumentResolvers = new ArrayList(); + + invokeMethod(config, "addArgumentResolvers", argumentResolvers); + + for (Object resolver : argumentResolvers) { + if (resolver instanceof QuerydslPredicateArgumentResolver) { + fail("QuerydslPredicateArgumentResolver should not be present when Querydsl not on path"); + } + } + } + + private ClassLoader initClassLoader() { + + ClassLoader classLoader = new ShadowingClassLoader(URLClassLoader.getSystemClassLoader()) { + + @Override + public Class loadClass(String name) throws ClassNotFoundException { + + if (name.startsWith("com.mysema")) { + throw new ClassNotFoundException(); + } + + return super.loadClass(name); + } + + @Override + protected Class findClass(String name) throws ClassNotFoundException { + + if (name.startsWith("com.mysema")) { + throw new ClassNotFoundException(); + } + + return super.findClass(name); + } + }; + + return classLoader; + } + + public static class ObjectFactoryImpl implements ObjectFactory { + + @Override + public ConversionService getObject() throws BeansException { + return null; + } + + } +} diff --git a/src/test/java/org/springframework/data/web/querydsl/QuerydslPredicateArgumentResolverUnitTests.java b/src/test/java/org/springframework/data/web/querydsl/QuerydslPredicateArgumentResolverUnitTests.java index 71e840311..0acb1aac0 100644 --- a/src/test/java/org/springframework/data/web/querydsl/QuerydslPredicateArgumentResolverUnitTests.java +++ b/src/test/java/org/springframework/data/web/querydsl/QuerydslPredicateArgumentResolverUnitTests.java @@ -17,20 +17,28 @@ package org.springframework.data.web.querydsl; import static org.hamcrest.Matchers.*; import static org.junit.Assert.*; +import static org.mockito.Mockito.*; + +import java.util.Collections; import org.junit.Before; import org.junit.Test; +import org.springframework.beans.factory.config.AutowireCapableBeanFactory; import org.springframework.core.MethodParameter; import org.springframework.data.domain.Page; import org.springframework.data.domain.Pageable; +import org.springframework.data.mapping.PropertyPath; import org.springframework.data.querydsl.QUser; import org.springframework.data.querydsl.User; +import org.springframework.data.repository.core.RepositoryInformation; +import org.springframework.data.repository.support.Repositories; import org.springframework.data.util.ClassTypeInformation; import org.springframework.data.util.TypeInformation; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.test.util.ReflectionTestUtils; import org.springframework.web.context.request.ServletWebRequest; +import com.mysema.query.types.Path; import com.mysema.query.types.Predicate; import com.mysema.query.types.expr.BooleanExpression; import com.mysema.query.types.path.StringPath; @@ -218,6 +226,53 @@ public class QuerydslPredicateArgumentResolverUnitTests { assertThat(type, is((TypeInformation) ClassTypeInformation.from(User.class))); } + /** + * @see DATACMNS-669 + */ + @Test + @SuppressWarnings({ "unchecked", "rawtypes" }) + public void createBindingsShouldHonorQuerydslBinderCustomizerHookWhenPresent() { + + Repositories repositories = mock(Repositories.class); + RepositoryInformation repoInfo = mock(RepositoryInformation.class); + + when(repositories.hasRepositoryFor(User.class)).thenReturn(true); + when(repositories.getRepositoryFor(User.class)).thenReturn(new SampleRepo()); + + resolver = new QuerydslPredicateArgumentResolver(null); + ReflectionTestUtils.setField(resolver, "repositories", repositories); + + QuerydslBindings bindings = resolver.createBindings(null, User.class); + MultiValueBinding, Object> binding = bindings.getBindingForPath(PropertyPath.from("firstname", + User.class)); + + assertThat(binding.bind((Path) QUser.user.firstname, Collections.singleton("rand")), + is((Predicate) QUser.user.firstname.contains("rand"))); + } + + /** + * @see DATACMNS-669 + */ + @Test + @SuppressWarnings({ "unchecked", "rawtypes" }) + public void shouldReuseExistingQuerydslBinderCustomizer() { + + AutowireCapableBeanFactory beanFactory = mock(AutowireCapableBeanFactory.class); + when(beanFactory.getBean(SpecificBinding.class)).thenReturn(new SpecificBinding()); + QuerydslPredicate annotation = getMethodParameterFor("specificFind", Predicate.class).getParameterAnnotation( + QuerydslPredicate.class); + + resolver = new QuerydslPredicateArgumentResolver(null); + ReflectionTestUtils.setField(resolver, "beanFactory", beanFactory); + + QuerydslBindings bindings = resolver.createBindings(annotation, User.class); + MultiValueBinding, Object> binding = bindings.getBindingForPath(PropertyPath.from("firstname", + User.class)); + + assertThat(binding.bind((Path) QUser.user.firstname, Collections.singleton("rand")), + is((Predicate) QUser.user.firstname.eq("RAND"))); + } + private static MethodParameter getMethodParameterFor(String methodName, Class... args) throws RuntimeException { try { @@ -265,4 +320,19 @@ public class QuerydslPredicateArgumentResolverUnitTests { User specificFind(@QuerydslPredicate(bindings = SpecificBinding.class) Predicate predicate); } + + public static class SampleRepo implements QuerydslBinderCustomizer { + + @Override + public void customize(QuerydslBindings bindings, QUser user) { + + bindings.bind(QUser.user.firstname).single(new SingleValueBinding() { + + @Override + public Predicate bind(StringPath path, String value) { + return path.contains(value); + } + }); + } + } }