This is an automated email from the ASF dual-hosted git repository.

sunchao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new fbe7f80  feat: Support `First`/`Last` aggregate functions (#97)
fbe7f80 is described below

commit fbe7f80cc4045516822e8dfa8e89a944c757c2f2
Author: Huaxin Gao <huaxin.ga...@gmail.com>
AuthorDate: Fri Feb 23 08:39:54 2024 -0800

    feat: Support `First`/`Last` aggregate functions (#97)
    
    
    Co-authored-by: Huaxin Gao <huaxin....@apple.com>
---
 core/src/execution/datafusion/planner.rs           | 25 +++++-
 core/src/execution/proto/expr.proto                | 14 ++++
 .../org/apache/comet/serde/QueryPlanSerde.scala    | 38 ++++++++-
 .../apache/comet/exec/CometAggregateSuite.scala    | 95 ++++++++++++++++++----
 4 files changed, 156 insertions(+), 16 deletions(-)

diff --git a/core/src/execution/datafusion/planner.rs 
b/core/src/execution/datafusion/planner.rs
index 2feaace..66a29cb 100644
--- a/core/src/execution/datafusion/planner.rs
+++ b/core/src/execution/datafusion/planner.rs
@@ -42,7 +42,8 @@ use datafusion_common::ScalarValue;
 use datafusion_physical_expr::{
     execution_props::ExecutionProps,
     expressions::{
-        CaseExpr, CastExpr, Count, InListExpr, IsNullExpr, Max, Min, 
NegativeExpr, NotExpr, Sum,
+        CaseExpr, CastExpr, Count, FirstValue, InListExpr, IsNullExpr, 
LastValue, Max, Min,
+        NegativeExpr, NotExpr, Sum,
     },
     AggregateExpr, ScalarFunctionExpr,
 };
@@ -900,6 +901,28 @@ impl PhysicalPlanner {
                     }
                 }
             }
+            AggExprStruct::First(expr) => {
+                let child = self.create_expr(expr.child.as_ref().unwrap(), 
schema)?;
+                let datatype = 
to_arrow_datatype(expr.datatype.as_ref().unwrap());
+                Ok(Arc::new(FirstValue::new(
+                    child,
+                    "first",
+                    datatype,
+                    vec![],
+                    vec![],
+                )))
+            }
+            AggExprStruct::Last(expr) => {
+                let child = self.create_expr(expr.child.as_ref().unwrap(), 
schema)?;
+                let datatype = 
to_arrow_datatype(expr.datatype.as_ref().unwrap());
+                Ok(Arc::new(LastValue::new(
+                    child,
+                    "last",
+                    datatype,
+                    vec![],
+                    vec![],
+                )))
+            }
         }
     }
 
diff --git a/core/src/execution/proto/expr.proto 
b/core/src/execution/proto/expr.proto
index 53035b8..a80335c 100644
--- a/core/src/execution/proto/expr.proto
+++ b/core/src/execution/proto/expr.proto
@@ -85,6 +85,8 @@ message AggExpr {
     Min min = 4;
     Max max = 5;
     Avg avg = 6;
+    First first = 7;
+    Last last = 8;
   }
 }
 
@@ -115,6 +117,18 @@ message Avg {
   bool fail_on_error = 4; // currently unused (useful for deciding Ansi vs 
Legacy mode)
 }
 
+message First {
+  Expr child = 1;
+  DataType datatype = 2;
+  bool ignore_nulls = 3;
+}
+
+message Last {
+  Expr child = 1;
+  DataType datatype = 2;
+  bool ignore_nulls = 3;
+}
+
 message Literal {
   oneof value {
     bool bool_val = 1;
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala 
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index 938e49f..fcc0ca9 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -23,7 +23,7 @@ import scala.collection.JavaConverters._
 
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.catalyst.expressions._
-import 
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, 
Average, Count, Final, Max, Min, Partial, Sum}
+import 
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, 
Average, Count, Final, First, Last, Max, Min, Partial, Sum}
 import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
 import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero
 import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, 
Partitioning, SinglePartition}
@@ -287,6 +287,42 @@ object QueryPlanSerde extends Logging with 
ShimQueryPlanSerde {
         } else {
           None
         }
+      case first @ First(child, ignoreNulls)
+          if !ignoreNulls => // DataFusion doesn't support ignoreNulls true
+        val childExpr = exprToProto(child, inputs)
+        val dataType = serializeDataType(first.dataType)
+
+        if (childExpr.isDefined && dataType.isDefined) {
+          val firstBuilder = ExprOuterClass.First.newBuilder()
+          firstBuilder.setChild(childExpr.get)
+          firstBuilder.setDatatype(dataType.get)
+
+          Some(
+            ExprOuterClass.AggExpr
+              .newBuilder()
+              .setFirst(firstBuilder)
+              .build())
+        } else {
+          None
+        }
+      case last @ Last(child, ignoreNulls)
+          if !ignoreNulls => // DataFusion doesn't support ignoreNulls true
+        val childExpr = exprToProto(child, inputs)
+        val dataType = serializeDataType(last.dataType)
+
+        if (childExpr.isDefined && dataType.isDefined) {
+          val lastBuilder = ExprOuterClass.Last.newBuilder()
+          lastBuilder.setChild(childExpr.get)
+          lastBuilder.setDatatype(dataType.get)
+
+          Some(
+            ExprOuterClass.AggExpr
+              .newBuilder()
+              .setLast(lastBuilder)
+              .build())
+        } else {
+          None
+        }
 
       case fn =>
         emitWarning(s"unsupported Spark aggregate function: $fn")
diff --git 
a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala 
b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
index 9b7fb3c..04735b5 100644
--- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
@@ -438,8 +438,12 @@ class CometAggregateSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
               (0 until numValues).map(i => (i, Random.nextInt() % numGroups)),
               "tbl",
               dictionaryEnabled) {
-              checkSparkAnswer(
-                "SELECT _2, SUM(_1), SUM(DISTINCT _1), MIN(_1), MAX(_1), 
COUNT(_1), COUNT(DISTINCT _1), AVG(_1) FROM tbl GROUP BY _2")
+              withView("v") {
+                sql("CREATE TEMP VIEW v AS SELECT _1, _2 FROM tbl ORDER BY _1")
+                checkSparkAnswer(
+                  "SELECT _2, SUM(_1), SUM(DISTINCT _1), MIN(_1), MAX(_1), 
COUNT(_1)," +
+                    " COUNT(DISTINCT _1), AVG(_1), FIRST(_1), LAST(_1) FROM v 
GROUP BY _2")
+              }
             }
           }
         }
@@ -458,6 +462,11 @@ class CometAggregateSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
               val path = new Path(dir.toURI.toString, "test.parquet")
               makeParquetFile(path, numValues, numGroups, dictionaryEnabled)
               withParquetTable(path.toUri.toString, "tbl") {
+                withView("v") {
+                  sql("CREATE TEMP VIEW v AS SELECT _g1, _g2, _3 FROM tbl 
ORDER BY _3")
+                  checkSparkAnswer("SELECT _g1, _g2, FIRST(_3) FROM v GROUP BY 
_g1, _g2")
+                  checkSparkAnswer("SELECT _g1, _g2, LAST(_3) FROM v GROUP BY 
_g1, _g2")
+                }
                 checkSparkAnswer("SELECT _g1, _g2, SUM(_3) FROM tbl GROUP BY 
_g1, _g2")
                 checkSparkAnswer("SELECT _g1, _g2, COUNT(_3) FROM tbl GROUP BY 
_g1, _g2")
                 checkSparkAnswer("SELECT _g1, _g2, SUM(DISTINCT _3) FROM tbl 
GROUP BY _g1, _g2")
@@ -491,6 +500,12 @@ class CometAggregateSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
               val path = new Path(dir.toURI.toString, "test.parquet")
               makeParquetFile(path, numValues, numGroups, dictionaryEnabled)
               withParquetTable(path.toUri.toString, "tbl") {
+                withView("v") {
+                  sql("CREATE TEMP VIEW v AS SELECT _g3, _g4, _3, _4 FROM tbl 
ORDER BY _3, _4")
+                  checkSparkAnswer(
+                    "SELECT _g3, _g4, FIRST(_3), FIRST(_4) FROM v GROUP BY 
_g3, _g4")
+                  checkSparkAnswer("SELECT _g3, _g4, LAST(_3), LAST(_4) FROM v 
GROUP BY _g3, _g4")
+                }
                 checkSparkAnswer("SELECT _g3, _g4, SUM(_3), SUM(_4) FROM tbl 
GROUP BY _g3, _g4")
                 checkSparkAnswer(
                   "SELECT _g3, _g4, SUM(DISTINCT _3), SUM(DISTINCT _4) FROM 
tbl GROUP BY _g3, _g4")
@@ -524,6 +539,11 @@ class CometAggregateSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
                 // Test all combinations of different aggregation & group-by 
types
                 (1 to 4).foreach { col =>
                   (1 to 14).foreach { gCol =>
+                    withView("v") {
+                      sql(s"CREATE TEMP VIEW v AS SELECT _g$gCol, _$col FROM 
tbl ORDER BY _$col")
+                      checkSparkAnswer(s"SELECT _g$gCol, FIRST(_$col) FROM v 
GROUP BY _g$gCol")
+                      checkSparkAnswer(s"SELECT _g$gCol, LAST(_$col) FROM v 
GROUP BY _g$gCol")
+                    }
                     checkSparkAnswer(s"SELECT _g$gCol, SUM(_$col) FROM tbl 
GROUP BY _g$gCol")
                     checkSparkAnswer(
                       s"SELECT _g$gCol, SUM(DISTINCT _$col) FROM tbl GROUP BY 
_g$gCol")
@@ -771,9 +791,9 @@ class CometAggregateSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
 
   test("distinct") {
     withSQLConf(CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true") {
-      Seq(true, false).foreach { bosonColumnShuffleEnabled =>
+      Seq(true, false).foreach { cometColumnShuffleEnabled =>
         withSQLConf(
-          CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> 
bosonColumnShuffleEnabled.toString) {
+          CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> 
cometColumnShuffleEnabled.toString) {
           Seq(true, false).foreach { dictionary =>
             withSQLConf("parquet.enable.dictionary" -> dictionary.toString) {
               val table = "test"
@@ -782,40 +802,40 @@ class CometAggregateSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
                 sql(
                   s"insert into $table values(1, 1, 1), (1, 1, 1), (1, 3, 1), 
(1, 4, 2), (5, 3, 2)")
 
-                var expectedNumOfBosonAggregates = 2
+                var expectedNumOfCometAggregates = 2
 
                 checkSparkAnswerAndNumOfAggregates(
                   s"SELECT DISTINCT(col2) FROM $table",
-                  expectedNumOfBosonAggregates)
+                  expectedNumOfCometAggregates)
 
-                expectedNumOfBosonAggregates = 4
+                expectedNumOfCometAggregates = 4
 
                 checkSparkAnswerAndNumOfAggregates(
                   s"SELECT COUNT(distinct col2) FROM $table",
-                  expectedNumOfBosonAggregates)
+                  expectedNumOfCometAggregates)
 
                 checkSparkAnswerAndNumOfAggregates(
                   s"SELECT COUNT(distinct col2), col1 FROM $table group by 
col1",
-                  expectedNumOfBosonAggregates)
+                  expectedNumOfCometAggregates)
 
                 checkSparkAnswerAndNumOfAggregates(
                   s"SELECT SUM(distinct col2) FROM $table",
-                  expectedNumOfBosonAggregates)
+                  expectedNumOfCometAggregates)
 
                 checkSparkAnswerAndNumOfAggregates(
                   s"SELECT SUM(distinct col2), col1 FROM $table group by col1",
-                  expectedNumOfBosonAggregates)
+                  expectedNumOfCometAggregates)
 
                 checkSparkAnswerAndNumOfAggregates(
                   "SELECT COUNT(distinct col2), SUM(distinct col2), col1, 
COUNT(distinct col2)," +
                     s" SUM(distinct col2) FROM $table group by col1",
-                  expectedNumOfBosonAggregates)
+                  expectedNumOfCometAggregates)
 
-                expectedNumOfBosonAggregates = 1
+                expectedNumOfCometAggregates = 1
                 checkSparkAnswerAndNumOfAggregates(
                   "SELECT COUNT(col2), MIN(col2), COUNT(DISTINCT col2), 
SUM(col2)," +
                     s" SUM(DISTINCT col2), COUNT(DISTINCT col2), col1 FROM 
$table group by col1",
-                  expectedNumOfBosonAggregates)
+                  expectedNumOfCometAggregates)
               }
             }
           }
@@ -824,6 +844,53 @@ class CometAggregateSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
     }
   }
 
+  test("first/last") {
+    withSQLConf(
+      CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true",
+      CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true") {
+      Seq(true, false).foreach { dictionary =>
+        withSQLConf("parquet.enable.dictionary" -> dictionary.toString) {
+          val table = "test"
+          withTable(table) {
+            sql(s"create table $table(col1 int, col2 int, col3 int) using 
parquet")
+            sql(
+              s"insert into $table values(4, 1, 1), (4, 1, 1), (3, 3, 1)," +
+                " (2, 4, 2), (1, 3, 2), (null, 1, 1)")
+            withView("t") {
+              sql("CREATE VIEW t AS SELECT col1, col3 FROM test ORDER BY col1")
+
+              var expectedNumOfCometAggregates = 2
+              checkSparkAnswerAndNumOfAggregates(
+                "SELECT FIRST(col1), LAST(col1) FROM t",
+                expectedNumOfCometAggregates)
+
+              checkSparkAnswerAndNumOfAggregates(
+                "SELECT FIRST(col1), LAST(col1), MIN(col1), COUNT(col1) FROM 
t",
+                expectedNumOfCometAggregates)
+
+              checkSparkAnswerAndNumOfAggregates(
+                "SELECT FIRST(col1), LAST(col1), col3 FROM t GROUP BY col3",
+                expectedNumOfCometAggregates)
+
+              checkSparkAnswerAndNumOfAggregates(
+                "SELECT FIRST(col1), LAST(col1), MIN(col1), COUNT(col1), col3 
FROM t GROUP BY col3",
+                expectedNumOfCometAggregates)
+
+              expectedNumOfCometAggregates = 0
+              checkSparkAnswerAndNumOfAggregates(
+                "SELECT FIRST(col1, true), LAST(col1) FROM t",
+                expectedNumOfCometAggregates)
+
+              checkSparkAnswerAndNumOfAggregates(
+                "SELECT FIRST(col1), LAST(col1, true), col3 FROM t GROUP BY 
col3",
+                expectedNumOfCometAggregates)
+            }
+          }
+        }
+      }
+    }
+  }
+
   protected def checkSparkAnswerAndNumOfAggregates(query: String, 
numAggregates: Int): Unit = {
     val df = sql(query)
     checkSparkAnswer(df)

Reply via email to