This is an automated email from the ASF dual-hosted git repository.
xiangfu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git
The following commit(s) were added to refs/heads/master by this push:
new 293f153395 Fixing flaky VectorTest (#11514)
293f153395 is described below
commit 293f153395aefb1316e2219d855d43f7664f5292
Author: Xiang Fu <[email protected]>
AuthorDate: Sun Sep 10 04:22:17 2023 -0700
Fixing flaky VectorTest (#11514)
---
.../pinot/integration/tests/custom/VectorTest.java | 142 +++++++++++++++------
1 file changed, 102 insertions(+), 40 deletions(-)
diff --git
a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/custom/VectorTest.java
b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/custom/VectorTest.java
index 70a07cbefc..245cd31d05 100644
---
a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/custom/VectorTest.java
+++
b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/custom/VectorTest.java
@@ -23,18 +23,18 @@ import com.google.common.collect.ImmutableList;
import java.io.File;
import java.util.ArrayList;
import java.util.Collection;
-import java.util.List;
+import java.util.stream.IntStream;
import org.apache.avro.file.DataFileWriter;
import org.apache.avro.generic.GenericData;
import org.apache.avro.generic.GenericDatumWriter;
import org.apache.commons.lang3.RandomUtils;
import org.apache.commons.lang3.StringUtils;
+import org.apache.pinot.common.function.scalar.VectorFunctions;
import org.apache.pinot.spi.data.FieldSpec;
import org.apache.pinot.spi.data.Schema;
import org.testng.annotations.Test;
import static org.testng.Assert.assertEquals;
-import static org.testng.Assert.assertTrue;
@Test(suiteName = "CustomClusterIntegrationTest")
@@ -44,6 +44,14 @@ public class VectorTest extends
CustomDataQueryClusterIntegrationTest {
private static final String VECTOR_1 = "vector1";
private static final String VECTOR_2 = "vector2";
private static final String ZERO_VECTOR = "zeroVector";
+ private static final String VECTOR_1_NORM = "vector1Norm";
+ private static final String VECTOR_2_NORM = "vector2Norm";
+ private static final String VECTORS_COSINE_DIST = "vectorsCosineDist";
+ private static final String VECTORS_INNER_PRODUCT = "vectorsInnerProduct";
+ private static final String VECTORS_L1_DIST = "vectorsL1Dist";
+ private static final String VECTORS_L2_DIST = "vectorsL2Dist";
+ private static final String VECTOR_ZERO_L1_DIST = "vectorZeroL1Dist";
+ private static final String VECTOR_ZERO_L2_DIST = "vectorZeroL2Dist";
private static final int VECTOR_DIM_SIZE = 512;
@Override
@@ -58,35 +66,42 @@ public class VectorTest extends
CustomDataQueryClusterIntegrationTest {
String query =
String.format("SELECT "
+ "cosineDistance(vector1, vector2), "
+ + VECTORS_COSINE_DIST + ", "
+ "innerProduct(vector1, vector2), "
+ + VECTORS_INNER_PRODUCT + ", "
+ "l1Distance(vector1, vector2), "
+ + VECTORS_L1_DIST + ", "
+ "l2Distance(vector1, vector2), "
+ + VECTORS_L2_DIST + ", "
+ "vectorDims(vector1), vectorDims(vector2), "
- + "vectorNorm(vector1), vectorNorm(vector2), "
+ + "vectorNorm(vector1), "
+ + VECTOR_1_NORM + ", "
+ + "vectorNorm(vector2), "
+ + VECTOR_2_NORM + ", "
+ "cosineDistance(vector1, zeroVector), "
+ "cosineDistance(vector1, zeroVector, 0) "
+ "FROM %s LIMIT %d", getTableName(), getCountStarResult());
JsonNode jsonNode = postQuery(query);
for (int i = 0; i < getCountStarResult(); i++) {
double cosineDistance =
jsonNode.get("resultTable").get("rows").get(i).get(0).asDouble();
- assertTrue(cosineDistance > 0.1 && cosineDistance < 0.4);
- double innerProduce =
jsonNode.get("resultTable").get("rows").get(i).get(1).asDouble();
- assertTrue(innerProduce > 100 && innerProduce < 160);
- double l1Distance =
jsonNode.get("resultTable").get("rows").get(i).get(2).asDouble();
- assertTrue(l1Distance > 140 && l1Distance < 210);
- double l2Distance =
jsonNode.get("resultTable").get("rows").get(i).get(3).asDouble();
- assertTrue(l2Distance > 8 && l2Distance < 11);
- int vectorDimsVector1 =
jsonNode.get("resultTable").get("rows").get(i).get(4).asInt();
+ assertEquals(cosineDistance,
jsonNode.get("resultTable").get("rows").get(i).get(1).asDouble());
+ double innerProduce =
jsonNode.get("resultTable").get("rows").get(i).get(2).asDouble();
+ assertEquals(innerProduce,
jsonNode.get("resultTable").get("rows").get(i).get(3).asDouble());
+ double l1Distance =
jsonNode.get("resultTable").get("rows").get(i).get(4).asDouble();
+ assertEquals(l1Distance,
jsonNode.get("resultTable").get("rows").get(i).get(5).asDouble());
+ double l2Distance =
jsonNode.get("resultTable").get("rows").get(i).get(6).asDouble();
+ assertEquals(l2Distance,
jsonNode.get("resultTable").get("rows").get(i).get(7).asDouble());
+ int vectorDimsVector1 =
jsonNode.get("resultTable").get("rows").get(i).get(8).asInt();
assertEquals(vectorDimsVector1, VECTOR_DIM_SIZE);
- int vectorDimsVector2 =
jsonNode.get("resultTable").get("rows").get(i).get(5).asInt();
+ int vectorDimsVector2 =
jsonNode.get("resultTable").get("rows").get(i).get(9).asInt();
assertEquals(vectorDimsVector2, VECTOR_DIM_SIZE);
- double vectorNormVector1 =
jsonNode.get("resultTable").get("rows").get(i).get(6).asInt();
- assertTrue(vectorNormVector1 > 10 && vectorNormVector1 < 16);
- double vectorNormVector2 =
jsonNode.get("resultTable").get("rows").get(i).get(7).asInt();
- assertTrue(vectorNormVector2 > 10 && vectorNormVector2 < 16);
- cosineDistance =
jsonNode.get("resultTable").get("rows").get(i).get(8).asDouble();
+ double vectorNormVector1 =
jsonNode.get("resultTable").get("rows").get(i).get(10).asDouble();
+ assertEquals(vectorNormVector1,
jsonNode.get("resultTable").get("rows").get(i).get(11).asDouble());
+ double vectorNormVector2 =
jsonNode.get("resultTable").get("rows").get(i).get(12).asDouble();
+ assertEquals(vectorNormVector2,
jsonNode.get("resultTable").get("rows").get(i).get(13).asDouble());
+ cosineDistance =
jsonNode.get("resultTable").get("rows").get(i).get(14).asDouble();
assertEquals(cosineDistance, Double.NaN);
- cosineDistance =
jsonNode.get("resultTable").get("rows").get(i).get(9).asDouble();
+ cosineDistance =
jsonNode.get("resultTable").get("rows").get(i).get(15).asDouble();
assertEquals(cosineDistance, 0.0);
}
}
@@ -106,7 +121,9 @@ public class VectorTest extends
CustomDataQueryClusterIntegrationTest {
+ "cosineDistance(vector1, %s), "
+ "innerProduct(vector1, %s), "
+ "l1Distance(vector1, %s), "
+ + VECTOR_ZERO_L1_DIST + ", "
+ "l2Distance(vector1, %s), "
+ + VECTOR_ZERO_L2_DIST + ", "
+ "vectorDims(%s), "
+ "vectorNorm(%s) "
+ "FROM %s LIMIT %d",
@@ -119,12 +136,12 @@ public class VectorTest extends
CustomDataQueryClusterIntegrationTest {
double innerProduce =
jsonNode.get("resultTable").get("rows").get(i).get(1).asDouble();
assertEquals(innerProduce, 0.0);
double l1Distance =
jsonNode.get("resultTable").get("rows").get(i).get(2).asDouble();
- assertTrue(l1Distance > 100 && l1Distance < 300);
- double l2Distance =
jsonNode.get("resultTable").get("rows").get(i).get(3).asDouble();
- assertTrue(l2Distance > 10 && l2Distance < 16);
- int vectorDimsVector =
jsonNode.get("resultTable").get("rows").get(i).get(4).asInt();
+ assertEquals(l1Distance,
jsonNode.get("resultTable").get("rows").get(i).get(3).asDouble());
+ double l2Distance =
jsonNode.get("resultTable").get("rows").get(i).get(4).asDouble();
+ assertEquals(l2Distance,
jsonNode.get("resultTable").get("rows").get(i).get(5).asDouble());
+ int vectorDimsVector =
jsonNode.get("resultTable").get("rows").get(i).get(6).asInt();
assertEquals(vectorDimsVector, VECTOR_DIM_SIZE);
- double vectorNormVector =
jsonNode.get("resultTable").get("rows").get(i).get(5).asInt();
+ double vectorNormVector =
jsonNode.get("resultTable").get("rows").get(i).get(7).asDouble();
assertEquals(vectorNormVector, 0.0);
}
@@ -166,6 +183,14 @@ public class VectorTest extends
CustomDataQueryClusterIntegrationTest {
.addMultiValueDimension(VECTOR_1, FieldSpec.DataType.FLOAT)
.addMultiValueDimension(VECTOR_2, FieldSpec.DataType.FLOAT)
.addMultiValueDimension(ZERO_VECTOR, FieldSpec.DataType.FLOAT)
+ .addSingleValueDimension(VECTOR_1_NORM, FieldSpec.DataType.DOUBLE)
+ .addSingleValueDimension(VECTOR_2_NORM, FieldSpec.DataType.DOUBLE)
+ .addSingleValueDimension(VECTORS_COSINE_DIST,
FieldSpec.DataType.DOUBLE)
+ .addSingleValueDimension(VECTORS_INNER_PRODUCT,
FieldSpec.DataType.DOUBLE)
+ .addSingleValueDimension(VECTORS_L1_DIST, FieldSpec.DataType.DOUBLE)
+ .addSingleValueDimension(VECTORS_L2_DIST, FieldSpec.DataType.DOUBLE)
+ .addSingleValueDimension(VECTOR_ZERO_L1_DIST,
FieldSpec.DataType.DOUBLE)
+ .addSingleValueDimension(VECTOR_ZERO_L2_DIST,
FieldSpec.DataType.DOUBLE)
.build();
}
@@ -183,7 +208,31 @@ public class VectorTest extends
CustomDataQueryClusterIntegrationTest {
null),
new org.apache.avro.Schema.Field(ZERO_VECTOR,
org.apache.avro.Schema.createArray(org.apache.avro.Schema.create(
org.apache.avro.Schema.Type.FLOAT)), null,
- null)
+ null),
+ new org.apache.avro.Schema.Field(VECTOR_1_NORM,
+ org.apache.avro.Schema.create(org.apache.avro.Schema.Type.DOUBLE),
+ null, null),
+ new org.apache.avro.Schema.Field(VECTOR_2_NORM,
+ org.apache.avro.Schema.create(org.apache.avro.Schema.Type.DOUBLE),
+ null, null),
+ new org.apache.avro.Schema.Field(VECTORS_COSINE_DIST,
+ org.apache.avro.Schema.create(org.apache.avro.Schema.Type.DOUBLE),
+ null, null),
+ new org.apache.avro.Schema.Field(VECTORS_INNER_PRODUCT,
+ org.apache.avro.Schema.create(org.apache.avro.Schema.Type.DOUBLE),
+ null, null),
+ new org.apache.avro.Schema.Field(VECTORS_L1_DIST,
+ org.apache.avro.Schema.create(org.apache.avro.Schema.Type.DOUBLE),
+ null, null),
+ new org.apache.avro.Schema.Field(VECTORS_L2_DIST,
+ org.apache.avro.Schema.create(org.apache.avro.Schema.Type.DOUBLE),
+ null, null),
+ new org.apache.avro.Schema.Field(VECTOR_ZERO_L1_DIST,
+ org.apache.avro.Schema.create(org.apache.avro.Schema.Type.DOUBLE),
+ null, null),
+ new org.apache.avro.Schema.Field(VECTOR_ZERO_L2_DIST,
+ org.apache.avro.Schema.create(org.apache.avro.Schema.Type.DOUBLE),
+ null, null)
));
// create avro file
@@ -194,12 +243,21 @@ public class VectorTest extends
CustomDataQueryClusterIntegrationTest {
// create avro record
GenericData.Record record = new GenericData.Record(avroSchema);
- Collection<Float> vector1 = createRandomVector(VECTOR_DIM_SIZE);
- Collection<Float> vector2 = createRandomVector(VECTOR_DIM_SIZE);
- Collection<Float> zeroVector = createZeroVector(VECTOR_DIM_SIZE);
- record.put(VECTOR_1, vector1);
- record.put(VECTOR_2, vector2);
- record.put(ZERO_VECTOR, zeroVector);
+ float[] vector1 = createRandomVector(VECTOR_DIM_SIZE);
+ float[] vector2 = createRandomVector(VECTOR_DIM_SIZE);
+ float[] zeroVector = createZeroVector(VECTOR_DIM_SIZE);
+
+ record.put(VECTOR_1, convertToFloatCollection(vector1));
+ record.put(VECTOR_2, convertToFloatCollection(vector2));
+ record.put(ZERO_VECTOR, convertToFloatCollection(zeroVector));
+ record.put(VECTOR_1_NORM, VectorFunctions.vectorNorm(vector1));
+ record.put(VECTOR_2_NORM, VectorFunctions.vectorNorm(vector2));
+ record.put(VECTORS_COSINE_DIST,
VectorFunctions.cosineDistance(vector1, vector2));
+ record.put(VECTORS_INNER_PRODUCT,
VectorFunctions.innerProduct(vector1, vector2));
+ record.put(VECTORS_L1_DIST, VectorFunctions.l1Distance(vector1,
vector2));
+ record.put(VECTORS_L2_DIST, VectorFunctions.l2Distance(vector1,
vector2));
+ record.put(VECTOR_ZERO_L1_DIST, VectorFunctions.l1Distance(vector1,
zeroVector));
+ record.put(VECTOR_ZERO_L2_DIST, VectorFunctions.l2Distance(vector1,
zeroVector));
// add avro record to file
fileWriter.append(record);
@@ -208,19 +266,23 @@ public class VectorTest extends
CustomDataQueryClusterIntegrationTest {
return avroFile;
}
- private Collection<Float> createZeroVector(int vectorDimSize) {
- List<Float> vector = new ArrayList<>();
- for (int i = 0; i < vectorDimSize; i++) {
- vector.add(i, 0.0f);
- }
+ private float[] createZeroVector(int vectorDimSize) {
+ float[] vector = new float[vectorDimSize];
+ IntStream.range(0, vectorDimSize).forEach(i -> vector[i] = 0.0f);
return vector;
}
- private Collection<Float> createRandomVector(int vectorDimSize) {
- List<Float> vector = new ArrayList<>();
- for (int i = 0; i < vectorDimSize; i++) {
- vector.add(i, RandomUtils.nextFloat(0.0f, 1.0f));
- }
+ private float[] createRandomVector(int vectorDimSize) {
+ float[] vector = new float[vectorDimSize];
+ IntStream.range(0, vectorDimSize).forEach(i -> vector[i] =
RandomUtils.nextFloat(0.0f, 1.0f));
return vector;
}
+
+ private Collection<Float> convertToFloatCollection(float[] vector) {
+ Collection<Float> vectorCollection = new ArrayList<>();
+ for (float v : vector) {
+ vectorCollection.add(v);
+ }
+ return vectorCollection;
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]