diff --git a/spring-orm/src/main/java/org/springframework/orm/jpa/LocalContainerEntityManagerFactoryBean.java b/spring-orm/src/main/java/org/springframework/orm/jpa/LocalContainerEntityManagerFactoryBean.java index f977c277bd1..ec5dbaddaea 100644 --- a/spring-orm/src/main/java/org/springframework/orm/jpa/LocalContainerEntityManagerFactoryBean.java +++ b/spring-orm/src/main/java/org/springframework/orm/jpa/LocalContainerEntityManagerFactoryBean.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -33,6 +33,7 @@ import org.springframework.instrument.classloading.LoadTimeWeaver; import org.springframework.jdbc.datasource.lookup.SingleDataSourceLookup; import org.springframework.lang.Nullable; import org.springframework.orm.jpa.persistenceunit.DefaultPersistenceUnitManager; +import org.springframework.orm.jpa.persistenceunit.ManagedClassNameFilter; import org.springframework.orm.jpa.persistenceunit.PersistenceManagedTypes; import org.springframework.orm.jpa.persistenceunit.PersistenceUnitManager; import org.springframework.orm.jpa.persistenceunit.PersistenceUnitPostProcessor; @@ -200,6 +201,16 @@ public class LocalContainerEntityManagerFactoryBean extends AbstractEntityManage this.internalPersistenceUnitManager.setPackagesToScan(packagesToScan); } + /** + * Set the {@link ManagedClassNameFilter} to apply on entity classes discovered + * using {@linkplain #setPackagesToScan(String...) classpath scanning}. + * @param managedClassNameFilter the predicate to filter entity classes + * @since 6.1.4 + */ + public void setManagedClassNameFilter(ManagedClassNameFilter managedClassNameFilter) { + this.internalPersistenceUnitManager.setManagedClassNameFilter(managedClassNameFilter); + } + /** * Specify one or more mapping resources (equivalent to {@code } * entries in {@code persistence.xml}) for the default persistence unit. diff --git a/spring-orm/src/main/java/org/springframework/orm/jpa/persistenceunit/DefaultPersistenceUnitManager.java b/spring-orm/src/main/java/org/springframework/orm/jpa/persistenceunit/DefaultPersistenceUnitManager.java index daeb31b9f1d..2cb8950d822 100644 --- a/spring-orm/src/main/java/org/springframework/orm/jpa/persistenceunit/DefaultPersistenceUnitManager.java +++ b/spring-orm/src/main/java/org/springframework/orm/jpa/persistenceunit/DefaultPersistenceUnitManager.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -116,6 +116,9 @@ public class DefaultPersistenceUnitManager @Nullable private String[] packagesToScan; + @Nullable + private ManagedClassNameFilter managedClassNameFilter; + @Nullable private String[] mappingResources; @@ -223,6 +226,7 @@ public class DefaultPersistenceUnitManager * resource for the default unit if the mapping file is not co-located with a * {@code persistence.xml} file (in which case we assume it is only meant to be * used with the persistence units defined there, like in standard JPA). + * @see #setManagedClassNameFilter(ManagedClassNameFilter) * @see #setManagedTypes(PersistenceManagedTypes) * @see #setDefaultPersistenceUnitName * @see #setMappingResources @@ -231,6 +235,16 @@ public class DefaultPersistenceUnitManager this.packagesToScan = packagesToScan; } + /** + * Set the {@link ManagedClassNameFilter} to apply on entity classes discovered + * using {@linkplain #setPackagesToScan(String...) classpath scanning}. + * @param managedClassNameFilter the predicate to filter entity classes + * @since 6.1.4 + */ + public void setManagedClassNameFilter(ManagedClassNameFilter managedClassNameFilter) { + this.managedClassNameFilter = managedClassNameFilter; + } + /** * Specify one or more mapping resources (equivalent to {@code } * entries in {@code persistence.xml}) for the default persistence unit. @@ -535,8 +549,9 @@ public class DefaultPersistenceUnitManager applyManagedTypes(scannedUnit, this.managedTypes); } else if (this.packagesToScan != null) { - applyManagedTypes(scannedUnit, new PersistenceManagedTypesScanner( - this.resourcePatternResolver).scan(this.packagesToScan)); + PersistenceManagedTypesScanner scanner = new PersistenceManagedTypesScanner( + this.resourcePatternResolver, this.managedClassNameFilter); + applyManagedTypes(scannedUnit, scanner.scan(this.packagesToScan)); } if (this.mappingResources != null) { diff --git a/spring-orm/src/main/java/org/springframework/orm/jpa/persistenceunit/ManagedClassNameFilter.java b/spring-orm/src/main/java/org/springframework/orm/jpa/persistenceunit/ManagedClassNameFilter.java new file mode 100644 index 00000000000..8e4cb0e2901 --- /dev/null +++ b/spring-orm/src/main/java/org/springframework/orm/jpa/persistenceunit/ManagedClassNameFilter.java @@ -0,0 +1,36 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.orm.jpa.persistenceunit; + +/** + * Strategy interface to filter the list of persistent managed types to include + * in the persistence unit. Only class names that match the filter are managed. + * + * @author Stephane Nicoll + * @since 6.1.4 + */ +@FunctionalInterface +public interface ManagedClassNameFilter { + + /** + * Test if the given clas name matches the filter. + * @param className the fully qualified class name of the persistent type to test + * @return {@code true} if the class name matches + */ + boolean matches(String className); + +} diff --git a/spring-orm/src/main/java/org/springframework/orm/jpa/persistenceunit/PersistenceManagedTypesScanner.java b/spring-orm/src/main/java/org/springframework/orm/jpa/persistenceunit/PersistenceManagedTypesScanner.java index 192db122fbf..2ca1a4bf523 100644 --- a/spring-orm/src/main/java/org/springframework/orm/jpa/persistenceunit/PersistenceManagedTypesScanner.java +++ b/spring-orm/src/main/java/org/springframework/orm/jpa/persistenceunit/PersistenceManagedTypesScanner.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -80,10 +80,20 @@ public final class PersistenceManagedTypesScanner { @Nullable private final CandidateComponentsIndex componentsIndex; + private final ManagedClassNameFilter managedClassNameFilter; + + + public PersistenceManagedTypesScanner(ResourceLoader resourceLoader, + @Nullable ManagedClassNameFilter managedClassNameFilter) { - public PersistenceManagedTypesScanner(ResourceLoader resourceLoader) { this.resourcePatternResolver = ResourcePatternUtils.getResourcePatternResolver(resourceLoader); this.componentsIndex = CandidateComponentsIndexLoader.loadIndex(resourceLoader.getClassLoader()); + this.managedClassNameFilter = (managedClassNameFilter != null ? managedClassNameFilter + : className -> true); + } + + public PersistenceManagedTypesScanner(ResourceLoader resourceLoader) { + this(resourceLoader, null); } @@ -107,7 +117,7 @@ public final class PersistenceManagedTypesScanner { for (AnnotationTypeFilter filter : entityTypeFilters) { candidates.addAll(this.componentsIndex.getCandidateTypes(pkg, filter.getAnnotationType().getName())); } - scanResult.managedClassNames.addAll(candidates); + scanResult.managedClassNames.addAll(candidates.stream().filter(this.managedClassNameFilter::matches).toList()); scanResult.managedPackages.addAll(this.componentsIndex.getCandidateTypes(pkg, "package-info")); return; } @@ -121,7 +131,8 @@ public final class PersistenceManagedTypesScanner { try { MetadataReader reader = readerFactory.getMetadataReader(resource); String className = reader.getClassMetadata().getClassName(); - if (matchesFilter(reader, readerFactory)) { + if (matchesEntityTypeFilter(reader, readerFactory) + && this.managedClassNameFilter.matches(className)) { scanResult.managedClassNames.add(className); if (scanResult.persistenceUnitRootUrl == null) { URL url = resource.getURL(); @@ -157,7 +168,7 @@ public final class PersistenceManagedTypesScanner { * Check whether any of the configured entity type filters matches * the current class descriptor contained in the metadata reader. */ - private boolean matchesFilter(MetadataReader reader, MetadataReaderFactory readerFactory) throws IOException { + private boolean matchesEntityTypeFilter(MetadataReader reader, MetadataReaderFactory readerFactory) throws IOException { for (TypeFilter filter : entityTypeFilters) { if (filter.match(reader, readerFactory)) { return true; diff --git a/spring-orm/src/test/java/org/springframework/orm/jpa/persistenceunit/PersistenceManagedTypesScannerTests.java b/spring-orm/src/test/java/org/springframework/orm/jpa/persistenceunit/PersistenceManagedTypesScannerTests.java index bf88cc09627..3bab6697635 100644 --- a/spring-orm/src/test/java/org/springframework/orm/jpa/persistenceunit/PersistenceManagedTypesScannerTests.java +++ b/spring-orm/src/test/java/org/springframework/orm/jpa/persistenceunit/PersistenceManagedTypesScannerTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,8 @@ package org.springframework.orm.jpa.persistenceunit; +import java.util.List; + import org.junit.jupiter.api.Test; import org.springframework.context.testfixture.index.CandidateComponentsTestClassLoader; @@ -28,6 +30,11 @@ import org.springframework.orm.jpa.domain.Person; import org.springframework.orm.jpa.domain2.entity.User; import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; /** * Tests for {@link PersistenceManagedTypesScanner}. @@ -36,7 +43,9 @@ import static org.assertj.core.api.Assertions.assertThat; */ class PersistenceManagedTypesScannerTests { - private final PersistenceManagedTypesScanner scanner = new PersistenceManagedTypesScanner(new DefaultResourceLoader()); + public static final DefaultResourceLoader RESOURCE_LOADER = new DefaultResourceLoader(); + + private final PersistenceManagedTypesScanner scanner = new PersistenceManagedTypesScanner(RESOURCE_LOADER); @Test void scanPackageWithOnlyEntities() { @@ -47,6 +56,29 @@ class PersistenceManagedTypesScannerTests { assertThat(managedTypes.getManagedPackages()).isEmpty(); } + @Test + void scanPackageInvokesManagedClassNamesFilter() { + ManagedClassNameFilter filter = mock(ManagedClassNameFilter.class); + given(filter.matches(anyString())).willReturn(true); + new PersistenceManagedTypesScanner(RESOURCE_LOADER, filter) + .scan("org.springframework.orm.jpa.domain"); + verify(filter).matches(Person.class.getName()); + verify(filter).matches(DriversLicense.class.getName()); + verify(filter).matches(Employee.class.getName()); + verify(filter).matches(EmployeeLocationConverter.class.getName()); + verifyNoMoreInteractions(filter); + } + + @Test + void scanPackageWithUseManagedClassNamesFilter() { + List candidates = List.of(Person.class.getName(), DriversLicense.class.getName()); + PersistenceManagedTypes managedTypes = new PersistenceManagedTypesScanner( + RESOURCE_LOADER, candidates::contains).scan("org.springframework.orm.jpa.domain"); + assertThat(managedTypes.getManagedClassNames()).containsExactlyInAnyOrder( + Person.class.getName(), DriversLicense.class.getName()); + assertThat(managedTypes.getManagedPackages()).isEmpty(); + } + @Test void scanPackageWithEntitiesAndManagedPackages() { PersistenceManagedTypes managedTypes = this.scanner.scan("org.springframework.orm.jpa.domain2"); @@ -65,7 +97,20 @@ class PersistenceManagedTypesScannerTests { "com.example.domain.Person", "com.example.domain.Address"); assertThat(managedTypes.getManagedPackages()).containsExactlyInAnyOrder( "com.example.domain"); + } + @Test + void scanPackageUsesIndexAndClassNameFilterIfPresent() { + List candidates = List.of("com.example.domain.Address"); + DefaultResourceLoader resourceLoader = new DefaultResourceLoader( + CandidateComponentsTestClassLoader.index(getClass().getClassLoader(), + new ClassPathResource("test-spring.components", getClass()))); + PersistenceManagedTypes managedTypes = new PersistenceManagedTypesScanner( + resourceLoader, candidates::contains).scan("com.example"); + assertThat(managedTypes.getManagedClassNames()).containsExactlyInAnyOrder( + "com.example.domain.Address"); + assertThat(managedTypes.getManagedPackages()).containsExactlyInAnyOrder( + "com.example.domain"); } }