diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/mapping/MongoVector.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/mapping/MongoVector.java index 3b2e0a45f..f7e0d1ee3 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/mapping/MongoVector.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/mapping/MongoVector.java @@ -24,9 +24,9 @@ import org.springframework.data.domain.Vector; import org.springframework.util.ObjectUtils; /** - * MongoDB-specific extension to {@link Vector} based on Mongo's {@link BinaryVector}. Note that only float32 and int8 - * variants can be represented as floating-point numbers. int1 returns an all-zero array for {@link #toFloatArray()} and - * {@link #toDoubleArray()}. + * MongoDB-specific extension to {@link Vector} based on Mongo's {@link BinaryVector}. Note that only {@code float32} + * and {@code int8} variants can be represented as floating-point numbers. {@code int1} throws + * {@link UnsupportedOperationException} when calling {@link #toFloatArray()} and {@link #toDoubleArray()}. * * @author Mark Paluch * @since 4.5 @@ -40,15 +40,65 @@ public class MongoVector implements Vector { } /** - * Creates a new {@link MongoVector} from the given {@link BinaryVector}. + * Creates a new binary {@link MongoVector} using the given {@link BinaryVector}. * * @param v binary vector representation. - * @return the {@link MongoVector} for the given vector values. + * @return the {@link MongoVector} wrapping {@link BinaryVector}. */ public static MongoVector of(BinaryVector v) { return new MongoVector(v); } + /** + * Creates a new binary {@link MongoVector} using the given {@code data}. + *

+ * A {@link BinaryVector.DataType#INT8} vector is a vector of 8-bit signed integers where each byte in the vector + * represents an element of a vector, with values in the range {@code [-128, 127]}. + *

+ * NOTE: The byte array is not copied; changes to the provided array will be referenced in the created + * {@code MongoVector} instance. + * + * @param data the byte array representing the {@link BinaryVector.DataType#INT8} vector data. + * @return the {@link MongoVector} containing the given vector values to be represented as binary {@code int8}. + */ + public static MongoVector ofInt8(byte[] data) { + return of(BinaryVector.int8Vector(data)); + } + + /** + * Creates a new binary {@link MongoVector} using the given {@code data}. + *

+ * A {@link BinaryVector.DataType#FLOAT32} vector is a vector of floating-point numbers, where each element in the + * vector is a {@code float}. + *

+ * NOTE: The float array is not copied; changes to the provided array will be referenced in the created + * {@code MongoVector} instance. + * + * @param data the float array representing the {@link BinaryVector.DataType#FLOAT32} vector data. + * @return the {@link MongoVector} containing the given vector values to be represented as binary {@code float32}. + */ + public static MongoVector ofFloat(float... data) { + return of(BinaryVector.floatVector(data)); + } + + /** + * Creates a new binary {@link MongoVector} from the given {@link Vector}. + *

+ * A {@link BinaryVector.DataType#FLOAT32} vector is a vector of floating-point numbers, where each element in the + * vector is a {@code float}. The given {@link Vector} must be able to return a {@link Vector#toFloatArray() float} + * array. + *

+ * NOTE: The float array is not copied; changes to the provided array will be referenced in the created + * {@code MongoVector} instance. + * + * @param v the + * @return the {@link MongoVector} using vector values from the given {@link Vector} to be represented as binary + * float32. + */ + public static MongoVector fromFloat(Vector v) { + return of(BinaryVector.floatVector(v.toFloatArray())); + } + @Override public Class getType() { @@ -90,6 +140,11 @@ public class MongoVector implements Vector { return 0; } + /** + * {@inheritDoc} + * + * @throws UnsupportedOperationException if the underlying data type is {@code int1} {@link PackedBitBinaryVector}. + */ @Override public float[] toFloatArray() { @@ -102,14 +157,22 @@ public class MongoVector implements Vector { if (v instanceof Int8BinaryVector i) { - float[] result = new float[i.getData().length]; - System.arraycopy(i.getData(), 0, result, 0, result.length); + byte[] data = i.getData(); + float[] result = new float[data.length]; + for (int j = 0; j < data.length; j++) { + result[j] = data[j]; + } return result; } - return new float[size()]; + throw new UnsupportedOperationException("Cannot return float array for " + v.getClass()); } + /** + * {@inheritDoc} + * + * @throws UnsupportedOperationException if the underlying data type is {@code int1} {@link PackedBitBinaryVector}. + */ @Override public double[] toDoubleArray() { @@ -126,12 +189,15 @@ public class MongoVector implements Vector { if (v instanceof Int8BinaryVector i) { - double[] result = new double[i.getData().length]; - System.arraycopy(i.getData(), 0, result, 0, result.length); + byte[] data = i.getData(); + double[] result = new double[data.length]; + for (int j = 0; j < data.length; j++) { + result[j] = data[j]; + } return result; } - return new double[size()]; + throw new UnsupportedOperationException("Cannot return double array for " + v.getClass()); } @Override diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/convert/MongoConvertersIntegrationTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/convert/MongoConvertersIntegrationTests.java index b57ab35ea..a1c2fc089 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/convert/MongoConvertersIntegrationTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/convert/MongoConvertersIntegrationTests.java @@ -138,7 +138,7 @@ public class MongoConvertersIntegrationTests { WithVectors source = new WithVectors(); source.binVector = BinaryVector.floatVector(new float[] { 1.1f, 2.2f, 3.3f }); - source.vector = MongoVector.of(source.binVector); + source.vector = MongoVector.ofFloat(new float[] { 1.1f, 2.2f, 3.3f }); template.save(source); @@ -146,6 +146,7 @@ public class MongoConvertersIntegrationTests { assertThat(loaded.vector).isEqualTo(source.vector); assertThat(loaded.binVector).isEqualTo(source.binVector); + assertThat(loaded.binVector).isEqualTo(source.vector.getSource()); } @Test // GH-4706 @@ -153,7 +154,7 @@ public class MongoConvertersIntegrationTests { WithVectors source = new WithVectors(); source.binVector = BinaryVector.int8Vector(new byte[] { 1, 2, 3 }); - source.vector = MongoVector.of(source.binVector); + source.vector = MongoVector.ofInt8(new byte[] { 1, 2, 3 }); template.save(source); @@ -161,6 +162,7 @@ public class MongoConvertersIntegrationTests { assertThat(loaded.vector).isEqualTo(source.vector); assertThat(loaded.binVector).isEqualTo(source.binVector); + assertThat(loaded.binVector).isEqualTo(source.vector.getSource()); } @Test // GH-4706 diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/mapping/MongoVectorUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/mapping/MongoVectorUnitTests.java new file mode 100644 index 000000000..31eeebdb8 --- /dev/null +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/mapping/MongoVectorUnitTests.java @@ -0,0 +1,79 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.core.mapping; + +import static org.springframework.data.mongodb.test.util.Assertions.*; + +import org.bson.BinaryVector; +import org.bson.Float32BinaryVector; +import org.junit.jupiter.api.Test; + +import org.springframework.data.domain.Vector; + +/** + * Unit tests for {@link MongoVector}. + * + * @author Mark Paluch + */ +class MongoVectorUnitTests { + + @Test // GH-4706 + void shouldReturnInt8AsFloatingPoints() { + + MongoVector vector = MongoVector.ofInt8(new byte[] { 1, 2, 3 }); + + assertThat(vector.toDoubleArray()).contains(1, 2, 3); + assertThat(vector.toFloatArray()).contains(1, 2, 3); + } + + @Test // GH-4706 + void shouldReturnFloatAsFloatingPoints() { + + MongoVector vector = MongoVector.ofFloat(1f, 2f, 3f); + + assertThat(vector.toDoubleArray()).contains(1, 2, 3); + assertThat(vector.toFloatArray()).contains(1, 2, 3); + } + + @Test // GH-4706 + void ofFloatIsNotEqualToVectorOf() { + + MongoVector mv = MongoVector.ofFloat(1f, 2f, 3f); + Vector v = Vector.of(1f, 2f, 3f); + + assertThat(v).isNotEqualTo(mv); + } + + @Test // GH-4706 + void mongoVectorCanAdaptToFloatVector() { + + Vector v = Vector.of(1f, 2f, 3f); + MongoVector mv = MongoVector.fromFloat(v); + + assertThat(mv.toFloatArray()).isEqualTo(v.toFloatArray()); + assertThat(mv.getSource()).isInstanceOf(Float32BinaryVector.class); + } + + @Test // GH-4706 + void shouldNotReturnFloatsForPackedBit() { + + MongoVector vector = MongoVector.of(BinaryVector.packedBitVector(new byte[] { 1, 2, 3 }, (byte) 1)); + + assertThatExceptionOfType(UnsupportedOperationException.class).isThrownBy(vector::toFloatArray); + assertThatExceptionOfType(UnsupportedOperationException.class).isThrownBy(vector::toDoubleArray); + } + +}