This is an automated email from the ASF dual-hosted git repository.
mahongbin pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new b07b36960 [Gluten-4706] Add a mode to execute count distinct directly
instead of Expand+Count (#4708)
b07b36960 is described below
commit b07b36960238db62647b6aeefc930b7fc3f271c7
Author: Hongbin Ma <[email protected]>
AuthorDate: Tue Mar 12 16:38:47 2024 +0800
[Gluten-4706] Add a mode to execute count distinct directly instead of
Expand+Count (#4708)
---
.../clickhouse/CHSparkPlanExecApi.scala | 6 +-
.../execution/CHHashAggregateExecTransformer.scala | 6 ++
.../execution/GlutenClickHouseHiveTableSuite.scala | 71 +++++++------
.../GlutenClickHouseNativeWriteTableSuite.scala | 38 +------
...lutenClickHouseWholeStageTransformerSuite.scala | 17 ++-
.../GlutenClickhouseCountDistinctSuite.scala | 118 +++++++++++++++++++++
.../extension/CustomAggExpressionTransformer.scala | 2 +-
.../AggregateFunctionPartialMerge.h | 4 +-
.../local-engine/Parser/AggregateFunctionParser.h | 14 ++-
.../CommonAggregateFunctionParser.cpp | 2 +-
.../expression/ExpressionMappings.scala | 1 +
.../extension/CountDistinctWithoutExpand.scala | 49 +++++++++
.../expressions/aggregate/CountDistinct.scala | 59 +++++++++++
.../main/scala/io/glutenproject/GlutenConfig.scala | 13 +++
.../glutenproject/expression/ExpressionNames.scala | 1 +
15 files changed, 322 insertions(+), 79 deletions(-)
diff --git
a/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHSparkPlanExecApi.scala
b/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHSparkPlanExecApi.scala
index 43d8ed1bc..f32116728 100644
---
a/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHSparkPlanExecApi.scala
+++
b/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHSparkPlanExecApi.scala
@@ -22,6 +22,7 @@ import io.glutenproject.execution._
import io.glutenproject.expression._
import io.glutenproject.expression.ConverterUtils.FunctionConfig
import io.glutenproject.extension.{FallbackBroadcastHashJoin,
FallbackBroadcastHashJoinPrepQueryStage}
+import io.glutenproject.extension.CountDistinctWithoutExpand
import io.glutenproject.extension.columnar.AddTransformHintRule
import
io.glutenproject.extension.columnar.MiscColumnarRules.TransformPreOverrides
import io.glutenproject.substrait.expression.{ExpressionBuilder,
ExpressionNode, WindowFunctionNode}
@@ -503,7 +504,10 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
* @return
*/
override def genExtendedOptimizers(): List[SparkSession =>
Rule[LogicalPlan]] = {
- List(spark => new CommonSubexpressionEliminateRule(spark,
spark.sessionState.conf))
+ List(
+ spark => new CommonSubexpressionEliminateRule(spark,
spark.sessionState.conf),
+ _ => CountDistinctWithoutExpand
+ )
}
/**
diff --git
a/backends-clickhouse/src/main/scala/io/glutenproject/execution/CHHashAggregateExecTransformer.scala
b/backends-clickhouse/src/main/scala/io/glutenproject/execution/CHHashAggregateExecTransformer.scala
index c6b99a6bc..85b967446 100644
---
a/backends-clickhouse/src/main/scala/io/glutenproject/execution/CHHashAggregateExecTransformer.scala
+++
b/backends-clickhouse/src/main/scala/io/glutenproject/execution/CHHashAggregateExecTransformer.scala
@@ -358,6 +358,12 @@ case class CHHashAggregateExecTransformer(
(makeStructType(fields), attr.nullable)
case expr if "bloom_filter_agg".equals(expr.prettyName) =>
(makeStructTypeSingleOne(expr.children.head.dataType,
attr.nullable), attr.nullable)
+ case cd: CountDistinct =>
+ var fields = Seq[(DataType, Boolean)]()
+ for (child <- cd.children) {
+ fields = fields :+ (child.dataType, child.nullable)
+ }
+ (makeStructType(fields), false)
case _ =>
(makeStructTypeSingleOne(attr.dataType, attr.nullable),
attr.nullable)
}
diff --git
a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseHiveTableSuite.scala
b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseHiveTableSuite.scala
index 938f8e6d1..b40a0fe0d 100644
---
a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseHiveTableSuite.scala
+++
b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseHiveTableSuite.scala
@@ -19,8 +19,10 @@ package io.glutenproject.execution
import io.glutenproject.GlutenConfig
import io.glutenproject.utils.UTSystemParameters
-import org.apache.spark.{SPARK_VERSION_SHORT, SparkConf}
-import org.apache.spark.sql.{DataFrame, SparkSession}
+import org.apache.spark.SPARK_VERSION_SHORT
+import org.apache.spark.SparkConf
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.delta.DeltaLog
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.hive.HiveTableScanExecTransformer
@@ -29,11 +31,11 @@ import org.apache.spark.sql.internal.SQLConf
import org.apache.hadoop.fs.Path
import java.io.{File, PrintWriter}
-import java.sql.Timestamp
+import java.sql.{Date, Timestamp}
import scala.reflect.ClassTag
-case class AllDataTypesWithComplextType(
+case class AllDataTypesWithComplexType(
string_field: String = null,
int_field: java.lang.Integer = null,
long_field: java.lang.Long = null,
@@ -51,6 +53,36 @@ case class AllDataTypesWithComplextType(
mapValueContainsNull: Map[Int, Option[Long]] = null
)
+object AllDataTypesWithComplexType {
+ def genTestData(): Seq[AllDataTypesWithComplexType] = {
+ (0 to 199).map {
+ i =>
+ if (i % 100 == 1) {
+ AllDataTypesWithComplexType()
+ } else {
+ AllDataTypesWithComplexType(
+ s"$i",
+ i,
+ i.toLong,
+ i.toFloat,
+ i.toDouble,
+ i.toShort,
+ i.toByte,
+ i % 2 == 0,
+ new java.math.BigDecimal(i + ".56"),
+ Date.valueOf(new
Date(System.currentTimeMillis()).toLocalDate.plusDays(i % 10)),
+ Timestamp.valueOf(
+ new
Timestamp(System.currentTimeMillis()).toLocalDateTime.plusDays(i % 10)),
+ Seq.apply(i + 1, i + 2, i + 3),
+ Seq.apply(Option.apply(i + 1), Option.empty, Option.apply(i + 3)),
+ Map.apply((i + 1, i + 2), (i + 3, i + 4)),
+ Map.empty
+ )
+ }
+ }
+ }
+}
+
class GlutenClickHouseHiveTableSuite
extends GlutenClickHouseWholeStageTransformerSuite
with AdaptiveSparkPlanHelper {
@@ -170,38 +202,13 @@ class GlutenClickHouseHiveTableSuite
"map_field map<int, long>," +
"map_field_with_null map<int, long>) stored as %s".format(fileFormat)
- def genTestData(): Seq[AllDataTypesWithComplextType] = {
- (0 to 199).map {
- i =>
- if (i % 100 == 1) {
- AllDataTypesWithComplextType()
- } else {
- AllDataTypesWithComplextType(
- s"$i",
- i,
- i.toLong,
- i.toFloat,
- i.toDouble,
- i.toShort,
- i.toByte,
- i % 2 == 0,
- new java.math.BigDecimal(i + ".56"),
- new java.sql.Date(System.currentTimeMillis()),
- new Timestamp(System.currentTimeMillis()),
- Seq.apply(i + 1, i + 2, i + 3),
- Seq.apply(Option.apply(i + 1), Option.empty, Option.apply(i + 3)),
- Map.apply((i + 1, i + 2), (i + 3, i + 4)),
- Map.empty
- )
- }
- }
- }
-
protected def initializeTable(
table_name: String,
table_create_sql: String,
partitions: Seq[String]): Unit = {
- spark.createDataFrame(genTestData()).createOrReplaceTempView("tmp_t")
+ spark
+ .createDataFrame(AllDataTypesWithComplexType.genTestData())
+ .createOrReplaceTempView("tmp_t")
val truncate_sql = "truncate table %s".format(table_name)
val drop_sql = "drop table if exists %s".format(table_name)
spark.sql(drop_sql)
diff --git
a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseNativeWriteTableSuite.scala
b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseNativeWriteTableSuite.scala
index 5e9c2d1fc..724fb2721 100644
---
a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseNativeWriteTableSuite.scala
+++
b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseNativeWriteTableSuite.scala
@@ -17,6 +17,7 @@
package io.glutenproject.execution
import io.glutenproject.GlutenConfig
+import io.glutenproject.execution.AllDataTypesWithComplexType.genTestData
import io.glutenproject.utils.UTSystemParameters
import org.apache.spark.SparkConf
@@ -27,8 +28,6 @@ import org.apache.spark.sql.test.SharedSparkSession
import org.scalatest.BeforeAndAfterAll
-import java.sql.{Date, Timestamp}
-
class GlutenClickHouseNativeWriteTableSuite
extends GlutenClickHouseWholeStageTransformerSuite
with AdaptiveSparkPlanHelper
@@ -90,41 +89,6 @@ class GlutenClickHouseNativeWriteTableSuite
private val table_name_vanilla_template = "hive_%s_test_written_by_vanilla"
private val formats = Array("orc", "parquet")
- def genTestData(): Seq[AllDataTypesWithComplextType] = {
- (0 to 199).map {
- i =>
- if (i % 100 == 1) {
- AllDataTypesWithComplextType()
- } else {
- AllDataTypesWithComplextType(
- s"$i",
- i,
- i.toLong,
- i.toFloat,
- i.toDouble,
- i.toShort,
- i.toByte,
- i % 2 == 0,
- new java.math.BigDecimal(i + ".56"),
- Date.valueOf(new
Date(System.currentTimeMillis()).toLocalDate.plusDays(i % 10)),
- Timestamp.valueOf(
- new
Timestamp(System.currentTimeMillis()).toLocalDateTime.plusDays(i % 10)),
- Seq.apply(i + 1, i + 2, i + 3),
- Seq.apply(Option.apply(i + 1), Option.empty, Option.apply(i + 3)),
- Map.apply((i + 1, i + 2), (i + 3, i + 4)),
- Map.empty
- )
- }
- }
- }
-
- protected def initializeTable(table_name: String, table_create_sql: String):
Unit = {
- spark.createDataFrame(genTestData()).createOrReplaceTempView("tmp_t")
- spark.sql(s"drop table IF EXISTS $table_name")
- spark.sql(table_create_sql)
- spark.sql("insert into %s select * from tmp_t".format(table_name))
- }
-
override protected def afterAll(): Unit = {
DeltaLog.clearCache()
diff --git
a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseWholeStageTransformerSuite.scala
b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseWholeStageTransformerSuite.scala
index 28da546fc..1890d00e3 100644
---
a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseWholeStageTransformerSuite.scala
+++
b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseWholeStageTransformerSuite.scala
@@ -16,6 +16,12 @@
*/
package io.glutenproject.execution
+import io.glutenproject.GlutenConfig
+import io.glutenproject.utils.UTSystemParameters
+
+import org.apache.spark.SparkConf
+import
org.apache.spark.sql.execution.datasources.v2.clickhouse.ClickHouseConfig
+
import org.apache.commons.io.FileUtils
import java.io.File
@@ -37,6 +43,15 @@ class GlutenClickHouseWholeStageTransformerSuite extends
WholeStageTransformerSu
}
}
+ override protected def sparkConf: SparkConf =
+ super.sparkConf
+ .set(GlutenConfig.GLUTEN_LIB_PATH,
UTSystemParameters.getClickHouseLibPath())
+ .set(
+ "spark.gluten.sql.columnar.backend.ch.use.v2",
+ ClickHouseConfig.DEFAULT_USE_DATASOURCE_V2)
+ .set("spark.gluten.sql.enable.native.validation", "false")
+ .set("spark.sql.warehouse.dir", warehouse)
+
override def beforeAll(): Unit = {
// prepare working paths
val basePathDir = new File(basePath)
@@ -56,6 +71,6 @@ class GlutenClickHouseWholeStageTransformerSuite extends
WholeStageTransformerSu
protected val hiveMetaStoreDB = metaStorePathAbsolute + "/metastore_db"
override protected val backend: String = "ch"
- override protected val resourcePath: String = ""
+ final override protected val resourcePath: String = "" // ch not need this
override protected val fileFormat: String = "parquet"
}
diff --git
a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickhouseCountDistinctSuite.scala
b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickhouseCountDistinctSuite.scala
new file mode 100644
index 000000000..4ba9882b6
--- /dev/null
+++
b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickhouseCountDistinctSuite.scala
@@ -0,0 +1,118 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.glutenproject.execution
+
+import io.glutenproject.execution.AllDataTypesWithComplexType.genTestData
+
+import org.apache.spark.SparkConf
+class GlutenClickhouseCountDistinctSuite extends
GlutenClickHouseWholeStageTransformerSuite {
+
+ override protected def sparkConf: SparkConf = {
+ super.sparkConf
+ .set("spark.gluten.sql.countDistinctWithoutExpand", "true")
+ .set("spark.sql.adaptive.enabled", "false")
+ }
+
+ test("check count distinct correctness") {
+ // simple case
+ var sql = "select count(distinct(a)) from values (1,1,1), (2,2,2) as
data(a,b,c)"
+ compareResultsAgainstVanillaSpark(sql, true, { _ => })
+
+ // with null
+ sql = "select count(distinct(a)) from " +
+ "values (1,1,1), (2,2,2), (1,3,3), (null,4,4), (null,5,5) as data(a,b,c)"
+ compareResultsAgainstVanillaSpark(sql, true, { _ => })
+
+ // three CD
+ sql = "select count(distinct(b)), count(distinct(a)),count(distinct c)
from " +
+ "values (0, null,1), (0,null,1), (1, 1,1), (2, 2, 1) ,(2,2,2) as
data(a,b,c)"
+ compareResultsAgainstVanillaSpark(sql, true, { _ => })
+
+ // count distinct with multiple args
+ sql = "select count(distinct(a,b)), count(distinct(a,b,c)) from " +
+ "values (0, null,1), (0,null,1), (1, 1,1), (2, 2, 1) ,(2,2,2) as
data(a,b,c)"
+ compareResultsAgainstVanillaSpark(sql, true, { _ => })
+ }
+
+ test("check count distinct execution plan") {
+ val sql =
+ "select count(distinct(b)), count(distinct a, b) from " +
+ "values (0, null,1), (1, 1,1), (2, 2,1), (1, 2,1) ,(2,2,2) as
data(a,b,c) group by c"
+
+ val df = spark.sql(sql)
+ WholeStageTransformerSuite.checkFallBack(df)
+
+ val planExecs = df.queryExecution.executedPlan.collect {
+ case aggTransformer: HashAggregateExecBaseTransformer => aggTransformer
+ }
+
+ planExecs.head.aggregateExpressions.foreach {
+ expr => assert(expr.toString().startsWith("countdistinct"))
+ }
+ planExecs(1).aggregateExpressions.foreach {
+ expr => assert(expr.toString().startsWith("partial_countdistinct"))
+ }
+ }
+
+ test("check all data types") {
+
spark.createDataFrame(genTestData()).createOrReplaceTempView("all_data_types")
+
+ // Vanilla does not support map
+ for (
+ field <-
AllDataTypesWithComplexType().getClass.getDeclaredFields.filterNot(
+ p => p.getName.startsWith("map"))
+ ) {
+ val sql = s"select count(distinct(${field.getName})) from all_data_types"
+ compareResultsAgainstVanillaSpark(sql, true, { _ => })
+ spark.sql(sql).show
+ }
+
+ // just test success run
+ for (
+ field <- AllDataTypesWithComplexType().getClass.getDeclaredFields.filter(
+ p => p.getName.startsWith("map"))
+ ) {
+ val sql = s"select count(distinct(${field.getName})) from all_data_types"
+ spark.sql(sql).show
+ }
+ }
+
+ test("check count distinct with agg fallback") {
+ // skewness agg is not supported, will cause fallback
+ val sql = "select count(distinct(a,b)) , skewness(b) from " +
+ "values (0, null,1), (0,null,1), (1, 1,1), (2, 2, 1) ,(2,2,2),(3,3,3) as
data(a,b,c)"
+ assertThrows[UnsupportedOperationException] {
+ spark.sql(sql).show
+ }
+ }
+
+ test("check count distinct with expr fallback") {
+ // try_add is not supported, will cause fallback after a project operator
+ val sql = s"""
+ select count(distinct(a,b)) , try_add(c,b) from
+ values (0, null,1), (0,null,2), (1, 1,4) as data(a,b,c) group by
try_add(c,b)
+ """;
+ val df = spark.sql(sql)
+ WholeStageTransformerSuite.checkFallBack(df, noFallback = false)
+ }
+
+ test("check count distinct with filter") {
+ val sql = "select count(distinct(a,b)) FILTER (where c <3) from " +
+ "values (0, null,1), (0,null,1), (1, 1,1), (2, 2, 1) ,(2,2,2),(3,3,3) as
data(a,b,c)"
+ compareResultsAgainstVanillaSpark(sql, true, { _ => })
+ }
+}
diff --git
a/backends-clickhouse/src/test/scala/io/glutenproject/execution/extension/CustomAggExpressionTransformer.scala
b/backends-clickhouse/src/test/scala/io/glutenproject/execution/extension/CustomAggExpressionTransformer.scala
index 76ec0d7ad..9fbc2aa3b 100644
---
a/backends-clickhouse/src/test/scala/io/glutenproject/execution/extension/CustomAggExpressionTransformer.scala
+++
b/backends-clickhouse/src/test/scala/io/glutenproject/execution/extension/CustomAggExpressionTransformer.scala
@@ -16,7 +16,7 @@
*/
package io.glutenproject.execution.extension
-import io.glutenproject.expression._
+import io.glutenproject.expression.Sig
import io.glutenproject.extension.ExpressionExtensionTrait
import org.apache.spark.sql.catalyst.expressions._
diff --git
a/cpp-ch/local-engine/AggregateFunctions/AggregateFunctionPartialMerge.h
b/cpp-ch/local-engine/AggregateFunctions/AggregateFunctionPartialMerge.h
index bf2becf32..f77c253f5 100644
--- a/cpp-ch/local-engine/AggregateFunctions/AggregateFunctionPartialMerge.h
+++ b/cpp-ch/local-engine/AggregateFunctions/AggregateFunctionPartialMerge.h
@@ -50,7 +50,8 @@ private:
public:
AggregateFunctionPartialMerge(const AggregateFunctionPtr & nested_, const
DataTypePtr & argument, const Array & params_)
- : IAggregateFunctionHelper<AggregateFunctionPartialMerge>({argument},
params_, createResultType(nested_)), nested_func(nested_)
+ : IAggregateFunctionHelper<AggregateFunctionPartialMerge>({argument},
params_, createResultType(nested_))
+ , nested_func(nested_)
{
const DataTypeAggregateFunction * data_type = typeid_cast<const
DataTypeAggregateFunction *>(argument.get());
@@ -115,5 +116,4 @@ public:
AggregateFunctionPtr getNestedFunction() const override { return
nested_func; }
};
-
}
diff --git a/cpp-ch/local-engine/Parser/AggregateFunctionParser.h
b/cpp-ch/local-engine/Parser/AggregateFunctionParser.h
index a9840eeef..e2444361f 100644
--- a/cpp-ch/local-engine/Parser/AggregateFunctionParser.h
+++ b/cpp-ch/local-engine/Parser/AggregateFunctionParser.h
@@ -29,7 +29,6 @@
namespace local_engine
{
-
class AggregateFunctionParser
{
public:
@@ -78,7 +77,11 @@ public:
}
};
- AggregateFunctionParser(SerializedPlanParser * plan_parser_) :
plan_parser(plan_parser_) { }
+ AggregateFunctionParser(SerializedPlanParser * plan_parser_)
+ : plan_parser(plan_parser_)
+ {
+ }
+
virtual ~AggregateFunctionParser() = default;
virtual String getName() const = 0;
@@ -93,6 +96,7 @@ public:
/// Do some preprojections for the function arguments, and return the
necessary arguments for the CH function.
virtual DB::ActionsDAG::NodeRawConstPtrs
parseFunctionArguments(const CommonFunctionInfo & func_info, const String
& ch_func_name, DB::ActionsDAGPtr & actions_dag) const;
+
DB::ActionsDAG::NodeRawConstPtrs parseFunctionArguments(const
CommonFunctionInfo & func_info, DB::ActionsDAGPtr & actions_dag) const
{
return parseFunctionArguments(func_info, getCHFunctionName(func_info),
actions_dag);
@@ -106,7 +110,9 @@ public:
/// Make a postprojection for the function result.
virtual const DB::ActionsDAG::Node * convertNodeTypeIfNeeded(
- const CommonFunctionInfo & func_info, const DB::ActionsDAG::Node *
func_node, DB::ActionsDAGPtr & actions_dag, bool withNullability) const;
+ const CommonFunctionInfo & func_info,
+ const DB::ActionsDAG::Node * func_node,
+ DB::ActionsDAGPtr & actions_dag, bool withNullability) const;
/// Parameters are only used in aggregate functions at present. e.g.
percentiles(0.5)(x).
/// 0.5 is the parameter of percentiles function.
@@ -159,6 +165,7 @@ protected:
SerializedPlanParser * plan_parser;
Poco::Logger * logger =
&Poco::Logger::get("AggregateFunctionParserFactory");
};
+
using AggregateFunctionParserPtr = std::shared_ptr<AggregateFunctionParser>;
using AggregateFunctionParserCreator =
std::function<AggregateFunctionParserPtr(SerializedPlanParser *)>;
@@ -200,5 +207,4 @@ struct AggregateFunctionParserRegister
{
AggregateFunctionParserRegister() {
AggregateFunctionParserFactory::instance().registerAggregateFunctionParser<Parser>(Parser::name);
}
};
-
}
diff --git
a/cpp-ch/local-engine/Parser/aggregate_function_parser/CommonAggregateFunctionParser.cpp
b/cpp-ch/local-engine/Parser/aggregate_function_parser/CommonAggregateFunctionParser.cpp
index fc334a214..afe65f793 100644
---
a/cpp-ch/local-engine/Parser/aggregate_function_parser/CommonAggregateFunctionParser.cpp
+++
b/cpp-ch/local-engine/Parser/aggregate_function_parser/CommonAggregateFunctionParser.cpp
@@ -47,5 +47,5 @@ REGISTER_COMMON_AGGREGATE_FUNCTION_PARSER(RowNumber,
row_number, row_number)
REGISTER_COMMON_AGGREGATE_FUNCTION_PARSER(Ntile, ntile, ntile)
REGISTER_COMMON_AGGREGATE_FUNCTION_PARSER(PercentRank, percent_rank,
percent_rank)
REGISTER_COMMON_AGGREGATE_FUNCTION_PARSER(CumeDist, cume_dist, cume_dist)
-
+REGISTER_COMMON_AGGREGATE_FUNCTION_PARSER(CountDistinct, count_distinct,
uniqExact)
}
diff --git
a/gluten-core/src/main/scala/io/glutenproject/expression/ExpressionMappings.scala
b/gluten-core/src/main/scala/io/glutenproject/expression/ExpressionMappings.scala
index b6a874756..a063c244b 100644
---
a/gluten-core/src/main/scala/io/glutenproject/expression/ExpressionMappings.scala
+++
b/gluten-core/src/main/scala/io/glutenproject/expression/ExpressionMappings.scala
@@ -251,6 +251,7 @@ object ExpressionMappings {
Sig[Sum](SUM),
Sig[Average](AVG),
Sig[Count](COUNT),
+ Sig[CountDistinct](COUNT_DISTINCT),
Sig[Min](MIN),
Sig[Max](MAX),
Sig[MaxBy](MAX_BY),
diff --git
a/gluten-core/src/main/scala/io/glutenproject/extension/CountDistinctWithoutExpand.scala
b/gluten-core/src/main/scala/io/glutenproject/extension/CountDistinctWithoutExpand.scala
new file mode 100644
index 000000000..c7f57030f
--- /dev/null
+++
b/gluten-core/src/main/scala/io/glutenproject/extension/CountDistinctWithoutExpand.scala
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.glutenproject.extension
+
+import io.glutenproject.GlutenConfig
+
+import
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression,
Count, CountDistinct}
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.trees.TreePattern.AGGREGATE_EXPRESSION
+
+/**
+ * By converting Count(with isDistinct=true) to a UDAF called CountDistinct,
we can avoid the Expand
+ * operator in the physical plan.
+ *
+ * This rule takes no effect unless spark.gluten.enabled and
+ * spark.gluten.sql.countDistinctWithoutExpand are both true
+ */
+object CountDistinctWithoutExpand extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = {
+ if (
+ GlutenConfig.getConf.enableGluten &&
GlutenConfig.getConf.enableCountDistinctWithoutExpand
+ ) {
+
plan.transformAllExpressionsWithPruning(_.containsPattern(AGGREGATE_EXPRESSION))
{
+ case ae: AggregateExpression if ae.isDistinct &&
ae.aggregateFunction.isInstanceOf[Count] =>
+ ae.copy(
+ aggregateFunction =
+
CountDistinct.apply(ae.aggregateFunction.asInstanceOf[Count].children),
+ isDistinct = false)
+ }
+ } else {
+ plan
+ }
+ }
+}
diff --git
a/gluten-core/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountDistinct.scala
b/gluten-core/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountDistinct.scala
new file mode 100644
index 000000000..54918fcc7
--- /dev/null
+++
b/gluten-core/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountDistinct.scala
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.expressions.aggregate
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.types._
+
+/**
+ * By default, spark execute count distinct as Expand + Count. This works
reliably but may be
+ * slower. We allow user to inject a custom count distinct function to speed
up the execution. Check
+ * the optimizer rule at CountDistinctWithoutExpand
+ */
+case class CountDistinct(children: Seq[Expression]) extends
DeclarativeAggregate {
+
+ override def nullable: Boolean = false
+
+ // Return data type.
+ override def dataType: DataType = LongType
+
+ protected lazy val cd = AttributeReference("count_distinct", LongType,
nullable = false)()
+
+ override lazy val aggBufferAttributes = cd :: Nil
+
+ override lazy val initialValues =
+ throw new UnsupportedOperationException(
+ "count distinct does not have non-columnar implementation")
+
+ override lazy val mergeExpressions =
+ throw new UnsupportedOperationException(
+ "count distinct does not have non-columnar implementation")
+
+ override lazy val evaluateExpression =
+ throw new UnsupportedOperationException(
+ "count distinct does not have non-columnar implementation")
+
+ override def defaultResult: Option[Literal] = Option(Literal(0L))
+
+ override lazy val updateExpressions =
+ throw new UnsupportedOperationException(
+ "count distinct does not have non-columnar implementation")
+
+ override protected def withNewChildrenInternal(
+ newChildren: IndexedSeq[Expression]): CountDistinct =
+ copy(children = newChildren)
+}
diff --git a/shims/common/src/main/scala/io/glutenproject/GlutenConfig.scala
b/shims/common/src/main/scala/io/glutenproject/GlutenConfig.scala
index 87f5fc4f0..8a8bd4152 100644
--- a/shims/common/src/main/scala/io/glutenproject/GlutenConfig.scala
+++ b/shims/common/src/main/scala/io/glutenproject/GlutenConfig.scala
@@ -94,6 +94,9 @@ class GlutenConfig(conf: SQLConf) extends Logging {
def enableCommonSubexpressionEliminate: Boolean =
conf.getConf(ENABLE_COMMON_SUBEXPRESSION_ELIMINATE)
+ def enableCountDistinctWithoutExpand: Boolean =
+ conf.getConf(ENABLE_COUNT_DISTINCT_WITHOUT_EXPAND)
+
def veloxOrcScanEnabled: Boolean =
conf.getConf(VELOX_ORC_SCAN_ENABLED)
@@ -1537,6 +1540,16 @@ object GlutenConfig {
.booleanConf
.createWithDefault(true)
+ val ENABLE_COUNT_DISTINCT_WITHOUT_EXPAND =
+ buildConf("spark.gluten.sql.countDistinctWithoutExpand")
+ .internal()
+ .doc(
+ "Convert Count Distinct to a UDAF called count_distinct to " +
+ "prevent SparkPlanner converting it to Expand+Count. WARNING: " +
+ "When enabled, count distinct queries will fail to fallback!!!")
+ .booleanConf
+ .createWithDefault(false)
+
val COLUMNAR_VELOX_BLOOM_FILTER_EXPECTED_NUM_ITEMS =
buildConf("spark.gluten.sql.columnar.backend.velox.bloomFilter.expectedNumItems")
.internal()
diff --git
a/shims/common/src/main/scala/io/glutenproject/expression/ExpressionNames.scala
b/shims/common/src/main/scala/io/glutenproject/expression/ExpressionNames.scala
index c688c8abd..d2fc4f9ec 100644
---
a/shims/common/src/main/scala/io/glutenproject/expression/ExpressionNames.scala
+++
b/shims/common/src/main/scala/io/glutenproject/expression/ExpressionNames.scala
@@ -22,6 +22,7 @@ object ExpressionNames {
final val SUM = "sum"
final val AVG = "avg"
final val COUNT = "count"
+ final val COUNT_DISTINCT = "count_distinct"
final val MIN = "min"
final val MAX = "max"
final val MAX_BY = "max_by"
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]