This is an automated email from the ASF dual-hosted git repository.
viirya pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new c1a2746 fix: incorrect result on Comet multiple column distinct count
(#268)
c1a2746 is described below
commit c1a2746d74fe64574b653bcd08d801125fa1d60c
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Mon Apr 15 15:46:07 2024 -0700
fix: incorrect result on Comet multiple column distinct count (#268)
* fix: incorrect result on Comet multiple column distinct count
* Update core/src/execution/datafusion/planner.rs
Co-authored-by: Andy Grove <[email protected]>
---------
Co-authored-by: Andy Grove <[email protected]>
---
core/src/execution/datafusion/planner.rs | 12 ++++++++++--
.../org/apache/comet/exec/CometAggregateSuite.scala | 21 ++++++++++++++++++++-
2 files changed, 30 insertions(+), 3 deletions(-)
diff --git a/core/src/execution/datafusion/planner.rs
b/core/src/execution/datafusion/planner.rs
index ab83872..e53ebe7 100644
--- a/core/src/execution/datafusion/planner.rs
+++ b/core/src/execution/datafusion/planner.rs
@@ -1094,8 +1094,16 @@ impl PhysicalPlanner {
) -> Result<Arc<dyn AggregateExpr>, ExecutionError> {
match spark_expr.expr_struct.as_ref().unwrap() {
AggExprStruct::Count(expr) => {
- let child = self.create_expr(&expr.children[0], schema)?;
- Ok(Arc::new(Count::new(child, "count", DataType::Int64)))
+ let children = expr
+ .children
+ .iter()
+ .map(|child| self.create_expr(child, schema.clone()))
+ .collect::<Result<Vec<_>, _>>()?;
+ Ok(Arc::new(Count::new_with_multiple_exprs(
+ children,
+ "count",
+ DataType::Int64,
+ )))
}
AggExprStruct::Min(expr) => {
let child = self.create_expr(expr.child.as_ref().unwrap(),
schema)?;
diff --git
a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
index 89681d3..230ac36 100644
--- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql.{CometTestBase, DataFrame, Row}
import org.apache.spark.sql.catalyst.optimizer.EliminateSorts
import org.apache.spark.sql.comet.CometHashAggregateExec
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
-import org.apache.spark.sql.functions.sum
+import org.apache.spark.sql.functions.{count_distinct, sum}
import org.apache.spark.sql.internal.SQLConf
import org.apache.comet.CometConf
@@ -40,6 +40,25 @@ import
org.apache.comet.CometSparkSessionExtensions.isSpark34Plus
class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper {
import testImplicits._
+ test("multiple column distinct count") {
+ withSQLConf(
+ CometConf.COMET_ENABLED.key -> "true",
+ CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true",
+ CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true") {
+ val df1 = Seq(
+ ("a", "b", "c"),
+ ("a", "b", "c"),
+ ("a", "b", "d"),
+ ("x", "y", "z"),
+ ("x", "q", null.asInstanceOf[String]))
+ .toDF("key1", "key2", "key3")
+
+ checkSparkAnswer(df1.agg(count_distinct($"key1", $"key2")))
+ checkSparkAnswer(df1.agg(count_distinct($"key1", $"key2", $"key3")))
+ checkSparkAnswer(df1.groupBy($"key1").agg(count_distinct($"key2",
$"key3")))
+ }
+ }
+
test("Only trigger Comet Final aggregation on Comet partial aggregation") {
withTempView("lowerCaseData") {
lowerCaseData.createOrReplaceTempView("lowerCaseData")