This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new 6730c1a392 GH-44088: [Java] Fix copyFrom in 
BaseVariableWidthViewVector (#44078)
6730c1a392 is described below

commit 6730c1a39260e74d97ef3eb7f25c7207b4157f26
Author: ViggoC <[email protected]>
AuthorDate: Mon Sep 23 14:42:14 2024 +0800

    GH-44088: [Java] Fix copyFrom in BaseVariableWidthViewVector (#44078)
    
    Fix bugs in `copyFromSafe` and `handleSafe`
    * GitHub Issue: #44088
    
    Authored-by: chenweiguo.vc <[email protected]>
    Signed-off-by: David Li <[email protected]>
---
 .../arrow/vector/BaseVariableWidthViewVector.java  | 102 +++++++-----------
 .../apache/arrow/vector/TestVarCharViewVector.java | 119 +++++++++++++++++++--
 2 files changed, 152 insertions(+), 69 deletions(-)

diff --git 
a/java/vector/src/main/java/org/apache/arrow/vector/BaseVariableWidthViewVector.java
 
b/java/vector/src/main/java/org/apache/arrow/vector/BaseVariableWidthViewVector.java
index aee5233f9d..15d2182783 100644
--- 
a/java/vector/src/main/java/org/apache/arrow/vector/BaseVariableWidthViewVector.java
+++ 
b/java/vector/src/main/java/org/apache/arrow/vector/BaseVariableWidthViewVector.java
@@ -565,6 +565,7 @@ public abstract class BaseVariableWidthViewVector extends 
BaseValueVector
     viewBuffer.getReferenceManager().release();
     viewBuffer = newBuf;
     lastValueAllocationSizeInBytes = viewBuffer.capacity();
+    lastValueCapacity = getValueCapacity();
   }
 
   /**
@@ -1248,10 +1249,7 @@ public abstract class BaseVariableWidthViewVector 
extends BaseValueVector
    */
   @Override
   public void setNull(int index) {
-    // We need to check and reallocate the validity buffer
-    while (index >= getValueCapacity()) {
-      reallocValidityBuffer();
-    }
+    handleSafe(index, 0);
     BitVectorHelper.unsetBit(validityBuffer, index);
   }
 
@@ -1460,25 +1458,12 @@ public abstract class BaseVariableWidthViewVector 
extends BaseValueVector
   }
 
   protected final void handleSafe(int index, int dataLength) {
-    final long lastSetCapacity = lastSet < 0 ? 0 : (long) index * ELEMENT_SIZE;
-    final long targetCapacity = roundUpToMultipleOf16(lastSetCapacity + 
dataLength);
-    // for views, we need each buffer with 16 byte alignment, so we need to 
check the last written
-    // index
-    // in the viewBuffer and allocate a new buffer which has 16 byte alignment 
for adding new
-    // values.
-    long writePosition = (long) index * ELEMENT_SIZE;
-    if (viewBuffer.capacity() <= writePosition || viewBuffer.capacity() < 
targetCapacity) {
-      /*
-       * Everytime we want to increase the capacity of the viewBuffer, we need 
to make sure that the new capacity
-       * meets 16 byte alignment.
-       * If the targetCapacity is larger than the writePosition, we may not 
necessarily
-       * want to allocate the targetCapacity to viewBuffer since when it is 
>={@link #INLINE_SIZE} either way
-       * we are writing to the dataBuffer.
-       */
-      reallocViewBuffer(Math.max(writePosition, targetCapacity));
+    final long targetCapacity = roundUpToMultipleOf16((long) index * 
ELEMENT_SIZE + dataLength);
+    if (viewBuffer.capacity() < targetCapacity) {
+      reallocViewBuffer(targetCapacity);
     }
 
-    while (index >= getValueCapacity()) {
+    while (index >= getValidityBufferValueCapacity()) {
       reallocValidityBuffer();
     }
   }
@@ -1498,26 +1483,7 @@ public abstract class BaseVariableWidthViewVector 
extends BaseValueVector
       BitVectorHelper.unsetBit(validityBuffer, thisIndex);
     } else {
       final int viewLength = from.getDataBuffer().getInt((long) fromIndex * 
ELEMENT_SIZE);
-      BitVectorHelper.setBit(validityBuffer, thisIndex);
-      final int start = thisIndex * ELEMENT_SIZE;
-      final int copyStart = fromIndex * ELEMENT_SIZE;
-      from.getDataBuffer().getBytes(start, viewBuffer, copyStart, 
ELEMENT_SIZE);
-      if (viewLength > INLINE_SIZE) {
-        final int bufIndex =
-            from.getDataBuffer()
-                .getInt(((long) fromIndex * ELEMENT_SIZE) + LENGTH_WIDTH + 
PREFIX_WIDTH);
-        final int dataOffset =
-            from.getDataBuffer()
-                .getInt(
-                    ((long) fromIndex * ELEMENT_SIZE)
-                        + LENGTH_WIDTH
-                        + PREFIX_WIDTH
-                        + BUF_INDEX_WIDTH);
-        final ArrowBuf dataBuf = ((BaseVariableWidthViewVector) 
from).dataBuffers.get(bufIndex);
-        final ArrowBuf thisDataBuf = allocateOrGetLastDataBuffer(viewLength);
-        thisDataBuf.setBytes(thisDataBuf.writerIndex(), dataBuf, dataOffset, 
viewLength);
-        thisDataBuf.writerIndex(thisDataBuf.writerIndex() + viewLength);
-      }
+      copyFromNotNull(fromIndex, thisIndex, from, viewLength);
     }
     lastSet = thisIndex;
   }
@@ -1539,30 +1505,44 @@ public abstract class BaseVariableWidthViewVector 
extends BaseValueVector
     } else {
       final int viewLength = from.getDataBuffer().getInt((long) fromIndex * 
ELEMENT_SIZE);
       handleSafe(thisIndex, viewLength);
-      BitVectorHelper.setBit(validityBuffer, thisIndex);
-      final int start = thisIndex * ELEMENT_SIZE;
-      final int copyStart = fromIndex * ELEMENT_SIZE;
-      from.getDataBuffer().getBytes(start, viewBuffer, copyStart, 
ELEMENT_SIZE);
-      if (viewLength > INLINE_SIZE) {
-        final int bufIndex =
-            from.getDataBuffer()
-                .getInt(((long) fromIndex * ELEMENT_SIZE) + LENGTH_WIDTH + 
PREFIX_WIDTH);
-        final int dataOffset =
-            from.getDataBuffer()
-                .getInt(
-                    ((long) fromIndex * ELEMENT_SIZE)
-                        + LENGTH_WIDTH
-                        + PREFIX_WIDTH
-                        + BUF_INDEX_WIDTH);
-        final ArrowBuf dataBuf = ((BaseVariableWidthViewVector) 
from).dataBuffers.get(bufIndex);
-        final ArrowBuf thisDataBuf = allocateOrGetLastDataBuffer(viewLength);
-        thisDataBuf.setBytes(thisDataBuf.writerIndex(), dataBuf, dataOffset, 
viewLength);
-        thisDataBuf.writerIndex(thisDataBuf.writerIndex() + viewLength);
-      }
+      copyFromNotNull(fromIndex, thisIndex, from, viewLength);
     }
     lastSet = thisIndex;
   }
 
+  private void copyFromNotNull(int fromIndex, int thisIndex, ValueVector from, 
int viewLength) {
+    BitVectorHelper.setBit(validityBuffer, thisIndex);
+    final int start = thisIndex * ELEMENT_SIZE;
+    final int copyStart = fromIndex * ELEMENT_SIZE;
+    if (viewLength > INLINE_SIZE) {
+      final int bufIndex =
+          from.getDataBuffer()
+              .getInt(((long) fromIndex * ELEMENT_SIZE) + LENGTH_WIDTH + 
PREFIX_WIDTH);
+      final int dataOffset =
+          from.getDataBuffer()
+              .getInt(
+                  ((long) fromIndex * ELEMENT_SIZE)
+                      + LENGTH_WIDTH
+                      + PREFIX_WIDTH
+                      + BUF_INDEX_WIDTH);
+      final ArrowBuf dataBuf = ((BaseVariableWidthViewVector) 
from).dataBuffers.get(bufIndex);
+      final ArrowBuf thisDataBuf = allocateOrGetLastDataBuffer(viewLength);
+
+      viewBuffer.setBytes(start, from.getDataBuffer(), copyStart, LENGTH_WIDTH 
+ PREFIX_WIDTH);
+      int writePosition = start + LENGTH_WIDTH + PREFIX_WIDTH;
+      // set buf id
+      viewBuffer.setInt(writePosition, dataBuffers.size() - 1);
+      writePosition += BUF_INDEX_WIDTH;
+      // set offset
+      viewBuffer.setInt(writePosition, (int) thisDataBuf.writerIndex());
+
+      thisDataBuf.setBytes(thisDataBuf.writerIndex(), dataBuf, dataOffset, 
viewLength);
+      thisDataBuf.writerIndex(thisDataBuf.writerIndex() + viewLength);
+    } else {
+      from.getDataBuffer().getBytes(copyStart, viewBuffer, start, 
ELEMENT_SIZE);
+    }
+  }
+
   @Override
   public ArrowBufPointer getDataPointer(int index) {
     return getDataPointer(index, new ArrowBufPointer());
diff --git 
a/java/vector/src/test/java/org/apache/arrow/vector/TestVarCharViewVector.java 
b/java/vector/src/test/java/org/apache/arrow/vector/TestVarCharViewVector.java
index 308431fdeb..232eec9ef1 100644
--- 
a/java/vector/src/test/java/org/apache/arrow/vector/TestVarCharViewVector.java
+++ 
b/java/vector/src/test/java/org/apache/arrow/vector/TestVarCharViewVector.java
@@ -33,7 +33,9 @@ import java.nio.ByteBuffer;
 import java.nio.charset.StandardCharsets;
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.HashMap;
 import java.util.List;
+import java.util.Map;
 import java.util.Objects;
 import java.util.Random;
 import java.util.function.BiConsumer;
@@ -261,6 +263,39 @@ public class TestVarCharViewVector {
     }
   }
 
+  @Test
+  public void testSetSafe() {
+    try (final ViewVarCharVector viewVarCharVector = new 
ViewVarCharVector("myvector", allocator)) {
+      viewVarCharVector.allocateNew(1, 1);
+      byte[] str6 = generateRandomString(40).getBytes();
+      final List<byte[]> strings = List.of(STR0, STR1, STR2, STR3, STR4, STR5, 
str6);
+
+      // set data to a position out of capacity index
+      Map<Integer, byte[]> expected = new HashMap<>();
+      for (byte[] string : strings) {
+        int cap = viewVarCharVector.getValueCapacity();
+        expected.put(cap, string);
+        viewVarCharVector.setSafe(cap, string);
+      }
+      int nullIndex = viewVarCharVector.getValueCapacity();
+      viewVarCharVector.setNull(nullIndex);
+      int valueCount = nullIndex + 1;
+      viewVarCharVector.setValueCount(valueCount);
+      assertEquals(viewVarCharVector.getNullCount(), valueCount - 
strings.size());
+
+      assertEquals(128, viewVarCharVector.getValueCapacity());
+      assertEquals(2, viewVarCharVector.dataBuffers.size());
+
+      for (int i = 0; i < viewVarCharVector.getValueCapacity(); i++) {
+        if (expected.containsKey(i)) {
+          assertArrayEquals(expected.get(i), viewVarCharVector.get(i));
+        } else {
+          assertNull(viewVarCharVector.get(i));
+        }
+      }
+    }
+  }
+
   @Test
   public void testMixedAllocation() {
     try (final ViewVarCharVector viewVarCharVector = new 
ViewVarCharVector("myvector", allocator)) {
@@ -1749,12 +1784,12 @@ public class TestVarCharViewVector {
         } else if (i % 3 == 1) {
           assertArrayEquals(
               Integer.toString(i).getBytes(StandardCharsets.UTF_8),
-              vector.get(i),
+              vector2.get(i),
               "unexpected value at index: " + i);
         } else {
           assertArrayEquals(
               (i + prefixString).getBytes(StandardCharsets.UTF_8),
-              vector.get(i),
+              vector2.get(i),
               "unexpected value at index: " + i);
         }
       }
@@ -1769,12 +1804,12 @@ public class TestVarCharViewVector {
         } else if (i % 3 == 1) {
           assertArrayEquals(
               Integer.toString(i).getBytes(StandardCharsets.UTF_8),
-              vector.get(i),
+              vector2.get(i),
               "unexpected value at index: " + i);
         } else {
           assertArrayEquals(
               (i + prefixString).getBytes(StandardCharsets.UTF_8),
-              vector.get(i),
+              vector2.get(i),
               "unexpected value at index: " + i);
         }
       }
@@ -1846,12 +1881,12 @@ public class TestVarCharViewVector {
         } else if (i % 3 == 1) {
           assertArrayEquals(
               Integer.toString(i).getBytes(StandardCharsets.UTF_8),
-              vector.get(i),
+              vector2.get(i),
               "unexpected value at index: " + i);
         } else {
           assertArrayEquals(
               (i + prefixString).getBytes(StandardCharsets.UTF_8),
-              vector.get(i),
+              vector2.get(i),
               "unexpected value at index: " + i);
         }
       }
@@ -1867,15 +1902,83 @@ public class TestVarCharViewVector {
         } else if (i % 3 == 1) {
           assertArrayEquals(
               Integer.toString(i).getBytes(StandardCharsets.UTF_8),
-              vector.get(i),
+              vector2.get(i),
               "unexpected value at index: " + i);
         } else {
           assertArrayEquals(
               (i + prefixString).getBytes(StandardCharsets.UTF_8),
-              vector.get(i),
+              vector2.get(i),
+              "unexpected value at index: " + i);
+        }
+      }
+
+      // make it reallocate
+      int valueCapacity = vector2.getValueCapacity();
+      for (int i = 0; i < numberOfValues; i++) {
+        int thisIndex = i + valueCapacity;
+        vector2.copyFromSafe(i, thisIndex, vector);
+        if (i % 3 == 0) {
+          assertNull(vector2.getObject(thisIndex));
+        } else if (i % 3 == 1) {
+          assertArrayEquals(
+              Integer.toString(i).getBytes(StandardCharsets.UTF_8),
+              vector2.get(thisIndex),
+              "unexpected value at index: " + i);
+        } else {
+          assertArrayEquals(
+              (i + prefixString).getBytes(StandardCharsets.UTF_8),
+              vector2.get(thisIndex),
               "unexpected value at index: " + i);
         }
       }
+
+      // test target vector with different initialCapacity
+      try (final BaseVariableWidthViewVector vector3 = 
vectorCreator.apply(allocator)) {
+        vector3.setInitialCapacity(16);
+        vector3.allocateNew();
+        for (int i = 0; i < numberOfValues; i++) {
+          vector3.copyFromSafe(i, i, vector);
+          if (i % 3 == 0) {
+            assertNull(vector3.getObject(i));
+          } else {
+            assertArrayEquals(vector.get(i), vector3.get(i));
+          }
+        }
+      }
+
+      // test overwrite a used vector by copy
+      try (final BaseVariableWidthViewVector targetVector = 
vectorCreator.apply(allocator)) {
+
+        targetVector.setInitialCapacity(initialCapacity);
+        targetVector.allocateNew();
+
+        // source vector: null, short, long...
+        // target vector: long, null, short...
+        for (int i = 0; i < numberOfValues; i++) {
+          if (i % 3 == 0) {
+            // long strings
+            byte[] b = (i + prefixString).getBytes(StandardCharsets.UTF_8);
+            targetVector.set(i, b, 0, b.length);
+          } else if (i % 3 == 1) {
+            // null values
+            targetVector.setNull(i);
+          } else {
+            // short strings
+            byte[] b = Integer.toString(i).getBytes(StandardCharsets.UTF_8);
+            targetVector.set(i, b, 0, b.length);
+          }
+        }
+        targetVector.setValueCount(numberOfValues);
+
+        for (int i = 0; i < numberOfValues; i++) {
+          targetVector.copyFromSafe(i, i, vector);
+          if (i % 3 == 0) {
+            assertNull(targetVector.getObject(i));
+          } else {
+            assertArrayEquals(vector.get(i), targetVector.get(i));
+          }
+        }
+      }
     }
   }
 

Reply via email to