This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new a4080209a9 GH-38662: [Java] Add comparators (#38669)
a4080209a9 is described below

commit a4080209a97a5d66accdeb71c5c1ffa982fed51e
Author: James Duong <[email protected]>
AuthorDate: Tue Nov 14 05:10:36 2023 -0800

    GH-38662: [Java] Add comparators (#38669)
    
    ### Rationale for this change
    Add missing Default VectorValueComparators for some more types.
    
    ### What changes are included in this PR?
    Add comparators for:
    - FixedSizeBinaryVector
    - LargeListVector
    - FixedSizeListVector
    - NullVector
    
    ### Are these changes tested?
    Yes, unit tests added.
    
    ### Are there any user-facing changes?
    No
    * Closes: #38662
    
    Authored-by: James Duong <[email protected]>
    Signed-off-by: David Li <[email protected]>
---
 .../algorithm/sort/DefaultVectorComparators.java   | 140 +++++++++++++++++++--
 .../sort/TestDefaultVectorComparator.java          | 132 +++++++++++++++++++
 2 files changed, 259 insertions(+), 13 deletions(-)

diff --git 
a/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/DefaultVectorComparators.java
 
b/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/DefaultVectorComparators.java
index 4f9c8b7d71..588876aa99 100644
--- 
a/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/DefaultVectorComparators.java
+++ 
b/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/DefaultVectorComparators.java
@@ -32,11 +32,13 @@ import org.apache.arrow.vector.DateMilliVector;
 import org.apache.arrow.vector.Decimal256Vector;
 import org.apache.arrow.vector.DecimalVector;
 import org.apache.arrow.vector.DurationVector;
+import org.apache.arrow.vector.FixedSizeBinaryVector;
 import org.apache.arrow.vector.Float4Vector;
 import org.apache.arrow.vector.Float8Vector;
 import org.apache.arrow.vector.IntVector;
 import org.apache.arrow.vector.IntervalDayVector;
 import org.apache.arrow.vector.IntervalMonthDayNanoVector;
+import org.apache.arrow.vector.NullVector;
 import org.apache.arrow.vector.SmallIntVector;
 import org.apache.arrow.vector.TimeMicroVector;
 import org.apache.arrow.vector.TimeMilliVector;
@@ -50,7 +52,9 @@ import org.apache.arrow.vector.UInt4Vector;
 import org.apache.arrow.vector.UInt8Vector;
 import org.apache.arrow.vector.ValueVector;
 import org.apache.arrow.vector.VariableWidthVector;
-import org.apache.arrow.vector.complex.BaseRepeatedValueVector;
+import org.apache.arrow.vector.complex.FixedSizeListVector;
+import org.apache.arrow.vector.complex.RepeatedValueVector;
+import org.apache.arrow.vector.holders.NullableFixedSizeBinaryHolder;
 
 /**
  * Default comparator implementations for different types of vectors.
@@ -111,13 +115,21 @@ public class DefaultVectorComparators {
         return (VectorValueComparator<T>) new TimeSecComparator();
       } else if (vector instanceof TimeStampVector) {
         return (VectorValueComparator<T>) new TimeStampComparator();
+      } else if (vector instanceof FixedSizeBinaryVector) {
+        return (VectorValueComparator<T>) new FixedSizeBinaryComparator();
       }
     } else if (vector instanceof VariableWidthVector) {
       return (VectorValueComparator<T>) new VariableWidthComparator();
-    } else if (vector instanceof BaseRepeatedValueVector) {
+    } else if (vector instanceof RepeatedValueVector) {
       VectorValueComparator<?> innerComparator =
-              createDefaultComparator(((BaseRepeatedValueVector) 
vector).getDataVector());
+              createDefaultComparator(((RepeatedValueVector) 
vector).getDataVector());
       return new RepeatedValueComparator(innerComparator);
+    } else if (vector instanceof FixedSizeListVector) {
+      VectorValueComparator<?> innerComparator =
+          createDefaultComparator(((FixedSizeListVector) 
vector).getDataVector());
+      return new FixedSizeListComparator(innerComparator);
+    } else if (vector instanceof NullVector) {
+      return (VectorValueComparator<T>) new NullComparator();
     }
 
     throw new IllegalArgumentException("No default comparator for " + 
vector.getClass().getCanonicalName());
@@ -674,6 +686,61 @@ public class DefaultVectorComparators {
     }
   }
 
+  /**
+   * Default comparator for {@link 
org.apache.arrow.vector.FixedSizeBinaryVector}.
+   * The comparison is in lexicographic order, with null comes first.
+   */
+  public static class FixedSizeBinaryComparator extends 
VectorValueComparator<FixedSizeBinaryVector> {
+
+    @Override
+    public int compare(int index1, int index2) {
+      NullableFixedSizeBinaryHolder holder1 = new 
NullableFixedSizeBinaryHolder();
+      NullableFixedSizeBinaryHolder holder2 = new 
NullableFixedSizeBinaryHolder();
+      vector1.get(index1, holder1);
+      vector2.get(index2, holder2);
+
+      return ByteFunctionHelpers.compare(
+          holder1.buffer, 0, holder1.byteWidth, holder2.buffer, 0, 
holder2.byteWidth);
+    }
+
+    @Override
+    public int compareNotNull(int index1, int index2) {
+      NullableFixedSizeBinaryHolder holder1 = new 
NullableFixedSizeBinaryHolder();
+      NullableFixedSizeBinaryHolder holder2 = new 
NullableFixedSizeBinaryHolder();
+      vector1.get(index1, holder1);
+      vector2.get(index2, holder2);
+
+      return ByteFunctionHelpers.compare(
+          holder1.buffer, 0, holder1.byteWidth, holder2.buffer, 0, 
holder2.byteWidth);
+    }
+
+    @Override
+    public VectorValueComparator<FixedSizeBinaryVector> createNew() {
+      return new FixedSizeBinaryComparator();
+    }
+  }
+
+  /**
+   * Default comparator for {@link org.apache.arrow.vector.NullVector}.
+   */
+  public static class NullComparator extends VectorValueComparator<NullVector> 
{
+    @Override
+    public int compare(int index1, int index2) {
+      // Values are always equal (and are always null).
+      return 0;
+    }
+
+    @Override
+    public int compareNotNull(int index1, int index2) {
+      throw new AssertionError("Cannot compare non-null values in a 
NullVector.");
+    }
+
+    @Override
+    public VectorValueComparator<NullVector> createNew() {
+      return new NullComparator();
+    }
+  }
+
   /**
    * Default comparator for {@link 
org.apache.arrow.vector.VariableWidthVector}.
    * The comparison is in lexicographic order, with null comes first.
@@ -705,14 +772,14 @@ public class DefaultVectorComparators {
   }
 
   /**
-   * Default comparator for {@link BaseRepeatedValueVector}.
+   * Default comparator for {@link RepeatedValueVector}.
    * It works by comparing the underlying vector in a lexicographic order.
    * @param <T> inner vector type.
    */
   public static class RepeatedValueComparator<T extends ValueVector>
-          extends VectorValueComparator<BaseRepeatedValueVector> {
+          extends VectorValueComparator<RepeatedValueVector> {
 
-    private VectorValueComparator<T> innerComparator;
+    private final VectorValueComparator<T> innerComparator;
 
     public RepeatedValueComparator(VectorValueComparator<T> innerComparator) {
       this.innerComparator = innerComparator;
@@ -720,16 +787,16 @@ public class DefaultVectorComparators {
 
     @Override
     public int compareNotNull(int index1, int index2) {
-      int startIdx1 = vector1.getOffsetBuffer().getInt(index1 * OFFSET_WIDTH);
-      int startIdx2 = vector2.getOffsetBuffer().getInt(index2 * OFFSET_WIDTH);
+      int startIdx1 = vector1.getOffsetBuffer().getInt((long) index1 * 
OFFSET_WIDTH);
+      int startIdx2 = vector2.getOffsetBuffer().getInt((long) index2 * 
OFFSET_WIDTH);
 
-      int endIdx1 = vector1.getOffsetBuffer().getInt((index1 + 1) * 
OFFSET_WIDTH);
-      int endIdx2 = vector2.getOffsetBuffer().getInt((index2 + 1) * 
OFFSET_WIDTH);
+      int endIdx1 = vector1.getOffsetBuffer().getInt((long) (index1 + 1) * 
OFFSET_WIDTH);
+      int endIdx2 = vector2.getOffsetBuffer().getInt((long) (index2 + 1) * 
OFFSET_WIDTH);
 
       int length1 = endIdx1 - startIdx1;
       int length2 = endIdx2 - startIdx2;
 
-      int length = length1 < length2 ? length1 : length2;
+      int length = Math.min(length1, length2);
 
       for (int i = 0; i < length; i++) {
         int result = innerComparator.compare(startIdx1 + i, startIdx2 + i);
@@ -741,13 +808,60 @@ public class DefaultVectorComparators {
     }
 
     @Override
-    public VectorValueComparator<BaseRepeatedValueVector> createNew() {
+    public VectorValueComparator<RepeatedValueVector> createNew() {
       VectorValueComparator<T> newInnerComparator = 
innerComparator.createNew();
       return new RepeatedValueComparator<>(newInnerComparator);
     }
 
     @Override
-    public void attachVectors(BaseRepeatedValueVector vector1, 
BaseRepeatedValueVector vector2) {
+    public void attachVectors(RepeatedValueVector vector1, RepeatedValueVector 
vector2) {
+      this.vector1 = vector1;
+      this.vector2 = vector2;
+
+      innerComparator.attachVectors((T) vector1.getDataVector(), (T) 
vector2.getDataVector());
+    }
+  }
+
+  /**
+   * Default comparator for {@link RepeatedValueVector}.
+   * It works by comparing the underlying vector in a lexicographic order.
+   * @param <T> inner vector type.
+   */
+  public static class FixedSizeListComparator<T extends ValueVector>
+      extends VectorValueComparator<FixedSizeListVector> {
+
+    private final VectorValueComparator<T> innerComparator;
+
+    public FixedSizeListComparator(VectorValueComparator<T> innerComparator) {
+      this.innerComparator = innerComparator;
+    }
+
+    @Override
+    public int compareNotNull(int index1, int index2) {
+      int length1 = vector1.getListSize();
+      int length2 = vector2.getListSize();
+
+      int length = Math.min(length1, length2);
+      int startIdx1 = vector1.getElementStartIndex(index1);
+      int startIdx2 = vector2.getElementStartIndex(index2);
+
+      for (int i = 0; i < length; i++) {
+        int result = innerComparator.compare(startIdx1 + i, startIdx2 + i);
+        if (result != 0) {
+          return result;
+        }
+      }
+      return length1 - length2;
+    }
+
+    @Override
+    public VectorValueComparator<FixedSizeListVector> createNew() {
+      VectorValueComparator<T> newInnerComparator = 
innerComparator.createNew();
+      return new FixedSizeListComparator<>(newInnerComparator);
+    }
+
+    @Override
+    public void attachVectors(FixedSizeListVector vector1, FixedSizeListVector 
vector2) {
       this.vector1 = vector1;
       this.vector2 = vector2;
 
diff --git 
a/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestDefaultVectorComparator.java
 
b/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestDefaultVectorComparator.java
index bdae85110a..43c634b764 100644
--- 
a/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestDefaultVectorComparator.java
+++ 
b/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestDefaultVectorComparator.java
@@ -31,12 +31,14 @@ import org.apache.arrow.vector.DateMilliVector;
 import org.apache.arrow.vector.Decimal256Vector;
 import org.apache.arrow.vector.DecimalVector;
 import org.apache.arrow.vector.DurationVector;
+import org.apache.arrow.vector.FixedSizeBinaryVector;
 import org.apache.arrow.vector.Float4Vector;
 import org.apache.arrow.vector.Float8Vector;
 import org.apache.arrow.vector.IntVector;
 import org.apache.arrow.vector.IntervalDayVector;
 import org.apache.arrow.vector.LargeVarBinaryVector;
 import org.apache.arrow.vector.LargeVarCharVector;
+import org.apache.arrow.vector.NullVector;
 import org.apache.arrow.vector.SmallIntVector;
 import org.apache.arrow.vector.TimeMicroVector;
 import org.apache.arrow.vector.TimeMilliVector;
@@ -52,6 +54,8 @@ import org.apache.arrow.vector.UInt8Vector;
 import org.apache.arrow.vector.ValueVector;
 import org.apache.arrow.vector.VarBinaryVector;
 import org.apache.arrow.vector.VarCharVector;
+import org.apache.arrow.vector.complex.FixedSizeListVector;
+import org.apache.arrow.vector.complex.LargeListVector;
 import org.apache.arrow.vector.complex.ListVector;
 import org.apache.arrow.vector.testing.ValueVectorDataPopulator;
 import org.apache.arrow.vector.types.TimeUnit;
@@ -158,6 +162,61 @@ public class TestDefaultVectorComparator {
     }
   }
 
+  private FixedSizeListVector createFixedSizeListVector(int count) {
+    FixedSizeListVector listVector = FixedSizeListVector.empty("list vector", 
count, allocator);
+    Types.MinorType type = Types.MinorType.INT;
+    listVector.addOrGetVector(FieldType.nullable(type.getType()));
+    listVector.allocateNew();
+
+    IntVector dataVector = (IntVector) listVector.getDataVector();
+
+    for (int i = 0; i < count; i++) {
+      dataVector.set(i, i);
+    }
+    dataVector.setValueCount(count);
+
+    listVector.setNotNull(0);
+    listVector.setValueCount(1);
+
+    return listVector;
+  }
+
+  @Test
+  public void testCompareFixedSizeLists() {
+    try (FixedSizeListVector listVector1 = createFixedSizeListVector(10);
+         FixedSizeListVector listVector2 = createFixedSizeListVector(11)) {
+      VectorValueComparator<FixedSizeListVector> comparator =
+          DefaultVectorComparators.createDefaultComparator(listVector1);
+      comparator.attachVectors(listVector1, listVector2);
+
+      // prefix is smaller
+      assertTrue(comparator.compare(0, 0) < 0);
+    }
+
+    try (FixedSizeListVector listVector1 = createFixedSizeListVector(11);
+         FixedSizeListVector listVector2 = createFixedSizeListVector(11)) {
+      ((IntVector) listVector2.getDataVector()).set(10, 110);
+
+      VectorValueComparator<FixedSizeListVector> comparator =
+          DefaultVectorComparators.createDefaultComparator(listVector1);
+      comparator.attachVectors(listVector1, listVector2);
+
+      // breaking tie by the last element
+      assertTrue(comparator.compare(0, 0) < 0);
+    }
+
+    try (FixedSizeListVector listVector1 = createFixedSizeListVector(10);
+         FixedSizeListVector listVector2 = createFixedSizeListVector(10)) {
+
+      VectorValueComparator<FixedSizeListVector> comparator =
+          DefaultVectorComparators.createDefaultComparator(listVector1);
+      comparator.attachVectors(listVector1, listVector2);
+
+      // list vector elements equal
+      assertTrue(comparator.compare(0, 0) == 0);
+    }
+  }
+
   @Test
   public void testCompareUInt1() {
     try (UInt1Vector vec = new UInt1Vector("", allocator)) {
@@ -845,6 +904,65 @@ public class TestDefaultVectorComparator {
     }
   }
 
+  @Test
+  public void testCompareFixedSizeBinary() {
+    try (FixedSizeBinaryVector vector1 = new FixedSizeBinaryVector("test1", 
allocator, 2);
+         FixedSizeBinaryVector vector2 = new FixedSizeBinaryVector("test1", 
allocator, 3)) {
+      vector1.allocateNew();
+      vector2.allocateNew();
+      vector1.set(0, new byte[] {1, 1});
+      vector2.set(0, new byte[] {1, 1, 0});
+      VectorValueComparator<FixedSizeBinaryVector> comparator =
+          DefaultVectorComparators.createDefaultComparator(vector1);
+      comparator.attachVectors(vector1, vector2);
+
+      // prefix is smaller
+      assertTrue(comparator.compare(0, 0) < 0);
+    }
+
+    try (FixedSizeBinaryVector vector1 = new FixedSizeBinaryVector("test1", 
allocator, 3);
+         FixedSizeBinaryVector vector2 = new FixedSizeBinaryVector("test1", 
allocator, 3)) {
+      vector1.allocateNew();
+      vector2.allocateNew();
+      vector1.set(0, new byte[] {1, 1, 0});
+      vector2.set(0, new byte[] {1, 1, 1});
+      VectorValueComparator<FixedSizeBinaryVector> comparator =
+          DefaultVectorComparators.createDefaultComparator(vector1);
+      comparator.attachVectors(vector1, vector2);
+
+      // breaking tie by the last element
+      assertTrue(comparator.compare(0, 0) < 0);
+    }
+
+    try (FixedSizeBinaryVector vector1 = new FixedSizeBinaryVector("test1", 
allocator, 3);
+         FixedSizeBinaryVector vector2 = new FixedSizeBinaryVector("test1", 
allocator, 3)) {
+      vector1.allocateNew();
+      vector2.allocateNew();
+      vector1.set(0, new byte[] {1, 1, 1});
+      vector2.set(0, new byte[] {1, 1, 1});
+      VectorValueComparator<FixedSizeBinaryVector> comparator =
+          DefaultVectorComparators.createDefaultComparator(vector1);
+      comparator.attachVectors(vector1, vector2);
+
+      // list vector elements equal
+      assertTrue(comparator.compare(0, 0) == 0);
+    }
+  }
+
+  @Test
+  public void testCompareNull() {
+    try (NullVector vec = new NullVector("test",
+        FieldType.notNullable(new ArrowType.Int(32, false)))) {
+      vec.setValueCount(2);
+
+      VectorValueComparator<NullVector> comparator =
+          DefaultVectorComparators.createDefaultComparator(vec);
+      comparator.attachVector(vec);
+      assertEquals(DefaultVectorComparators.NullComparator.class, 
comparator.getClass());
+      assertEquals(0, comparator.compare(0, 1));
+    }
+  }
+
   @Test
   public void testCheckNullsOnCompareIsFalseForNonNullableVector() {
     try (IntVector vec = new IntVector("not nullable",
@@ -937,4 +1055,18 @@ public class TestDefaultVectorComparator {
     VectorValueComparator<V> comparator = 
DefaultVectorComparators.createDefaultComparator(vec);
     assertEquals(DefaultVectorComparators.VariableWidthComparator.class, 
comparator.getClass());
   }
+
+  @Test
+  public void testRepeatedDefaultComparators() {
+    final FieldType type = FieldType.nullable(Types.MinorType.INT.getType());
+    try (final LargeListVector vector = new LargeListVector("list", allocator, 
type, null)) {
+      vector.addOrGetVector(FieldType.nullable(type.getType()));
+      verifyRepeatedComparatorReturned(vector);
+    }
+  }
+
+  private static <V extends ValueVector> void 
verifyRepeatedComparatorReturned(V vec) {
+    VectorValueComparator<V> comparator = 
DefaultVectorComparators.createDefaultComparator(vec);
+    assertEquals(DefaultVectorComparators.RepeatedValueComparator.class, 
comparator.getClass());
+  }
 }

Reply via email to