This is an automated email from the ASF dual-hosted git repository. lwz9103 pushed a commit to branch liquid in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
commit 307796100e92b4892dc401a17e67ed5cb9685579 Author: Wenzheng Liu <[email protected]> AuthorDate: Thu May 30 13:54:13 2024 +0800 [KE] Support sum0 function (#17) (cherry picked from commit 9aaf196a1970d88b8a690afb56722803970af8cd) --- .../execution/kap/GlutenKapExpressionsSuite.scala | 131 +++++++++++++++++++++ .../sql/catalyst/expressions/KapExpressions.scala | 87 ++++++++++++++ .../gluten/KapExpressionsTransformer.scala | 74 ++++++++++++ .../CommonAggregateFunctionParser.cpp | 1 + 4 files changed, 293 insertions(+) diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/kap/GlutenKapExpressionsSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/kap/GlutenKapExpressionsSuite.scala new file mode 100644 index 0000000000..827096138d --- /dev/null +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/kap/GlutenKapExpressionsSuite.scala @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.execution.kap + +import org.apache.gluten.execution.GlutenClickHouseWholeStageTransformerSuite +import org.apache.gluten.utils.UTSystemParameters + +import org.apache.spark.SparkConf +import org.apache.spark.sql.catalyst.FunctionIdentifier +import org.apache.spark.sql.catalyst.analysis.FunctionRegistryBase +import org.apache.spark.sql.catalyst.expressions.Sum0 +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper + +class GlutenKapExpressionsSuite + extends GlutenClickHouseWholeStageTransformerSuite + with AdaptiveSparkPlanHelper { + + /** Run Gluten + ClickHouse Backend with SortShuffleManager */ + override protected def sparkConf: SparkConf = { + super.sparkConf + .set("spark.sql.shuffle.partitions", "10") + .set("spark.sql.files.maxPartitionBytes", "32m") + .set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.ColumnarShuffleManager") + .set("spark.io.compression.codec", "LZ4") + .set("spark.sql.autoBroadcastJoinThreshold", "10MB") + .set("spark.sql.adaptive.enabled", "false") + .set("spark.sql.execution.useObjectHashAggregateExec", "true") + .set( + "spark.gluten.sql.columnar.extended.expressions.transformer", + "org.apache.spark.sql.catalyst.expressions.gluten.KapExpressionsTransformer") + .set("spark.gluten.sql.columnar.backend.ch.shuffle.hash.algorithm", "sparkMurmurHash3_32") + .set("spark.gluten.sql.columnar.backend.ch.runtime_config.use_local_format", "true") + } + + override def beforeAll(): Unit = { + super.beforeAll() + // register the extension expressions + val (sum0Expr, sum0Builder) = FunctionRegistryBase.build[Sum0]("sum0", None) + spark.sessionState.functionRegistry.registerFunction( + FunctionIdentifier.apply("sum0"), + sum0Expr, + sum0Builder + ) + } + + test("test sum0") { + val viewName = "sum0_table" + withTempView(viewName) { + val sum0DataDir = s"${UTSystemParameters.testDataPath}/$fileFormat/index/sum0" + spark + .sql(s""" + | select + | `84` as week_beg_dt, + | `129` as meta_category_name, + | `149` as lstg_format_name, + | `180` as site_name, + | `100000` as trans_cnt, + | `100001` as gmv, + | `100002` as price_cnt, + | `100003` as seller_cnt, + | `100004` as total_items + | from parquet.`$sum0DataDir` + |""".stripMargin) + .createOrReplaceTempView(viewName) + val sql1 = + s""" + |select * from sum0_table order by week_beg_dt limit 10 + |""".stripMargin + compareResultsAgainstVanillaSpark(sql1, compareResult = true, df => { df.show(10) }) + + val sql2 = + s""" + |select sum(gmv), sum0(trans_cnt), sum0(price_cnt), sum0(seller_cnt), sum(total_items) + |from sum0_table + |where lstg_format_name='FP-GTC' + |and week_beg_dt between '2013-05-01' and DATE '2013-08-01' + |""".stripMargin + compareResultsAgainstVanillaSpark(sql2, compareResult = true, _ => {}) + + val sql3 = + s""" + |select site_name, sum(gmv), sum0(trans_cnt), sum(total_items) from sum0_table + |where lstg_format_name='FP-GTC' + |and week_beg_dt = '2013-01-01' + |group by site_name order by site_name + |""".stripMargin + compareResultsAgainstVanillaSpark(sql3, compareResult = true, _ => {}) + + // aggregate empty data with group by + val sql4 = + s""" + |select site_name, sum(gmv), sum0(trans_cnt), sum(total_items) from sum0_table + |where extract(month from week_beg_dt) = 13 + |group by site_name order by site_name + |""".stripMargin + compareResultsAgainstVanillaSpark(sql4, compareResult = true, _ => {}) + + // aggregate empty data without group by + val sql5 = + s""" + |select sum(gmv), sum0(trans_cnt), sum(total_items) from sum0_table + |where extract(month from week_beg_dt) = 13 + |""".stripMargin + compareResultsAgainstVanillaSpark(sql5, compareResult = true, _ => {}) + + val sql6 = + s""" + |select sum(gmv), sum0(trans_cnt), sum(total_items) from sum0_table + |where extract(month from week_beg_dt) = 12 + |union all + |select sum(gmv), sum0(trans_cnt), sum(total_items) from sum0_table + |where extract(month from week_beg_dt) = 13 + |""".stripMargin + compareResultsAgainstVanillaSpark(sql6, compareResult = true, _ => {}) + } + } +} diff --git a/backends-clickhouse/src/test/scala/org/apache/spark/sql/catalyst/expressions/KapExpressions.scala b/backends-clickhouse/src/test/scala/org/apache/spark/sql/catalyst/expressions/KapExpressions.scala new file mode 100644 index 0000000000..27281d74b7 --- /dev/null +++ b/backends-clickhouse/src/test/scala/org/apache/spark/sql/catalyst/expressions/KapExpressions.scala @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate +import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.types._ + +@ExpressionDescription( + usage = "_FUNC_(expr) - Returns the sum calculated from values of a group. " + + "It differs in that when no non null values are applied zero is returned instead of null") +case class Sum0(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes { + + override def children: Seq[Expression] = child :: Nil + + override def nullable: Boolean = true + + // Return data type. + override def dataType: DataType = resultType + + override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) + + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForNumericExpr(child.dataType, "function sum") + + private lazy val resultType = child.dataType match { + case DecimalType.Fixed(precision, scale) => + DecimalType.bounded(precision + 10, scale) + case _: IntegralType => LongType + case _ => DoubleType + } + + private lazy val sumDataType = resultType + + private lazy val sum = AttributeReference("sum", sumDataType)() + + private lazy val zero = Cast(Literal(0), sumDataType) + + override lazy val aggBufferAttributes = sum :: Nil + + override lazy val initialValues: Seq[Expression] = Seq( + // /* sum = */ Literal.create(0, sumDataType) + // /* sum = */ Literal.create(null, sumDataType) + Cast(Literal(0), sumDataType) + ) + + override lazy val updateExpressions: Seq[Expression] = { + if (child.nullable) { + Seq( + /* sum = */ + Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(child, sumDataType)), sum)) + ) + } else { + Seq( + /* sum = */ + Add(Coalesce(Seq(sum, zero)), Cast(child, sumDataType)) + ) + } + } + + override lazy val mergeExpressions: Seq[Expression] = { + Seq( + /* sum = */ + Coalesce(Seq(Add(Coalesce(Seq(sum.left, zero)), sum.right), sum.left)) + ) + } + + override lazy val evaluateExpression: Expression = sum + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = + super.legacyWithNewChildren(newChildren) +} diff --git a/backends-clickhouse/src/test/scala/org/apache/spark/sql/catalyst/expressions/gluten/KapExpressionsTransformer.scala b/backends-clickhouse/src/test/scala/org/apache/spark/sql/catalyst/expressions/gluten/KapExpressionsTransformer.scala new file mode 100644 index 0000000000..af9fe6ae1e --- /dev/null +++ b/backends-clickhouse/src/test/scala/org/apache/spark/sql/catalyst/expressions/gluten/KapExpressionsTransformer.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions.gluten + +import org.apache.gluten.exception.GlutenNotSupportException +import org.apache.gluten.expression._ +import org.apache.gluten.extension.ExpressionExtensionTrait + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.types.DataType + +import scala.collection.mutable.ListBuffer + +case class KapExpressionsTransformer() extends ExpressionExtensionTrait { + + /** Generate the extension expressions list, format: Sig[XXXExpression]("XXXExpressionName") */ + def expressionSigList: Seq[Sig] = Seq( + Sig[Sum0]("sum0") + ) + + override def getAttrsIndexForExtensionAggregateExpr( + aggregateFunc: AggregateFunction, + mode: AggregateMode, + exp: AggregateExpression, + aggregateAttributeList: Seq[Attribute], + aggregateAttr: ListBuffer[Attribute], + resIndex: Int): Int = { + var resIdx = resIndex + exp.mode match { + case Partial | PartialMerge => + val aggBufferAttr = aggregateFunc.inputAggBufferAttributes + for (index <- aggBufferAttr.indices) { + val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr(index)) + aggregateAttr += attr + } + resIdx += aggBufferAttr.size + resIdx + case Final | Complete => + aggregateAttr += aggregateAttributeList(resIdx) + resIdx += 1 + resIdx + case other => + throw new GlutenNotSupportException(s"Unsupported aggregate mode: $other.") + } + } + + override def buildCustomAggregateFunction( + aggregateFunc: AggregateFunction): (Option[String], Seq[DataType]) = { + val substraitAggFuncName = aggregateFunc match { + case _ => + extensionExpressionsMapping.get(aggregateFunc.getClass) + } + if (substraitAggFuncName.isEmpty) { + throw new UnsupportedOperationException( + s"Aggregate function ${aggregateFunc.getClass} is not supported.") + } + (substraitAggFuncName, aggregateFunc.children.map(child => child.dataType)) + } +} diff --git a/cpp-ch/local-engine/Parser/aggregate_function_parser/CommonAggregateFunctionParser.cpp b/cpp-ch/local-engine/Parser/aggregate_function_parser/CommonAggregateFunctionParser.cpp index 2369d6cdd7..7bdab862f5 100644 --- a/cpp-ch/local-engine/Parser/aggregate_function_parser/CommonAggregateFunctionParser.cpp +++ b/cpp-ch/local-engine/Parser/aggregate_function_parser/CommonAggregateFunctionParser.cpp @@ -26,6 +26,7 @@ REGISTER_COMMON_AGGREGATE_FUNCTION_PARSER(KeBitmapAndValue, ke_bitmap_and_value, REGISTER_COMMON_AGGREGATE_FUNCTION_PARSER(KeBitmapAndIds, ke_bitmap_and_ids, ke_bitmap_and_ids) REGISTER_COMMON_AGGREGATE_FUNCTION_PARSER(Sum, sum, sum) +REGISTER_COMMON_AGGREGATE_FUNCTION_PARSER(Sum0, sum0, sum0) REGISTER_COMMON_AGGREGATE_FUNCTION_PARSER(Avg, avg, avg) REGISTER_COMMON_AGGREGATE_FUNCTION_PARSER(Min, min, min) REGISTER_COMMON_AGGREGATE_FUNCTION_PARSER(Max, max, max) --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
