This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new 7bc2452b35 GH-41902: [Java] Variadic Buffer Counts Incorrect (#41930)
7bc2452b35 is described below

commit 7bc2452b350867b3ddc9de9ceceeef0e4d722941
Author: Vibhatha Lakmal Abeykoon <[email protected]>
AuthorDate: Tue Jun 4 12:35:13 2024 +0530

    GH-41902: [Java] Variadic Buffer Counts Incorrect (#41930)
    
    ### Rationale for this change
    
    In the initial PR for `variadicBufferCounts` addition to Java spec, the non 
variadic buffer-ed vectors were assigned with 0 valued non-empty 
`variadicBufferCounts`. And this caused CIs to fail in Arrow Rust.
    
    ### What changes are included in this PR?
    
    This PR changes such that non variadic buffer-ed vectors would contain an 
empty `variadicBufferCounts` attribute in `ArrowRecordBatch` interface in Java. 
Also this includes upgrade to JUNIT5.
    
    ### Are these changes tested?
    
    Yes, from existing tests and a new test added.
    
    ### Are there any user-facing changes?
    
    No
    * GitHub Issue: #41902
    
    Authored-by: Vibhatha Abeykoon <[email protected]>
    Signed-off-by: David Li <[email protected]>
---
 .../org/apache/arrow/c/StructVectorLoader.java     | 29 +++++---
 .../org/apache/arrow/c/StructVectorUnloader.java   |  5 +-
 .../java/org/apache/arrow/c/DictionaryTest.java    | 60 ++++++++++++++++
 .../java/org/apache/arrow/vector/VectorLoader.java | 20 ++++--
 .../org/apache/arrow/vector/VectorUnloader.java    |  5 +-
 .../org/apache/arrow/vector/TestValueVector.java   | 56 +++++++++++++++
 .../apache/arrow/vector/TestVarCharViewVector.java | 80 ++++++++++++++++++++++
 7 files changed, 238 insertions(+), 17 deletions(-)

diff --git a/java/c/src/main/java/org/apache/arrow/c/StructVectorLoader.java 
b/java/c/src/main/java/org/apache/arrow/c/StructVectorLoader.java
index 27acf84d30..1b0c59163a 100644
--- a/java/c/src/main/java/org/apache/arrow/c/StructVectorLoader.java
+++ b/java/c/src/main/java/org/apache/arrow/c/StructVectorLoader.java
@@ -27,6 +27,7 @@ import java.util.List;
 import org.apache.arrow.memory.ArrowBuf;
 import org.apache.arrow.memory.BufferAllocator;
 import org.apache.arrow.util.Collections2;
+import org.apache.arrow.vector.BaseVariableWidthViewVector;
 import org.apache.arrow.vector.FieldVector;
 import org.apache.arrow.vector.TypeLayout;
 import org.apache.arrow.vector.complex.StructVector;
@@ -54,7 +55,14 @@ public class StructVectorLoader {
 
   /**
    * Construct with a schema.
-   *
+   * <p>
+   * The schema referred to here can be obtained from the struct vector.
+   * The schema here should be the children of a struct vector, not a schema
+   * containing the struct field itself.
+   * For example:
+   * <code>
+   * Schema schema = new Schema(structVector.getField().getChildren());
+   * </code>
    * @param schema buffers are added based on schema.
    */
   public StructVectorLoader(Schema schema) {
@@ -90,7 +98,7 @@ public class StructVectorLoader {
         .fromCompressionType(recordBatch.getBodyCompression().getCodec());
     decompressionNeeded = codecType != 
CompressionUtil.CodecType.NO_COMPRESSION;
     CompressionCodec codec = decompressionNeeded ? 
factory.createCodec(codecType) : NoCompressionCodec.INSTANCE;
-    Iterator<Long> variadicBufferCounts = null;
+    Iterator<Long> variadicBufferCounts = Collections.emptyIterator();
     if (recordBatch.getVariadicBufferCounts() != null && 
!recordBatch.getVariadicBufferCounts().isEmpty()) {
       variadicBufferCounts = recordBatch.getVariadicBufferCounts().iterator();
     }
@@ -98,9 +106,10 @@ public class StructVectorLoader {
       loadBuffers(fieldVector, fieldVector.getField(), buffers, nodes, codec, 
variadicBufferCounts);
     }
     result.loadFieldBuffers(new ArrowFieldNode(recordBatch.getLength(), 0), 
Collections.singletonList(null));
-    if (nodes.hasNext() || buffers.hasNext()) {
-      throw new IllegalArgumentException("not all nodes and buffers were 
consumed. nodes: " + 
-        Collections2.toList(nodes).toString() + " buffers: " + 
Collections2.toList(buffers).toString());
+    if (nodes.hasNext() || buffers.hasNext() || 
variadicBufferCounts.hasNext()) {
+      throw new IllegalArgumentException("not all nodes, buffers and 
variadicBufferCounts were consumed. nodes: " +
+        Collections2.toString(nodes) + " buffers: " + 
Collections2.toString(buffers) + " variadicBufferCounts: " +
+          Collections2.toString(variadicBufferCounts));
     }
     return result;
   }
@@ -109,10 +118,14 @@ public class StructVectorLoader {
       CompressionCodec codec, Iterator<Long> variadicBufferCounts) {
     checkArgument(nodes.hasNext(), "no more field nodes for field %s and 
vector %s", field, vector);
     ArrowFieldNode fieldNode = nodes.next();
-    // variadicBufferLayoutCount will be 0 for vectors of type except 
BaseVariableWidthViewVector
+    // variadicBufferLayoutCount will be 0 for vectors of a type except 
BaseVariableWidthViewVector
     long variadicBufferLayoutCount = 0;
-    if (variadicBufferCounts != null) {
-      variadicBufferLayoutCount = variadicBufferCounts.next();
+    if (vector instanceof BaseVariableWidthViewVector) {
+      if (variadicBufferCounts.hasNext()) {
+        variadicBufferLayoutCount = variadicBufferCounts.next();
+      } else {
+        throw new IllegalStateException("No variadicBufferCounts available for 
BaseVariableWidthViewVector");
+      }
     }
     int bufferLayoutCount = (int) (variadicBufferLayoutCount + 
TypeLayout.getTypeBufferCount(field.getType()));
     List<ArrowBuf> ownBuffers = new ArrayList<>(bufferLayoutCount);
diff --git a/java/c/src/main/java/org/apache/arrow/c/StructVectorUnloader.java 
b/java/c/src/main/java/org/apache/arrow/c/StructVectorUnloader.java
index 8d015157eb..82539acf6f 100644
--- a/java/c/src/main/java/org/apache/arrow/c/StructVectorUnloader.java
+++ b/java/c/src/main/java/org/apache/arrow/c/StructVectorUnloader.java
@@ -109,7 +109,10 @@ public class StructVectorUnloader {
     List<ArrowBuf> fieldBuffers = vector.getFieldBuffers();
     long variadicBufferCount = getVariadicBufferCount(vector);
     int expectedBufferCount = (int) 
(TypeLayout.getTypeBufferCount(vector.getField().getType()) + 
variadicBufferCount);
-    variadicBufferCounts.add(variadicBufferCount);
+    // only update variadicBufferCounts for vectors that have variadic buffers
+    if (variadicBufferCount > 0) {
+      variadicBufferCounts.add(variadicBufferCount);
+    }
     if (fieldBuffers.size() != expectedBufferCount) {
       throw new IllegalArgumentException(String.format("wrong number of 
buffers for field %s in vector %s. found: %s",
           vector.getField(), vector.getClass().getSimpleName(), fieldBuffers));
diff --git a/java/c/src/test/java/org/apache/arrow/c/DictionaryTest.java 
b/java/c/src/test/java/org/apache/arrow/c/DictionaryTest.java
index d892781756..aa1264e484 100644
--- a/java/c/src/test/java/org/apache/arrow/c/DictionaryTest.java
+++ b/java/c/src/test/java/org/apache/arrow/c/DictionaryTest.java
@@ -17,6 +17,8 @@
 
 package org.apache.arrow.c;
 
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
 import static org.junit.jupiter.api.Assertions.assertTrue;
 
 import java.io.ByteArrayInputStream;
@@ -29,6 +31,7 @@ import org.apache.arrow.c.ArrowArray;
 import org.apache.arrow.c.ArrowSchema;
 import org.apache.arrow.c.CDataDictionaryProvider;
 import org.apache.arrow.c.Data;
+import org.apache.arrow.memory.BufferAllocator;
 import org.apache.arrow.memory.RootAllocator;
 import org.apache.arrow.util.AutoCloseables;
 import org.apache.arrow.vector.FieldVector;
@@ -36,13 +39,19 @@ import org.apache.arrow.vector.IntVector;
 import org.apache.arrow.vector.ValueVector;
 import org.apache.arrow.vector.VarCharVector;
 import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.vector.ViewVarCharVector;
 import org.apache.arrow.vector.compare.VectorEqualsVisitor;
+import org.apache.arrow.vector.complex.StructVector;
 import org.apache.arrow.vector.dictionary.Dictionary;
 import org.apache.arrow.vector.dictionary.DictionaryEncoder;
 import org.apache.arrow.vector.dictionary.DictionaryProvider;
 import org.apache.arrow.vector.ipc.ArrowStreamReader;
 import org.apache.arrow.vector.ipc.ArrowStreamWriter;
+import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
+import org.apache.arrow.vector.types.Types.MinorType;
 import org.apache.arrow.vector.types.pojo.DictionaryEncoding;
+import org.apache.arrow.vector.types.pojo.Field;
+import org.apache.arrow.vector.types.pojo.FieldType;
 import org.apache.arrow.vector.types.pojo.Schema;
 import org.junit.jupiter.api.AfterEach;
 import org.junit.jupiter.api.BeforeEach;
@@ -216,4 +225,55 @@ public class DictionaryTest {
     return new ArrowStreamReader(in, allocator);
   }
 
+  private void createStructVector(StructVector vector) {
+    final ViewVarCharVector child1 = vector.addOrGet("f0",
+        FieldType.nullable(MinorType.VIEWVARCHAR.getType()), 
ViewVarCharVector.class);
+    final IntVector child2 = vector.addOrGet("f1",
+        FieldType.nullable(MinorType.INT.getType()), IntVector.class);
+
+    // Write the values to child 1
+    child1.allocateNew();
+    child1.set(0, "01234567890".getBytes());
+    child1.set(1, "012345678901234567".getBytes());
+    vector.setIndexDefined(0);
+
+    // Write the values to child 2
+    child2.allocateNew();
+    child2.set(0, 10);
+    child2.set(1, 11);
+    vector.setIndexDefined(1);
+
+    vector.setValueCount(2);
+  }
+
+  @Test
+  public void testVectorLoadUnloadOnStructVector() {
+    try (final StructVector structVector1 = StructVector.empty("struct", 
allocator)) {
+      createStructVector(structVector1);
+      Field field1 = structVector1.getField();
+      Schema schema = new Schema(field1.getChildren());
+      StructVectorUnloader vectorUnloader = new 
StructVectorUnloader(structVector1);
+
+      try (
+          ArrowRecordBatch recordBatch = vectorUnloader.getRecordBatch();
+          BufferAllocator finalVectorsAllocator = 
allocator.newChildAllocator("struct", 0, Long.MAX_VALUE);
+      ) {
+        // validating recordBatch contains an output for variadicBufferCounts
+        assertFalse(recordBatch.getVariadicBufferCounts().isEmpty());
+        assertEquals(1, recordBatch.getVariadicBufferCounts().size());
+        assertEquals(1, recordBatch.getVariadicBufferCounts().get(0));
+
+        StructVectorLoader vectorLoader = new StructVectorLoader(schema);
+        try (StructVector structVector2 = 
vectorLoader.load(finalVectorsAllocator, recordBatch)) {
+          // Improve this after fixing 
https://github.com/apache/arrow/issues/41933
+          // assertTrue(VectorEqualsVisitor.vectorEquals(structVector1, 
structVector2), "vectors are not equivalent");
+          
assertTrue(VectorEqualsVisitor.vectorEquals(structVector1.getChild("f0"), 
structVector2.getChild("f0")),
+              "vectors are not equivalent");
+          
assertTrue(VectorEqualsVisitor.vectorEquals(structVector1.getChild("f1"), 
structVector2.getChild("f1")),
+              "vectors are not equivalent");
+        }
+      }
+    }
+  }
+
 }
diff --git 
a/java/vector/src/main/java/org/apache/arrow/vector/VectorLoader.java 
b/java/vector/src/main/java/org/apache/arrow/vector/VectorLoader.java
index 9590e70f46..dec536ae6c 100644
--- a/java/vector/src/main/java/org/apache/arrow/vector/VectorLoader.java
+++ b/java/vector/src/main/java/org/apache/arrow/vector/VectorLoader.java
@@ -20,6 +20,7 @@ package org.apache.arrow.vector;
 import static org.apache.arrow.util.Preconditions.checkArgument;
 
 import java.util.ArrayList;
+import java.util.Collections;
 import java.util.Iterator;
 import java.util.List;
 
@@ -80,7 +81,7 @@ public class VectorLoader {
         
CompressionUtil.CodecType.fromCompressionType(recordBatch.getBodyCompression().getCodec());
     decompressionNeeded = codecType != 
CompressionUtil.CodecType.NO_COMPRESSION;
     CompressionCodec codec = decompressionNeeded ? 
factory.createCodec(codecType) : NoCompressionCodec.INSTANCE;
-    Iterator<Long> variadicBufferCounts = null;
+    Iterator<Long> variadicBufferCounts = Collections.emptyIterator();;
     if (recordBatch.getVariadicBufferCounts() != null && 
!recordBatch.getVariadicBufferCounts().isEmpty()) {
       variadicBufferCounts = recordBatch.getVariadicBufferCounts().iterator();
     }
@@ -89,9 +90,10 @@ public class VectorLoader {
       loadBuffers(fieldVector, fieldVector.getField(), buffers, nodes, codec, 
variadicBufferCounts);
     }
     root.setRowCount(recordBatch.getLength());
-    if (nodes.hasNext() || buffers.hasNext()) {
-      throw new IllegalArgumentException("not all nodes and buffers were 
consumed. nodes: " +
-          Collections2.toString(nodes) + " buffers: " + 
Collections2.toString(buffers));
+    if (nodes.hasNext() || buffers.hasNext() || 
variadicBufferCounts.hasNext()) {
+      throw new IllegalArgumentException("not all nodes, buffers and 
variadicBufferCounts were consumed. nodes: " +
+          Collections2.toString(nodes) + " buffers: " + 
Collections2.toString(buffers) + " variadicBufferCounts: " +
+          Collections2.toString(variadicBufferCounts));
     }
   }
 
@@ -104,10 +106,14 @@ public class VectorLoader {
       Iterator<Long> variadicBufferCounts) {
     checkArgument(nodes.hasNext(), "no more field nodes for field %s and 
vector %s", field, vector);
     ArrowFieldNode fieldNode = nodes.next();
-    // variadicBufferLayoutCount will be 0 for vectors of type except 
BaseVariableWidthViewVector
+    // variadicBufferLayoutCount will be 0 for vectors of a type except 
BaseVariableWidthViewVector
     long variadicBufferLayoutCount = 0;
-    if (variadicBufferCounts != null) {
-      variadicBufferLayoutCount = variadicBufferCounts.next();
+    if (vector instanceof BaseVariableWidthViewVector) {
+      if (variadicBufferCounts.hasNext()) {
+        variadicBufferLayoutCount = variadicBufferCounts.next();
+      } else {
+        throw new IllegalStateException("No variadicBufferCounts available for 
BaseVariableWidthViewVector");
+      }
     }
     int bufferLayoutCount = (int) (variadicBufferLayoutCount + 
TypeLayout.getTypeBufferCount(field.getType()));
     List<ArrowBuf> ownBuffers = new ArrayList<>(bufferLayoutCount);
diff --git 
a/java/vector/src/main/java/org/apache/arrow/vector/VectorUnloader.java 
b/java/vector/src/main/java/org/apache/arrow/vector/VectorUnloader.java
index 8528099b6d..6e7ab34eba 100644
--- a/java/vector/src/main/java/org/apache/arrow/vector/VectorUnloader.java
+++ b/java/vector/src/main/java/org/apache/arrow/vector/VectorUnloader.java
@@ -103,7 +103,10 @@ public class VectorUnloader {
     List<ArrowBuf> fieldBuffers = vector.getFieldBuffers();
     long variadicBufferCount = getVariadicBufferCount(vector);
     int expectedBufferCount = (int) 
(TypeLayout.getTypeBufferCount(vector.getField().getType()) + 
variadicBufferCount);
-    variadicBufferCounts.add(variadicBufferCount);
+    // only update variadicBufferCounts for vectors that have variadic buffers
+    if (variadicBufferCount > 0) {
+      variadicBufferCounts.add(variadicBufferCount);
+    }
     if (fieldBuffers.size() != expectedBufferCount) {
       throw new IllegalArgumentException(String.format(
           "wrong number of buffers for field %s in vector %s. found: %s",
diff --git 
a/java/vector/src/test/java/org/apache/arrow/vector/TestValueVector.java 
b/java/vector/src/test/java/org/apache/arrow/vector/TestValueVector.java
index fda14b24a4..b0d316070a 100644
--- a/java/vector/src/test/java/org/apache/arrow/vector/TestValueVector.java
+++ b/java/vector/src/test/java/org/apache/arrow/vector/TestValueVector.java
@@ -3441,4 +3441,60 @@ public class TestValueVector {
     }
     target.close();
   }
+
+  @Test
+  public void testVectorLoadUnloadOnNonVariadicVectors() {
+
+    try (final IntVector vector1 = new IntVector("myvector", allocator)) {
+
+      setVector(vector1, 1, 2, 3, 4, 5, 6);
+      vector1.setValueCount(15);
+
+      /* Check the vector output */
+      assertEquals(1, vector1.get(0));
+      assertEquals(2, vector1.get(1));
+      assertEquals(3, vector1.get(2));
+      assertEquals(4, vector1.get(3));
+      assertEquals(5, vector1.get(4));
+      assertEquals(6, vector1.get(5));
+
+      Field field = vector1.getField();
+      String fieldName = field.getName();
+
+      List<Field> fields = new ArrayList<>();
+      List<FieldVector> fieldVectors = new ArrayList<>();
+
+      fields.add(field);
+      fieldVectors.add(vector1);
+
+      Schema schema = new Schema(fields);
+
+      VectorSchemaRoot schemaRoot1 = new VectorSchemaRoot(schema, 
fieldVectors, vector1.getValueCount());
+      VectorUnloader vectorUnloader = new VectorUnloader(schemaRoot1);
+
+      try (
+          ArrowRecordBatch recordBatch = vectorUnloader.getRecordBatch();
+          BufferAllocator finalVectorsAllocator = 
allocator.newChildAllocator("new vector", 0, Long.MAX_VALUE);
+          VectorSchemaRoot schemaRoot2 = VectorSchemaRoot.create(schema, 
finalVectorsAllocator);
+      ) {
+
+        // validating recordBatch doesn't contain an output for 
variadicBufferCounts
+        assertTrue(recordBatch.getVariadicBufferCounts().isEmpty());
+
+        VectorLoader vectorLoader = new VectorLoader(schemaRoot2);
+        vectorLoader.load(recordBatch);
+
+        IntVector vector2 = (IntVector) schemaRoot2.getVector(fieldName);
+        vector2.setValueCount(25);
+
+        /* Check the vector output */
+        assertEquals(1, vector2.get(0));
+        assertEquals(2, vector2.get(1));
+        assertEquals(3, vector2.get(2));
+        assertEquals(4, vector2.get(3));
+        assertEquals(5, vector2.get(4));
+        assertEquals(6, vector2.get(5));
+      }
+    }
+  }
 }
diff --git 
a/java/vector/src/test/java/org/apache/arrow/vector/TestVarCharViewVector.java 
b/java/vector/src/test/java/org/apache/arrow/vector/TestVarCharViewVector.java
index 1ba3bc3576..817941ecb4 100644
--- 
a/java/vector/src/test/java/org/apache/arrow/vector/TestVarCharViewVector.java
+++ 
b/java/vector/src/test/java/org/apache/arrow/vector/TestVarCharViewVector.java
@@ -23,6 +23,7 @@ import static 
org.apache.arrow.vector.TestUtils.newViewVarCharVector;
 import static 
org.apache.arrow.vector.testing.ValueVectorDataPopulator.setVector;
 import static org.junit.jupiter.api.Assertions.assertArrayEquals;
 import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
 import static org.junit.jupiter.api.Assertions.assertNotNull;
 import static org.junit.jupiter.api.Assertions.assertNull;
 import static org.junit.jupiter.api.Assertions.assertSame;
@@ -2212,6 +2213,85 @@ public class TestVarCharViewVector {
     }
   }
 
+  @Test
+  public void testVectorLoadUnloadOnMixedTypes() {
+
+    try (final IntVector vector1 = new IntVector("myvector", allocator);
+        final ViewVarCharVector vector2 = new 
ViewVarCharVector("myviewvector", allocator)) {
+
+      final int valueCount = 15;
+
+      setVector(vector1, 1, 2, 3, 4, 5, 6);
+      vector1.setValueCount(valueCount);
+
+      setVector(vector2, STR1, STR2, STR3, STR4, STR5, STR6);
+      vector1.setValueCount(valueCount);
+
+      /* Check the vector output */
+      assertEquals(1, vector1.get(0));
+      assertEquals(2, vector1.get(1));
+      assertEquals(3, vector1.get(2));
+      assertEquals(4, vector1.get(3));
+      assertEquals(5, vector1.get(4));
+      assertEquals(6, vector1.get(5));
+
+      Field field1 = vector1.getField();
+      String fieldName1 = field1.getName();
+
+      Field field2 = vector2.getField();
+      String fieldName2 = field2.getName();
+
+      List<Field> fields = new ArrayList<>(2);
+      List<FieldVector> fieldVectors = new ArrayList<>(2);
+
+      fields.add(field1);
+      fields.add(field2);
+      fieldVectors.add(vector1);
+      fieldVectors.add(vector2);
+
+      Schema schema = new Schema(fields);
+
+      VectorSchemaRoot schemaRoot1 = new VectorSchemaRoot(schema, 
fieldVectors, valueCount);
+      VectorUnloader vectorUnloader = new VectorUnloader(schemaRoot1);
+
+      try (
+          ArrowRecordBatch recordBatch = vectorUnloader.getRecordBatch();
+          BufferAllocator finalVectorsAllocator = 
allocator.newChildAllocator("new vector", 0, Long.MAX_VALUE);
+          VectorSchemaRoot schemaRoot2 = VectorSchemaRoot.create(schema, 
finalVectorsAllocator);
+      ) {
+
+        // validating recordBatch contains an output for variadicBufferCounts
+        assertFalse(recordBatch.getVariadicBufferCounts().isEmpty());
+        assertEquals(1, recordBatch.getVariadicBufferCounts().size());
+
+        VectorLoader vectorLoader = new VectorLoader(schemaRoot2);
+        vectorLoader.load(recordBatch);
+
+        IntVector vector3 = (IntVector) schemaRoot2.getVector(fieldName1);
+        vector3.setValueCount(25);
+
+        /* Check the vector output */
+        assertEquals(1, vector3.get(0));
+        assertEquals(2, vector3.get(1));
+        assertEquals(3, vector3.get(2));
+        assertEquals(4, vector3.get(3));
+        assertEquals(5, vector3.get(4));
+        assertEquals(6, vector3.get(5));
+
+        ViewVarCharVector vector4 = (ViewVarCharVector) 
schemaRoot2.getVector(fieldName2);
+        vector4.setValueCount(25);
+
+        /* Check the vector output */
+        assertArrayEquals(STR1, vector4.get(0));
+        assertArrayEquals(STR2, vector4.get(1));
+        assertArrayEquals(STR3, vector4.get(2));
+        assertArrayEquals(STR4, vector4.get(3));
+        assertArrayEquals(STR5, vector4.get(4));
+        assertArrayEquals(STR6, vector4.get(5));
+      }
+    }
+  }
+
   private String generateRandomString(int length) {
     Random random = new Random();
     StringBuilder sb = new StringBuilder(length);

Reply via email to