diff --git a/src/main/java/org/springframework/data/repository/support/Repositories.java b/src/main/java/org/springframework/data/repository/support/Repositories.java index ef9e0b67c..cdf80f304 100644 --- a/src/main/java/org/springframework/data/repository/support/Repositories.java +++ b/src/main/java/org/springframework/data/repository/support/Repositories.java @@ -16,6 +16,8 @@ package org.springframework.data.repository.support; import java.io.Serializable; +import java.lang.reflect.Method; +import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.Iterator; @@ -29,6 +31,7 @@ import org.springframework.data.mapping.PersistentEntity; import org.springframework.data.mapping.context.MappingContext; import org.springframework.data.repository.CrudRepository; import org.springframework.data.repository.core.CrudInvoker; +import org.springframework.data.repository.core.CrudMethods; import org.springframework.data.repository.core.EntityInformation; import org.springframework.data.repository.core.RepositoryInformation; import org.springframework.data.repository.core.support.RepositoryFactoryInformation; @@ -202,7 +205,7 @@ public class Repositories implements Iterable> { Assert.notNull(repository, String.format("No repository found for domain class: %s", domainClass)); - if (repository instanceof CrudRepository) { + if (repository instanceof CrudRepository && !hasRedeclaredCrudMethods(domainClass)) { return new CrudRepositoryInvoker((CrudRepository) repository); } else { return new ReflectionRepositoryInvoker(repository, getRepositoryInformationFor(domainClass).getCrudMethods()); @@ -217,6 +220,25 @@ public class Repositories implements Iterable> { return repositoryFactoryInfos.keySet().iterator(); } + /** + * Returns whether any of the CRUD methods of the repository for the given domain type are redeclared. + * + * @param type must not be {@literal null}. + * @return + */ + private boolean hasRedeclaredCrudMethods(Class type) { + + CrudMethods crudMethods = getRepositoryInformationFor(type).getCrudMethods(); + + for (Method method : Arrays.asList(crudMethods.getFindOneMethod(), crudMethods.getSaveMethod())) { + if (!method.getDeclaringClass().equals(CrudRepository.class)) { + return true; + } + } + + return false; + } + /** * Null-object to avoid nasty {@literal null} checks in cache lookups. * diff --git a/src/test/java/org/springframework/data/repository/support/RepositoriesIntegrationTests.java b/src/test/java/org/springframework/data/repository/support/RepositoriesIntegrationTests.java index fce9fef7e..edb0a3fe1 100644 --- a/src/test/java/org/springframework/data/repository/support/RepositoriesIntegrationTests.java +++ b/src/test/java/org/springframework/data/repository/support/RepositoriesIntegrationTests.java @@ -78,6 +78,15 @@ public class RepositoriesIntegrationTests { public ProductRepository productRepository() { return mock(ProductRepository.class); } + + @Bean + public RepositoryFactoryBeanSupport, Order, Long> orderRepositoryFactory() { + + DummyRepositoryFactoryBean, Order, Long> factory = new DummyRepositoryFactoryBean, Order, Long>(); + factory.setRepositoryInterface(OrderRepository.class); + + return factory; + } } @Autowired Repositories repositories; @@ -97,6 +106,7 @@ public class RepositoriesIntegrationTests { assertThat(repositories, is(notNullValue())); assertThat(repositories.getCrudInvoker(User.class), is(instanceOf(CrudRepositoryInvoker.class))); assertThat(repositories.getCrudInvoker(Product.class), is(instanceOf(ReflectionRepositoryInvoker.class))); + assertThat(repositories.getCrudInvoker(Order.class), is(instanceOf(ReflectionRepositoryInvoker.class))); } /** @@ -132,12 +142,19 @@ public class RepositoriesIntegrationTests { } - public static class Product {} + static class Product {} - public static interface ProductRepository extends Repository { + interface ProductRepository extends Repository { Product findOne(Long id); Product save(Product product); } + + static class Order {} + + interface OrderRepository extends CrudRepository { + + Order findOne(Long id); + } }