diff --git a/src/main/java/org/springframework/data/repository/config/CustomRepositoryImplementationDetector.java b/src/main/java/org/springframework/data/repository/config/CustomRepositoryImplementationDetector.java index 3552ab2ff..8447247ca 100644 --- a/src/main/java/org/springframework/data/repository/config/CustomRepositoryImplementationDetector.java +++ b/src/main/java/org/springframework/data/repository/config/CustomRepositoryImplementationDetector.java @@ -72,7 +72,7 @@ public class CustomRepositoryImplementationDetector { return detectCustomImplementation( // configuration.getImplementationClassName(), // configuration.getImplementationBeanName(), // - configuration.getImplementationBasePackages(configuration.getImplementationClassName()), // + configuration.getImplementationBasePackages(), // configuration.getExcludeFilters(), // bd -> configuration.getConfigurationSource().generateBeanName(bd)); } diff --git a/src/main/java/org/springframework/data/repository/config/DefaultRepositoryConfiguration.java b/src/main/java/org/springframework/data/repository/config/DefaultRepositoryConfiguration.java index 98180ddc4..42aef84ce 100644 --- a/src/main/java/org/springframework/data/repository/config/DefaultRepositoryConfiguration.java +++ b/src/main/java/org/springframework/data/repository/config/DefaultRepositoryConfiguration.java @@ -74,13 +74,13 @@ public class DefaultRepositoryConfiguration getImplementationBasePackages(String interfaceClassName) { + public Streamable getImplementationBasePackages() { return configurationSource.shouldLimitRepositoryImplementationBasePackages() - ? Streamable.of(ClassUtils.getPackageName(interfaceClassName)) + ? Streamable.of(ClassUtils.getPackageName(getRepositoryInterface())) : getBasePackages(); } diff --git a/src/main/java/org/springframework/data/repository/config/RepositoryBeanDefinitionBuilder.java b/src/main/java/org/springframework/data/repository/config/RepositoryBeanDefinitionBuilder.java index 7d5b81c81..13ff86667 100644 --- a/src/main/java/org/springframework/data/repository/config/RepositoryBeanDefinitionBuilder.java +++ b/src/main/java/org/springframework/data/repository/config/RepositoryBeanDefinitionBuilder.java @@ -15,16 +15,21 @@ */ package org.springframework.data.repository.config; +import lombok.Value; + import java.io.IOException; import java.util.Arrays; +import java.util.Collections; import java.util.List; import java.util.Optional; +import java.util.function.Function; import java.util.stream.Collectors; import java.util.stream.Stream; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.beans.factory.BeanDefinitionStoreException; +import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.support.AbstractBeanDefinition; import org.springframework.beans.factory.support.BeanDefinitionBuilder; import org.springframework.beans.factory.support.BeanDefinitionRegistry; @@ -42,6 +47,7 @@ import org.springframework.data.repository.NoRepositoryBean; import org.springframework.data.repository.core.support.RepositoryFragment; import org.springframework.data.repository.core.support.RepositoryFragmentsFactoryBean; import org.springframework.data.repository.query.ExtensionAwareEvaluationContextProvider; +import org.springframework.data.util.Optionals; import org.springframework.data.util.StreamUtils; import org.springframework.util.Assert; import org.springframework.util.ClassUtils; @@ -127,7 +133,6 @@ class RepositoryBeanDefinitionBuilder { .rootBeanDefinition(RepositoryFragmentsFactoryBean.class); List fragmentBeanNames = registerRepositoryFragmentsImplementation(configuration) // - .stream() // .map(RepositoryFragmentConfiguration::getFragmentBeanName) // .collect(Collectors.toList()); @@ -171,41 +176,29 @@ class RepositoryBeanDefinitionBuilder { }); } - private List registerRepositoryFragmentsImplementation( + private Stream registerRepositoryFragmentsImplementation( RepositoryConfiguration configuration) { ClassMetadata classMetadata = getClassMetadata(configuration.getRepositoryInterface()); return Arrays.stream(classMetadata.getInterfaceNames()) // - .filter(this::isFragmentInterfaceCandidate) // - .map(it -> detectRepositoryFragmentConfiguration(configuration, it)) // - .filter(Optional::isPresent) // - .map(Optional::get) // + .filter(it -> FragmentMetadata.isCandidate(it, metadataReaderFactory)) // + .map(it -> FragmentMetadata.of(it, configuration)) // + .map(it -> detectRepositoryFragmentConfiguration(it)) // + .flatMap(it -> Optionals.toStream(it)) // .peek(it -> potentiallyRegisterFragmentImplementation(configuration, it)) // - .peek(it -> potentiallyRegisterRepositoryFragment(configuration, it)) // - .collect(Collectors.toList()); - } - - private boolean isFragmentInterfaceCandidate(String interfaceName) { - - AnnotationMetadata metadata = getAnnotationMetadata(interfaceName); - - return !metadata.hasAnnotation(NoRepositoryBean.class.getName()); + .peek(it -> potentiallyRegisterRepositoryFragment(configuration, it)); } private Optional detectRepositoryFragmentConfiguration( - RepositoryConfiguration configuration, String fragmentInterfaceName) { + FragmentMetadata configuration) { - List exclusions = getExclusions(configuration); - - String className = ClassUtils.getShortName(fragmentInterfaceName) - .concat(configuration.getConfigurationSource().getRepositoryImplementationPostfix().orElse("Impl")); + String className = configuration.getFragmentImplementationClassName(); Optional beanDefinition = implementationDetector.detectCustomImplementation(className, null, - configuration.getImplementationBasePackages(fragmentInterfaceName), exclusions, - bd -> configuration.getConfigurationSource().generateBeanName(bd)); + configuration.getBasePackages(), configuration.getExclusions(), configuration.getBeanNameGenerator()); - return beanDefinition.map(bd -> new RepositoryFragmentConfiguration(fragmentInterfaceName, bd)); + return beanDefinition.map(bd -> new RepositoryFragmentConfiguration(configuration.getFragmentInterfaceName(), bd)); } private void potentiallyRegisterFragmentImplementation(RepositoryConfiguration repositoryConfiguration, @@ -263,19 +256,84 @@ class RepositoryBeanDefinitionBuilder { } } - private AnnotationMetadata getAnnotationMetadata(String className) { + @Value(staticConstructor = "of") + static class FragmentMetadata { - try { - return metadataReaderFactory.getMetadataReader(className).getAnnotationMetadata(); - } catch (IOException e) { - throw new BeanDefinitionStoreException(String.format("Cannot parse %s metadata.", className), e); + String fragmentInterfaceName; + RepositoryConfiguration configuration; + + /** + * Returns whether the given interface is a fragment candidate. + * + * @param interfaceName must not be {@literal null} or empty. + * @param factory must not be {@literal null}. + * @return + */ + public static boolean isCandidate(String interfaceName, MetadataReaderFactory factory) { + + Assert.hasText(interfaceName, "Interface name must not be null or empty!"); + Assert.notNull(factory, "MetadataReaderFactory must not be null!"); + + AnnotationMetadata metadata = getAnnotationMetadata(interfaceName, factory); + + return !metadata.hasAnnotation(NoRepositoryBean.class.getName()); } - } - private static List getExclusions(RepositoryConfiguration configuration) { + /** + * Returns the exclusions to be used when scanning for fragment implementations. + * + * @return + */ + public List getExclusions() { - return Stream - .concat(configuration.getExcludeFilters().stream(), Stream.of(new AnnotationTypeFilter(NoRepositoryBean.class)))// - .collect(StreamUtils.toUnmodifiableList()); + Stream configurationExcludes = configuration.getExcludeFilters().stream(); + Stream noRepositoryBeans = Stream.of(new AnnotationTypeFilter(NoRepositoryBean.class)); + + return Stream.concat(configurationExcludes, noRepositoryBeans).collect(StreamUtils.toUnmodifiableList()); + } + + /** + * Returns the name of the implementation class to be detected for the fragment interface. + * + * @return + */ + public String getFragmentImplementationClassName() { + + RepositoryConfigurationSource configurationSource = configuration.getConfigurationSource(); + String postfix = configurationSource.getRepositoryImplementationPostfix().orElse("Impl"); + + return ClassUtils.getShortName(fragmentInterfaceName).concat(postfix); + } + + /** + * Returns the base packages to be scanned to find implementations of the current fragment interface. + * + * @return + */ + public Iterable getBasePackages() { + + return configuration.getConfigurationSource().shouldLimitRepositoryImplementationBasePackages() ? // + Collections.singleton(ClassUtils.getPackageName(fragmentInterfaceName)) : // + configuration.getImplementationBasePackages(); + } + + /** + * Returns the bean name generating function to be used for the fragment. + * + * @return + */ + public Function getBeanNameGenerator() { + return definition -> configuration.getConfigurationSource().generateBeanName(definition); + } + + private static AnnotationMetadata getAnnotationMetadata(String className, + MetadataReaderFactory metadataReaderFactory) { + + try { + return metadataReaderFactory.getMetadataReader(className).getAnnotationMetadata(); + } catch (IOException e) { + throw new BeanDefinitionStoreException(String.format("Cannot parse %s metadata.", className), e); + } + } } } diff --git a/src/main/java/org/springframework/data/repository/config/RepositoryConfiguration.java b/src/main/java/org/springframework/data/repository/config/RepositoryConfiguration.java index b1dc5368b..a4c31a17d 100644 --- a/src/main/java/org/springframework/data/repository/config/RepositoryConfiguration.java +++ b/src/main/java/org/springframework/data/repository/config/RepositoryConfiguration.java @@ -40,11 +40,10 @@ public interface RepositoryConfiguration getImplementationBasePackages(String interfaceClassName); + Streamable getImplementationBasePackages(); /** * Returns the interface name of the repository. diff --git a/src/test/java/org/springframework/data/repository/config/DefaultRepositoryConfigurationUnitTests.java b/src/test/java/org/springframework/data/repository/config/DefaultRepositoryConfigurationUnitTests.java index e71f66c48..e945e1630 100755 --- a/src/test/java/org/springframework/data/repository/config/DefaultRepositoryConfigurationUnitTests.java +++ b/src/test/java/org/springframework/data/repository/config/DefaultRepositoryConfigurationUnitTests.java @@ -89,7 +89,7 @@ public class DefaultRepositoryConfigurationUnitTests { when(source.shouldLimitRepositoryImplementationBasePackages()).thenReturn(true); - assertThat(getConfiguration(source).getImplementationBasePackages("com.acme.MyRepository")) + assertThat(getConfiguration(source, "com.acme.MyRepository").getImplementationBasePackages()) .containsOnly("com.acme"); } @@ -98,7 +98,7 @@ public class DefaultRepositoryConfigurationUnitTests { when(source.shouldLimitRepositoryImplementationBasePackages()).thenReturn(true); - assertThat(getConfiguration(source).getImplementationBasePackages(NestedInterface.class.getName())) + assertThat(getConfiguration(source, NestedInterface.class.getName()).getImplementationBasePackages()) .containsOnly("org.springframework.data.repository.config"); } @@ -107,13 +107,18 @@ public class DefaultRepositoryConfigurationUnitTests { when(source.getBasePackages()).thenReturn(Streamable.of("com", "org.coyote")); - assertThat(getConfiguration(source).getImplementationBasePackages("com.acme.MyRepository")).contains("com", + assertThat(getConfiguration(source, "com.acme.MyRepository").getImplementationBasePackages()).contains("com", "org.coyote"); } private DefaultRepositoryConfiguration getConfiguration( RepositoryConfigurationSource source) { - RootBeanDefinition beanDefinition = createBeanDefinition(); + return getConfiguration(source, "com.acme.MyRepository"); + } + + private DefaultRepositoryConfiguration getConfiguration( + RepositoryConfigurationSource source, String repositoryInterfaceName) { + RootBeanDefinition beanDefinition = createBeanDefinition(repositoryInterfaceName); return new DefaultRepositoryConfiguration<>(source, beanDefinition, extension); } @@ -123,9 +128,9 @@ public class DefaultRepositoryConfigurationUnitTests { String repositoryFactoryBeanClassName, modulePrefix; } - private static RootBeanDefinition createBeanDefinition() { + private static RootBeanDefinition createBeanDefinition(String repositoryInterfaceName) { - RootBeanDefinition beanDefinition = new RootBeanDefinition("com.acme.MyRepository"); + RootBeanDefinition beanDefinition = new RootBeanDefinition(repositoryInterfaceName); ConstructorArgumentValues constructorArgumentValues = new ConstructorArgumentValues(); constructorArgumentValues.addGenericArgumentValue(MyRepository.class);