This is an automated email from the ASF dual-hosted git repository.

lwz9103 pushed a commit to branch liquid
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git

commit 307796100e92b4892dc401a17e67ed5cb9685579
Author: Wenzheng Liu <[email protected]>
AuthorDate: Thu May 30 13:54:13 2024 +0800

    [KE] Support sum0 function (#17)
    
    (cherry picked from commit 9aaf196a1970d88b8a690afb56722803970af8cd)
---
 .../execution/kap/GlutenKapExpressionsSuite.scala  | 131 +++++++++++++++++++++
 .../sql/catalyst/expressions/KapExpressions.scala  |  87 ++++++++++++++
 .../gluten/KapExpressionsTransformer.scala         |  74 ++++++++++++
 .../CommonAggregateFunctionParser.cpp              |   1 +
 4 files changed, 293 insertions(+)

diff --git 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/kap/GlutenKapExpressionsSuite.scala
 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/kap/GlutenKapExpressionsSuite.scala
new file mode 100644
index 0000000000..827096138d
--- /dev/null
+++ 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/kap/GlutenKapExpressionsSuite.scala
@@ -0,0 +1,131 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.execution.kap
+
+import org.apache.gluten.execution.GlutenClickHouseWholeStageTransformerSuite
+import org.apache.gluten.utils.UTSystemParameters
+
+import org.apache.spark.SparkConf
+import org.apache.spark.sql.catalyst.FunctionIdentifier
+import org.apache.spark.sql.catalyst.analysis.FunctionRegistryBase
+import org.apache.spark.sql.catalyst.expressions.Sum0
+import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
+
+class GlutenKapExpressionsSuite
+  extends GlutenClickHouseWholeStageTransformerSuite
+  with AdaptiveSparkPlanHelper {
+
+  /** Run Gluten + ClickHouse Backend with SortShuffleManager */
+  override protected def sparkConf: SparkConf = {
+    super.sparkConf
+      .set("spark.sql.shuffle.partitions", "10")
+      .set("spark.sql.files.maxPartitionBytes", "32m")
+      .set("spark.shuffle.manager", 
"org.apache.spark.shuffle.sort.ColumnarShuffleManager")
+      .set("spark.io.compression.codec", "LZ4")
+      .set("spark.sql.autoBroadcastJoinThreshold", "10MB")
+      .set("spark.sql.adaptive.enabled", "false")
+      .set("spark.sql.execution.useObjectHashAggregateExec", "true")
+      .set(
+        "spark.gluten.sql.columnar.extended.expressions.transformer",
+        
"org.apache.spark.sql.catalyst.expressions.gluten.KapExpressionsTransformer")
+      .set("spark.gluten.sql.columnar.backend.ch.shuffle.hash.algorithm", 
"sparkMurmurHash3_32")
+      
.set("spark.gluten.sql.columnar.backend.ch.runtime_config.use_local_format", 
"true")
+  }
+
+  override def beforeAll(): Unit = {
+    super.beforeAll()
+    // register the extension expressions
+    val (sum0Expr, sum0Builder) = FunctionRegistryBase.build[Sum0]("sum0", 
None)
+    spark.sessionState.functionRegistry.registerFunction(
+      FunctionIdentifier.apply("sum0"),
+      sum0Expr,
+      sum0Builder
+    )
+  }
+
+  test("test sum0") {
+    val viewName = "sum0_table"
+    withTempView(viewName) {
+      val sum0DataDir = 
s"${UTSystemParameters.testDataPath}/$fileFormat/index/sum0"
+      spark
+        .sql(s"""
+                | select
+                | `84` as week_beg_dt,
+                | `129` as meta_category_name,
+                | `149` as lstg_format_name,
+                | `180` as site_name,
+                | `100000` as trans_cnt,
+                | `100001` as gmv,
+                | `100002` as price_cnt,
+                | `100003` as seller_cnt,
+                | `100004` as total_items
+                | from parquet.`$sum0DataDir`
+                |""".stripMargin)
+        .createOrReplaceTempView(viewName)
+      val sql1 =
+        s"""
+           |select * from sum0_table order by week_beg_dt limit 10
+           |""".stripMargin
+      compareResultsAgainstVanillaSpark(sql1, compareResult = true, df => { 
df.show(10) })
+
+      val sql2 =
+        s"""
+           |select sum(gmv), sum0(trans_cnt), sum0(price_cnt), 
sum0(seller_cnt), sum(total_items)
+           |from sum0_table
+           |where lstg_format_name='FP-GTC'
+           |and week_beg_dt between '2013-05-01' and DATE '2013-08-01'
+           |""".stripMargin
+      compareResultsAgainstVanillaSpark(sql2, compareResult = true, _ => {})
+
+      val sql3 =
+        s"""
+           |select site_name, sum(gmv), sum0(trans_cnt), sum(total_items) from 
sum0_table
+           |where lstg_format_name='FP-GTC'
+           |and week_beg_dt = '2013-01-01'
+           |group by site_name order by site_name
+           |""".stripMargin
+      compareResultsAgainstVanillaSpark(sql3, compareResult = true, _ => {})
+
+      // aggregate empty data with group by
+      val sql4 =
+        s"""
+           |select site_name, sum(gmv), sum0(trans_cnt), sum(total_items) from 
sum0_table
+           |where extract(month from week_beg_dt) = 13
+           |group by site_name order by site_name
+           |""".stripMargin
+      compareResultsAgainstVanillaSpark(sql4, compareResult = true, _ => {})
+
+      // aggregate empty data without group by
+      val sql5 =
+        s"""
+           |select sum(gmv), sum0(trans_cnt), sum(total_items) from sum0_table
+           |where extract(month from week_beg_dt) = 13
+           |""".stripMargin
+      compareResultsAgainstVanillaSpark(sql5, compareResult = true, _ => {})
+
+      val sql6 =
+        s"""
+           |select sum(gmv), sum0(trans_cnt), sum(total_items) from sum0_table
+           |where extract(month from week_beg_dt) = 12
+           |union all
+           |select sum(gmv), sum0(trans_cnt), sum(total_items) from sum0_table
+           |where extract(month from week_beg_dt) = 13
+           |""".stripMargin
+      compareResultsAgainstVanillaSpark(sql6, compareResult = true, _ => {})
+    }
+  }
+}
diff --git 
a/backends-clickhouse/src/test/scala/org/apache/spark/sql/catalyst/expressions/KapExpressions.scala
 
b/backends-clickhouse/src/test/scala/org/apache/spark/sql/catalyst/expressions/KapExpressions.scala
new file mode 100644
index 0000000000..27281d74b7
--- /dev/null
+++ 
b/backends-clickhouse/src/test/scala/org/apache/spark/sql/catalyst/expressions/KapExpressions.scala
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate
+import org.apache.spark.sql.catalyst.util.TypeUtils
+import org.apache.spark.sql.types._
+
+@ExpressionDescription(
+  usage = "_FUNC_(expr) - Returns the sum calculated from values of a group. " 
+
+    "It differs in that when no non null values are applied zero is returned 
instead of null")
+case class Sum0(child: Expression) extends DeclarativeAggregate with 
ImplicitCastInputTypes {
+
+  override def children: Seq[Expression] = child :: Nil
+
+  override def nullable: Boolean = true
+
+  // Return data type.
+  override def dataType: DataType = resultType
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
+
+  override def checkInputDataTypes(): TypeCheckResult =
+    TypeUtils.checkForNumericExpr(child.dataType, "function sum")
+
+  private lazy val resultType = child.dataType match {
+    case DecimalType.Fixed(precision, scale) =>
+      DecimalType.bounded(precision + 10, scale)
+    case _: IntegralType => LongType
+    case _ => DoubleType
+  }
+
+  private lazy val sumDataType = resultType
+
+  private lazy val sum = AttributeReference("sum", sumDataType)()
+
+  private lazy val zero = Cast(Literal(0), sumDataType)
+
+  override lazy val aggBufferAttributes = sum :: Nil
+
+  override lazy val initialValues: Seq[Expression] = Seq(
+    //    /* sum = */ Literal.create(0, sumDataType)
+    //    /* sum = */ Literal.create(null, sumDataType)
+    Cast(Literal(0), sumDataType)
+  )
+
+  override lazy val updateExpressions: Seq[Expression] = {
+    if (child.nullable) {
+      Seq(
+        /* sum = */
+        Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(child, sumDataType)), 
sum))
+      )
+    } else {
+      Seq(
+        /* sum = */
+        Add(Coalesce(Seq(sum, zero)), Cast(child, sumDataType))
+      )
+    }
+  }
+
+  override lazy val mergeExpressions: Seq[Expression] = {
+    Seq(
+      /* sum = */
+      Coalesce(Seq(Add(Coalesce(Seq(sum.left, zero)), sum.right), sum.left))
+    )
+  }
+
+  override lazy val evaluateExpression: Expression = sum
+
+  override protected def withNewChildrenInternal(newChildren: 
IndexedSeq[Expression]): Expression =
+    super.legacyWithNewChildren(newChildren)
+}
diff --git 
a/backends-clickhouse/src/test/scala/org/apache/spark/sql/catalyst/expressions/gluten/KapExpressionsTransformer.scala
 
b/backends-clickhouse/src/test/scala/org/apache/spark/sql/catalyst/expressions/gluten/KapExpressionsTransformer.scala
new file mode 100644
index 0000000000..af9fe6ae1e
--- /dev/null
+++ 
b/backends-clickhouse/src/test/scala/org/apache/spark/sql/catalyst/expressions/gluten/KapExpressionsTransformer.scala
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.expressions.gluten
+
+import org.apache.gluten.exception.GlutenNotSupportException
+import org.apache.gluten.expression._
+import org.apache.gluten.extension.ExpressionExtensionTrait
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.sql.types.DataType
+
+import scala.collection.mutable.ListBuffer
+
+case class KapExpressionsTransformer() extends ExpressionExtensionTrait {
+
+  /** Generate the extension expressions list, format: 
Sig[XXXExpression]("XXXExpressionName") */
+  def expressionSigList: Seq[Sig] = Seq(
+    Sig[Sum0]("sum0")
+  )
+
+  override def getAttrsIndexForExtensionAggregateExpr(
+      aggregateFunc: AggregateFunction,
+      mode: AggregateMode,
+      exp: AggregateExpression,
+      aggregateAttributeList: Seq[Attribute],
+      aggregateAttr: ListBuffer[Attribute],
+      resIndex: Int): Int = {
+    var resIdx = resIndex
+    exp.mode match {
+      case Partial | PartialMerge =>
+        val aggBufferAttr = aggregateFunc.inputAggBufferAttributes
+        for (index <- aggBufferAttr.indices) {
+          val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr(index))
+          aggregateAttr += attr
+        }
+        resIdx += aggBufferAttr.size
+        resIdx
+      case Final | Complete =>
+        aggregateAttr += aggregateAttributeList(resIdx)
+        resIdx += 1
+        resIdx
+      case other =>
+        throw new GlutenNotSupportException(s"Unsupported aggregate mode: 
$other.")
+    }
+  }
+
+  override def buildCustomAggregateFunction(
+      aggregateFunc: AggregateFunction): (Option[String], Seq[DataType]) = {
+    val substraitAggFuncName = aggregateFunc match {
+      case _ =>
+        extensionExpressionsMapping.get(aggregateFunc.getClass)
+    }
+    if (substraitAggFuncName.isEmpty) {
+      throw new UnsupportedOperationException(
+        s"Aggregate function ${aggregateFunc.getClass} is not supported.")
+    }
+    (substraitAggFuncName, aggregateFunc.children.map(child => child.dataType))
+  }
+}
diff --git 
a/cpp-ch/local-engine/Parser/aggregate_function_parser/CommonAggregateFunctionParser.cpp
 
b/cpp-ch/local-engine/Parser/aggregate_function_parser/CommonAggregateFunctionParser.cpp
index 2369d6cdd7..7bdab862f5 100644
--- 
a/cpp-ch/local-engine/Parser/aggregate_function_parser/CommonAggregateFunctionParser.cpp
+++ 
b/cpp-ch/local-engine/Parser/aggregate_function_parser/CommonAggregateFunctionParser.cpp
@@ -26,6 +26,7 @@ REGISTER_COMMON_AGGREGATE_FUNCTION_PARSER(KeBitmapAndValue, 
ke_bitmap_and_value,
 REGISTER_COMMON_AGGREGATE_FUNCTION_PARSER(KeBitmapAndIds, ke_bitmap_and_ids, 
ke_bitmap_and_ids)
 
 REGISTER_COMMON_AGGREGATE_FUNCTION_PARSER(Sum, sum, sum)
+REGISTER_COMMON_AGGREGATE_FUNCTION_PARSER(Sum0, sum0, sum0)
 REGISTER_COMMON_AGGREGATE_FUNCTION_PARSER(Avg, avg, avg)
 REGISTER_COMMON_AGGREGATE_FUNCTION_PARSER(Min, min, min)
 REGISTER_COMMON_AGGREGATE_FUNCTION_PARSER(Max, max, max)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to