This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new d0f9f3e613 GH-43966: [Java] Check for nullabilities when comparing 
StructVector (#43968)
d0f9f3e613 is described below

commit d0f9f3e6136bffaa94b25d4a1c95576d4747773d
Author: hellishfire <[email protected]>
AuthorDate: Mon Sep 9 15:26:15 2024 +0800

    GH-43966: [Java] Check for nullabilities when comparing StructVector 
(#43968)
    
    
    
    ### Rationale for this change
    
    See #43966
    
    ### What changes are included in this PR?
    
    Check for nullabilities when comparing StructVector with RangeEqualsVisitor.
    
    ### Are these changes tested?
    
    Yes
    
    ### Are there any user-facing changes?
    No
    
    * GitHub Issue: #43966
    
    Authored-by: youming.whl <[email protected]>
    Signed-off-by: David Li <[email protected]>
---
 .../arrow/vector/compare/RangeEqualsVisitor.java   | 61 +++++++++++++++++++---
 .../vector/compare/TestRangeEqualsVisitor.java     | 20 +++++--
 2 files changed, 72 insertions(+), 9 deletions(-)

diff --git 
a/java/vector/src/main/java/org/apache/arrow/vector/compare/RangeEqualsVisitor.java
 
b/java/vector/src/main/java/org/apache/arrow/vector/compare/RangeEqualsVisitor.java
index 9aa1bffb84..ed51f748af 100644
--- 
a/java/vector/src/main/java/org/apache/arrow/vector/compare/RangeEqualsVisitor.java
+++ 
b/java/vector/src/main/java/org/apache/arrow/vector/compare/RangeEqualsVisitor.java
@@ -41,6 +41,7 @@ import org.apache.arrow.vector.complex.LargeListViewVector;
 import org.apache.arrow.vector.complex.ListVector;
 import org.apache.arrow.vector.complex.ListViewVector;
 import org.apache.arrow.vector.complex.NonNullableStructVector;
+import org.apache.arrow.vector.complex.StructVector;
 import org.apache.arrow.vector.complex.UnionVector;
 
 /** Visitor to compare a range of values for vectors. */
@@ -345,6 +346,20 @@ public class RangeEqualsVisitor implements 
VectorVisitor<Boolean, Range> {
     return true;
   }
 
+  private boolean compareStructVectorsInternal(
+      NonNullableStructVector leftVector, NonNullableStructVector rightVector, 
Range range) {
+    List<String> leftChildNames = leftVector.getChildFieldNames();
+    for (String name : leftChildNames) {
+      RangeEqualsVisitor visitor =
+          createInnerVisitor(
+              leftVector.getChild(name), rightVector.getChild(name), /*type 
comparator*/ null);
+      if (!visitor.rangeEquals(range)) {
+        return false;
+      }
+    }
+    return true;
+  }
+
   protected boolean compareStructVectors(Range range) {
     NonNullableStructVector leftVector = (NonNullableStructVector) left;
     NonNullableStructVector rightVector = (NonNullableStructVector) right;
@@ -354,15 +369,49 @@ public class RangeEqualsVisitor implements 
VectorVisitor<Boolean, Range> {
       return false;
     }
 
-    for (String name : leftChildNames) {
-      RangeEqualsVisitor visitor =
-          createInnerVisitor(
-              leftVector.getChild(name), rightVector.getChild(name), /*type 
comparator*/ null);
-      if (!visitor.rangeEquals(range)) {
+    if (!(leftVector instanceof StructVector || rightVector instanceof 
StructVector)) {
+      // neither struct vector is nullable
+      return compareStructVectorsInternal(leftVector, rightVector, range);
+    }
+
+    Range subRange = new Range(0, 0, 0);
+    boolean lastIsNull = true;
+    int lastNullIndex = -1;
+    for (int i = 0; i < range.getLength(); i++) {
+      int leftIndex = range.getLeftStart() + i;
+      int rightIndex = range.getRightStart() + i;
+      boolean isLeftNull = leftVector.isNull(leftIndex);
+      boolean isRightNull = rightVector.isNull(rightIndex);
+
+      if (isLeftNull != isRightNull) {
+        // exactly one slot is null, unequal
         return false;
       }
+      if (isLeftNull) {
+        // slots are null
+        if (!lastIsNull) {
+          subRange
+              .setLeftStart(range.getLeftStart() + lastNullIndex + 1)
+              .setRightStart(range.getRightStart() + lastNullIndex + 1)
+              .setLength(i - (lastNullIndex + 1));
+          if (!compareStructVectorsInternal(leftVector, rightVector, 
subRange)) {
+            return false;
+          }
+        }
+        lastIsNull = true;
+        lastNullIndex = i;
+      } else {
+        // slots are not null
+        lastIsNull = false;
+      }
+    }
+    if (!lastIsNull) {
+      subRange
+          .setLeftStart(range.getLeftStart() + lastNullIndex + 1)
+          .setRightStart(range.getRightStart() + lastNullIndex + 1)
+          .setLength(range.getLength() - (lastNullIndex + 1));
+      return compareStructVectorsInternal(leftVector, rightVector, subRange);
     }
-
     return true;
   }
 
diff --git 
a/java/vector/src/test/java/org/apache/arrow/vector/compare/TestRangeEqualsVisitor.java
 
b/java/vector/src/test/java/org/apache/arrow/vector/compare/TestRangeEqualsVisitor.java
index eca5c2d9b2..08da786eb2 100644
--- 
a/java/vector/src/test/java/org/apache/arrow/vector/compare/TestRangeEqualsVisitor.java
+++ 
b/java/vector/src/test/java/org/apache/arrow/vector/compare/TestRangeEqualsVisitor.java
@@ -434,17 +434,18 @@ public class TestRangeEqualsVisitor {
       NullableStructWriter writer1 = vector1.getWriter();
       writer1.allocate();
 
+      writeStructVector(writer1, 0, 0L);
       writeStructVector(writer1, 1, 10L);
       writeStructVector(writer1, 2, 20L);
       writeStructVector(writer1, 3, 30L);
       writeStructVector(writer1, 4, 40L);
       writeStructVector(writer1, 5, 50L);
-      writer1.setValueCount(5);
+      writer1.setValueCount(6);
 
       NullableStructWriter writer2 = vector2.getWriter();
       writer2.allocate();
 
-      writeStructVector(writer2, 0, 00L);
+      writeStructVector(writer2, 0, 0L);
       writeStructVector(writer2, 2, 20L);
       writeStructVector(writer2, 3, 30L);
       writeStructVector(writer2, 4, 40L);
@@ -452,7 +453,20 @@ public class TestRangeEqualsVisitor {
       writer2.setValueCount(5);
 
       RangeEqualsVisitor visitor = new RangeEqualsVisitor(vector1, vector2);
-      assertTrue(visitor.rangeEquals(new Range(1, 1, 3)));
+      assertTrue(visitor.rangeEquals(new Range(2, 1, 3)));
+
+      // different nullability but same values
+      vector1.setNull(3);
+      assertFalse(visitor.rangeEquals(new Range(2, 1, 3)));
+      // both null and same values
+      vector2.setNull(2);
+      assertTrue(visitor.rangeEquals(new Range(2, 1, 3)));
+      // both not null but different values
+      assertFalse(visitor.rangeEquals(new Range(2, 1, 4)));
+      // both null but different values
+      vector1.setNull(5);
+      vector2.setNull(4);
+      assertTrue(visitor.rangeEquals(new Range(2, 1, 4)));
     }
   }
 

Reply via email to