This is an automated email from the ASF dual-hosted git repository.
lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new d0f9f3e613 GH-43966: [Java] Check for nullabilities when comparing
StructVector (#43968)
d0f9f3e613 is described below
commit d0f9f3e6136bffaa94b25d4a1c95576d4747773d
Author: hellishfire <[email protected]>
AuthorDate: Mon Sep 9 15:26:15 2024 +0800
GH-43966: [Java] Check for nullabilities when comparing StructVector
(#43968)
### Rationale for this change
See #43966
### What changes are included in this PR?
Check for nullabilities when comparing StructVector with RangeEqualsVisitor.
### Are these changes tested?
Yes
### Are there any user-facing changes?
No
* GitHub Issue: #43966
Authored-by: youming.whl <[email protected]>
Signed-off-by: David Li <[email protected]>
---
.../arrow/vector/compare/RangeEqualsVisitor.java | 61 +++++++++++++++++++---
.../vector/compare/TestRangeEqualsVisitor.java | 20 +++++--
2 files changed, 72 insertions(+), 9 deletions(-)
diff --git
a/java/vector/src/main/java/org/apache/arrow/vector/compare/RangeEqualsVisitor.java
b/java/vector/src/main/java/org/apache/arrow/vector/compare/RangeEqualsVisitor.java
index 9aa1bffb84..ed51f748af 100644
---
a/java/vector/src/main/java/org/apache/arrow/vector/compare/RangeEqualsVisitor.java
+++
b/java/vector/src/main/java/org/apache/arrow/vector/compare/RangeEqualsVisitor.java
@@ -41,6 +41,7 @@ import org.apache.arrow.vector.complex.LargeListViewVector;
import org.apache.arrow.vector.complex.ListVector;
import org.apache.arrow.vector.complex.ListViewVector;
import org.apache.arrow.vector.complex.NonNullableStructVector;
+import org.apache.arrow.vector.complex.StructVector;
import org.apache.arrow.vector.complex.UnionVector;
/** Visitor to compare a range of values for vectors. */
@@ -345,6 +346,20 @@ public class RangeEqualsVisitor implements
VectorVisitor<Boolean, Range> {
return true;
}
+ private boolean compareStructVectorsInternal(
+ NonNullableStructVector leftVector, NonNullableStructVector rightVector,
Range range) {
+ List<String> leftChildNames = leftVector.getChildFieldNames();
+ for (String name : leftChildNames) {
+ RangeEqualsVisitor visitor =
+ createInnerVisitor(
+ leftVector.getChild(name), rightVector.getChild(name), /*type
comparator*/ null);
+ if (!visitor.rangeEquals(range)) {
+ return false;
+ }
+ }
+ return true;
+ }
+
protected boolean compareStructVectors(Range range) {
NonNullableStructVector leftVector = (NonNullableStructVector) left;
NonNullableStructVector rightVector = (NonNullableStructVector) right;
@@ -354,15 +369,49 @@ public class RangeEqualsVisitor implements
VectorVisitor<Boolean, Range> {
return false;
}
- for (String name : leftChildNames) {
- RangeEqualsVisitor visitor =
- createInnerVisitor(
- leftVector.getChild(name), rightVector.getChild(name), /*type
comparator*/ null);
- if (!visitor.rangeEquals(range)) {
+ if (!(leftVector instanceof StructVector || rightVector instanceof
StructVector)) {
+ // neither struct vector is nullable
+ return compareStructVectorsInternal(leftVector, rightVector, range);
+ }
+
+ Range subRange = new Range(0, 0, 0);
+ boolean lastIsNull = true;
+ int lastNullIndex = -1;
+ for (int i = 0; i < range.getLength(); i++) {
+ int leftIndex = range.getLeftStart() + i;
+ int rightIndex = range.getRightStart() + i;
+ boolean isLeftNull = leftVector.isNull(leftIndex);
+ boolean isRightNull = rightVector.isNull(rightIndex);
+
+ if (isLeftNull != isRightNull) {
+ // exactly one slot is null, unequal
return false;
}
+ if (isLeftNull) {
+ // slots are null
+ if (!lastIsNull) {
+ subRange
+ .setLeftStart(range.getLeftStart() + lastNullIndex + 1)
+ .setRightStart(range.getRightStart() + lastNullIndex + 1)
+ .setLength(i - (lastNullIndex + 1));
+ if (!compareStructVectorsInternal(leftVector, rightVector,
subRange)) {
+ return false;
+ }
+ }
+ lastIsNull = true;
+ lastNullIndex = i;
+ } else {
+ // slots are not null
+ lastIsNull = false;
+ }
+ }
+ if (!lastIsNull) {
+ subRange
+ .setLeftStart(range.getLeftStart() + lastNullIndex + 1)
+ .setRightStart(range.getRightStart() + lastNullIndex + 1)
+ .setLength(range.getLength() - (lastNullIndex + 1));
+ return compareStructVectorsInternal(leftVector, rightVector, subRange);
}
-
return true;
}
diff --git
a/java/vector/src/test/java/org/apache/arrow/vector/compare/TestRangeEqualsVisitor.java
b/java/vector/src/test/java/org/apache/arrow/vector/compare/TestRangeEqualsVisitor.java
index eca5c2d9b2..08da786eb2 100644
---
a/java/vector/src/test/java/org/apache/arrow/vector/compare/TestRangeEqualsVisitor.java
+++
b/java/vector/src/test/java/org/apache/arrow/vector/compare/TestRangeEqualsVisitor.java
@@ -434,17 +434,18 @@ public class TestRangeEqualsVisitor {
NullableStructWriter writer1 = vector1.getWriter();
writer1.allocate();
+ writeStructVector(writer1, 0, 0L);
writeStructVector(writer1, 1, 10L);
writeStructVector(writer1, 2, 20L);
writeStructVector(writer1, 3, 30L);
writeStructVector(writer1, 4, 40L);
writeStructVector(writer1, 5, 50L);
- writer1.setValueCount(5);
+ writer1.setValueCount(6);
NullableStructWriter writer2 = vector2.getWriter();
writer2.allocate();
- writeStructVector(writer2, 0, 00L);
+ writeStructVector(writer2, 0, 0L);
writeStructVector(writer2, 2, 20L);
writeStructVector(writer2, 3, 30L);
writeStructVector(writer2, 4, 40L);
@@ -452,7 +453,20 @@ public class TestRangeEqualsVisitor {
writer2.setValueCount(5);
RangeEqualsVisitor visitor = new RangeEqualsVisitor(vector1, vector2);
- assertTrue(visitor.rangeEquals(new Range(1, 1, 3)));
+ assertTrue(visitor.rangeEquals(new Range(2, 1, 3)));
+
+ // different nullability but same values
+ vector1.setNull(3);
+ assertFalse(visitor.rangeEquals(new Range(2, 1, 3)));
+ // both null and same values
+ vector2.setNull(2);
+ assertTrue(visitor.rangeEquals(new Range(2, 1, 3)));
+ // both not null but different values
+ assertFalse(visitor.rangeEquals(new Range(2, 1, 4)));
+ // both null but different values
+ vector1.setNull(5);
+ vector2.setNull(4);
+ assertTrue(visitor.rangeEquals(new Range(2, 1, 4)));
}
}