This is an automated email from the ASF dual-hosted git repository.

kazuyukitanimura pushed a commit to branch optimize-null-count
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git

commit 50d686eb85edaee39e3fb9b8f90ef2b0520b8b31
Author: Kazuyuki Tanimura <ktanim...@apple.com>
AuthorDate: Fri Aug 9 02:15:27 2024 -0700

    fix: Optimize not to call getNullCount as much as possible
---
 .../org/apache/arrow/c/CometArrayExporter.java     | 120 +++++++++++++++++++++
 .../org/apache/comet/parquet/ColumnReader.java     |  13 +--
 .../apache/comet/parquet/MetadataColumnReader.java |   5 +-
 .../main/java/org/apache/comet/parquet/Utils.java  |   8 ++
 .../apache/comet/vector/CometDecodedVector.java    |   7 +-
 .../apache/comet/vector/CometDictionaryVector.java |   7 +-
 .../org/apache/comet/vector/CometPlainVector.java  |   8 +-
 .../java/org/apache/comet/vector/CometVector.java  |  14 ++-
 .../scala/org/apache/comet/vector/NativeUtil.scala |  19 +++-
 9 files changed, 178 insertions(+), 23 deletions(-)

diff --git a/common/src/main/java/org/apache/arrow/c/CometArrayExporter.java 
b/common/src/main/java/org/apache/arrow/c/CometArrayExporter.java
new file mode 100644
index 00000000..f76e0e18
--- /dev/null
+++ b/common/src/main/java/org/apache/arrow/c/CometArrayExporter.java
@@ -0,0 +1,120 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.arrow.c;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.arrow.c.jni.JniWrapper;
+import org.apache.arrow.memory.ArrowBuf;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.vector.FieldVector;
+import org.apache.arrow.vector.dictionary.Dictionary;
+import org.apache.arrow.vector.dictionary.DictionaryProvider;
+import org.apache.arrow.vector.types.pojo.DictionaryEncoding;
+
+import static org.apache.arrow.c.Data.exportField;
+import static org.apache.arrow.c.NativeUtil.NULL;
+import static org.apache.arrow.c.NativeUtil.addressOrNull;
+import static org.apache.arrow.util.Preconditions.checkNotNull;
+
+public final class CometArrayExporter {
+  // Copied from Data.exportVector and changed to take nullCount from outside
+  public static void exportVector(
+      BufferAllocator allocator,
+      FieldVector vector,
+      DictionaryProvider provider,
+      ArrowArray out,
+      ArrowSchema outSchema,
+      long nullCount) {
+    exportField(allocator, vector.getField(), provider, outSchema);
+    export(allocator, out, vector, provider, nullCount);
+  }
+
+  private static void export(
+      BufferAllocator allocator,
+      ArrowArray array,
+      FieldVector vector,
+      DictionaryProvider dictionaryProvider,
+      long nullCount) {
+    List<FieldVector> children = vector.getChildrenFromFields();
+    List<ArrowBuf> buffers = vector.getFieldBuffers();
+    int valueCount = vector.getValueCount();
+    DictionaryEncoding dictionaryEncoding = vector.getField().getDictionary();
+
+    ArrayExporter.ExportedArrayPrivateData data = new 
ArrayExporter.ExportedArrayPrivateData();
+    try {
+      if (children != null) {
+        data.children = new ArrayList<>(children.size());
+        data.children_ptrs = allocator.buffer((long) children.size() * 
Long.BYTES);
+        for (int i = 0; i < children.size(); i++) {
+          ArrowArray child = ArrowArray.allocateNew(allocator);
+          data.children.add(child);
+          data.children_ptrs.writeLong(child.memoryAddress());
+        }
+      }
+
+      if (buffers != null) {
+        data.buffers = new ArrayList<>(buffers.size());
+        data.buffers_ptrs = allocator.buffer((long) buffers.size() * 
Long.BYTES);
+        vector.exportCDataBuffers(data.buffers, data.buffers_ptrs, NULL);
+      }
+
+      if (dictionaryEncoding != null) {
+        Dictionary dictionary = 
dictionaryProvider.lookup(dictionaryEncoding.getId());
+        checkNotNull(dictionary, "Dictionary lookup failed on export of 
dictionary encoded array");
+
+        data.dictionary = ArrowArray.allocateNew(allocator);
+        FieldVector dictionaryVector = dictionary.getVector();
+        // Since the dictionary index tracks the nullCount, the nullCount of 
the values can be 0
+        export(allocator, data.dictionary, dictionaryVector, 
dictionaryProvider, 0);
+      }
+
+      ArrowArray.Snapshot snapshot = new ArrowArray.Snapshot();
+      snapshot.length = valueCount;
+      snapshot.null_count = nullCount;
+      snapshot.offset = 0;
+      snapshot.n_buffers = (data.buffers != null) ? data.buffers.size() : 0;
+      snapshot.n_children = (data.children != null) ? data.children.size() : 0;
+      snapshot.buffers = addressOrNull(data.buffers_ptrs);
+      snapshot.children = addressOrNull(data.children_ptrs);
+      snapshot.dictionary = addressOrNull(data.dictionary);
+      snapshot.release = NULL;
+      array.save(snapshot);
+
+      // sets release and private data
+      JniWrapper.get().exportArray(array.memoryAddress(), data);
+    } catch (Exception e) {
+      data.close();
+      throw e;
+    }
+
+    // Export children
+    if (children != null) {
+      for (int i = 0; i < children.size(); i++) {
+        FieldVector childVector = children.get(i);
+        ArrowArray child = data.children.get(i);
+        // TODO: getNullCount is slow, avoid calling it if possible
+        int cNullCount = childVector.getNullCount();
+        export(allocator, child, childVector, dictionaryProvider, cNullCount);
+      }
+    }
+  }
+}
diff --git a/common/src/main/java/org/apache/comet/parquet/ColumnReader.java 
b/common/src/main/java/org/apache/comet/parquet/ColumnReader.java
index 9e594804..820e0b50 100644
--- a/common/src/main/java/org/apache/comet/parquet/ColumnReader.java
+++ b/common/src/main/java/org/apache/comet/parquet/ColumnReader.java
@@ -48,6 +48,8 @@ import org.apache.comet.vector.CometDictionaryVector;
 import org.apache.comet.vector.CometPlainVector;
 import org.apache.comet.vector.CometVector;
 
+import static org.apache.comet.parquet.Utils.getNullCount;
+
 public class ColumnReader extends AbstractColumnReader {
   protected static final Logger LOG = 
LoggerFactory.getLogger(ColumnReader.class);
 
@@ -205,11 +207,12 @@ public class ColumnReader extends AbstractColumnReader {
 
     try (ArrowArray array = ArrowArray.wrap(addresses[0]);
         ArrowSchema schema = ArrowSchema.wrap(addresses[1])) {
+      int nullCount = getNullCount(array);
       FieldVector vector = importer.importVector(array, schema);
 
       DictionaryEncoding dictionaryEncoding = 
vector.getField().getDictionary();
 
-      CometPlainVector cometVector = new CometPlainVector(vector, 
useDecimal128);
+      CometPlainVector cometVector = new CometPlainVector(vector, 
useDecimal128, false, nullCount);
 
       // Update whether the current vector contains any null values. This is 
used in the following
       // batch(s) to determine whether we can skip loading the native vector.
@@ -232,8 +235,9 @@ public class ColumnReader extends AbstractColumnReader {
       // We should already re-initiate `CometDictionary` here because 
`Data.importVector` API will
       // release the previous dictionary vector and create a new one.
       Dictionary arrowDictionary = 
importer.getProvider().lookup(dictionaryEncoding.getId());
+      // Since the dictionary index tracks the nullCount, 
dictionaryVector.nullCount can be 0
       CometPlainVector dictionaryVector =
-          new CometPlainVector(arrowDictionary.getVector(), useDecimal128, 
isUuid);
+          new CometPlainVector(arrowDictionary.getVector(), useDecimal128, 
isUuid, 0);
       if (dictionary != null) {
         dictionary.setDictionaryVector(dictionaryVector);
       } else {
@@ -243,9 +247,6 @@ public class ColumnReader extends AbstractColumnReader {
       currentVector =
           new CometDictionaryVector(
               cometVector, dictionary, importer.getProvider(), useDecimal128, 
false, isUuid);
-
-      currentVector =
-          new CometDictionaryVector(cometVector, dictionary, 
importer.getProvider(), useDecimal128);
       return currentVector;
     }
   }
@@ -255,7 +256,7 @@ public class ColumnReader extends AbstractColumnReader {
     if (page == null) {
       throw new RuntimeException("overreading: returned DataPage is null");
     }
-    ;
+
     int pageValueCount = page.getValueCount();
     page.accept(
         new DataPage.Visitor<Void>() {
diff --git 
a/common/src/main/java/org/apache/comet/parquet/MetadataColumnReader.java 
b/common/src/main/java/org/apache/comet/parquet/MetadataColumnReader.java
index b8722ca7..6922d94f 100644
--- a/common/src/main/java/org/apache/comet/parquet/MetadataColumnReader.java
+++ b/common/src/main/java/org/apache/comet/parquet/MetadataColumnReader.java
@@ -31,6 +31,8 @@ import org.apache.spark.sql.types.DataType;
 import org.apache.comet.vector.CometPlainVector;
 import org.apache.comet.vector.CometVector;
 
+import static org.apache.comet.parquet.Utils.getNullCount;
+
 /** A metadata column reader that can be extended by {@link 
RowIndexColumnReader} etc. */
 public class MetadataColumnReader extends AbstractColumnReader {
   private final BufferAllocator allocator = new RootAllocator();
@@ -53,8 +55,9 @@ public class MetadataColumnReader extends 
AbstractColumnReader {
       long[] addresses = Native.currentBatch(nativeHandle);
       try (ArrowArray array = ArrowArray.wrap(addresses[0]);
           ArrowSchema schema = ArrowSchema.wrap(addresses[1])) {
+        int nullCount = getNullCount(array);
         FieldVector fieldVector = Data.importVector(allocator, array, schema, 
null);
-        vector = new CometPlainVector(fieldVector, useDecimal128);
+        vector = new CometPlainVector(fieldVector, useDecimal128, false, 
nullCount);
       }
     }
     vector.setNumValues(total);
diff --git a/common/src/main/java/org/apache/comet/parquet/Utils.java 
b/common/src/main/java/org/apache/comet/parquet/Utils.java
index f73251e2..54137ffb 100644
--- a/common/src/main/java/org/apache/comet/parquet/Utils.java
+++ b/common/src/main/java/org/apache/comet/parquet/Utils.java
@@ -19,11 +19,13 @@
 
 package org.apache.comet.parquet;
 
+import org.apache.arrow.c.ArrowArray;
 import org.apache.arrow.c.CometSchemaImporter;
 import org.apache.parquet.column.ColumnDescriptor;
 import org.apache.parquet.schema.LogicalTypeAnnotation;
 import org.apache.parquet.schema.PrimitiveType;
 import org.apache.spark.sql.types.*;
+import org.apache.spark.unsafe.Platform;
 
 public class Utils {
   public static ColumnReader getColumnReader(
@@ -257,4 +259,10 @@ public class Utils {
         throw new UnsupportedOperationException("Unsupported TimeUnit " + tu);
     }
   }
+
+  public static int getNullCount(ArrowArray array) {
+    // The second long value in the c interface is the null count
+    // 
https://arrow.apache.org/docs/format/CDataInterface.html#c.ArrowArray.null_count
+    return (int) Platform.getLong(null, array.memoryAddress() + 8L);
+  }
 }
diff --git 
a/common/src/main/java/org/apache/comet/vector/CometDecodedVector.java 
b/common/src/main/java/org/apache/comet/vector/CometDecodedVector.java
index f699134f..fc46b3d1 100644
--- a/common/src/main/java/org/apache/comet/vector/CometDecodedVector.java
+++ b/common/src/main/java/org/apache/comet/vector/CometDecodedVector.java
@@ -41,14 +41,15 @@ public abstract class CometDecodedVector extends 
CometVector {
   protected boolean isUuid;
 
   protected CometDecodedVector(ValueVector vector, Field valueField, boolean 
useDecimal128) {
-    this(vector, valueField, useDecimal128, false);
+    // TODO: getNullCount is slow, avoid calling it if possible
+    this(vector, valueField, useDecimal128, false, vector.getNullCount());
   }
 
   protected CometDecodedVector(
-      ValueVector vector, Field valueField, boolean useDecimal128, boolean 
isUuid) {
+      ValueVector vector, Field valueField, boolean useDecimal128, boolean 
isUuid, int nullCount) {
     super(Utils.fromArrowField(valueField), useDecimal128);
     this.valueVector = vector;
-    this.numNulls = valueVector.getNullCount();
+    this.numNulls = nullCount;
     this.numValues = valueVector.getValueCount();
     this.hasNull = numNulls != 0;
     this.isUuid = isUuid;
diff --git 
a/common/src/main/java/org/apache/comet/vector/CometDictionaryVector.java 
b/common/src/main/java/org/apache/comet/vector/CometDictionaryVector.java
index a49255e7..dbd2037a 100644
--- a/common/src/main/java/org/apache/comet/vector/CometDictionaryVector.java
+++ b/common/src/main/java/org/apache/comet/vector/CometDictionaryVector.java
@@ -49,7 +49,12 @@ public class CometDictionaryVector extends 
CometDecodedVector {
       boolean useDecimal128,
       boolean isAlias,
       boolean isUuid) {
-    super(indices.valueVector, values.getValueVector().getField(), 
useDecimal128, isUuid);
+    super(
+        indices.valueVector,
+        values.getValueVector().getField(),
+        useDecimal128,
+        isUuid,
+        indices.numNulls());
     Preconditions.checkArgument(
         indices.valueVector instanceof IntVector, "'indices' should be a 
IntVector");
     this.values = values;
diff --git a/common/src/main/java/org/apache/comet/vector/CometPlainVector.java 
b/common/src/main/java/org/apache/comet/vector/CometPlainVector.java
index 65cc876b..60fcd948 100644
--- a/common/src/main/java/org/apache/comet/vector/CometPlainVector.java
+++ b/common/src/main/java/org/apache/comet/vector/CometPlainVector.java
@@ -39,11 +39,13 @@ public class CometPlainVector extends CometDecodedVector {
   private int booleanByteCacheIndex = -1;
 
   public CometPlainVector(ValueVector vector, boolean useDecimal128) {
-    this(vector, useDecimal128, false);
+    // TODO: getNullCount is slow, avoid calling it if possible
+    this(vector, useDecimal128, false, vector.getNullCount());
   }
 
-  public CometPlainVector(ValueVector vector, boolean useDecimal128, boolean 
isUuid) {
-    super(vector, vector.getField(), useDecimal128, isUuid);
+  public CometPlainVector(
+      ValueVector vector, boolean useDecimal128, boolean isUuid, int 
nullCount) {
+    super(vector, vector.getField(), useDecimal128, isUuid, nullCount);
     // NullType doesn't have data buffer.
     if (vector instanceof NullVector) {
       this.valueBufferAddress = -1;
diff --git a/common/src/main/java/org/apache/comet/vector/CometVector.java 
b/common/src/main/java/org/apache/comet/vector/CometVector.java
index 8e4c4edf..d7dffe5c 100644
--- a/common/src/main/java/org/apache/comet/vector/CometVector.java
+++ b/common/src/main/java/org/apache/comet/vector/CometVector.java
@@ -233,7 +233,10 @@ public abstract class CometVector extends ColumnVector {
    * @return `CometVector` implementation
    */
   protected static CometVector getVector(
-      ValueVector vector, boolean useDecimal128, DictionaryProvider 
dictionaryProvider) {
+      ValueVector vector,
+      boolean useDecimal128,
+      DictionaryProvider dictionaryProvider,
+      int nullCount) {
     if (vector instanceof StructVector) {
       return new CometStructVector(vector, useDecimal128);
     } else if (vector instanceof MapVector) {
@@ -242,14 +245,15 @@ public abstract class CometVector extends ColumnVector {
       return new CometListVector(vector, useDecimal128);
     } else {
       DictionaryEncoding dictionaryEncoding = 
vector.getField().getDictionary();
-      CometPlainVector cometVector = new CometPlainVector(vector, 
useDecimal128);
+      CometPlainVector cometVector = new CometPlainVector(vector, 
useDecimal128, false, nullCount);
 
       if (dictionaryEncoding == null) {
         return cometVector;
       } else {
         Dictionary dictionary = 
dictionaryProvider.lookup(dictionaryEncoding.getId());
+        // Since the dictionary index tracks the nullCount, 
dictionaryVector.nullCount can be 0
         CometPlainVector dictionaryVector =
-            new CometPlainVector(dictionary.getVector(), useDecimal128);
+            new CometPlainVector(dictionary.getVector(), useDecimal128, false, 
0);
         CometDictionary cometDictionary = new 
CometDictionary(dictionaryVector);
 
         return new CometDictionaryVector(
@@ -259,6 +263,8 @@ public abstract class CometVector extends ColumnVector {
   }
 
   protected static CometVector getVector(ValueVector vector, boolean 
useDecimal128) {
-    return getVector(vector, useDecimal128, null);
+    // This is currently called only for CometStructVector and CometListVector
+    // If necessary set the proper nullCount
+    return getVector(vector, useDecimal128, null, 0);
   }
 }
diff --git a/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala 
b/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala
index 89f79c9c..43c9b7e8 100644
--- a/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala
+++ b/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala
@@ -21,7 +21,8 @@ package org.apache.comet.vector
 
 import scala.collection.mutable
 
-import org.apache.arrow.c.{ArrowArray, ArrowImporter, ArrowSchema, 
CDataDictionaryProvider, Data}
+import org.apache.arrow.c.{ArrowArray, ArrowImporter, ArrowSchema, 
CDataDictionaryProvider}
+import org.apache.arrow.c.CometArrayExporter.exportVector
 import org.apache.arrow.vector.VectorSchemaRoot
 import org.apache.arrow.vector.dictionary.DictionaryProvider
 import org.apache.spark.SparkException
@@ -29,6 +30,7 @@ import org.apache.spark.sql.comet.util.Utils
 import org.apache.spark.sql.vectorized.ColumnarBatch
 
 import org.apache.comet.CometArrowAllocator
+import org.apache.comet.parquet.Utils.getNullCount;
 
 class NativeUtil {
   import Utils._
@@ -64,12 +66,16 @@ class NativeUtil {
 
           val arrowSchema = ArrowSchema.allocateNew(allocator)
           val arrowArray = ArrowArray.allocateNew(allocator)
-          Data.exportVector(
+          exportVector(
             allocator,
             getFieldVector(valueVector, "export"),
             provider,
             arrowArray,
-            arrowSchema)
+            arrowSchema,
+            // TODO: Somehow calling valueVector.getNullCount seems to be 
faster than a.numNulls,
+            //       but it should be the other way around. Investigate why
+            valueVector.getNullCount)
+          // a.numNulls())
 
           exportedVectors += arrowArray.memoryAddress()
           exportedVectors += arrowSchema.memoryAddress()
@@ -98,13 +104,15 @@ class NativeUtil {
     for (i <- arrayAddress.indices by 2) {
       val arrowSchema = ArrowSchema.wrap(arrayAddress(i + 1))
       val arrowArray = ArrowArray.wrap(arrayAddress(i))
+      val nullCount = getNullCount(arrowArray)
 
       // Native execution should always have 'useDecimal128' set to true since 
it doesn't support
       // other cases.
       arrayVectors += CometVector.getVector(
         importer.importVector(arrowArray, arrowSchema, dictionaryProvider),
         true,
-        dictionaryProvider)
+        dictionaryProvider,
+        nullCount)
 
       arrowArray.close()
       arrowSchema.close()
@@ -145,7 +153,8 @@ object NativeUtil {
     val vectors = (0 until arrowRoot.getFieldVectors.size()).map { i =>
       val vector = arrowRoot.getFieldVectors.get(i)
       // Native shuffle always uses decimal128.
-      CometVector.getVector(vector, true, provider)
+      // TODO: getNullCount is slow, avoid calling it if possible
+      CometVector.getVector(vector, true, provider, vector.getNullCount)
     }
     new ColumnarBatch(vectors.toArray, arrowRoot.getRowCount)
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@datafusion.apache.org
For additional commands, e-mail: commits-h...@datafusion.apache.org

Reply via email to