This is an automated email from the ASF dual-hosted git repository. ravindra pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push: new 6c1fccd ARROW-6022: [Java] Support equals API in ValueVector to compare two vectors equal 6c1fccd is described below commit 6c1fccdb5bb040c981b8b12e9c27e565343e8952 Author: tianchen <niki...@alibaba-inc.com> AuthorDate: Mon Aug 12 14:19:51 2019 +0530 ARROW-6022: [Java] Support equals API in ValueVector to compare two vectors equal Related to [ARROW-6022](https://issues.apache.org/jira/browse/ARROW-6022). In some case, this feature is useful. In ARROW-1184, Dictionary#equals not work due to the lack of this API. Moreover, we already implemented equals(int index, ValueVector target, int targetIndex), so this new added API could reuse it. Closes #4933 from tianchen92/ARROW-6022 and squashes the following commits: 7e20f79d5 <tianchen> remove CompareUtility a5d22fd4f <tianchen> fix variable names 226a20f8c <tianchen> refactor 694d9f6b9 <tianchen> make equals to visitor mode 0dfa943e9 <tianchen> compare struct child names and add UT c7081c274 <tianchen> check list validity bit b942794b2 <tianchen> use ArrowType for equal 1d95c9c7f <tianchen> fix Decimal equals 3c9f06600 <tianchen> fix 002688296 <tianchen> use MinorType and check isSet e58c15834 <tianchen> refactor Dictionary#equals 10dca2ccb <tianchen> fix 6bc3f681c <tianchen> ARROW-6022: Support equals API in ValueVector to compare two vectors equal Authored-by: tianchen <niki...@alibaba-inc.com> Signed-off-by: Pindikura Ravindra <ravin...@dremio.com> --- .../src/main/codegen/templates/UnionVector.java | 24 +- .../apache/arrow/vector/BaseFixedWidthVector.java | 27 +- .../arrow/vector/BaseVariableWidthVector.java | 24 +- .../org/apache/arrow/vector/DecimalVector.java | 1 - .../apache/arrow/vector/ExtensionTypeVector.java | 6 + .../java/org/apache/arrow/vector/ValueVector.java | 8 + .../java/org/apache/arrow/vector/ZeroVector.java | 6 + .../arrow/vector/compare/RangeEqualsVisitor.java | 288 +++++++++++++++ .../arrow/vector/compare/VectorEqualsVisitor.java | 39 ++ .../arrow/vector/complex/AbstractStructVector.java | 5 +- .../arrow/vector/complex/FixedSizeListVector.java | 21 +- .../apache/arrow/vector/complex/ListVector.java | 34 +- .../vector/complex/NonNullableStructVector.java | 43 +-- .../apache/arrow/vector/dictionary/Dictionary.java | 14 +- .../apache/arrow/vector/TestDictionaryVector.java | 24 +- .../org/apache/arrow/vector/TestValueVector.java | 395 ++++++++++++++++++++- 16 files changed, 836 insertions(+), 123 deletions(-) diff --git a/java/vector/src/main/codegen/templates/UnionVector.java b/java/vector/src/main/codegen/templates/UnionVector.java index 59cc91f..6220e51 100644 --- a/java/vector/src/main/codegen/templates/UnionVector.java +++ b/java/vector/src/main/codegen/templates/UnionVector.java @@ -18,8 +18,10 @@ import io.netty.buffer.ArrowBuf; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.ReferenceManager; +import org.apache.arrow.util.Preconditions; import org.apache.arrow.vector.ValueVector; import org.apache.arrow.vector.types.UnionMode; +import org.apache.arrow.vector.compare.RangeEqualsVisitor; import org.apache.arrow.vector.types.pojo.FieldType; import org.apache.arrow.vector.util.CallBack; @@ -36,6 +38,7 @@ import io.netty.buffer.ArrowBuf; import java.util.ArrayList; import java.util.Collections; import java.util.Iterator; +import org.apache.arrow.vector.compare.RangeEqualsVisitor; import org.apache.arrow.vector.complex.impl.ComplexCopier; import org.apache.arrow.vector.util.CallBack; import org.apache.arrow.vector.ipc.message.ArrowFieldNode; @@ -677,16 +680,17 @@ public class UnionVector implements FieldVector { if (to == null) { return false; } - if (this.getClass() != to.getClass()) { - return false; - } - UnionVector that = (UnionVector) to; - ValueVector leftVector = getVector(index); - ValueVector rightVector = that.getVector(toIndex); + Preconditions.checkArgument(index >= 0 && index < valueCount, + "index %s out of range[0, %s]:", index, valueCount - 1); + Preconditions.checkArgument(toIndex >= 0 && toIndex < to.getValueCount(), + "index %s out of range[0, %s]:", index, to.getValueCount() - 1); - if (leftVector.getClass() != rightVector.getClass()) { - return false; - } - return leftVector.equals(index, rightVector, toIndex); + RangeEqualsVisitor visitor = new RangeEqualsVisitor(to, index, toIndex, 1); + return this.accept(visitor); + } + + @Override + public boolean accept(RangeEqualsVisitor visitor) { + return visitor.visit(this); } } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/BaseFixedWidthVector.java b/java/vector/src/main/java/org/apache/arrow/vector/BaseFixedWidthVector.java index aca34af..4e953a3 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/BaseFixedWidthVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/BaseFixedWidthVector.java @@ -26,6 +26,7 @@ import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.util.ArrowBufPointer; import org.apache.arrow.memory.util.ByteFunctionHelpers; import org.apache.arrow.util.Preconditions; +import org.apache.arrow.vector.compare.RangeEqualsVisitor; import org.apache.arrow.vector.ipc.message.ArrowFieldNode; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.util.CallBack; @@ -71,6 +72,10 @@ public abstract class BaseFixedWidthVector extends BaseValueVector } + public int getTypeWidth() { + return typeWidth; + } + @Override public String getName() { return field.getName(); @@ -870,20 +875,18 @@ public abstract class BaseFixedWidthVector extends BaseValueVector if (to == null) { return false; } - if (this.getClass() != to.getClass()) { - return false; - } - BaseFixedWidthVector that = (BaseFixedWidthVector) to; + Preconditions.checkArgument(index >= 0 && index < valueCount, + "index %s out of range[0, %s]:", index, valueCount - 1); + Preconditions.checkArgument(toIndex >= 0 && toIndex < to.getValueCount(), + "index %s out of range[0, %s]:", index, to.getValueCount() - 1); - int leftStart = typeWidth * index; - int leftEnd = typeWidth * (index + 1); - - int rightStart = typeWidth * toIndex; - int rightEnd = typeWidth * (toIndex + 1); + RangeEqualsVisitor visitor = new RangeEqualsVisitor(to, index, toIndex, 1); + return this.accept(visitor); + } - int ret = ByteFunctionHelpers.equal(this.getDataBuffer(), leftStart, leftEnd, - that.getDataBuffer(), rightStart, rightEnd); - return ret == 1; + @Override + public boolean accept(RangeEqualsVisitor visitor) { + return visitor.visit(this); } } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/BaseVariableWidthVector.java b/java/vector/src/main/java/org/apache/arrow/vector/BaseVariableWidthVector.java index 5bb54b9..b7aa816 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/BaseVariableWidthVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/BaseVariableWidthVector.java @@ -28,6 +28,7 @@ import org.apache.arrow.memory.OutOfMemoryException; import org.apache.arrow.memory.util.ArrowBufPointer; import org.apache.arrow.memory.util.ByteFunctionHelpers; import org.apache.arrow.util.Preconditions; +import org.apache.arrow.vector.compare.RangeEqualsVisitor; import org.apache.arrow.vector.ipc.message.ArrowFieldNode; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.util.CallBack; @@ -1369,20 +1370,17 @@ public abstract class BaseVariableWidthVector extends BaseValueVector if (to == null) { return false; } - if (this.getClass() != to.getClass()) { - return false; - } - - BaseVariableWidthVector that = (BaseVariableWidthVector) to; + Preconditions.checkArgument(index >= 0 && index < valueCount, + "index %s out of range[0, %s]:", index, valueCount - 1); + Preconditions.checkArgument(toIndex >= 0 && toIndex < to.getValueCount(), + "index %s out of range[0, %s]:", index, to.getValueCount() - 1); - final int leftStart = getStartOffset(index); - final int leftEnd = getStartOffset(index + 1); - - final int rightStart = that.getStartOffset(toIndex); - final int rightEnd = that.getStartOffset(toIndex + 1); + RangeEqualsVisitor visitor = new RangeEqualsVisitor(to, index, toIndex, 1); + return this.accept(visitor); + } - int ret = ByteFunctionHelpers.equal(this.getDataBuffer(), leftStart, leftEnd, - that.getDataBuffer(), rightStart, rightEnd); - return ret == 1; + @Override + public boolean accept(RangeEqualsVisitor visitor) { + return visitor.visit(this); } } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/DecimalVector.java b/java/vector/src/main/java/org/apache/arrow/vector/DecimalVector.java index cf77186..5a450e6 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/DecimalVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/DecimalVector.java @@ -489,7 +489,6 @@ public class DecimalVector extends BaseFixedWidthVector { set(index, isSet, start, buffer); } - /*----------------------------------------------------------------* | | | vector transfer | diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ExtensionTypeVector.java b/java/vector/src/main/java/org/apache/arrow/vector/ExtensionTypeVector.java index 14a66f8..02dc8e3 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/ExtensionTypeVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/ExtensionTypeVector.java @@ -22,6 +22,7 @@ import java.util.List; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.OutOfMemoryException; +import org.apache.arrow.vector.compare.RangeEqualsVisitor; import org.apache.arrow.vector.complex.reader.FieldReader; import org.apache.arrow.vector.ipc.message.ArrowFieldNode; import org.apache.arrow.vector.types.Types.MinorType; @@ -256,4 +257,9 @@ public abstract class ExtensionTypeVector<T extends BaseValueVector & FieldVecto public BufferAllocator getAllocator() { return underlyingVector.getAllocator(); } + + @Override + public boolean accept(RangeEqualsVisitor visitor) { + return visitor.visit(getUnderlyingVector()); + } } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ValueVector.java b/java/vector/src/main/java/org/apache/arrow/vector/ValueVector.java index 4d98397..aed24f2 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/ValueVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/ValueVector.java @@ -21,6 +21,7 @@ import java.io.Closeable; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.OutOfMemoryException; +import org.apache.arrow.vector.compare.RangeEqualsVisitor; import org.apache.arrow.vector.complex.reader.FieldReader; import org.apache.arrow.vector.types.Types.MinorType; import org.apache.arrow.vector.types.pojo.Field; @@ -272,4 +273,11 @@ public interface ValueVector extends Closeable, Iterable<ValueVector> { * @param from source vector */ void copyFromSafe(int fromIndex, int thisIndex, ValueVector from); + + /** + * Compare range values in this vector and vector in visitor. + * @param visitor visitor which holds the vector to compare. + * @return true if equals, otherwise false. + */ + boolean accept(RangeEqualsVisitor visitor); } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ZeroVector.java b/java/vector/src/main/java/org/apache/arrow/vector/ZeroVector.java index 2a6e21c..4d5826c 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/ZeroVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/ZeroVector.java @@ -25,6 +25,7 @@ import java.util.List; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.OutOfMemoryException; +import org.apache.arrow.vector.compare.RangeEqualsVisitor; import org.apache.arrow.vector.complex.impl.NullReader; import org.apache.arrow.vector.complex.reader.FieldReader; import org.apache.arrow.vector.ipc.message.ArrowFieldNode; @@ -264,4 +265,9 @@ public class ZeroVector implements FieldVector { public void copyFromSafe(int fromIndex, int thisIndex, ValueVector from) { throw new UnsupportedOperationException(); } + + @Override + public boolean accept(RangeEqualsVisitor visitor) { + return true; + } } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/compare/RangeEqualsVisitor.java b/java/vector/src/main/java/org/apache/arrow/vector/compare/RangeEqualsVisitor.java new file mode 100644 index 0000000..19cf79e --- /dev/null +++ b/java/vector/src/main/java/org/apache/arrow/vector/compare/RangeEqualsVisitor.java @@ -0,0 +1,288 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.vector.compare; + +import java.util.List; + +import org.apache.arrow.memory.util.ByteFunctionHelpers; +import org.apache.arrow.vector.BaseFixedWidthVector; +import org.apache.arrow.vector.BaseVariableWidthVector; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.ZeroVector; +import org.apache.arrow.vector.complex.BaseRepeatedValueVector; +import org.apache.arrow.vector.complex.FixedSizeListVector; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.NonNullableStructVector; +import org.apache.arrow.vector.complex.UnionVector; + +/** + * Visitor to compare a range of values for vectors. + */ +public class RangeEqualsVisitor { + + protected final ValueVector right; + protected final int leftStart; + protected final int rightStart; + protected final int length; + + /** + * Constructs a new instance. + */ + public RangeEqualsVisitor(ValueVector right, int leftStart, int rightStart, int length) { + this.leftStart = leftStart; + this.rightStart = rightStart; + this.right = right; + this.length = length; + } + + public boolean visit(BaseFixedWidthVector left) { + return compareBaseFixedWidthVectors(left); + } + + public boolean visit(BaseVariableWidthVector left) { + return compareBaseVariableWidthVectors(left); + } + + public boolean visit(ListVector left) { + return compareListVectors(left); + } + + public boolean visit(FixedSizeListVector left) { + return compareFixedSizeListVectors(left); + } + + public boolean visit(NonNullableStructVector left) { + return compareStructVectors(left); + } + + public boolean visit(UnionVector left) { + return compareUnionVectors(left); + } + + public boolean visit(ZeroVector left) { + return true; + } + + public boolean visit(ValueVector left) { + throw new UnsupportedOperationException(); + } + + protected boolean compareValueVector(ValueVector left, ValueVector right) { + return left.getField().getType().equals(right.getField().getType()); + } + + protected boolean compareUnionVectors(UnionVector left) { + + if (!compareValueVector(left, right)) { + return false; + } + + UnionVector rightVector = (UnionVector) right; + + List<FieldVector> leftChildren = left.getChildrenFromFields(); + List<FieldVector> rightChildren = rightVector.getChildrenFromFields(); + + if (leftChildren.size() != rightChildren.size()) { + return false; + } + + for (int k = 0; k < leftChildren.size(); k++) { + RangeEqualsVisitor visitor = new RangeEqualsVisitor(rightChildren.get(k), + leftStart, rightStart, length); + if (!leftChildren.get(k).accept(visitor)) { + return false; + } + } + return true; + } + + protected boolean compareStructVectors(NonNullableStructVector left) { + if (!compareValueVector(left, right)) { + return false; + } + + NonNullableStructVector rightVector = (NonNullableStructVector) right; + + if (!left.getChildFieldNames().equals(rightVector.getChildFieldNames())) { + return false; + } + + for (String name : left.getChildFieldNames()) { + RangeEqualsVisitor visitor = new RangeEqualsVisitor(rightVector.getChild(name), + leftStart, rightStart, length); + if (!left.getChild(name).accept(visitor)) { + return false; + } + } + + return true; + } + + protected boolean compareBaseFixedWidthVectors(BaseFixedWidthVector left) { + + if (!compareValueVector(left, right)) { + return false; + } + + for (int i = 0; i < length; i++) { + int leftIndex = leftStart + i; + int rightIndex = rightStart + i; + + boolean isNull = left.isNull(leftIndex); + + if (isNull != right.isNull(rightIndex)) { + return false; + } + + int typeWidth = left.getTypeWidth(); + if (!isNull) { + int startByteLeft = typeWidth * leftIndex; + int endByteLeft = typeWidth * (leftIndex + 1); + + int startByteRight = typeWidth * rightIndex; + int endByteRight = typeWidth * (rightIndex + 1); + + int ret = ByteFunctionHelpers.equal(left.getDataBuffer(), startByteLeft, endByteLeft, + right.getDataBuffer(), startByteRight, endByteRight); + + if (ret == 0) { + return false; + } + } + } + return true; + } + + protected boolean compareBaseVariableWidthVectors(BaseVariableWidthVector left) { + if (!compareValueVector(left, right)) { + return false; + } + + for (int i = 0; i < length; i++) { + int leftIndex = leftStart + i; + int rightIndex = rightStart + i; + + boolean isNull = left.isNull(leftIndex); + if (isNull != right.isNull(rightIndex)) { + return false; + } + + int offsetWidth = BaseVariableWidthVector.OFFSET_WIDTH; + + if (!isNull) { + final int startByteLeft = left.getOffsetBuffer().getInt(leftIndex * offsetWidth); + final int endByteLeft = left.getOffsetBuffer().getInt((leftIndex + 1) * offsetWidth); + + final int startByteRight = right.getOffsetBuffer().getInt(rightIndex * offsetWidth); + final int endByteRight = right.getOffsetBuffer().getInt((rightIndex + 1) * offsetWidth); + + int ret = ByteFunctionHelpers.equal(left.getDataBuffer(), startByteLeft, endByteLeft, + right.getDataBuffer(), startByteRight, endByteRight); + + if (ret == 0) { + return false; + } + } + } + return true; + } + + protected boolean compareListVectors(ListVector left) { + if (!compareValueVector(left, right)) { + return false; + } + + for (int i = 0; i < length; i++) { + int leftIndex = leftStart + i; + int rightIndex = rightStart + i; + + boolean isNull = left.isNull(leftIndex); + if (isNull != right.isNull(rightIndex)) { + return false; + } + + int offsetWidth = BaseRepeatedValueVector.OFFSET_WIDTH; + + if (!isNull) { + final int startByteLeft = left.getOffsetBuffer().getInt(leftIndex * offsetWidth); + final int endByteLeft = left.getOffsetBuffer().getInt((leftIndex + 1) * offsetWidth); + + final int startByteRight = right.getOffsetBuffer().getInt(rightIndex * offsetWidth); + final int endByteRight = right.getOffsetBuffer().getInt((rightIndex + 1) * offsetWidth); + + if ((endByteLeft - startByteLeft) != (endByteRight - startByteRight)) { + return false; + } + + ValueVector leftDataVector = left.getDataVector(); + ValueVector rightDataVector = ((ListVector)right).getDataVector(); + + if (!leftDataVector.accept(new RangeEqualsVisitor(rightDataVector, startByteLeft, + startByteRight, (endByteLeft - startByteLeft)))) { + return false; + } + } + } + return true; + } + + protected boolean compareFixedSizeListVectors(FixedSizeListVector left) { + if (!compareValueVector(left, right)) { + return false; + } + + if (left.getListSize() != ((FixedSizeListVector)right).getListSize()) { + return false; + } + + for (int i = 0; i < length; i++) { + int leftIndex = leftStart + i; + int rightIndex = rightStart + i; + + boolean isNull = left.isNull(leftIndex); + if (isNull != right.isNull(rightIndex)) { + return false; + } + + int listSize = left.getListSize(); + + if (!isNull) { + final int startByteLeft = leftIndex * listSize; + final int endByteLeft = (leftIndex + 1) * listSize; + + final int startByteRight = rightIndex * listSize; + final int endByteRight = (rightIndex + 1) * listSize; + + if ((endByteLeft - startByteLeft) != (endByteRight - startByteRight)) { + return false; + } + + ValueVector leftDataVector = left.getDataVector(); + ValueVector rightDataVector = ((FixedSizeListVector)right).getDataVector(); + + if (!leftDataVector.accept(new RangeEqualsVisitor(rightDataVector, startByteLeft, startByteRight, + (endByteLeft - startByteLeft)))) { + return false; + } + } + } + return true; + } + +} diff --git a/java/vector/src/main/java/org/apache/arrow/vector/compare/VectorEqualsVisitor.java b/java/vector/src/main/java/org/apache/arrow/vector/compare/VectorEqualsVisitor.java new file mode 100644 index 0000000..47071dd --- /dev/null +++ b/java/vector/src/main/java/org/apache/arrow/vector/compare/VectorEqualsVisitor.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.vector.compare; + +import org.apache.arrow.util.Preconditions; +import org.apache.arrow.vector.ValueVector; + +/** + * Visitor to compare vectors equal. + */ +public class VectorEqualsVisitor extends RangeEqualsVisitor { + + public VectorEqualsVisitor(ValueVector right) { + super(Preconditions.checkNotNull(right), 0, 0, right.getValueCount()); + } + + @Override + protected boolean compareValueVector(ValueVector left, ValueVector right) { + if (!super.compareValueVector(left, right)) { + return false; + } + return left.getValueCount() == right.getValueCount(); + } +} diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/AbstractStructVector.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/AbstractStructVector.java index dc9b1a1..25762fd 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/complex/AbstractStructVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/AbstractStructVector.java @@ -226,7 +226,10 @@ public abstract class AbstractStructVector extends AbstractContainerVector { return children; } - protected List<String> getChildFieldNames() { + /** + * Get child field names. + */ + public List<String> getChildFieldNames() { return getChildren().stream() .map(child -> child.getField().getName()) .collect(Collectors.toList()); diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/FixedSizeListVector.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/FixedSizeListVector.java index 5be308e..6bdf817 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/complex/FixedSizeListVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/FixedSizeListVector.java @@ -39,6 +39,7 @@ import org.apache.arrow.vector.BufferBacked; import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.ValueVector; import org.apache.arrow.vector.ZeroVector; +import org.apache.arrow.vector.compare.RangeEqualsVisitor; import org.apache.arrow.vector.complex.impl.UnionFixedSizeListReader; import org.apache.arrow.vector.complex.impl.UnionFixedSizeListWriter; import org.apache.arrow.vector.ipc.message.ArrowFieldNode; @@ -539,18 +540,18 @@ public class FixedSizeListVector extends BaseValueVector implements FieldVector, if (to == null) { return false; } - if (this.getClass() != to.getClass()) { - return false; - } + Preconditions.checkArgument(index >= 0 && index < valueCount, + "index %s out of range[0, %s]:", index, valueCount - 1); + Preconditions.checkArgument(toIndex >= 0 && toIndex < to.getValueCount(), + "index %s out of range[0, %s]:", index, to.getValueCount() - 1); - FixedSizeListVector that = (FixedSizeListVector) to; + RangeEqualsVisitor visitor = new RangeEqualsVisitor(to, index, toIndex, 1); + return this.accept(visitor); + } - for (int i = 0; i < listSize; i++) { - if (!vector.equals(index * listSize + i, that, toIndex * listSize + i)) { - return false; - } - } - return true; + @Override + public boolean accept(RangeEqualsVisitor visitor) { + return visitor.visit(this); } private class TransferImpl implements TransferPair { diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/ListVector.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/ListVector.java index 43b43bd..f07fbdb 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/complex/ListVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/ListVector.java @@ -28,12 +28,14 @@ import java.util.List; import org.apache.arrow.memory.BaseAllocator; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.OutOfMemoryException; +import org.apache.arrow.util.Preconditions; import org.apache.arrow.vector.AddOrGetResult; import org.apache.arrow.vector.BitVectorHelper; import org.apache.arrow.vector.BufferBacked; import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.ValueVector; import org.apache.arrow.vector.ZeroVector; +import org.apache.arrow.vector.compare.RangeEqualsVisitor; import org.apache.arrow.vector.complex.impl.ComplexCopier; import org.apache.arrow.vector.complex.impl.UnionListReader; import org.apache.arrow.vector.complex.impl.UnionListWriter; @@ -430,28 +432,13 @@ public class ListVector extends BaseRepeatedValueVector implements FieldVector, if (to == null) { return false; } - if (this.getClass() != to.getClass()) { - return false; - } - - ListVector that = (ListVector) to; - final int leftStart = offsetBuffer.getInt(index * OFFSET_WIDTH); - final int leftEnd = offsetBuffer.getInt((index + 1) * OFFSET_WIDTH); - - final int rightStart = that.offsetBuffer.getInt(toIndex * OFFSET_WIDTH); - final int rightEnd = that.offsetBuffer.getInt((toIndex + 1) * OFFSET_WIDTH); + Preconditions.checkArgument(index >= 0 && index < valueCount, + "index %s out of range[0, %s]:", index, valueCount - 1); + Preconditions.checkArgument(toIndex >= 0 && toIndex < to.getValueCount(), + "index %s out of range[0, %s]:", index, to.getValueCount() - 1); - if ((leftEnd - leftStart) != (rightEnd - rightStart)) { - return false; - } - - for (int i = 0; i < (leftEnd - leftStart); i++) { - if (!vector.equals(leftStart + i, that.vector, rightStart + i)) { - return false; - } - } - - return true; + RangeEqualsVisitor visitor = new RangeEqualsVisitor(to, index, toIndex, 1); + return this.accept(visitor); } private class TransferImpl implements TransferPair { @@ -837,4 +824,9 @@ public class ListVector extends BaseRepeatedValueVector implements FieldVector, public int getLastSet() { return lastSet; } + + @Override + public boolean accept(RangeEqualsVisitor visitor) { + return visitor.visit(this); + } } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/NonNullableStructVector.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/NonNullableStructVector.java index 1d9b871..995751e 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/complex/NonNullableStructVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/NonNullableStructVector.java @@ -26,9 +26,11 @@ import java.util.List; import java.util.Map; import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.util.Preconditions; import org.apache.arrow.vector.DensityAwareVector; import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.compare.RangeEqualsVisitor; import org.apache.arrow.vector.complex.impl.SingleStructReaderImpl; import org.apache.arrow.vector.complex.reader.FieldReader; import org.apache.arrow.vector.holders.ComplexHolder; @@ -299,41 +301,22 @@ public class NonNullableStructVector extends AbstractStructVector { } @Override + public boolean accept(RangeEqualsVisitor visitor) { + return visitor.visit(this); + } + + @Override public boolean equals(int index, ValueVector to, int toIndex) { if (to == null) { return false; } - if (this.getClass() != to.getClass()) { - return false; - } - NonNullableStructVector that = (NonNullableStructVector) to; - List<ValueVector> leftChildrens = new ArrayList<>(); - List<ValueVector> rightChildrens = new ArrayList<>(); - - for (String child : getChildFieldNames()) { - ValueVector v = getChild(child); - if (v != null) { - leftChildrens.add(v); - } - } + Preconditions.checkArgument(index >= 0 && index < valueCount, + "index %s out of range[0, %s]:", index, valueCount - 1); + Preconditions.checkArgument(toIndex >= 0 && toIndex < to.getValueCount(), + "index %s out of range[0, %s]:", index, to.getValueCount() - 1); - for (String child : that.getChildFieldNames()) { - ValueVector v = that.getChild(child); - if (v != null) { - rightChildrens.add(v); - } - } - - if (leftChildrens.size() != rightChildrens.size()) { - return false; - } - - for (int i = 0; i < leftChildrens.size(); i++) { - if (!leftChildrens.get(i).equals(index, rightChildrens.get(i), toIndex)) { - return false; - } - } - return true; + RangeEqualsVisitor visitor = new RangeEqualsVisitor(to, index, toIndex, 1); + return this.accept(visitor); } @Override diff --git a/java/vector/src/main/java/org/apache/arrow/vector/dictionary/Dictionary.java b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/Dictionary.java index 082d2ba..72b3da0 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/dictionary/Dictionary.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/Dictionary.java @@ -20,9 +20,9 @@ package org.apache.arrow.vector.dictionary; import java.util.Objects; import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.compare.VectorEqualsVisitor; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.DictionaryEncoding; -import org.apache.arrow.vector.util.Validator; /** * A dictionary (integer to Value mapping) that is used to facilitate @@ -64,21 +64,11 @@ public class Dictionary { return false; } Dictionary that = (Dictionary) o; - return Objects.equals(encoding, that.encoding) && compareFieldVector(dictionary, that.dictionary); + return Objects.equals(encoding, that.encoding) && dictionary.accept(new VectorEqualsVisitor(that.dictionary)); } @Override public int hashCode() { return Objects.hash(encoding, dictionary); } - - //TODO after vector api support compare two vectors, this should be cleaned up - private boolean compareFieldVector(FieldVector vector1, FieldVector vector2) { - try { - Validator.compareFieldVectors(vector1, vector2); - } catch (IllegalArgumentException e) { - return false; - } - return true; - } } diff --git a/java/vector/src/test/java/org/apache/arrow/vector/TestDictionaryVector.java b/java/vector/src/test/java/org/apache/arrow/vector/TestDictionaryVector.java index 2d6391b..ab0efa5 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/TestDictionaryVector.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/TestDictionaryVector.java @@ -398,8 +398,8 @@ public class TestDictionaryVector { @Test public void testIntEquals() { //test Int - try (final IntVector vector1 = new IntVector("", allocator); - final IntVector vector2 = new IntVector("", allocator)) { + try (final IntVector vector1 = new IntVector("int", allocator); + final IntVector vector2 = new IntVector("int", allocator)) { Dictionary dict1 = new Dictionary(vector1, new DictionaryEncoding(1L, false, null)); Dictionary dict2 = new Dictionary(vector2, new DictionaryEncoding(1L, false, null)); @@ -426,8 +426,8 @@ public class TestDictionaryVector { @Test public void testVarcharEquals() { - try (final VarCharVector vector1 = new VarCharVector("", allocator); - final VarCharVector vector2 = new VarCharVector("", allocator)) { + try (final VarCharVector vector1 = new VarCharVector("varchar", allocator); + final VarCharVector vector2 = new VarCharVector("varchar", allocator)) { Dictionary dict1 = new Dictionary(vector1, new DictionaryEncoding(1L, false, null)); Dictionary dict2 = new Dictionary(vector2, new DictionaryEncoding(1L, false, null)); @@ -455,8 +455,8 @@ public class TestDictionaryVector { @Test public void testVarBinaryEquals() { - try (final VarBinaryVector vector1 = new VarBinaryVector("", allocator); - final VarBinaryVector vector2 = new VarBinaryVector("", allocator)) { + try (final VarBinaryVector vector1 = new VarBinaryVector("binary", allocator); + final VarBinaryVector vector2 = new VarBinaryVector("binary", allocator)) { Dictionary dict1 = new Dictionary(vector1, new DictionaryEncoding(1L, false, null)); Dictionary dict2 = new Dictionary(vector2, new DictionaryEncoding(1L, false, null)); @@ -484,8 +484,8 @@ public class TestDictionaryVector { @Test public void testListEquals() { - try (final ListVector vector1 = ListVector.empty("", allocator); - final ListVector vector2 = ListVector.empty("", allocator);) { + try (final ListVector vector1 = ListVector.empty("list", allocator); + final ListVector vector2 = ListVector.empty("list", allocator);) { Dictionary dict1 = new Dictionary(vector1, new DictionaryEncoding(1L, false, null)); Dictionary dict2 = new Dictionary(vector2, new DictionaryEncoding(1L, false, null)); @@ -514,8 +514,8 @@ public class TestDictionaryVector { @Test public void testStructEquals() { - try (final StructVector vector1 = StructVector.empty("", allocator); - final StructVector vector2 = StructVector.empty("", allocator);) { + try (final StructVector vector1 = StructVector.empty("struct", allocator); + final StructVector vector2 = StructVector.empty("struct", allocator);) { vector1.addOrGet("f0", FieldType.nullable(new ArrowType.Int(32, true)), IntVector.class); vector1.addOrGet("f1", FieldType.nullable(new ArrowType.Int(64, true)), BigIntVector.class); vector2.addOrGet("f0", FieldType.nullable(new ArrowType.Int(32, true)), IntVector.class); @@ -544,8 +544,8 @@ public class TestDictionaryVector { @Test public void testUnionEquals() { - try (final UnionVector vector1 = new UnionVector("", allocator, null); - final UnionVector vector2 = new UnionVector("", allocator, null);) { + try (final UnionVector vector1 = new UnionVector("union", allocator, null); + final UnionVector vector2 = new UnionVector("union", allocator, null);) { final NullableUInt4Holder uInt4Holder = new NullableUInt4Holder(); uInt4Holder.value = 10; diff --git a/java/vector/src/test/java/org/apache/arrow/vector/TestValueVector.java b/java/vector/src/test/java/org/apache/arrow/vector/TestValueVector.java index c38f4dc..b553f31 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/TestValueVector.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/TestValueVector.java @@ -36,12 +36,22 @@ import org.apache.arrow.memory.BaseAllocator; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.memory.util.ArrowBufPointer; +import org.apache.arrow.vector.compare.VectorEqualsVisitor; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.StructVector; +import org.apache.arrow.vector.complex.UnionVector; +import org.apache.arrow.vector.complex.impl.NullableStructWriter; +import org.apache.arrow.vector.complex.impl.UnionListWriter; +import org.apache.arrow.vector.holders.NullableIntHolder; +import org.apache.arrow.vector.holders.NullableUInt4Holder; import org.apache.arrow.vector.holders.NullableVarBinaryHolder; import org.apache.arrow.vector.holders.NullableVarCharHolder; import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; +import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.Types.MinorType; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; import org.apache.arrow.vector.types.pojo.Schema; import org.apache.arrow.vector.util.OversizedAllocationException; import org.apache.arrow.vector.util.Text; @@ -52,7 +62,6 @@ import org.junit.Test; import io.netty.buffer.ArrowBuf; - public class TestValueVector { private static final String EMPTY_SCHEMA_PATH = ""; @@ -2245,4 +2254,388 @@ public class TestValueVector { assertNull(varBinaryVector.get(0)); } } + + @Test + public void testZeroVectorEquals() { + try (final ZeroVector vector1 = new ZeroVector(); + final ZeroVector vector2 = new ZeroVector()) { + + VectorEqualsVisitor visitor = new VectorEqualsVisitor(vector2); + assertTrue(vector1.accept(visitor)); + } + } + + @Test + public void testIntVectorEqualsWithNull() { + try (final IntVector vector1 = new IntVector("int", allocator); + final IntVector vector2 = new IntVector("int", allocator)) { + + vector1.allocateNew(2); + vector1.setValueCount(2); + vector2.allocateNew(2); + vector2.setValueCount(2); + + vector1.setSafe(0, 1); + vector1.setSafe(1, 2); + + vector2.setSafe(0, 1); + VectorEqualsVisitor visitor = new VectorEqualsVisitor(vector2); + + assertFalse(vector1.accept(visitor)); + } + } + + @Test + public void testIntVectorEquals() { + try (final IntVector vector1 = new IntVector("int", allocator); + final IntVector vector2 = new IntVector("int", allocator)) { + + vector1.allocateNew(3); + vector1.setValueCount(3); + vector2.allocateNew(3); + vector2.setValueCount(2); + + vector1.setSafe(0, 1); + vector1.setSafe(1, 2); + vector1.setSafe(2, 3); + + vector2.setSafe(0, 1); + vector2.setSafe(1, 2); + + VectorEqualsVisitor visitor = new VectorEqualsVisitor(vector2); + + assertFalse(vector1.accept(visitor)); + + vector2.setValueCount(3); + vector2.setSafe(2, 2); + assertFalse(vector1.equals(vector2)); + + vector2.setSafe(2, 3); + assertTrue(vector1.accept(visitor)); + } + } + + @Test + public void testDecimalVectorEquals() { + try (final DecimalVector vector1 = new DecimalVector("decimal", allocator, 3, 3); + final DecimalVector vector2 = new DecimalVector("decimal", allocator, 3, 3); + final DecimalVector vector3 = new DecimalVector("decimal", allocator, 3, 2)) { + + vector1.allocateNew(2); + vector1.setValueCount(2); + vector2.allocateNew(2); + vector2.setValueCount(2); + vector3.allocateNew(2); + vector3.setValueCount(2); + + vector1.setSafe(0, 100); + vector1.setSafe(1, 200); + + vector2.setSafe(0, 100); + vector2.setSafe(1, 200); + + vector3.setSafe(0, 100); + vector3.setSafe(1, 200); + + VectorEqualsVisitor visitor1 = new VectorEqualsVisitor(vector2); + VectorEqualsVisitor visitor2 = new VectorEqualsVisitor(vector3); + + assertTrue(vector1.accept(visitor1)); + assertFalse(vector1.accept(visitor2)); + } + } + + @Test + public void testVarcharVectorEuqalsWithNull() { + try (final VarCharVector vector1 = new VarCharVector("varchar", allocator); + final VarCharVector vector2 = new VarCharVector("varchar", allocator)) { + + vector1.allocateNew(); + vector2.allocateNew(); + + // set some values + vector1.setSafe(0, STR1, 0, STR1.length); + vector1.setSafe(1, STR2, 0, STR2.length); + vector1.setValueCount(2); + + vector2.setSafe(0, STR1, 0, STR1.length); + vector2.setValueCount(2); + + VectorEqualsVisitor visitor = new VectorEqualsVisitor(vector2); + assertFalse(vector1.accept(visitor)); + } + } + + @Test + public void testVarcharVectorEquals() { + try (final VarCharVector vector1 = new VarCharVector("varchar", allocator); + final VarCharVector vector2 = new VarCharVector("varchar", allocator)) { + + vector1.allocateNew(); + vector2.allocateNew(); + + // set some values + vector1.setSafe(0, STR1, 0, STR1.length); + vector1.setSafe(1, STR2, 0, STR2.length); + vector1.setSafe(2, STR3, 0, STR3.length); + vector1.setValueCount(3); + + vector2.setSafe(0, STR1, 0, STR1.length); + vector2.setSafe(1, STR2, 0, STR2.length); + vector2.setValueCount(2); + + VectorEqualsVisitor visitor = new VectorEqualsVisitor(vector2); + assertFalse(vector1.accept(visitor)); + + vector2.setSafe(2, STR3, 0, STR3.length); + vector2.setValueCount(3); + assertTrue(vector1.accept(visitor)); + } + } + + @Test + public void testVarBinaryVectorEquals() { + try (final VarBinaryVector vector1 = new VarBinaryVector("binary", allocator); + final VarBinaryVector vector2 = new VarBinaryVector("binary", allocator)) { + + vector1.allocateNew(); + vector2.allocateNew(); + + // set some values + vector1.setSafe(0, STR1, 0, STR1.length); + vector1.setSafe(1, STR2, 0, STR2.length); + vector1.setSafe(2, STR3, 0, STR3.length); + vector1.setValueCount(3); + + vector2.setSafe(0, STR1, 0, STR1.length); + vector2.setSafe(1, STR2, 0, STR2.length); + vector2.setValueCount(2); + + VectorEqualsVisitor visitor = new VectorEqualsVisitor(vector2); + assertFalse(vector1.accept(visitor)); + + vector2.setSafe(2, STR3, 0, STR3.length); + vector2.setValueCount(3); + assertTrue(vector1.accept(visitor)); + } + } + + @Test + public void testListVectorEqualsWithNull() { + try (final ListVector vector1 = ListVector.empty("list", allocator); + final ListVector vector2 = ListVector.empty("list", allocator);) { + + UnionListWriter writer1 = vector1.getWriter(); + writer1.allocate(); + + //set some values + writeListVector(writer1, new int[] {1, 2}); + writeListVector(writer1, new int[] {3, 4}); + writeListVector(writer1, new int[] {}); + writer1.setValueCount(3); + + UnionListWriter writer2 = vector2.getWriter(); + writer2.allocate(); + + //set some values + writeListVector(writer2, new int[] {1, 2}); + writeListVector(writer2, new int[] {3, 4}); + writer2.setValueCount(3); + + VectorEqualsVisitor visitor = new VectorEqualsVisitor(vector2); + + assertFalse(vector1.accept(visitor)); + } + } + + @Test + public void testListVectorEquals() { + try (final ListVector vector1 = ListVector.empty("list", allocator); + final ListVector vector2 = ListVector.empty("list", allocator);) { + + UnionListWriter writer1 = vector1.getWriter(); + writer1.allocate(); + + //set some values + writeListVector(writer1, new int[] {1, 2}); + writeListVector(writer1, new int[] {3, 4}); + writeListVector(writer1, new int[] {5, 6}); + writer1.setValueCount(3); + + UnionListWriter writer2 = vector2.getWriter(); + writer2.allocate(); + + //set some values + writeListVector(writer2, new int[] {1, 2}); + writeListVector(writer2, new int[] {3, 4}); + writer2.setValueCount(2); + + VectorEqualsVisitor visitor = new VectorEqualsVisitor(vector2); + assertFalse(vector1.accept(visitor)); + + writeListVector(writer2, new int[] {5, 6}); + writer2.setValueCount(3); + + assertTrue(vector1.accept(visitor)); + } + } + + @Test + public void testStructVectorEqualsWithNull() { + + try (final StructVector vector1 = StructVector.empty("struct", allocator); + final StructVector vector2 = StructVector.empty("struct", allocator);) { + vector1.addOrGet("f0", FieldType.nullable(new ArrowType.Int(32, true)), IntVector.class); + vector1.addOrGet("f1", FieldType.nullable(new ArrowType.Int(64, true)), BigIntVector.class); + vector2.addOrGet("f0", FieldType.nullable(new ArrowType.Int(32, true)), IntVector.class); + vector2.addOrGet("f1", FieldType.nullable(new ArrowType.Int(64, true)), BigIntVector.class); + + NullableStructWriter writer1 = vector1.getWriter(); + writer1.allocate(); + + writeStructVector(writer1, 1, 10L); + writeStructVector(writer1, 2, 20L); + writeStructVector(writer1, 3, 30L); + writer1.setValueCount(3); + + NullableStructWriter writer2 = vector2.getWriter(); + writer2.allocate(); + + writeStructVector(writer2, 1, 10L); + writeStructVector(writer2, 3, 30L); + writer2.setValueCount(3); + + VectorEqualsVisitor visitor = new VectorEqualsVisitor(vector2); + assertFalse(vector1.accept(visitor)); + } + } + + @Test + public void testStructVectorEquals() { + try (final StructVector vector1 = StructVector.empty("struct", allocator); + final StructVector vector2 = StructVector.empty("struct", allocator);) { + vector1.addOrGet("f0", FieldType.nullable(new ArrowType.Int(32, true)), IntVector.class); + vector1.addOrGet("f1", FieldType.nullable(new ArrowType.Int(64, true)), BigIntVector.class); + vector2.addOrGet("f0", FieldType.nullable(new ArrowType.Int(32, true)), IntVector.class); + vector2.addOrGet("f1", FieldType.nullable(new ArrowType.Int(64, true)), BigIntVector.class); + + NullableStructWriter writer1 = vector1.getWriter(); + writer1.allocate(); + + writeStructVector(writer1, 1, 10L); + writeStructVector(writer1, 2, 20L); + writeStructVector(writer1, 3, 30L); + writer1.setValueCount(3); + + NullableStructWriter writer2 = vector2.getWriter(); + writer2.allocate(); + + writeStructVector(writer2, 1, 10L); + writeStructVector(writer2, 2, 20L); + writer2.setValueCount(2); + + VectorEqualsVisitor visitor = new VectorEqualsVisitor(vector2); + assertFalse(vector1.accept(visitor)); + + writeStructVector(writer2, 3, 30L); + writer2.setValueCount(3); + + assertTrue(vector1.accept(visitor)); + } + } + + @Test + public void testStructVectorEqualsWithDiffChild() { + try (final StructVector vector1 = StructVector.empty("struct", allocator); + final StructVector vector2 = StructVector.empty("struct", allocator);) { + vector1.addOrGet("f0", FieldType.nullable(new ArrowType.Int(32, true)), IntVector.class); + vector1.addOrGet("f1", FieldType.nullable(new ArrowType.Int(64, true)), BigIntVector.class); + vector2.addOrGet("f0", FieldType.nullable(new ArrowType.Int(32, true)), IntVector.class); + vector2.addOrGet("f10", FieldType.nullable(new ArrowType.Int(64, true)), BigIntVector.class); + + NullableStructWriter writer1 = vector1.getWriter(); + writer1.allocate(); + + writeStructVector(writer1, 1, 10L); + writeStructVector(writer1, 2, 20L); + writer1.setValueCount(2); + + NullableStructWriter writer2 = vector2.getWriter(); + writer2.allocate(); + + writeStructVector(writer2, 1, 10L); + writeStructVector(writer2, 2, 20L); + writer2.setValueCount(2); + + VectorEqualsVisitor visitor = new VectorEqualsVisitor(vector2); + assertFalse(vector1.accept(visitor)); + } + } + + @Test + public void testUnionVectorEquals() { + try (final UnionVector vector1 = new UnionVector("union", allocator, null); + final UnionVector vector2 = new UnionVector("union", allocator, null);) { + + final NullableUInt4Holder uInt4Holder = new NullableUInt4Holder(); + uInt4Holder.value = 10; + uInt4Holder.isSet = 1; + + final NullableIntHolder intHolder = new NullableIntHolder(); + uInt4Holder.value = 20; + uInt4Holder.isSet = 1; + + vector1.setType(0, Types.MinorType.UINT4); + vector1.setSafe(0, uInt4Holder); + + vector1.setType(1, Types.MinorType.INT); + vector1.setSafe(1, intHolder); + vector1.setValueCount(2); + + vector2.setType(0, Types.MinorType.UINT4); + vector2.setSafe(0, uInt4Holder); + + vector2.setType(1, Types.MinorType.INT); + vector2.setSafe(1, intHolder); + vector2.setValueCount(2); + + VectorEqualsVisitor visitor = new VectorEqualsVisitor(vector2); + assertTrue(vector1.accept(visitor)); + } + } + + @Test(expected = IllegalArgumentException.class) + public void testEqualsWithIndexOutOfRange() { + try (final IntVector vector1 = new IntVector("int", allocator); + final IntVector vector2 = new IntVector("int", allocator)) { + + vector1.allocateNew(2); + vector1.setValueCount(2); + vector2.allocateNew(2); + vector2.setValueCount(2); + + vector1.setSafe(0, 1); + vector1.setSafe(1, 2); + + vector2.setSafe(0, 1); + vector2.setSafe(1, 2); + + assertTrue(vector1.equals(3, vector2, 2)); + } + } + + private void writeStructVector(NullableStructWriter writer, int value1, long value2) { + writer.start(); + writer.integer("f0").writeInt(value1); + writer.bigInt("f1").writeBigInt(value2); + writer.end(); + } + + private void writeListVector(UnionListWriter writer, int[] values) { + writer.startList(); + for (int v: values) { + writer.integer().writeInt(v); + } + writer.endList(); + } }