This is an automated email from the ASF dual-hosted git repository.

sunchao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new ed5de4b  feat: Support bitwise aggregate functions (#197)
ed5de4b is described below

commit ed5de4bcb8f8b64fb273e77dcb87f07b9f417984
Author: Huaxin Gao <[email protected]>
AuthorDate: Thu Mar 14 08:33:36 2024 -0700

    feat: Support bitwise aggregate functions (#197)
---
 core/src/execution/datafusion/planner.rs           | 21 ++++++--
 core/src/execution/proto/expr.proto                | 18 +++++++
 .../org/apache/comet/serde/QueryPlanSerde.scala    | 60 +++++++++++++++++++++-
 .../apache/comet/exec/CometAggregateSuite.scala    | 50 ++++++++++++++++++
 4 files changed, 145 insertions(+), 4 deletions(-)

diff --git a/core/src/execution/datafusion/planner.rs 
b/core/src/execution/datafusion/planner.rs
index ef2787f..d52ec80 100644
--- a/core/src/execution/datafusion/planner.rs
+++ b/core/src/execution/datafusion/planner.rs
@@ -27,9 +27,9 @@ use datafusion::{
     physical_expr::{
         execution_props::ExecutionProps,
         expressions::{
-            in_list, BinaryExpr, CaseExpr, CastExpr, Column, Count, 
FirstValue, InListExpr,
-            IsNotNullExpr, IsNullExpr, LastValue, Literal as 
DataFusionLiteral, Max, Min,
-            NegativeExpr, NotExpr, Sum, UnKnownColumn,
+            in_list, BinaryExpr, BitAnd, BitOr, BitXor, CaseExpr, CastExpr, 
Column, Count,
+            FirstValue, InListExpr, IsNotNullExpr, IsNullExpr, LastValue,
+            Literal as DataFusionLiteral, Max, Min, NegativeExpr, NotExpr, 
Sum, UnKnownColumn,
         },
         functions::create_physical_expr,
         AggregateExpr, PhysicalExpr, PhysicalSortExpr, ScalarFunctionExpr,
@@ -940,6 +940,21 @@ impl PhysicalPlanner {
                     vec![],
                 )))
             }
+            AggExprStruct::BitAndAgg(expr) => {
+                let child = self.create_expr(expr.child.as_ref().unwrap(), 
schema)?;
+                let datatype = 
to_arrow_datatype(expr.datatype.as_ref().unwrap());
+                Ok(Arc::new(BitAnd::new(child, "bit_and", datatype)))
+            }
+            AggExprStruct::BitOrAgg(expr) => {
+                let child = self.create_expr(expr.child.as_ref().unwrap(), 
schema)?;
+                let datatype = 
to_arrow_datatype(expr.datatype.as_ref().unwrap());
+                Ok(Arc::new(BitOr::new(child, "bit_or", datatype)))
+            }
+            AggExprStruct::BitXorAgg(expr) => {
+                let child = self.create_expr(expr.child.as_ref().unwrap(), 
schema)?;
+                let datatype = 
to_arrow_datatype(expr.datatype.as_ref().unwrap());
+                Ok(Arc::new(BitXor::new(child, "bit_xor", datatype)))
+            }
         }
     }
 
diff --git a/core/src/execution/proto/expr.proto 
b/core/src/execution/proto/expr.proto
index 8aa81b7..e8d35d1 100644
--- a/core/src/execution/proto/expr.proto
+++ b/core/src/execution/proto/expr.proto
@@ -88,6 +88,9 @@ message AggExpr {
     Avg avg = 6;
     First first = 7;
     Last last = 8;
+    BitAndAgg bitAndAgg = 9;
+    BitOrAgg bitOrAgg = 10;
+    BitXorAgg bitXorAgg = 11;
   }
 }
 
@@ -130,6 +133,21 @@ message Last {
   bool ignore_nulls = 3;
 }
 
+message BitAndAgg {
+  Expr child = 1;
+  DataType datatype = 2;
+}
+
+message BitOrAgg {
+  Expr child = 1;
+  DataType datatype = 2;
+}
+
+message BitXorAgg {
+  Expr child = 1;
+  DataType datatype = 2;
+}
+
 message Literal {
   oneof value {
     bool bool_val = 1;
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala 
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index 5da926e..87b4dff 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -23,7 +23,7 @@ import scala.collection.JavaConverters._
 
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.catalyst.expressions._
-import 
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, 
Average, Count, Final, First, Last, Max, Min, Partial, Sum}
+import 
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, 
Average, BitAndAgg, BitOrAgg, BitXorAgg, Count, Final, First, Last, Max, Min, 
Partial, Sum}
 import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
 import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero
 import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, 
Partitioning, SinglePartition}
@@ -188,6 +188,13 @@ object QueryPlanSerde extends Logging with 
ShimQueryPlanSerde {
     }
   }
 
+  private def bitwiseAggTypeSupported(dt: DataType): Boolean = {
+    dt match {
+      case _: IntegerType | LongType | ShortType | ByteType => true
+      case _ => false
+    }
+  }
+
   def aggExprToProto(
       aggExpr: AggregateExpression,
       inputs: Seq[Attribute],
@@ -328,6 +335,57 @@ object QueryPlanSerde extends Logging with 
ShimQueryPlanSerde {
         } else {
           None
         }
+      case bitAnd @ BitAndAgg(child) if 
bitwiseAggTypeSupported(bitAnd.dataType) =>
+        val childExpr = exprToProto(child, inputs, binding)
+        val dataType = serializeDataType(bitAnd.dataType)
+
+        if (childExpr.isDefined && dataType.isDefined) {
+          val bitAndBuilder = ExprOuterClass.BitAndAgg.newBuilder()
+          bitAndBuilder.setChild(childExpr.get)
+          bitAndBuilder.setDatatype(dataType.get)
+
+          Some(
+            ExprOuterClass.AggExpr
+              .newBuilder()
+              .setBitAndAgg(bitAndBuilder)
+              .build())
+        } else {
+          None
+        }
+      case bitOr @ BitOrAgg(child) if bitwiseAggTypeSupported(bitOr.dataType) 
=>
+        val childExpr = exprToProto(child, inputs, binding)
+        val dataType = serializeDataType(bitOr.dataType)
+
+        if (childExpr.isDefined && dataType.isDefined) {
+          val bitOrBuilder = ExprOuterClass.BitOrAgg.newBuilder()
+          bitOrBuilder.setChild(childExpr.get)
+          bitOrBuilder.setDatatype(dataType.get)
+
+          Some(
+            ExprOuterClass.AggExpr
+              .newBuilder()
+              .setBitOrAgg(bitOrBuilder)
+              .build())
+        } else {
+          None
+        }
+      case bitXor @ BitXorAgg(child) if 
bitwiseAggTypeSupported(bitXor.dataType) =>
+        val childExpr = exprToProto(child, inputs, binding)
+        val dataType = serializeDataType(bitXor.dataType)
+
+        if (childExpr.isDefined && dataType.isDefined) {
+          val bitXorBuilder = ExprOuterClass.BitXorAgg.newBuilder()
+          bitXorBuilder.setChild(childExpr.get)
+          bitXorBuilder.setDatatype(dataType.get)
+
+          Some(
+            ExprOuterClass.AggExpr
+              .newBuilder()
+              .setBitXorAgg(bitXorBuilder)
+              .build())
+        } else {
+          None
+        }
 
       case fn =>
         emitWarning(s"unsupported Spark aggregate function: $fn")
diff --git 
a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala 
b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
index 1dac14d..982d39f 100644
--- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
@@ -947,6 +947,56 @@ class CometAggregateSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
     }
   }
 
+  test("bitwise aggregate") {
+    withSQLConf(
+      CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true",
+      CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true") {
+      Seq(true, false).foreach { dictionary =>
+        withSQLConf("parquet.enable.dictionary" -> dictionary.toString) {
+          val table = "test"
+          withTable(table) {
+            sql(s"create table $table(col1 long, col2 int, col3 short, col4 
byte) using parquet")
+            sql(
+              s"insert into $table values(4, 1, 1, 3), (4, 1, 1, 3), (3, 3, 1, 
4)," +
+                " (2, 4, 2, 5), (1, 3, 2, 6), (null, 1, 1, 7)")
+            val expectedNumOfCometAggregates = 2
+            checkSparkAnswerAndNumOfAggregates(
+              "SELECT BIT_AND(col1), BIT_OR(col1), BIT_XOR(col1)," +
+                " BIT_AND(col2), BIT_OR(col2), BIT_XOR(col2)," +
+                " BIT_AND(col3), BIT_OR(col3), BIT_XOR(col3)," +
+                " BIT_AND(col4), BIT_OR(col4), BIT_XOR(col4) FROM test",
+              expectedNumOfCometAggregates)
+
+            // Make sure the combination of BITWISE aggregates and other 
aggregates work OK
+            checkSparkAnswerAndNumOfAggregates(
+              "SELECT BIT_AND(col1), BIT_OR(col1), BIT_XOR(col1)," +
+                " BIT_AND(col2), BIT_OR(col2), BIT_XOR(col2)," +
+                " BIT_AND(col3), BIT_OR(col3), BIT_XOR(col3)," +
+                " BIT_AND(col4), BIT_OR(col4), BIT_XOR(col4), MIN(col1), 
COUNT(col1) FROM test",
+              expectedNumOfCometAggregates)
+
+            checkSparkAnswerAndNumOfAggregates(
+              "SELECT BIT_AND(col1), BIT_OR(col1), BIT_XOR(col1)," +
+                " BIT_AND(col2), BIT_OR(col2), BIT_XOR(col2)," +
+                " BIT_AND(col3), BIT_OR(col3), BIT_XOR(col3)," +
+                " BIT_AND(col4), BIT_OR(col4), BIT_XOR(col4), col3 FROM test 
GROUP BY col3",
+              expectedNumOfCometAggregates)
+
+            // Make sure the combination of BITWISE aggregates and other 
aggregates work OK
+            // with group by
+            checkSparkAnswerAndNumOfAggregates(
+              "SELECT BIT_AND(col1), BIT_OR(col1), BIT_XOR(col1)," +
+                " BIT_AND(col2), BIT_OR(col2), BIT_XOR(col2)," +
+                " BIT_AND(col3), BIT_OR(col3), BIT_XOR(col3)," +
+                " BIT_AND(col4), BIT_OR(col4), BIT_XOR(col4)," +
+                " MIN(col1), COUNT(col1), col3 FROM test GROUP BY col3",
+              expectedNumOfCometAggregates)
+          }
+        }
+      }
+    }
+  }
+
   protected def checkSparkAnswerAndNumOfAggregates(query: String, 
numAggregates: Int): Unit = {
     val df = sql(query)
     checkSparkAnswer(df)

Reply via email to