From ca77912f861dff5a63f992f62db82fd8ea9c3cf4 Mon Sep 17 00:00:00 2001 From: Mark Paluch Date: Wed, 10 Sep 2025 10:39:41 +0200 Subject: [PATCH] Refine JavaPoet usage. See: #2121 Original pull request: #2124 --- .../jdbc/repository/aot/JdbcCodeBlocks.java | 123 ++++++++---------- .../aot/JdbcRepositoryContributor.java | 2 + .../AotJdbcRepositoryIntegrationTests.java | 12 ++ .../AotFragmentTestConfigurationSupport.java | 20 +-- ...RepositoryContributorIntegrationTests.java | 11 ++ ...dbcRepositoryMetadataIntegrationTests.java | 6 + 6 files changed, 88 insertions(+), 86 deletions(-) diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/aot/JdbcCodeBlocks.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/aot/JdbcCodeBlocks.java index 4df6d81cf..6919f7264 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/aot/JdbcCodeBlocks.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/aot/JdbcCodeBlocks.java @@ -34,6 +34,7 @@ import org.jspecify.annotations.Nullable; import org.springframework.core.annotation.MergedAnnotation; import org.springframework.data.domain.SliceImpl; import org.springframework.data.domain.Sort; +import org.springframework.data.javapoet.LordOfTheStrings; import org.springframework.data.jdbc.repository.query.JdbcQueryMethod; import org.springframework.data.jdbc.repository.query.Modifying; import org.springframework.data.jdbc.repository.query.ParameterBinding; @@ -44,10 +45,10 @@ import org.springframework.data.relational.core.query.CriteriaDefinition; import org.springframework.data.relational.core.sql.LockMode; import org.springframework.data.relational.repository.Lock; import org.springframework.data.repository.aot.generate.AotQueryMethodGenerationContext; +import org.springframework.data.repository.aot.generate.MethodReturn; import org.springframework.data.repository.query.parser.Part; import org.springframework.data.support.PageableExecutionUtils; import org.springframework.data.util.Pair; -import org.springframework.data.util.ReflectionUtils; import org.springframework.javapoet.CodeBlock; import org.springframework.javapoet.CodeBlock.Builder; import org.springframework.javapoet.TypeName; @@ -58,7 +59,7 @@ import org.springframework.jdbc.core.SingleColumnRowMapper; import org.springframework.jdbc.core.namedparam.MapSqlParameterSource; import org.springframework.jdbc.core.namedparam.SqlParameterSource; import org.springframework.util.Assert; -import org.springframework.util.ObjectUtils; +import org.springframework.util.ClassUtils; import org.springframework.util.StringUtils; /** @@ -568,25 +569,26 @@ class JdbcCodeBlocks { Builder builder = CodeBlock.builder(); - boolean isProjecting = !ObjectUtils.nullSafeEquals( - TypeName.get(context.getRepositoryInformation().getDomainType()), context.getActualReturnType()); - Type actualReturnType = isProjecting ? context.getActualReturnType().getType() + MethodReturn methodReturn = context.getMethodReturn(); + boolean isProjecting = methodReturn.isProjecting() + || StringUtils.hasText(context.getDynamicProjectionParameterName()); + Type actualReturnType = isProjecting ? methodReturn.getActualReturnClass() : context.getRepositoryInformation().getDomainType(); - builder.add("\n"); - Class returnType = context.getMethod().getReturnType(); - TypeName queryResultType = TypeName.get(context.getActualReturnType().toClass()); + Class returnType = context.getMethodReturn().toClass(); + + TypeName queryResultType = methodReturn.getActualClassName(); String result = context.localVariable("result"); String rowMapper = context.localVariable("rowMapper"); if (modifying.isPresent()) { - return update(builder, returnType); + return update(returnType); } else if (aotQuery.isCount()) { - return count(builder, result, returnType, queryResultType); + return count(result, returnType, queryResultType); } else if (aotQuery.isExists()) { - return exists(builder, queryResultType); + return exists(queryResultType); } else if (aotQuery.isDelete()) { - return delete(builder, rowMapper, result, queryResultType, returnType, actualReturnType); + return delete(rowMapper, result, queryResultType, returnType, actualReturnType); } else { String resultSetExtractor = null; @@ -603,7 +605,7 @@ class JdbcCodeBlocks { if (isProjecting) { typeToRead = context.getReturnedType().getDomainType(); } else { - typeToRead = context.getActualReturnType().getType(); + typeToRead = methodReturn.getActualReturnClass(); } builder.addStatement("$T $L = getRowMapperFactory().create($T.class)", RowMapper.class, rowMapper, @@ -667,25 +669,25 @@ class JdbcCodeBlocks { } builder.addStatement("return ($T) convertMany($L, %s)".formatted(dynamicProjection ? "$L" : "$T.class"), - context.getReturnTypeName(), result, queryResultTypeRef); + methodReturn.getTypeName(), result, queryResultTypeRef); } else if (queryMethod.isStreamQuery()) { builder.addStatement("$1T $2L = getJdbcOperations().queryForStream($3L, $4L, $5L)", Stream.class, result, queryVariableName, parameterSourceVariableName, rowMapper); - builder.addStatement("return ($T) convertMany($L, $T.class)", context.getReturnTypeName(), result, + builder.addStatement("return ($T) convertMany($L, $T.class)", methodReturn.getTypeName(), result, queryResultTypeRef); } else { builder.addStatement("$T $L = queryForObject($L, $L, $L)", Object.class, result, queryVariableName, parameterSourceVariableName, rowMapper); - if (Optional.class.isAssignableFrom(context.getReturnType().toClass())) { + if (methodReturn.isOptional()) { builder.addStatement( "return ($1T) $1T.ofNullable(convertOne($2L, %s))".formatted(dynamicProjection ? "$3L" : "$3T.class"), Optional.class, result, queryResultTypeRef); } else { builder.addStatement("return ($T) convertOne($L, %s)".formatted(dynamicProjection ? "$L" : "$T.class"), - context.getReturnTypeName(), result, queryResultTypeRef); + methodReturn.getTypeName(), result, queryResultTypeRef); } } } @@ -693,37 +695,35 @@ class JdbcCodeBlocks { return builder.build(); } - private CodeBlock update(Builder builder, Class returnType) { + private CodeBlock update(Class returnType) { String result = context.localVariable("result"); - builder.add("$["); - - if (!ReflectionUtils.isVoid(returnType)) { - builder.add("int $L = ", result); - } + Builder builder = CodeBlock.builder(); - builder.add("getJdbcOperations().update($L, $L)", queryVariableName, parameterSourceVariableName); - builder.add(";\n$]"); + LordOfTheStrings.InvocationBuilder invoke = LordOfTheStrings.invoke("getJdbcOperations().update($L, $L)", + queryVariableName, parameterSourceVariableName); - if (returnType == boolean.class || returnType == Boolean.class) { - builder.addStatement("return $L != 0", result); - } else if (returnType == Long.class) { - builder.addStatement("return (long) $L", result); - } else if (ReflectionUtils.isVoid(returnType)) { - if (returnType == Void.class) { - builder.addStatement("return null"); - } + if (context.getMethodReturn().isVoid()) { + builder.addStatement(invoke.build()); } else { - builder.addStatement("return $L", result); + builder.addStatement(invoke.assignTo("int $L", result)); } + builder.addStatement(LordOfTheStrings.returning(returnType) // + .whenBoolean("$L != 0", result) // + .whenBoxedLong("(long) $L", result) // + .otherwise("$L", result)// + .build()); + return builder.build(); } - private CodeBlock delete(Builder builder, String rowMapper, String result, TypeName queryResultType, + private CodeBlock delete(String rowMapper, String result, TypeName queryResultType, Class returnType, Type actualReturnType) { + CodeBlock.Builder builder = CodeBlock.builder(); + builder.addStatement("$T $L = getRowMapperFactory().create($T.class)", RowMapper.class, rowMapper, context.getRepositoryInformation().getDomainType()); @@ -732,48 +732,37 @@ class JdbcCodeBlocks { builder.addStatement("$L.forEach(getOperations()::delete)", result); - if (Collection.class.isAssignableFrom(context.getReturnType().toClass())) { - builder.addStatement("return ($T) convertMany($L, $T.class)", context.getReturnTypeName(), result, - queryResultType); - } else if (returnType == context.getRepositoryInformation().getDomainType()) { - builder.addStatement("return ($1T) ($2L.isEmpty() ? null : $2L.iterator().next())", actualReturnType, result); - } else if (returnType == boolean.class || returnType == Boolean.class) { - builder.addStatement("return !$L.isEmpty()", result); - } else if (returnType == Long.class) { - builder.addStatement("return (long) $L.size()", result); - } else if (ReflectionUtils.isVoid(returnType)) { - if (returnType == Void.class) { - builder.addStatement("return null"); - } - } else { - builder.addStatement("return $L.size()", result); - } + builder.addStatement(LordOfTheStrings.returning(returnType) // + .when(Collection.class.isAssignableFrom(context.getMethodReturn().toClass()), + "($T) convertMany($L, $T.class)", context.getMethodReturn().getTypeName(), result, queryResultType) // + .when(context.getRepositoryInformation().getDomainType(), + "($1T) ($2L.isEmpty() ? null : $2L.iterator().next())", actualReturnType, result) // + .whenBoolean("!$L.isEmpty()", result) // + .whenBoxedLong("(long) $L.size()", result) // + .otherwise("$L.size()", result) // + .build()); return builder.build(); } - private CodeBlock count(Builder builder, String result, Class returnType, TypeName queryResultType) { + private CodeBlock count(String result, Class returnType, TypeName queryResultType) { + + CodeBlock.Builder builder = CodeBlock.builder(); builder.addStatement("$1T $2L = queryForObject($3L, $4L, new $5T<>($1T.class))", Number.class, result, queryVariableName, parameterSourceVariableName, SingleColumnRowMapper.class); - if (returnType == Long.class) { - builder.addStatement("return $1L != null ? $1L.longValue() : null", result); - } else if (returnType == Integer.class) { - builder.addStatement("return $1L != null ? $1L.intValue() : null", result); - } else if (returnType == Long.TYPE) { - builder.addStatement("return $1L != null ? $1L.longValue() : 0L", result); - } else if (returnType == Integer.TYPE) { - builder.addStatement("return $1L != null ? $1L.intValue() : 0", result); - } else { - builder.addStatement("return ($T) convertOne($L, $T.class)", context.getReturnTypeName(), result, - queryResultType); - } + builder.addStatement(LordOfTheStrings.returning(returnType) // + .number(result) // + .otherwise("($T) convertOne($L, $T.class)", context.getMethodReturn().getTypeName(), result, queryResultType) // + .build()); return builder.build(); } - private CodeBlock exists(Builder builder, TypeName queryResultType) { + private CodeBlock exists(TypeName queryResultType) { + + CodeBlock.Builder builder = CodeBlock.builder(); builder.addStatement("return ($T) getJdbcOperations().query($L, $L, $T::next)", queryResultType, queryVariableName, parameterSourceVariableName, ResultSet.class); @@ -783,8 +772,8 @@ class JdbcCodeBlocks { public static boolean returnsModifying(Class returnType) { - return returnType == int.class || returnType == long.class || returnType == Integer.class - || returnType == Long.class; + return ClassUtils.resolvePrimitiveIfNecessary(returnType) == Integer.class + || ClassUtils.resolvePrimitiveIfNecessary(returnType) == Long.class; } } diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/aot/JdbcRepositoryContributor.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/aot/JdbcRepositoryContributor.java index ba780fdf4..dedf1b2f1 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/aot/JdbcRepositoryContributor.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/aot/JdbcRepositoryContributor.java @@ -146,6 +146,8 @@ public class JdbcRepositoryContributor extends RepositoryContributor { body.add(JdbcCodeBlocks.queryBuilder(context, queryMethod).filter(aotQueries) .usingQueryVariableName(queryVariable).parameterSource(parameterSourceVariable).lock(lock).build()); + body.add("\n"); + body.add(JdbcCodeBlocks.executionBuilder(context, queryMethod).modifying(modifying) .usingQueryVariableName(queryVariable).parameterSource(parameterSourceVariable).queries(aotQueries) .queryAnnotation(query).build()); diff --git a/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/AotJdbcRepositoryIntegrationTests.java b/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/AotJdbcRepositoryIntegrationTests.java index 7fb014c67..3578f6014 100644 --- a/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/AotJdbcRepositoryIntegrationTests.java +++ b/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/AotJdbcRepositoryIntegrationTests.java @@ -15,16 +15,21 @@ */ package org.springframework.data.jdbc.repository; +import java.util.Optional; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.ApplicationContext; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.ComponentScan; import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.FilterType; +import org.springframework.data.jdbc.core.JdbcAggregateOperations; +import org.springframework.data.jdbc.core.convert.QueryMappingConfiguration; import org.springframework.data.jdbc.core.dialect.JdbcH2Dialect; import org.springframework.data.jdbc.repository.aot.AotFragmentTestConfigurationSupport; import org.springframework.data.jdbc.repository.aot.UserRepository; import org.springframework.data.jdbc.repository.config.EnableJdbcRepositories; +import org.springframework.data.jdbc.repository.support.BeanFactoryAwareRowMapperFactory; import org.springframework.data.jdbc.testing.DatabaseType; import org.springframework.data.jdbc.testing.EnabledOnDatabase; import org.springframework.data.jdbc.testing.IntegrationTest; @@ -58,6 +63,13 @@ class AotJdbcRepositoryIntegrationTests extends JdbcRepositoryIntegrationTests { false); } + @Bean + BeanFactoryAwareRowMapperFactory rowMapperFactory(ApplicationContext context, + JdbcAggregateOperations aggregateOperations, Optional queryMappingConfiguration) { + return new BeanFactoryAwareRowMapperFactory(context, aggregateOperations, + queryMappingConfiguration.orElse(QueryMappingConfiguration.EMPTY)); + } + @Bean @Override DummyEntityRepository dummyEntityRepository() { diff --git a/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/aot/AotFragmentTestConfigurationSupport.java b/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/aot/AotFragmentTestConfigurationSupport.java index 1ac066746..6e3f69d31 100644 --- a/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/aot/AotFragmentTestConfigurationSupport.java +++ b/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/aot/AotFragmentTestConfigurationSupport.java @@ -19,7 +19,6 @@ import java.lang.reflect.Method; import java.lang.reflect.Proxy; import java.util.Arrays; import java.util.List; -import java.util.Optional; import org.mockito.Mockito; @@ -34,9 +33,6 @@ import org.springframework.beans.factory.support.AbstractBeanDefinition; import org.springframework.beans.factory.support.BeanDefinitionBuilder; import org.springframework.beans.factory.support.BeanDefinitionRegistry; import org.springframework.beans.factory.support.DefaultBeanNameGenerator; -import org.springframework.context.ApplicationContext; -import org.springframework.context.ApplicationContextAware; -import org.springframework.context.annotation.Bean; import org.springframework.core.env.Environment; import org.springframework.core.env.StandardEnvironment; import org.springframework.core.io.DefaultResourceLoader; @@ -45,12 +41,10 @@ import org.springframework.core.type.AnnotationMetadata; import org.springframework.data.expression.ValueExpressionParser; import org.springframework.data.jdbc.core.JdbcAggregateOperations; import org.springframework.data.jdbc.core.convert.MappingJdbcConverter; -import org.springframework.data.jdbc.core.convert.QueryMappingConfiguration; import org.springframework.data.jdbc.core.dialect.JdbcDialect; import org.springframework.data.jdbc.core.mapping.JdbcMappingContext; import org.springframework.data.jdbc.repository.config.EnableJdbcRepositories; import org.springframework.data.jdbc.repository.query.RowMapperFactory; -import org.springframework.data.jdbc.repository.support.BeanFactoryAwareRowMapperFactory; import org.springframework.data.projection.ProjectionFactory; import org.springframework.data.projection.SpelAwareProxyProjectionFactory; import org.springframework.data.repository.config.AnnotationRepositoryConfigurationSource; @@ -71,13 +65,12 @@ import org.springframework.util.ReflectionUtils; * * @author Mark Paluch */ -public class AotFragmentTestConfigurationSupport implements BeanFactoryPostProcessor, ApplicationContextAware { +public class AotFragmentTestConfigurationSupport implements BeanFactoryPostProcessor { private final Class repositoryInterface; private final JdbcDialect dialect; private final boolean registerFragmentFacade; private final TestJdbcAotRepositoryContext repositoryContext; - private ApplicationContext applicationContext; public AotFragmentTestConfigurationSupport(Class repositoryInterface, JdbcDialect dialect, Class configClass) { this(repositoryInterface, dialect, configClass, true); @@ -98,13 +91,6 @@ public class AotFragmentTestConfigurationSupport implements BeanFactoryPostProce this.registerFragmentFacade = registerFragmentFacade; } - @Bean - BeanFactoryAwareRowMapperFactory rowMapperFactory(ApplicationContext context, - JdbcAggregateOperations aggregateOperations, Optional queryMappingConfiguration) { - return new BeanFactoryAwareRowMapperFactory(context, aggregateOperations, - queryMappingConfiguration.orElse(QueryMappingConfiguration.EMPTY)); - } - @Override public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) throws BeansException { @@ -191,8 +177,4 @@ public class AotFragmentTestConfigurationSupport implements BeanFactoryPostProce return creationContext; } - @Override - public void setApplicationContext(ApplicationContext applicationContext) throws BeansException { - this.applicationContext = applicationContext; - } } diff --git a/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/aot/JdbcRepositoryContributorIntegrationTests.java b/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/aot/JdbcRepositoryContributorIntegrationTests.java index a89454570..9d874e973 100644 --- a/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/aot/JdbcRepositoryContributorIntegrationTests.java +++ b/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/aot/JdbcRepositoryContributorIntegrationTests.java @@ -18,11 +18,13 @@ package org.springframework.data.jdbc.repository.aot; import static org.assertj.core.api.Assertions.*; import java.util.List; +import java.util.Optional; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.ApplicationContext; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.ComponentScan; import org.springframework.context.annotation.Configuration; @@ -34,8 +36,10 @@ import org.springframework.data.domain.Pageable; import org.springframework.data.domain.Slice; import org.springframework.data.domain.Sort; import org.springframework.data.jdbc.core.JdbcAggregateOperations; +import org.springframework.data.jdbc.core.convert.QueryMappingConfiguration; import org.springframework.data.jdbc.core.dialect.JdbcH2Dialect; import org.springframework.data.jdbc.repository.config.EnableJdbcRepositories; +import org.springframework.data.jdbc.repository.support.BeanFactoryAwareRowMapperFactory; import org.springframework.data.jdbc.testing.DatabaseType; import org.springframework.data.jdbc.testing.EnabledOnDatabase; import org.springframework.data.jdbc.testing.IntegrationTest; @@ -71,6 +75,13 @@ class JdbcRepositoryContributorIntegrationTests { return TestClass.of(JdbcRepositoryContributorIntegrationTests.class); } + @Bean + BeanFactoryAwareRowMapperFactory rowMapperFactory(ApplicationContext context, + JdbcAggregateOperations aggregateOperations, Optional queryMappingConfiguration) { + return new BeanFactoryAwareRowMapperFactory(context, aggregateOperations, + queryMappingConfiguration.orElse(QueryMappingConfiguration.EMPTY)); + } + @Bean MyRowMapper myRowMapper() { return new MyRowMapper(); diff --git a/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/aot/JdbcRepositoryMetadataIntegrationTests.java b/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/aot/JdbcRepositoryMetadataIntegrationTests.java index f1b0c6600..d41d22879 100644 --- a/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/aot/JdbcRepositoryMetadataIntegrationTests.java +++ b/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/aot/JdbcRepositoryMetadataIntegrationTests.java @@ -41,6 +41,7 @@ import org.springframework.data.jdbc.core.convert.MappingJdbcConverter; import org.springframework.data.jdbc.core.convert.RelationResolver; import org.springframework.data.jdbc.core.dialect.JdbcH2Dialect; import org.springframework.data.jdbc.repository.config.EnableJdbcRepositories; +import org.springframework.data.jdbc.repository.query.RowMapperFactory; import org.springframework.data.mapping.PersistentPropertyPath; import org.springframework.data.relational.core.mapping.RelationalMappingContext; import org.springframework.data.relational.core.mapping.RelationalPersistentProperty; @@ -69,6 +70,11 @@ class JdbcRepositoryMetadataIntegrationTests { return new RelationalMappingContext(); } + @Bean + RowMapperFactory rowMapperFactory() { + return mock(RowMapperFactory.class); + } + @Bean JdbcConverter converter(RelationalMappingContext mappingContext) { return new MappingJdbcConverter(mappingContext, new RelationResolver() {