diff --git a/spring-context/src/main/java/org/springframework/validation/beanvalidation/BeanValidationBeanRegistrationAotProcessor.java b/spring-context/src/main/java/org/springframework/validation/beanvalidation/BeanValidationBeanRegistrationAotProcessor.java index d13639a0c08..0647386c21c 100644 --- a/spring-context/src/main/java/org/springframework/validation/beanvalidation/BeanValidationBeanRegistrationAotProcessor.java +++ b/spring-context/src/main/java/org/springframework/validation/beanvalidation/BeanValidationBeanRegistrationAotProcessor.java @@ -18,16 +18,20 @@ package org.springframework.validation.beanvalidation; import java.util.Collection; import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; import java.util.Set; import jakarta.validation.ConstraintValidator; import jakarta.validation.NoProviderFoundException; import jakarta.validation.Validation; import jakarta.validation.Validator; +import jakarta.validation.ValidatorFactory; import jakarta.validation.metadata.BeanDescriptor; import jakarta.validation.metadata.ConstraintDescriptor; -import jakarta.validation.metadata.ConstructorDescriptor; -import jakarta.validation.metadata.MethodDescriptor; +import jakarta.validation.metadata.ContainerElementTypeDescriptor; +import jakarta.validation.metadata.ExecutableDescriptor; import jakarta.validation.metadata.MethodType; import jakarta.validation.metadata.ParameterDescriptor; import jakarta.validation.metadata.PropertyDescriptor; @@ -36,13 +40,17 @@ import org.apache.commons.logging.LogFactory; import org.springframework.aot.generate.GenerationContext; import org.springframework.aot.hint.MemberCategory; +import org.springframework.aot.hint.ReflectionHints; import org.springframework.beans.factory.aot.BeanRegistrationAotContribution; import org.springframework.beans.factory.aot.BeanRegistrationAotProcessor; import org.springframework.beans.factory.aot.BeanRegistrationCode; import org.springframework.beans.factory.support.RegisteredBean; import org.springframework.core.KotlinDetector; +import org.springframework.core.ResolvableType; import org.springframework.lang.Nullable; +import org.springframework.util.Assert; import org.springframework.util.ClassUtils; +import org.springframework.util.ReflectionUtils; /** * AOT {@code BeanRegistrationAotProcessor} that adds additional hints @@ -80,8 +88,8 @@ class BeanValidationBeanRegistrationAotProcessor implements BeanRegistrationAotP @Nullable private static Validator getValidatorIfAvailable() { - try { - return Validation.buildDefaultValidatorFactory().getValidator(); + try (ValidatorFactory validator = Validation.buildDefaultValidatorFactory()) { + return validator.getValidator(); } catch (NoProviderFoundException ex) { logger.info("No Bean Validation provider available - skipping validation constraint hint inference"); @@ -95,64 +103,134 @@ class BeanValidationBeanRegistrationAotProcessor implements BeanRegistrationAotP return null; } + Class beanClass = registeredBean.getBeanClass(); + Set> validatedClasses = new HashSet<>(); + Set>> constraintValidatorClasses = new HashSet<>(); + + processAheadOfTime(beanClass, validatedClasses, constraintValidatorClasses); + + if (!validatedClasses.isEmpty() || !constraintValidatorClasses.isEmpty()) { + return new AotContribution(validatedClasses, constraintValidatorClasses); + } + return null; + } + + private static void processAheadOfTime(Class clazz, Collection> validatedClasses, + Collection>> constraintValidatorClasses) { + + Assert.notNull(validator, "Validator can't be null"); + BeanDescriptor descriptor; try { - descriptor = validator.getConstraintsForClass(registeredBean.getBeanClass()); + descriptor = validator.getConstraintsForClass(clazz); } catch (RuntimeException ex) { - if (KotlinDetector.isKotlinType(registeredBean.getBeanClass()) && ex instanceof ArrayIndexOutOfBoundsException) { + if (KotlinDetector.isKotlinType(clazz) && ex instanceof ArrayIndexOutOfBoundsException) { // See https://hibernate.atlassian.net/browse/HV-1796 and https://youtrack.jetbrains.com/issue/KT-40857 - logger.warn("Skipping validation constraint hint inference for bean " + registeredBean.getBeanName() + + logger.warn("Skipping validation constraint hint inference for class " + clazz + " due to an ArrayIndexOutOfBoundsException at validator level"); } else if (ex instanceof TypeNotPresentException) { - logger.debug("Skipping validation constraint hint inference for bean " + - registeredBean.getBeanName() + " due to a TypeNotPresentException at validator level: " + ex.getMessage()); + logger.debug("Skipping validation constraint hint inference for class " + + clazz + " due to a TypeNotPresentException at validator level: " + ex.getMessage()); } else { - logger.warn("Skipping validation constraint hint inference for bean " + - registeredBean.getBeanName(), ex); + logger.warn("Skipping validation constraint hint inference for class " + clazz, ex); } - return null; + return; } - Set> constraintDescriptors = new HashSet<>(); - for (MethodDescriptor methodDescriptor : descriptor.getConstrainedMethods(MethodType.NON_GETTER, MethodType.GETTER)) { - for (ParameterDescriptor parameterDescriptor : methodDescriptor.getParameterDescriptors()) { - constraintDescriptors.addAll(parameterDescriptor.getConstraintDescriptors()); - } + processExecutableDescriptor(descriptor.getConstrainedMethods(MethodType.NON_GETTER, MethodType.GETTER), constraintValidatorClasses); + processExecutableDescriptor(descriptor.getConstrainedConstructors(), constraintValidatorClasses); + processPropertyDescriptors(descriptor.getConstrainedProperties(), constraintValidatorClasses); + if (!constraintValidatorClasses.isEmpty() && shouldProcess(clazz)) { + validatedClasses.add(clazz); } - for (ConstructorDescriptor constructorDescriptor : descriptor.getConstrainedConstructors()) { - for (ParameterDescriptor parameterDescriptor : constructorDescriptor.getParameterDescriptors()) { - constraintDescriptors.addAll(parameterDescriptor.getConstraintDescriptors()); + + ReflectionUtils.doWithFields(clazz, field -> { + Class type = field.getType(); + if (Iterable.class.isAssignableFrom(type) || List.class.isAssignableFrom(type) || Optional.class.isAssignableFrom(type)) { + ResolvableType resolvableType = ResolvableType.forField(field); + Class genericType = resolvableType.getGeneric(0).toClass(); + if (shouldProcess(genericType)) { + validatedClasses.add(clazz); + processAheadOfTime(genericType, validatedClasses, constraintValidatorClasses); + } + } + if (Map.class.isAssignableFrom(type)) { + ResolvableType resolvableType = ResolvableType.forField(field); + Class keyGenericType = resolvableType.getGeneric(0).toClass(); + Class valueGenericType = resolvableType.getGeneric(1).toClass(); + if (shouldProcess(keyGenericType)) { + validatedClasses.add(clazz); + processAheadOfTime(keyGenericType, validatedClasses, constraintValidatorClasses); + } + if (shouldProcess(valueGenericType)) { + validatedClasses.add(clazz); + processAheadOfTime(valueGenericType, validatedClasses, constraintValidatorClasses); + } + } + }); + } + + private static boolean shouldProcess(Class clazz) { + return !clazz.getCanonicalName().startsWith("java."); + } + + private static void processExecutableDescriptor(Set executableDescriptors, + Collection>> constraintValidatorClasses) { + + for (ExecutableDescriptor executableDescriptor : executableDescriptors) { + for (ParameterDescriptor parameterDescriptor : executableDescriptor.getParameterDescriptors()) { + for (ConstraintDescriptor constraintDescriptor : parameterDescriptor.getConstraintDescriptors()) { + constraintValidatorClasses.addAll(constraintDescriptor.getConstraintValidatorClasses()); + } + for (ContainerElementTypeDescriptor typeDescriptor : parameterDescriptor.getConstrainedContainerElementTypes()) { + for (ConstraintDescriptor constraintDescriptor : typeDescriptor.getConstraintDescriptors()) { + constraintValidatorClasses.addAll(constraintDescriptor.getConstraintValidatorClasses()); + } + } } } - for (PropertyDescriptor propertyDescriptor : descriptor.getConstrainedProperties()) { - constraintDescriptors.addAll(propertyDescriptor.getConstraintDescriptors()); - } - if (!constraintDescriptors.isEmpty()) { - return new AotContribution(constraintDescriptors); + } + + private static void processPropertyDescriptors(Set propertyDescriptors, + Collection>> constraintValidatorClasses) { + + for (PropertyDescriptor propertyDescriptor : propertyDescriptors) { + for (ConstraintDescriptor constraintDescriptor : propertyDescriptor.getConstraintDescriptors()) { + constraintValidatorClasses.addAll(constraintDescriptor.getConstraintValidatorClasses()); + } + for (ContainerElementTypeDescriptor typeDescriptor : propertyDescriptor.getConstrainedContainerElementTypes()) { + for (ConstraintDescriptor constraintDescriptor : typeDescriptor.getConstraintDescriptors()) { + constraintValidatorClasses.addAll(constraintDescriptor.getConstraintValidatorClasses()); + } + } } - return null; } } private static class AotContribution implements BeanRegistrationAotContribution { - private final Collection> constraintDescriptors; + private final Collection> validatedClasses; + private final Collection>> constraintValidatorClasses; - public AotContribution(Collection> constraintDescriptors) { - this.constraintDescriptors = constraintDescriptors; + public AotContribution(Collection> validatedClasses, + Collection>> constraintValidatorClasses) { + + this.validatedClasses = validatedClasses; + this.constraintValidatorClasses = constraintValidatorClasses; } @Override public void applyTo(GenerationContext generationContext, BeanRegistrationCode beanRegistrationCode) { - for (ConstraintDescriptor constraintDescriptor : this.constraintDescriptors) { - for (Class constraintValidatorClass : constraintDescriptor.getConstraintValidatorClasses()) { - generationContext.getRuntimeHints().reflection().registerType(constraintValidatorClass, - MemberCategory.INVOKE_DECLARED_CONSTRUCTORS); - } + ReflectionHints hints = generationContext.getRuntimeHints().reflection(); + for (Class validatedClass : this.validatedClasses) { + hints.registerType(validatedClass, MemberCategory.DECLARED_FIELDS); + } + for (Class> constraintValidatorClass : this.constraintValidatorClasses) { + hints.registerType(constraintValidatorClass, MemberCategory.INVOKE_DECLARED_CONSTRUCTORS); } } } diff --git a/spring-context/src/test/java/org/springframework/validation/beanvalidation/BeanValidationBeanRegistrationAotProcessorTests.java b/spring-context/src/test/java/org/springframework/validation/beanvalidation/BeanValidationBeanRegistrationAotProcessorTests.java index 4d28c5ba99b..d43d8033317 100644 --- a/spring-context/src/test/java/org/springframework/validation/beanvalidation/BeanValidationBeanRegistrationAotProcessorTests.java +++ b/spring-context/src/test/java/org/springframework/validation/beanvalidation/BeanValidationBeanRegistrationAotProcessorTests.java @@ -20,11 +20,16 @@ import java.lang.annotation.Documented; import java.lang.annotation.Repeatable; import java.lang.annotation.Retention; import java.lang.annotation.Target; +import java.util.ArrayList; +import java.util.List; import jakarta.validation.Constraint; import jakarta.validation.ConstraintValidator; import jakarta.validation.ConstraintValidatorContext; import jakarta.validation.Payload; +import jakarta.validation.Valid; +import jakarta.validation.constraints.Pattern; +import org.hibernate.validator.internal.constraintvalidators.bv.PatternValidator; import org.junit.jupiter.api.Test; import org.springframework.aot.generate.GenerationContext; @@ -67,6 +72,9 @@ class BeanValidationBeanRegistrationAotProcessorTests { @Test void shouldProcessMethodParameterLevelConstraint() { process(MethodParameterLevelConstraint.class); + assertThat(this.generationContext.getRuntimeHints().reflection().typeHints()).hasSize(2); + assertThat(RuntimeHintsPredicates.reflection().onType(MethodParameterLevelConstraint.class) + .withMemberCategory(MemberCategory.DECLARED_FIELDS)).accepts(this.generationContext.getRuntimeHints()); assertThat(RuntimeHintsPredicates.reflection().onType(ExistsValidator.class) .withMemberCategory(MemberCategory.INVOKE_DECLARED_CONSTRUCTORS)).accepts(this.generationContext.getRuntimeHints()); } @@ -74,6 +82,9 @@ class BeanValidationBeanRegistrationAotProcessorTests { @Test void shouldProcessConstructorParameterLevelConstraint() { process(ConstructorParameterLevelConstraint.class); + assertThat(this.generationContext.getRuntimeHints().reflection().typeHints()).hasSize(2); + assertThat(RuntimeHintsPredicates.reflection().onType(ConstructorParameterLevelConstraint.class) + .withMemberCategory(MemberCategory.DECLARED_FIELDS)).accepts(this.generationContext.getRuntimeHints()); assertThat(RuntimeHintsPredicates.reflection().onType(ExistsValidator.class) .withMemberCategory(MemberCategory.INVOKE_DECLARED_CONSTRUCTORS)).accepts(this.generationContext.getRuntimeHints()); } @@ -81,10 +92,35 @@ class BeanValidationBeanRegistrationAotProcessorTests { @Test void shouldProcessPropertyLevelConstraint() { process(PropertyLevelConstraint.class); + assertThat(this.generationContext.getRuntimeHints().reflection().typeHints()).hasSize(2); + assertThat(RuntimeHintsPredicates.reflection().onType(PropertyLevelConstraint.class) + .withMemberCategory(MemberCategory.DECLARED_FIELDS)).accepts(this.generationContext.getRuntimeHints()); assertThat(RuntimeHintsPredicates.reflection().onType(ExistsValidator.class) .withMemberCategory(MemberCategory.INVOKE_DECLARED_CONSTRUCTORS)).accepts(this.generationContext.getRuntimeHints()); } + @Test + void shouldProcessGenericTypeLevelConstraint() { + process(GenericTypeLevelConstraint.class); + assertThat(this.generationContext.getRuntimeHints().reflection().typeHints()).hasSize(2); + assertThat(RuntimeHintsPredicates.reflection().onType(GenericTypeLevelConstraint.class) + .withMemberCategory(MemberCategory.DECLARED_FIELDS)).accepts(this.generationContext.getRuntimeHints()); + assertThat(RuntimeHintsPredicates.reflection().onType(PatternValidator.class) + .withMemberCategory(MemberCategory.INVOKE_DECLARED_CONSTRUCTORS)).accepts(this.generationContext.getRuntimeHints()); + } + + @Test + void shouldProcessTransitiveGenericTypeLevelConstraint() { + process(TransitiveGenericTypeLevelConstraint.class); + assertThat(this.generationContext.getRuntimeHints().reflection().typeHints()).hasSize(3); + assertThat(RuntimeHintsPredicates.reflection().onType(TransitiveGenericTypeLevelConstraint.class) + .withMemberCategory(MemberCategory.DECLARED_FIELDS)).accepts(this.generationContext.getRuntimeHints()); + assertThat(RuntimeHintsPredicates.reflection().onType(Exclude.class) + .withMemberCategory(MemberCategory.DECLARED_FIELDS)).accepts(this.generationContext.getRuntimeHints()); + assertThat(RuntimeHintsPredicates.reflection().onType(PatternValidator.class) + .withMemberCategory(MemberCategory.INVOKE_DECLARED_CONSTRUCTORS)).accepts(this.generationContext.getRuntimeHints()); + } + private void process(Class beanClass) { BeanRegistrationAotContribution contribution = createContribution(beanClass); if (contribution != null) { @@ -168,4 +204,44 @@ class BeanValidationBeanRegistrationAotProcessorTests { } } + static class Exclude { + + @Valid + private List<@Pattern(regexp="^([1-5][x|X]{2}|[1-5][0-9]{2})\\$") String> httpStatus; + + public List getHttpStatus() { + return httpStatus; + } + + public void setHttpStatus(List httpStatus) { + this.httpStatus = httpStatus; + } + } + + static class GenericTypeLevelConstraint { + + private List<@Pattern(regexp="^([1-5][x|X]{2}|[1-5][0-9]{2})\\$") String> httpStatus; + + public List getHttpStatus() { + return httpStatus; + } + + public void setHttpStatus(List httpStatus) { + this.httpStatus = httpStatus; + } + } + + static class TransitiveGenericTypeLevelConstraint { + + private List exclude = new ArrayList<>(); + + public List getExclude() { + return exclude; + } + + public void setExclude(List exclude) { + this.exclude = exclude; + } + } + }