Browse Source

Refine JavaPoet usage.

See: #2121
Original pull request: #2124
pull/2144/head
Mark Paluch 3 months ago
parent
commit
ca77912f86
No known key found for this signature in database
GPG Key ID: 55BC6374BAA9D973
  1. 123
      spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/aot/JdbcCodeBlocks.java
  2. 2
      spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/aot/JdbcRepositoryContributor.java
  3. 12
      spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/AotJdbcRepositoryIntegrationTests.java
  4. 20
      spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/aot/AotFragmentTestConfigurationSupport.java
  5. 11
      spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/aot/JdbcRepositoryContributorIntegrationTests.java
  6. 6
      spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/aot/JdbcRepositoryMetadataIntegrationTests.java

123
spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/aot/JdbcCodeBlocks.java

@ -34,6 +34,7 @@ import org.jspecify.annotations.Nullable; @@ -34,6 +34,7 @@ import org.jspecify.annotations.Nullable;
import org.springframework.core.annotation.MergedAnnotation;
import org.springframework.data.domain.SliceImpl;
import org.springframework.data.domain.Sort;
import org.springframework.data.javapoet.LordOfTheStrings;
import org.springframework.data.jdbc.repository.query.JdbcQueryMethod;
import org.springframework.data.jdbc.repository.query.Modifying;
import org.springframework.data.jdbc.repository.query.ParameterBinding;
@ -44,10 +45,10 @@ import org.springframework.data.relational.core.query.CriteriaDefinition; @@ -44,10 +45,10 @@ import org.springframework.data.relational.core.query.CriteriaDefinition;
import org.springframework.data.relational.core.sql.LockMode;
import org.springframework.data.relational.repository.Lock;
import org.springframework.data.repository.aot.generate.AotQueryMethodGenerationContext;
import org.springframework.data.repository.aot.generate.MethodReturn;
import org.springframework.data.repository.query.parser.Part;
import org.springframework.data.support.PageableExecutionUtils;
import org.springframework.data.util.Pair;
import org.springframework.data.util.ReflectionUtils;
import org.springframework.javapoet.CodeBlock;
import org.springframework.javapoet.CodeBlock.Builder;
import org.springframework.javapoet.TypeName;
@ -58,7 +59,7 @@ import org.springframework.jdbc.core.SingleColumnRowMapper; @@ -58,7 +59,7 @@ import org.springframework.jdbc.core.SingleColumnRowMapper;
import org.springframework.jdbc.core.namedparam.MapSqlParameterSource;
import org.springframework.jdbc.core.namedparam.SqlParameterSource;
import org.springframework.util.Assert;
import org.springframework.util.ObjectUtils;
import org.springframework.util.ClassUtils;
import org.springframework.util.StringUtils;
/**
@ -568,25 +569,26 @@ class JdbcCodeBlocks { @@ -568,25 +569,26 @@ class JdbcCodeBlocks {
Builder builder = CodeBlock.builder();
boolean isProjecting = !ObjectUtils.nullSafeEquals(
TypeName.get(context.getRepositoryInformation().getDomainType()), context.getActualReturnType());
Type actualReturnType = isProjecting ? context.getActualReturnType().getType()
MethodReturn methodReturn = context.getMethodReturn();
boolean isProjecting = methodReturn.isProjecting()
|| StringUtils.hasText(context.getDynamicProjectionParameterName());
Type actualReturnType = isProjecting ? methodReturn.getActualReturnClass()
: context.getRepositoryInformation().getDomainType();
builder.add("\n");
Class<?> returnType = context.getMethod().getReturnType();
TypeName queryResultType = TypeName.get(context.getActualReturnType().toClass());
Class<?> returnType = context.getMethodReturn().toClass();
TypeName queryResultType = methodReturn.getActualClassName();
String result = context.localVariable("result");
String rowMapper = context.localVariable("rowMapper");
if (modifying.isPresent()) {
return update(builder, returnType);
return update(returnType);
} else if (aotQuery.isCount()) {
return count(builder, result, returnType, queryResultType);
return count(result, returnType, queryResultType);
} else if (aotQuery.isExists()) {
return exists(builder, queryResultType);
return exists(queryResultType);
} else if (aotQuery.isDelete()) {
return delete(builder, rowMapper, result, queryResultType, returnType, actualReturnType);
return delete(rowMapper, result, queryResultType, returnType, actualReturnType);
} else {
String resultSetExtractor = null;
@ -603,7 +605,7 @@ class JdbcCodeBlocks { @@ -603,7 +605,7 @@ class JdbcCodeBlocks {
if (isProjecting) {
typeToRead = context.getReturnedType().getDomainType();
} else {
typeToRead = context.getActualReturnType().getType();
typeToRead = methodReturn.getActualReturnClass();
}
builder.addStatement("$T $L = getRowMapperFactory().create($T.class)", RowMapper.class, rowMapper,
@ -667,25 +669,25 @@ class JdbcCodeBlocks { @@ -667,25 +669,25 @@ class JdbcCodeBlocks {
}
builder.addStatement("return ($T) convertMany($L, %s)".formatted(dynamicProjection ? "$L" : "$T.class"),
context.getReturnTypeName(), result, queryResultTypeRef);
methodReturn.getTypeName(), result, queryResultTypeRef);
} else if (queryMethod.isStreamQuery()) {
builder.addStatement("$1T $2L = getJdbcOperations().queryForStream($3L, $4L, $5L)", Stream.class, result,
queryVariableName, parameterSourceVariableName, rowMapper);
builder.addStatement("return ($T) convertMany($L, $T.class)", context.getReturnTypeName(), result,
builder.addStatement("return ($T) convertMany($L, $T.class)", methodReturn.getTypeName(), result,
queryResultTypeRef);
} else {
builder.addStatement("$T $L = queryForObject($L, $L, $L)", Object.class, result, queryVariableName,
parameterSourceVariableName, rowMapper);
if (Optional.class.isAssignableFrom(context.getReturnType().toClass())) {
if (methodReturn.isOptional()) {
builder.addStatement(
"return ($1T) $1T.ofNullable(convertOne($2L, %s))".formatted(dynamicProjection ? "$3L" : "$3T.class"),
Optional.class, result, queryResultTypeRef);
} else {
builder.addStatement("return ($T) convertOne($L, %s)".formatted(dynamicProjection ? "$L" : "$T.class"),
context.getReturnTypeName(), result, queryResultTypeRef);
methodReturn.getTypeName(), result, queryResultTypeRef);
}
}
}
@ -693,37 +695,35 @@ class JdbcCodeBlocks { @@ -693,37 +695,35 @@ class JdbcCodeBlocks {
return builder.build();
}
private CodeBlock update(Builder builder, Class<?> returnType) {
private CodeBlock update(Class<?> returnType) {
String result = context.localVariable("result");
builder.add("$[");
if (!ReflectionUtils.isVoid(returnType)) {
builder.add("int $L = ", result);
}
Builder builder = CodeBlock.builder();
builder.add("getJdbcOperations().update($L, $L)", queryVariableName, parameterSourceVariableName);
builder.add(";\n$]");
LordOfTheStrings.InvocationBuilder invoke = LordOfTheStrings.invoke("getJdbcOperations().update($L, $L)",
queryVariableName, parameterSourceVariableName);
if (returnType == boolean.class || returnType == Boolean.class) {
builder.addStatement("return $L != 0", result);
} else if (returnType == Long.class) {
builder.addStatement("return (long) $L", result);
} else if (ReflectionUtils.isVoid(returnType)) {
if (returnType == Void.class) {
builder.addStatement("return null");
}
if (context.getMethodReturn().isVoid()) {
builder.addStatement(invoke.build());
} else {
builder.addStatement("return $L", result);
builder.addStatement(invoke.assignTo("int $L", result));
}
builder.addStatement(LordOfTheStrings.returning(returnType) //
.whenBoolean("$L != 0", result) //
.whenBoxedLong("(long) $L", result) //
.otherwise("$L", result)//
.build());
return builder.build();
}
private CodeBlock delete(Builder builder, String rowMapper, String result, TypeName queryResultType,
private CodeBlock delete(String rowMapper, String result, TypeName queryResultType,
Class<?> returnType, Type actualReturnType) {
CodeBlock.Builder builder = CodeBlock.builder();
builder.addStatement("$T $L = getRowMapperFactory().create($T.class)", RowMapper.class, rowMapper,
context.getRepositoryInformation().getDomainType());
@ -732,48 +732,37 @@ class JdbcCodeBlocks { @@ -732,48 +732,37 @@ class JdbcCodeBlocks {
builder.addStatement("$L.forEach(getOperations()::delete)", result);
if (Collection.class.isAssignableFrom(context.getReturnType().toClass())) {
builder.addStatement("return ($T) convertMany($L, $T.class)", context.getReturnTypeName(), result,
queryResultType);
} else if (returnType == context.getRepositoryInformation().getDomainType()) {
builder.addStatement("return ($1T) ($2L.isEmpty() ? null : $2L.iterator().next())", actualReturnType, result);
} else if (returnType == boolean.class || returnType == Boolean.class) {
builder.addStatement("return !$L.isEmpty()", result);
} else if (returnType == Long.class) {
builder.addStatement("return (long) $L.size()", result);
} else if (ReflectionUtils.isVoid(returnType)) {
if (returnType == Void.class) {
builder.addStatement("return null");
}
} else {
builder.addStatement("return $L.size()", result);
}
builder.addStatement(LordOfTheStrings.returning(returnType) //
.when(Collection.class.isAssignableFrom(context.getMethodReturn().toClass()),
"($T) convertMany($L, $T.class)", context.getMethodReturn().getTypeName(), result, queryResultType) //
.when(context.getRepositoryInformation().getDomainType(),
"($1T) ($2L.isEmpty() ? null : $2L.iterator().next())", actualReturnType, result) //
.whenBoolean("!$L.isEmpty()", result) //
.whenBoxedLong("(long) $L.size()", result) //
.otherwise("$L.size()", result) //
.build());
return builder.build();
}
private CodeBlock count(Builder builder, String result, Class<?> returnType, TypeName queryResultType) {
private CodeBlock count(String result, Class<?> returnType, TypeName queryResultType) {
CodeBlock.Builder builder = CodeBlock.builder();
builder.addStatement("$1T $2L = queryForObject($3L, $4L, new $5T<>($1T.class))", Number.class, result,
queryVariableName, parameterSourceVariableName, SingleColumnRowMapper.class);
if (returnType == Long.class) {
builder.addStatement("return $1L != null ? $1L.longValue() : null", result);
} else if (returnType == Integer.class) {
builder.addStatement("return $1L != null ? $1L.intValue() : null", result);
} else if (returnType == Long.TYPE) {
builder.addStatement("return $1L != null ? $1L.longValue() : 0L", result);
} else if (returnType == Integer.TYPE) {
builder.addStatement("return $1L != null ? $1L.intValue() : 0", result);
} else {
builder.addStatement("return ($T) convertOne($L, $T.class)", context.getReturnTypeName(), result,
queryResultType);
}
builder.addStatement(LordOfTheStrings.returning(returnType) //
.number(result) //
.otherwise("($T) convertOne($L, $T.class)", context.getMethodReturn().getTypeName(), result, queryResultType) //
.build());
return builder.build();
}
private CodeBlock exists(Builder builder, TypeName queryResultType) {
private CodeBlock exists(TypeName queryResultType) {
CodeBlock.Builder builder = CodeBlock.builder();
builder.addStatement("return ($T) getJdbcOperations().query($L, $L, $T::next)", queryResultType,
queryVariableName, parameterSourceVariableName, ResultSet.class);
@ -783,8 +772,8 @@ class JdbcCodeBlocks { @@ -783,8 +772,8 @@ class JdbcCodeBlocks {
public static boolean returnsModifying(Class<?> returnType) {
return returnType == int.class || returnType == long.class || returnType == Integer.class
|| returnType == Long.class;
return ClassUtils.resolvePrimitiveIfNecessary(returnType) == Integer.class
|| ClassUtils.resolvePrimitiveIfNecessary(returnType) == Long.class;
}
}

2
spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/aot/JdbcRepositoryContributor.java

@ -146,6 +146,8 @@ public class JdbcRepositoryContributor extends RepositoryContributor { @@ -146,6 +146,8 @@ public class JdbcRepositoryContributor extends RepositoryContributor {
body.add(JdbcCodeBlocks.queryBuilder(context, queryMethod).filter(aotQueries)
.usingQueryVariableName(queryVariable).parameterSource(parameterSourceVariable).lock(lock).build());
body.add("\n");
body.add(JdbcCodeBlocks.executionBuilder(context, queryMethod).modifying(modifying)
.usingQueryVariableName(queryVariable).parameterSource(parameterSourceVariable).queries(aotQueries)
.queryAnnotation(query).build());

12
spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/AotJdbcRepositoryIntegrationTests.java

@ -15,16 +15,21 @@ @@ -15,16 +15,21 @@
*/
package org.springframework.data.jdbc.repository;
import java.util.Optional;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ApplicationContext;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.ComponentScan;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.FilterType;
import org.springframework.data.jdbc.core.JdbcAggregateOperations;
import org.springframework.data.jdbc.core.convert.QueryMappingConfiguration;
import org.springframework.data.jdbc.core.dialect.JdbcH2Dialect;
import org.springframework.data.jdbc.repository.aot.AotFragmentTestConfigurationSupport;
import org.springframework.data.jdbc.repository.aot.UserRepository;
import org.springframework.data.jdbc.repository.config.EnableJdbcRepositories;
import org.springframework.data.jdbc.repository.support.BeanFactoryAwareRowMapperFactory;
import org.springframework.data.jdbc.testing.DatabaseType;
import org.springframework.data.jdbc.testing.EnabledOnDatabase;
import org.springframework.data.jdbc.testing.IntegrationTest;
@ -58,6 +63,13 @@ class AotJdbcRepositoryIntegrationTests extends JdbcRepositoryIntegrationTests { @@ -58,6 +63,13 @@ class AotJdbcRepositoryIntegrationTests extends JdbcRepositoryIntegrationTests {
false);
}
@Bean
BeanFactoryAwareRowMapperFactory rowMapperFactory(ApplicationContext context,
JdbcAggregateOperations aggregateOperations, Optional<QueryMappingConfiguration> queryMappingConfiguration) {
return new BeanFactoryAwareRowMapperFactory(context, aggregateOperations,
queryMappingConfiguration.orElse(QueryMappingConfiguration.EMPTY));
}
@Bean
@Override
DummyEntityRepository dummyEntityRepository() {

20
spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/aot/AotFragmentTestConfigurationSupport.java

@ -19,7 +19,6 @@ import java.lang.reflect.Method; @@ -19,7 +19,6 @@ import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import org.mockito.Mockito;
@ -34,9 +33,6 @@ import org.springframework.beans.factory.support.AbstractBeanDefinition; @@ -34,9 +33,6 @@ import org.springframework.beans.factory.support.AbstractBeanDefinition;
import org.springframework.beans.factory.support.BeanDefinitionBuilder;
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
import org.springframework.beans.factory.support.DefaultBeanNameGenerator;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.context.annotation.Bean;
import org.springframework.core.env.Environment;
import org.springframework.core.env.StandardEnvironment;
import org.springframework.core.io.DefaultResourceLoader;
@ -45,12 +41,10 @@ import org.springframework.core.type.AnnotationMetadata; @@ -45,12 +41,10 @@ import org.springframework.core.type.AnnotationMetadata;
import org.springframework.data.expression.ValueExpressionParser;
import org.springframework.data.jdbc.core.JdbcAggregateOperations;
import org.springframework.data.jdbc.core.convert.MappingJdbcConverter;
import org.springframework.data.jdbc.core.convert.QueryMappingConfiguration;
import org.springframework.data.jdbc.core.dialect.JdbcDialect;
import org.springframework.data.jdbc.core.mapping.JdbcMappingContext;
import org.springframework.data.jdbc.repository.config.EnableJdbcRepositories;
import org.springframework.data.jdbc.repository.query.RowMapperFactory;
import org.springframework.data.jdbc.repository.support.BeanFactoryAwareRowMapperFactory;
import org.springframework.data.projection.ProjectionFactory;
import org.springframework.data.projection.SpelAwareProxyProjectionFactory;
import org.springframework.data.repository.config.AnnotationRepositoryConfigurationSource;
@ -71,13 +65,12 @@ import org.springframework.util.ReflectionUtils; @@ -71,13 +65,12 @@ import org.springframework.util.ReflectionUtils;
*
* @author Mark Paluch
*/
public class AotFragmentTestConfigurationSupport implements BeanFactoryPostProcessor, ApplicationContextAware {
public class AotFragmentTestConfigurationSupport implements BeanFactoryPostProcessor {
private final Class<?> repositoryInterface;
private final JdbcDialect dialect;
private final boolean registerFragmentFacade;
private final TestJdbcAotRepositoryContext<?> repositoryContext;
private ApplicationContext applicationContext;
public AotFragmentTestConfigurationSupport(Class<?> repositoryInterface, JdbcDialect dialect, Class<?> configClass) {
this(repositoryInterface, dialect, configClass, true);
@ -98,13 +91,6 @@ public class AotFragmentTestConfigurationSupport implements BeanFactoryPostProce @@ -98,13 +91,6 @@ public class AotFragmentTestConfigurationSupport implements BeanFactoryPostProce
this.registerFragmentFacade = registerFragmentFacade;
}
@Bean
BeanFactoryAwareRowMapperFactory rowMapperFactory(ApplicationContext context,
JdbcAggregateOperations aggregateOperations, Optional<QueryMappingConfiguration> queryMappingConfiguration) {
return new BeanFactoryAwareRowMapperFactory(context, aggregateOperations,
queryMappingConfiguration.orElse(QueryMappingConfiguration.EMPTY));
}
@Override
public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) throws BeansException {
@ -191,8 +177,4 @@ public class AotFragmentTestConfigurationSupport implements BeanFactoryPostProce @@ -191,8 +177,4 @@ public class AotFragmentTestConfigurationSupport implements BeanFactoryPostProce
return creationContext;
}
@Override
public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
this.applicationContext = applicationContext;
}
}

11
spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/aot/JdbcRepositoryContributorIntegrationTests.java

@ -18,11 +18,13 @@ package org.springframework.data.jdbc.repository.aot; @@ -18,11 +18,13 @@ package org.springframework.data.jdbc.repository.aot;
import static org.assertj.core.api.Assertions.*;
import java.util.List;
import java.util.Optional;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ApplicationContext;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.ComponentScan;
import org.springframework.context.annotation.Configuration;
@ -34,8 +36,10 @@ import org.springframework.data.domain.Pageable; @@ -34,8 +36,10 @@ import org.springframework.data.domain.Pageable;
import org.springframework.data.domain.Slice;
import org.springframework.data.domain.Sort;
import org.springframework.data.jdbc.core.JdbcAggregateOperations;
import org.springframework.data.jdbc.core.convert.QueryMappingConfiguration;
import org.springframework.data.jdbc.core.dialect.JdbcH2Dialect;
import org.springframework.data.jdbc.repository.config.EnableJdbcRepositories;
import org.springframework.data.jdbc.repository.support.BeanFactoryAwareRowMapperFactory;
import org.springframework.data.jdbc.testing.DatabaseType;
import org.springframework.data.jdbc.testing.EnabledOnDatabase;
import org.springframework.data.jdbc.testing.IntegrationTest;
@ -71,6 +75,13 @@ class JdbcRepositoryContributorIntegrationTests { @@ -71,6 +75,13 @@ class JdbcRepositoryContributorIntegrationTests {
return TestClass.of(JdbcRepositoryContributorIntegrationTests.class);
}
@Bean
BeanFactoryAwareRowMapperFactory rowMapperFactory(ApplicationContext context,
JdbcAggregateOperations aggregateOperations, Optional<QueryMappingConfiguration> queryMappingConfiguration) {
return new BeanFactoryAwareRowMapperFactory(context, aggregateOperations,
queryMappingConfiguration.orElse(QueryMappingConfiguration.EMPTY));
}
@Bean
MyRowMapper myRowMapper() {
return new MyRowMapper();

6
spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/aot/JdbcRepositoryMetadataIntegrationTests.java

@ -41,6 +41,7 @@ import org.springframework.data.jdbc.core.convert.MappingJdbcConverter; @@ -41,6 +41,7 @@ import org.springframework.data.jdbc.core.convert.MappingJdbcConverter;
import org.springframework.data.jdbc.core.convert.RelationResolver;
import org.springframework.data.jdbc.core.dialect.JdbcH2Dialect;
import org.springframework.data.jdbc.repository.config.EnableJdbcRepositories;
import org.springframework.data.jdbc.repository.query.RowMapperFactory;
import org.springframework.data.mapping.PersistentPropertyPath;
import org.springframework.data.relational.core.mapping.RelationalMappingContext;
import org.springframework.data.relational.core.mapping.RelationalPersistentProperty;
@ -69,6 +70,11 @@ class JdbcRepositoryMetadataIntegrationTests { @@ -69,6 +70,11 @@ class JdbcRepositoryMetadataIntegrationTests {
return new RelationalMappingContext();
}
@Bean
RowMapperFactory rowMapperFactory() {
return mock(RowMapperFactory.class);
}
@Bean
JdbcConverter converter(RelationalMappingContext mappingContext) {
return new MappingJdbcConverter(mappingContext, new RelationResolver() {

Loading…
Cancel
Save