Browse Source

Polishing.

Refine MongoVector factory methods for a more natural adoption and terminology when creating vectors.

See #4706
pull/4915/merge
Mark Paluch 7 months ago
parent
commit
df3abef717
No known key found for this signature in database
GPG Key ID: 55BC6374BAA9D973
  1. 88
      spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/mapping/MongoVector.java
  2. 6
      spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/convert/MongoConvertersIntegrationTests.java
  3. 79
      spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/mapping/MongoVectorUnitTests.java

88
spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/mapping/MongoVector.java

@ -24,9 +24,9 @@ import org.springframework.data.domain.Vector; @@ -24,9 +24,9 @@ import org.springframework.data.domain.Vector;
import org.springframework.util.ObjectUtils;
/**
* MongoDB-specific extension to {@link Vector} based on Mongo's {@link BinaryVector}. Note that only float32 and int8
* variants can be represented as floating-point numbers. int1 returns an all-zero array for {@link #toFloatArray()} and
* {@link #toDoubleArray()}.
* MongoDB-specific extension to {@link Vector} based on Mongo's {@link BinaryVector}. Note that only {@code float32}
* and {@code int8} variants can be represented as floating-point numbers. {@code int1} throws
* {@link UnsupportedOperationException} when calling {@link #toFloatArray()} and {@link #toDoubleArray()}.
*
* @author Mark Paluch
* @since 4.5
@ -40,15 +40,65 @@ public class MongoVector implements Vector { @@ -40,15 +40,65 @@ public class MongoVector implements Vector {
}
/**
* Creates a new {@link MongoVector} from the given {@link BinaryVector}.
* Creates a new binary {@link MongoVector} using the given {@link BinaryVector}.
*
* @param v binary vector representation.
* @return the {@link MongoVector} for the given vector values.
* @return the {@link MongoVector} wrapping {@link BinaryVector}.
*/
public static MongoVector of(BinaryVector v) {
return new MongoVector(v);
}
/**
* Creates a new binary {@link MongoVector} using the given {@code data}.
* <p>
* A {@link BinaryVector.DataType#INT8} vector is a vector of 8-bit signed integers where each byte in the vector
* represents an element of a vector, with values in the range {@code [-128, 127]}.
* <p>
* NOTE: The byte array is not copied; changes to the provided array will be referenced in the created
* {@code MongoVector} instance.
*
* @param data the byte array representing the {@link BinaryVector.DataType#INT8} vector data.
* @return the {@link MongoVector} containing the given vector values to be represented as binary {@code int8}.
*/
public static MongoVector ofInt8(byte[] data) {
return of(BinaryVector.int8Vector(data));
}
/**
* Creates a new binary {@link MongoVector} using the given {@code data}.
* <p>
* A {@link BinaryVector.DataType#FLOAT32} vector is a vector of floating-point numbers, where each element in the
* vector is a {@code float}.
* <p>
* NOTE: The float array is not copied; changes to the provided array will be referenced in the created
* {@code MongoVector} instance.
*
* @param data the float array representing the {@link BinaryVector.DataType#FLOAT32} vector data.
* @return the {@link MongoVector} containing the given vector values to be represented as binary {@code float32}.
*/
public static MongoVector ofFloat(float... data) {
return of(BinaryVector.floatVector(data));
}
/**
* Creates a new binary {@link MongoVector} from the given {@link Vector}.
* <p>
* A {@link BinaryVector.DataType#FLOAT32} vector is a vector of floating-point numbers, where each element in the
* vector is a {@code float}. The given {@link Vector} must be able to return a {@link Vector#toFloatArray() float}
* array.
* <p>
* NOTE: The float array is not copied; changes to the provided array will be referenced in the created
* {@code MongoVector} instance.
*
* @param v the
* @return the {@link MongoVector} using vector values from the given {@link Vector} to be represented as binary
* float32.
*/
public static MongoVector fromFloat(Vector v) {
return of(BinaryVector.floatVector(v.toFloatArray()));
}
@Override
public Class<? extends Number> getType() {
@ -90,6 +140,11 @@ public class MongoVector implements Vector { @@ -90,6 +140,11 @@ public class MongoVector implements Vector {
return 0;
}
/**
* {@inheritDoc}
*
* @throws UnsupportedOperationException if the underlying data type is {@code int1} {@link PackedBitBinaryVector}.
*/
@Override
public float[] toFloatArray() {
@ -102,14 +157,22 @@ public class MongoVector implements Vector { @@ -102,14 +157,22 @@ public class MongoVector implements Vector {
if (v instanceof Int8BinaryVector i) {
float[] result = new float[i.getData().length];
System.arraycopy(i.getData(), 0, result, 0, result.length);
byte[] data = i.getData();
float[] result = new float[data.length];
for (int j = 0; j < data.length; j++) {
result[j] = data[j];
}
return result;
}
return new float[size()];
throw new UnsupportedOperationException("Cannot return float array for " + v.getClass());
}
/**
* {@inheritDoc}
*
* @throws UnsupportedOperationException if the underlying data type is {@code int1} {@link PackedBitBinaryVector}.
*/
@Override
public double[] toDoubleArray() {
@ -126,12 +189,15 @@ public class MongoVector implements Vector { @@ -126,12 +189,15 @@ public class MongoVector implements Vector {
if (v instanceof Int8BinaryVector i) {
double[] result = new double[i.getData().length];
System.arraycopy(i.getData(), 0, result, 0, result.length);
byte[] data = i.getData();
double[] result = new double[data.length];
for (int j = 0; j < data.length; j++) {
result[j] = data[j];
}
return result;
}
return new double[size()];
throw new UnsupportedOperationException("Cannot return double array for " + v.getClass());
}
@Override

6
spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/convert/MongoConvertersIntegrationTests.java

@ -138,7 +138,7 @@ public class MongoConvertersIntegrationTests { @@ -138,7 +138,7 @@ public class MongoConvertersIntegrationTests {
WithVectors source = new WithVectors();
source.binVector = BinaryVector.floatVector(new float[] { 1.1f, 2.2f, 3.3f });
source.vector = MongoVector.of(source.binVector);
source.vector = MongoVector.ofFloat(new float[] { 1.1f, 2.2f, 3.3f });
template.save(source);
@ -146,6 +146,7 @@ public class MongoConvertersIntegrationTests { @@ -146,6 +146,7 @@ public class MongoConvertersIntegrationTests {
assertThat(loaded.vector).isEqualTo(source.vector);
assertThat(loaded.binVector).isEqualTo(source.binVector);
assertThat(loaded.binVector).isEqualTo(source.vector.getSource());
}
@Test // GH-4706
@ -153,7 +154,7 @@ public class MongoConvertersIntegrationTests { @@ -153,7 +154,7 @@ public class MongoConvertersIntegrationTests {
WithVectors source = new WithVectors();
source.binVector = BinaryVector.int8Vector(new byte[] { 1, 2, 3 });
source.vector = MongoVector.of(source.binVector);
source.vector = MongoVector.ofInt8(new byte[] { 1, 2, 3 });
template.save(source);
@ -161,6 +162,7 @@ public class MongoConvertersIntegrationTests { @@ -161,6 +162,7 @@ public class MongoConvertersIntegrationTests {
assertThat(loaded.vector).isEqualTo(source.vector);
assertThat(loaded.binVector).isEqualTo(source.binVector);
assertThat(loaded.binVector).isEqualTo(source.vector.getSource());
}
@Test // GH-4706

79
spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/mapping/MongoVectorUnitTests.java

@ -0,0 +1,79 @@ @@ -0,0 +1,79 @@
/*
* Copyright 2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.data.mongodb.core.mapping;
import static org.springframework.data.mongodb.test.util.Assertions.*;
import org.bson.BinaryVector;
import org.bson.Float32BinaryVector;
import org.junit.jupiter.api.Test;
import org.springframework.data.domain.Vector;
/**
* Unit tests for {@link MongoVector}.
*
* @author Mark Paluch
*/
class MongoVectorUnitTests {
@Test // GH-4706
void shouldReturnInt8AsFloatingPoints() {
MongoVector vector = MongoVector.ofInt8(new byte[] { 1, 2, 3 });
assertThat(vector.toDoubleArray()).contains(1, 2, 3);
assertThat(vector.toFloatArray()).contains(1, 2, 3);
}
@Test // GH-4706
void shouldReturnFloatAsFloatingPoints() {
MongoVector vector = MongoVector.ofFloat(1f, 2f, 3f);
assertThat(vector.toDoubleArray()).contains(1, 2, 3);
assertThat(vector.toFloatArray()).contains(1, 2, 3);
}
@Test // GH-4706
void ofFloatIsNotEqualToVectorOf() {
MongoVector mv = MongoVector.ofFloat(1f, 2f, 3f);
Vector v = Vector.of(1f, 2f, 3f);
assertThat(v).isNotEqualTo(mv);
}
@Test // GH-4706
void mongoVectorCanAdaptToFloatVector() {
Vector v = Vector.of(1f, 2f, 3f);
MongoVector mv = MongoVector.fromFloat(v);
assertThat(mv.toFloatArray()).isEqualTo(v.toFloatArray());
assertThat(mv.getSource()).isInstanceOf(Float32BinaryVector.class);
}
@Test // GH-4706
void shouldNotReturnFloatsForPackedBit() {
MongoVector vector = MongoVector.of(BinaryVector.packedBitVector(new byte[] { 1, 2, 3 }, (byte) 1));
assertThatExceptionOfType(UnsupportedOperationException.class).isThrownBy(vector::toFloatArray);
assertThatExceptionOfType(UnsupportedOperationException.class).isThrownBy(vector::toDoubleArray);
}
}
Loading…
Cancel
Save