From 5e7b3a3bedb846f01ab01e056e58c1df4d2ab3d7 Mon Sep 17 00:00:00 2001 From: Stefano Cordio Date: Sat, 23 Nov 2024 16:47:20 +0100 Subject: [PATCH 1/2] Avoid infinite recursion in BeanValidationBeanRegistrationAotProcessor Prior to this commit, AOT processing for bean validation failed with a StackOverflowError for constraints with fields having recursive generic types. With this commit, the algorithm tracks visited classes and aborts preemptively when a cycle is detected. Closes gh-33950 Co-authored-by: Sam Brannen --- ...alidationBeanRegistrationAotProcessor.java | 19 ++++++++------ ...tionBeanRegistrationAotProcessorTests.java | 26 +++++++++++++++++++ 2 files changed, 37 insertions(+), 8 deletions(-) diff --git a/spring-context/src/main/java/org/springframework/validation/beanvalidation/BeanValidationBeanRegistrationAotProcessor.java b/spring-context/src/main/java/org/springframework/validation/beanvalidation/BeanValidationBeanRegistrationAotProcessor.java index 0647386c21c..dfc55d7c0b4 100644 --- a/spring-context/src/main/java/org/springframework/validation/beanvalidation/BeanValidationBeanRegistrationAotProcessor.java +++ b/spring-context/src/main/java/org/springframework/validation/beanvalidation/BeanValidationBeanRegistrationAotProcessor.java @@ -18,7 +18,6 @@ package org.springframework.validation.beanvalidation; import java.util.Collection; import java.util.HashSet; -import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; @@ -104,10 +103,11 @@ class BeanValidationBeanRegistrationAotProcessor implements BeanRegistrationAotP } Class beanClass = registeredBean.getBeanClass(); + Set> visitedClasses = new HashSet<>(); Set> validatedClasses = new HashSet<>(); Set>> constraintValidatorClasses = new HashSet<>(); - processAheadOfTime(beanClass, validatedClasses, constraintValidatorClasses); + processAheadOfTime(beanClass, visitedClasses, validatedClasses, constraintValidatorClasses); if (!validatedClasses.isEmpty() || !constraintValidatorClasses.isEmpty()) { return new AotContribution(validatedClasses, constraintValidatorClasses); @@ -115,9 +115,12 @@ class BeanValidationBeanRegistrationAotProcessor implements BeanRegistrationAotP return null; } - private static void processAheadOfTime(Class clazz, Collection> validatedClasses, - Collection>> constraintValidatorClasses) { + private static void processAheadOfTime(Class clazz, Set> visitedClasses, Set> validatedClasses, + Set>> constraintValidatorClasses) { + if (!visitedClasses.add(clazz)) { + return; + } Assert.notNull(validator, "Validator can't be null"); BeanDescriptor descriptor; @@ -149,12 +152,12 @@ class BeanValidationBeanRegistrationAotProcessor implements BeanRegistrationAotP ReflectionUtils.doWithFields(clazz, field -> { Class type = field.getType(); - if (Iterable.class.isAssignableFrom(type) || List.class.isAssignableFrom(type) || Optional.class.isAssignableFrom(type)) { + if (Iterable.class.isAssignableFrom(type) || Optional.class.isAssignableFrom(type)) { ResolvableType resolvableType = ResolvableType.forField(field); Class genericType = resolvableType.getGeneric(0).toClass(); if (shouldProcess(genericType)) { validatedClasses.add(clazz); - processAheadOfTime(genericType, validatedClasses, constraintValidatorClasses); + processAheadOfTime(genericType, visitedClasses, validatedClasses, constraintValidatorClasses); } } if (Map.class.isAssignableFrom(type)) { @@ -163,11 +166,11 @@ class BeanValidationBeanRegistrationAotProcessor implements BeanRegistrationAotP Class valueGenericType = resolvableType.getGeneric(1).toClass(); if (shouldProcess(keyGenericType)) { validatedClasses.add(clazz); - processAheadOfTime(keyGenericType, validatedClasses, constraintValidatorClasses); + processAheadOfTime(keyGenericType, visitedClasses, validatedClasses, constraintValidatorClasses); } if (shouldProcess(valueGenericType)) { validatedClasses.add(clazz); - processAheadOfTime(valueGenericType, validatedClasses, constraintValidatorClasses); + processAheadOfTime(valueGenericType, visitedClasses, validatedClasses, constraintValidatorClasses); } } }); diff --git a/spring-context/src/test/java/org/springframework/validation/beanvalidation/BeanValidationBeanRegistrationAotProcessorTests.java b/spring-context/src/test/java/org/springframework/validation/beanvalidation/BeanValidationBeanRegistrationAotProcessorTests.java index d43d8033317..bbcdf3e4b70 100644 --- a/spring-context/src/test/java/org/springframework/validation/beanvalidation/BeanValidationBeanRegistrationAotProcessorTests.java +++ b/spring-context/src/test/java/org/springframework/validation/beanvalidation/BeanValidationBeanRegistrationAotProcessorTests.java @@ -22,6 +22,9 @@ import java.lang.annotation.Retention; import java.lang.annotation.Target; import java.util.ArrayList; import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; import jakarta.validation.Constraint; import jakarta.validation.ConstraintValidator; @@ -31,6 +34,8 @@ import jakarta.validation.Valid; import jakarta.validation.constraints.Pattern; import org.hibernate.validator.internal.constraintvalidators.bv.PatternValidator; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; import org.springframework.aot.generate.GenerationContext; import org.springframework.aot.hint.MemberCategory; @@ -121,6 +126,15 @@ class BeanValidationBeanRegistrationAotProcessorTests { .withMemberCategory(MemberCategory.INVOKE_DECLARED_CONSTRUCTORS)).accepts(this.generationContext.getRuntimeHints()); } + @ParameterizedTest // gh-33936 + @ValueSource(classes = {BeanWithIterable.class, BeanWithMap.class, BeanWithOptional.class}) + void shouldProcessRecursiveGenericsWithoutInfiniteRecursion(Class beanClass) { + process(beanClass); + assertThat(this.generationContext.getRuntimeHints().reflection().typeHints()).hasSize(1); + assertThat(RuntimeHintsPredicates.reflection().onType(beanClass) + .withMemberCategory(MemberCategory.DECLARED_FIELDS)).accepts(this.generationContext.getRuntimeHints()); + } + private void process(Class beanClass) { BeanRegistrationAotContribution contribution = createContribution(beanClass); if (contribution != null) { @@ -244,4 +258,16 @@ class BeanValidationBeanRegistrationAotProcessorTests { } } + static class BeanWithIterable { + private final Iterable beans = Set.of(); + } + + static class BeanWithMap { + private final Map beans = Map.of(); + } + + static class BeanWithOptional { + private final Optional beans = Optional.empty(); + } + } From 051f1dac241825159618c0397fd7966c78ece549 Mon Sep 17 00:00:00 2001 From: Sam Brannen <104798+sbrannen@users.noreply.github.com> Date: Sun, 24 Nov 2024 14:09:38 +0100 Subject: [PATCH 2/2] Polish contribution See gh-33950 --- ...nValidationBeanRegistrationAotProcessor.java | 8 ++++---- ...dationBeanRegistrationAotProcessorTests.java | 17 ++++++++--------- 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/spring-context/src/main/java/org/springframework/validation/beanvalidation/BeanValidationBeanRegistrationAotProcessor.java b/spring-context/src/main/java/org/springframework/validation/beanvalidation/BeanValidationBeanRegistrationAotProcessor.java index dfc55d7c0b4..d9680e775d8 100644 --- a/spring-context/src/main/java/org/springframework/validation/beanvalidation/BeanValidationBeanRegistrationAotProcessor.java +++ b/spring-context/src/main/java/org/springframework/validation/beanvalidation/BeanValidationBeanRegistrationAotProcessor.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -103,11 +103,10 @@ class BeanValidationBeanRegistrationAotProcessor implements BeanRegistrationAotP } Class beanClass = registeredBean.getBeanClass(); - Set> visitedClasses = new HashSet<>(); Set> validatedClasses = new HashSet<>(); Set>> constraintValidatorClasses = new HashSet<>(); - processAheadOfTime(beanClass, visitedClasses, validatedClasses, constraintValidatorClasses); + processAheadOfTime(beanClass, new HashSet<>(), validatedClasses, constraintValidatorClasses); if (!validatedClasses.isEmpty() || !constraintValidatorClasses.isEmpty()) { return new AotContribution(validatedClasses, constraintValidatorClasses); @@ -118,10 +117,11 @@ class BeanValidationBeanRegistrationAotProcessor implements BeanRegistrationAotP private static void processAheadOfTime(Class clazz, Set> visitedClasses, Set> validatedClasses, Set>> constraintValidatorClasses) { + Assert.notNull(validator, "Validator cannot be null"); + if (!visitedClasses.add(clazz)) { return; } - Assert.notNull(validator, "Validator can't be null"); BeanDescriptor descriptor; try { diff --git a/spring-context/src/test/java/org/springframework/validation/beanvalidation/BeanValidationBeanRegistrationAotProcessorTests.java b/spring-context/src/test/java/org/springframework/validation/beanvalidation/BeanValidationBeanRegistrationAotProcessorTests.java index bbcdf3e4b70..b3f3de83cf8 100644 --- a/spring-context/src/test/java/org/springframework/validation/beanvalidation/BeanValidationBeanRegistrationAotProcessorTests.java +++ b/spring-context/src/test/java/org/springframework/validation/beanvalidation/BeanValidationBeanRegistrationAotProcessorTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,7 +24,6 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Optional; -import java.util.Set; import jakarta.validation.Constraint; import jakarta.validation.ConstraintValidator; @@ -127,7 +126,7 @@ class BeanValidationBeanRegistrationAotProcessorTests { } @ParameterizedTest // gh-33936 - @ValueSource(classes = {BeanWithIterable.class, BeanWithMap.class, BeanWithOptional.class}) + @ValueSource(classes = {BeanWithRecursiveIterable.class, BeanWithRecursiveMap.class, BeanWithRecursiveOptional.class}) void shouldProcessRecursiveGenericsWithoutInfiniteRecursion(Class beanClass) { process(beanClass); assertThat(this.generationContext.getRuntimeHints().reflection().typeHints()).hasSize(1); @@ -258,16 +257,16 @@ class BeanValidationBeanRegistrationAotProcessorTests { } } - static class BeanWithIterable { - private final Iterable beans = Set.of(); + static class BeanWithRecursiveIterable { + Iterable iterable; } - static class BeanWithMap { - private final Map beans = Map.of(); + static class BeanWithRecursiveMap { + Map map; } - static class BeanWithOptional { - private final Optional beans = Optional.empty(); + static class BeanWithRecursiveOptional { + Optional optional; } }