diff --git a/spring-data-r2dbc/src/test/java/org/springframework/data/r2dbc/core/PostgresReactiveDataAccessStrategyTests.java b/spring-data-r2dbc/src/test/java/org/springframework/data/r2dbc/core/PostgresReactiveDataAccessStrategyTests.java index 8014e08bf..dbee15388 100644 --- a/spring-data-r2dbc/src/test/java/org/springframework/data/r2dbc/core/PostgresReactiveDataAccessStrategyTests.java +++ b/spring-data-r2dbc/src/test/java/org/springframework/data/r2dbc/core/PostgresReactiveDataAccessStrategyTests.java @@ -15,15 +15,10 @@ */ package org.springframework.data.r2dbc.core; +import static org.mockito.Mockito.*; +import static org.springframework.data.r2dbc.testing.Assertions.*; + import io.r2dbc.postgresql.codec.Interval; -import org.junit.jupiter.api.Test; -import org.springframework.core.convert.converter.Converter; -import org.springframework.data.convert.ReadingConverter; -import org.springframework.data.convert.WritingConverter; -import org.springframework.data.r2dbc.convert.EnumWriteSupport; -import org.springframework.data.r2dbc.dialect.PostgresDialect; -import org.springframework.data.r2dbc.mapping.OutboundRow; -import org.springframework.data.relational.core.sql.SqlIdentifier; import java.time.Duration; import java.util.ArrayList; @@ -33,7 +28,18 @@ import java.util.EnumSet; import java.util.List; import java.util.Set; -import static org.springframework.data.r2dbc.testing.Assertions.*; +import org.junit.jupiter.api.Test; +import org.springframework.core.convert.converter.Converter; +import org.springframework.data.convert.ReadingConverter; +import org.springframework.data.convert.WritingConverter; +import org.springframework.data.r2dbc.convert.EnumWriteSupport; +import org.springframework.data.r2dbc.core.StatementMapper.InsertSpec; +import org.springframework.data.r2dbc.dialect.PostgresDialect; +import org.springframework.data.r2dbc.mapping.OutboundRow; +import org.springframework.data.relational.core.sql.SqlIdentifier; +import org.springframework.r2dbc.core.Parameter; +import org.springframework.r2dbc.core.PreparedOperation; +import org.springframework.r2dbc.core.binding.BindTarget; /** * {@link PostgresDialect} specific tests for {@link ReactiveDataAccessStrategy}. @@ -58,6 +64,20 @@ public class PostgresReactiveDataAccessStrategyTests extends ReactiveDataAccessS assertThat(row).withColumn("myarray").hasValueInstanceOf(Integer[][].class); } + @Test // GH-1593 + void shouldConvertEnumsCorrectly() { + + StatementMapper mapper = strategy.getStatementMapper(); + MyEnum[] value = { MyEnum.ONE }; + InsertSpec insert = mapper.createInsert("table").withColumn("my_col", Parameter.from(value)); + PreparedOperation mappedObject = mapper.getMappedObject(insert); + + BindTarget bindTarget = mock(BindTarget.class); + mappedObject.bindTo(bindTarget); + + verify(bindTarget).bind(0, new String[] { "ONE" }); + } + @Test // gh-161 void shouldConvertNullArrayToDriverArrayType() { diff --git a/spring-data-relational/src/main/java/org/springframework/data/relational/core/conversion/BasicRelationalConverter.java b/spring-data-relational/src/main/java/org/springframework/data/relational/core/conversion/BasicRelationalConverter.java index 2a355c84b..b5bf5534f 100644 --- a/spring-data-relational/src/main/java/org/springframework/data/relational/core/conversion/BasicRelationalConverter.java +++ b/spring-data-relational/src/main/java/org/springframework/data/relational/core/conversion/BasicRelationalConverter.java @@ -15,6 +15,7 @@ */ package org.springframework.data.relational.core.conversion; +import java.lang.reflect.Array; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; @@ -161,44 +162,19 @@ public class BasicRelationalConverter implements RelationalConverter { if (getConversions().isSimpleType(value.getClass())) { - if (TypeInformation.OBJECT != type) { - - if (conversionService.canConvert(value.getClass(), type.getType())) { - value = conversionService.convert(value, type.getType()); - } + if (TypeInformation.OBJECT != type && conversionService.canConvert(value.getClass(), type.getType())) { + value = conversionService.convert(value, type.getType()); } return getPotentiallyConvertedSimpleWrite(value); } - // TODO: We should add conversion support for arrays, however, - // these should consider multi-dimensional arrays as well. - if (value.getClass().isArray() // - && !value.getClass().getComponentType().isEnum() // - && (TypeInformation.OBJECT.equals(type) // - || type.isCollectionLike()) // - ) { - return value; + if (value.getClass().isArray()) { + return writeArray(value, type); } if (value instanceof Collection) { - - List mapped = new ArrayList<>(); - - TypeInformation component = TypeInformation.OBJECT; - if (type.isCollectionLike() && type.getActualType() != null) { - component = type.getRequiredComponentType(); - } - - for (Object o : (Iterable) value) { - mapped.add(writeValue(o, component)); - } - - if (type.getType().isInstance(mapped) || !type.isCollectionLike()) { - return mapped; - } - - return conversionService.convert(mapped, type.getType()); + return writeCollection((Iterable) value, type); } RelationalPersistentEntity persistentEntity = context.getPersistentEntity(value.getClass()); @@ -212,6 +188,57 @@ public class BasicRelationalConverter implements RelationalConverter { return conversionService.convert(value, type.getType()); } + private Object writeArray(Object value, TypeInformation type) { + + Class componentType = value.getClass().getComponentType(); + Optional> optionalWriteTarget = getConversions().getCustomWriteTarget(componentType); + + if (optionalWriteTarget.isEmpty() && !componentType.isEnum()) { + return value; + } + + Class customWriteTarget = optionalWriteTarget + .orElseGet(() -> componentType.isEnum() ? String.class : componentType); + + // optimization: bypass identity conversion + if (customWriteTarget.equals(componentType)) { + return value; + } + + TypeInformation component = TypeInformation.OBJECT; + if (type.isCollectionLike() && type.getActualType() != null) { + component = type.getRequiredComponentType(); + } + + int length = Array.getLength(value); + Object target = Array.newInstance(customWriteTarget, length); + for (int i = 0; i < length; i++) { + Array.set(target, i, writeValue(Array.get(value, i), component)); + } + + return target; + } + + private Object writeCollection(Iterable value, TypeInformation type) { + + List mapped = new ArrayList<>(); + + TypeInformation component = TypeInformation.OBJECT; + if (type.isCollectionLike() && type.getActualType() != null) { + component = type.getRequiredComponentType(); + } + + for (Object o : value) { + mapped.add(writeValue(o, component)); + } + + if (type.getType().isInstance(mapped) || !type.isCollectionLike()) { + return mapped; + } + + return conversionService.convert(mapped, type.getType()); + } + @Override public EntityInstantiators getEntityInstantiators() { return this.entityInstantiators; diff --git a/spring-data-relational/src/test/java/org/springframework/data/relational/core/conversion/BasicRelationalConverterUnitTests.java b/spring-data-relational/src/test/java/org/springframework/data/relational/core/conversion/BasicRelationalConverterUnitTests.java index e058cb326..df13eb349 100644 --- a/spring-data-relational/src/test/java/org/springframework/data/relational/core/conversion/BasicRelationalConverterUnitTests.java +++ b/spring-data-relational/src/test/java/org/springframework/data/relational/core/conversion/BasicRelationalConverterUnitTests.java @@ -17,14 +17,15 @@ package org.springframework.data.relational.core.conversion; import static org.assertj.core.api.Assertions.*; +import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import java.util.Set; +import java.util.function.Function; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.springframework.core.convert.converter.GenericConverter; import org.springframework.data.convert.ConverterBuilder; +import org.springframework.data.convert.ConverterBuilder.ConverterAware; import org.springframework.data.convert.CustomConversions; import org.springframework.data.mapping.PersistentPropertyAccessor; import org.springframework.data.relational.core.mapping.RelationalMappingContext; @@ -46,8 +47,13 @@ class BasicRelationalConverterUnitTests { @BeforeEach public void before() throws Exception { - Set converters = ConverterBuilder.writing(MyValue.class, String.class, MyValue::foo) - .andReading(MyValue::new).getConverters(); + List converters = new ArrayList<>(); + converters.addAll( + ConverterBuilder.writing(MyValue.class, String.class, MyValue::foo).andReading(MyValue::new).getConverters()); + + ConverterAware converterAware = ConverterBuilder + .writing(MySimpleEnum.class, MySimpleEnum.class, Function.identity()).andReading(mySimpleEnum -> mySimpleEnum); + converters.addAll(converterAware.getConverters()); CustomConversions conversions = new CustomConversions(CustomConversions.StoreConversions.NONE, converters); context.setSimpleTypeHolder(conversions.getSimpleTypeHolder()); @@ -79,7 +85,23 @@ class BasicRelationalConverterUnitTests { assertThat(result).isEqualTo("ON"); } - @Test // DATAJDBC-235 + @Test + void shouldConvertEnumArrayToStringArray() { + + Object result = converter.writeValue(new MyEnum[] { MyEnum.ON }, TypeInformation.OBJECT); + + assertThat(result).isEqualTo(new String[] { "ON" }); + } + + @Test // GH-1593 + void shouldRetainEnumArray() { + + Object result = converter.writeValue(new MySimpleEnum[] { MySimpleEnum.ON }, TypeInformation.OBJECT); + + assertThat(result).isEqualTo(new MySimpleEnum[] { MySimpleEnum.ON }); + } + + @Test // GH-1593 void shouldConvertStringToEnum() { Object result = converter.readValue("OFF", TypeInformation.of(MyEnum.class)); @@ -145,4 +167,8 @@ class BasicRelationalConverterUnitTests { enum MyEnum { ON, OFF; } + + enum MySimpleEnum { + ON, OFF; + } }