This is an automated email from the ASF dual-hosted git repository. kazuyukitanimura pushed a commit to branch optimize-null-count in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
commit 50d686eb85edaee39e3fb9b8f90ef2b0520b8b31 Author: Kazuyuki Tanimura <ktanim...@apple.com> AuthorDate: Fri Aug 9 02:15:27 2024 -0700 fix: Optimize not to call getNullCount as much as possible --- .../org/apache/arrow/c/CometArrayExporter.java | 120 +++++++++++++++++++++ .../org/apache/comet/parquet/ColumnReader.java | 13 +-- .../apache/comet/parquet/MetadataColumnReader.java | 5 +- .../main/java/org/apache/comet/parquet/Utils.java | 8 ++ .../apache/comet/vector/CometDecodedVector.java | 7 +- .../apache/comet/vector/CometDictionaryVector.java | 7 +- .../org/apache/comet/vector/CometPlainVector.java | 8 +- .../java/org/apache/comet/vector/CometVector.java | 14 ++- .../scala/org/apache/comet/vector/NativeUtil.scala | 19 +++- 9 files changed, 178 insertions(+), 23 deletions(-) diff --git a/common/src/main/java/org/apache/arrow/c/CometArrayExporter.java b/common/src/main/java/org/apache/arrow/c/CometArrayExporter.java new file mode 100644 index 00000000..f76e0e18 --- /dev/null +++ b/common/src/main/java/org/apache/arrow/c/CometArrayExporter.java @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.arrow.c; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.arrow.c.jni.JniWrapper; +import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.dictionary.DictionaryProvider; +import org.apache.arrow.vector.types.pojo.DictionaryEncoding; + +import static org.apache.arrow.c.Data.exportField; +import static org.apache.arrow.c.NativeUtil.NULL; +import static org.apache.arrow.c.NativeUtil.addressOrNull; +import static org.apache.arrow.util.Preconditions.checkNotNull; + +public final class CometArrayExporter { + // Copied from Data.exportVector and changed to take nullCount from outside + public static void exportVector( + BufferAllocator allocator, + FieldVector vector, + DictionaryProvider provider, + ArrowArray out, + ArrowSchema outSchema, + long nullCount) { + exportField(allocator, vector.getField(), provider, outSchema); + export(allocator, out, vector, provider, nullCount); + } + + private static void export( + BufferAllocator allocator, + ArrowArray array, + FieldVector vector, + DictionaryProvider dictionaryProvider, + long nullCount) { + List<FieldVector> children = vector.getChildrenFromFields(); + List<ArrowBuf> buffers = vector.getFieldBuffers(); + int valueCount = vector.getValueCount(); + DictionaryEncoding dictionaryEncoding = vector.getField().getDictionary(); + + ArrayExporter.ExportedArrayPrivateData data = new ArrayExporter.ExportedArrayPrivateData(); + try { + if (children != null) { + data.children = new ArrayList<>(children.size()); + data.children_ptrs = allocator.buffer((long) children.size() * Long.BYTES); + for (int i = 0; i < children.size(); i++) { + ArrowArray child = ArrowArray.allocateNew(allocator); + data.children.add(child); + data.children_ptrs.writeLong(child.memoryAddress()); + } + } + + if (buffers != null) { + data.buffers = new ArrayList<>(buffers.size()); + data.buffers_ptrs = allocator.buffer((long) buffers.size() * Long.BYTES); + vector.exportCDataBuffers(data.buffers, data.buffers_ptrs, NULL); + } + + if (dictionaryEncoding != null) { + Dictionary dictionary = dictionaryProvider.lookup(dictionaryEncoding.getId()); + checkNotNull(dictionary, "Dictionary lookup failed on export of dictionary encoded array"); + + data.dictionary = ArrowArray.allocateNew(allocator); + FieldVector dictionaryVector = dictionary.getVector(); + // Since the dictionary index tracks the nullCount, the nullCount of the values can be 0 + export(allocator, data.dictionary, dictionaryVector, dictionaryProvider, 0); + } + + ArrowArray.Snapshot snapshot = new ArrowArray.Snapshot(); + snapshot.length = valueCount; + snapshot.null_count = nullCount; + snapshot.offset = 0; + snapshot.n_buffers = (data.buffers != null) ? data.buffers.size() : 0; + snapshot.n_children = (data.children != null) ? data.children.size() : 0; + snapshot.buffers = addressOrNull(data.buffers_ptrs); + snapshot.children = addressOrNull(data.children_ptrs); + snapshot.dictionary = addressOrNull(data.dictionary); + snapshot.release = NULL; + array.save(snapshot); + + // sets release and private data + JniWrapper.get().exportArray(array.memoryAddress(), data); + } catch (Exception e) { + data.close(); + throw e; + } + + // Export children + if (children != null) { + for (int i = 0; i < children.size(); i++) { + FieldVector childVector = children.get(i); + ArrowArray child = data.children.get(i); + // TODO: getNullCount is slow, avoid calling it if possible + int cNullCount = childVector.getNullCount(); + export(allocator, child, childVector, dictionaryProvider, cNullCount); + } + } + } +} diff --git a/common/src/main/java/org/apache/comet/parquet/ColumnReader.java b/common/src/main/java/org/apache/comet/parquet/ColumnReader.java index 9e594804..820e0b50 100644 --- a/common/src/main/java/org/apache/comet/parquet/ColumnReader.java +++ b/common/src/main/java/org/apache/comet/parquet/ColumnReader.java @@ -48,6 +48,8 @@ import org.apache.comet.vector.CometDictionaryVector; import org.apache.comet.vector.CometPlainVector; import org.apache.comet.vector.CometVector; +import static org.apache.comet.parquet.Utils.getNullCount; + public class ColumnReader extends AbstractColumnReader { protected static final Logger LOG = LoggerFactory.getLogger(ColumnReader.class); @@ -205,11 +207,12 @@ public class ColumnReader extends AbstractColumnReader { try (ArrowArray array = ArrowArray.wrap(addresses[0]); ArrowSchema schema = ArrowSchema.wrap(addresses[1])) { + int nullCount = getNullCount(array); FieldVector vector = importer.importVector(array, schema); DictionaryEncoding dictionaryEncoding = vector.getField().getDictionary(); - CometPlainVector cometVector = new CometPlainVector(vector, useDecimal128); + CometPlainVector cometVector = new CometPlainVector(vector, useDecimal128, false, nullCount); // Update whether the current vector contains any null values. This is used in the following // batch(s) to determine whether we can skip loading the native vector. @@ -232,8 +235,9 @@ public class ColumnReader extends AbstractColumnReader { // We should already re-initiate `CometDictionary` here because `Data.importVector` API will // release the previous dictionary vector and create a new one. Dictionary arrowDictionary = importer.getProvider().lookup(dictionaryEncoding.getId()); + // Since the dictionary index tracks the nullCount, dictionaryVector.nullCount can be 0 CometPlainVector dictionaryVector = - new CometPlainVector(arrowDictionary.getVector(), useDecimal128, isUuid); + new CometPlainVector(arrowDictionary.getVector(), useDecimal128, isUuid, 0); if (dictionary != null) { dictionary.setDictionaryVector(dictionaryVector); } else { @@ -243,9 +247,6 @@ public class ColumnReader extends AbstractColumnReader { currentVector = new CometDictionaryVector( cometVector, dictionary, importer.getProvider(), useDecimal128, false, isUuid); - - currentVector = - new CometDictionaryVector(cometVector, dictionary, importer.getProvider(), useDecimal128); return currentVector; } } @@ -255,7 +256,7 @@ public class ColumnReader extends AbstractColumnReader { if (page == null) { throw new RuntimeException("overreading: returned DataPage is null"); } - ; + int pageValueCount = page.getValueCount(); page.accept( new DataPage.Visitor<Void>() { diff --git a/common/src/main/java/org/apache/comet/parquet/MetadataColumnReader.java b/common/src/main/java/org/apache/comet/parquet/MetadataColumnReader.java index b8722ca7..6922d94f 100644 --- a/common/src/main/java/org/apache/comet/parquet/MetadataColumnReader.java +++ b/common/src/main/java/org/apache/comet/parquet/MetadataColumnReader.java @@ -31,6 +31,8 @@ import org.apache.spark.sql.types.DataType; import org.apache.comet.vector.CometPlainVector; import org.apache.comet.vector.CometVector; +import static org.apache.comet.parquet.Utils.getNullCount; + /** A metadata column reader that can be extended by {@link RowIndexColumnReader} etc. */ public class MetadataColumnReader extends AbstractColumnReader { private final BufferAllocator allocator = new RootAllocator(); @@ -53,8 +55,9 @@ public class MetadataColumnReader extends AbstractColumnReader { long[] addresses = Native.currentBatch(nativeHandle); try (ArrowArray array = ArrowArray.wrap(addresses[0]); ArrowSchema schema = ArrowSchema.wrap(addresses[1])) { + int nullCount = getNullCount(array); FieldVector fieldVector = Data.importVector(allocator, array, schema, null); - vector = new CometPlainVector(fieldVector, useDecimal128); + vector = new CometPlainVector(fieldVector, useDecimal128, false, nullCount); } } vector.setNumValues(total); diff --git a/common/src/main/java/org/apache/comet/parquet/Utils.java b/common/src/main/java/org/apache/comet/parquet/Utils.java index f73251e2..54137ffb 100644 --- a/common/src/main/java/org/apache/comet/parquet/Utils.java +++ b/common/src/main/java/org/apache/comet/parquet/Utils.java @@ -19,11 +19,13 @@ package org.apache.comet.parquet; +import org.apache.arrow.c.ArrowArray; import org.apache.arrow.c.CometSchemaImporter; import org.apache.parquet.column.ColumnDescriptor; import org.apache.parquet.schema.LogicalTypeAnnotation; import org.apache.parquet.schema.PrimitiveType; import org.apache.spark.sql.types.*; +import org.apache.spark.unsafe.Platform; public class Utils { public static ColumnReader getColumnReader( @@ -257,4 +259,10 @@ public class Utils { throw new UnsupportedOperationException("Unsupported TimeUnit " + tu); } } + + public static int getNullCount(ArrowArray array) { + // The second long value in the c interface is the null count + // https://arrow.apache.org/docs/format/CDataInterface.html#c.ArrowArray.null_count + return (int) Platform.getLong(null, array.memoryAddress() + 8L); + } } diff --git a/common/src/main/java/org/apache/comet/vector/CometDecodedVector.java b/common/src/main/java/org/apache/comet/vector/CometDecodedVector.java index f699134f..fc46b3d1 100644 --- a/common/src/main/java/org/apache/comet/vector/CometDecodedVector.java +++ b/common/src/main/java/org/apache/comet/vector/CometDecodedVector.java @@ -41,14 +41,15 @@ public abstract class CometDecodedVector extends CometVector { protected boolean isUuid; protected CometDecodedVector(ValueVector vector, Field valueField, boolean useDecimal128) { - this(vector, valueField, useDecimal128, false); + // TODO: getNullCount is slow, avoid calling it if possible + this(vector, valueField, useDecimal128, false, vector.getNullCount()); } protected CometDecodedVector( - ValueVector vector, Field valueField, boolean useDecimal128, boolean isUuid) { + ValueVector vector, Field valueField, boolean useDecimal128, boolean isUuid, int nullCount) { super(Utils.fromArrowField(valueField), useDecimal128); this.valueVector = vector; - this.numNulls = valueVector.getNullCount(); + this.numNulls = nullCount; this.numValues = valueVector.getValueCount(); this.hasNull = numNulls != 0; this.isUuid = isUuid; diff --git a/common/src/main/java/org/apache/comet/vector/CometDictionaryVector.java b/common/src/main/java/org/apache/comet/vector/CometDictionaryVector.java index a49255e7..dbd2037a 100644 --- a/common/src/main/java/org/apache/comet/vector/CometDictionaryVector.java +++ b/common/src/main/java/org/apache/comet/vector/CometDictionaryVector.java @@ -49,7 +49,12 @@ public class CometDictionaryVector extends CometDecodedVector { boolean useDecimal128, boolean isAlias, boolean isUuid) { - super(indices.valueVector, values.getValueVector().getField(), useDecimal128, isUuid); + super( + indices.valueVector, + values.getValueVector().getField(), + useDecimal128, + isUuid, + indices.numNulls()); Preconditions.checkArgument( indices.valueVector instanceof IntVector, "'indices' should be a IntVector"); this.values = values; diff --git a/common/src/main/java/org/apache/comet/vector/CometPlainVector.java b/common/src/main/java/org/apache/comet/vector/CometPlainVector.java index 65cc876b..60fcd948 100644 --- a/common/src/main/java/org/apache/comet/vector/CometPlainVector.java +++ b/common/src/main/java/org/apache/comet/vector/CometPlainVector.java @@ -39,11 +39,13 @@ public class CometPlainVector extends CometDecodedVector { private int booleanByteCacheIndex = -1; public CometPlainVector(ValueVector vector, boolean useDecimal128) { - this(vector, useDecimal128, false); + // TODO: getNullCount is slow, avoid calling it if possible + this(vector, useDecimal128, false, vector.getNullCount()); } - public CometPlainVector(ValueVector vector, boolean useDecimal128, boolean isUuid) { - super(vector, vector.getField(), useDecimal128, isUuid); + public CometPlainVector( + ValueVector vector, boolean useDecimal128, boolean isUuid, int nullCount) { + super(vector, vector.getField(), useDecimal128, isUuid, nullCount); // NullType doesn't have data buffer. if (vector instanceof NullVector) { this.valueBufferAddress = -1; diff --git a/common/src/main/java/org/apache/comet/vector/CometVector.java b/common/src/main/java/org/apache/comet/vector/CometVector.java index 8e4c4edf..d7dffe5c 100644 --- a/common/src/main/java/org/apache/comet/vector/CometVector.java +++ b/common/src/main/java/org/apache/comet/vector/CometVector.java @@ -233,7 +233,10 @@ public abstract class CometVector extends ColumnVector { * @return `CometVector` implementation */ protected static CometVector getVector( - ValueVector vector, boolean useDecimal128, DictionaryProvider dictionaryProvider) { + ValueVector vector, + boolean useDecimal128, + DictionaryProvider dictionaryProvider, + int nullCount) { if (vector instanceof StructVector) { return new CometStructVector(vector, useDecimal128); } else if (vector instanceof MapVector) { @@ -242,14 +245,15 @@ public abstract class CometVector extends ColumnVector { return new CometListVector(vector, useDecimal128); } else { DictionaryEncoding dictionaryEncoding = vector.getField().getDictionary(); - CometPlainVector cometVector = new CometPlainVector(vector, useDecimal128); + CometPlainVector cometVector = new CometPlainVector(vector, useDecimal128, false, nullCount); if (dictionaryEncoding == null) { return cometVector; } else { Dictionary dictionary = dictionaryProvider.lookup(dictionaryEncoding.getId()); + // Since the dictionary index tracks the nullCount, dictionaryVector.nullCount can be 0 CometPlainVector dictionaryVector = - new CometPlainVector(dictionary.getVector(), useDecimal128); + new CometPlainVector(dictionary.getVector(), useDecimal128, false, 0); CometDictionary cometDictionary = new CometDictionary(dictionaryVector); return new CometDictionaryVector( @@ -259,6 +263,8 @@ public abstract class CometVector extends ColumnVector { } protected static CometVector getVector(ValueVector vector, boolean useDecimal128) { - return getVector(vector, useDecimal128, null); + // This is currently called only for CometStructVector and CometListVector + // If necessary set the proper nullCount + return getVector(vector, useDecimal128, null, 0); } } diff --git a/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala b/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala index 89f79c9c..43c9b7e8 100644 --- a/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala +++ b/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala @@ -21,7 +21,8 @@ package org.apache.comet.vector import scala.collection.mutable -import org.apache.arrow.c.{ArrowArray, ArrowImporter, ArrowSchema, CDataDictionaryProvider, Data} +import org.apache.arrow.c.{ArrowArray, ArrowImporter, ArrowSchema, CDataDictionaryProvider} +import org.apache.arrow.c.CometArrayExporter.exportVector import org.apache.arrow.vector.VectorSchemaRoot import org.apache.arrow.vector.dictionary.DictionaryProvider import org.apache.spark.SparkException @@ -29,6 +30,7 @@ import org.apache.spark.sql.comet.util.Utils import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.comet.CometArrowAllocator +import org.apache.comet.parquet.Utils.getNullCount; class NativeUtil { import Utils._ @@ -64,12 +66,16 @@ class NativeUtil { val arrowSchema = ArrowSchema.allocateNew(allocator) val arrowArray = ArrowArray.allocateNew(allocator) - Data.exportVector( + exportVector( allocator, getFieldVector(valueVector, "export"), provider, arrowArray, - arrowSchema) + arrowSchema, + // TODO: Somehow calling valueVector.getNullCount seems to be faster than a.numNulls, + // but it should be the other way around. Investigate why + valueVector.getNullCount) + // a.numNulls()) exportedVectors += arrowArray.memoryAddress() exportedVectors += arrowSchema.memoryAddress() @@ -98,13 +104,15 @@ class NativeUtil { for (i <- arrayAddress.indices by 2) { val arrowSchema = ArrowSchema.wrap(arrayAddress(i + 1)) val arrowArray = ArrowArray.wrap(arrayAddress(i)) + val nullCount = getNullCount(arrowArray) // Native execution should always have 'useDecimal128' set to true since it doesn't support // other cases. arrayVectors += CometVector.getVector( importer.importVector(arrowArray, arrowSchema, dictionaryProvider), true, - dictionaryProvider) + dictionaryProvider, + nullCount) arrowArray.close() arrowSchema.close() @@ -145,7 +153,8 @@ object NativeUtil { val vectors = (0 until arrowRoot.getFieldVectors.size()).map { i => val vector = arrowRoot.getFieldVectors.get(i) // Native shuffle always uses decimal128. - CometVector.getVector(vector, true, provider) + // TODO: getNullCount is slow, avoid calling it if possible + CometVector.getVector(vector, true, provider, vector.getNullCount) } new ColumnarBatch(vectors.toArray, arrowRoot.getRowCount) } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@datafusion.apache.org For additional commands, e-mail: commits-h...@datafusion.apache.org