This is an automated email from the ASF dual-hosted git repository.
rui pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new e9034ff5e [GLUTEN-4946][CH] Fix avg(bigint) overflow (#5048)
e9034ff5e is described below
commit e9034ff5e2ec4cce8cd5defaf2ade9b44b8c8aa3
Author: loudongfeng <[email protected]>
AuthorDate: Mon Mar 25 12:55:00 2024 +0800
[GLUTEN-4946][CH] Fix avg(bigint) overflow (#5048)
---
.../clickhouse/CHSparkPlanExecApi.scala | 2 +
.../catalyst/CHAggregateFunctionRewriteRule.scala | 60 ++++++++++++++++++++++
.../execution/GlutenFunctionValidateSuite.scala | 21 ++++++++
.../main/scala/io/glutenproject/GlutenConfig.scala | 8 +++
4 files changed, 91 insertions(+)
diff --git
a/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHSparkPlanExecApi.scala
b/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHSparkPlanExecApi.scala
index 29af5a0e5..4b6ee1909 100644
---
a/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHSparkPlanExecApi.scala
+++
b/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHSparkPlanExecApi.scala
@@ -35,6 +35,7 @@ import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{GenShuffleWriterParameters,
GlutenShuffleWriterWrapper, HashPartitioningWrapper}
import org.apache.spark.shuffle.utils.CHShuffleUtil
import org.apache.spark.sql.{SparkSession, Strategy}
+import org.apache.spark.sql.catalyst.CHAggregateFunctionRewriteRule
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions._
@@ -518,6 +519,7 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
override def genExtendedOptimizers(): List[SparkSession =>
Rule[LogicalPlan]] = {
List(
spark => new CommonSubexpressionEliminateRule(spark,
spark.sessionState.conf),
+ spark => CHAggregateFunctionRewriteRule(spark),
_ => CountDistinctWithoutExpand
)
}
diff --git
a/backends-clickhouse/src/main/scala/org/apache/spark/sql/catalyst/CHAggregateFunctionRewriteRule.scala
b/backends-clickhouse/src/main/scala/org/apache/spark/sql/catalyst/CHAggregateFunctionRewriteRule.scala
new file mode 100644
index 000000000..623db7993
--- /dev/null
+++
b/backends-clickhouse/src/main/scala/org/apache/spark/sql/catalyst/CHAggregateFunctionRewriteRule.scala
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst
+
+import io.glutenproject.GlutenConfig
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.expressions.Cast
+import
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression,
Average}
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.types._
+
+/**
+ * Avg(Int) function: CH use input type for intermediate sum type, while spark
use double so need
+ * convert .
+ * @param spark
+ */
+case class CHAggregateFunctionRewriteRule(spark: SparkSession) extends
Rule[LogicalPlan] {
+ override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp
{
+ case a: Aggregate =>
+ a.transformExpressions {
+ case avgExpr @ AggregateExpression(avg: Average, _, _, _, _)
+ if GlutenConfig.getConf.enableCastAvgAggregateFunction &&
+ GlutenConfig.getConf.enableColumnarHashAgg &&
+ !avgExpr.isDistinct && isDataTypeNeedConvert(avg.child.dataType)
=>
+ AggregateExpression(
+ avg.copy(child = Cast(avg.child, DoubleType)),
+ avgExpr.mode,
+ avgExpr.isDistinct,
+ avgExpr.filter,
+ avgExpr.resultId
+ )
+ }
+ }
+
+ private def isDataTypeNeedConvert(dataType: DataType): Boolean = {
+ dataType match {
+ case FloatType => true
+ case IntegerType => true
+ case LongType => true
+ case ShortType => true
+ case _ => false
+ }
+ }
+}
diff --git
a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenFunctionValidateSuite.scala
b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenFunctionValidateSuite.scala
index 818fe4e5f..1a27e68fe 100644
---
a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenFunctionValidateSuite.scala
+++
b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenFunctionValidateSuite.scala
@@ -639,4 +639,25 @@ class GlutenFunctionValidateSuite extends
GlutenClickHouseWholeStageTransformerS
val sql = "select cast(concat(' ', cast(id as string)) as bigint) from
range(10)"
runQueryAndCompare(sql)(checkOperatorMatch[ProjectExecTransformer])
}
+
+ test("avg(bigint) overflow") {
+ withSQLConf(
+ "spark.gluten.sql.columnar.forceShuffledHashJoin" -> "false",
+ "spark.sql.autoBroadcastJoinThreshold" -> "-1") {
+ withTable("myitem") {
+ sql("create table big_int(id bigint) using parquet")
+ sql("""
+ |insert into big_int values (9223372036854775807),
+ |(9223372036854775807),
+ |(9223372036854775807),
+ |(9223372036854775807)
+ |""".stripMargin)
+ val q = "select avg(id) from big_int"
+
runQueryAndCompare(q)(checkOperatorMatch[CHHashAggregateExecTransformer])
+ val disinctSQL = "select count(distinct id), avg(distinct id), avg(id)
from big_int"
+
runQueryAndCompare(disinctSQL)(checkOperatorMatch[CHHashAggregateExecTransformer])
+ }
+ }
+ }
+
}
diff --git a/shims/common/src/main/scala/io/glutenproject/GlutenConfig.scala
b/shims/common/src/main/scala/io/glutenproject/GlutenConfig.scala
index 48e37cdb3..4119a09fc 100644
--- a/shims/common/src/main/scala/io/glutenproject/GlutenConfig.scala
+++ b/shims/common/src/main/scala/io/glutenproject/GlutenConfig.scala
@@ -353,6 +353,8 @@ class GlutenConfig(conf: SQLConf) extends Logging {
def enableColumnarProjectCollapse: Boolean =
conf.getConf(ENABLE_COLUMNAR_PROJECT_COLLAPSE)
def awsSdkLogLevel: String = conf.getConf(AWS_SDK_LOG_LEVEL)
+
+ def enableCastAvgAggregateFunction: Boolean =
conf.getConf(COLUMNAR_NATIVE_CAST_AGGREGATE_ENABLED)
}
object GlutenConfig {
@@ -1691,4 +1693,10 @@ object GlutenConfig {
.doc("Force fallback for orc char type scan.")
.booleanConf
.createWithDefault(true)
+
+ val COLUMNAR_NATIVE_CAST_AGGREGATE_ENABLED =
+ buildConf("spark.gluten.sql.columnar.cast.avg")
+ .internal()
+ .booleanConf
+ .createWithDefault(true)
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]