diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/util/BsonUtils.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/util/BsonUtils.java index c9e746b78..8321e2889 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/util/BsonUtils.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/util/BsonUtils.java @@ -15,32 +15,33 @@ */ package org.springframework.data.mongodb.util; -import java.time.Instant; -import java.time.LocalDate; -import java.time.LocalDateTime; -import java.time.LocalTime; -import java.time.ZoneOffset; -import java.time.temporal.Temporal; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.Date; +import java.util.List; import java.util.Map; import java.util.StringJoiner; import java.util.function.Function; import java.util.stream.StreamSupport; import org.bson.*; +import org.bson.codecs.Codec; import org.bson.codecs.DocumentCodec; +import org.bson.codecs.EncoderContext; +import org.bson.codecs.configuration.CodecConfigurationException; import org.bson.codecs.configuration.CodecRegistry; import org.bson.conversions.Bson; import org.bson.json.JsonParseException; import org.bson.types.Binary; +import org.bson.types.Decimal128; import org.bson.types.ObjectId; import org.springframework.core.convert.converter.Converter; import org.springframework.data.mongodb.CodecRegistryProvider; import org.springframework.lang.Nullable; import org.springframework.util.Assert; +import org.springframework.util.ClassUtils; import org.springframework.util.CollectionUtils; import org.springframework.util.ObjectUtils; import org.springframework.util.StringUtils; @@ -109,7 +110,7 @@ public class BsonUtils { return dbo.toMap(); } - return new Document((Map) bson.toBsonDocument(Document.class, codecRegistry)); + return new Document(bson.toBsonDocument(Document.class, codecRegistry)); } /** @@ -327,6 +328,20 @@ public class BsonUtils { * @since 3.0 */ public static BsonValue simpleToBsonValue(Object source) { + return simpleToBsonValue(source, MongoClientSettings.getDefaultCodecRegistry()); + } + + /** + * Convert a given simple value (eg. {@link String}, {@link Long}) to its corresponding {@link BsonValue}. + * + * @param source must not be {@literal null}. + * @param codecRegistry The {@link CodecRegistry} used as a fallback to convert types using native {@link Codec}. Must + * not be {@literal null}. + * @return the corresponding {@link BsonValue} representation. + * @throws IllegalArgumentException if {@literal source} does not correspond to a {@link BsonValue} type. + * @since 4.2 + */ + public static BsonValue simpleToBsonValue(Object source, CodecRegistry codecRegistry) { if (source instanceof BsonValue bsonValue) { return bsonValue; @@ -364,31 +379,30 @@ public class BsonUtils { return new BsonDouble(floatValue); } - if(source instanceof Binary binary) { + if (source instanceof Binary binary) { return new BsonBinary(binary.getType(), binary.getData()); } - if(source instanceof Temporal) { - if (source instanceof Instant value) { - return new BsonDateTime(value.toEpochMilli()); - } - if (source instanceof LocalDateTime value) { - return new BsonDateTime(value.toInstant(ZoneOffset.UTC).toEpochMilli()); - } - if(source instanceof LocalDate value) { - return new BsonDateTime(value.atStartOfDay(ZoneOffset.UTC).toInstant().toEpochMilli()); - } - if(source instanceof LocalTime value) { - return new BsonDateTime(value.atDate(LocalDate.ofEpochDay(0L)).toInstant(ZoneOffset.UTC).toEpochMilli()); - } - } - - if(source instanceof Date date) { + if (source instanceof Date date) { new BsonDateTime(date.getTime()); } - throw new IllegalArgumentException(String.format("Unable to convert %s (%s) to BsonValue.", source, - source != null ? source.getClass().getName() : "null")); + try { + + Object value = source; + if (ClassUtils.isPrimitiveArray(source.getClass())) { + value = CollectionUtils.arrayToList(source); + } + + Codec codec = codecRegistry.get(value.getClass()); + BsonCapturingWriter writer = new BsonCapturingWriter(value.getClass()); + codec.encode(writer, value, + ObjectUtils.isArray(value) || value instanceof Collection ? EncoderContext.builder().build() : null); + return writer.getCapturedValue(); + } catch (CodecConfigurationException e) { + throw new IllegalArgumentException( + String.format("Unable to convert %s to BsonValue.", source != null ? source.getClass().getName() : "null")); + } } /** @@ -694,7 +708,7 @@ public class BsonUtils { if (value instanceof Collection collection) { return toString(collection); - } else if (value instanceof Map map) { + } else if (value instanceof Map map) { return toString(map); } else if (ObjectUtils.isArray(value)) { return toString(Arrays.asList(ObjectUtils.toObjectArray(value))); @@ -733,4 +747,162 @@ public class BsonUtils { return joiner.toString(); } + + private static class BsonCapturingWriter extends AbstractBsonWriter { + + List values = new ArrayList<>(0); + + public BsonCapturingWriter(Class type) { + super(new BsonWriterSettings()); + if (ClassUtils.isAssignable(Map.class, type)) { + setContext(new Context(null, BsonContextType.DOCUMENT)); + } else if (ClassUtils.isAssignable(List.class, type) || type.isArray()) { + setContext(new Context(null, BsonContextType.ARRAY)); + } else { + setContext(new Context(null, BsonContextType.DOCUMENT)); + } + } + + BsonValue getCapturedValue() { + + if (values.isEmpty()) { + return null; + } + if (!getContext().getContextType().equals(BsonContextType.ARRAY)) { + return values.get(0); + } + + return new BsonArray(values); + } + + @Override + protected void doWriteStartDocument() { + + } + + @Override + protected void doWriteEndDocument() { + + } + + @Override + public void writeStartArray() { + setState(State.VALUE); + } + + @Override + public void writeEndArray() { + setState(State.NAME); + } + + @Override + protected void doWriteStartArray() { + + } + + @Override + protected void doWriteEndArray() { + + } + + @Override + protected void doWriteBinaryData(BsonBinary value) { + values.add(value); + } + + @Override + protected void doWriteBoolean(boolean value) { + values.add(BsonBoolean.valueOf(value)); + } + + @Override + protected void doWriteDateTime(long value) { + values.add(new BsonDateTime(value)); + } + + @Override + protected void doWriteDBPointer(BsonDbPointer value) { + values.add(value); + } + + @Override + protected void doWriteDouble(double value) { + values.add(new BsonDouble(value)); + } + + @Override + protected void doWriteInt32(int value) { + values.add(new BsonInt32(value)); + } + + @Override + protected void doWriteInt64(long value) { + values.add(new BsonInt64(value)); + } + + @Override + protected void doWriteDecimal128(Decimal128 value) { + values.add(new BsonDecimal128(value)); + } + + @Override + protected void doWriteJavaScript(String value) { + values.add(new BsonJavaScript(value)); + } + + @Override + protected void doWriteJavaScriptWithScope(String value) { + values.add(new BsonJavaScriptWithScope(value, null)); + } + + @Override + protected void doWriteMaxKey() { + + } + + @Override + protected void doWriteMinKey() { + + } + + @Override + protected void doWriteNull() { + values.add(new BsonNull()); + } + + @Override + protected void doWriteObjectId(ObjectId value) { + values.add(new BsonObjectId(value)); + } + + @Override + protected void doWriteRegularExpression(BsonRegularExpression value) { + values.add(value); + } + + @Override + protected void doWriteString(String value) { + values.add(new BsonString(value)); + } + + @Override + protected void doWriteSymbol(String value) { + values.add(new BsonSymbol(value)); + } + + @Override + protected void doWriteTimestamp(BsonTimestamp value) { + values.add(value); + } + + @Override + protected void doWriteUndefined() { + values.add(new BsonUndefined()); + } + + @Override + public void flush() { + values.clear(); + } + } } diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/util/json/BsonUtilsTest.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/util/json/BsonUtilsTest.java index cd23e9581..e9cc62815 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/util/json/BsonUtilsTest.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/util/json/BsonUtilsTest.java @@ -17,10 +17,19 @@ package org.springframework.data.mongodb.util.json; import static org.assertj.core.api.Assertions.*; +import java.time.Instant; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.LocalTime; +import java.time.temporal.Temporal; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; +import java.util.Date; +import java.util.List; +import java.util.stream.Stream; +import org.bson.BsonArray; import org.bson.BsonDouble; import org.bson.BsonInt32; import org.bson.BsonInt64; @@ -29,7 +38,9 @@ import org.bson.BsonString; import org.bson.Document; import org.bson.types.ObjectId; import org.junit.jupiter.api.Test; - +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import org.springframework.data.mongodb.util.BsonUtils; import com.mongodb.BasicDBList; @@ -105,9 +116,9 @@ class BsonUtilsTest { @Test // GH-3571 void asCollectionConvertsArrayToCollection() { - Object source = new String[]{"one", "two"}; + Object source = new String[] { "one", "two" }; - assertThat((Collection)BsonUtils.asCollection(source)).containsExactly("one", "two"); + assertThat((Collection) BsonUtils.asCollection(source)).containsExactly("one", "two"); } @Test // GH-3571 @@ -115,7 +126,7 @@ class BsonUtilsTest { Object source = 100L; - assertThat((Collection)BsonUtils.asCollection(source)).containsExactly(source); + assertThat((Collection) BsonUtils.asCollection(source)).containsExactly(source); } @Test // GH-3702 @@ -126,4 +137,41 @@ class BsonUtilsTest { assertThat(BsonUtils.supportsBson(new BasicDBList())).isTrue(); assertThat(BsonUtils.supportsBson(Collections.emptyMap())).isTrue(); } + + @ParameterizedTest // GH-4432 + @MethodSource("javaTimeInstances") + void convertsJavaTimeTypesToBsonDateTime(Temporal source) { + + assertThat(BsonUtils.simpleToBsonValue(source)) + .isEqualTo(new Document("value", source).toBsonDocument().get("value")); + } + + @ParameterizedTest // GH-4432 + @MethodSource("collectionLikeInstances") + void convertsCollectionLikeToBsonArray(Object source) { + + assertThat(BsonUtils.simpleToBsonValue(source)) + .isEqualTo(new Document("value", source).toBsonDocument().get("value")); + } + + @Test // GH-4432 + void convertsPrimitiveArrayToBsonArray() { + + assertThat(BsonUtils.simpleToBsonValue(new int[] { 1, 2, 3 })) + .isEqualTo(new BsonArray(List.of(new BsonInt32(1), new BsonInt32(2), new BsonInt32(3)))); + } + + static Stream javaTimeInstances() { + + return Stream.of(Arguments.of(Instant.now()), Arguments.of(LocalDate.now()), Arguments.of(LocalDateTime.now()), + Arguments.of(LocalTime.now())); + } + + static Stream collectionLikeInstances() { + + return Stream.of(Arguments.of(new String[] { "1", "2", "3" }), Arguments.of(List.of("1", "2", "3")), + Arguments.of(new Integer[] { 1, 2, 3 }), Arguments.of(List.of(1, 2, 3)), + Arguments.of(new Date[] { new Date() }), Arguments.of(List.of(new Date())), + Arguments.of(new LocalDate[] { LocalDate.now() }), Arguments.of(List.of(LocalDate.now()))); + } }