This is an automated email from the ASF dual-hosted git repository.
loneylee pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 262d9c176e [GLUTEN-9020][CH] Support delta DV BitmapAggregator (#9138)
262d9c176e is described below
commit 262d9c176ecb32d8d186b066454f49723177411c
Author: Shuai li <[email protected]>
AuthorDate: Wed Apr 2 10:07:51 2025 +0800
[GLUTEN-9020][CH] Support delta DV BitmapAggregator (#9138)
* [GLUTEN-9020][CH] Support delta DV `bitmapaggregator`
---
.../DeltaExpressionExtensionTransformer.scala | 79 ++++++++++++++
.../gluten/sql/shims/delta32/Delta32Shims.scala | 5 +
.../GlutenDeltaParquetDeletionVectorSuite.scala | 83 ++++++++++++++
.../apache/gluten/component/CHDeltaComponent.scala | 3 +
.../org/apache/gluten/sql/shims/DeltaShims.scala | 2 +
.../backendsapi/clickhouse/CHListenerApi.scala | 2 +-
.../gluten/backendsapi/clickhouse/CHRuleApi.scala | 3 +-
.../clickhouse/CHSparkPlanExecApi.scala | 10 +-
.../execution/CHHashAggregateExecTransformer.scala | 18 ++--
.../apache/gluten/expression/CHExpressions.scala | 10 +-
.../extension/ExpressionExtensionTrait.scala | 25 ++++-
.../AggregateFunctionDVRoaringBitmap.cpp | 50 +++++++++
.../AggregateFunctionDVRoaringBitmap.h | 119 +++++++++++++++++++++
cpp-ch/local-engine/Common/CHUtil.cpp | 2 +
.../CommonAggregateFunctionParser.cpp | 1 +
.../Delta/Bitmap/DeltaDVRoaringBitmapArray.cpp | 82 ++++++++++----
.../Delta/Bitmap/DeltaDVRoaringBitmapArray.h | 9 +-
.../Storages/SubstraitSource/Delta/DeltaReader.cpp | 4 +-
.../Iceberg/PositionalDeleteFileReader.cpp | 2 +-
.../tests/gtest_clickhouse_roaring_bitmap.cpp | 14 +--
20 files changed, 460 insertions(+), 63 deletions(-)
diff --git
a/backends-clickhouse/src-delta-32/main/scala/org/apache/gluten/extension/DeltaExpressionExtensionTransformer.scala
b/backends-clickhouse/src-delta-32/main/scala/org/apache/gluten/extension/DeltaExpressionExtensionTransformer.scala
new file mode 100644
index 0000000000..8f9439e7d9
--- /dev/null
+++
b/backends-clickhouse/src-delta-32/main/scala/org/apache/gluten/extension/DeltaExpressionExtensionTransformer.scala
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.extension
+
+import org.apache.gluten.exception.GlutenNotSupportException
+import org.apache.gluten.expression.{ConverterUtils, ExpressionTransformer,
Sig}
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
+import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.sql.catalyst.expressions.aggregation.BitmapAggregator
+import org.apache.spark.sql.types.DataType
+
+import scala.collection.mutable.ListBuffer
+
+case class DeltaExpressionExtensionTransformer() extends
ExpressionExtensionTrait with Logging {
+
+ override def expressionSigList: Seq[Sig] =
Sig[BitmapAggregator]("bitmapaggregator") :: Nil
+
+ override def getAttrsIndexForExtensionAggregateExpr(
+ aggregateFunc: AggregateFunction,
+ mode: AggregateMode,
+ exp: AggregateExpression,
+ aggregateAttributeList: Seq[Attribute],
+ aggregateAttr: ListBuffer[Attribute],
+ resIndex: Int): Int = {
+ exp.mode match {
+ case Partial | PartialMerge =>
+ val aggBufferAttr = aggregateFunc.inputAggBufferAttributes
+ for (index <- aggBufferAttr.indices) {
+ val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr(index))
+ aggregateAttr += attr
+ }
+ resIndex + aggBufferAttr.size
+ case Final | Complete =>
+ aggregateAttr += aggregateAttributeList(resIndex)
+ resIndex + 1
+ case other =>
+ throw new GlutenNotSupportException(s"Unsupported aggregate mode:
$other.")
+ }
+ }
+
+ /** Get the custom agg function substrait name and the input types of the
child */
+ override def buildCustomAggregateFunction(
+ aggregateFunc: AggregateFunction): (Option[String], Seq[DataType]) = {
+ val substraitAggFuncName = aggregateFunc match {
+ case _ =>
+ extensionExpressionsMapping.get(aggregateFunc.getClass)
+ }
+ if (substraitAggFuncName.isEmpty) {
+ throw new UnsupportedOperationException(
+ s"Aggregate function ${aggregateFunc.getClass} is not supported.")
+ }
+ (substraitAggFuncName, aggregateFunc.children.map(child => child.dataType))
+ }
+
+ override def replaceWithExtensionExpressionTransformer(
+ substraitExprName: String,
+ expr: Expression,
+ attributeSeq: Seq[Attribute]): ExpressionTransformer = expr match {
+ case _ =>
+ throw new UnsupportedOperationException(
+ s"${expr.getClass} or $expr is not currently supported.")
+ }
+}
diff --git
a/backends-clickhouse/src-delta-32/main/scala/org/apache/gluten/sql/shims/delta32/Delta32Shims.scala
b/backends-clickhouse/src-delta-32/main/scala/org/apache/gluten/sql/shims/delta32/Delta32Shims.scala
index f4735ff9ec..52da7b7883 100644
---
a/backends-clickhouse/src-delta-32/main/scala/org/apache/gluten/sql/shims/delta32/Delta32Shims.scala
+++
b/backends-clickhouse/src-delta-32/main/scala/org/apache/gluten/sql/shims/delta32/Delta32Shims.scala
@@ -17,6 +17,7 @@
package org.apache.gluten.sql.shims.delta32
import org.apache.gluten.execution.GlutenPlan
+import org.apache.gluten.extension.{DeltaExpressionExtensionTransformer,
ExpressionExtensionTrait}
import org.apache.gluten.sql.shims.DeltaShims
import org.apache.spark.sql.delta.DeltaParquetFileFormat
@@ -40,6 +41,10 @@ class Delta32Shims extends DeltaShims {
DeltaOptimizedWriterTransformer.from(plan)
}
+ override def registerExpressionExtension(): Unit = {
+
ExpressionExtensionTrait.registerExpressionExtension(DeltaExpressionExtensionTransformer())
+ }
+
/**
* decode ZeroMQ Base85 encoded file path
*
diff --git
a/backends-clickhouse/src-delta-32/test/scala/org/apache/spark/gluten/delta/GlutenDeltaParquetDeletionVectorSuite.scala
b/backends-clickhouse/src-delta-32/test/scala/org/apache/spark/gluten/delta/GlutenDeltaParquetDeletionVectorSuite.scala
index b7b5b8ed95..b96e02af2c 100644
---
a/backends-clickhouse/src-delta-32/test/scala/org/apache/spark/gluten/delta/GlutenDeltaParquetDeletionVectorSuite.scala
+++
b/backends-clickhouse/src-delta-32/test/scala/org/apache/spark/gluten/delta/GlutenDeltaParquetDeletionVectorSuite.scala
@@ -19,8 +19,13 @@ package org.apache.spark.gluten.delta
import org.apache.gluten.execution.{FileSourceScanExecTransformer,
GlutenClickHouseTPCHAbstractSuite}
import org.apache.spark.SparkConf
+import org.apache.spark.sql.Column
+import org.apache.spark.sql.catalyst.expressions.aggregation.BitmapAggregator
+import org.apache.spark.sql.delta.deletionvectors.RoaringBitmapArrayFormat
import org.apache.spark.sql.delta.files.TahoeFileIndex
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
+import org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec
+import org.apache.spark.sql.functions.col
// Some sqls' line length exceeds 100
// scalastyle:off line.size.limit
@@ -180,6 +185,84 @@ class GlutenDeltaParquetDeletionVectorSuite
assert(scanExec.nonEmpty)
}
+ test("test ObjectHashAggregateExec(bitmapaggregator) no fallback") {
+ val table_name = "dv_fallback"
+ withTable(table_name) {
+ withSQLConf("spark.sql.adaptive.enabled" -> "false") {
+ spark.sql(s"""
+ |CREATE TABLE IF NOT EXISTS $table_name
+ |($q1SchemaString)
+ |USING delta
+ |TBLPROPERTIES (delta.enableDeletionVectors='true')
+ |LOCATION '$basePath/$table_name'
+ |""".stripMargin)
+
+ spark.sql(s"""
+ | insert into table $table_name select * from lineitem
+ |""".stripMargin)
+
+ def createBitmapSetAggregator(indexColumn: Column): Column = {
+ val func = new BitmapAggregator(indexColumn.expr,
RoaringBitmapArrayFormat.Portable)
+ new Column(func.toAggregateExpression(isDistinct = false))
+ }
+
+ val aggColumns = Seq(createBitmapSetAggregator(col("l_orderkey")))
+
+ val aggregated = sql(s"select l_orderkey,l_shipdate from $table_name")
+ .groupBy(col("l_shipdate"))
+ .agg(aggColumns.head, aggColumns.tail: _*)
+ .select("*")
+ .toDF()
+ aggregated.collect()
+ val bitMapAggregator = aggregated.queryExecution.executedPlan.collect {
+ case agg: ObjectHashAggregateExec => agg
+ }
+ assert(bitMapAggregator.isEmpty)
+ }
+ }
+ }
+
+ test("test delta DV write") {
+ val table_name = "dv_write_test"
+ withTable(table_name) {
+ spark.sql(s"""
+ |CREATE TABLE IF NOT EXISTS $table_name
+ |($q1SchemaString)
+ |USING delta
+ |TBLPROPERTIES (delta.enableDeletionVectors='true')
+ |LOCATION '$basePath/$table_name'
+ |""".stripMargin)
+
+ spark.sql(s"""
+ | insert into table $table_name select * from lineitem
+ |""".stripMargin)
+
+ spark.sql(s"""
+ | delete from $table_name
+ | where mod(l_orderkey, 3) = 1 and l_orderkey < 100
+ |""".stripMargin)
+
+ val df = spark.sql(s"""
+ | select sum(l_linenumber) from $table_name
+ |""".stripMargin)
+ val result = df.collect()
+ assertResult(1802335)(result.apply(0).get(0))
+
+ spark.sql(s"""
+ | update $table_name
+ | set l_orderkey = 1 where l_orderkey > 0
+ |""".stripMargin)
+
+ spark.sql(s""" select count(*) from $table_name """.stripMargin).show()
+
+ val df2 = spark.sql(s"""
+ | select sum(l_orderkey) from $table_name
+ |""".stripMargin)
+ val result2 = df2.collect()
+ assertResult(600536)(result2.apply(0).get(0))
+ }
+ }
+
test("test parquet partition table delete with the delta DV") {
withSQLConf(("spark.sql.sources.partitionOverwriteMode", "dynamic")) {
spark.sql(s"""
diff --git
a/backends-clickhouse/src-delta/main/scala/org/apache/gluten/component/CHDeltaComponent.scala
b/backends-clickhouse/src-delta/main/scala/org/apache/gluten/component/CHDeltaComponent.scala
index 48a4c88518..1b187e4a4b 100644
---
a/backends-clickhouse/src-delta/main/scala/org/apache/gluten/component/CHDeltaComponent.scala
+++
b/backends-clickhouse/src-delta/main/scala/org/apache/gluten/component/CHDeltaComponent.scala
@@ -21,6 +21,7 @@ import org.apache.gluten.execution.OffloadDeltaNode
import org.apache.gluten.extension.columnar.heuristic.HeuristicTransform
import org.apache.gluten.extension.columnar.validator.Validators
import org.apache.gluten.extension.injector.Injector
+import org.apache.gluten.sql.shims.DeltaShimLoader
class CHDeltaComponent extends Component {
override def name(): String = "ch-delta"
@@ -34,5 +35,7 @@ class CHDeltaComponent extends Component {
val offload = Seq(OffloadDeltaNode())
HeuristicTransform.Simple(Validators.newValidator(c.glutenConf,
offload), offload)
}
+
+ DeltaShimLoader.getDeltaShims.registerExpressionExtension()
}
}
diff --git
a/backends-clickhouse/src-delta/main/scala/org/apache/gluten/sql/shims/DeltaShims.scala
b/backends-clickhouse/src-delta/main/scala/org/apache/gluten/sql/shims/DeltaShims.scala
index 09de110817..ee263910dd 100644
---
a/backends-clickhouse/src-delta/main/scala/org/apache/gluten/sql/shims/DeltaShims.scala
+++
b/backends-clickhouse/src-delta/main/scala/org/apache/gluten/sql/shims/DeltaShims.scala
@@ -33,6 +33,8 @@ trait DeltaShims {
s"Can't transform ColumnarDeltaOptimizedWriterExec from
${plan.getClass.getSimpleName}")
}
+ def registerExpressionExtension(): Unit = {}
+
def convertRowIndexFilterIdEncoded(
partitionColsCnt: Int,
file: PartitionedFile,
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHListenerApi.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHListenerApi.scala
index 77d571aa19..4b661b3c2d 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHListenerApi.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHListenerApi.scala
@@ -53,7 +53,7 @@ class CHListenerApi extends ListenerApi with Logging {
pc.conf.get(GlutenConfig.EXTENDED_EXPRESSION_TRAN_CONF.key, "")
)
if (expressionExtensionTransformer != null) {
- ExpressionExtensionTrait.expressionExtensionTransformer =
expressionExtensionTransformer
+
ExpressionExtensionTrait.registerExpressionExtension(expressionExtensionTransformer)
}
}
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala
index 1b56c003dc..fe145364e2 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala
@@ -189,7 +189,8 @@ object CHRuleApi {
// case s: SerializeFromObjectExec => true
// case d: DeserializeToObjectExec => true
// case o: ObjectHashAggregateExec => true
- case rddScanExec: RDDScanExec if rddScanExec.nodeName.contains("Delta
Table State") => true
+// case rddScanExec: RDDScanExec if rddScanExec.no
+ // deName.contains("Delta Table State") => true
case f: FileSourceScanExec if includedDeltaOperator(f) => true
case v2CommandExec: V2CommandExec => true
case commandResultExec: CommandResultExec => true
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
index f9fcffed82..3a4267c4b1 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
@@ -583,7 +583,7 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with
Logging {
Sig[CollectSet](ExpressionNames.COLLECT_SET),
Sig[MonotonicallyIncreasingID](MONOTONICALLY_INCREASING_ID)
) ++
-
ExpressionExtensionTrait.expressionExtensionTransformer.expressionSigList ++
+ ExpressionExtensionTrait.expressionExtensionSigList ++
SparkShimLoader.getSparkShims.bloomFilterExpressionMappings()
}
@@ -592,12 +592,12 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with
Logging {
substraitExprName: String,
expr: Expression,
attributeSeq: Seq[Attribute]): Option[ExpressionTransformer] = expr
match {
- case e
- if
ExpressionExtensionTrait.expressionExtensionTransformer.extensionExpressionsMapping
- .contains(e.getClass) =>
+ case e if
ExpressionExtensionTrait.findExpressionExtension(e.getClass).nonEmpty =>
// Use extended expression transformer to replace custom expression first
Some(
- ExpressionExtensionTrait.expressionExtensionTransformer
+ ExpressionExtensionTrait
+ .findExpressionExtension(e.getClass)
+ .get
.replaceWithExtensionExpressionTransformer(substraitExprName, e,
attributeSeq))
case _ => None
}
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashAggregateExecTransformer.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashAggregateExecTransformer.scala
index f58a41bbea..88913bb2b1 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashAggregateExecTransformer.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashAggregateExecTransformer.scala
@@ -359,12 +359,11 @@ case class CHHashAggregateExecTransformer(
} else {
val aggExpr = aggExpressions(columnIndex - groupingExprs.length)
val aggregateFunc = aggExpr.aggregateFunction
+ val expressionExtensionTransformer =
+
ExpressionExtensionTrait.findExpressionExtension(aggregateFunc.getClass)
var aggFunctionName =
- if (
-
ExpressionExtensionTrait.expressionExtensionTransformer.extensionExpressionsMapping
- .contains(aggregateFunc.getClass)
- ) {
- ExpressionExtensionTrait.expressionExtensionTransformer
+ if (expressionExtensionTransformer.nonEmpty) {
+ expressionExtensionTransformer.get
.buildCustomAggregateFunction(aggregateFunc)
._1
.get
@@ -571,12 +570,11 @@ case class CHHashAggregateExecPullOutHelper(
index: Int): Int = {
var resIndex = index
val aggregateFunc = exp.aggregateFunction
+ val expressionExtensionTransformer =
+ ExpressionExtensionTrait.findExpressionExtension(aggregateFunc.getClass)
// First handle the custom aggregate functions
- if (
-
ExpressionExtensionTrait.expressionExtensionTransformer.extensionExpressionsMapping.contains(
- aggregateFunc.getClass)
- ) {
- ExpressionExtensionTrait.expressionExtensionTransformer
+ if (expressionExtensionTransformer.nonEmpty) {
+ expressionExtensionTransformer.get
.getAttrsIndexForExtensionAggregateExpr(
aggregateFunc,
exp.mode,
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/expression/CHExpressions.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/expression/CHExpressions.scala
index af1ac52b1e..fa8a5763a6 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/expression/CHExpressions.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/expression/CHExpressions.scala
@@ -27,13 +27,11 @@ object CHExpressions {
// Since https://github.com/apache/incubator-gluten/pull/1937.
def createAggregateFunction(args: java.lang.Object, aggregateFunc:
AggregateFunction): Long = {
val functionMap = args.asInstanceOf[java.util.HashMap[String,
java.lang.Long]]
- if (
-
ExpressionExtensionTrait.expressionExtensionTransformer.extensionExpressionsMapping.contains(
- aggregateFunc.getClass)
- ) {
+ val expressionExtensionTransformer =
+ ExpressionExtensionTrait.findExpressionExtension(aggregateFunc.getClass)
+ if (expressionExtensionTransformer.nonEmpty) {
val (substraitAggFuncName, inputTypes) =
-
ExpressionExtensionTrait.expressionExtensionTransformer.buildCustomAggregateFunction(
- aggregateFunc)
+
expressionExtensionTransformer.get.buildCustomAggregateFunction(aggregateFunc)
assert(substraitAggFuncName.isDefined)
return ExpressionBuilder.newScalarFunction(
functionMap,
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/ExpressionExtensionTrait.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/ExpressionExtensionTrait.scala
index c64f26869e..f2e8d4d8bd 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/ExpressionExtensionTrait.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/ExpressionExtensionTrait.scala
@@ -32,7 +32,7 @@ trait ExpressionExtensionTrait {
def expressionSigList: Seq[Sig]
/** Generate the extension expressions mapping map */
- def extensionExpressionsMapping: Map[Class[_], String] =
+ lazy val extensionExpressionsMapping: Map[Class[_], String] =
expressionSigList.map(s => (s.expClass, s.name)).toMap[Class[_], String]
/** Replace extension expression to transformer. */
@@ -63,9 +63,26 @@ trait ExpressionExtensionTrait {
}
}
-object ExpressionExtensionTrait {
- var expressionExtensionTransformer: ExpressionExtensionTrait =
- DefaultExpressionExtensionTransformer()
+object ExpressionExtensionTrait extends Logging {
+ private var expressionExtensionTransformers: Seq[ExpressionExtensionTrait] =
Seq.apply()
+
+ private var expressionExtensionSig = Seq.empty[Sig]
+ def expressionExtensionSigList: Seq[Sig] = expressionExtensionSig
+
+ def findExpressionExtension(clazz: Class[_]):
Option[ExpressionExtensionTrait] = {
+
expressionExtensionTransformers.find(_.extensionExpressionsMapping.contains(clazz))
+ }
+
+ def registerExpressionExtension(expressionExtension:
ExpressionExtensionTrait): Unit =
+ synchronized {
+ expressionExtensionTransformers.find(_.getClass ==
expressionExtension.getClass) match {
+ case Some(_) =>
+ logWarning(s"${expressionExtension.getClass} has been registered. It
will be ignore.")
+ case _ =>
+ expressionExtensionTransformers = expressionExtensionTransformers :+
expressionExtension
+ expressionExtensionSig =
expressionExtensionTransformers.flatMap(_.expressionSigList)
+ }
+ }
case class DefaultExpressionExtensionTransformer() extends
ExpressionExtensionTrait with Logging {
diff --git
a/cpp-ch/local-engine/AggregateFunctions/AggregateFunctionDVRoaringBitmap.cpp
b/cpp-ch/local-engine/AggregateFunctions/AggregateFunctionDVRoaringBitmap.cpp
new file mode 100644
index 0000000000..a9d68626d8
--- /dev/null
+++
b/cpp-ch/local-engine/AggregateFunctions/AggregateFunctionDVRoaringBitmap.cpp
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <AggregateFunctions/AggregateFunctionFactory.h>
+
+#include "AggregateFunctionDVRoaringBitmap.h"
+
+#include <AggregateFunctions/FactoryHelpers.h>
+
+namespace DB
+{
+struct Settings;
+
+namespace ErrorCodes
+{
+extern const int ILLEGAL_TYPE_OF_ARGUMENT;
+extern const int BAD_ARGUMENTS;
+extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
+}
+}
+
+namespace local_engine
+{
+
+DB::AggregateFunctionPtr createAggregateFunctionDVRoaringBitmap(
+ const std::string & name, const DB::DataTypes & argument_types, const
DB::Array & parameters, const DB::Settings *)
+{
+ return DB::AggregateFunctionPtr(
+ new AggregateFunctionDVRoaringBitmap<Int64,
AggregateFunctionDVRoaringBitmapData>(argument_types, parameters));
+}
+
+void registerAggregateFunctionDVRoaringBitmap(DB::AggregateFunctionFactory &
factory)
+{
+ factory.registerFunction("bitmapaggregator",
createAggregateFunctionDVRoaringBitmap);
+}
+}
diff --git
a/cpp-ch/local-engine/AggregateFunctions/AggregateFunctionDVRoaringBitmap.h
b/cpp-ch/local-engine/AggregateFunctions/AggregateFunctionDVRoaringBitmap.h
new file mode 100644
index 0000000000..a60184a5be
--- /dev/null
+++ b/cpp-ch/local-engine/AggregateFunctions/AggregateFunctionDVRoaringBitmap.h
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#pragma once
+
+#include <AggregateFunctions/IAggregateFunction.h>
+#include <Columns/ColumnArray.h>
+#include <Columns/ColumnString.h>
+#include <DataTypes/DataTypeString.h>
+#include <DataTypes/DataTypeTuple.h>
+#include <DataTypes/DataTypesNumber.h>
+#include <Storages/SubstraitSource/Delta/Bitmap/DeltaDVRoaringBitmapArray.h>
+
+
+namespace local_engine
+{
+
+struct AggregateFunctionDVRoaringBitmapData
+{
+ AggregateFunctionDVRoaringBitmapData() { }
+
+ DeltaDVRoaringBitmapArray roaring_bitmap_array;
+
+ void insertResultInto(DB::ColumnInt64 & cardinality, DB::ColumnInt64 &
last, DB::ColumnString & bitmap)
+ {
+ cardinality.getData().push_back(roaring_bitmap_array.rb_size());
+ auto last_value = roaring_bitmap_array.last();
+ if (last_value.has_value())
+ last.getData().push_back(last_value.value());
+ else
+ last.insertDefault();
+
+ bitmap.insert(roaring_bitmap_array.serialize());
+ }
+
+ void write(DB::WriteBuffer & buf) const
+ {
+ DB::writeString(roaring_bitmap_array.serialize(), buf);
+ }
+
+ void read(DB::ReadBuffer & buf)
+ {
+ roaring_bitmap_array.deserialize(buf);
+ }
+};
+
+
+template <typename T, typename Data>
+class AggregateFunctionDVRoaringBitmap final : public
DB::IAggregateFunctionDataHelper<Data, AggregateFunctionDVRoaringBitmap<T,
Data>>
+{
+public:
+ AggregateFunctionDVRoaringBitmap(const DB::DataTypes & argument_types_,
const DB::Array & parameters_)
+ : DB::IAggregateFunctionDataHelper<Data,
AggregateFunctionDVRoaringBitmap<T, Data>>(
+ argument_types_, parameters_, createResultType())
+ {
+ }
+
+ static DB::DataTypePtr createResultType()
+ {
+ DB::DataTypes types;
+ auto cardinality = std::make_shared<DB::DataTypeInt64>();
+ auto last = std::make_shared<DB::DataTypeInt64>();
+ auto bitmap = std::make_shared<DB::DataTypeString>();
+
+ types.emplace_back(cardinality);
+ types.emplace_back(last);
+ types.emplace_back(bitmap);
+
+ return std::make_shared<DB::DataTypeTuple>(types);
+ }
+
+ bool allocatesMemoryInArena() const override { return false; }
+
+ void add(DB::AggregateDataPtr __restrict place, const DB::IColumn **
columns, size_t row_num, DB::Arena *) const override
+ {
+ this->data(place).roaring_bitmap_array.rb_add(assert_cast<const
DB::ColumnVector<T> &>(*columns[0]).getData()[row_num]);
+ }
+
+ void merge(DB::AggregateDataPtr __restrict place,
DB::ConstAggregateDataPtr rhs, DB::Arena *) const override
+ {
+
this->data(place).roaring_bitmap_array.rb_merge(this->data(rhs).roaring_bitmap_array);
+ }
+
+ void insertResultInto(DB::AggregateDataPtr __restrict place, DB::IColumn &
to, DB::Arena *) const override
+ {
+ auto & to_tuple = assert_cast<DB::ColumnTuple &>(to);
+ auto & cardinality = assert_cast<DB::ColumnInt64
&>(to_tuple.getColumn(0));
+ auto & last = assert_cast<DB::ColumnInt64 &>(to_tuple.getColumn(1));
+ auto a = to_tuple.getColumn(2).getDataType();
+ auto & bitmap = assert_cast<DB::ColumnString &>(to_tuple.getColumn(2));
+ this->data(place).insertResultInto(cardinality, last, bitmap);
+ }
+
+ String getName() const override { return "bitmapaggregator"; }
+
+ void serialize(DB::ConstAggregateDataPtr place, DB::WriteBuffer & buf,
std::optional<size_t> version) const override
+ {
+ this->data(place).write(buf);
+ }
+
+ void deserialize(DB::AggregateDataPtr place, DB::ReadBuffer & buf,
std::optional<size_t> version, DB::Arena * arena) const override
+ {
+ this->data(place).read(buf);
+ }
+};
+}
diff --git a/cpp-ch/local-engine/Common/CHUtil.cpp
b/cpp-ch/local-engine/Common/CHUtil.cpp
index e2d5b01786..945344827d 100644
--- a/cpp-ch/local-engine/Common/CHUtil.cpp
+++ b/cpp-ch/local-engine/Common/CHUtil.cpp
@@ -895,6 +895,7 @@ extern void
registerAggregateFunctionCombinatorPartialMerge(AggregateFunctionCom
extern void registerAggregateFunctionsBloomFilter(AggregateFunctionFactory &);
extern void registerAggregateFunctionSparkAvg(AggregateFunctionFactory &);
extern void registerAggregateFunctionRowNumGroup(AggregateFunctionFactory &);
+extern void registerAggregateFunctionDVRoaringBitmap(AggregateFunctionFactory
&);
extern void registerFunctions(FunctionFactory &);
@@ -909,6 +910,7 @@ void registerAllFunctions()
registerAggregateFunctionSparkAvg(agg_factory);
registerAggregateFunctionRowNumGroup(agg_factory);
DB::registerAggregateFunctionUniqHyperLogLogPlusPlus(agg_factory);
+ registerAggregateFunctionDVRoaringBitmap(agg_factory);
/// register aggregate function combinators from local_engine
auto & combinator_factory = AggregateFunctionCombinatorFactory::instance();
diff --git
a/cpp-ch/local-engine/Parser/aggregate_function_parser/CommonAggregateFunctionParser.cpp
b/cpp-ch/local-engine/Parser/aggregate_function_parser/CommonAggregateFunctionParser.cpp
index d88885a312..ca448edb57 100644
---
a/cpp-ch/local-engine/Parser/aggregate_function_parser/CommonAggregateFunctionParser.cpp
+++
b/cpp-ch/local-engine/Parser/aggregate_function_parser/CommonAggregateFunctionParser.cpp
@@ -44,4 +44,5 @@ REGISTER_COMMON_AGGREGATE_FUNCTION_PARSER(PercentRank,
percent_rank, percent_ran
REGISTER_COMMON_AGGREGATE_FUNCTION_PARSER(Rank, rank, rank)
REGISTER_COMMON_AGGREGATE_FUNCTION_PARSER(RowNumber, row_number, row_number)
REGISTER_COMMON_AGGREGATE_FUNCTION_PARSER(CountDistinct, count_distinct,
uniqExact)
+REGISTER_COMMON_AGGREGATE_FUNCTION_PARSER(BitmapAggregator, bitmapaggregator,
bitmapaggregator)
}
diff --git
a/cpp-ch/local-engine/Storages/SubstraitSource/Delta/Bitmap/DeltaDVRoaringBitmapArray.cpp
b/cpp-ch/local-engine/Storages/SubstraitSource/Delta/Bitmap/DeltaDVRoaringBitmapArray.cpp
index 58dd911354..e238a2b362 100644
---
a/cpp-ch/local-engine/Storages/SubstraitSource/Delta/Bitmap/DeltaDVRoaringBitmapArray.cpp
+++
b/cpp-ch/local-engine/Storages/SubstraitSource/Delta/Bitmap/DeltaDVRoaringBitmapArray.cpp
@@ -49,11 +49,11 @@ UInt64
DeltaDVRoaringBitmapArray::compose_from_high_low_bytes(UInt32 high, UInt3
return (static_cast<uint64_t>(high) << 32) | low;
}
-DeltaDVRoaringBitmapArray::DeltaDVRoaringBitmapArray(const DB::ContextPtr &
context_) : context(context_)
+DeltaDVRoaringBitmapArray::DeltaDVRoaringBitmapArray()
{
}
-void DeltaDVRoaringBitmapArray::rb_read(const String & file_path, Int32
offset, Int32 data_size)
+void DeltaDVRoaringBitmapArray::rb_read(const String & file_path, Int32
offset, Int32 data_size, DB::ContextPtr context)
{
substrait::ReadRel::LocalFiles::FileOrFiles file_info;
file_info.set_uri_file(file_path);
@@ -74,30 +74,12 @@ void DeltaDVRoaringBitmapArray::rb_read(const String &
file_path, Int32 offset,
throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "The size of the
deletion vector is mismatch.");
int checksum_value = static_cast<Int32>(crc32_z(0L,
reinterpret_cast<unsigned char *>(in->position()), size));
+ deserialize(*in);
- int magic_num;
- readBinaryLittleEndian(magic_num, *in);
- if (magic_num != 1681511377)
- throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "The magic num is
mismatch.");
-
- int64_t bitmap_array_size;
- readBinaryLittleEndian(bitmap_array_size, *in);
-
- roaring_bitmap_array.reserve(bitmap_array_size);
- for (size_t i = 0; i < bitmap_array_size; ++i)
- {
- int bitmap_index;
- readBinaryLittleEndian(bitmap_index, *in);
- roaring::Roaring r = roaring::Roaring::read(in->position());
- size_t current_bitmap_size = r.getSizeInBytes();
- in->ignore(current_bitmap_size);
- roaring_bitmap_array.push_back(r);
- }
int expected_checksum;
readBinaryBigEndian(expected_checksum, *in);
if (expected_checksum != checksum_value)
throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Checksum
mismatch.");
-
}
UInt64 DeltaDVRoaringBitmapArray::rb_size() const
@@ -124,8 +106,7 @@ void DeltaDVRoaringBitmapArray::rb_clear()
bool DeltaDVRoaringBitmapArray::rb_is_empty() const
{
- return std::ranges::all_of(roaring_bitmap_array.begin(),
roaring_bitmap_array.end(),
- [](const auto& rb) { return rb.isEmpty(); });
+ return std::ranges::all_of(roaring_bitmap_array.begin(),
roaring_bitmap_array.end(), [](const auto & rb) { return rb.isEmpty(); });
}
void DeltaDVRoaringBitmapArray::rb_add(Int64 x)
@@ -171,4 +152,59 @@ bool DeltaDVRoaringBitmapArray::operator==(const
DeltaDVRoaringBitmapArray & oth
return roaring_bitmap_array == other.roaring_bitmap_array;
}
+
+
+String DeltaDVRoaringBitmapArray::serialize() const
+{
+ DB::WriteBufferFromOwnString out;
+ constexpr Int32 magic_number = 1681511377;
+ writeBinaryLittleEndian(magic_number, out);
+ Int64 size = roaring_bitmap_array.size();
+ writeBinaryLittleEndian(size, out);
+
+ for (Int32 i = 0; i < roaring_bitmap_array.size(); ++i)
+ {
+ writeBinaryLittleEndian(i, out);
+ std::unique_ptr<roaring::Roaring> bitmap =
std::make_unique<roaring::Roaring>(roaring_bitmap_array.at(i));
+ bitmap->runOptimize();
+ auto size_in_bytes = bitmap->getSizeInBytes();
+ std::unique_ptr<char[]> buf(new char[size_in_bytes]);
+ bitmap->write(buf.get());
+ out.write(buf.get(), size_in_bytes);
+ }
+
+ return out.str();
+}
+
+void DeltaDVRoaringBitmapArray::deserialize(DB::ReadBuffer & buf)
+{
+ Int32 magic_num;
+ readBinaryLittleEndian(magic_num, buf);
+ if (magic_num != 1681511377)
+ throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "The magic num is
mismatch.");
+
+ int64_t bitmap_array_size;
+ readBinaryLittleEndian(bitmap_array_size, buf);
+
+ roaring_bitmap_array.reserve(bitmap_array_size);
+ for (size_t i = 0; i < bitmap_array_size; ++i)
+ {
+ int bitmap_index;
+ readBinaryLittleEndian(bitmap_index, buf);
+ roaring::Roaring r = roaring::Roaring::read(buf.position());
+ size_t current_bitmap_size = r.getSizeInBytes();
+ buf.ignore(current_bitmap_size);
+ roaring_bitmap_array.push_back(r);
+ }
+}
+
+std::optional<Int64> DeltaDVRoaringBitmapArray::last()
+{
+ if (roaring_bitmap_array.empty() || roaring_bitmap_array.back().isEmpty())
+ return std::nullopt;
+
+ return compose_from_high_low_bytes(roaring_bitmap_array.size(),
roaring_bitmap_array.back().maximum());
+}
+
+
}
\ No newline at end of file
diff --git
a/cpp-ch/local-engine/Storages/SubstraitSource/Delta/Bitmap/DeltaDVRoaringBitmapArray.h
b/cpp-ch/local-engine/Storages/SubstraitSource/Delta/Bitmap/DeltaDVRoaringBitmapArray.h
index 5b9fd93f58..067ec5db32 100644
---
a/cpp-ch/local-engine/Storages/SubstraitSource/Delta/Bitmap/DeltaDVRoaringBitmapArray.h
+++
b/cpp-ch/local-engine/Storages/SubstraitSource/Delta/Bitmap/DeltaDVRoaringBitmapArray.h
@@ -16,6 +16,7 @@
*/
#pragma once
+#include <IO/ReadBuffer.h>
#include <Interpreters/Context_fwd.h>
#include <base/types.h>
#include <boost/core/noncopyable.hpp>
@@ -31,7 +32,6 @@ class DeltaDVRoaringBitmapArray : private boost::noncopyable
{
static constexpr Int64 MAX_REPRESENTABLE_VALUE
= (static_cast<UInt64>(INT32_MAX - 1) << 32) |
(static_cast<UInt64>(INT32_MIN) & 0xFFFFFFFFL);
- DB::ContextPtr context;
std::vector<roaring::Roaring> roaring_bitmap_array;
static std::pair<UInt32, UInt32> decompose_high_low_bytes(UInt64 value);
@@ -40,17 +40,20 @@ class DeltaDVRoaringBitmapArray : private boost::noncopyable
void rb_shrink_bitmaps(Int32 new_length);
public:
- explicit DeltaDVRoaringBitmapArray(const DB::ContextPtr & context_);
+ explicit DeltaDVRoaringBitmapArray();
~DeltaDVRoaringBitmapArray() = default;
bool operator==(const DeltaDVRoaringBitmapArray & other) const;
UInt64 rb_size() const;
- void rb_read(const String & file_path, Int32 offset, Int32 data_size);
+ void rb_read(const String & file_path, Int32 offset, Int32 data_size,
DB::ContextPtr context);
bool rb_contains(Int64 x) const;
bool rb_is_empty() const;
void rb_clear();
void rb_add(Int64 value);
void rb_merge(const DeltaDVRoaringBitmapArray & that);
void rb_or(const DeltaDVRoaringBitmapArray & that);
+ String serialize() const;
+ void deserialize(DB::ReadBuffer & buf);
+ std::optional<Int64> last();
};
}
\ No newline at end of file
diff --git a/cpp-ch/local-engine/Storages/SubstraitSource/Delta/DeltaReader.cpp
b/cpp-ch/local-engine/Storages/SubstraitSource/Delta/DeltaReader.cpp
index fbec57fb87..cee84b750d 100644
--- a/cpp-ch/local-engine/Storages/SubstraitSource/Delta/DeltaReader.cpp
+++ b/cpp-ch/local-engine/Storages/SubstraitSource/Delta/DeltaReader.cpp
@@ -81,8 +81,8 @@ DeltaReader::DeltaReader(
{
if (bitmap_config)
{
- bitmap_array =
std::make_unique<DeltaDVRoaringBitmapArray>(file->getContext());
- bitmap_array->rb_read(bitmap_config->path_or_inline_dv,
bitmap_config->offset, bitmap_config->size_in_bytes);
+ bitmap_array = std::make_unique<DeltaDVRoaringBitmapArray>();
+ bitmap_array->rb_read(bitmap_config->path_or_inline_dv,
bitmap_config->offset, bitmap_config->size_in_bytes, file->getContext());
}
}
diff --git
a/cpp-ch/local-engine/Storages/SubstraitSource/Iceberg/PositionalDeleteFileReader.cpp
b/cpp-ch/local-engine/Storages/SubstraitSource/Iceberg/PositionalDeleteFileReader.cpp
index 3db44699dc..8251e6b749 100644
---
a/cpp-ch/local-engine/Storages/SubstraitSource/Iceberg/PositionalDeleteFileReader.cpp
+++
b/cpp-ch/local-engine/Storages/SubstraitSource/Iceberg/PositionalDeleteFileReader.cpp
@@ -41,7 +41,7 @@ std::unique_ptr<DeltaDVRoaringBitmapArray> createBitmapExpr(
{
assert(!position_delete_files.empty());
- std::unique_ptr<DeltaDVRoaringBitmapArray> result =
std::make_unique<DeltaDVRoaringBitmapArray>(context);
+ std::unique_ptr<DeltaDVRoaringBitmapArray> result =
std::make_unique<DeltaDVRoaringBitmapArray>();
for (auto deleteIndex : position_delete_files)
{
diff --git a/cpp-ch/local-engine/tests/gtest_clickhouse_roaring_bitmap.cpp
b/cpp-ch/local-engine/tests/gtest_clickhouse_roaring_bitmap.cpp
index 758bc8e610..3dd13c9ccd 100644
--- a/cpp-ch/local-engine/tests/gtest_clickhouse_roaring_bitmap.cpp
+++ b/cpp-ch/local-engine/tests/gtest_clickhouse_roaring_bitmap.cpp
@@ -138,8 +138,8 @@ TEST(Delta_DV, DeltaDVRoaringBitmapArray)
const std::string
file_uri(test::gtest_uri("deletion_vector_multiple.bin"));
const std::string
file_uri1(test::gtest_uri("deletion_vector_only_one.bin"));
- DeltaDVRoaringBitmapArray bitmap_array(context);
- bitmap_array.rb_read(file_uri, 426433, 426424);
+ DeltaDVRoaringBitmapArray bitmap_array{};
+ bitmap_array.rb_read(file_uri, 426433, 426424, context);
EXPECT_TRUE(bitmap_array.rb_contains(5));
EXPECT_TRUE(bitmap_array.rb_contains(3618));
EXPECT_TRUE(bitmap_array.rb_contains(155688));
@@ -157,8 +157,8 @@ TEST(Delta_DV, DeltaDVRoaringBitmapArray)
bitmap_array.rb_add(10000000000);
EXPECT_TRUE(bitmap_array.rb_contains(10000000000));
- DeltaDVRoaringBitmapArray bitmap_array1(context);
- bitmap_array1.rb_read(file_uri1, 1, 539);
+ DeltaDVRoaringBitmapArray bitmap_array1{};
+ bitmap_array1.rb_read(file_uri1, 1, 539, context);
EXPECT_TRUE(bitmap_array1.rb_contains(0));
EXPECT_TRUE(bitmap_array1.rb_contains(1003));
EXPECT_TRUE(bitmap_array1.rb_contains(880));
@@ -173,8 +173,8 @@ TEST(Delta_DV, DeltaDVRoaringBitmapArray)
const std::string
file_uri2(test::gtest_uri("deletion_vector_long_values.bin"));
- DeltaDVRoaringBitmapArray bitmap_array2(context);
- bitmap_array2.rb_read(file_uri2, 1, 4047);
+ DeltaDVRoaringBitmapArray bitmap_array2{};
+ bitmap_array2.rb_read(file_uri2, 1, 4047, context);
EXPECT_FALSE(bitmap_array2.rb_is_empty());
EXPECT_EQ(2098, bitmap_array2.rb_size());
EXPECT_TRUE(bitmap_array2.rb_contains(0));
@@ -212,7 +212,7 @@ TEST(Delta_DV, DeltaDVRoaringBitmapArray)
bitmap_array2.rb_clear();
EXPECT_TRUE(bitmap_array2.rb_is_empty());
- DeltaDVRoaringBitmapArray bitmap_array3(context);
+ DeltaDVRoaringBitmapArray bitmap_array3{};
bitmap_array3.rb_add(3000000000);
bitmap_array3.rb_add(5000000000);
bitmap_array3.rb_add(10000000000);
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]