Browse Source

Refine Repository Composition retrieval during AOT.

Add module identifier and base repository implementation properties.
Fix fragment function previously overriding already set property due to name clash.
Extend tests for bean definition resolution and code block creation.

See: #3279
Original Pull Request: #3282
pull/3304/head
Christoph Strobl 8 months ago committed by Mark Paluch
parent
commit
9c966376a7
No known key found for this signature in database
GPG Key ID: 55BC6374BAA9D973
  1. 7
      src/main/java/org/springframework/data/repository/aot/generate/AotRepositoryBuilder.java
  2. 2
      src/main/java/org/springframework/data/repository/aot/generate/MethodContributor.java
  3. 2
      src/main/java/org/springframework/data/repository/config/AotRepositoryBeanDefinitionPropertiesDecorator.java
  4. 36
      src/main/java/org/springframework/data/repository/config/AotRepositoryInformation.java
  5. 131
      src/main/java/org/springframework/data/repository/config/RepositoryBeanDefinitionReader.java
  6. 21
      src/main/java/org/springframework/data/repository/config/RepositoryRegistrationAotContribution.java
  7. 5
      src/main/java/org/springframework/data/repository/core/RepositoryInformation.java
  8. 2
      src/main/java/org/springframework/data/repository/core/RepositoryInformationSupport.java
  9. 16
      src/main/java/org/springframework/data/repository/core/support/RepositoryFactoryBeanSupport.java
  10. 2
      src/main/java/org/springframework/data/repository/core/support/RepositoryFragment.java
  11. 2
      src/test/java/example/UserRepository.java
  12. 25
      src/test/java/example/UserRepositoryExtension.java
  13. 29
      src/test/java/example/UserRepositoryExtensionImpl.java
  14. 157
      src/test/java/org/springframework/data/repository/aot/generate/AotRepositoryBuilderUnitTests.java
  15. 88
      src/test/java/org/springframework/data/repository/aot/generate/AotRepositoryMethodBuilderUnitTests.java
  16. 57
      src/test/java/org/springframework/data/repository/aot/generate/MethodCapturingRepositoryContributor.java
  17. 167
      src/test/java/org/springframework/data/repository/aot/generate/RepositoryContributorUnitTests.java
  18. 122
      src/test/java/org/springframework/data/repository/config/RepositoryBeanDefinitionReaderTests.java

7
src/main/java/org/springframework/data/repository/aot/generate/AotRepositoryBuilder.java

@ -138,9 +138,8 @@ class AotRepositoryBuilder { @@ -138,9 +138,8 @@ class AotRepositoryBuilder {
this.customizer.customize(repositoryInformation, generationMetadata, builder);
JavaFile javaFile = JavaFile.builder(packageName(), builder.build()).build();
// TODO: module identifier
AotRepositoryMetadata metadata = new AotRepositoryMetadata(repositoryInformation.getRepositoryInterface().getName(),
"", repositoryType, methodMetadata);
repositoryInformation.moduleName() != null ? repositoryInformation.moduleName() : "", repositoryType, methodMetadata);
return new AotBundle(javaFile, metadata.toJson());
}
@ -148,15 +147,15 @@ class AotRepositoryBuilder { @@ -148,15 +147,15 @@ class AotRepositoryBuilder {
private void contributeMethod(Method method, RepositoryComposition repositoryComposition,
List<AotRepositoryMethod> methodMetadata, TypeSpec.Builder builder) {
if (repositoryInformation.isCustomMethod(method) || repositoryInformation.isBaseClassMethod(method)) {
if (repositoryInformation.isCustomMethod(method) || (repositoryInformation.isBaseClassMethod(method) && !repositoryInformation.isQueryMethod(method))) {
RepositoryFragment<?> fragment = repositoryComposition.findFragment(method);
if (fragment != null) {
methodMetadata.add(getFragmentMetadata(method, fragment));
}
return;
}
}
if (method.isBridge() || method.isDefault() || java.lang.reflect.Modifier.isStatic(method.getModifiers())) {
return;

2
src/main/java/org/springframework/data/repository/aot/generate/MethodContributor.java

@ -36,7 +36,7 @@ public abstract class MethodContributor<M extends QueryMethod> { @@ -36,7 +36,7 @@ public abstract class MethodContributor<M extends QueryMethod> {
private final M queryMethod;
private final QueryMetadata metadata;
private MethodContributor(M queryMethod, QueryMetadata metadata) {
MethodContributor(M queryMethod, QueryMetadata metadata) {
this.queryMethod = queryMethod;
this.metadata = metadata;
}

2
src/main/java/org/springframework/data/repository/config/AotRepositoryBeanDefinitionPropertiesDecorator.java

@ -55,7 +55,7 @@ class AotRepositoryBeanDefinitionPropertiesDecorator { @@ -55,7 +55,7 @@ class AotRepositoryBeanDefinitionPropertiesDecorator {
// bring in properties as usual
builder.add(inheritedProperties.get());
builder.add("beanDefinition.getPropertyValues().addPropertyValue(\"repositoryFragments\", new $T() {\n",
builder.add("beanDefinition.getPropertyValues().addPropertyValue(\"repositoryFragmentsFunction\", new $T() {\n",
RepositoryFactoryBeanSupport.RepositoryFragmentsFunction.class);
builder.indent();
builder.add("public $T getRepositoryFragments($T beanFactory, $T context) {\n",

36
src/main/java/org/springframework/data/repository/config/AotRepositoryInformation.java

@ -21,10 +21,12 @@ import java.util.LinkedHashSet; @@ -21,10 +21,12 @@ import java.util.LinkedHashSet;
import java.util.Set;
import java.util.function.Supplier;
import org.jspecify.annotations.Nullable;
import org.springframework.data.repository.core.RepositoryInformation;
import org.springframework.data.repository.core.RepositoryInformationSupport;
import org.springframework.data.repository.core.RepositoryMetadata;
import org.springframework.data.repository.core.support.RepositoryComposition;
import org.springframework.data.repository.core.support.RepositoryComposition.RepositoryFragments;
import org.springframework.data.repository.core.support.RepositoryFragment;
import org.springframework.data.util.Lazy;
@ -36,16 +38,31 @@ import org.springframework.data.util.Lazy; @@ -36,16 +38,31 @@ import org.springframework.data.util.Lazy;
*/
class AotRepositoryInformation extends RepositoryInformationSupport implements RepositoryInformation {
private final @Nullable String moduleName;
private final Supplier<Collection<RepositoryFragment<?>>> fragments;
private Lazy<RepositoryComposition> baseComposition = Lazy.of(() -> {
return RepositoryComposition.of(RepositoryFragment.structural(getRepositoryBaseClass()));
});
AotRepositoryInformation(Supplier<RepositoryMetadata> repositoryMetadata, Supplier<Class<?>> repositoryBaseClass,
Supplier<Collection<RepositoryFragment<?>>> fragments) {
private final Lazy<RepositoryComposition> repositoryComposition;
private final Lazy<RepositoryComposition> baseComposition;
AotRepositoryInformation(@Nullable String moduleName, Supplier<RepositoryMetadata> repositoryMetadata,
Supplier<Class<?>> repositoryBaseClass, Supplier<Collection<RepositoryFragment<?>>> fragments) {
super(repositoryMetadata, repositoryBaseClass);
this.moduleName = moduleName;
this.fragments = fragments;
this.repositoryComposition = Lazy
.of(() -> RepositoryComposition.fromMetadata(getMetadata()).append(RepositoryFragments.from(getFragments())));
this.baseComposition = Lazy.of(() -> {
RepositoryComposition targetRepoComposition = repositoryComposition.get();
return RepositoryComposition.of(RepositoryFragment.structural(getRepositoryBaseClass())) //
.withArgumentConverter(targetRepoComposition.getArgumentConverter()) //
.withMethodLookup(targetRepoComposition.getMethodLookup());
});
}
/**
@ -57,10 +74,9 @@ class AotRepositoryInformation extends RepositoryInformationSupport implements R @@ -57,10 +74,9 @@ class AotRepositoryInformation extends RepositoryInformationSupport implements R
return new LinkedHashSet<>(fragments.get());
}
// Not required during AOT processing.
@Override
public boolean isCustomMethod(Method method) {
return false;
return repositoryComposition.get().findMethod(method).isPresent();
}
@Override
@ -75,7 +91,11 @@ class AotRepositoryInformation extends RepositoryInformationSupport implements R @@ -75,7 +91,11 @@ class AotRepositoryInformation extends RepositoryInformationSupport implements R
@Override
public RepositoryComposition getRepositoryComposition() {
return baseComposition.get().append(RepositoryComposition.RepositoryFragments.from(fragments.get()));
return repositoryComposition.get();
}
@Override
public @Nullable String moduleName() {
return moduleName;
}
}

131
src/main/java/org/springframework/data/repository/config/RepositoryBeanDefinitionReader.java

@ -16,17 +16,21 @@ @@ -16,17 +16,21 @@
package org.springframework.data.repository.config;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.beans.factory.config.ConstructorArgumentValues.ValueHolder;
import org.springframework.beans.factory.config.RuntimeBeanReference;
import org.springframework.beans.factory.support.RegisteredBean;
import org.springframework.core.ResolvableType;
import org.springframework.data.repository.CrudRepository;
import org.springframework.data.repository.PagingAndSortingRepository;
import org.springframework.data.repository.core.RepositoryInformation;
import org.springframework.data.repository.core.support.DefaultRepositoryMetadata;
import org.springframework.data.repository.core.RepositoryMetadata;
import org.springframework.data.repository.core.support.AbstractRepositoryMetadata;
import org.springframework.data.repository.core.support.RepositoryFragment;
import org.springframework.data.util.Lazy;
import org.springframework.data.repository.core.support.RepositoryFragment.ImplementedRepositoryFragment;
import org.springframework.util.ClassUtils;
/**
@ -38,49 +42,108 @@ import org.springframework.util.ClassUtils; @@ -38,49 +42,108 @@ import org.springframework.util.ClassUtils;
*/
class RepositoryBeanDefinitionReader {
static RepositoryInformation readRepositoryInformation(RepositoryConfiguration<?> metadata,
ConfigurableListableBeanFactory beanFactory) {
return new AotRepositoryInformation(metadataSupplier(metadata, beanFactory),
repositoryBaseClass(metadata, beanFactory), fragments(metadata, beanFactory));
/**
* @return
*/
static RepositoryInformation repositoryInformation(RepositoryConfiguration<?> repoConfig, RegisteredBean repoBean) {
return repositoryInformation(repoConfig, repoBean.getMergedBeanDefinition(), repoBean.getBeanFactory());
}
private static Supplier<Collection<RepositoryFragment<?>>> fragments(RepositoryConfiguration<?> metadata,
/**
* @param source the RepositoryFactoryBeanSupport bean definition.
* @param beanFactory
* @return
*/
@SuppressWarnings("NullAway")
static RepositoryInformation repositoryInformation(RepositoryConfiguration<?> repoConfig, BeanDefinition source,
ConfigurableListableBeanFactory beanFactory) {
if (metadata instanceof RepositoryFragmentConfigurationProvider provider) {
return Lazy.of(() -> {
return provider.getFragmentConfiguration().stream().flatMap(it -> {
RepositoryMetadata metadata = AbstractRepositoryMetadata
.getMetadata(forName(repoConfig.getRepositoryInterface(), beanFactory));
Class<?> repositoryBaseClass = readRepositoryBaseClass(source, beanFactory);
List<RepositoryFragment<?>> fragmentList = readRepositoryFragments(source, beanFactory);
if (source.getPropertyValues().contains("customImplementation")) {
Object o = source.getPropertyValues().get("customImplementation");
if (o instanceof RuntimeBeanReference rbr) {
BeanDefinition customImplBeanDefintion = beanFactory.getBeanDefinition(rbr.getBeanName());
Class<?> beanType = forName(customImplBeanDefintion.getBeanClassName(), beanFactory);
ResolvableType[] interfaces = ResolvableType.forClass(beanType).getInterfaces();
if (interfaces.length == 1) {
fragmentList.add(new ImplementedRepositoryFragment(interfaces[0].toClass(), beanType));
} else {
boolean found = false;
for (ResolvableType i : interfaces) {
if (beanType.getSimpleName().contains(i.resolve().getSimpleName())) {
fragmentList.add(new ImplementedRepositoryFragment(interfaces[0].toClass(), beanType));
found = true;
break;
}
}
if (!found) {
fragmentList.add(RepositoryFragment.implemented(beanType));
}
}
}
}
List<RepositoryFragment<?>> fragments = new ArrayList<>(1);
String moduleName = (String) source.getPropertyValues().get("moduleName");
AotRepositoryInformation repositoryInformation = new AotRepositoryInformation(moduleName, () -> metadata,
() -> repositoryBaseClass, () -> fragmentList);
return repositoryInformation;
}
fragments.add(RepositoryFragment.implemented(forName(it.getClassName(), beanFactory)));
@SuppressWarnings("NullAway")
private static Class<?> readRepositoryBaseClass(BeanDefinition source, ConfigurableListableBeanFactory beanFactory) {
if (it.getInterfaceName() != null) {
fragments.add(RepositoryFragment.structural(forName(it.getInterfaceName(), beanFactory)));
Object repoBaseClassName = source.getPropertyValues().get("repositoryBaseClass");
if (repoBaseClassName != null) {
return forName(repoBaseClassName.toString(), beanFactory);
}
return fragments.stream();
}).collect(Collectors.toList());
});
if (source.getPropertyValues().contains("moduleBaseClass")) {
return forName((String) source.getPropertyValues().get("moduleBaseClass"), beanFactory);
}
return Lazy.of(Collections::emptyList);
return Dummy.class;
}
@SuppressWarnings({ "rawtypes", "unchecked" })
private static Supplier<Class<?>> repositoryBaseClass(RepositoryConfiguration metadata,
@SuppressWarnings("NullAway")
private static List<RepositoryFragment<?>> readRepositoryFragments(BeanDefinition source,
ConfigurableListableBeanFactory beanFactory) {
return Lazy.of(() -> (Class<?>) metadata.getRepositoryBaseClassName().map(it -> forName(it.toString(), beanFactory))
.orElse(Object.class));
}
RuntimeBeanReference beanReference = (RuntimeBeanReference) source.getPropertyValues().get("repositoryFragments");
BeanDefinition fragments = beanFactory.getBeanDefinition(beanReference.getBeanName());
ValueHolder fragmentBeanNameList = fragments.getConstructorArgumentValues().getArgumentValue(0, List.class);
List<String> fragmentBeanNames = (List<String>) fragmentBeanNameList.getValue();
List<RepositoryFragment<?>> fragmentList = new ArrayList<>();
for (String beanName : fragmentBeanNames) {
private static Supplier<org.springframework.data.repository.core.RepositoryMetadata> metadataSupplier(
RepositoryConfiguration<?> metadata, ConfigurableListableBeanFactory beanFactory) {
return Lazy.of(() -> new DefaultRepositoryMetadata(forName(metadata.getRepositoryInterface(), beanFactory)));
BeanDefinition fragmentBeanDefinition = beanFactory.getBeanDefinition(beanName);
ValueHolder argumentValue = fragmentBeanDefinition.getConstructorArgumentValues().getArgumentValue(0,
String.class);
ValueHolder argumentValue1 = fragmentBeanDefinition.getConstructorArgumentValues().getArgumentValue(1, null, null,
null);
Object fragmentClassName = argumentValue.getValue();
try {
Class<?> type = ClassUtils.forName(fragmentClassName.toString(), beanFactory.getBeanClassLoader());
if (argumentValue1 != null && argumentValue1.getValue() instanceof RuntimeBeanReference rbf) {
BeanDefinition implBeanDef = beanFactory.getBeanDefinition(rbf.getBeanName());
Class implClass = ClassUtils.forName(implBeanDef.getBeanClassName(), beanFactory.getBeanClassLoader());
fragmentList.add(new RepositoryFragment.ImplementedRepositoryFragment(type, implClass));
} else {
fragmentList.add(RepositoryFragment.structural(type));
}
} catch (ClassNotFoundException e) {
throw new RuntimeException(e);
}
}
return fragmentList;
}
static abstract class Dummy implements CrudRepository<Object, Object>, PagingAndSortingRepository<Object, Object> {}
static Class<?> forName(String name, ConfigurableListableBeanFactory beanFactory) {
try {

21
src/main/java/org/springframework/data/repository/config/RepositoryRegistrationAotContribution.java

@ -28,7 +28,6 @@ import java.util.function.BiFunction; @@ -28,7 +28,6 @@ import java.util.function.BiFunction;
import java.util.function.Predicate;
import org.jspecify.annotations.Nullable;
import org.springframework.aop.SpringProxy;
import org.springframework.aop.framework.Advised;
import org.springframework.aot.generate.GenerationContext;
@ -49,7 +48,6 @@ import org.springframework.data.projection.TargetAware; @@ -49,7 +48,6 @@ import org.springframework.data.projection.TargetAware;
import org.springframework.data.repository.Repository;
import org.springframework.data.repository.aot.generate.RepositoryContributor;
import org.springframework.data.repository.core.RepositoryInformation;
import org.springframework.data.repository.core.support.RepositoryFactoryBeanSupport;
import org.springframework.data.repository.core.support.RepositoryFragment;
import org.springframework.data.util.Predicates;
import org.springframework.data.util.QTypeContributor;
@ -90,8 +88,7 @@ public class RepositoryRegistrationAotContribution implements BeanRegistrationAo @@ -90,8 +88,7 @@ public class RepositoryRegistrationAotContribution implements BeanRegistrationAo
* @throws IllegalArgumentException if the {@link RepositoryRegistrationAotProcessor} is {@literal null}.
* @see RepositoryRegistrationAotProcessor
*/
protected RepositoryRegistrationAotContribution(
RepositoryRegistrationAotProcessor processor) {
protected RepositoryRegistrationAotContribution(RepositoryRegistrationAotProcessor processor) {
Assert.notNull(processor, "RepositoryRegistrationAotProcessor must not be null");
@ -108,8 +105,7 @@ public class RepositoryRegistrationAotContribution implements BeanRegistrationAo @@ -108,8 +105,7 @@ public class RepositoryRegistrationAotContribution implements BeanRegistrationAo
* @throws IllegalArgumentException if the {@link RepositoryRegistrationAotProcessor} is {@literal null}.
* @see RepositoryRegistrationAotProcessor
*/
public static RepositoryRegistrationAotContribution fromProcessor(
RepositoryRegistrationAotProcessor processor) {
public static RepositoryRegistrationAotContribution fromProcessor(RepositoryRegistrationAotProcessor processor) {
return new RepositoryRegistrationAotContribution(processor);
}
@ -255,7 +251,8 @@ public class RepositoryRegistrationAotContribution implements BeanRegistrationAo @@ -255,7 +251,8 @@ public class RepositoryRegistrationAotContribution implements BeanRegistrationAo
});
implementation.ifPresent(impl -> {
contribution.getRuntimeHints().reflection().registerType(impl.getClass(), hint -> {
Class<?> typeToRegister = impl instanceof Class c ? c : impl.getClass();
contribution.getRuntimeHints().reflection().registerType(typeToRegister, hint -> {
hint.withMembers(MemberCategory.INVOKE_PUBLIC_METHODS);
@ -365,18 +362,16 @@ public class RepositoryRegistrationAotContribution implements BeanRegistrationAo @@ -365,18 +362,16 @@ public class RepositoryRegistrationAotContribution implements BeanRegistrationAo
@SuppressWarnings("rawtypes")
private DefaultAotRepositoryContext buildAotRepositoryContext(RegisteredBean bean,
RepositoryConfiguration<?> repositoryMetadata) {
RepositoryConfiguration<?> repositoryConfiguration) {
DefaultAotRepositoryContext repositoryContext = new DefaultAotRepositoryContext(
AotContext.from(getBeanFactory(), getRepositoryRegistrationAotProcessor().getEnvironment()));
RepositoryFactoryBeanSupport rfbs = bean.getBeanFactory().getBean("&" + bean.getBeanName(),
RepositoryFactoryBeanSupport.class);
repositoryContext.setBeanName(bean.getBeanName());
repositoryContext.setBasePackages(repositoryMetadata.getBasePackages().toSet());
repositoryContext.setBasePackages(repositoryConfiguration.getBasePackages().toSet());
repositoryContext.setIdentifyingAnnotations(resolveIdentifyingAnnotations());
repositoryContext.setRepositoryInformation(rfbs.getRepositoryInformation());
repositoryContext
.setRepositoryInformation(RepositoryBeanDefinitionReader.repositoryInformation(repositoryConfiguration, bean));
return repositoryContext;
}

5
src/main/java/org/springframework/data/repository/core/RepositoryInformation.java

@ -18,6 +18,7 @@ package org.springframework.data.repository.core; @@ -18,6 +18,7 @@ package org.springframework.data.repository.core;
import java.lang.reflect.Method;
import java.util.List;
import org.jspecify.annotations.Nullable;
import org.springframework.data.repository.core.support.RepositoryComposition;
/**
@ -105,4 +106,8 @@ public interface RepositoryInformation extends RepositoryMetadata { @@ -105,4 +106,8 @@ public interface RepositoryInformation extends RepositoryMetadata {
*/
RepositoryComposition getRepositoryComposition();
default @Nullable String moduleName() {
return null;
}
}

2
src/main/java/org/springframework/data/repository/core/RepositoryInformationSupport.java

@ -184,7 +184,7 @@ public abstract class RepositoryInformationSupport implements RepositoryInformat @@ -184,7 +184,7 @@ public abstract class RepositoryInformationSupport implements RepositoryInformat
return true;
}
private RepositoryMetadata getMetadata() {
protected RepositoryMetadata getMetadata() {
return metadata.get();
}

16
src/main/java/org/springframework/data/repository/core/support/RepositoryFactoryBeanSupport.java

@ -95,6 +95,10 @@ public abstract class RepositoryFactoryBeanSupport<T extends Repository<S, ID>, @@ -95,6 +95,10 @@ public abstract class RepositoryFactoryBeanSupport<T extends Repository<S, ID>,
private @Nullable Lazy<T> repository;
private @Nullable RepositoryMetadata repositoryMetadata;
// AOT bean factory hint?
private @Nullable String moduleBaseClass;
private @Nullable String moduleName;
/**
* Creates a new {@link RepositoryFactoryBeanSupport} for the given repository interface.
*
@ -155,7 +159,7 @@ public abstract class RepositoryFactoryBeanSupport<T extends Repository<S, ID>, @@ -155,7 +159,7 @@ public abstract class RepositoryFactoryBeanSupport<T extends Repository<S, ID>,
* @param repositoryFragments
*/
public void setRepositoryFragments(RepositoryFragments repositoryFragments) {
setRepositoryFragments(RepositoryFragmentsFunction.just(repositoryFragments));
setRepositoryFragmentsFunction(RepositoryFragmentsFunction.just(repositoryFragments));
}
/**
@ -165,7 +169,7 @@ public abstract class RepositoryFactoryBeanSupport<T extends Repository<S, ID>, @@ -165,7 +169,7 @@ public abstract class RepositoryFactoryBeanSupport<T extends Repository<S, ID>,
* @param fragmentsFunction
* @since 4.0
*/
public void setRepositoryFragments(RepositoryFragmentsFunction fragmentsFunction) {
public void setRepositoryFragmentsFunction(RepositoryFragmentsFunction fragmentsFunction) {
this.fragments.add(fragmentsFunction);
}
@ -257,6 +261,14 @@ public abstract class RepositoryFactoryBeanSupport<T extends Repository<S, ID>, @@ -257,6 +261,14 @@ public abstract class RepositoryFactoryBeanSupport<T extends Repository<S, ID>,
this.publisher = publisher;
}
public void setModuleBaseClass(String moduleBaseClass) {
this.moduleBaseClass = moduleBaseClass;
}
public void setModuleName(String moduleName) {
this.moduleName = moduleName;
}
@Override
@SuppressWarnings("unchecked")
public EntityInformation<S, ID> getEntityInformation() {

2
src/main/java/org/springframework/data/repository/core/support/RepositoryFragment.java

@ -265,7 +265,7 @@ public interface RepositoryFragment<T> { @@ -265,7 +265,7 @@ public interface RepositoryFragment<T> {
Assert.notNull(implementation, "Implementation object must not be null");
if (interfaceClass != null) {
if (interfaceClass != null && !(implementation instanceof Class)) {
Assert
.isTrue(ClassUtils.isAssignableValue(interfaceClass, implementation),

2
src/test/java/example/UserRepository.java

@ -24,7 +24,7 @@ import org.springframework.data.repository.CrudRepository; @@ -24,7 +24,7 @@ import org.springframework.data.repository.CrudRepository;
/**
* @author Christoph Strobl
*/
public interface UserRepository extends CrudRepository<User, Long> {
public interface UserRepository extends CrudRepository<User, Long>, UserRepositoryExtension {
User findByFirstname(String firstname);

25
src/test/java/example/UserRepositoryExtension.java

@ -0,0 +1,25 @@ @@ -0,0 +1,25 @@
/*
* Copyright 2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package example;
import example.UserRepository.User;
/**
* @author Christoph Strobl
*/
public interface UserRepositoryExtension {
User findUserByExtensionMethod();
}

29
src/test/java/example/UserRepositoryExtensionImpl.java

@ -0,0 +1,29 @@ @@ -0,0 +1,29 @@
/*
* Copyright 2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package example;
import example.UserRepository.User;
/**
* @author Christoph Strobl
*/
public class UserRepositoryExtensionImpl implements UserRepositoryExtension {
@Override
public User findUserByExtensionMethod() {
return null;
}
}

157
src/test/java/org/springframework/data/repository/aot/generate/AotRepositoryBuilderUnitTests.java

@ -0,0 +1,157 @@ @@ -0,0 +1,157 @@
/*
* Copyright 2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.data.repository.aot.generate;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import example.UserRepository;
import example.UserRepository.User;
import java.util.TimeZone;
import javax.lang.model.element.Modifier;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;
import org.springframework.data.geo.Metric;
import org.springframework.data.projection.SpelAwareProxyProjectionFactory;
import org.springframework.data.repository.core.RepositoryInformation;
import org.springframework.data.repository.query.QueryMethod;
import org.springframework.data.util.TypeInformation;
import org.springframework.javapoet.MethodSpec;
import org.springframework.javapoet.TypeName;
import org.springframework.stereotype.Repository;
/**
* @author Christoph Strobl
*/
class AotRepositoryBuilderUnitTests {
RepositoryInformation repositoryInformation;
@BeforeEach
void beforeEach() {
repositoryInformation = mock(RepositoryInformation.class);
doReturn(UserRepository.class).when(repositoryInformation).getRepositoryInterface();
}
@Test // GH-3279
void writesClassSkeleton() {
AotRepositoryBuilder repoBuilder = AotRepositoryBuilder.forRepository(repositoryInformation,
new SpelAwareProxyProjectionFactory());
assertThat(repoBuilder.build().javaFile().toString())
.contains("package %s;".formatted(UserRepository.class.getPackageName())) // same package as source repo
.contains("@Generated") // marked as generated source
.contains("public class %sImpl__Aot".formatted(UserRepository.class.getSimpleName())) // target name
.contains("public UserRepositoryImpl__Aot()"); // default constructor if not arguments to wire
}
@Test // GH-3279
void appliesCtorArguments() {
AotRepositoryBuilder repoBuilder = AotRepositoryBuilder.forRepository(repositoryInformation,
new SpelAwareProxyProjectionFactory());
repoBuilder.withConstructorCustomizer(ctor -> {
ctor.addParameter("param1", Metric.class);
ctor.addParameter("param2", String.class);
ctor.addParameter("ctorScoped", TypeName.OBJECT, false);
});
assertThat(repoBuilder.build().javaFile().toString()) //
.contains("private final Metric param1;") //
.contains("private final String param2;") //
.doesNotContain("private final Object ctorScoped;") //
.contains("public UserRepositoryImpl__Aot(Metric param1, String param2, Object ctorScoped)") //
.contains("this.param1 = param1") //
.contains("this.param2 = param2") //
.doesNotContain("this.ctorScoped = ctorScoped");
}
@Test // GH-3279
void appliesCtorCodeBlock() {
AotRepositoryBuilder repoBuilder = AotRepositoryBuilder.forRepository(repositoryInformation,
new SpelAwareProxyProjectionFactory());
repoBuilder.withConstructorCustomizer(ctor -> {
ctor.customize((info, code) -> {
code.addStatement("throw new $T($S)", IllegalStateException.class, "initialization error");
});
});
assertThat(repoBuilder.build().javaFile().toString()).containsIgnoringWhitespaces(
"UserRepositoryImpl__Aot() { throw new IllegalStateException(\"initialization error\"); }");
}
@Test // GH-3279
void appliesClassCustomizations() {
AotRepositoryBuilder repoBuilder = AotRepositoryBuilder.forRepository(repositoryInformation,
new SpelAwareProxyProjectionFactory());
repoBuilder.withClassCustomizer((info, metadata, clazz) -> {
clazz.addField(Float.class, "f", Modifier.PRIVATE, Modifier.STATIC);
clazz.addField(Double.class, "d", Modifier.PUBLIC);
clazz.addField(TimeZone.class, "t", Modifier.FINAL);
clazz.addAnnotation(Repository.class);
clazz.addMethod(MethodSpec.methodBuilder("oops").build());
});
assertThat(repoBuilder.build().javaFile().toString()) //
.contains("@Repository") //
.contains("private static Float f;") //
.contains("public Double d;") //
.contains("final TimeZone t;") //
.containsIgnoringWhitespaces("void oops() { }");
}
@Test // GH-3279
void appliesQueryMethodContributor() {
AotRepositoryBuilder repoBuilder = AotRepositoryBuilder.forRepository(repositoryInformation,
new SpelAwareProxyProjectionFactory());
when(repositoryInformation.isQueryMethod(Mockito.argThat(arg -> arg.getName().equals("findByFirstname"))))
.thenReturn(true);
doReturn(TypeInformation.of(User.class)).when(repositoryInformation).getReturnType(any());
repoBuilder.withQueryMethodContributor((method, info) -> {
return new MethodContributor<>(mock(QueryMethod.class), null) {
@Override
public MethodSpec contribute(AotQueryMethodGenerationContext context) {
return MethodSpec.methodBuilder("oops").build();
}
@Override
public boolean contributesMethodSpec() {
return true;
}
};
});
assertThat(repoBuilder.build().javaFile().toString()) //
.containsIgnoringWhitespaces("void oops() { }");
}
}

88
src/test/java/org/springframework/data/repository/aot/generate/AotRepositoryMethodBuilderUnitTests.java

@ -0,0 +1,88 @@ @@ -0,0 +1,88 @@
/*
* Copyright 2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.data.repository.aot.generate;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.when;
import example.UserRepository;
import example.UserRepository.User;
import java.lang.reflect.Method;
import java.util.List;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;
import org.springframework.core.ResolvableType;
import org.springframework.data.repository.core.RepositoryInformation;
import org.springframework.data.util.TypeInformation;
import org.springframework.javapoet.ParameterSpec;
import org.springframework.javapoet.ParameterizedTypeName;
/**
* @author Christoph Strobl
*/
class AotRepositoryMethodBuilderUnitTests {
RepositoryInformation repositoryInformation;
AotQueryMethodGenerationContext methodGenerationContext;
@BeforeEach
void beforeEach() {
repositoryInformation = Mockito.mock(RepositoryInformation.class);
methodGenerationContext = Mockito.mock(AotQueryMethodGenerationContext.class);
when(methodGenerationContext.getRepositoryInformation()).thenReturn(repositoryInformation);
}
@Test // GH-3279
void generatesMethodSkeletonBasedOnGenerationMetadata() throws NoSuchMethodException {
Method method = UserRepository.class.getMethod("findByFirstname", String.class);
when(methodGenerationContext.getMethod()).thenReturn(method);
when(methodGenerationContext.getReturnType()).thenReturn(ResolvableType.forClass(User.class));
doReturn(TypeInformation.of(User.class)).when(repositoryInformation).getReturnType(any());
MethodMetadata methodMetadata = new MethodMetadata(repositoryInformation, method);
methodMetadata.addParameter(ParameterSpec.builder(String.class, "firstname").build());
when(methodGenerationContext.getTargetMethodMetadata()).thenReturn(methodMetadata);
AotRepositoryMethodBuilder builder = new AotRepositoryMethodBuilder(methodGenerationContext);
assertThat(builder.buildMethod().toString()) //
.containsPattern("public .*User findByFirstname\\(.*String firstname\\)");
}
@Test // GH-3279
void generatesMethodWithGenerics() throws NoSuchMethodException {
Method method = UserRepository.class.getMethod("findByFirstnameIn", List.class);
when(methodGenerationContext.getMethod()).thenReturn(method);
when(methodGenerationContext.getReturnType())
.thenReturn(ResolvableType.forClassWithGenerics(List.class, User.class));
doReturn(TypeInformation.of(User.class)).when(repositoryInformation).getReturnType(any());
MethodMetadata methodMetadata = new MethodMetadata(repositoryInformation, method);
methodMetadata
.addParameter(ParameterSpec.builder(ParameterizedTypeName.get(List.class, String.class), "firstnames").build());
when(methodGenerationContext.getTargetMethodMetadata()).thenReturn(methodMetadata);
AotRepositoryMethodBuilder builder = new AotRepositoryMethodBuilder(methodGenerationContext);
assertThat(builder.buildMethod().toString()) //
.containsPattern("public .*List<.*User> findByFirstnameIn\\(") //
.containsPattern(".*List<.*String> firstnames\\)");
}
}

57
src/test/java/org/springframework/data/repository/aot/generate/MethodCapturingRepositoryContributor.java

@ -0,0 +1,57 @@ @@ -0,0 +1,57 @@
/*
* Copyright 2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.data.repository.aot.generate;
import static org.assertj.core.api.Assertions.assertThat;
import java.lang.reflect.Method;
import java.util.List;
import org.assertj.core.api.MapAssert;
import org.jspecify.annotations.Nullable;
import org.springframework.data.repository.config.AotRepositoryContext;
import org.springframework.data.repository.core.RepositoryInformation;
import org.springframework.data.repository.query.QueryMethod;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
/**
* @author Christoph Strobl
*/
public class MethodCapturingRepositoryContributor extends RepositoryContributor {
MultiValueMap<String, Method> capturedInvocations;
public MethodCapturingRepositoryContributor(AotRepositoryContext repositoryContext) {
super(repositoryContext);
this.capturedInvocations = new LinkedMultiValueMap<>(3);
}
@Override
protected @Nullable MethodContributor<? extends QueryMethod> contributeQueryMethod(Method method,
RepositoryInformation repositoryInformation) {
capturedInvocations.add(method.getName(), method);
return null;
}
void verifyContributionFor(String methodName) {
assertThat(capturedInvocations).containsKey(methodName);
}
MapAssert<String, List<Method>> verifyContributedMethods() {
return assertThat(capturedInvocations);
}
}

167
src/test/java/org/springframework/data/repository/aot/generate/RepositoryContributorUnitTests.java

@ -15,31 +15,45 @@ @@ -15,31 +15,45 @@
*/
package org.springframework.data.repository.aot.generate;
import static org.assertj.core.api.Assertions.*;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.Mockito.when;
import example.UserRepository;
import example.UserRepositoryExtension;
import example.UserRepositoryExtensionImpl;
import java.lang.reflect.Method;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import org.jspecify.annotations.Nullable;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;
import org.springframework.aot.test.generate.TestGenerationContext;
import org.springframework.core.test.tools.TestCompiler;
import org.springframework.data.aot.CodeContributionAssert;
import org.springframework.data.repository.CrudRepository;
import org.springframework.data.repository.config.AotRepositoryContext;
import org.springframework.data.repository.core.RepositoryInformation;
import org.springframework.data.repository.core.support.RepositoryComposition;
import org.springframework.data.repository.core.support.RepositoryComposition.RepositoryFragments;
import org.springframework.data.repository.core.support.RepositoryFragment;
import org.springframework.data.repository.query.QueryMethod;
import org.springframework.javapoet.CodeBlock;
import org.springframework.util.ClassUtils;
/**
* Unit tests targeting {@link RepositoryContributor}.
*
* @author Christoph Strobl
*/
class RepositoryContributorUnitTests {
@Test
void testCompile() {
@Test // GH-3279
void createsCompilableClassStub() {
DummyModuleAotRepositoryContext aotContext = new DummyModuleAotRepositoryContext(UserRepository.class, null);
RepositoryContributor repositoryContributor = new RepositoryContributor(aotContext) {
@ -55,8 +69,7 @@ class RepositoryContributorUnitTests { @@ -55,8 +69,7 @@ class RepositoryContributorUnitTests {
public Map<String, Object> serialize() {
return Map.of();
}
})
.contribute(context -> {
}).contribute(context -> {
CodeBlock.Builder builder = CodeBlock.builder();
if (!ClassUtils.isVoidType(method.getReturnType())) {
@ -81,4 +94,146 @@ class RepositoryContributorUnitTests { @@ -81,4 +94,146 @@ class RepositoryContributorUnitTests {
new CodeContributionAssert(generationContext).contributesReflectionFor(expectedTypeName);
}
@Test // GH-3279
void callsMethodContributionForQueryMethod() {
AotRepositoryContext repositoryContext = Mockito.mock(AotRepositoryContext.class);
RepositoryInformation repositoryInformation = Mockito.mock(RepositoryInformation.class);
when(repositoryContext.getRepositoryInformation()).thenReturn(repositoryInformation);
when(repositoryInformation.getRepositoryInterface()).thenReturn((Class) UserRepository.class);
when(repositoryInformation.isQueryMethod(argThat(it -> it.getName().equals("findByFirstname")))).thenReturn(true);
MethodCapturingRepositoryContributor contributor = new MethodCapturingRepositoryContributor(repositoryContext);
contributor.contribute(new TestGenerationContext(UserRepository.class));
contributor.verifyContributionFor("findByFirstname");
}
@Test // GH-3279
void doesNotContributeBaseClassMethods() {
AotRepositoryContext repositoryContext = Mockito.mock(AotRepositoryContext.class);
RepositoryInformation repositoryInformation = Mockito.mock(RepositoryInformation.class);
when(repositoryContext.getRepositoryInformation()).thenReturn(repositoryInformation);
when(repositoryInformation.getRepositoryInterface()).thenReturn((Class) UserRepository.class);
when(repositoryInformation.getRepositoryComposition())
.thenReturn(RepositoryComposition.of(RepositoryFragment.structural(RepoBaseClass.class)));
when(repositoryInformation.isBaseClassMethod(argThat(it -> it.getName().equals("findByFirstname"))))
.thenReturn(true);
when(repositoryInformation.isQueryMethod(argThat(it -> !it.getName().equals("findByFirstname")))).thenReturn(true);
MethodCapturingRepositoryContributor contributor = new MethodCapturingRepositoryContributor(repositoryContext);
contributor.contribute(new TestGenerationContext(UserRepository.class));
contributor.verifyContributedMethods().isNotEmpty().doesNotContainKey("findByFirstname");
}
@Test // GH-3279
void doesNotContributeFragmentMethod() {
AotRepositoryContext repositoryContext = Mockito.mock(AotRepositoryContext.class);
RepositoryInformation repositoryInformation = Mockito.mock(RepositoryInformation.class);
when(repositoryContext.getRepositoryInformation()).thenReturn(repositoryInformation);
when(repositoryInformation.getRepositoryInterface()).thenReturn((Class) UserRepository.class);
when(repositoryInformation.getRepositoryComposition())
.thenReturn(RepositoryComposition.of(RepositoryFragment.structural(UserRepository.class))
.append(RepositoryFragments
.from(Set.of(new RepositoryFragment.ImplementedRepositoryFragment(UserRepositoryExtension.class,
UserRepositoryExtensionImpl.class)))));
when(repositoryInformation.isCustomMethod(argThat(it -> it.getName().equals("findUserByExtensionMethod"))))
.thenReturn(true);
when(repositoryInformation.isQueryMethod(argThat(it -> it.getName().equals("findByFirstname")))).thenReturn(true);
MethodCapturingRepositoryContributor contributor = new MethodCapturingRepositoryContributor(repositoryContext);
contributor.contribute(new TestGenerationContext(UserRepository.class));
contributor.verifyContributedMethods().isNotEmpty().doesNotContainKey("findUserByExtensionMethod");
}
@Test // GH-3279
void contributesBaseClassMethodIfQueryMethod() {
AotRepositoryContext repositoryContext = Mockito.mock(AotRepositoryContext.class);
RepositoryInformation repositoryInformation = Mockito.mock(RepositoryInformation.class);
when(repositoryContext.getRepositoryInformation()).thenReturn(repositoryInformation);
when(repositoryInformation.getRepositoryInterface()).thenReturn((Class) UserRepository.class);
when(repositoryInformation.getRepositoryComposition())
.thenReturn(RepositoryComposition.of(RepositoryFragment.structural(RepoBaseClass.class)));
when(repositoryInformation.isBaseClassMethod(argThat(it -> it.getName().equals("findByFirstname"))))
.thenReturn(true);
when(repositoryInformation.isQueryMethod(any())).thenReturn(true);
MethodCapturingRepositoryContributor contributor = new MethodCapturingRepositoryContributor(repositoryContext);
contributor.contribute(new TestGenerationContext(UserRepository.class));
contributor.verifyContributedMethods().containsKey("findByFirstname").hasSizeGreaterThan(1);
}
static class RepoBaseClass<T, ID> implements CrudRepository<T, ID> {
private CrudRepository<T, ID> delegate;
public <S extends T> S save(S entity) {
return this.delegate.save(entity);
}
@Override
public <S extends T> Iterable<S> saveAll(Iterable<S> entities) {
return this.delegate.saveAll(entities);
}
public Optional<T> findById(ID id) {
return this.delegate.findById(id);
}
@Override
public boolean existsById(ID id) {
return this.delegate.existsById(id);
}
@Override
public Iterable<T> findAll() {
return this.delegate.findAll();
}
@Override
public Iterable<T> findAllById(Iterable<ID> ids) {
return this.delegate.findAllById(ids);
}
@Override
public long count() {
return this.delegate.count();
}
@Override
public void deleteById(ID id) {
this.delegate.deleteById(id);
}
@Override
public void delete(T entity) {
this.delegate.delete(entity);
}
@Override
public void deleteAllById(Iterable<? extends ID> ids) {
this.delegate.deleteAllById(ids);
}
@Override
public void deleteAll(Iterable<? extends T> entities) {
this.delegate.deleteAll(entities);
}
@Override
public void deleteAll() {
this.delegate.deleteAll();
}
}
}

122
src/test/java/org/springframework/data/repository/config/RepositoryBeanDefinitionReaderTests.java

@ -0,0 +1,122 @@ @@ -0,0 +1,122 @@
/*
* Copyright 2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.data.repository.config;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.mock;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;
import org.springframework.aot.hint.RuntimeHints;
import org.springframework.beans.factory.support.RegisteredBean;
import org.springframework.beans.factory.support.RootBeanDefinition;
import org.springframework.context.annotation.AnnotationConfigApplicationContext;
import org.springframework.data.aot.sample.ConfigWithCustomImplementation;
import org.springframework.data.aot.sample.ConfigWithCustomRepositoryBaseClass;
import org.springframework.data.aot.sample.ConfigWithCustomRepositoryBaseClass.CustomerRepositoryWithCustomBaseRepo;
import org.springframework.data.aot.sample.ConfigWithSimpleCrudRepository;
import org.springframework.data.repository.core.RepositoryInformation;
import org.springframework.data.repository.core.support.RepositoryFactoryBeanSupport;
/**
* @author Christoph Strobl
*/
class RepositoryBeanDefinitionReaderTests {
@Test // GH-3279
void readsSimpleConfigFromBeanFactory() {
RegisteredBean repoFactoryBean = repositoryFactory(ConfigWithSimpleCrudRepository.class);
RepositoryConfiguration<?> repoConfig = mock(RepositoryConfiguration.class);
Mockito.when(repoConfig.getRepositoryInterface()).thenReturn(ConfigWithSimpleCrudRepository.MyRepo.class.getName());
RepositoryInformation repositoryInformation = RepositoryBeanDefinitionReader.repositoryInformation(repoConfig,
repoFactoryBean.getMergedBeanDefinition(), repoFactoryBean.getBeanFactory());
assertThat(repositoryInformation.getRepositoryInterface()).isEqualTo(ConfigWithSimpleCrudRepository.MyRepo.class);
assertThat(repositoryInformation.getDomainType()).isEqualTo(ConfigWithSimpleCrudRepository.Person.class);
assertThat(repositoryInformation.getFragments()).isEmpty();
}
@Test // GH-3279
void readsCustomRepoBaseClassFromBeanFactory() {
RegisteredBean repoFactoryBean = repositoryFactory(ConfigWithCustomRepositoryBaseClass.class);
RepositoryConfiguration<?> repoConfig = mock(RepositoryConfiguration.class);
Class<?> repositoryInterfaceType = CustomerRepositoryWithCustomBaseRepo.class;
Mockito.when(repoConfig.getRepositoryInterface()).thenReturn(repositoryInterfaceType.getName());
RepositoryInformation repositoryInformation = RepositoryBeanDefinitionReader.repositoryInformation(repoConfig,
repoFactoryBean.getMergedBeanDefinition(), repoFactoryBean.getBeanFactory());
assertThat(repositoryInformation.getRepositoryBaseClass())
.isEqualTo(ConfigWithCustomRepositoryBaseClass.RepoBaseClass.class);
}
@Test // GH-3279
void readsFragmentsFromBeanFactory() {
RegisteredBean repoFactoryBean = repositoryFactory(ConfigWithCustomImplementation.class);
RepositoryConfiguration<?> repoConfig = mock(RepositoryConfiguration.class);
Class<?> repositoryInterfaceType = ConfigWithCustomImplementation.RepositoryWithCustomImplementation.class;
Mockito.when(repoConfig.getRepositoryInterface()).thenReturn(repositoryInterfaceType.getName());
RepositoryInformation repositoryInformation = RepositoryBeanDefinitionReader.repositoryInformation(repoConfig,
repoFactoryBean.getMergedBeanDefinition(), repoFactoryBean.getBeanFactory());
assertThat(repositoryInformation.getFragments()).satisfiesExactly(fragment -> {
assertThat(fragment.getSignatureContributor())
.isEqualTo(ConfigWithCustomImplementation.CustomImplInterface.class);
});
}
@Test // GH-3279
void fallsBackToModuleBaseClassIfSetAndNoRepoBaseDefined() {
RegisteredBean repoFactoryBean = repositoryFactory(ConfigWithSimpleCrudRepository.class);
RootBeanDefinition rootBeanDefinition = repoFactoryBean.getMergedBeanDefinition().cloneBeanDefinition();
// need to unset because its defined as non default
rootBeanDefinition.getPropertyValues().removePropertyValue("repositoryBaseClass");
rootBeanDefinition.getPropertyValues().add("moduleBaseClass", ModuleBase.class.getName());
RepositoryConfiguration<?> repoConfig = mock(RepositoryConfiguration.class);
Mockito.when(repoConfig.getRepositoryInterface()).thenReturn(ConfigWithSimpleCrudRepository.MyRepo.class.getName());
RepositoryInformation repositoryInformation = RepositoryBeanDefinitionReader.repositoryInformation(repoConfig,
rootBeanDefinition, repoFactoryBean.getBeanFactory());
assertThat(repositoryInformation.getRepositoryBaseClass()).isEqualTo(ModuleBase.class);
}
static RegisteredBean repositoryFactory(Class<?> configClass) {
AnnotationConfigApplicationContext applicationContext = new AnnotationConfigApplicationContext();
applicationContext.register(configClass);
applicationContext.refreshForAotProcessing(new RuntimeHints());
String[] beanNamesForType = applicationContext.getBeanNamesForType(RepositoryFactoryBeanSupport.class);
if (beanNamesForType.length != 1) {
throw new IllegalStateException("Unable to find repository FactoryBean");
}
return RegisteredBean.of(applicationContext.getBeanFactory(), beanNamesForType[0]);
}
static class ModuleBase {}
}
Loading…
Cancel
Save