This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 562a8772 Refactor UnaryExpr and MathExpr in protobuf (#1056)
562a8772 is described below

commit 562a877283d5b0d8eb6dfceb1470c100dbc0adfe
Author: Andy Grove <[email protected]>
AuthorDate: Wed Nov 6 08:16:18 2024 -0700

    Refactor UnaryExpr and MathExpr in protobuf (#1056)
---
 native/proto/src/proto/expr.proto                  |  68 +-----
 .../org/apache/comet/serde/QueryPlanSerde.scala    | 245 ++++++++-------------
 2 files changed, 101 insertions(+), 212 deletions(-)

diff --git a/native/proto/src/proto/expr.proto 
b/native/proto/src/proto/expr.proto
index b5b975a5..3a8193f4 100644
--- a/native/proto/src/proto/expr.proto
+++ b/native/proto/src/proto/expr.proto
@@ -28,10 +28,10 @@ message Expr {
   oneof expr_struct {
     Literal literal = 2;
     BoundReference bound = 3;
-    Add add = 4;
-    Subtract subtract = 5;
-    Multiply multiply = 6;
-    Divide divide = 7;
+    MathExpr add = 4;
+    MathExpr subtract = 5;
+    MathExpr multiply = 6;
+    MathExpr divide = 7;
     Cast cast = 8;
     BinaryExpr eq = 9;
     BinaryExpr neq = 10;
@@ -39,13 +39,13 @@ message Expr {
     BinaryExpr gt_eq = 12;
     BinaryExpr lt = 13;
     BinaryExpr lt_eq = 14;
-    IsNull is_null = 15;
-    IsNotNull is_not_null = 16;
+    UnaryExpr is_null = 15;
+    UnaryExpr is_not_null = 16;
     BinaryExpr and = 17;
     BinaryExpr or = 18;
     SortOrder sort_order = 19;
     Substring substring = 20;
-    StringSpace string_space = 21;
+    UnaryExpr string_space = 21;
     Hour hour = 22;
     Minute minute = 23;
     Second second = 24;
@@ -61,10 +61,10 @@ message Expr {
     BinaryExpr bitwiseAnd = 34;
     BinaryExpr bitwiseOr = 35;
     BinaryExpr bitwiseXor = 36;
-    Remainder remainder = 37;
+    MathExpr remainder = 37;
     CaseWhen caseWhen = 38;
     In in = 39;
-    Not not = 40;
+    UnaryExpr not = 40;
     UnaryMinus unary_minus = 41;
     BinaryExpr bitwiseShiftRight = 42;
     BinaryExpr bitwiseShiftLeft = 43;
@@ -72,7 +72,7 @@ message Expr {
     NormalizeNaNAndZero normalize_nan_and_zero = 45;
     TruncDate truncDate = 46;
     TruncTimestamp truncTimestamp = 47;
-    BitwiseNot bitwiseNot = 48;
+    UnaryExpr bitwiseNot = 48;
     Abs abs = 49;
     Subquery subquery = 50;
     UnboundReference unbound = 51;
@@ -220,35 +220,7 @@ message Literal {
    bool is_null = 12;
 }
 
-message Add {
-  Expr left = 1;
-  Expr right = 2;
-  bool fail_on_error = 3;
-  DataType return_type = 4;
-}
-
-message Subtract {
-  Expr left = 1;
-  Expr right = 2;
-  bool fail_on_error = 3;
-  DataType return_type = 4;
-}
-
-message Multiply {
-  Expr left = 1;
-  Expr right = 2;
-  bool fail_on_error = 3;
-  DataType return_type = 4;
-}
-
-message Divide {
-  Expr left = 1;
-  Expr right = 2;
-  bool fail_on_error = 3;
-  DataType return_type = 4;
-}
-
-message Remainder {
+message MathExpr {
   Expr left = 1;
   Expr right = 2;
   bool fail_on_error = 3;
@@ -274,11 +246,7 @@ message BinaryExpr {
   Expr right = 2;
 }
 
-message IsNull {
-  Expr child = 1;
-}
-
-message IsNotNull {
+message UnaryExpr {
   Expr child = 1;
 }
 
@@ -305,10 +273,6 @@ message Substring {
   int32 len = 3;
 }
 
-message StringSpace {
-  Expr child = 1;
-}
-
 message ToJson {
   Expr child = 1;
   string timezone = 2;
@@ -368,10 +332,6 @@ message NormalizeNaNAndZero {
   DataType datatype = 2;
 }
 
-message Not {
-  Expr child = 1;
-}
-
 message UnaryMinus {
   Expr child = 1;
   bool fail_on_error = 2;
@@ -394,10 +354,6 @@ message TruncTimestamp {
   string timezone = 3;
 }
 
-message BitwiseNot {
-  Expr child = 1;
-}
-
 message Abs {
   Expr child = 1;
   EvalMode eval_mode = 2;
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala 
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index ef4d2563..8bdc886d 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -932,26 +932,12 @@ object QueryPlanSerde extends Logging with 
ShimQueryPlanSerde with CometExprShim
           handleCast(child, inputs, dt, timeZoneId, evalMode(c))
 
         case add @ Add(left, right, _) if supportedDataType(left.dataType) =>
-          val leftExpr = exprToProtoInternal(left, inputs)
-          val rightExpr = exprToProtoInternal(right, inputs)
-
-          if (leftExpr.isDefined && rightExpr.isDefined) {
-            val addBuilder = ExprOuterClass.Add.newBuilder()
-            addBuilder.setLeft(leftExpr.get)
-            addBuilder.setRight(rightExpr.get)
-            addBuilder.setFailOnError(getFailOnError(add))
-            serializeDataType(add.dataType).foreach { t =>
-              addBuilder.setReturnType(t)
-            }
-
-            Some(
+          createMathExpression(left, right, inputs, add.dataType, 
getFailOnError(add)).map {
+            expr =>
               ExprOuterClass.Expr
                 .newBuilder()
-                .setAdd(addBuilder)
-                .build())
-          } else {
-            withInfo(add, left, right)
-            None
+                .setAdd(expr)
+                .build()
           }
 
         case add @ Add(left, _, _) if !supportedDataType(left.dataType) =>
@@ -959,26 +945,12 @@ object QueryPlanSerde extends Logging with 
ShimQueryPlanSerde with CometExprShim
           None
 
         case sub @ Subtract(left, right, _) if 
supportedDataType(left.dataType) =>
-          val leftExpr = exprToProtoInternal(left, inputs)
-          val rightExpr = exprToProtoInternal(right, inputs)
-
-          if (leftExpr.isDefined && rightExpr.isDefined) {
-            val builder = ExprOuterClass.Subtract.newBuilder()
-            builder.setLeft(leftExpr.get)
-            builder.setRight(rightExpr.get)
-            builder.setFailOnError(getFailOnError(sub))
-            serializeDataType(sub.dataType).foreach { t =>
-              builder.setReturnType(t)
-            }
-
-            Some(
+          createMathExpression(left, right, inputs, sub.dataType, 
getFailOnError(sub)).map {
+            expr =>
               ExprOuterClass.Expr
                 .newBuilder()
-                .setSubtract(builder)
-                .build())
-          } else {
-            withInfo(sub, left, right)
-            None
+                .setSubtract(expr)
+                .build()
           }
 
         case sub @ Subtract(left, _, _) if !supportedDataType(left.dataType) =>
@@ -987,26 +959,12 @@ object QueryPlanSerde extends Logging with 
ShimQueryPlanSerde with CometExprShim
 
         case mul @ Multiply(left, right, _)
             if supportedDataType(left.dataType) && 
!decimalBeforeSpark34(left.dataType) =>
-          val leftExpr = exprToProtoInternal(left, inputs)
-          val rightExpr = exprToProtoInternal(right, inputs)
-
-          if (leftExpr.isDefined && rightExpr.isDefined) {
-            val builder = ExprOuterClass.Multiply.newBuilder()
-            builder.setLeft(leftExpr.get)
-            builder.setRight(rightExpr.get)
-            builder.setFailOnError(getFailOnError(mul))
-            serializeDataType(mul.dataType).foreach { t =>
-              builder.setReturnType(t)
-            }
-
-            Some(
+          createMathExpression(left, right, inputs, mul.dataType, 
getFailOnError(mul)).map {
+            expr =>
               ExprOuterClass.Expr
                 .newBuilder()
-                .setMultiply(builder)
-                .build())
-          } else {
-            withInfo(mul, left, right)
-            None
+                .setMultiply(expr)
+                .build()
           }
 
         case mul @ Multiply(left, _, _) =>
@@ -1020,30 +978,19 @@ object QueryPlanSerde extends Logging with 
ShimQueryPlanSerde with CometExprShim
 
         case div @ Divide(left, right, _)
             if supportedDataType(left.dataType) && 
!decimalBeforeSpark34(left.dataType) =>
-          val leftExpr = exprToProtoInternal(left, inputs)
           // Datafusion now throws an exception for dividing by zero
           // See https://github.com/apache/arrow-datafusion/pull/6792
           // For now, use NullIf to swap zeros with nulls.
-          val rightExpr = exprToProtoInternal(nullIfWhenPrimitive(right), 
inputs)
-
-          if (leftExpr.isDefined && rightExpr.isDefined) {
-            val builder = ExprOuterClass.Divide.newBuilder()
-            builder.setLeft(leftExpr.get)
-            builder.setRight(rightExpr.get)
-            builder.setFailOnError(getFailOnError(div))
-            serializeDataType(div.dataType).foreach { t =>
-              builder.setReturnType(t)
-            }
+          val rightExpr = nullIfWhenPrimitive(right)
 
-            Some(
+          createMathExpression(left, rightExpr, inputs, div.dataType, 
getFailOnError(div)).map {
+            expr =>
               ExprOuterClass.Expr
                 .newBuilder()
-                .setDivide(builder)
-                .build())
-          } else {
-            withInfo(div, left, right)
-            None
+                .setDivide(expr)
+                .build()
           }
+
         case div @ Divide(left, _, _) =>
           if (!supportedDataType(left.dataType)) {
             withInfo(div, s"Unsupported datatype ${left.dataType}")
@@ -1055,27 +1002,16 @@ object QueryPlanSerde extends Logging with 
ShimQueryPlanSerde with CometExprShim
 
         case rem @ Remainder(left, right, _)
             if supportedDataType(left.dataType) && 
!decimalBeforeSpark34(left.dataType) =>
-          val leftExpr = exprToProtoInternal(left, inputs)
-          val rightExpr = exprToProtoInternal(nullIfWhenPrimitive(right), 
inputs)
-
-          if (leftExpr.isDefined && rightExpr.isDefined) {
-            val builder = ExprOuterClass.Remainder.newBuilder()
-            builder.setLeft(leftExpr.get)
-            builder.setRight(rightExpr.get)
-            builder.setFailOnError(getFailOnError(rem))
-            serializeDataType(rem.dataType).foreach { t =>
-              builder.setReturnType(t)
-            }
+          val rightExpr = nullIfWhenPrimitive(right)
 
-            Some(
+          createMathExpression(left, rightExpr, inputs, rem.dataType, 
getFailOnError(rem)).map {
+            expr =>
               ExprOuterClass.Expr
                 .newBuilder()
-                .setRemainder(builder)
-                .build())
-          } else {
-            withInfo(rem, left, right)
-            None
+                .setRemainder(expr)
+                .build()
           }
+
         case rem @ Remainder(left, _, _) =>
           if (!supportedDataType(left.dataType)) {
             withInfo(rem, s"Unsupported datatype ${left.dataType}")
@@ -1346,20 +1282,11 @@ object QueryPlanSerde extends Logging with 
ShimQueryPlanSerde with CometExprShim
           }
 
         case StringSpace(child) =>
-          val childExpr = exprToProtoInternal(child, inputs)
-
-          if (childExpr.isDefined) {
-            val builder = ExprOuterClass.StringSpace.newBuilder()
-            builder.setChild(childExpr.get)
-
-            Some(
-              ExprOuterClass.Expr
-                .newBuilder()
-                .setStringSpace(builder)
-                .build())
-          } else {
-            withInfo(expr, child)
-            None
+          createUnaryExpr(child, inputs).map { expr =>
+            ExprOuterClass.Expr
+              .newBuilder()
+              .setStringSpace(expr)
+              .build()
           }
 
         case Hour(child, timeZoneId) =>
@@ -1495,37 +1422,19 @@ object QueryPlanSerde extends Logging with 
ShimQueryPlanSerde with CometExprShim
           optExprWithInfo(optExpr, expr, child)
 
         case IsNull(child) =>
-          val childExpr = exprToProtoInternal(child, inputs)
-
-          if (childExpr.isDefined) {
-            val castBuilder = ExprOuterClass.IsNull.newBuilder()
-            castBuilder.setChild(childExpr.get)
-
-            Some(
-              ExprOuterClass.Expr
-                .newBuilder()
-                .setIsNull(castBuilder)
-                .build())
-          } else {
-            withInfo(expr, child)
-            None
+          createUnaryExpr(child, inputs).map { expr =>
+            ExprOuterClass.Expr
+              .newBuilder()
+              .setIsNull(expr)
+              .build()
           }
 
         case IsNotNull(child) =>
-          val childExpr = exprToProtoInternal(child, inputs)
-
-          if (childExpr.isDefined) {
-            val castBuilder = ExprOuterClass.IsNotNull.newBuilder()
-            castBuilder.setChild(childExpr.get)
-
-            Some(
-              ExprOuterClass.Expr
-                .newBuilder()
-                .setIsNotNull(castBuilder)
-                .build())
-          } else {
-            withInfo(expr, child)
-            None
+          createUnaryExpr(child, inputs).map { expr =>
+            ExprOuterClass.Expr
+              .newBuilder()
+              .setIsNotNull(expr)
+              .build()
           }
 
         case IsNaN(child) =>
@@ -2021,20 +1930,11 @@ object QueryPlanSerde extends Logging with 
ShimQueryPlanSerde with CometExprShim
           }
 
         case BitwiseNot(child) =>
-          val childExpr = exprToProtoInternal(child, inputs)
-
-          if (childExpr.isDefined) {
-            val builder = ExprOuterClass.BitwiseNot.newBuilder()
-            builder.setChild(childExpr.get)
-
-            Some(
-              ExprOuterClass.Expr
-                .newBuilder()
-                .setBitwiseNot(builder)
-                .build())
-          } else {
-            withInfo(expr, child)
-            None
+          createUnaryExpr(child, inputs).map { expr =>
+            ExprOuterClass.Expr
+              .newBuilder()
+              .setBitwiseNot(expr)
+              .build()
           }
 
         case BitwiseOr(left, right) =>
@@ -2101,18 +2001,11 @@ object QueryPlanSerde extends Logging with 
ShimQueryPlanSerde with CometExprShim
           in(expr, value, list, inputs, true)
 
         case Not(child) =>
-          val childExpr = exprToProtoInternal(child, inputs)
-          if (childExpr.isDefined) {
-            val builder = ExprOuterClass.Not.newBuilder()
-            builder.setChild(childExpr.get)
-            Some(
-              ExprOuterClass.Expr
-                .newBuilder()
-                .setNot(builder)
-                .build())
-          } else {
-            withInfo(expr, child)
-            None
+          createUnaryExpr(child, inputs).map { expr =>
+            ExprOuterClass.Expr
+              .newBuilder()
+              .setNot(expr)
+              .build()
           }
 
         case UnaryMinus(child, failOnError) =>
@@ -2394,6 +2287,22 @@ object QueryPlanSerde extends Logging with 
ShimQueryPlanSerde with CometExprShim
       }
     }
 
+    def createUnaryExpr(
+        child: Expression,
+        inputs: Seq[Attribute]): Option[ExprOuterClass.UnaryExpr] = {
+      val childExpr = exprToProtoInternal(child, inputs)
+      if (childExpr.isDefined) {
+        Some(
+          ExprOuterClass.UnaryExpr
+            .newBuilder()
+            .setChild(childExpr.get)
+            .build())
+      } else {
+        withInfo(expr, child)
+        None
+      }
+    }
+
     def createBinaryExpr(
         left: Expression,
         right: Expression,
@@ -2413,6 +2322,30 @@ object QueryPlanSerde extends Logging with 
ShimQueryPlanSerde with CometExprShim
       }
     }
 
+    def createMathExpression(
+        left: Expression,
+        right: Expression,
+        inputs: Seq[Attribute],
+        dataType: DataType,
+        failOnError: Boolean): Option[ExprOuterClass.MathExpr] = {
+      val leftExpr = exprToProtoInternal(left, inputs)
+      val rightExpr = exprToProtoInternal(right, inputs)
+
+      if (leftExpr.isDefined && rightExpr.isDefined) {
+        val builder = ExprOuterClass.MathExpr.newBuilder()
+        builder.setLeft(leftExpr.get)
+        builder.setRight(rightExpr.get)
+        builder.setFailOnError(failOnError)
+        serializeDataType(dataType).foreach { t =>
+          builder.setReturnType(t)
+        }
+        Some(builder.build())
+      } else {
+        withInfo(expr, left, right)
+        None
+      }
+    }
+
     def trim(
         expr: Expression, // parent expression
         srcStr: Expression,


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to