Browse Source

Refine Repository Composition retrieval during AOT.

Add module identifier and base repository implementation properties.
Fix fragment function previously overriding already set property due to name clash.
Extend tests for bean definition resolution and code block creation.

See: #3279
Original Pull Request: #3282
pull/3304/head
Christoph Strobl 8 months ago committed by Mark Paluch
parent
commit
9c966376a7
No known key found for this signature in database
GPG Key ID: 55BC6374BAA9D973
  1. 7
      src/main/java/org/springframework/data/repository/aot/generate/AotRepositoryBuilder.java
  2. 2
      src/main/java/org/springframework/data/repository/aot/generate/MethodContributor.java
  3. 2
      src/main/java/org/springframework/data/repository/config/AotRepositoryBeanDefinitionPropertiesDecorator.java
  4. 36
      src/main/java/org/springframework/data/repository/config/AotRepositoryInformation.java
  5. 131
      src/main/java/org/springframework/data/repository/config/RepositoryBeanDefinitionReader.java
  6. 21
      src/main/java/org/springframework/data/repository/config/RepositoryRegistrationAotContribution.java
  7. 5
      src/main/java/org/springframework/data/repository/core/RepositoryInformation.java
  8. 2
      src/main/java/org/springframework/data/repository/core/RepositoryInformationSupport.java
  9. 16
      src/main/java/org/springframework/data/repository/core/support/RepositoryFactoryBeanSupport.java
  10. 2
      src/main/java/org/springframework/data/repository/core/support/RepositoryFragment.java
  11. 2
      src/test/java/example/UserRepository.java
  12. 25
      src/test/java/example/UserRepositoryExtension.java
  13. 29
      src/test/java/example/UserRepositoryExtensionImpl.java
  14. 157
      src/test/java/org/springframework/data/repository/aot/generate/AotRepositoryBuilderUnitTests.java
  15. 88
      src/test/java/org/springframework/data/repository/aot/generate/AotRepositoryMethodBuilderUnitTests.java
  16. 57
      src/test/java/org/springframework/data/repository/aot/generate/MethodCapturingRepositoryContributor.java
  17. 167
      src/test/java/org/springframework/data/repository/aot/generate/RepositoryContributorUnitTests.java
  18. 122
      src/test/java/org/springframework/data/repository/config/RepositoryBeanDefinitionReaderTests.java

7
src/main/java/org/springframework/data/repository/aot/generate/AotRepositoryBuilder.java

@ -138,9 +138,8 @@ class AotRepositoryBuilder {
this.customizer.customize(repositoryInformation, generationMetadata, builder); this.customizer.customize(repositoryInformation, generationMetadata, builder);
JavaFile javaFile = JavaFile.builder(packageName(), builder.build()).build(); JavaFile javaFile = JavaFile.builder(packageName(), builder.build()).build();
// TODO: module identifier
AotRepositoryMetadata metadata = new AotRepositoryMetadata(repositoryInformation.getRepositoryInterface().getName(), AotRepositoryMetadata metadata = new AotRepositoryMetadata(repositoryInformation.getRepositoryInterface().getName(),
"", repositoryType, methodMetadata); repositoryInformation.moduleName() != null ? repositoryInformation.moduleName() : "", repositoryType, methodMetadata);
return new AotBundle(javaFile, metadata.toJson()); return new AotBundle(javaFile, metadata.toJson());
} }
@ -148,15 +147,15 @@ class AotRepositoryBuilder {
private void contributeMethod(Method method, RepositoryComposition repositoryComposition, private void contributeMethod(Method method, RepositoryComposition repositoryComposition,
List<AotRepositoryMethod> methodMetadata, TypeSpec.Builder builder) { List<AotRepositoryMethod> methodMetadata, TypeSpec.Builder builder) {
if (repositoryInformation.isCustomMethod(method) || repositoryInformation.isBaseClassMethod(method)) { if (repositoryInformation.isCustomMethod(method) || (repositoryInformation.isBaseClassMethod(method) && !repositoryInformation.isQueryMethod(method))) {
RepositoryFragment<?> fragment = repositoryComposition.findFragment(method); RepositoryFragment<?> fragment = repositoryComposition.findFragment(method);
if (fragment != null) { if (fragment != null) {
methodMetadata.add(getFragmentMetadata(method, fragment)); methodMetadata.add(getFragmentMetadata(method, fragment));
}
return; return;
} }
}
if (method.isBridge() || method.isDefault() || java.lang.reflect.Modifier.isStatic(method.getModifiers())) { if (method.isBridge() || method.isDefault() || java.lang.reflect.Modifier.isStatic(method.getModifiers())) {
return; return;

2
src/main/java/org/springframework/data/repository/aot/generate/MethodContributor.java

@ -36,7 +36,7 @@ public abstract class MethodContributor<M extends QueryMethod> {
private final M queryMethod; private final M queryMethod;
private final QueryMetadata metadata; private final QueryMetadata metadata;
private MethodContributor(M queryMethod, QueryMetadata metadata) { MethodContributor(M queryMethod, QueryMetadata metadata) {
this.queryMethod = queryMethod; this.queryMethod = queryMethod;
this.metadata = metadata; this.metadata = metadata;
} }

2
src/main/java/org/springframework/data/repository/config/AotRepositoryBeanDefinitionPropertiesDecorator.java

@ -55,7 +55,7 @@ class AotRepositoryBeanDefinitionPropertiesDecorator {
// bring in properties as usual // bring in properties as usual
builder.add(inheritedProperties.get()); builder.add(inheritedProperties.get());
builder.add("beanDefinition.getPropertyValues().addPropertyValue(\"repositoryFragments\", new $T() {\n", builder.add("beanDefinition.getPropertyValues().addPropertyValue(\"repositoryFragmentsFunction\", new $T() {\n",
RepositoryFactoryBeanSupport.RepositoryFragmentsFunction.class); RepositoryFactoryBeanSupport.RepositoryFragmentsFunction.class);
builder.indent(); builder.indent();
builder.add("public $T getRepositoryFragments($T beanFactory, $T context) {\n", builder.add("public $T getRepositoryFragments($T beanFactory, $T context) {\n",

36
src/main/java/org/springframework/data/repository/config/AotRepositoryInformation.java

@ -21,10 +21,12 @@ import java.util.LinkedHashSet;
import java.util.Set; import java.util.Set;
import java.util.function.Supplier; import java.util.function.Supplier;
import org.jspecify.annotations.Nullable;
import org.springframework.data.repository.core.RepositoryInformation; import org.springframework.data.repository.core.RepositoryInformation;
import org.springframework.data.repository.core.RepositoryInformationSupport; import org.springframework.data.repository.core.RepositoryInformationSupport;
import org.springframework.data.repository.core.RepositoryMetadata; import org.springframework.data.repository.core.RepositoryMetadata;
import org.springframework.data.repository.core.support.RepositoryComposition; import org.springframework.data.repository.core.support.RepositoryComposition;
import org.springframework.data.repository.core.support.RepositoryComposition.RepositoryFragments;
import org.springframework.data.repository.core.support.RepositoryFragment; import org.springframework.data.repository.core.support.RepositoryFragment;
import org.springframework.data.util.Lazy; import org.springframework.data.util.Lazy;
@ -36,16 +38,31 @@ import org.springframework.data.util.Lazy;
*/ */
class AotRepositoryInformation extends RepositoryInformationSupport implements RepositoryInformation { class AotRepositoryInformation extends RepositoryInformationSupport implements RepositoryInformation {
private final @Nullable String moduleName;
private final Supplier<Collection<RepositoryFragment<?>>> fragments; private final Supplier<Collection<RepositoryFragment<?>>> fragments;
private Lazy<RepositoryComposition> baseComposition = Lazy.of(() -> {
return RepositoryComposition.of(RepositoryFragment.structural(getRepositoryBaseClass()));
});
AotRepositoryInformation(Supplier<RepositoryMetadata> repositoryMetadata, Supplier<Class<?>> repositoryBaseClass, private final Lazy<RepositoryComposition> repositoryComposition;
Supplier<Collection<RepositoryFragment<?>>> fragments) { private final Lazy<RepositoryComposition> baseComposition;
AotRepositoryInformation(@Nullable String moduleName, Supplier<RepositoryMetadata> repositoryMetadata,
Supplier<Class<?>> repositoryBaseClass, Supplier<Collection<RepositoryFragment<?>>> fragments) {
super(repositoryMetadata, repositoryBaseClass); super(repositoryMetadata, repositoryBaseClass);
this.moduleName = moduleName;
this.fragments = fragments; this.fragments = fragments;
this.repositoryComposition = Lazy
.of(() -> RepositoryComposition.fromMetadata(getMetadata()).append(RepositoryFragments.from(getFragments())));
this.baseComposition = Lazy.of(() -> {
RepositoryComposition targetRepoComposition = repositoryComposition.get();
return RepositoryComposition.of(RepositoryFragment.structural(getRepositoryBaseClass())) //
.withArgumentConverter(targetRepoComposition.getArgumentConverter()) //
.withMethodLookup(targetRepoComposition.getMethodLookup());
});
} }
/** /**
@ -57,10 +74,9 @@ class AotRepositoryInformation extends RepositoryInformationSupport implements R
return new LinkedHashSet<>(fragments.get()); return new LinkedHashSet<>(fragments.get());
} }
// Not required during AOT processing.
@Override @Override
public boolean isCustomMethod(Method method) { public boolean isCustomMethod(Method method) {
return false; return repositoryComposition.get().findMethod(method).isPresent();
} }
@Override @Override
@ -75,7 +91,11 @@ class AotRepositoryInformation extends RepositoryInformationSupport implements R
@Override @Override
public RepositoryComposition getRepositoryComposition() { public RepositoryComposition getRepositoryComposition() {
return baseComposition.get().append(RepositoryComposition.RepositoryFragments.from(fragments.get())); return repositoryComposition.get();
} }
@Override
public @Nullable String moduleName() {
return moduleName;
}
} }

131
src/main/java/org/springframework/data/repository/config/RepositoryBeanDefinitionReader.java

@ -16,17 +16,21 @@
package org.springframework.data.repository.config; package org.springframework.data.repository.config;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.beans.factory.config.ConstructorArgumentValues.ValueHolder;
import org.springframework.beans.factory.config.RuntimeBeanReference;
import org.springframework.beans.factory.support.RegisteredBean;
import org.springframework.core.ResolvableType;
import org.springframework.data.repository.CrudRepository;
import org.springframework.data.repository.PagingAndSortingRepository;
import org.springframework.data.repository.core.RepositoryInformation; import org.springframework.data.repository.core.RepositoryInformation;
import org.springframework.data.repository.core.support.DefaultRepositoryMetadata; import org.springframework.data.repository.core.RepositoryMetadata;
import org.springframework.data.repository.core.support.AbstractRepositoryMetadata;
import org.springframework.data.repository.core.support.RepositoryFragment; import org.springframework.data.repository.core.support.RepositoryFragment;
import org.springframework.data.util.Lazy; import org.springframework.data.repository.core.support.RepositoryFragment.ImplementedRepositoryFragment;
import org.springframework.util.ClassUtils; import org.springframework.util.ClassUtils;
/** /**
@ -38,49 +42,108 @@ import org.springframework.util.ClassUtils;
*/ */
class RepositoryBeanDefinitionReader { class RepositoryBeanDefinitionReader {
static RepositoryInformation readRepositoryInformation(RepositoryConfiguration<?> metadata, /**
ConfigurableListableBeanFactory beanFactory) { * @return
*/
return new AotRepositoryInformation(metadataSupplier(metadata, beanFactory), static RepositoryInformation repositoryInformation(RepositoryConfiguration<?> repoConfig, RegisteredBean repoBean) {
repositoryBaseClass(metadata, beanFactory), fragments(metadata, beanFactory)); return repositoryInformation(repoConfig, repoBean.getMergedBeanDefinition(), repoBean.getBeanFactory());
} }
private static Supplier<Collection<RepositoryFragment<?>>> fragments(RepositoryConfiguration<?> metadata, /**
* @param source the RepositoryFactoryBeanSupport bean definition.
* @param beanFactory
* @return
*/
@SuppressWarnings("NullAway")
static RepositoryInformation repositoryInformation(RepositoryConfiguration<?> repoConfig, BeanDefinition source,
ConfigurableListableBeanFactory beanFactory) { ConfigurableListableBeanFactory beanFactory) {
if (metadata instanceof RepositoryFragmentConfigurationProvider provider) { RepositoryMetadata metadata = AbstractRepositoryMetadata
.getMetadata(forName(repoConfig.getRepositoryInterface(), beanFactory));
return Lazy.of(() -> { Class<?> repositoryBaseClass = readRepositoryBaseClass(source, beanFactory);
return provider.getFragmentConfiguration().stream().flatMap(it -> { List<RepositoryFragment<?>> fragmentList = readRepositoryFragments(source, beanFactory);
if (source.getPropertyValues().contains("customImplementation")) {
Object o = source.getPropertyValues().get("customImplementation");
if (o instanceof RuntimeBeanReference rbr) {
BeanDefinition customImplBeanDefintion = beanFactory.getBeanDefinition(rbr.getBeanName());
Class<?> beanType = forName(customImplBeanDefintion.getBeanClassName(), beanFactory);
ResolvableType[] interfaces = ResolvableType.forClass(beanType).getInterfaces();
if (interfaces.length == 1) {
fragmentList.add(new ImplementedRepositoryFragment(interfaces[0].toClass(), beanType));
} else {
boolean found = false;
for (ResolvableType i : interfaces) {
if (beanType.getSimpleName().contains(i.resolve().getSimpleName())) {
fragmentList.add(new ImplementedRepositoryFragment(interfaces[0].toClass(), beanType));
found = true;
break;
}
}
if (!found) {
fragmentList.add(RepositoryFragment.implemented(beanType));
}
}
}
}
List<RepositoryFragment<?>> fragments = new ArrayList<>(1); String moduleName = (String) source.getPropertyValues().get("moduleName");
AotRepositoryInformation repositoryInformation = new AotRepositoryInformation(moduleName, () -> metadata,
() -> repositoryBaseClass, () -> fragmentList);
return repositoryInformation;
}
fragments.add(RepositoryFragment.implemented(forName(it.getClassName(), beanFactory))); @SuppressWarnings("NullAway")
private static Class<?> readRepositoryBaseClass(BeanDefinition source, ConfigurableListableBeanFactory beanFactory) {
if (it.getInterfaceName() != null) { Object repoBaseClassName = source.getPropertyValues().get("repositoryBaseClass");
fragments.add(RepositoryFragment.structural(forName(it.getInterfaceName(), beanFactory))); if (repoBaseClassName != null) {
return forName(repoBaseClassName.toString(), beanFactory);
} }
if (source.getPropertyValues().contains("moduleBaseClass")) {
return fragments.stream(); return forName((String) source.getPropertyValues().get("moduleBaseClass"), beanFactory);
}).collect(Collectors.toList());
});
} }
return Dummy.class;
return Lazy.of(Collections::emptyList);
} }
@SuppressWarnings({ "rawtypes", "unchecked" }) @SuppressWarnings("NullAway")
private static Supplier<Class<?>> repositoryBaseClass(RepositoryConfiguration metadata, private static List<RepositoryFragment<?>> readRepositoryFragments(BeanDefinition source,
ConfigurableListableBeanFactory beanFactory) { ConfigurableListableBeanFactory beanFactory) {
return Lazy.of(() -> (Class<?>) metadata.getRepositoryBaseClassName().map(it -> forName(it.toString(), beanFactory)) RuntimeBeanReference beanReference = (RuntimeBeanReference) source.getPropertyValues().get("repositoryFragments");
.orElse(Object.class)); BeanDefinition fragments = beanFactory.getBeanDefinition(beanReference.getBeanName());
}
ValueHolder fragmentBeanNameList = fragments.getConstructorArgumentValues().getArgumentValue(0, List.class);
List<String> fragmentBeanNames = (List<String>) fragmentBeanNameList.getValue();
List<RepositoryFragment<?>> fragmentList = new ArrayList<>();
for (String beanName : fragmentBeanNames) {
BeanDefinition fragmentBeanDefinition = beanFactory.getBeanDefinition(beanName);
ValueHolder argumentValue = fragmentBeanDefinition.getConstructorArgumentValues().getArgumentValue(0,
String.class);
ValueHolder argumentValue1 = fragmentBeanDefinition.getConstructorArgumentValues().getArgumentValue(1, null, null,
null);
Object fragmentClassName = argumentValue.getValue();
private static Supplier<org.springframework.data.repository.core.RepositoryMetadata> metadataSupplier( try {
RepositoryConfiguration<?> metadata, ConfigurableListableBeanFactory beanFactory) { Class<?> type = ClassUtils.forName(fragmentClassName.toString(), beanFactory.getBeanClassLoader());
return Lazy.of(() -> new DefaultRepositoryMetadata(forName(metadata.getRepositoryInterface(), beanFactory)));
if (argumentValue1 != null && argumentValue1.getValue() instanceof RuntimeBeanReference rbf) {
BeanDefinition implBeanDef = beanFactory.getBeanDefinition(rbf.getBeanName());
Class implClass = ClassUtils.forName(implBeanDef.getBeanClassName(), beanFactory.getBeanClassLoader());
fragmentList.add(new RepositoryFragment.ImplementedRepositoryFragment(type, implClass));
} else {
fragmentList.add(RepositoryFragment.structural(type));
}
} catch (ClassNotFoundException e) {
throw new RuntimeException(e);
}
} }
return fragmentList;
}
static abstract class Dummy implements CrudRepository<Object, Object>, PagingAndSortingRepository<Object, Object> {}
static Class<?> forName(String name, ConfigurableListableBeanFactory beanFactory) { static Class<?> forName(String name, ConfigurableListableBeanFactory beanFactory) {
try { try {

21
src/main/java/org/springframework/data/repository/config/RepositoryRegistrationAotContribution.java

@ -28,7 +28,6 @@ import java.util.function.BiFunction;
import java.util.function.Predicate; import java.util.function.Predicate;
import org.jspecify.annotations.Nullable; import org.jspecify.annotations.Nullable;
import org.springframework.aop.SpringProxy; import org.springframework.aop.SpringProxy;
import org.springframework.aop.framework.Advised; import org.springframework.aop.framework.Advised;
import org.springframework.aot.generate.GenerationContext; import org.springframework.aot.generate.GenerationContext;
@ -49,7 +48,6 @@ import org.springframework.data.projection.TargetAware;
import org.springframework.data.repository.Repository; import org.springframework.data.repository.Repository;
import org.springframework.data.repository.aot.generate.RepositoryContributor; import org.springframework.data.repository.aot.generate.RepositoryContributor;
import org.springframework.data.repository.core.RepositoryInformation; import org.springframework.data.repository.core.RepositoryInformation;
import org.springframework.data.repository.core.support.RepositoryFactoryBeanSupport;
import org.springframework.data.repository.core.support.RepositoryFragment; import org.springframework.data.repository.core.support.RepositoryFragment;
import org.springframework.data.util.Predicates; import org.springframework.data.util.Predicates;
import org.springframework.data.util.QTypeContributor; import org.springframework.data.util.QTypeContributor;
@ -90,8 +88,7 @@ public class RepositoryRegistrationAotContribution implements BeanRegistrationAo
* @throws IllegalArgumentException if the {@link RepositoryRegistrationAotProcessor} is {@literal null}. * @throws IllegalArgumentException if the {@link RepositoryRegistrationAotProcessor} is {@literal null}.
* @see RepositoryRegistrationAotProcessor * @see RepositoryRegistrationAotProcessor
*/ */
protected RepositoryRegistrationAotContribution( protected RepositoryRegistrationAotContribution(RepositoryRegistrationAotProcessor processor) {
RepositoryRegistrationAotProcessor processor) {
Assert.notNull(processor, "RepositoryRegistrationAotProcessor must not be null"); Assert.notNull(processor, "RepositoryRegistrationAotProcessor must not be null");
@ -108,8 +105,7 @@ public class RepositoryRegistrationAotContribution implements BeanRegistrationAo
* @throws IllegalArgumentException if the {@link RepositoryRegistrationAotProcessor} is {@literal null}. * @throws IllegalArgumentException if the {@link RepositoryRegistrationAotProcessor} is {@literal null}.
* @see RepositoryRegistrationAotProcessor * @see RepositoryRegistrationAotProcessor
*/ */
public static RepositoryRegistrationAotContribution fromProcessor( public static RepositoryRegistrationAotContribution fromProcessor(RepositoryRegistrationAotProcessor processor) {
RepositoryRegistrationAotProcessor processor) {
return new RepositoryRegistrationAotContribution(processor); return new RepositoryRegistrationAotContribution(processor);
} }
@ -255,7 +251,8 @@ public class RepositoryRegistrationAotContribution implements BeanRegistrationAo
}); });
implementation.ifPresent(impl -> { implementation.ifPresent(impl -> {
contribution.getRuntimeHints().reflection().registerType(impl.getClass(), hint -> { Class<?> typeToRegister = impl instanceof Class c ? c : impl.getClass();
contribution.getRuntimeHints().reflection().registerType(typeToRegister, hint -> {
hint.withMembers(MemberCategory.INVOKE_PUBLIC_METHODS); hint.withMembers(MemberCategory.INVOKE_PUBLIC_METHODS);
@ -365,18 +362,16 @@ public class RepositoryRegistrationAotContribution implements BeanRegistrationAo
@SuppressWarnings("rawtypes") @SuppressWarnings("rawtypes")
private DefaultAotRepositoryContext buildAotRepositoryContext(RegisteredBean bean, private DefaultAotRepositoryContext buildAotRepositoryContext(RegisteredBean bean,
RepositoryConfiguration<?> repositoryMetadata) { RepositoryConfiguration<?> repositoryConfiguration) {
DefaultAotRepositoryContext repositoryContext = new DefaultAotRepositoryContext( DefaultAotRepositoryContext repositoryContext = new DefaultAotRepositoryContext(
AotContext.from(getBeanFactory(), getRepositoryRegistrationAotProcessor().getEnvironment())); AotContext.from(getBeanFactory(), getRepositoryRegistrationAotProcessor().getEnvironment()));
RepositoryFactoryBeanSupport rfbs = bean.getBeanFactory().getBean("&" + bean.getBeanName(),
RepositoryFactoryBeanSupport.class);
repositoryContext.setBeanName(bean.getBeanName()); repositoryContext.setBeanName(bean.getBeanName());
repositoryContext.setBasePackages(repositoryMetadata.getBasePackages().toSet()); repositoryContext.setBasePackages(repositoryConfiguration.getBasePackages().toSet());
repositoryContext.setIdentifyingAnnotations(resolveIdentifyingAnnotations()); repositoryContext.setIdentifyingAnnotations(resolveIdentifyingAnnotations());
repositoryContext.setRepositoryInformation(rfbs.getRepositoryInformation()); repositoryContext
.setRepositoryInformation(RepositoryBeanDefinitionReader.repositoryInformation(repositoryConfiguration, bean));
return repositoryContext; return repositoryContext;
} }

5
src/main/java/org/springframework/data/repository/core/RepositoryInformation.java

@ -18,6 +18,7 @@ package org.springframework.data.repository.core;
import java.lang.reflect.Method; import java.lang.reflect.Method;
import java.util.List; import java.util.List;
import org.jspecify.annotations.Nullable;
import org.springframework.data.repository.core.support.RepositoryComposition; import org.springframework.data.repository.core.support.RepositoryComposition;
/** /**
@ -105,4 +106,8 @@ public interface RepositoryInformation extends RepositoryMetadata {
*/ */
RepositoryComposition getRepositoryComposition(); RepositoryComposition getRepositoryComposition();
default @Nullable String moduleName() {
return null;
}
} }

2
src/main/java/org/springframework/data/repository/core/RepositoryInformationSupport.java

@ -184,7 +184,7 @@ public abstract class RepositoryInformationSupport implements RepositoryInformat
return true; return true;
} }
private RepositoryMetadata getMetadata() { protected RepositoryMetadata getMetadata() {
return metadata.get(); return metadata.get();
} }

16
src/main/java/org/springframework/data/repository/core/support/RepositoryFactoryBeanSupport.java

@ -95,6 +95,10 @@ public abstract class RepositoryFactoryBeanSupport<T extends Repository<S, ID>,
private @Nullable Lazy<T> repository; private @Nullable Lazy<T> repository;
private @Nullable RepositoryMetadata repositoryMetadata; private @Nullable RepositoryMetadata repositoryMetadata;
// AOT bean factory hint?
private @Nullable String moduleBaseClass;
private @Nullable String moduleName;
/** /**
* Creates a new {@link RepositoryFactoryBeanSupport} for the given repository interface. * Creates a new {@link RepositoryFactoryBeanSupport} for the given repository interface.
* *
@ -155,7 +159,7 @@ public abstract class RepositoryFactoryBeanSupport<T extends Repository<S, ID>,
* @param repositoryFragments * @param repositoryFragments
*/ */
public void setRepositoryFragments(RepositoryFragments repositoryFragments) { public void setRepositoryFragments(RepositoryFragments repositoryFragments) {
setRepositoryFragments(RepositoryFragmentsFunction.just(repositoryFragments)); setRepositoryFragmentsFunction(RepositoryFragmentsFunction.just(repositoryFragments));
} }
/** /**
@ -165,7 +169,7 @@ public abstract class RepositoryFactoryBeanSupport<T extends Repository<S, ID>,
* @param fragmentsFunction * @param fragmentsFunction
* @since 4.0 * @since 4.0
*/ */
public void setRepositoryFragments(RepositoryFragmentsFunction fragmentsFunction) { public void setRepositoryFragmentsFunction(RepositoryFragmentsFunction fragmentsFunction) {
this.fragments.add(fragmentsFunction); this.fragments.add(fragmentsFunction);
} }
@ -257,6 +261,14 @@ public abstract class RepositoryFactoryBeanSupport<T extends Repository<S, ID>,
this.publisher = publisher; this.publisher = publisher;
} }
public void setModuleBaseClass(String moduleBaseClass) {
this.moduleBaseClass = moduleBaseClass;
}
public void setModuleName(String moduleName) {
this.moduleName = moduleName;
}
@Override @Override
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public EntityInformation<S, ID> getEntityInformation() { public EntityInformation<S, ID> getEntityInformation() {

2
src/main/java/org/springframework/data/repository/core/support/RepositoryFragment.java

@ -265,7 +265,7 @@ public interface RepositoryFragment<T> {
Assert.notNull(implementation, "Implementation object must not be null"); Assert.notNull(implementation, "Implementation object must not be null");
if (interfaceClass != null) { if (interfaceClass != null && !(implementation instanceof Class)) {
Assert Assert
.isTrue(ClassUtils.isAssignableValue(interfaceClass, implementation), .isTrue(ClassUtils.isAssignableValue(interfaceClass, implementation),

2
src/test/java/example/UserRepository.java

@ -24,7 +24,7 @@ import org.springframework.data.repository.CrudRepository;
/** /**
* @author Christoph Strobl * @author Christoph Strobl
*/ */
public interface UserRepository extends CrudRepository<User, Long> { public interface UserRepository extends CrudRepository<User, Long>, UserRepositoryExtension {
User findByFirstname(String firstname); User findByFirstname(String firstname);

25
src/test/java/example/UserRepositoryExtension.java

@ -0,0 +1,25 @@
/*
* Copyright 2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package example;
import example.UserRepository.User;
/**
* @author Christoph Strobl
*/
public interface UserRepositoryExtension {
User findUserByExtensionMethod();
}

29
src/test/java/example/UserRepositoryExtensionImpl.java

@ -0,0 +1,29 @@
/*
* Copyright 2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package example;
import example.UserRepository.User;
/**
* @author Christoph Strobl
*/
public class UserRepositoryExtensionImpl implements UserRepositoryExtension {
@Override
public User findUserByExtensionMethod() {
return null;
}
}

157
src/test/java/org/springframework/data/repository/aot/generate/AotRepositoryBuilderUnitTests.java

@ -0,0 +1,157 @@
/*
* Copyright 2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.data.repository.aot.generate;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import example.UserRepository;
import example.UserRepository.User;
import java.util.TimeZone;
import javax.lang.model.element.Modifier;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;
import org.springframework.data.geo.Metric;
import org.springframework.data.projection.SpelAwareProxyProjectionFactory;
import org.springframework.data.repository.core.RepositoryInformation;
import org.springframework.data.repository.query.QueryMethod;
import org.springframework.data.util.TypeInformation;
import org.springframework.javapoet.MethodSpec;
import org.springframework.javapoet.TypeName;
import org.springframework.stereotype.Repository;
/**
* @author Christoph Strobl
*/
class AotRepositoryBuilderUnitTests {
RepositoryInformation repositoryInformation;
@BeforeEach
void beforeEach() {
repositoryInformation = mock(RepositoryInformation.class);
doReturn(UserRepository.class).when(repositoryInformation).getRepositoryInterface();
}
@Test // GH-3279
void writesClassSkeleton() {
AotRepositoryBuilder repoBuilder = AotRepositoryBuilder.forRepository(repositoryInformation,
new SpelAwareProxyProjectionFactory());
assertThat(repoBuilder.build().javaFile().toString())
.contains("package %s;".formatted(UserRepository.class.getPackageName())) // same package as source repo
.contains("@Generated") // marked as generated source
.contains("public class %sImpl__Aot".formatted(UserRepository.class.getSimpleName())) // target name
.contains("public UserRepositoryImpl__Aot()"); // default constructor if not arguments to wire
}
@Test // GH-3279
void appliesCtorArguments() {
AotRepositoryBuilder repoBuilder = AotRepositoryBuilder.forRepository(repositoryInformation,
new SpelAwareProxyProjectionFactory());
repoBuilder.withConstructorCustomizer(ctor -> {
ctor.addParameter("param1", Metric.class);
ctor.addParameter("param2", String.class);
ctor.addParameter("ctorScoped", TypeName.OBJECT, false);
});
assertThat(repoBuilder.build().javaFile().toString()) //
.contains("private final Metric param1;") //
.contains("private final String param2;") //
.doesNotContain("private final Object ctorScoped;") //
.contains("public UserRepositoryImpl__Aot(Metric param1, String param2, Object ctorScoped)") //
.contains("this.param1 = param1") //
.contains("this.param2 = param2") //
.doesNotContain("this.ctorScoped = ctorScoped");
}
@Test // GH-3279
void appliesCtorCodeBlock() {
AotRepositoryBuilder repoBuilder = AotRepositoryBuilder.forRepository(repositoryInformation,
new SpelAwareProxyProjectionFactory());
repoBuilder.withConstructorCustomizer(ctor -> {
ctor.customize((info, code) -> {
code.addStatement("throw new $T($S)", IllegalStateException.class, "initialization error");
});
});
assertThat(repoBuilder.build().javaFile().toString()).containsIgnoringWhitespaces(
"UserRepositoryImpl__Aot() { throw new IllegalStateException(\"initialization error\"); }");
}
@Test // GH-3279
void appliesClassCustomizations() {
AotRepositoryBuilder repoBuilder = AotRepositoryBuilder.forRepository(repositoryInformation,
new SpelAwareProxyProjectionFactory());
repoBuilder.withClassCustomizer((info, metadata, clazz) -> {
clazz.addField(Float.class, "f", Modifier.PRIVATE, Modifier.STATIC);
clazz.addField(Double.class, "d", Modifier.PUBLIC);
clazz.addField(TimeZone.class, "t", Modifier.FINAL);
clazz.addAnnotation(Repository.class);
clazz.addMethod(MethodSpec.methodBuilder("oops").build());
});
assertThat(repoBuilder.build().javaFile().toString()) //
.contains("@Repository") //
.contains("private static Float f;") //
.contains("public Double d;") //
.contains("final TimeZone t;") //
.containsIgnoringWhitespaces("void oops() { }");
}
@Test // GH-3279
void appliesQueryMethodContributor() {
AotRepositoryBuilder repoBuilder = AotRepositoryBuilder.forRepository(repositoryInformation,
new SpelAwareProxyProjectionFactory());
when(repositoryInformation.isQueryMethod(Mockito.argThat(arg -> arg.getName().equals("findByFirstname"))))
.thenReturn(true);
doReturn(TypeInformation.of(User.class)).when(repositoryInformation).getReturnType(any());
repoBuilder.withQueryMethodContributor((method, info) -> {
return new MethodContributor<>(mock(QueryMethod.class), null) {
@Override
public MethodSpec contribute(AotQueryMethodGenerationContext context) {
return MethodSpec.methodBuilder("oops").build();
}
@Override
public boolean contributesMethodSpec() {
return true;
}
};
});
assertThat(repoBuilder.build().javaFile().toString()) //
.containsIgnoringWhitespaces("void oops() { }");
}
}

88
src/test/java/org/springframework/data/repository/aot/generate/AotRepositoryMethodBuilderUnitTests.java

@ -0,0 +1,88 @@
/*
* Copyright 2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.data.repository.aot.generate;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.when;
import example.UserRepository;
import example.UserRepository.User;
import java.lang.reflect.Method;
import java.util.List;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;
import org.springframework.core.ResolvableType;
import org.springframework.data.repository.core.RepositoryInformation;
import org.springframework.data.util.TypeInformation;
import org.springframework.javapoet.ParameterSpec;
import org.springframework.javapoet.ParameterizedTypeName;
/**
* @author Christoph Strobl
*/
class AotRepositoryMethodBuilderUnitTests {
RepositoryInformation repositoryInformation;
AotQueryMethodGenerationContext methodGenerationContext;
@BeforeEach
void beforeEach() {
repositoryInformation = Mockito.mock(RepositoryInformation.class);
methodGenerationContext = Mockito.mock(AotQueryMethodGenerationContext.class);
when(methodGenerationContext.getRepositoryInformation()).thenReturn(repositoryInformation);
}
@Test // GH-3279
void generatesMethodSkeletonBasedOnGenerationMetadata() throws NoSuchMethodException {
Method method = UserRepository.class.getMethod("findByFirstname", String.class);
when(methodGenerationContext.getMethod()).thenReturn(method);
when(methodGenerationContext.getReturnType()).thenReturn(ResolvableType.forClass(User.class));
doReturn(TypeInformation.of(User.class)).when(repositoryInformation).getReturnType(any());
MethodMetadata methodMetadata = new MethodMetadata(repositoryInformation, method);
methodMetadata.addParameter(ParameterSpec.builder(String.class, "firstname").build());
when(methodGenerationContext.getTargetMethodMetadata()).thenReturn(methodMetadata);
AotRepositoryMethodBuilder builder = new AotRepositoryMethodBuilder(methodGenerationContext);
assertThat(builder.buildMethod().toString()) //
.containsPattern("public .*User findByFirstname\\(.*String firstname\\)");
}
@Test // GH-3279
void generatesMethodWithGenerics() throws NoSuchMethodException {
Method method = UserRepository.class.getMethod("findByFirstnameIn", List.class);
when(methodGenerationContext.getMethod()).thenReturn(method);
when(methodGenerationContext.getReturnType())
.thenReturn(ResolvableType.forClassWithGenerics(List.class, User.class));
doReturn(TypeInformation.of(User.class)).when(repositoryInformation).getReturnType(any());
MethodMetadata methodMetadata = new MethodMetadata(repositoryInformation, method);
methodMetadata
.addParameter(ParameterSpec.builder(ParameterizedTypeName.get(List.class, String.class), "firstnames").build());
when(methodGenerationContext.getTargetMethodMetadata()).thenReturn(methodMetadata);
AotRepositoryMethodBuilder builder = new AotRepositoryMethodBuilder(methodGenerationContext);
assertThat(builder.buildMethod().toString()) //
.containsPattern("public .*List<.*User> findByFirstnameIn\\(") //
.containsPattern(".*List<.*String> firstnames\\)");
}
}

57
src/test/java/org/springframework/data/repository/aot/generate/MethodCapturingRepositoryContributor.java

@ -0,0 +1,57 @@
/*
* Copyright 2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.data.repository.aot.generate;
import static org.assertj.core.api.Assertions.assertThat;
import java.lang.reflect.Method;
import java.util.List;
import org.assertj.core.api.MapAssert;
import org.jspecify.annotations.Nullable;
import org.springframework.data.repository.config.AotRepositoryContext;
import org.springframework.data.repository.core.RepositoryInformation;
import org.springframework.data.repository.query.QueryMethod;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
/**
* @author Christoph Strobl
*/
public class MethodCapturingRepositoryContributor extends RepositoryContributor {
MultiValueMap<String, Method> capturedInvocations;
public MethodCapturingRepositoryContributor(AotRepositoryContext repositoryContext) {
super(repositoryContext);
this.capturedInvocations = new LinkedMultiValueMap<>(3);
}
@Override
protected @Nullable MethodContributor<? extends QueryMethod> contributeQueryMethod(Method method,
RepositoryInformation repositoryInformation) {
capturedInvocations.add(method.getName(), method);
return null;
}
void verifyContributionFor(String methodName) {
assertThat(capturedInvocations).containsKey(methodName);
}
MapAssert<String, List<Method>> verifyContributedMethods() {
return assertThat(capturedInvocations);
}
}

167
src/test/java/org/springframework/data/repository/aot/generate/RepositoryContributorUnitTests.java

@ -15,31 +15,45 @@
*/ */
package org.springframework.data.repository.aot.generate; package org.springframework.data.repository.aot.generate;
import static org.assertj.core.api.Assertions.*; import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.Mockito.when;
import example.UserRepository; import example.UserRepository;
import example.UserRepositoryExtension;
import example.UserRepositoryExtensionImpl;
import java.lang.reflect.Method; import java.lang.reflect.Method;
import java.util.Map; import java.util.Map;
import java.util.Optional;
import java.util.Set;
import org.jspecify.annotations.Nullable; import org.jspecify.annotations.Nullable;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.mockito.Mockito;
import org.springframework.aot.test.generate.TestGenerationContext; import org.springframework.aot.test.generate.TestGenerationContext;
import org.springframework.core.test.tools.TestCompiler; import org.springframework.core.test.tools.TestCompiler;
import org.springframework.data.aot.CodeContributionAssert; import org.springframework.data.aot.CodeContributionAssert;
import org.springframework.data.repository.CrudRepository;
import org.springframework.data.repository.config.AotRepositoryContext;
import org.springframework.data.repository.core.RepositoryInformation; import org.springframework.data.repository.core.RepositoryInformation;
import org.springframework.data.repository.core.support.RepositoryComposition;
import org.springframework.data.repository.core.support.RepositoryComposition.RepositoryFragments;
import org.springframework.data.repository.core.support.RepositoryFragment;
import org.springframework.data.repository.query.QueryMethod; import org.springframework.data.repository.query.QueryMethod;
import org.springframework.javapoet.CodeBlock; import org.springframework.javapoet.CodeBlock;
import org.springframework.util.ClassUtils; import org.springframework.util.ClassUtils;
/** /**
* Unit tests targeting {@link RepositoryContributor}.
*
* @author Christoph Strobl * @author Christoph Strobl
*/ */
class RepositoryContributorUnitTests { class RepositoryContributorUnitTests {
@Test @Test // GH-3279
void testCompile() { void createsCompilableClassStub() {
DummyModuleAotRepositoryContext aotContext = new DummyModuleAotRepositoryContext(UserRepository.class, null); DummyModuleAotRepositoryContext aotContext = new DummyModuleAotRepositoryContext(UserRepository.class, null);
RepositoryContributor repositoryContributor = new RepositoryContributor(aotContext) { RepositoryContributor repositoryContributor = new RepositoryContributor(aotContext) {
@ -55,8 +69,7 @@ class RepositoryContributorUnitTests {
public Map<String, Object> serialize() { public Map<String, Object> serialize() {
return Map.of(); return Map.of();
} }
}) }).contribute(context -> {
.contribute(context -> {
CodeBlock.Builder builder = CodeBlock.builder(); CodeBlock.Builder builder = CodeBlock.builder();
if (!ClassUtils.isVoidType(method.getReturnType())) { if (!ClassUtils.isVoidType(method.getReturnType())) {
@ -81,4 +94,146 @@ class RepositoryContributorUnitTests {
new CodeContributionAssert(generationContext).contributesReflectionFor(expectedTypeName); new CodeContributionAssert(generationContext).contributesReflectionFor(expectedTypeName);
} }
@Test // GH-3279
void callsMethodContributionForQueryMethod() {
AotRepositoryContext repositoryContext = Mockito.mock(AotRepositoryContext.class);
RepositoryInformation repositoryInformation = Mockito.mock(RepositoryInformation.class);
when(repositoryContext.getRepositoryInformation()).thenReturn(repositoryInformation);
when(repositoryInformation.getRepositoryInterface()).thenReturn((Class) UserRepository.class);
when(repositoryInformation.isQueryMethod(argThat(it -> it.getName().equals("findByFirstname")))).thenReturn(true);
MethodCapturingRepositoryContributor contributor = new MethodCapturingRepositoryContributor(repositoryContext);
contributor.contribute(new TestGenerationContext(UserRepository.class));
contributor.verifyContributionFor("findByFirstname");
}
@Test // GH-3279
void doesNotContributeBaseClassMethods() {
AotRepositoryContext repositoryContext = Mockito.mock(AotRepositoryContext.class);
RepositoryInformation repositoryInformation = Mockito.mock(RepositoryInformation.class);
when(repositoryContext.getRepositoryInformation()).thenReturn(repositoryInformation);
when(repositoryInformation.getRepositoryInterface()).thenReturn((Class) UserRepository.class);
when(repositoryInformation.getRepositoryComposition())
.thenReturn(RepositoryComposition.of(RepositoryFragment.structural(RepoBaseClass.class)));
when(repositoryInformation.isBaseClassMethod(argThat(it -> it.getName().equals("findByFirstname"))))
.thenReturn(true);
when(repositoryInformation.isQueryMethod(argThat(it -> !it.getName().equals("findByFirstname")))).thenReturn(true);
MethodCapturingRepositoryContributor contributor = new MethodCapturingRepositoryContributor(repositoryContext);
contributor.contribute(new TestGenerationContext(UserRepository.class));
contributor.verifyContributedMethods().isNotEmpty().doesNotContainKey("findByFirstname");
}
@Test // GH-3279
void doesNotContributeFragmentMethod() {
AotRepositoryContext repositoryContext = Mockito.mock(AotRepositoryContext.class);
RepositoryInformation repositoryInformation = Mockito.mock(RepositoryInformation.class);
when(repositoryContext.getRepositoryInformation()).thenReturn(repositoryInformation);
when(repositoryInformation.getRepositoryInterface()).thenReturn((Class) UserRepository.class);
when(repositoryInformation.getRepositoryComposition())
.thenReturn(RepositoryComposition.of(RepositoryFragment.structural(UserRepository.class))
.append(RepositoryFragments
.from(Set.of(new RepositoryFragment.ImplementedRepositoryFragment(UserRepositoryExtension.class,
UserRepositoryExtensionImpl.class)))));
when(repositoryInformation.isCustomMethod(argThat(it -> it.getName().equals("findUserByExtensionMethod"))))
.thenReturn(true);
when(repositoryInformation.isQueryMethod(argThat(it -> it.getName().equals("findByFirstname")))).thenReturn(true);
MethodCapturingRepositoryContributor contributor = new MethodCapturingRepositoryContributor(repositoryContext);
contributor.contribute(new TestGenerationContext(UserRepository.class));
contributor.verifyContributedMethods().isNotEmpty().doesNotContainKey("findUserByExtensionMethod");
}
@Test // GH-3279
void contributesBaseClassMethodIfQueryMethod() {
AotRepositoryContext repositoryContext = Mockito.mock(AotRepositoryContext.class);
RepositoryInformation repositoryInformation = Mockito.mock(RepositoryInformation.class);
when(repositoryContext.getRepositoryInformation()).thenReturn(repositoryInformation);
when(repositoryInformation.getRepositoryInterface()).thenReturn((Class) UserRepository.class);
when(repositoryInformation.getRepositoryComposition())
.thenReturn(RepositoryComposition.of(RepositoryFragment.structural(RepoBaseClass.class)));
when(repositoryInformation.isBaseClassMethod(argThat(it -> it.getName().equals("findByFirstname"))))
.thenReturn(true);
when(repositoryInformation.isQueryMethod(any())).thenReturn(true);
MethodCapturingRepositoryContributor contributor = new MethodCapturingRepositoryContributor(repositoryContext);
contributor.contribute(new TestGenerationContext(UserRepository.class));
contributor.verifyContributedMethods().containsKey("findByFirstname").hasSizeGreaterThan(1);
}
static class RepoBaseClass<T, ID> implements CrudRepository<T, ID> {
private CrudRepository<T, ID> delegate;
public <S extends T> S save(S entity) {
return this.delegate.save(entity);
}
@Override
public <S extends T> Iterable<S> saveAll(Iterable<S> entities) {
return this.delegate.saveAll(entities);
}
public Optional<T> findById(ID id) {
return this.delegate.findById(id);
}
@Override
public boolean existsById(ID id) {
return this.delegate.existsById(id);
}
@Override
public Iterable<T> findAll() {
return this.delegate.findAll();
}
@Override
public Iterable<T> findAllById(Iterable<ID> ids) {
return this.delegate.findAllById(ids);
}
@Override
public long count() {
return this.delegate.count();
}
@Override
public void deleteById(ID id) {
this.delegate.deleteById(id);
}
@Override
public void delete(T entity) {
this.delegate.delete(entity);
}
@Override
public void deleteAllById(Iterable<? extends ID> ids) {
this.delegate.deleteAllById(ids);
}
@Override
public void deleteAll(Iterable<? extends T> entities) {
this.delegate.deleteAll(entities);
}
@Override
public void deleteAll() {
this.delegate.deleteAll();
}
}
} }

122
src/test/java/org/springframework/data/repository/config/RepositoryBeanDefinitionReaderTests.java

@ -0,0 +1,122 @@
/*
* Copyright 2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.data.repository.config;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.mock;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;
import org.springframework.aot.hint.RuntimeHints;
import org.springframework.beans.factory.support.RegisteredBean;
import org.springframework.beans.factory.support.RootBeanDefinition;
import org.springframework.context.annotation.AnnotationConfigApplicationContext;
import org.springframework.data.aot.sample.ConfigWithCustomImplementation;
import org.springframework.data.aot.sample.ConfigWithCustomRepositoryBaseClass;
import org.springframework.data.aot.sample.ConfigWithCustomRepositoryBaseClass.CustomerRepositoryWithCustomBaseRepo;
import org.springframework.data.aot.sample.ConfigWithSimpleCrudRepository;
import org.springframework.data.repository.core.RepositoryInformation;
import org.springframework.data.repository.core.support.RepositoryFactoryBeanSupport;
/**
* @author Christoph Strobl
*/
class RepositoryBeanDefinitionReaderTests {
@Test // GH-3279
void readsSimpleConfigFromBeanFactory() {
RegisteredBean repoFactoryBean = repositoryFactory(ConfigWithSimpleCrudRepository.class);
RepositoryConfiguration<?> repoConfig = mock(RepositoryConfiguration.class);
Mockito.when(repoConfig.getRepositoryInterface()).thenReturn(ConfigWithSimpleCrudRepository.MyRepo.class.getName());
RepositoryInformation repositoryInformation = RepositoryBeanDefinitionReader.repositoryInformation(repoConfig,
repoFactoryBean.getMergedBeanDefinition(), repoFactoryBean.getBeanFactory());
assertThat(repositoryInformation.getRepositoryInterface()).isEqualTo(ConfigWithSimpleCrudRepository.MyRepo.class);
assertThat(repositoryInformation.getDomainType()).isEqualTo(ConfigWithSimpleCrudRepository.Person.class);
assertThat(repositoryInformation.getFragments()).isEmpty();
}
@Test // GH-3279
void readsCustomRepoBaseClassFromBeanFactory() {
RegisteredBean repoFactoryBean = repositoryFactory(ConfigWithCustomRepositoryBaseClass.class);
RepositoryConfiguration<?> repoConfig = mock(RepositoryConfiguration.class);
Class<?> repositoryInterfaceType = CustomerRepositoryWithCustomBaseRepo.class;
Mockito.when(repoConfig.getRepositoryInterface()).thenReturn(repositoryInterfaceType.getName());
RepositoryInformation repositoryInformation = RepositoryBeanDefinitionReader.repositoryInformation(repoConfig,
repoFactoryBean.getMergedBeanDefinition(), repoFactoryBean.getBeanFactory());
assertThat(repositoryInformation.getRepositoryBaseClass())
.isEqualTo(ConfigWithCustomRepositoryBaseClass.RepoBaseClass.class);
}
@Test // GH-3279
void readsFragmentsFromBeanFactory() {
RegisteredBean repoFactoryBean = repositoryFactory(ConfigWithCustomImplementation.class);
RepositoryConfiguration<?> repoConfig = mock(RepositoryConfiguration.class);
Class<?> repositoryInterfaceType = ConfigWithCustomImplementation.RepositoryWithCustomImplementation.class;
Mockito.when(repoConfig.getRepositoryInterface()).thenReturn(repositoryInterfaceType.getName());
RepositoryInformation repositoryInformation = RepositoryBeanDefinitionReader.repositoryInformation(repoConfig,
repoFactoryBean.getMergedBeanDefinition(), repoFactoryBean.getBeanFactory());
assertThat(repositoryInformation.getFragments()).satisfiesExactly(fragment -> {
assertThat(fragment.getSignatureContributor())
.isEqualTo(ConfigWithCustomImplementation.CustomImplInterface.class);
});
}
@Test // GH-3279
void fallsBackToModuleBaseClassIfSetAndNoRepoBaseDefined() {
RegisteredBean repoFactoryBean = repositoryFactory(ConfigWithSimpleCrudRepository.class);
RootBeanDefinition rootBeanDefinition = repoFactoryBean.getMergedBeanDefinition().cloneBeanDefinition();
// need to unset because its defined as non default
rootBeanDefinition.getPropertyValues().removePropertyValue("repositoryBaseClass");
rootBeanDefinition.getPropertyValues().add("moduleBaseClass", ModuleBase.class.getName());
RepositoryConfiguration<?> repoConfig = mock(RepositoryConfiguration.class);
Mockito.when(repoConfig.getRepositoryInterface()).thenReturn(ConfigWithSimpleCrudRepository.MyRepo.class.getName());
RepositoryInformation repositoryInformation = RepositoryBeanDefinitionReader.repositoryInformation(repoConfig,
rootBeanDefinition, repoFactoryBean.getBeanFactory());
assertThat(repositoryInformation.getRepositoryBaseClass()).isEqualTo(ModuleBase.class);
}
static RegisteredBean repositoryFactory(Class<?> configClass) {
AnnotationConfigApplicationContext applicationContext = new AnnotationConfigApplicationContext();
applicationContext.register(configClass);
applicationContext.refreshForAotProcessing(new RuntimeHints());
String[] beanNamesForType = applicationContext.getBeanNamesForType(RepositoryFactoryBeanSupport.class);
if (beanNamesForType.length != 1) {
throw new IllegalStateException("Unable to find repository FactoryBean");
}
return RegisteredBean.of(applicationContext.getBeanFactory(), beanNamesForType[0]);
}
static class ModuleBase {}
}
Loading…
Cancel
Save