This is an automated email from the ASF dual-hosted git repository.
lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new 7bc2452b35 GH-41902: [Java] Variadic Buffer Counts Incorrect (#41930)
7bc2452b35 is described below
commit 7bc2452b350867b3ddc9de9ceceeef0e4d722941
Author: Vibhatha Lakmal Abeykoon <[email protected]>
AuthorDate: Tue Jun 4 12:35:13 2024 +0530
GH-41902: [Java] Variadic Buffer Counts Incorrect (#41930)
### Rationale for this change
In the initial PR for `variadicBufferCounts` addition to Java spec, the non
variadic buffer-ed vectors were assigned with 0 valued non-empty
`variadicBufferCounts`. And this caused CIs to fail in Arrow Rust.
### What changes are included in this PR?
This PR changes such that non variadic buffer-ed vectors would contain an
empty `variadicBufferCounts` attribute in `ArrowRecordBatch` interface in Java.
Also this includes upgrade to JUNIT5.
### Are these changes tested?
Yes, from existing tests and a new test added.
### Are there any user-facing changes?
No
* GitHub Issue: #41902
Authored-by: Vibhatha Abeykoon <[email protected]>
Signed-off-by: David Li <[email protected]>
---
.../org/apache/arrow/c/StructVectorLoader.java | 29 +++++---
.../org/apache/arrow/c/StructVectorUnloader.java | 5 +-
.../java/org/apache/arrow/c/DictionaryTest.java | 60 ++++++++++++++++
.../java/org/apache/arrow/vector/VectorLoader.java | 20 ++++--
.../org/apache/arrow/vector/VectorUnloader.java | 5 +-
.../org/apache/arrow/vector/TestValueVector.java | 56 +++++++++++++++
.../apache/arrow/vector/TestVarCharViewVector.java | 80 ++++++++++++++++++++++
7 files changed, 238 insertions(+), 17 deletions(-)
diff --git a/java/c/src/main/java/org/apache/arrow/c/StructVectorLoader.java
b/java/c/src/main/java/org/apache/arrow/c/StructVectorLoader.java
index 27acf84d30..1b0c59163a 100644
--- a/java/c/src/main/java/org/apache/arrow/c/StructVectorLoader.java
+++ b/java/c/src/main/java/org/apache/arrow/c/StructVectorLoader.java
@@ -27,6 +27,7 @@ import java.util.List;
import org.apache.arrow.memory.ArrowBuf;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.util.Collections2;
+import org.apache.arrow.vector.BaseVariableWidthViewVector;
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.TypeLayout;
import org.apache.arrow.vector.complex.StructVector;
@@ -54,7 +55,14 @@ public class StructVectorLoader {
/**
* Construct with a schema.
- *
+ * <p>
+ * The schema referred to here can be obtained from the struct vector.
+ * The schema here should be the children of a struct vector, not a schema
+ * containing the struct field itself.
+ * For example:
+ * <code>
+ * Schema schema = new Schema(structVector.getField().getChildren());
+ * </code>
* @param schema buffers are added based on schema.
*/
public StructVectorLoader(Schema schema) {
@@ -90,7 +98,7 @@ public class StructVectorLoader {
.fromCompressionType(recordBatch.getBodyCompression().getCodec());
decompressionNeeded = codecType !=
CompressionUtil.CodecType.NO_COMPRESSION;
CompressionCodec codec = decompressionNeeded ?
factory.createCodec(codecType) : NoCompressionCodec.INSTANCE;
- Iterator<Long> variadicBufferCounts = null;
+ Iterator<Long> variadicBufferCounts = Collections.emptyIterator();
if (recordBatch.getVariadicBufferCounts() != null &&
!recordBatch.getVariadicBufferCounts().isEmpty()) {
variadicBufferCounts = recordBatch.getVariadicBufferCounts().iterator();
}
@@ -98,9 +106,10 @@ public class StructVectorLoader {
loadBuffers(fieldVector, fieldVector.getField(), buffers, nodes, codec,
variadicBufferCounts);
}
result.loadFieldBuffers(new ArrowFieldNode(recordBatch.getLength(), 0),
Collections.singletonList(null));
- if (nodes.hasNext() || buffers.hasNext()) {
- throw new IllegalArgumentException("not all nodes and buffers were
consumed. nodes: " +
- Collections2.toList(nodes).toString() + " buffers: " +
Collections2.toList(buffers).toString());
+ if (nodes.hasNext() || buffers.hasNext() ||
variadicBufferCounts.hasNext()) {
+ throw new IllegalArgumentException("not all nodes, buffers and
variadicBufferCounts were consumed. nodes: " +
+ Collections2.toString(nodes) + " buffers: " +
Collections2.toString(buffers) + " variadicBufferCounts: " +
+ Collections2.toString(variadicBufferCounts));
}
return result;
}
@@ -109,10 +118,14 @@ public class StructVectorLoader {
CompressionCodec codec, Iterator<Long> variadicBufferCounts) {
checkArgument(nodes.hasNext(), "no more field nodes for field %s and
vector %s", field, vector);
ArrowFieldNode fieldNode = nodes.next();
- // variadicBufferLayoutCount will be 0 for vectors of type except
BaseVariableWidthViewVector
+ // variadicBufferLayoutCount will be 0 for vectors of a type except
BaseVariableWidthViewVector
long variadicBufferLayoutCount = 0;
- if (variadicBufferCounts != null) {
- variadicBufferLayoutCount = variadicBufferCounts.next();
+ if (vector instanceof BaseVariableWidthViewVector) {
+ if (variadicBufferCounts.hasNext()) {
+ variadicBufferLayoutCount = variadicBufferCounts.next();
+ } else {
+ throw new IllegalStateException("No variadicBufferCounts available for
BaseVariableWidthViewVector");
+ }
}
int bufferLayoutCount = (int) (variadicBufferLayoutCount +
TypeLayout.getTypeBufferCount(field.getType()));
List<ArrowBuf> ownBuffers = new ArrayList<>(bufferLayoutCount);
diff --git a/java/c/src/main/java/org/apache/arrow/c/StructVectorUnloader.java
b/java/c/src/main/java/org/apache/arrow/c/StructVectorUnloader.java
index 8d015157eb..82539acf6f 100644
--- a/java/c/src/main/java/org/apache/arrow/c/StructVectorUnloader.java
+++ b/java/c/src/main/java/org/apache/arrow/c/StructVectorUnloader.java
@@ -109,7 +109,10 @@ public class StructVectorUnloader {
List<ArrowBuf> fieldBuffers = vector.getFieldBuffers();
long variadicBufferCount = getVariadicBufferCount(vector);
int expectedBufferCount = (int)
(TypeLayout.getTypeBufferCount(vector.getField().getType()) +
variadicBufferCount);
- variadicBufferCounts.add(variadicBufferCount);
+ // only update variadicBufferCounts for vectors that have variadic buffers
+ if (variadicBufferCount > 0) {
+ variadicBufferCounts.add(variadicBufferCount);
+ }
if (fieldBuffers.size() != expectedBufferCount) {
throw new IllegalArgumentException(String.format("wrong number of
buffers for field %s in vector %s. found: %s",
vector.getField(), vector.getClass().getSimpleName(), fieldBuffers));
diff --git a/java/c/src/test/java/org/apache/arrow/c/DictionaryTest.java
b/java/c/src/test/java/org/apache/arrow/c/DictionaryTest.java
index d892781756..aa1264e484 100644
--- a/java/c/src/test/java/org/apache/arrow/c/DictionaryTest.java
+++ b/java/c/src/test/java/org/apache/arrow/c/DictionaryTest.java
@@ -17,6 +17,8 @@
package org.apache.arrow.c;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import java.io.ByteArrayInputStream;
@@ -29,6 +31,7 @@ import org.apache.arrow.c.ArrowArray;
import org.apache.arrow.c.ArrowSchema;
import org.apache.arrow.c.CDataDictionaryProvider;
import org.apache.arrow.c.Data;
+import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.util.AutoCloseables;
import org.apache.arrow.vector.FieldVector;
@@ -36,13 +39,19 @@ import org.apache.arrow.vector.IntVector;
import org.apache.arrow.vector.ValueVector;
import org.apache.arrow.vector.VarCharVector;
import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.vector.ViewVarCharVector;
import org.apache.arrow.vector.compare.VectorEqualsVisitor;
+import org.apache.arrow.vector.complex.StructVector;
import org.apache.arrow.vector.dictionary.Dictionary;
import org.apache.arrow.vector.dictionary.DictionaryEncoder;
import org.apache.arrow.vector.dictionary.DictionaryProvider;
import org.apache.arrow.vector.ipc.ArrowStreamReader;
import org.apache.arrow.vector.ipc.ArrowStreamWriter;
+import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
+import org.apache.arrow.vector.types.Types.MinorType;
import org.apache.arrow.vector.types.pojo.DictionaryEncoding;
+import org.apache.arrow.vector.types.pojo.Field;
+import org.apache.arrow.vector.types.pojo.FieldType;
import org.apache.arrow.vector.types.pojo.Schema;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
@@ -216,4 +225,55 @@ public class DictionaryTest {
return new ArrowStreamReader(in, allocator);
}
+ private void createStructVector(StructVector vector) {
+ final ViewVarCharVector child1 = vector.addOrGet("f0",
+ FieldType.nullable(MinorType.VIEWVARCHAR.getType()),
ViewVarCharVector.class);
+ final IntVector child2 = vector.addOrGet("f1",
+ FieldType.nullable(MinorType.INT.getType()), IntVector.class);
+
+ // Write the values to child 1
+ child1.allocateNew();
+ child1.set(0, "01234567890".getBytes());
+ child1.set(1, "012345678901234567".getBytes());
+ vector.setIndexDefined(0);
+
+ // Write the values to child 2
+ child2.allocateNew();
+ child2.set(0, 10);
+ child2.set(1, 11);
+ vector.setIndexDefined(1);
+
+ vector.setValueCount(2);
+ }
+
+ @Test
+ public void testVectorLoadUnloadOnStructVector() {
+ try (final StructVector structVector1 = StructVector.empty("struct",
allocator)) {
+ createStructVector(structVector1);
+ Field field1 = structVector1.getField();
+ Schema schema = new Schema(field1.getChildren());
+ StructVectorUnloader vectorUnloader = new
StructVectorUnloader(structVector1);
+
+ try (
+ ArrowRecordBatch recordBatch = vectorUnloader.getRecordBatch();
+ BufferAllocator finalVectorsAllocator =
allocator.newChildAllocator("struct", 0, Long.MAX_VALUE);
+ ) {
+ // validating recordBatch contains an output for variadicBufferCounts
+ assertFalse(recordBatch.getVariadicBufferCounts().isEmpty());
+ assertEquals(1, recordBatch.getVariadicBufferCounts().size());
+ assertEquals(1, recordBatch.getVariadicBufferCounts().get(0));
+
+ StructVectorLoader vectorLoader = new StructVectorLoader(schema);
+ try (StructVector structVector2 =
vectorLoader.load(finalVectorsAllocator, recordBatch)) {
+ // Improve this after fixing
https://github.com/apache/arrow/issues/41933
+ // assertTrue(VectorEqualsVisitor.vectorEquals(structVector1,
structVector2), "vectors are not equivalent");
+
assertTrue(VectorEqualsVisitor.vectorEquals(structVector1.getChild("f0"),
structVector2.getChild("f0")),
+ "vectors are not equivalent");
+
assertTrue(VectorEqualsVisitor.vectorEquals(structVector1.getChild("f1"),
structVector2.getChild("f1")),
+ "vectors are not equivalent");
+ }
+ }
+ }
+ }
+
}
diff --git
a/java/vector/src/main/java/org/apache/arrow/vector/VectorLoader.java
b/java/vector/src/main/java/org/apache/arrow/vector/VectorLoader.java
index 9590e70f46..dec536ae6c 100644
--- a/java/vector/src/main/java/org/apache/arrow/vector/VectorLoader.java
+++ b/java/vector/src/main/java/org/apache/arrow/vector/VectorLoader.java
@@ -20,6 +20,7 @@ package org.apache.arrow.vector;
import static org.apache.arrow.util.Preconditions.checkArgument;
import java.util.ArrayList;
+import java.util.Collections;
import java.util.Iterator;
import java.util.List;
@@ -80,7 +81,7 @@ public class VectorLoader {
CompressionUtil.CodecType.fromCompressionType(recordBatch.getBodyCompression().getCodec());
decompressionNeeded = codecType !=
CompressionUtil.CodecType.NO_COMPRESSION;
CompressionCodec codec = decompressionNeeded ?
factory.createCodec(codecType) : NoCompressionCodec.INSTANCE;
- Iterator<Long> variadicBufferCounts = null;
+ Iterator<Long> variadicBufferCounts = Collections.emptyIterator();;
if (recordBatch.getVariadicBufferCounts() != null &&
!recordBatch.getVariadicBufferCounts().isEmpty()) {
variadicBufferCounts = recordBatch.getVariadicBufferCounts().iterator();
}
@@ -89,9 +90,10 @@ public class VectorLoader {
loadBuffers(fieldVector, fieldVector.getField(), buffers, nodes, codec,
variadicBufferCounts);
}
root.setRowCount(recordBatch.getLength());
- if (nodes.hasNext() || buffers.hasNext()) {
- throw new IllegalArgumentException("not all nodes and buffers were
consumed. nodes: " +
- Collections2.toString(nodes) + " buffers: " +
Collections2.toString(buffers));
+ if (nodes.hasNext() || buffers.hasNext() ||
variadicBufferCounts.hasNext()) {
+ throw new IllegalArgumentException("not all nodes, buffers and
variadicBufferCounts were consumed. nodes: " +
+ Collections2.toString(nodes) + " buffers: " +
Collections2.toString(buffers) + " variadicBufferCounts: " +
+ Collections2.toString(variadicBufferCounts));
}
}
@@ -104,10 +106,14 @@ public class VectorLoader {
Iterator<Long> variadicBufferCounts) {
checkArgument(nodes.hasNext(), "no more field nodes for field %s and
vector %s", field, vector);
ArrowFieldNode fieldNode = nodes.next();
- // variadicBufferLayoutCount will be 0 for vectors of type except
BaseVariableWidthViewVector
+ // variadicBufferLayoutCount will be 0 for vectors of a type except
BaseVariableWidthViewVector
long variadicBufferLayoutCount = 0;
- if (variadicBufferCounts != null) {
- variadicBufferLayoutCount = variadicBufferCounts.next();
+ if (vector instanceof BaseVariableWidthViewVector) {
+ if (variadicBufferCounts.hasNext()) {
+ variadicBufferLayoutCount = variadicBufferCounts.next();
+ } else {
+ throw new IllegalStateException("No variadicBufferCounts available for
BaseVariableWidthViewVector");
+ }
}
int bufferLayoutCount = (int) (variadicBufferLayoutCount +
TypeLayout.getTypeBufferCount(field.getType()));
List<ArrowBuf> ownBuffers = new ArrayList<>(bufferLayoutCount);
diff --git
a/java/vector/src/main/java/org/apache/arrow/vector/VectorUnloader.java
b/java/vector/src/main/java/org/apache/arrow/vector/VectorUnloader.java
index 8528099b6d..6e7ab34eba 100644
--- a/java/vector/src/main/java/org/apache/arrow/vector/VectorUnloader.java
+++ b/java/vector/src/main/java/org/apache/arrow/vector/VectorUnloader.java
@@ -103,7 +103,10 @@ public class VectorUnloader {
List<ArrowBuf> fieldBuffers = vector.getFieldBuffers();
long variadicBufferCount = getVariadicBufferCount(vector);
int expectedBufferCount = (int)
(TypeLayout.getTypeBufferCount(vector.getField().getType()) +
variadicBufferCount);
- variadicBufferCounts.add(variadicBufferCount);
+ // only update variadicBufferCounts for vectors that have variadic buffers
+ if (variadicBufferCount > 0) {
+ variadicBufferCounts.add(variadicBufferCount);
+ }
if (fieldBuffers.size() != expectedBufferCount) {
throw new IllegalArgumentException(String.format(
"wrong number of buffers for field %s in vector %s. found: %s",
diff --git
a/java/vector/src/test/java/org/apache/arrow/vector/TestValueVector.java
b/java/vector/src/test/java/org/apache/arrow/vector/TestValueVector.java
index fda14b24a4..b0d316070a 100644
--- a/java/vector/src/test/java/org/apache/arrow/vector/TestValueVector.java
+++ b/java/vector/src/test/java/org/apache/arrow/vector/TestValueVector.java
@@ -3441,4 +3441,60 @@ public class TestValueVector {
}
target.close();
}
+
+ @Test
+ public void testVectorLoadUnloadOnNonVariadicVectors() {
+
+ try (final IntVector vector1 = new IntVector("myvector", allocator)) {
+
+ setVector(vector1, 1, 2, 3, 4, 5, 6);
+ vector1.setValueCount(15);
+
+ /* Check the vector output */
+ assertEquals(1, vector1.get(0));
+ assertEquals(2, vector1.get(1));
+ assertEquals(3, vector1.get(2));
+ assertEquals(4, vector1.get(3));
+ assertEquals(5, vector1.get(4));
+ assertEquals(6, vector1.get(5));
+
+ Field field = vector1.getField();
+ String fieldName = field.getName();
+
+ List<Field> fields = new ArrayList<>();
+ List<FieldVector> fieldVectors = new ArrayList<>();
+
+ fields.add(field);
+ fieldVectors.add(vector1);
+
+ Schema schema = new Schema(fields);
+
+ VectorSchemaRoot schemaRoot1 = new VectorSchemaRoot(schema,
fieldVectors, vector1.getValueCount());
+ VectorUnloader vectorUnloader = new VectorUnloader(schemaRoot1);
+
+ try (
+ ArrowRecordBatch recordBatch = vectorUnloader.getRecordBatch();
+ BufferAllocator finalVectorsAllocator =
allocator.newChildAllocator("new vector", 0, Long.MAX_VALUE);
+ VectorSchemaRoot schemaRoot2 = VectorSchemaRoot.create(schema,
finalVectorsAllocator);
+ ) {
+
+ // validating recordBatch doesn't contain an output for
variadicBufferCounts
+ assertTrue(recordBatch.getVariadicBufferCounts().isEmpty());
+
+ VectorLoader vectorLoader = new VectorLoader(schemaRoot2);
+ vectorLoader.load(recordBatch);
+
+ IntVector vector2 = (IntVector) schemaRoot2.getVector(fieldName);
+ vector2.setValueCount(25);
+
+ /* Check the vector output */
+ assertEquals(1, vector2.get(0));
+ assertEquals(2, vector2.get(1));
+ assertEquals(3, vector2.get(2));
+ assertEquals(4, vector2.get(3));
+ assertEquals(5, vector2.get(4));
+ assertEquals(6, vector2.get(5));
+ }
+ }
+ }
}
diff --git
a/java/vector/src/test/java/org/apache/arrow/vector/TestVarCharViewVector.java
b/java/vector/src/test/java/org/apache/arrow/vector/TestVarCharViewVector.java
index 1ba3bc3576..817941ecb4 100644
---
a/java/vector/src/test/java/org/apache/arrow/vector/TestVarCharViewVector.java
+++
b/java/vector/src/test/java/org/apache/arrow/vector/TestVarCharViewVector.java
@@ -23,6 +23,7 @@ import static
org.apache.arrow.vector.TestUtils.newViewVarCharVector;
import static
org.apache.arrow.vector.testing.ValueVectorDataPopulator.setVector;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertSame;
@@ -2212,6 +2213,85 @@ public class TestVarCharViewVector {
}
}
+ @Test
+ public void testVectorLoadUnloadOnMixedTypes() {
+
+ try (final IntVector vector1 = new IntVector("myvector", allocator);
+ final ViewVarCharVector vector2 = new
ViewVarCharVector("myviewvector", allocator)) {
+
+ final int valueCount = 15;
+
+ setVector(vector1, 1, 2, 3, 4, 5, 6);
+ vector1.setValueCount(valueCount);
+
+ setVector(vector2, STR1, STR2, STR3, STR4, STR5, STR6);
+ vector1.setValueCount(valueCount);
+
+ /* Check the vector output */
+ assertEquals(1, vector1.get(0));
+ assertEquals(2, vector1.get(1));
+ assertEquals(3, vector1.get(2));
+ assertEquals(4, vector1.get(3));
+ assertEquals(5, vector1.get(4));
+ assertEquals(6, vector1.get(5));
+
+ Field field1 = vector1.getField();
+ String fieldName1 = field1.getName();
+
+ Field field2 = vector2.getField();
+ String fieldName2 = field2.getName();
+
+ List<Field> fields = new ArrayList<>(2);
+ List<FieldVector> fieldVectors = new ArrayList<>(2);
+
+ fields.add(field1);
+ fields.add(field2);
+ fieldVectors.add(vector1);
+ fieldVectors.add(vector2);
+
+ Schema schema = new Schema(fields);
+
+ VectorSchemaRoot schemaRoot1 = new VectorSchemaRoot(schema,
fieldVectors, valueCount);
+ VectorUnloader vectorUnloader = new VectorUnloader(schemaRoot1);
+
+ try (
+ ArrowRecordBatch recordBatch = vectorUnloader.getRecordBatch();
+ BufferAllocator finalVectorsAllocator =
allocator.newChildAllocator("new vector", 0, Long.MAX_VALUE);
+ VectorSchemaRoot schemaRoot2 = VectorSchemaRoot.create(schema,
finalVectorsAllocator);
+ ) {
+
+ // validating recordBatch contains an output for variadicBufferCounts
+ assertFalse(recordBatch.getVariadicBufferCounts().isEmpty());
+ assertEquals(1, recordBatch.getVariadicBufferCounts().size());
+
+ VectorLoader vectorLoader = new VectorLoader(schemaRoot2);
+ vectorLoader.load(recordBatch);
+
+ IntVector vector3 = (IntVector) schemaRoot2.getVector(fieldName1);
+ vector3.setValueCount(25);
+
+ /* Check the vector output */
+ assertEquals(1, vector3.get(0));
+ assertEquals(2, vector3.get(1));
+ assertEquals(3, vector3.get(2));
+ assertEquals(4, vector3.get(3));
+ assertEquals(5, vector3.get(4));
+ assertEquals(6, vector3.get(5));
+
+ ViewVarCharVector vector4 = (ViewVarCharVector)
schemaRoot2.getVector(fieldName2);
+ vector4.setValueCount(25);
+
+ /* Check the vector output */
+ assertArrayEquals(STR1, vector4.get(0));
+ assertArrayEquals(STR2, vector4.get(1));
+ assertArrayEquals(STR3, vector4.get(2));
+ assertArrayEquals(STR4, vector4.get(3));
+ assertArrayEquals(STR5, vector4.get(4));
+ assertArrayEquals(STR6, vector4.get(5));
+ }
+ }
+ }
+
private String generateRandomString(int length) {
Random random = new Random();
StringBuilder sb = new StringBuilder(length);