This is an automated email from the ASF dual-hosted git repository.

apitrou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new 37935604bf GH-38242: [Java] Fix incorrect internal struct accounting 
for DenseUnionVector#getBufferSizeFor (#38305)
37935604bf is described below

commit 37935604bf168a3b2d52f3cc5b0edf83b5783309
Author: Dan Stone <[email protected]>
AuthorDate: Mon Oct 23 12:30:27 2023 +0100

    GH-38242: [Java] Fix incorrect internal struct accounting for 
DenseUnionVector#getBufferSizeFor (#38305)
    
    ### What changes are included in this PR?
    
    Fix incorrect implementation of `DenseUnionVector.getBufferSizeFor`.
    Sum the type id counts before calling `getBufferSizeFor` on the union's 
child vectors.
    
    ### Are these changes tested?
    
    Yes. A test verifies the OOB read (requires bounds check), and an example 
count is correct, fails as expected before fix.
    
    **This PR contains a "Critical Fix".**
    
    For users of DenseUnionVector the size can calculated incorrectly, as well 
as cause of out-of-bounds buffer reads which may return garbage (or potentially 
segfaulting) if bounds checking is off.
    
    * Closes: #38242
    
    Authored-by: Dan Stone <[email protected]>
    Signed-off-by: Antoine Pitrou <[email protected]>
---
 .../main/codegen/templates/DenseUnionVector.java   | 20 +++++-
 .../vector/complex/TestDenseUnionBufferSize.java   | 81 ++++++++++++++++++++++
 2 files changed, 99 insertions(+), 2 deletions(-)

diff --git a/java/vector/src/main/codegen/templates/DenseUnionVector.java 
b/java/vector/src/main/codegen/templates/DenseUnionVector.java
index 12fc52af3c..de0cf84fd8 100644
--- a/java/vector/src/main/codegen/templates/DenseUnionVector.java
+++ b/java/vector/src/main/codegen/templates/DenseUnionVector.java
@@ -730,8 +730,24 @@ public class DenseUnionVector extends 
AbstractContainerVector implements FieldVe
     if (count == 0) {
       return 0;
     }
-    return (int) (count * TYPE_WIDTH + (long) count * OFFSET_WIDTH
-        + DataSizeRoundingUtil.divideBy8Ceil(count) + 
internalStruct.getBufferSizeFor(count));
+
+    int[] counts = new int[Byte.MAX_VALUE + 1];
+    for (int i = 0; i < count; i++) {
+      byte typeId = getTypeId(i);
+      if (typeId != -1) {
+        counts[typeId] += 1;
+      }
+    }
+
+    long childBytes = 0;
+    for (int typeId = 0; typeId < childVectors.length; typeId++) {
+      ValueVector childVector = childVectors[typeId];
+      if (childVector != null) {
+        childBytes += childVector.getBufferSizeFor(counts[typeId]);
+      }
+    }
+
+    return (int) (count * TYPE_WIDTH + (long) count * OFFSET_WIDTH + 
childBytes);
   }
 
   @Override
diff --git 
a/java/vector/src/test/java/org/apache/arrow/vector/complex/TestDenseUnionBufferSize.java
 
b/java/vector/src/test/java/org/apache/arrow/vector/complex/TestDenseUnionBufferSize.java
new file mode 100644
index 0000000000..82ef7a479d
--- /dev/null
+++ 
b/java/vector/src/test/java/org/apache/arrow/vector/complex/TestDenseUnionBufferSize.java
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.vector.complex;
+
+import static org.junit.jupiter.api.Assertions.*;
+
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.memory.RootAllocator;
+import org.apache.arrow.vector.BaseValueVector;
+import org.apache.arrow.vector.IntVector;
+import org.apache.arrow.vector.VarBinaryVector;
+import org.apache.arrow.vector.holders.NullableIntHolder;
+import org.apache.arrow.vector.holders.NullableVarBinaryHolder;
+import org.apache.arrow.vector.types.UnionMode;
+import org.apache.arrow.vector.types.pojo.ArrowType;
+import org.apache.arrow.vector.types.pojo.FieldType;
+import org.junit.jupiter.api.Test;
+
+public class TestDenseUnionBufferSize {
+  @Test
+  public void testBufferSize() {
+    try (BufferAllocator allocator = new RootAllocator();
+         DenseUnionVector duv = new DenseUnionVector("duv", allocator,
+                 FieldType.nullable(new ArrowType.Union(UnionMode.Dense, 
null)), null)) {
+
+      byte aTypeId = 42;
+      byte bTypeId = 7;
+
+      duv.addVector(aTypeId, new IntVector("a", FieldType.notNullable(new 
ArrowType.Int(32, true)), allocator));
+      duv.addVector(bTypeId, new VarBinaryVector("b", 
FieldType.notNullable(new ArrowType.Binary()), allocator));
+
+      NullableIntHolder intHolder = new NullableIntHolder();
+      NullableVarBinaryHolder varBinaryHolder = new NullableVarBinaryHolder();
+
+      int aCount = BaseValueVector.INITIAL_VALUE_ALLOCATION + 1;
+      for (int i = 0; i < aCount; i++) {
+        duv.setTypeId(i, aTypeId);
+        duv.setSafe(i, intHolder);
+      }
+
+      int bCount = 2;
+      for (int i = 0; i < bCount; i++) {
+        duv.setTypeId(i + aCount, bTypeId);
+        duv.setSafe(i + aCount, varBinaryHolder);
+      }
+
+      int count = aCount + bCount;
+      duv.setValueCount(count);
+
+      // will not necessarily see an error unless bounds checking is on.
+      assertDoesNotThrow(duv::getBufferSize);
+
+      IntVector intVector = duv.getIntVector(aTypeId);
+      VarBinaryVector varBinaryVector = duv.getVarBinaryVector(bTypeId);
+
+      long overhead = DenseUnionVector.TYPE_WIDTH + 
DenseUnionVector.OFFSET_WIDTH;
+
+      assertEquals(overhead * count + intVector.getBufferSize() + 
varBinaryVector.getBufferSize(),
+              duv.getBufferSize());
+
+      assertEquals(overhead * (aCount + 1) + 
intVector.getBufferSizeFor(aCount) + varBinaryVector.getBufferSizeFor(1),
+              duv.getBufferSizeFor(aCount + 1));
+
+    }
+  }
+}

Reply via email to