From 079c177d720693c036517801f818996a881787d7 Mon Sep 17 00:00:00 2001 From: Mark Paluch Date: Thu, 14 Aug 2025 12:23:36 +0200 Subject: [PATCH] Add support for embeddable mapping to QueryMapper and UpdateMapper. We now support querying, updating, sorting and projecting embeddables by resolving these to their individual columns. Closes #2011 Original pull request: #2114 --- .../r2dbc/core/DefaultStatementMapper.java | 2 +- .../data/r2dbc/core/R2dbcEntityTemplate.java | 16 +- .../data/r2dbc/query/QueryMapper.java | 281 +++++++++++++++++- .../data/r2dbc/query/UpdateMapper.java | 38 ++- ...ltReactiveDataAccessStrategyUnitTests.java | 37 ++- .../r2dbc/query/QueryMapperUnitTests.java | 129 +++++++- .../r2dbc/query/UpdateMapperUnitTests.java | 91 +++++- ...dbcRepositoryEmbeddedIntegrationTests.java | 22 +- .../query/PartTreeR2dbcQueryUnitTests.java | 75 +++++ 9 files changed, 636 insertions(+), 55 deletions(-) diff --git a/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/core/DefaultStatementMapper.java b/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/core/DefaultStatementMapper.java index a7fcf2a13..24f7f46a5 100644 --- a/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/core/DefaultStatementMapper.java +++ b/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/core/DefaultStatementMapper.java @@ -143,7 +143,7 @@ class DefaultStatementMapper implements StatementMapper { List mapped = new ArrayList<>(selectList.size()); for (Expression expression : selectList) { - mapped.add(updateMapper.getMappedObject(expression, entity)); + mapped.addAll(updateMapper.getMappedObjects(expression, entity)); } return mapped; diff --git a/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/core/R2dbcEntityTemplate.java b/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/core/R2dbcEntityTemplate.java index a414483b8..353c550fd 100644 --- a/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/core/R2dbcEntityTemplate.java +++ b/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/core/R2dbcEntityTemplate.java @@ -23,7 +23,7 @@ import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import java.util.Collections; -import java.util.HashMap; +import java.util.LinkedHashMap; import java.util.LinkedHashSet; import java.util.List; import java.util.Map; @@ -34,6 +34,7 @@ import java.util.function.Function; import java.util.stream.Collectors; import org.reactivestreams.Publisher; + import org.springframework.beans.BeansException; import org.springframework.beans.factory.BeanFactory; import org.springframework.beans.factory.BeanFactoryAware; @@ -60,7 +61,6 @@ import org.springframework.data.r2dbc.mapping.event.BeforeConvertCallback; import org.springframework.data.r2dbc.mapping.event.BeforeSaveCallback; import org.springframework.data.relational.core.conversion.AbstractRelationalConverter; import org.springframework.data.relational.core.mapping.PersistentPropertyTranslator; -import org.springframework.data.relational.core.mapping.RelationalMappingContext; import org.springframework.data.relational.core.mapping.RelationalPersistentEntity; import org.springframework.data.relational.core.mapping.RelationalPersistentProperty; import org.springframework.data.relational.core.query.Criteria; @@ -621,9 +621,13 @@ public class R2dbcEntityTemplate implements R2dbcEntityOperations, BeanFactoryAw return maybeCallBeforeSave(entityToUse, outboundRow, tableName) // .flatMap(onBeforeSave -> { - Map idValues = new HashMap<>(); - ((RelationalMappingContext) mappingContext).getAggregatePath(persistentEntity).getTableInfo() - .idColumnInfos().forEach((ap, ci) -> idValues.put(ci.name(), outboundRow.remove(ci.name()))); + Map idValues = new LinkedHashMap<>(); + List identifierColumns = dataAccessStrategy.getIdentifierColumns(persistentEntity.getType()); + Assert.state(!identifierColumns.isEmpty(), entityToUse + " has no Identifier. Update is not possible."); + + identifierColumns.forEach(sqlIdentifier -> { + idValues.put(sqlIdentifier, outboundRow.remove(sqlIdentifier)); + }); persistentEntity.forEach(p -> { if (p.isInsertOnly()) { @@ -631,8 +635,6 @@ public class R2dbcEntityTemplate implements R2dbcEntityOperations, BeanFactoryAw } }); - Assert.state(!idValues.isEmpty(), entityToUse + " has no id. Update is not possible"); - Criteria criteria = null; for (Map.Entry idAndValue : idValues.entrySet()) { if (criteria == null) { diff --git a/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/query/QueryMapper.java b/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/query/QueryMapper.java index f6ee60dd0..7b9b7d89a 100644 --- a/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/query/QueryMapper.java +++ b/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/query/QueryMapper.java @@ -24,6 +24,8 @@ import java.util.regex.Pattern; import org.springframework.data.domain.Sort; import org.springframework.data.mapping.MappingException; +import org.springframework.data.mapping.PersistentProperty; +import org.springframework.data.mapping.PersistentPropertyAccessor; import org.springframework.data.mapping.PersistentPropertyPath; import org.springframework.data.mapping.PropertyPath; import org.springframework.data.mapping.PropertyReferenceException; @@ -112,22 +114,42 @@ public class QueryMapper { SqlSort.validate(order); - OrderByField simpleOrderByField = createSimpleOrderByField(table, entity, order); - OrderByField orderBy = simpleOrderByField.withNullHandling(order.getNullHandling()); - mappedOrder.add(order.isAscending() ? orderBy.asc() : orderBy.desc()); + List simpleOrderByFields = createSimpleOrderByFields(table, entity, order); + + simpleOrderByFields.forEach(field -> { + + OrderByField orderBy = field.withNullHandling(order.getNullHandling()); + mappedOrder.add(order.isAscending() ? orderBy.asc() : orderBy.desc()); + }); } return mappedOrder; } - private OrderByField createSimpleOrderByField(Table table, RelationalPersistentEntity entity, Sort.Order order) { + private List createSimpleOrderByFields(Table table, @Nullable RelationalPersistentEntity entity, + Sort.Order order) { if (order instanceof SqlSort.SqlOrder sqlOrder && sqlOrder.isUnsafe()) { - return OrderByField.from(Expressions.just(sqlOrder.getProperty())); + return List.of(OrderByField.from(Expressions.just(sqlOrder.getProperty()))); } Field field = createPropertyField(entity, SqlIdentifier.unquoted(order.getProperty()), this.mappingContext); - return OrderByField.from(table.column(field.getMappedColumnName())); + + if (field.isEmbedded() && entity != null) { + + RelationalPersistentEntity embeddedEntity = getMappingContext() + .getRequiredPersistentEntity(field.getRequiredProperty()); + + List fields = new ArrayList<>(); + + for (RelationalPersistentProperty embeddedProperty : embeddedEntity) { + fields.addAll(createSimpleOrderByFields(table, embeddedEntity, order.withProperty(embeddedProperty.getName()))); + } + + return fields; + } + + return List.of(OrderByField.from(table.column(field.getMappedColumnName()))); } /** @@ -137,12 +159,35 @@ public class QueryMapper { * @param entity related {@link RelationalPersistentEntity}, can be {@literal null}. * @return the mapped {@link Expression}. * @since 1.1 + * @deprecated since 4.0 in favor of {@link #getMappedObjects(Expression, RelationalPersistentEntity)} where usage of + * {@link org.springframework.data.relational.core.mapping.Embedded embeddable properties} can return more + * than one mapped result. */ + @Deprecated(since = "4.0") public Expression getMappedObject(Expression expression, @Nullable RelationalPersistentEntity entity) { + List mappedObjects = getMappedObjects(expression, entity); + + if (mappedObjects.isEmpty()) { + throw new IllegalArgumentException(String.format("Cannot map %s", expression)); + } + + return mappedObjects.get(0); + } + + /** + * Map the {@link Expression} object to apply field name mapping using {@link Class the type to read}. + * + * @param expression must not be {@literal null}. + * @param entity related {@link RelationalPersistentEntity}, can be {@literal null}. + * @return the mapped {@link Expression}s. + * @since 4.0 + */ + public List getMappedObjects(Expression expression, @Nullable RelationalPersistentEntity entity) { + if (entity == null || expression instanceof AsteriskFromTable || expression instanceof Expressions.SimpleExpression) { - return expression; + return List.of(expression); } if (expression instanceof Column column) { @@ -150,8 +195,22 @@ public class QueryMapper { Field field = createPropertyField(entity, column.getName()); TableLike table = column.getTable(); + if (field.isEmbedded()) { + + RelationalPersistentEntity embeddedEntity = getMappingContext() + .getRequiredPersistentEntity(field.getRequiredProperty()); + + List expressions = new ArrayList<>(); + + for (RelationalPersistentProperty embeddedProperty : embeddedEntity) { + expressions.addAll(getMappedObjects(Column.create(embeddedProperty.getName(), table), embeddedEntity)); + } + + return expressions; + } + Column columnFromTable = table.column(field.getMappedColumnName()); - return column instanceof Aliased ? columnFromTable.as(((Aliased) column).getAlias()) : columnFromTable; + return List.of(column instanceof Aliased ? columnFromTable.as(((Aliased) column).getAlias()) : columnFromTable); } if (expression instanceof SimpleFunction function) { @@ -160,12 +219,12 @@ public class QueryMapper { List mappedArguments = new ArrayList<>(arguments.size()); for (Expression argument : arguments) { - mappedArguments.add(getMappedObject(argument, entity)); + mappedArguments.addAll(getMappedObjects(argument, entity)); } SimpleFunction mappedFunction = SimpleFunction.create(function.getFunctionName(), mappedArguments); - return function instanceof Aliased ? mappedFunction.as(((Aliased) function).getAlias()) : mappedFunction; + return List.of(function instanceof Aliased ? mappedFunction.as(((Aliased) function).getAlias()) : mappedFunction); } throw new IllegalArgumentException(String.format("Cannot map %s", expression)); @@ -297,6 +356,43 @@ public class QueryMapper { @Nullable RelationalPersistentEntity entity) { Field propertyField = createPropertyField(entity, criteria.getColumn(), this.mappingContext); + + if (propertyField.isEmbedded() && entity != null) { + + Object value = criteria.getValue(); + + RelationalPersistentEntity embeddedEntity = mappingContext + .getRequiredPersistentEntity(propertyField.getRequiredProperty()); + PersistentPropertyAccessor propertyAccessor = getEmbeddedPropertyAccessor(value, embeddedEntity, + propertyField); + + Condition condition = Conditions.unrestricted(); + + for (RelationalPersistentProperty embeddedProperty : embeddedEntity) { + + Object propertyValue = propertyAccessor.getProperty(embeddedProperty); + + CriteriaWrapper cw = new CriteriaWrapper(criteria) { + + @Override + public SqlIdentifier getColumn() { + return SqlIdentifier.unquoted(embeddedProperty.getName()); + } + + @Nullable + @Override + public Object getValue() { + return propertyValue; + } + }; + + Condition mapped = mapCondition(cw, bindings, table, embeddedEntity); + condition = condition.and(mapped); + } + + return condition; + } + Column column = table.column(propertyField.getMappedColumnName()); TypeInformation actualType = propertyField.getTypeHint().getRequiredActualType(); @@ -321,6 +417,39 @@ public class QueryMapper { } return createCondition(column, mappedValue, typeHint, bindings, comparator, criteria.isIgnoreCase()); + + } + + static PersistentPropertyAccessor getEmbeddedPropertyAccessor(@Nullable Object value, + RelationalPersistentEntity embeddedEntity, Field propertyField) { + + if (value != null) { + + Class propertyType = embeddedEntity.getType(); + if (!propertyType.isInstance(value)) { + throw new IllegalArgumentException("Value of property " + propertyField.getRequiredProperty().getName() + + " is not an instance of " + embeddedEntity.getType().getName() + " but " + value.getClass().getName()); + } + + return embeddedEntity.getPropertyAccessor(value); + } + + return new PersistentPropertyAccessor<>() { + @Override + public void setProperty(PersistentProperty property, @org.jspecify.annotations.Nullable Object value) { + + } + + @Override + public @org.jspecify.annotations.Nullable Object getProperty(PersistentProperty property) { + return null; + } + + @Override + public Object getBean() { + return null; + } + }; } private Escaper getEscaper(Comparator comparator) { @@ -587,6 +716,25 @@ public class QueryMapper { public TypeInformation getTypeHint() { return TypeInformation.OBJECT; } + + public boolean isEmbedded() { + return false; + } + + public @org.jspecify.annotations.Nullable RelationalPersistentProperty getProperty() { + return null; + } + + public RelationalPersistentProperty getRequiredProperty() { + + RelationalPersistentProperty property = getProperty(); + + if (property == null) { + throw new IllegalStateException("No property found for field: " + this.name); + } + + return property; + } } /** @@ -633,13 +781,34 @@ public class QueryMapper { this.mappingContext = context; this.path = getPath(name.getReference()); - this.property = this.path == null ? property : this.path.getLeafProperty(); + + RelationalPersistentProperty persistentProperty = null; + if (this.path != null) { + + RelationalPersistentEntity currentEntity = entity; + RelationalPersistentProperty currentProperty = null; + for (RelationalPersistentProperty p : path) { + + currentProperty = currentEntity.getPersistentProperty(p.getName()); + + if (currentProperty == null) { + break; + } + + if (currentProperty.isEntity()) { + currentEntity = mappingContext.getRequiredPersistentEntity(currentProperty); + } + } + + persistentProperty = currentProperty; + } + + this.property = persistentProperty; } @Override public SqlIdentifier getMappedColumnName() { - return this.path == null || this.path.getLeafProperty() == null ? super.getMappedColumnName() - : this.path.getLeafProperty().getColumnName(); + return this.property == null ? super.getMappedColumnName() : this.property.getColumnName(); } /** @@ -700,5 +869,91 @@ public class QueryMapper { return this.property.getTypeInformation(); } + + @Override + public boolean isEmbedded() { + return this.property != null && this.property.isEmbedded(); + } + + @Override + public @org.jspecify.annotations.Nullable RelationalPersistentProperty getProperty() { + return this.property; + } + } + + abstract static class CriteriaWrapper extends AbstractCriteria { + + private final CriteriaDefinition delegate; + + public CriteriaWrapper(CriteriaDefinition delegate) { + this.delegate = delegate; + } + + @Nullable + @Override + public Comparator getComparator() { + return delegate.getComparator(); + } + + @Override + public boolean isIgnoreCase() { + return delegate.isIgnoreCase(); + } + } + + abstract static class AbstractCriteria implements CriteriaDefinition { + @Override + public boolean isGroup() { + return false; + } + + @Override + public List getGroup() { + return List.of(); + } + + @Nullable + @Override + public SqlIdentifier getColumn() { + return null; + } + + @Nullable + @Override + public Comparator getComparator() { + return null; + } + + @Nullable + @Override + public Object getValue() { + return null; + } + + @Override + public boolean isIgnoreCase() { + return false; + } + + @Nullable + @Override + public CriteriaDefinition getPrevious() { + return null; + } + + @Override + public boolean hasPrevious() { + return false; + } + + @Override + public boolean isEmpty() { + return false; + } + + @Override + public Combinator getCombinator() { + return null; + } } } diff --git a/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/query/UpdateMapper.java b/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/query/UpdateMapper.java index fb9eec7ed..4070959e2 100644 --- a/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/query/UpdateMapper.java +++ b/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/query/UpdateMapper.java @@ -16,13 +16,16 @@ package org.springframework.data.r2dbc.query; import java.util.ArrayList; +import java.util.Collection; import java.util.List; import java.util.Map; +import org.springframework.data.mapping.PersistentPropertyAccessor; import org.springframework.data.r2dbc.convert.R2dbcConverter; import org.springframework.data.r2dbc.dialect.R2dbcDialect; import org.springframework.data.relational.core.dialect.Escaper; import org.springframework.data.relational.core.mapping.RelationalPersistentEntity; +import org.springframework.data.relational.core.mapping.RelationalPersistentProperty; import org.springframework.data.relational.core.query.Update; import org.springframework.data.relational.core.query.ValueFunction; import org.springframework.data.relational.core.sql.AssignValue; @@ -94,23 +97,42 @@ public class UpdateMapper extends QueryMapper { List result = new ArrayList<>(); assignments.forEach((column, value) -> { - Assignment assignment = getAssignment(column, value, bindings, table, entity); - result.add(assignment); + result.addAll(getAssignments(column, value, bindings, table, entity)); }); return new BoundAssignments(bindings, result); } - private Assignment getAssignment(SqlIdentifier columnName, Object value, MutableBindings bindings, Table table, - @Nullable RelationalPersistentEntity entity) { + private Collection getAssignments(SqlIdentifier columnName, Object value, MutableBindings bindings, + Table table, @Nullable RelationalPersistentEntity entity) { Field propertyField = createPropertyField(entity, columnName, getMappingContext()); + + if (propertyField.isEmbedded() && entity != null) { + + RelationalPersistentEntity embeddedEntity = getMappingContext() + .getRequiredPersistentEntity(propertyField.getRequiredProperty()); + PersistentPropertyAccessor propertyAccessor = getEmbeddedPropertyAccessor(value, embeddedEntity, + propertyField); + + List assignments = new ArrayList<>(); + + for (RelationalPersistentProperty embeddedProperty : embeddedEntity) { + + Object propertyValue = propertyAccessor.getProperty(embeddedProperty); + + assignments.addAll(getAssignments(SqlIdentifier.unquoted(embeddedProperty.getName()), propertyValue, bindings, + table, embeddedEntity)); + } + + return assignments; + } + Column column = table.column(propertyField.getMappedColumnName()); TypeInformation actualType = propertyField.getTypeHint().getRequiredActualType(); Object mappedValue; Class typeHint; - if (value instanceof Parameter parameter) { mappedValue = convertValue(parameter.getValue(), propertyField.getTypeHint()); @@ -121,7 +143,7 @@ public class UpdateMapper extends QueryMapper { mappedValue = valueFunction.map(v -> convertValue(v, propertyField.getTypeHint())).apply(Escaper.DEFAULT); if (mappedValue == null) { - return Assignments.value(column, SQL.nullLiteral()); + return List.of(Assignments.value(column, SQL.nullLiteral())); } typeHint = actualType.getType(); @@ -130,13 +152,13 @@ public class UpdateMapper extends QueryMapper { mappedValue = convertValue(value, propertyField.getTypeHint()); if (mappedValue == null) { - return Assignments.value(column, SQL.nullLiteral()); + return List.of(Assignments.value(column, SQL.nullLiteral())); } typeHint = actualType.getType(); } - return createAssignment(column, mappedValue, typeHint, bindings); + return List.of(createAssignment(column, mappedValue, typeHint, bindings)); } private Assignment createAssignment(Column column, Object value, Class type, MutableBindings bindings) { diff --git a/spring-data-r2dbc/src/test/java/org/springframework/data/r2dbc/core/DefaultReactiveDataAccessStrategyUnitTests.java b/spring-data-r2dbc/src/test/java/org/springframework/data/r2dbc/core/DefaultReactiveDataAccessStrategyUnitTests.java index 8f1a3e355..e5db9514d 100644 --- a/spring-data-r2dbc/src/test/java/org/springframework/data/r2dbc/core/DefaultReactiveDataAccessStrategyUnitTests.java +++ b/spring-data-r2dbc/src/test/java/org/springframework/data/r2dbc/core/DefaultReactiveDataAccessStrategyUnitTests.java @@ -1,10 +1,14 @@ package org.springframework.data.r2dbc.core; +import static org.assertj.core.api.Assertions.*; + import java.util.Arrays; import java.util.List; +import java.util.stream.Stream; + +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; -import org.assertj.core.api.SoftAssertions; -import org.junit.jupiter.api.Test; import org.springframework.data.annotation.Id; import org.springframework.data.r2dbc.dialect.H2Dialect; import org.springframework.data.relational.core.mapping.Embedded; @@ -19,21 +23,28 @@ class DefaultReactiveDataAccessStrategyUnitTests { DefaultReactiveDataAccessStrategy dataAccessStrategy = new DefaultReactiveDataAccessStrategy(H2Dialect.INSTANCE); - @Test - void getAllColumns() { + @ParameterizedTest + @MethodSource("fixtures") + void shouldReportAllColumns(Fixture fixture) { - SoftAssertions.assertSoftly(softly -> { - check(softly, SimpleEntity.class, "ID", "NAME"); - check(softly, WithEmbedded.class, "ID", "L1_NAME", "L1_L2_NAME", "L1_L2_NUMBER"); - check(softly, WithEmbeddedId.class, "ID_NAME", "ID_NUMBER", "NAME"); - }); + List sqlIdentifiers = Arrays.stream(fixture.allColumns()).map(SqlIdentifier::quoted).toList(); + + assertThat(dataAccessStrategy.getAllColumns(fixture.entityType())) + .containsExactlyInAnyOrder(sqlIdentifiers.toArray(new SqlIdentifier[0])); } - private void check(SoftAssertions softly, Class entityType, String... columnNames) { + static Stream fixtures() { + return Stream.of(new Fixture(SimpleEntity.class, "ID", "NAME"), + new Fixture(WithEmbedded.class, "ID", "L1_NAME", "L1_L2_NAME", "L1_L2_NUMBER"), + new Fixture(WithEmbeddedId.class, "ID_NAME", "ID_NUMBER", "NAME")); + } - List sqlIdentifiers = Arrays.stream(columnNames).map(SqlIdentifier::quoted).toList(); - softly.assertThat(dataAccessStrategy.getAllColumns(entityType)).describedAs(entityType.getName()) - .containsExactlyInAnyOrder(sqlIdentifiers.toArray(new SqlIdentifier[0])); + record Fixture(Class entityType, String... allColumns) { + + @Override + public String toString() { + return entityType.getSimpleName(); + } } record SimpleEntity(int id, String name) { diff --git a/spring-data-r2dbc/src/test/java/org/springframework/data/r2dbc/query/QueryMapperUnitTests.java b/spring-data-r2dbc/src/test/java/org/springframework/data/r2dbc/query/QueryMapperUnitTests.java index cb7ef38a1..e29ea2075 100644 --- a/spring-data-r2dbc/src/test/java/org/springframework/data/r2dbc/query/QueryMapperUnitTests.java +++ b/spring-data-r2dbc/src/test/java/org/springframework/data/r2dbc/query/QueryMapperUnitTests.java @@ -26,6 +26,7 @@ import java.util.List; import java.util.Objects; import org.junit.jupiter.api.Test; + import org.springframework.core.convert.converter.Converter; import org.springframework.data.domain.Sort; import org.springframework.data.r2dbc.convert.MappingR2dbcConverter; @@ -36,6 +37,7 @@ import org.springframework.data.r2dbc.dialect.PostgresDialect; import org.springframework.data.r2dbc.dialect.R2dbcDialect; import org.springframework.data.r2dbc.mapping.R2dbcMappingContext; import org.springframework.data.relational.core.mapping.Column; +import org.springframework.data.relational.core.mapping.Embedded; import org.springframework.data.relational.core.query.Criteria; import org.springframework.data.relational.core.sql.Expression; import org.springframework.data.relational.core.sql.Functions; @@ -45,6 +47,7 @@ import org.springframework.data.relational.domain.SqlSort; import org.springframework.r2dbc.core.Parameter; import org.springframework.r2dbc.core.binding.BindMarkersFactory; import org.springframework.r2dbc.core.binding.BindTarget; + import org.testcontainers.shaded.com.fasterxml.jackson.databind.JsonNode; import org.testcontainers.shaded.com.fasterxml.jackson.databind.node.TextNode; @@ -215,8 +218,8 @@ class QueryMapperUnitTests { Table table = Table.create("my_table").as("my_aliased_table"); - Expression mappedObject = mapper.getMappedObject(table.column("alternative").as("my_aliased_col"), - mapper.getMappingContext().getRequiredPersistentEntity(Person.class)); + Expression mappedObject = mapper.getMappedObjects(table.column("alternative").as("my_aliased_col"), + mapper.getMappingContext().getRequiredPersistentEntity(Person.class)).get(0); assertThat(mappedObject).hasToString("my_aliased_table.another_name AS my_aliased_col"); } @@ -226,8 +229,8 @@ class QueryMapperUnitTests { Table table = Table.create("my_table").as("my_aliased_table"); - Expression mappedObject = mapper.getMappedObject(Functions.count(table.column("alternative")), - mapper.getMappingContext().getRequiredPersistentEntity(Person.class)); + Expression mappedObject = mapper.getMappedObjects(Functions.count(table.column("alternative")), + mapper.getMappingContext().getRequiredPersistentEntity(Person.class)).get(0); assertThat(mappedObject).hasToString("COUNT(my_aliased_table.another_name)"); } @@ -237,8 +240,8 @@ class QueryMapperUnitTests { Table table = Table.create("my_table").as("my_aliased_table"); - Expression mappedObject = mapper.getMappedObject(table.column("unknown").as("my_aliased_col"), - mapper.getMappingContext().getRequiredPersistentEntity(Person.class)); + Expression mappedObject = mapper.getMappedObjects(table.column("unknown").as("my_aliased_col"), + mapper.getMappingContext().getRequiredPersistentEntity(Person.class)).get(0); assertThat(mappedObject).hasToString("my_aliased_table.unknown AS my_aliased_col"); } @@ -248,7 +251,7 @@ class QueryMapperUnitTests { Table table = Table.create("my_table").as("my_aliased_table"); - Expression mappedObject = mapper.getMappedObject(table.column("my_col").as("my_aliased_col"), null); + Expression mappedObject = mapper.getMappedObjects(table.column("my_col").as("my_aliased_col"), null).get(0); assertThat(mappedObject).hasToString("my_aliased_table.my_col AS my_aliased_col"); } @@ -541,12 +544,94 @@ class QueryMapperUnitTests { assertThat(bindings.getBindings().iterator().next().getValue()).isEqualTo("foo"); } + @Test // GH-2096 + void shouldMapPathToEmbeddable() { + + Criteria criteria = Criteria.where("home").is(new Address(new Country("DE"))); + + BoundCondition bindings = map(criteria, WithEmbeddable.class); + + assertThat(bindings.getCondition()) + .hasToString("withembeddable.home_country_name = ?[$1] AND withembeddable.home_street = ?[$2]"); + } + + @Test // GH-2096 + void shouldMapPathToNestedEmbeddable() { + + Criteria criteria = Criteria.where("home.country").is(new Country("DE")); + + BoundCondition bindings = map(criteria, WithEmbeddable.class); + + assertThat(bindings.getCondition()).hasToString("withembeddable.home_country_name = ?[$1]"); + } + + @Test // GH-2096 + void shouldMapPathIntoEmbeddable() { + + Criteria criteria = Criteria.where("home.country.name").is("DE"); + + BoundCondition bindings = map(criteria, WithEmbeddable.class); + + assertThat(bindings.getCondition()).hasToString("withembeddable.home_country_name = ?[$1]"); + } + + @Test // GH-2096 + void shouldMapSortPathForEmbeddable() { + + List orderByFields = map(Sort.by("home"), WithEmbeddable.class); + + Table table = Table.create("withembeddable"); + assertThat(orderByFields).contains(OrderByField.from(table.column("home_country_name"), Sort.Direction.ASC)) + .contains(OrderByField.from(table.column("home_street"), Sort.Direction.ASC)); + } + + @Test // GH-2096 + void shouldMapSortPathIntoNestedEmbeddable() { + + List orderByFields = map(Sort.by("home.country"), WithEmbeddable.class); + + Table table = Table.create("withembeddable"); + assertThat(orderByFields).contains(OrderByField.from(table.column("home_country_name"), Sort.Direction.ASC)); + } + + @Test // GH-2096 + void shouldMapSortPathIntoEmbeddable() { + + List orderByFields = map(Sort.by("home.country.name"), WithEmbeddable.class); + + Table table = Table.create("withembeddable"); + assertThat(orderByFields).contains(OrderByField.from(table.column("home_country_name"), Sort.Direction.ASC)); + } + + @Test // GH-2096 + void shouldMapSelectionForEmbeddable() { + + Table table = Table.create("my_table").as("my_aliased_table"); + + List mappedObject = mapper.getMappedObjects(table.column("home"), + mapper.getMappingContext().getRequiredPersistentEntity(WithEmbeddable.class)); + + assertThat(mappedObject).extracting(Expression::toString) // + .hasSize(2) // + .contains("my_aliased_table.home_street", "my_aliased_table.home_country_name"); + } + private BoundCondition map(Criteria criteria) { + return map(criteria, Person.class); + } + + private BoundCondition map(Criteria criteria, Class entityType) { BindMarkersFactory markers = BindMarkersFactory.indexed("$", 1); - return mapper.getMappedObject(markers.create(), criteria, Table.create("person"), - mapper.getMappingContext().getRequiredPersistentEntity(Person.class)); + return mapper.getMappedObject(markers.create(), criteria, Table.create(entityType.getSimpleName().toLowerCase()), + mapper.getMappingContext().getRequiredPersistentEntity(entityType)); + } + + private List map(Sort sort, Class entityType) { + + return mapper.getMappedSort(Table.create(entityType.getSimpleName().toLowerCase()), sort, + mapper.getMappingContext().getRequiredPersistentEntity(entityType)); } static class Person { @@ -560,6 +645,32 @@ class QueryMapperUnitTests { JsonNode jsonNode; } + static class WithEmbeddable { + + @Embedded.Nullable(prefix = "home_") Address home; + + @Embedded.Nullable(prefix = "work_") Address work; + } + + static class Address { + + @Embedded.Nullable(prefix = "country_") Country country; + String street; + + public Address(Country country) { + this.country = country; + } + } + + static class Country { + + String name; + + public Country(String name) { + this.name = name; + } + } + enum MyEnum { ONE, TWO, } diff --git a/spring-data-r2dbc/src/test/java/org/springframework/data/r2dbc/query/UpdateMapperUnitTests.java b/spring-data-r2dbc/src/test/java/org/springframework/data/r2dbc/query/UpdateMapperUnitTests.java index 60100dd71..0c63722e6 100644 --- a/spring-data-r2dbc/src/test/java/org/springframework/data/r2dbc/query/UpdateMapperUnitTests.java +++ b/spring-data-r2dbc/src/test/java/org/springframework/data/r2dbc/query/UpdateMapperUnitTests.java @@ -22,11 +22,13 @@ import java.util.Map; import java.util.stream.Collectors; import org.junit.jupiter.api.Test; + import org.springframework.data.r2dbc.convert.MappingR2dbcConverter; import org.springframework.data.r2dbc.convert.R2dbcConverter; import org.springframework.data.r2dbc.dialect.PostgresDialect; import org.springframework.data.r2dbc.mapping.R2dbcMappingContext; import org.springframework.data.relational.core.mapping.Column; +import org.springframework.data.relational.core.mapping.Embedded; import org.springframework.data.relational.core.query.Update; import org.springframework.data.relational.core.sql.AssignValue; import org.springframework.data.relational.core.sql.Expression; @@ -108,12 +110,66 @@ public class UpdateMapperUnitTests { .containsEntry(SqlIdentifier.unquoted("c2"), SQL.bindMarker("$2")); } + @Test // GH-2096 + void shouldMapPathToEmbeddable() { + + Update update = Update.update("home", new Address(new Country("DE"), "foo")); + + BoundAssignments mapped = map(update, WithEmbeddable.class); + + Map assignments = mapped.getAssignments().stream().map(it -> (AssignValue) it) + .collect(Collectors.toMap(k -> k.getColumn().getName(), AssignValue::getValue)); + + assertThat(assignments).hasSize(2).containsEntry(SqlIdentifier.unquoted("home_country_name"), SQL.bindMarker("$1")) + .containsEntry(SqlIdentifier.unquoted("home_street"), SQL.bindMarker("$2")); + + mapped.getBindings().forEach(it -> { + assertThat(it.getValue()).isIn("DE", "foo"); + }); + } + + @Test // GH-2096 + void shouldMapPathToNestedEmbeddable() { + + Update update = Update.update("home.country", new Country("DE")); + + BoundAssignments mapped = map(update, WithEmbeddable.class); + + Map assignments = mapped.getAssignments().stream().map(it -> (AssignValue) it) + .collect(Collectors.toMap(k -> k.getColumn().getName(), AssignValue::getValue)); + + assertThat(assignments).hasSize(1).containsEntry(SqlIdentifier.unquoted("home_country_name"), SQL.bindMarker("$1")); + mapped.getBindings().forEach(it -> { + assertThat(it.getValue()).isEqualTo("DE"); + }); + } + + @Test // GH-2096 + void shouldMapPathIntoEmbeddable() { + + Update update = Update.update("home.country.name", "DE"); + + BoundAssignments mapped = map(update, WithEmbeddable.class); + + Map assignments = mapped.getAssignments().stream().map(it -> (AssignValue) it) + .collect(Collectors.toMap(k -> k.getColumn().getName(), AssignValue::getValue)); + + assertThat(assignments).hasSize(1).containsEntry(SqlIdentifier.unquoted("home_country_name"), SQL.bindMarker("$1")); + mapped.getBindings().forEach(it -> { + assertThat(it.getValue()).isEqualTo("DE"); + }); + } + private BoundAssignments map(Update update) { + return map(update, Person.class); + } + + private BoundAssignments map(Update update, Class entityType) { BindMarkersFactory markers = BindMarkersFactory.indexed("$", 1); - return mapper.getMappedObject(markers.create(), update, Table.create("person"), - converter.getMappingContext().getRequiredPersistentEntity(Person.class)); + return mapper.getMappedObject(markers.create(), update, Table.create(entityType.getSimpleName().toLowerCase()), + converter.getMappingContext().getRequiredPersistentEntity(entityType)); } static class Person { @@ -121,4 +177,35 @@ public class UpdateMapperUnitTests { String name; @Column("another_name") String alternative; } + + static class WithEmbeddable { + + @Embedded.Nullable(prefix = "home_") Address home; + + @Embedded.Nullable(prefix = "work_") Address work; + } + + static class Address { + + @Embedded.Nullable(prefix = "country_") Country country; + String street; + + public Address(Country country) { + this.country = country; + } + + public Address(Country country, String street) { + this.country = country; + this.street = street; + } + } + + static class Country { + + String name; + + public Country(String name) { + this.name = name; + } + } } diff --git a/spring-data-r2dbc/src/test/java/org/springframework/data/r2dbc/repository/H2R2dbcRepositoryEmbeddedIntegrationTests.java b/spring-data-r2dbc/src/test/java/org/springframework/data/r2dbc/repository/H2R2dbcRepositoryEmbeddedIntegrationTests.java index 31e3bcbb6..714ecf623 100644 --- a/spring-data-r2dbc/src/test/java/org/springframework/data/r2dbc/repository/H2R2dbcRepositoryEmbeddedIntegrationTests.java +++ b/spring-data-r2dbc/src/test/java/org/springframework/data/r2dbc/repository/H2R2dbcRepositoryEmbeddedIntegrationTests.java @@ -15,6 +15,8 @@ */ package org.springframework.data.r2dbc.repository; +import static org.assertj.core.api.Assertions.*; + import io.r2dbc.spi.ConnectionFactory; import reactor.core.publisher.Hooks; import reactor.test.StepVerifier; @@ -28,6 +30,7 @@ import javax.sql.DataSource; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.ComponentScan; @@ -35,6 +38,7 @@ import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.FilterType; import org.springframework.dao.DataAccessException; import org.springframework.data.annotation.Id; +import org.springframework.data.domain.Example; import org.springframework.data.r2dbc.config.AbstractR2dbcConfiguration; import org.springframework.data.r2dbc.convert.R2dbcCustomConversions; import org.springframework.data.r2dbc.mapping.R2dbcMappingContext; @@ -44,6 +48,7 @@ import org.springframework.data.r2dbc.testing.R2dbcIntegrationTestSupport; import org.springframework.data.relational.RelationalManagedTypes; import org.springframework.data.relational.core.mapping.Embedded; import org.springframework.data.relational.core.mapping.NamingStrategy; +import org.springframework.data.repository.query.ReactiveQueryByExampleExecutor; import org.springframework.data.repository.reactive.ReactiveCrudRepository; import org.springframework.jdbc.core.JdbcTemplate; import org.springframework.test.context.ContextConfiguration; @@ -53,6 +58,7 @@ import org.springframework.test.context.junit.jupiter.SpringExtension; * Tests for support of embedded entities. * * @author Jens Schauder + * @author Mark Paluch */ @ExtendWith(SpringExtension.class) @ContextConfiguration @@ -63,7 +69,6 @@ class H2R2dbcRepositoryEmbeddedIntegrationTests extends R2dbcIntegrationTestSupp } @Autowired private PersonRepository repository; - @Autowired private ConnectionFactory connectionFactory; protected JdbcTemplate jdbc; @Configuration @@ -146,7 +151,20 @@ class H2R2dbcRepositoryEmbeddedIntegrationTests extends R2dbcIntegrationTestSupp .verifyComplete(); } - interface PersonRepository extends ReactiveCrudRepository {} + @Test // GH-2096 + void shouldFindUsingQueryByExample() { + + shouldInsertNewItems(); + + Person probe = new Person(null, new Name("Frodo", "Baggins")); + + repository.findAll(Example.of(probe)) // + .as(StepVerifier::create) // + .assertNext(p -> assertThat(p.name.first).isEqualTo("Frodo")) // + .verifyComplete(); + } + + interface PersonRepository extends ReactiveCrudRepository, ReactiveQueryByExampleExecutor {} record Person(@Id Integer id, @Embedded.Empty(prefix = "name_") Name name) { diff --git a/spring-data-r2dbc/src/test/java/org/springframework/data/r2dbc/repository/query/PartTreeR2dbcQueryUnitTests.java b/spring-data-r2dbc/src/test/java/org/springframework/data/r2dbc/repository/query/PartTreeR2dbcQueryUnitTests.java index 9450c0b1b..8dbafba7f 100644 --- a/spring-data-r2dbc/src/test/java/org/springframework/data/r2dbc/repository/query/PartTreeR2dbcQueryUnitTests.java +++ b/spring-data-r2dbc/src/test/java/org/springframework/data/r2dbc/repository/query/PartTreeR2dbcQueryUnitTests.java @@ -50,6 +50,7 @@ import org.springframework.data.r2dbc.core.ReactiveDataAccessStrategy; import org.springframework.data.r2dbc.dialect.DialectResolver; import org.springframework.data.r2dbc.dialect.R2dbcDialect; import org.springframework.data.r2dbc.mapping.R2dbcMappingContext; +import org.springframework.data.relational.core.mapping.Embedded; import org.springframework.data.relational.core.mapping.RelationalMappingContext; import org.springframework.data.relational.core.mapping.Table; import org.springframework.data.relational.core.sql.LockMode; @@ -789,6 +790,46 @@ class PartTreeR2dbcQueryUnitTests { .where(TABLE + ".first_name = $1 LIMIT 1"); } + @Test // GH-2096 + void createsQueryForEmbeddable() throws Exception { + + R2dbcQueryMethod queryMethod = getQueryMethod(WithEmbeddableRepository.class, "findByHome", Address.class); + PartTreeR2dbcQuery r2dbcQuery = new PartTreeR2dbcQuery(queryMethod, operations, r2dbcConverter, dataAccessStrategy); + PreparedOperation query = createQuery(queryMethod, r2dbcQuery, new Address(new Country("DE"))); + + PreparedOperationAssert.assertThat(query) // + .selects("with_embeddable.home_country_name", "with_embeddable.work_country_name") // + .from("with_embeddable") // + .where("with_embeddable.home_country_name = $1"); + } + + @Test // GH-2096 + void createsQueryForNestedEmbeddable() throws Exception { + + R2dbcQueryMethod queryMethod = getQueryMethod(WithEmbeddableRepository.class, "findByHomeCountry", Country.class); + PartTreeR2dbcQuery r2dbcQuery = new PartTreeR2dbcQuery(queryMethod, operations, r2dbcConverter, dataAccessStrategy); + PreparedOperation query = createQuery(queryMethod, r2dbcQuery, new Country("DE")); + + PreparedOperationAssert.assertThat(query) // + .selects("with_embeddable.home_country_name", "with_embeddable.work_country_name") // + .from("with_embeddable") // + .where("with_embeddable.home_country_name = $1"); + } + + @Test // GH-2096 + void createsQueryForNestedEmbeddableValue() throws Exception { + + R2dbcQueryMethod queryMethod = getQueryMethod(WithEmbeddableRepository.class, "findByHomeCountryName", + String.class); + PartTreeR2dbcQuery r2dbcQuery = new PartTreeR2dbcQuery(queryMethod, operations, r2dbcConverter, dataAccessStrategy); + PreparedOperation query = createQuery(queryMethod, r2dbcQuery, "DE"); + + PreparedOperationAssert.assertThat(query) // + .selects("with_embeddable.home_country_name", "with_embeddable.work_country_name") // + .from("with_embeddable") // + .where("with_embeddable.home_country_name = $1"); + } + private PreparedOperation createQuery(R2dbcQueryMethod queryMethod, PartTreeR2dbcQuery r2dbcQuery, Object... parameters) { return createQuery(r2dbcQuery, getAccessor(queryMethod, parameters)); @@ -1001,6 +1042,15 @@ class PartTreeR2dbcQueryUnitTests { Mono countByFirstName(String firstName); } + interface WithEmbeddableRepository extends Repository { + + Mono findByHome(Address home); + + Mono findByHomeCountry(Country homeCountry); + + Mono findByHomeCountryName(String homeCountryName); + } + @Table("users") private static class User { @@ -1038,4 +1088,29 @@ class PartTreeR2dbcQueryUnitTests { String firstName; String unknown; } + + static class WithEmbeddable { + + @Embedded.Nullable(prefix = "home_") Address home; + + @Embedded.Nullable(prefix = "work_") Address work; + } + + static class Address { + + @Embedded.Nullable(prefix = "country_") Country country; + + public Address(Country country) { + this.country = country; + } + } + + static class Country { + + String name; + + public Country(String name) { + this.name = name; + } + } }