This is an automated email from the ASF dual-hosted git repository.
lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new a4080209a9 GH-38662: [Java] Add comparators (#38669)
a4080209a9 is described below
commit a4080209a97a5d66accdeb71c5c1ffa982fed51e
Author: James Duong <[email protected]>
AuthorDate: Tue Nov 14 05:10:36 2023 -0800
GH-38662: [Java] Add comparators (#38669)
### Rationale for this change
Add missing Default VectorValueComparators for some more types.
### What changes are included in this PR?
Add comparators for:
- FixedSizeBinaryVector
- LargeListVector
- FixedSizeListVector
- NullVector
### Are these changes tested?
Yes, unit tests added.
### Are there any user-facing changes?
No
* Closes: #38662
Authored-by: James Duong <[email protected]>
Signed-off-by: David Li <[email protected]>
---
.../algorithm/sort/DefaultVectorComparators.java | 140 +++++++++++++++++++--
.../sort/TestDefaultVectorComparator.java | 132 +++++++++++++++++++
2 files changed, 259 insertions(+), 13 deletions(-)
diff --git
a/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/DefaultVectorComparators.java
b/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/DefaultVectorComparators.java
index 4f9c8b7d71..588876aa99 100644
---
a/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/DefaultVectorComparators.java
+++
b/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/DefaultVectorComparators.java
@@ -32,11 +32,13 @@ import org.apache.arrow.vector.DateMilliVector;
import org.apache.arrow.vector.Decimal256Vector;
import org.apache.arrow.vector.DecimalVector;
import org.apache.arrow.vector.DurationVector;
+import org.apache.arrow.vector.FixedSizeBinaryVector;
import org.apache.arrow.vector.Float4Vector;
import org.apache.arrow.vector.Float8Vector;
import org.apache.arrow.vector.IntVector;
import org.apache.arrow.vector.IntervalDayVector;
import org.apache.arrow.vector.IntervalMonthDayNanoVector;
+import org.apache.arrow.vector.NullVector;
import org.apache.arrow.vector.SmallIntVector;
import org.apache.arrow.vector.TimeMicroVector;
import org.apache.arrow.vector.TimeMilliVector;
@@ -50,7 +52,9 @@ import org.apache.arrow.vector.UInt4Vector;
import org.apache.arrow.vector.UInt8Vector;
import org.apache.arrow.vector.ValueVector;
import org.apache.arrow.vector.VariableWidthVector;
-import org.apache.arrow.vector.complex.BaseRepeatedValueVector;
+import org.apache.arrow.vector.complex.FixedSizeListVector;
+import org.apache.arrow.vector.complex.RepeatedValueVector;
+import org.apache.arrow.vector.holders.NullableFixedSizeBinaryHolder;
/**
* Default comparator implementations for different types of vectors.
@@ -111,13 +115,21 @@ public class DefaultVectorComparators {
return (VectorValueComparator<T>) new TimeSecComparator();
} else if (vector instanceof TimeStampVector) {
return (VectorValueComparator<T>) new TimeStampComparator();
+ } else if (vector instanceof FixedSizeBinaryVector) {
+ return (VectorValueComparator<T>) new FixedSizeBinaryComparator();
}
} else if (vector instanceof VariableWidthVector) {
return (VectorValueComparator<T>) new VariableWidthComparator();
- } else if (vector instanceof BaseRepeatedValueVector) {
+ } else if (vector instanceof RepeatedValueVector) {
VectorValueComparator<?> innerComparator =
- createDefaultComparator(((BaseRepeatedValueVector)
vector).getDataVector());
+ createDefaultComparator(((RepeatedValueVector)
vector).getDataVector());
return new RepeatedValueComparator(innerComparator);
+ } else if (vector instanceof FixedSizeListVector) {
+ VectorValueComparator<?> innerComparator =
+ createDefaultComparator(((FixedSizeListVector)
vector).getDataVector());
+ return new FixedSizeListComparator(innerComparator);
+ } else if (vector instanceof NullVector) {
+ return (VectorValueComparator<T>) new NullComparator();
}
throw new IllegalArgumentException("No default comparator for " +
vector.getClass().getCanonicalName());
@@ -674,6 +686,61 @@ public class DefaultVectorComparators {
}
}
+ /**
+ * Default comparator for {@link
org.apache.arrow.vector.FixedSizeBinaryVector}.
+ * The comparison is in lexicographic order, with null comes first.
+ */
+ public static class FixedSizeBinaryComparator extends
VectorValueComparator<FixedSizeBinaryVector> {
+
+ @Override
+ public int compare(int index1, int index2) {
+ NullableFixedSizeBinaryHolder holder1 = new
NullableFixedSizeBinaryHolder();
+ NullableFixedSizeBinaryHolder holder2 = new
NullableFixedSizeBinaryHolder();
+ vector1.get(index1, holder1);
+ vector2.get(index2, holder2);
+
+ return ByteFunctionHelpers.compare(
+ holder1.buffer, 0, holder1.byteWidth, holder2.buffer, 0,
holder2.byteWidth);
+ }
+
+ @Override
+ public int compareNotNull(int index1, int index2) {
+ NullableFixedSizeBinaryHolder holder1 = new
NullableFixedSizeBinaryHolder();
+ NullableFixedSizeBinaryHolder holder2 = new
NullableFixedSizeBinaryHolder();
+ vector1.get(index1, holder1);
+ vector2.get(index2, holder2);
+
+ return ByteFunctionHelpers.compare(
+ holder1.buffer, 0, holder1.byteWidth, holder2.buffer, 0,
holder2.byteWidth);
+ }
+
+ @Override
+ public VectorValueComparator<FixedSizeBinaryVector> createNew() {
+ return new FixedSizeBinaryComparator();
+ }
+ }
+
+ /**
+ * Default comparator for {@link org.apache.arrow.vector.NullVector}.
+ */
+ public static class NullComparator extends VectorValueComparator<NullVector>
{
+ @Override
+ public int compare(int index1, int index2) {
+ // Values are always equal (and are always null).
+ return 0;
+ }
+
+ @Override
+ public int compareNotNull(int index1, int index2) {
+ throw new AssertionError("Cannot compare non-null values in a
NullVector.");
+ }
+
+ @Override
+ public VectorValueComparator<NullVector> createNew() {
+ return new NullComparator();
+ }
+ }
+
/**
* Default comparator for {@link
org.apache.arrow.vector.VariableWidthVector}.
* The comparison is in lexicographic order, with null comes first.
@@ -705,14 +772,14 @@ public class DefaultVectorComparators {
}
/**
- * Default comparator for {@link BaseRepeatedValueVector}.
+ * Default comparator for {@link RepeatedValueVector}.
* It works by comparing the underlying vector in a lexicographic order.
* @param <T> inner vector type.
*/
public static class RepeatedValueComparator<T extends ValueVector>
- extends VectorValueComparator<BaseRepeatedValueVector> {
+ extends VectorValueComparator<RepeatedValueVector> {
- private VectorValueComparator<T> innerComparator;
+ private final VectorValueComparator<T> innerComparator;
public RepeatedValueComparator(VectorValueComparator<T> innerComparator) {
this.innerComparator = innerComparator;
@@ -720,16 +787,16 @@ public class DefaultVectorComparators {
@Override
public int compareNotNull(int index1, int index2) {
- int startIdx1 = vector1.getOffsetBuffer().getInt(index1 * OFFSET_WIDTH);
- int startIdx2 = vector2.getOffsetBuffer().getInt(index2 * OFFSET_WIDTH);
+ int startIdx1 = vector1.getOffsetBuffer().getInt((long) index1 *
OFFSET_WIDTH);
+ int startIdx2 = vector2.getOffsetBuffer().getInt((long) index2 *
OFFSET_WIDTH);
- int endIdx1 = vector1.getOffsetBuffer().getInt((index1 + 1) *
OFFSET_WIDTH);
- int endIdx2 = vector2.getOffsetBuffer().getInt((index2 + 1) *
OFFSET_WIDTH);
+ int endIdx1 = vector1.getOffsetBuffer().getInt((long) (index1 + 1) *
OFFSET_WIDTH);
+ int endIdx2 = vector2.getOffsetBuffer().getInt((long) (index2 + 1) *
OFFSET_WIDTH);
int length1 = endIdx1 - startIdx1;
int length2 = endIdx2 - startIdx2;
- int length = length1 < length2 ? length1 : length2;
+ int length = Math.min(length1, length2);
for (int i = 0; i < length; i++) {
int result = innerComparator.compare(startIdx1 + i, startIdx2 + i);
@@ -741,13 +808,60 @@ public class DefaultVectorComparators {
}
@Override
- public VectorValueComparator<BaseRepeatedValueVector> createNew() {
+ public VectorValueComparator<RepeatedValueVector> createNew() {
VectorValueComparator<T> newInnerComparator =
innerComparator.createNew();
return new RepeatedValueComparator<>(newInnerComparator);
}
@Override
- public void attachVectors(BaseRepeatedValueVector vector1,
BaseRepeatedValueVector vector2) {
+ public void attachVectors(RepeatedValueVector vector1, RepeatedValueVector
vector2) {
+ this.vector1 = vector1;
+ this.vector2 = vector2;
+
+ innerComparator.attachVectors((T) vector1.getDataVector(), (T)
vector2.getDataVector());
+ }
+ }
+
+ /**
+ * Default comparator for {@link RepeatedValueVector}.
+ * It works by comparing the underlying vector in a lexicographic order.
+ * @param <T> inner vector type.
+ */
+ public static class FixedSizeListComparator<T extends ValueVector>
+ extends VectorValueComparator<FixedSizeListVector> {
+
+ private final VectorValueComparator<T> innerComparator;
+
+ public FixedSizeListComparator(VectorValueComparator<T> innerComparator) {
+ this.innerComparator = innerComparator;
+ }
+
+ @Override
+ public int compareNotNull(int index1, int index2) {
+ int length1 = vector1.getListSize();
+ int length2 = vector2.getListSize();
+
+ int length = Math.min(length1, length2);
+ int startIdx1 = vector1.getElementStartIndex(index1);
+ int startIdx2 = vector2.getElementStartIndex(index2);
+
+ for (int i = 0; i < length; i++) {
+ int result = innerComparator.compare(startIdx1 + i, startIdx2 + i);
+ if (result != 0) {
+ return result;
+ }
+ }
+ return length1 - length2;
+ }
+
+ @Override
+ public VectorValueComparator<FixedSizeListVector> createNew() {
+ VectorValueComparator<T> newInnerComparator =
innerComparator.createNew();
+ return new FixedSizeListComparator<>(newInnerComparator);
+ }
+
+ @Override
+ public void attachVectors(FixedSizeListVector vector1, FixedSizeListVector
vector2) {
this.vector1 = vector1;
this.vector2 = vector2;
diff --git
a/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestDefaultVectorComparator.java
b/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestDefaultVectorComparator.java
index bdae85110a..43c634b764 100644
---
a/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestDefaultVectorComparator.java
+++
b/java/algorithm/src/test/java/org/apache/arrow/algorithm/sort/TestDefaultVectorComparator.java
@@ -31,12 +31,14 @@ import org.apache.arrow.vector.DateMilliVector;
import org.apache.arrow.vector.Decimal256Vector;
import org.apache.arrow.vector.DecimalVector;
import org.apache.arrow.vector.DurationVector;
+import org.apache.arrow.vector.FixedSizeBinaryVector;
import org.apache.arrow.vector.Float4Vector;
import org.apache.arrow.vector.Float8Vector;
import org.apache.arrow.vector.IntVector;
import org.apache.arrow.vector.IntervalDayVector;
import org.apache.arrow.vector.LargeVarBinaryVector;
import org.apache.arrow.vector.LargeVarCharVector;
+import org.apache.arrow.vector.NullVector;
import org.apache.arrow.vector.SmallIntVector;
import org.apache.arrow.vector.TimeMicroVector;
import org.apache.arrow.vector.TimeMilliVector;
@@ -52,6 +54,8 @@ import org.apache.arrow.vector.UInt8Vector;
import org.apache.arrow.vector.ValueVector;
import org.apache.arrow.vector.VarBinaryVector;
import org.apache.arrow.vector.VarCharVector;
+import org.apache.arrow.vector.complex.FixedSizeListVector;
+import org.apache.arrow.vector.complex.LargeListVector;
import org.apache.arrow.vector.complex.ListVector;
import org.apache.arrow.vector.testing.ValueVectorDataPopulator;
import org.apache.arrow.vector.types.TimeUnit;
@@ -158,6 +162,61 @@ public class TestDefaultVectorComparator {
}
}
+ private FixedSizeListVector createFixedSizeListVector(int count) {
+ FixedSizeListVector listVector = FixedSizeListVector.empty("list vector",
count, allocator);
+ Types.MinorType type = Types.MinorType.INT;
+ listVector.addOrGetVector(FieldType.nullable(type.getType()));
+ listVector.allocateNew();
+
+ IntVector dataVector = (IntVector) listVector.getDataVector();
+
+ for (int i = 0; i < count; i++) {
+ dataVector.set(i, i);
+ }
+ dataVector.setValueCount(count);
+
+ listVector.setNotNull(0);
+ listVector.setValueCount(1);
+
+ return listVector;
+ }
+
+ @Test
+ public void testCompareFixedSizeLists() {
+ try (FixedSizeListVector listVector1 = createFixedSizeListVector(10);
+ FixedSizeListVector listVector2 = createFixedSizeListVector(11)) {
+ VectorValueComparator<FixedSizeListVector> comparator =
+ DefaultVectorComparators.createDefaultComparator(listVector1);
+ comparator.attachVectors(listVector1, listVector2);
+
+ // prefix is smaller
+ assertTrue(comparator.compare(0, 0) < 0);
+ }
+
+ try (FixedSizeListVector listVector1 = createFixedSizeListVector(11);
+ FixedSizeListVector listVector2 = createFixedSizeListVector(11)) {
+ ((IntVector) listVector2.getDataVector()).set(10, 110);
+
+ VectorValueComparator<FixedSizeListVector> comparator =
+ DefaultVectorComparators.createDefaultComparator(listVector1);
+ comparator.attachVectors(listVector1, listVector2);
+
+ // breaking tie by the last element
+ assertTrue(comparator.compare(0, 0) < 0);
+ }
+
+ try (FixedSizeListVector listVector1 = createFixedSizeListVector(10);
+ FixedSizeListVector listVector2 = createFixedSizeListVector(10)) {
+
+ VectorValueComparator<FixedSizeListVector> comparator =
+ DefaultVectorComparators.createDefaultComparator(listVector1);
+ comparator.attachVectors(listVector1, listVector2);
+
+ // list vector elements equal
+ assertTrue(comparator.compare(0, 0) == 0);
+ }
+ }
+
@Test
public void testCompareUInt1() {
try (UInt1Vector vec = new UInt1Vector("", allocator)) {
@@ -845,6 +904,65 @@ public class TestDefaultVectorComparator {
}
}
+ @Test
+ public void testCompareFixedSizeBinary() {
+ try (FixedSizeBinaryVector vector1 = new FixedSizeBinaryVector("test1",
allocator, 2);
+ FixedSizeBinaryVector vector2 = new FixedSizeBinaryVector("test1",
allocator, 3)) {
+ vector1.allocateNew();
+ vector2.allocateNew();
+ vector1.set(0, new byte[] {1, 1});
+ vector2.set(0, new byte[] {1, 1, 0});
+ VectorValueComparator<FixedSizeBinaryVector> comparator =
+ DefaultVectorComparators.createDefaultComparator(vector1);
+ comparator.attachVectors(vector1, vector2);
+
+ // prefix is smaller
+ assertTrue(comparator.compare(0, 0) < 0);
+ }
+
+ try (FixedSizeBinaryVector vector1 = new FixedSizeBinaryVector("test1",
allocator, 3);
+ FixedSizeBinaryVector vector2 = new FixedSizeBinaryVector("test1",
allocator, 3)) {
+ vector1.allocateNew();
+ vector2.allocateNew();
+ vector1.set(0, new byte[] {1, 1, 0});
+ vector2.set(0, new byte[] {1, 1, 1});
+ VectorValueComparator<FixedSizeBinaryVector> comparator =
+ DefaultVectorComparators.createDefaultComparator(vector1);
+ comparator.attachVectors(vector1, vector2);
+
+ // breaking tie by the last element
+ assertTrue(comparator.compare(0, 0) < 0);
+ }
+
+ try (FixedSizeBinaryVector vector1 = new FixedSizeBinaryVector("test1",
allocator, 3);
+ FixedSizeBinaryVector vector2 = new FixedSizeBinaryVector("test1",
allocator, 3)) {
+ vector1.allocateNew();
+ vector2.allocateNew();
+ vector1.set(0, new byte[] {1, 1, 1});
+ vector2.set(0, new byte[] {1, 1, 1});
+ VectorValueComparator<FixedSizeBinaryVector> comparator =
+ DefaultVectorComparators.createDefaultComparator(vector1);
+ comparator.attachVectors(vector1, vector2);
+
+ // list vector elements equal
+ assertTrue(comparator.compare(0, 0) == 0);
+ }
+ }
+
+ @Test
+ public void testCompareNull() {
+ try (NullVector vec = new NullVector("test",
+ FieldType.notNullable(new ArrowType.Int(32, false)))) {
+ vec.setValueCount(2);
+
+ VectorValueComparator<NullVector> comparator =
+ DefaultVectorComparators.createDefaultComparator(vec);
+ comparator.attachVector(vec);
+ assertEquals(DefaultVectorComparators.NullComparator.class,
comparator.getClass());
+ assertEquals(0, comparator.compare(0, 1));
+ }
+ }
+
@Test
public void testCheckNullsOnCompareIsFalseForNonNullableVector() {
try (IntVector vec = new IntVector("not nullable",
@@ -937,4 +1055,18 @@ public class TestDefaultVectorComparator {
VectorValueComparator<V> comparator =
DefaultVectorComparators.createDefaultComparator(vec);
assertEquals(DefaultVectorComparators.VariableWidthComparator.class,
comparator.getClass());
}
+
+ @Test
+ public void testRepeatedDefaultComparators() {
+ final FieldType type = FieldType.nullable(Types.MinorType.INT.getType());
+ try (final LargeListVector vector = new LargeListVector("list", allocator,
type, null)) {
+ vector.addOrGetVector(FieldType.nullable(type.getType()));
+ verifyRepeatedComparatorReturned(vector);
+ }
+ }
+
+ private static <V extends ValueVector> void
verifyRepeatedComparatorReturned(V vec) {
+ VectorValueComparator<V> comparator =
DefaultVectorComparators.createDefaultComparator(vec);
+ assertEquals(DefaultVectorComparators.RepeatedValueComparator.class,
comparator.getClass());
+ }
}