This is an automated email from the ASF dual-hosted git repository.

xiangfu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new 293f153395 Fixing flaky VectorTest (#11514)
293f153395 is described below

commit 293f153395aefb1316e2219d855d43f7664f5292
Author: Xiang Fu <[email protected]>
AuthorDate: Sun Sep 10 04:22:17 2023 -0700

    Fixing flaky VectorTest (#11514)
---
 .../pinot/integration/tests/custom/VectorTest.java | 142 +++++++++++++++------
 1 file changed, 102 insertions(+), 40 deletions(-)

diff --git 
a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/custom/VectorTest.java
 
b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/custom/VectorTest.java
index 70a07cbefc..245cd31d05 100644
--- 
a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/custom/VectorTest.java
+++ 
b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/custom/VectorTest.java
@@ -23,18 +23,18 @@ import com.google.common.collect.ImmutableList;
 import java.io.File;
 import java.util.ArrayList;
 import java.util.Collection;
-import java.util.List;
+import java.util.stream.IntStream;
 import org.apache.avro.file.DataFileWriter;
 import org.apache.avro.generic.GenericData;
 import org.apache.avro.generic.GenericDatumWriter;
 import org.apache.commons.lang3.RandomUtils;
 import org.apache.commons.lang3.StringUtils;
+import org.apache.pinot.common.function.scalar.VectorFunctions;
 import org.apache.pinot.spi.data.FieldSpec;
 import org.apache.pinot.spi.data.Schema;
 import org.testng.annotations.Test;
 
 import static org.testng.Assert.assertEquals;
-import static org.testng.Assert.assertTrue;
 
 
 @Test(suiteName = "CustomClusterIntegrationTest")
@@ -44,6 +44,14 @@ public class VectorTest extends 
CustomDataQueryClusterIntegrationTest {
   private static final String VECTOR_1 = "vector1";
   private static final String VECTOR_2 = "vector2";
   private static final String ZERO_VECTOR = "zeroVector";
+  private static final String VECTOR_1_NORM = "vector1Norm";
+  private static final String VECTOR_2_NORM = "vector2Norm";
+  private static final String VECTORS_COSINE_DIST = "vectorsCosineDist";
+  private static final String VECTORS_INNER_PRODUCT = "vectorsInnerProduct";
+  private static final String VECTORS_L1_DIST = "vectorsL1Dist";
+  private static final String VECTORS_L2_DIST = "vectorsL2Dist";
+  private static final String VECTOR_ZERO_L1_DIST = "vectorZeroL1Dist";
+  private static final String VECTOR_ZERO_L2_DIST = "vectorZeroL2Dist";
   private static final int VECTOR_DIM_SIZE = 512;
 
   @Override
@@ -58,35 +66,42 @@ public class VectorTest extends 
CustomDataQueryClusterIntegrationTest {
     String query =
         String.format("SELECT "
             + "cosineDistance(vector1, vector2), "
+            + VECTORS_COSINE_DIST + ", "
             + "innerProduct(vector1, vector2), "
+            + VECTORS_INNER_PRODUCT + ", "
             + "l1Distance(vector1, vector2), "
+            + VECTORS_L1_DIST + ", "
             + "l2Distance(vector1, vector2), "
+            + VECTORS_L2_DIST + ", "
             + "vectorDims(vector1), vectorDims(vector2), "
-            + "vectorNorm(vector1), vectorNorm(vector2), "
+            + "vectorNorm(vector1), "
+            + VECTOR_1_NORM + ", "
+            + "vectorNorm(vector2), "
+            + VECTOR_2_NORM + ", "
             + "cosineDistance(vector1, zeroVector), "
             + "cosineDistance(vector1, zeroVector, 0) "
             + "FROM %s LIMIT %d", getTableName(), getCountStarResult());
     JsonNode jsonNode = postQuery(query);
     for (int i = 0; i < getCountStarResult(); i++) {
       double cosineDistance = 
jsonNode.get("resultTable").get("rows").get(i).get(0).asDouble();
-      assertTrue(cosineDistance > 0.1 && cosineDistance < 0.4);
-      double innerProduce = 
jsonNode.get("resultTable").get("rows").get(i).get(1).asDouble();
-      assertTrue(innerProduce > 100 && innerProduce < 160);
-      double l1Distance = 
jsonNode.get("resultTable").get("rows").get(i).get(2).asDouble();
-      assertTrue(l1Distance > 140 && l1Distance < 210);
-      double l2Distance = 
jsonNode.get("resultTable").get("rows").get(i).get(3).asDouble();
-      assertTrue(l2Distance > 8 && l2Distance < 11);
-      int vectorDimsVector1 = 
jsonNode.get("resultTable").get("rows").get(i).get(4).asInt();
+      assertEquals(cosineDistance, 
jsonNode.get("resultTable").get("rows").get(i).get(1).asDouble());
+      double innerProduce = 
jsonNode.get("resultTable").get("rows").get(i).get(2).asDouble();
+      assertEquals(innerProduce, 
jsonNode.get("resultTable").get("rows").get(i).get(3).asDouble());
+      double l1Distance = 
jsonNode.get("resultTable").get("rows").get(i).get(4).asDouble();
+      assertEquals(l1Distance, 
jsonNode.get("resultTable").get("rows").get(i).get(5).asDouble());
+      double l2Distance = 
jsonNode.get("resultTable").get("rows").get(i).get(6).asDouble();
+      assertEquals(l2Distance, 
jsonNode.get("resultTable").get("rows").get(i).get(7).asDouble());
+      int vectorDimsVector1 = 
jsonNode.get("resultTable").get("rows").get(i).get(8).asInt();
       assertEquals(vectorDimsVector1, VECTOR_DIM_SIZE);
-      int vectorDimsVector2 = 
jsonNode.get("resultTable").get("rows").get(i).get(5).asInt();
+      int vectorDimsVector2 = 
jsonNode.get("resultTable").get("rows").get(i).get(9).asInt();
       assertEquals(vectorDimsVector2, VECTOR_DIM_SIZE);
-      double vectorNormVector1 = 
jsonNode.get("resultTable").get("rows").get(i).get(6).asInt();
-      assertTrue(vectorNormVector1 > 10 && vectorNormVector1 < 16);
-      double vectorNormVector2 = 
jsonNode.get("resultTable").get("rows").get(i).get(7).asInt();
-      assertTrue(vectorNormVector2 > 10 && vectorNormVector2 < 16);
-      cosineDistance = 
jsonNode.get("resultTable").get("rows").get(i).get(8).asDouble();
+      double vectorNormVector1 = 
jsonNode.get("resultTable").get("rows").get(i).get(10).asDouble();
+      assertEquals(vectorNormVector1, 
jsonNode.get("resultTable").get("rows").get(i).get(11).asDouble());
+      double vectorNormVector2 = 
jsonNode.get("resultTable").get("rows").get(i).get(12).asDouble();
+      assertEquals(vectorNormVector2, 
jsonNode.get("resultTable").get("rows").get(i).get(13).asDouble());
+      cosineDistance = 
jsonNode.get("resultTable").get("rows").get(i).get(14).asDouble();
       assertEquals(cosineDistance, Double.NaN);
-      cosineDistance = 
jsonNode.get("resultTable").get("rows").get(i).get(9).asDouble();
+      cosineDistance = 
jsonNode.get("resultTable").get("rows").get(i).get(15).asDouble();
       assertEquals(cosineDistance, 0.0);
     }
   }
@@ -106,7 +121,9 @@ public class VectorTest extends 
CustomDataQueryClusterIntegrationTest {
                 + "cosineDistance(vector1, %s), "
                 + "innerProduct(vector1, %s), "
                 + "l1Distance(vector1, %s), "
+                + VECTOR_ZERO_L1_DIST + ", "
                 + "l2Distance(vector1, %s), "
+                + VECTOR_ZERO_L2_DIST + ", "
                 + "vectorDims(%s), "
                 + "vectorNorm(%s) "
                 + "FROM %s LIMIT %d",
@@ -119,12 +136,12 @@ public class VectorTest extends 
CustomDataQueryClusterIntegrationTest {
       double innerProduce = 
jsonNode.get("resultTable").get("rows").get(i).get(1).asDouble();
       assertEquals(innerProduce, 0.0);
       double l1Distance = 
jsonNode.get("resultTable").get("rows").get(i).get(2).asDouble();
-      assertTrue(l1Distance > 100 && l1Distance < 300);
-      double l2Distance = 
jsonNode.get("resultTable").get("rows").get(i).get(3).asDouble();
-      assertTrue(l2Distance > 10 && l2Distance < 16);
-      int vectorDimsVector = 
jsonNode.get("resultTable").get("rows").get(i).get(4).asInt();
+      assertEquals(l1Distance, 
jsonNode.get("resultTable").get("rows").get(i).get(3).asDouble());
+      double l2Distance = 
jsonNode.get("resultTable").get("rows").get(i).get(4).asDouble();
+      assertEquals(l2Distance, 
jsonNode.get("resultTable").get("rows").get(i).get(5).asDouble());
+      int vectorDimsVector = 
jsonNode.get("resultTable").get("rows").get(i).get(6).asInt();
       assertEquals(vectorDimsVector, VECTOR_DIM_SIZE);
-      double vectorNormVector = 
jsonNode.get("resultTable").get("rows").get(i).get(5).asInt();
+      double vectorNormVector = 
jsonNode.get("resultTable").get("rows").get(i).get(7).asDouble();
       assertEquals(vectorNormVector, 0.0);
     }
 
@@ -166,6 +183,14 @@ public class VectorTest extends 
CustomDataQueryClusterIntegrationTest {
         .addMultiValueDimension(VECTOR_1, FieldSpec.DataType.FLOAT)
         .addMultiValueDimension(VECTOR_2, FieldSpec.DataType.FLOAT)
         .addMultiValueDimension(ZERO_VECTOR, FieldSpec.DataType.FLOAT)
+        .addSingleValueDimension(VECTOR_1_NORM, FieldSpec.DataType.DOUBLE)
+        .addSingleValueDimension(VECTOR_2_NORM, FieldSpec.DataType.DOUBLE)
+        .addSingleValueDimension(VECTORS_COSINE_DIST, 
FieldSpec.DataType.DOUBLE)
+        .addSingleValueDimension(VECTORS_INNER_PRODUCT, 
FieldSpec.DataType.DOUBLE)
+        .addSingleValueDimension(VECTORS_L1_DIST, FieldSpec.DataType.DOUBLE)
+        .addSingleValueDimension(VECTORS_L2_DIST, FieldSpec.DataType.DOUBLE)
+        .addSingleValueDimension(VECTOR_ZERO_L1_DIST, 
FieldSpec.DataType.DOUBLE)
+        .addSingleValueDimension(VECTOR_ZERO_L2_DIST, 
FieldSpec.DataType.DOUBLE)
         .build();
   }
 
@@ -183,7 +208,31 @@ public class VectorTest extends 
CustomDataQueryClusterIntegrationTest {
             null),
         new org.apache.avro.Schema.Field(ZERO_VECTOR, 
org.apache.avro.Schema.createArray(org.apache.avro.Schema.create(
             org.apache.avro.Schema.Type.FLOAT)), null,
-            null)
+            null),
+        new org.apache.avro.Schema.Field(VECTOR_1_NORM,
+            org.apache.avro.Schema.create(org.apache.avro.Schema.Type.DOUBLE),
+            null, null),
+        new org.apache.avro.Schema.Field(VECTOR_2_NORM,
+            org.apache.avro.Schema.create(org.apache.avro.Schema.Type.DOUBLE),
+            null, null),
+        new org.apache.avro.Schema.Field(VECTORS_COSINE_DIST,
+            org.apache.avro.Schema.create(org.apache.avro.Schema.Type.DOUBLE),
+            null, null),
+        new org.apache.avro.Schema.Field(VECTORS_INNER_PRODUCT,
+            org.apache.avro.Schema.create(org.apache.avro.Schema.Type.DOUBLE),
+            null, null),
+        new org.apache.avro.Schema.Field(VECTORS_L1_DIST,
+            org.apache.avro.Schema.create(org.apache.avro.Schema.Type.DOUBLE),
+            null, null),
+        new org.apache.avro.Schema.Field(VECTORS_L2_DIST,
+            org.apache.avro.Schema.create(org.apache.avro.Schema.Type.DOUBLE),
+            null, null),
+        new org.apache.avro.Schema.Field(VECTOR_ZERO_L1_DIST,
+            org.apache.avro.Schema.create(org.apache.avro.Schema.Type.DOUBLE),
+            null, null),
+        new org.apache.avro.Schema.Field(VECTOR_ZERO_L2_DIST,
+            org.apache.avro.Schema.create(org.apache.avro.Schema.Type.DOUBLE),
+            null, null)
     ));
 
     // create avro file
@@ -194,12 +243,21 @@ public class VectorTest extends 
CustomDataQueryClusterIntegrationTest {
         // create avro record
         GenericData.Record record = new GenericData.Record(avroSchema);
 
-        Collection<Float> vector1 = createRandomVector(VECTOR_DIM_SIZE);
-        Collection<Float> vector2 = createRandomVector(VECTOR_DIM_SIZE);
-        Collection<Float> zeroVector = createZeroVector(VECTOR_DIM_SIZE);
-        record.put(VECTOR_1, vector1);
-        record.put(VECTOR_2, vector2);
-        record.put(ZERO_VECTOR, zeroVector);
+        float[] vector1 = createRandomVector(VECTOR_DIM_SIZE);
+        float[] vector2 = createRandomVector(VECTOR_DIM_SIZE);
+        float[] zeroVector = createZeroVector(VECTOR_DIM_SIZE);
+
+        record.put(VECTOR_1, convertToFloatCollection(vector1));
+        record.put(VECTOR_2, convertToFloatCollection(vector2));
+        record.put(ZERO_VECTOR, convertToFloatCollection(zeroVector));
+        record.put(VECTOR_1_NORM, VectorFunctions.vectorNorm(vector1));
+        record.put(VECTOR_2_NORM, VectorFunctions.vectorNorm(vector2));
+        record.put(VECTORS_COSINE_DIST, 
VectorFunctions.cosineDistance(vector1, vector2));
+        record.put(VECTORS_INNER_PRODUCT, 
VectorFunctions.innerProduct(vector1, vector2));
+        record.put(VECTORS_L1_DIST, VectorFunctions.l1Distance(vector1, 
vector2));
+        record.put(VECTORS_L2_DIST, VectorFunctions.l2Distance(vector1, 
vector2));
+        record.put(VECTOR_ZERO_L1_DIST, VectorFunctions.l1Distance(vector1, 
zeroVector));
+        record.put(VECTOR_ZERO_L2_DIST, VectorFunctions.l2Distance(vector1, 
zeroVector));
 
         // add avro record to file
         fileWriter.append(record);
@@ -208,19 +266,23 @@ public class VectorTest extends 
CustomDataQueryClusterIntegrationTest {
     return avroFile;
   }
 
-  private Collection<Float> createZeroVector(int vectorDimSize) {
-    List<Float> vector = new ArrayList<>();
-    for (int i = 0; i < vectorDimSize; i++) {
-      vector.add(i, 0.0f);
-    }
+  private float[] createZeroVector(int vectorDimSize) {
+    float[] vector = new float[vectorDimSize];
+    IntStream.range(0, vectorDimSize).forEach(i -> vector[i] = 0.0f);
     return vector;
   }
 
-  private Collection<Float> createRandomVector(int vectorDimSize) {
-    List<Float> vector = new ArrayList<>();
-    for (int i = 0; i < vectorDimSize; i++) {
-      vector.add(i, RandomUtils.nextFloat(0.0f, 1.0f));
-    }
+  private float[] createRandomVector(int vectorDimSize) {
+    float[] vector = new float[vectorDimSize];
+    IntStream.range(0, vectorDimSize).forEach(i -> vector[i] = 
RandomUtils.nextFloat(0.0f, 1.0f));
     return vector;
   }
+
+  private Collection<Float> convertToFloatCollection(float[] vector) {
+    Collection<Float> vectorCollection = new ArrayList<>();
+    for (float v : vector) {
+      vectorCollection.add(v);
+    }
+    return vectorCollection;
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to