This is an automated email from the ASF dual-hosted git repository.
lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-java.git
The following commit(s) were added to refs/heads/main by this push:
new 7c25ce5d8 GH-52: Make RangeEqualsVisitor of RunEndEncodedVector more
efficient (#761)
7c25ce5d8 is described below
commit 7c25ce5d86490822600b49928d34a08b4dddad46
Author: ViggoC <[email protected]>
AuthorDate: Thu May 22 21:40:44 2025 +0800
GH-52: Make RangeEqualsVisitor of RunEndEncodedVector more efficient (#761)
## What's Changed
Avoid doing a binary search on every step to make the RangeEqualsVisitor
of RunEndEncodedVector more efficient.
Closes #52 .
---
.../arrow/vector/compare/RangeEqualsVisitor.java | 44 +++++-----
.../arrow/vector/complex/RunEndEncodedVector.java | 98 ++++++++++++++++++++++
.../arrow/vector/TestRunEndEncodedVector.java | 18 ++--
.../vector/compare/TestRangeEqualsVisitor.java | 52 ++++++++++++
4 files changed, 181 insertions(+), 31 deletions(-)
diff --git
a/vector/src/main/java/org/apache/arrow/vector/compare/RangeEqualsVisitor.java
b/vector/src/main/java/org/apache/arrow/vector/compare/RangeEqualsVisitor.java
index abcf312c5..bc2e3a6aa 100644
---
a/vector/src/main/java/org/apache/arrow/vector/compare/RangeEqualsVisitor.java
+++
b/vector/src/main/java/org/apache/arrow/vector/compare/RangeEqualsVisitor.java
@@ -43,6 +43,7 @@ import org.apache.arrow.vector.complex.ListVector;
import org.apache.arrow.vector.complex.ListViewVector;
import org.apache.arrow.vector.complex.NonNullableStructVector;
import org.apache.arrow.vector.complex.RunEndEncodedVector;
+import org.apache.arrow.vector.complex.RunEndEncodedVector.RangeIterator;
import org.apache.arrow.vector.complex.StructVector;
import org.apache.arrow.vector.complex.UnionVector;
@@ -270,42 +271,35 @@ public class RangeEqualsVisitor implements
VectorVisitor<Boolean, Range> {
RunEndEncodedVector leftVector = (RunEndEncodedVector) left;
RunEndEncodedVector rightVector = (RunEndEncodedVector) right;
- final int leftRangeEnd = range.getLeftStart() + range.getLength();
- final int rightRangeEnd = range.getRightStart() + range.getLength();
+ final RunEndEncodedVector.RangeIterator leftIterator =
+ new RunEndEncodedVector.RangeIterator(leftVector,
range.getLeftStart(), range.getLength());
+ final RunEndEncodedVector.RangeIterator rightIterator =
+ new RunEndEncodedVector.RangeIterator(
+ rightVector, range.getRightStart(), range.getLength());
FieldVector leftValuesVector = leftVector.getValuesVector();
FieldVector rightValuesVector = rightVector.getValuesVector();
RangeEqualsVisitor innerVisitor = createInnerVisitor(leftValuesVector,
rightValuesVector, null);
- int leftLogicalIndex = range.getLeftStart();
- int rightLogicalIndex = range.getRightStart();
+ while (nextRun(leftIterator, rightIterator)) {
+ int leftPhysicalIndex = leftIterator.getRunIndex();
+ int rightPhysicalIndex = rightIterator.getRunIndex();
- while (leftLogicalIndex < leftRangeEnd) {
- // TODO: implement it more efficient
- // https://github.com/apache/arrow/issues/44157
- int leftPhysicalIndex = leftVector.getPhysicalIndex(leftLogicalIndex);
- int rightPhysicalIndex = rightVector.getPhysicalIndex(rightLogicalIndex);
- if (leftValuesVector.accept(
- innerVisitor, new Range(leftPhysicalIndex, rightPhysicalIndex, 1))) {
- int leftRunEnd = leftVector.getRunEnd(leftLogicalIndex);
- int rightRunEnd = rightVector.getRunEnd(rightLogicalIndex);
-
- int leftRunLength = Math.min(leftRunEnd, leftRangeEnd) -
leftLogicalIndex;
- int rightRunLength = Math.min(rightRunEnd, rightRangeEnd) -
rightLogicalIndex;
-
- if (leftRunLength != rightRunLength) {
- return false;
- } else {
- leftLogicalIndex = leftRunEnd;
- rightLogicalIndex = rightRunEnd;
- }
- } else {
+ if (leftIterator.getRunLength() != rightIterator.getRunLength()
+ || !leftValuesVector.accept(
+ innerVisitor, new Range(leftPhysicalIndex, rightPhysicalIndex,
1))) {
return false;
}
}
- return true;
+ return leftIterator.isEnd() && rightIterator.isEnd();
+ }
+
+ private static boolean nextRun(RangeIterator leftIterator, RangeIterator
rightIterator) {
+ boolean left = leftIterator.nextRun();
+ boolean right = rightIterator.nextRun();
+ return left && right;
}
protected RangeEqualsVisitor createInnerVisitor(
diff --git
a/vector/src/main/java/org/apache/arrow/vector/complex/RunEndEncodedVector.java
b/vector/src/main/java/org/apache/arrow/vector/complex/RunEndEncodedVector.java
index 1bb9a3d6c..b83e13449 100644
---
a/vector/src/main/java/org/apache/arrow/vector/complex/RunEndEncodedVector.java
+++
b/vector/src/main/java/org/apache/arrow/vector/complex/RunEndEncodedVector.java
@@ -28,6 +28,7 @@ import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.OutOfMemoryException;
import org.apache.arrow.memory.util.ByteFunctionHelpers;
import org.apache.arrow.memory.util.hash.ArrowBufHasher;
+import org.apache.arrow.util.Preconditions;
import org.apache.arrow.vector.BaseIntVector;
import org.apache.arrow.vector.BaseValueVector;
import org.apache.arrow.vector.BigIntVector;
@@ -820,4 +821,101 @@ public class RunEndEncodedVector extends BaseValueVector
implements FieldVector
return result;
}
+
+ public static class RangeIterator {
+
+ private final RunEndEncodedVector runEndEncodedVector;
+ private final int rangeEnd;
+ private int runIndex;
+ private int runEnd;
+ private int logicalPos;
+
+ /**
+ * Constructs a new RangeIterator for iterating over a range of values in
a RunEndEncodedVector.
+ *
+ * @param runEndEncodedVector The vector to iterate over
+ * @param startIndex The logical start index of the range (inclusive)
+ * @param length The number of values to include in the range
+ * @throws IllegalArgumentException if startIndex is negative or
(startIndex + length) exceeds
+ * vector bounds
+ */
+ public RangeIterator(RunEndEncodedVector runEndEncodedVector, int
startIndex, int length) {
+ int rangeEnd = startIndex + length;
+ Preconditions.checkArgument(
+ startIndex >= 0, "startIndex %s must be non negative.", startIndex);
+ Preconditions.checkArgument(
+ rangeEnd <= runEndEncodedVector.getValueCount(),
+ "(startIndex + length) %s out of range[0, %s].",
+ rangeEnd,
+ runEndEncodedVector.getValueCount());
+
+ this.rangeEnd = rangeEnd;
+ this.runEndEncodedVector = runEndEncodedVector;
+ this.runIndex = runEndEncodedVector.getPhysicalIndex(startIndex) - 1;
+ this.runEnd = startIndex;
+ this.logicalPos = -1;
+ }
+
+ /**
+ * Advances to the next run in the range.
+ *
+ * @return true if there is another run available, false if iteration has
completed
+ */
+ public boolean nextRun() {
+ logicalPos = runEnd;
+ if (logicalPos >= rangeEnd) {
+ return false;
+ }
+ updateRun();
+ return true;
+ }
+
+ private void updateRun() {
+ runIndex++;
+ runEnd = (int) ((BaseIntVector)
runEndEncodedVector.runEndsVector).getValueAsLong(runIndex);
+ }
+
+ /**
+ * Advances to the next value in the range.
+ *
+ * @return true if there is another value available, false if iteration
has completed
+ */
+ public boolean nextValue() {
+ logicalPos++;
+ if (logicalPos >= rangeEnd) {
+ return false;
+ }
+ if (logicalPos == runEnd) {
+ updateRun();
+ }
+ return true;
+ }
+
+ /**
+ * Gets the current run index (physical position in the run-ends vector).
+ *
+ * @return the current run index
+ */
+ public int getRunIndex() {
+ return runIndex;
+ }
+
+ /**
+ * Gets the length of the current run within the iterator's range.
+ *
+ * @return the number of remaining values in current run within the
iterator's range
+ */
+ public int getRunLength() {
+ return Math.min(runEnd, rangeEnd) - logicalPos;
+ }
+
+ /**
+ * Checks if iteration has completed.
+ *
+ * @return true if all values in the range have been processed, false
otherwise
+ */
+ public boolean isEnd() {
+ return logicalPos >= rangeEnd;
+ }
+ }
}
diff --git
a/vector/src/test/java/org/apache/arrow/vector/TestRunEndEncodedVector.java
b/vector/src/test/java/org/apache/arrow/vector/TestRunEndEncodedVector.java
index adf51c073..9fa153e92 100644
--- a/vector/src/test/java/org/apache/arrow/vector/TestRunEndEncodedVector.java
+++ b/vector/src/test/java/org/apache/arrow/vector/TestRunEndEncodedVector.java
@@ -148,12 +148,18 @@ public class TestRunEndEncodedVector {
assertTrue(
constantVector.accept(
new RangeEqualsVisitor(constantVector, constantVector), new
Range(1, 2, 13)));
- assertFalse(
- constantVector.accept(
- new RangeEqualsVisitor(constantVector, constantVector), new
Range(1, 10, 10)));
- assertFalse(
- constantVector.accept(
- new RangeEqualsVisitor(constantVector, constantVector), new
Range(10, 1, 10)));
+
+ // throws exception if the range end is out the bound of the vector
+ assertThrows(
+ IllegalArgumentException.class,
+ () ->
+ constantVector.accept(
+ new RangeEqualsVisitor(constantVector, constantVector), new
Range(1, 10, 10)));
+ assertThrows(
+ IllegalArgumentException.class,
+ () ->
+ constantVector.accept(
+ new RangeEqualsVisitor(constantVector, constantVector), new
Range(10, 1, 10)));
// Create REE vector representing: [1, 2, 2, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5,
5, 5].
RunEndEncodedVector reeVector =
diff --git
a/vector/src/test/java/org/apache/arrow/vector/compare/TestRangeEqualsVisitor.java
b/vector/src/test/java/org/apache/arrow/vector/compare/TestRangeEqualsVisitor.java
index 08da786eb..962473435 100644
---
a/vector/src/test/java/org/apache/arrow/vector/compare/TestRangeEqualsVisitor.java
+++
b/vector/src/test/java/org/apache/arrow/vector/compare/TestRangeEqualsVisitor.java
@@ -22,6 +22,7 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
import java.nio.charset.Charset;
import java.util.Arrays;
+import java.util.List;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.BigIntVector;
@@ -39,6 +40,7 @@ import org.apache.arrow.vector.complex.FixedSizeListVector;
import org.apache.arrow.vector.complex.LargeListViewVector;
import org.apache.arrow.vector.complex.ListVector;
import org.apache.arrow.vector.complex.ListViewVector;
+import org.apache.arrow.vector.complex.RunEndEncodedVector;
import org.apache.arrow.vector.complex.StructVector;
import org.apache.arrow.vector.complex.UnionVector;
import org.apache.arrow.vector.complex.impl.NullableStructWriter;
@@ -53,7 +55,9 @@ import org.apache.arrow.vector.holders.NullableIntHolder;
import org.apache.arrow.vector.holders.NullableUInt4Holder;
import org.apache.arrow.vector.types.FloatingPointPrecision;
import org.apache.arrow.vector.types.Types;
+import org.apache.arrow.vector.types.Types.MinorType;
import org.apache.arrow.vector.types.pojo.ArrowType;
+import org.apache.arrow.vector.types.pojo.ArrowType.RunEndEncoded;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.FieldType;
import org.junit.jupiter.api.AfterEach;
@@ -1003,6 +1007,54 @@ public class TestRangeEqualsVisitor {
}
}
+ @Test
+ public void testRunEndEncodedFloat8ApproxEquals() {
+ try (final Float8Vector vector1 = new Float8Vector("float", allocator);
+ final Float8Vector vector2 = new Float8Vector("float", allocator);
+ final Float8Vector vector3 = new Float8Vector("float", allocator);
+ final IntVector reeVector = new IntVector("ree", allocator)) {
+
+ final float epsilon = 1.0E-6f;
+ setVector(vector1, 1.1, 2.2);
+ setVector(vector2, 1.1 + epsilon / 2, 2.2 + epsilon / 2);
+ setVector(vector3, 1.1 + epsilon * 2, 2.2 + epsilon * 2);
+ setVector(reeVector, 1, 3);
+
+ ArrowType type = MinorType.FLOAT8.getType();
+ final FieldType valueType = FieldType.notNullable(type);
+ final FieldType runEndType =
FieldType.notNullable(MinorType.INT.getType());
+
+ final Field valueField = new Field("value", valueType, null);
+ final Field runEndField = new Field("ree", runEndType, null);
+
+ Field field =
+ new Field(
+ "ree_float",
+ FieldType.notNullable(RunEndEncoded.INSTANCE),
+ List.of(runEndField, valueField));
+
+ try (final RunEndEncodedVector encodedVector1 =
+ new RunEndEncodedVector(field, allocator, reeVector, vector1,
null);
+ final RunEndEncodedVector encodedVector2 =
+ new RunEndEncodedVector(field, allocator, reeVector, vector2,
null);
+ final RunEndEncodedVector encodedVector3 =
+ new RunEndEncodedVector(field, allocator, reeVector, vector3,
null)) {
+
+ encodedVector1.setValueCount(3);
+ encodedVector2.setValueCount(3);
+ encodedVector3.setValueCount(3);
+
+ Range range = new Range(0, 0, encodedVector1.getValueCount());
+ assertTrue(
+ new ApproxEqualsVisitor(encodedVector1, encodedVector2, epsilon,
epsilon)
+ .rangeEquals(range));
+ assertFalse(
+ new ApproxEqualsVisitor(encodedVector1, encodedVector3, epsilon,
epsilon)
+ .rangeEquals(range));
+ }
+ }
+ }
+
private void writeStructVector(NullableStructWriter writer, int value1, long
value2) {
writer.start();
writer.integer("f0").writeInt(value1);