This is an automated email from the ASF dual-hosted git repository.
sunchao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new ed5de4b feat: Support bitwise aggregate functions (#197)
ed5de4b is described below
commit ed5de4bcb8f8b64fb273e77dcb87f07b9f417984
Author: Huaxin Gao <[email protected]>
AuthorDate: Thu Mar 14 08:33:36 2024 -0700
feat: Support bitwise aggregate functions (#197)
---
core/src/execution/datafusion/planner.rs | 21 ++++++--
core/src/execution/proto/expr.proto | 18 +++++++
.../org/apache/comet/serde/QueryPlanSerde.scala | 60 +++++++++++++++++++++-
.../apache/comet/exec/CometAggregateSuite.scala | 50 ++++++++++++++++++
4 files changed, 145 insertions(+), 4 deletions(-)
diff --git a/core/src/execution/datafusion/planner.rs
b/core/src/execution/datafusion/planner.rs
index ef2787f..d52ec80 100644
--- a/core/src/execution/datafusion/planner.rs
+++ b/core/src/execution/datafusion/planner.rs
@@ -27,9 +27,9 @@ use datafusion::{
physical_expr::{
execution_props::ExecutionProps,
expressions::{
- in_list, BinaryExpr, CaseExpr, CastExpr, Column, Count,
FirstValue, InListExpr,
- IsNotNullExpr, IsNullExpr, LastValue, Literal as
DataFusionLiteral, Max, Min,
- NegativeExpr, NotExpr, Sum, UnKnownColumn,
+ in_list, BinaryExpr, BitAnd, BitOr, BitXor, CaseExpr, CastExpr,
Column, Count,
+ FirstValue, InListExpr, IsNotNullExpr, IsNullExpr, LastValue,
+ Literal as DataFusionLiteral, Max, Min, NegativeExpr, NotExpr,
Sum, UnKnownColumn,
},
functions::create_physical_expr,
AggregateExpr, PhysicalExpr, PhysicalSortExpr, ScalarFunctionExpr,
@@ -940,6 +940,21 @@ impl PhysicalPlanner {
vec![],
)))
}
+ AggExprStruct::BitAndAgg(expr) => {
+ let child = self.create_expr(expr.child.as_ref().unwrap(),
schema)?;
+ let datatype =
to_arrow_datatype(expr.datatype.as_ref().unwrap());
+ Ok(Arc::new(BitAnd::new(child, "bit_and", datatype)))
+ }
+ AggExprStruct::BitOrAgg(expr) => {
+ let child = self.create_expr(expr.child.as_ref().unwrap(),
schema)?;
+ let datatype =
to_arrow_datatype(expr.datatype.as_ref().unwrap());
+ Ok(Arc::new(BitOr::new(child, "bit_or", datatype)))
+ }
+ AggExprStruct::BitXorAgg(expr) => {
+ let child = self.create_expr(expr.child.as_ref().unwrap(),
schema)?;
+ let datatype =
to_arrow_datatype(expr.datatype.as_ref().unwrap());
+ Ok(Arc::new(BitXor::new(child, "bit_xor", datatype)))
+ }
}
}
diff --git a/core/src/execution/proto/expr.proto
b/core/src/execution/proto/expr.proto
index 8aa81b7..e8d35d1 100644
--- a/core/src/execution/proto/expr.proto
+++ b/core/src/execution/proto/expr.proto
@@ -88,6 +88,9 @@ message AggExpr {
Avg avg = 6;
First first = 7;
Last last = 8;
+ BitAndAgg bitAndAgg = 9;
+ BitOrAgg bitOrAgg = 10;
+ BitXorAgg bitXorAgg = 11;
}
}
@@ -130,6 +133,21 @@ message Last {
bool ignore_nulls = 3;
}
+message BitAndAgg {
+ Expr child = 1;
+ DataType datatype = 2;
+}
+
+message BitOrAgg {
+ Expr child = 1;
+ DataType datatype = 2;
+}
+
+message BitXorAgg {
+ Expr child = 1;
+ DataType datatype = 2;
+}
+
message Literal {
oneof value {
bool bool_val = 1;
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index 5da926e..87b4dff 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -23,7 +23,7 @@ import scala.collection.JavaConverters._
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions._
-import
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression,
Average, Count, Final, First, Last, Max, Min, Partial, Sum}
+import
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression,
Average, BitAndAgg, BitOrAgg, BitXorAgg, Count, Final, First, Last, Max, Min,
Partial, Sum}
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning,
Partitioning, SinglePartition}
@@ -188,6 +188,13 @@ object QueryPlanSerde extends Logging with
ShimQueryPlanSerde {
}
}
+ private def bitwiseAggTypeSupported(dt: DataType): Boolean = {
+ dt match {
+ case _: IntegerType | LongType | ShortType | ByteType => true
+ case _ => false
+ }
+ }
+
def aggExprToProto(
aggExpr: AggregateExpression,
inputs: Seq[Attribute],
@@ -328,6 +335,57 @@ object QueryPlanSerde extends Logging with
ShimQueryPlanSerde {
} else {
None
}
+ case bitAnd @ BitAndAgg(child) if
bitwiseAggTypeSupported(bitAnd.dataType) =>
+ val childExpr = exprToProto(child, inputs, binding)
+ val dataType = serializeDataType(bitAnd.dataType)
+
+ if (childExpr.isDefined && dataType.isDefined) {
+ val bitAndBuilder = ExprOuterClass.BitAndAgg.newBuilder()
+ bitAndBuilder.setChild(childExpr.get)
+ bitAndBuilder.setDatatype(dataType.get)
+
+ Some(
+ ExprOuterClass.AggExpr
+ .newBuilder()
+ .setBitAndAgg(bitAndBuilder)
+ .build())
+ } else {
+ None
+ }
+ case bitOr @ BitOrAgg(child) if bitwiseAggTypeSupported(bitOr.dataType)
=>
+ val childExpr = exprToProto(child, inputs, binding)
+ val dataType = serializeDataType(bitOr.dataType)
+
+ if (childExpr.isDefined && dataType.isDefined) {
+ val bitOrBuilder = ExprOuterClass.BitOrAgg.newBuilder()
+ bitOrBuilder.setChild(childExpr.get)
+ bitOrBuilder.setDatatype(dataType.get)
+
+ Some(
+ ExprOuterClass.AggExpr
+ .newBuilder()
+ .setBitOrAgg(bitOrBuilder)
+ .build())
+ } else {
+ None
+ }
+ case bitXor @ BitXorAgg(child) if
bitwiseAggTypeSupported(bitXor.dataType) =>
+ val childExpr = exprToProto(child, inputs, binding)
+ val dataType = serializeDataType(bitXor.dataType)
+
+ if (childExpr.isDefined && dataType.isDefined) {
+ val bitXorBuilder = ExprOuterClass.BitXorAgg.newBuilder()
+ bitXorBuilder.setChild(childExpr.get)
+ bitXorBuilder.setDatatype(dataType.get)
+
+ Some(
+ ExprOuterClass.AggExpr
+ .newBuilder()
+ .setBitXorAgg(bitXorBuilder)
+ .build())
+ } else {
+ None
+ }
case fn =>
emitWarning(s"unsupported Spark aggregate function: $fn")
diff --git
a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
index 1dac14d..982d39f 100644
--- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
@@ -947,6 +947,56 @@ class CometAggregateSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
}
}
+ test("bitwise aggregate") {
+ withSQLConf(
+ CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true",
+ CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true") {
+ Seq(true, false).foreach { dictionary =>
+ withSQLConf("parquet.enable.dictionary" -> dictionary.toString) {
+ val table = "test"
+ withTable(table) {
+ sql(s"create table $table(col1 long, col2 int, col3 short, col4
byte) using parquet")
+ sql(
+ s"insert into $table values(4, 1, 1, 3), (4, 1, 1, 3), (3, 3, 1,
4)," +
+ " (2, 4, 2, 5), (1, 3, 2, 6), (null, 1, 1, 7)")
+ val expectedNumOfCometAggregates = 2
+ checkSparkAnswerAndNumOfAggregates(
+ "SELECT BIT_AND(col1), BIT_OR(col1), BIT_XOR(col1)," +
+ " BIT_AND(col2), BIT_OR(col2), BIT_XOR(col2)," +
+ " BIT_AND(col3), BIT_OR(col3), BIT_XOR(col3)," +
+ " BIT_AND(col4), BIT_OR(col4), BIT_XOR(col4) FROM test",
+ expectedNumOfCometAggregates)
+
+ // Make sure the combination of BITWISE aggregates and other
aggregates work OK
+ checkSparkAnswerAndNumOfAggregates(
+ "SELECT BIT_AND(col1), BIT_OR(col1), BIT_XOR(col1)," +
+ " BIT_AND(col2), BIT_OR(col2), BIT_XOR(col2)," +
+ " BIT_AND(col3), BIT_OR(col3), BIT_XOR(col3)," +
+ " BIT_AND(col4), BIT_OR(col4), BIT_XOR(col4), MIN(col1),
COUNT(col1) FROM test",
+ expectedNumOfCometAggregates)
+
+ checkSparkAnswerAndNumOfAggregates(
+ "SELECT BIT_AND(col1), BIT_OR(col1), BIT_XOR(col1)," +
+ " BIT_AND(col2), BIT_OR(col2), BIT_XOR(col2)," +
+ " BIT_AND(col3), BIT_OR(col3), BIT_XOR(col3)," +
+ " BIT_AND(col4), BIT_OR(col4), BIT_XOR(col4), col3 FROM test
GROUP BY col3",
+ expectedNumOfCometAggregates)
+
+ // Make sure the combination of BITWISE aggregates and other
aggregates work OK
+ // with group by
+ checkSparkAnswerAndNumOfAggregates(
+ "SELECT BIT_AND(col1), BIT_OR(col1), BIT_XOR(col1)," +
+ " BIT_AND(col2), BIT_OR(col2), BIT_XOR(col2)," +
+ " BIT_AND(col3), BIT_OR(col3), BIT_XOR(col3)," +
+ " BIT_AND(col4), BIT_OR(col4), BIT_XOR(col4)," +
+ " MIN(col1), COUNT(col1), col3 FROM test GROUP BY col3",
+ expectedNumOfCometAggregates)
+ }
+ }
+ }
+ }
+ }
+
protected def checkSparkAnswerAndNumOfAggregates(query: String,
numAggregates: Int): Unit = {
val df = sql(query)
checkSparkAnswer(df)