This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 69760a37 minor: Refactor binary expr serde to reduce code duplication
(#1053)
69760a37 is described below
commit 69760a37d474bd2d243ce35eef22de6a5a9b1348
Author: Andy Grove <[email protected]>
AuthorDate: Tue Nov 5 08:22:03 2024 -0700
minor: Refactor binary expr serde to reduce code duplication (#1053)
* Use one BinaryExpr definition in protobuf
* refactor And
* refactor remaining binary expressions
* update test
* update test
---
native/core/src/execution/datafusion/planner.rs | 2 +-
native/proto/src/proto/expr.proto | 137 +------
.../org/apache/comet/serde/QueryPlanSerde.scala | 438 ++++++---------------
3 files changed, 142 insertions(+), 435 deletions(-)
diff --git a/native/core/src/execution/datafusion/planner.rs
b/native/core/src/execution/datafusion/planner.rs
index 5b53cb39..e2ea3863 100644
--- a/native/core/src/execution/datafusion/planner.rs
+++ b/native/core/src/execution/datafusion/planner.rs
@@ -2388,7 +2388,7 @@ mod tests {
};
let expr = spark_expression::Expr {
- expr_struct: Some(Eq(Box::new(spark_expression::Equal {
+ expr_struct: Some(Eq(Box::new(spark_expression::BinaryExpr {
left: Some(Box::new(left)),
right: Some(Box::new(right)),
}))),
diff --git a/native/proto/src/proto/expr.proto
b/native/proto/src/proto/expr.proto
index 796ca5be..b5b975a5 100644
--- a/native/proto/src/proto/expr.proto
+++ b/native/proto/src/proto/expr.proto
@@ -33,16 +33,16 @@ message Expr {
Multiply multiply = 6;
Divide divide = 7;
Cast cast = 8;
- Equal eq = 9;
- NotEqual neq = 10;
- GreaterThan gt = 11;
- GreaterThanEqual gt_eq = 12;
- LessThan lt = 13;
- LessThanEqual lt_eq = 14;
+ BinaryExpr eq = 9;
+ BinaryExpr neq = 10;
+ BinaryExpr gt = 11;
+ BinaryExpr gt_eq = 12;
+ BinaryExpr lt = 13;
+ BinaryExpr lt_eq = 14;
IsNull is_null = 15;
IsNotNull is_not_null = 16;
- And and = 17;
- Or or = 18;
+ BinaryExpr and = 17;
+ BinaryExpr or = 18;
SortOrder sort_order = 19;
Substring substring = 20;
StringSpace string_space = 21;
@@ -50,24 +50,24 @@ message Expr {
Minute minute = 23;
Second second = 24;
CheckOverflow check_overflow = 25;
- Like like = 26;
- StartsWith startsWith = 27;
- EndsWith endsWith = 28;
- Contains contains = 29;
- RLike rlike = 30;
+ BinaryExpr like = 26;
+ BinaryExpr startsWith = 27;
+ BinaryExpr endsWith = 28;
+ BinaryExpr contains = 29;
+ BinaryExpr rlike = 30;
ScalarFunc scalarFunc = 31;
- EqualNullSafe eqNullSafe = 32;
- NotEqualNullSafe neqNullSafe = 33;
- BitwiseAnd bitwiseAnd = 34;
- BitwiseOr bitwiseOr = 35;
- BitwiseXor bitwiseXor = 36;
+ BinaryExpr eqNullSafe = 32;
+ BinaryExpr neqNullSafe = 33;
+ BinaryExpr bitwiseAnd = 34;
+ BinaryExpr bitwiseOr = 35;
+ BinaryExpr bitwiseXor = 36;
Remainder remainder = 37;
CaseWhen caseWhen = 38;
In in = 39;
Not not = 40;
UnaryMinus unary_minus = 41;
- BitwiseShiftRight bitwiseShiftRight = 42;
- BitwiseShiftLeft bitwiseShiftLeft = 43;
+ BinaryExpr bitwiseShiftRight = 42;
+ BinaryExpr bitwiseShiftLeft = 43;
IfExpr if = 44;
NormalizeNaNAndZero normalize_nan_and_zero = 45;
TruncDate truncDate = 46;
@@ -269,52 +269,7 @@ message Cast {
bool allow_incompat = 5;
}
-message Equal {
- Expr left = 1;
- Expr right = 2;
-}
-
-message NotEqual {
- Expr left = 1;
- Expr right = 2;
-}
-
-message EqualNullSafe {
- Expr left = 1;
- Expr right = 2;
-}
-
-message NotEqualNullSafe {
- Expr left = 1;
- Expr right = 2;
-}
-
-message GreaterThan {
- Expr left = 1;
- Expr right = 2;
-}
-
-message GreaterThanEqual {
- Expr left = 1;
- Expr right = 2;
-}
-
-message LessThan {
- Expr left = 1;
- Expr right = 2;
-}
-
-message LessThanEqual {
- Expr left = 1;
- Expr right = 2;
-}
-
-message And {
- Expr left = 1;
- Expr right = 2;
-}
-
-message Or {
+message BinaryExpr {
Expr left = 1;
Expr right = 2;
}
@@ -384,62 +339,12 @@ message CheckOverflow {
bool fail_on_error = 3;
}
-message Like {
- Expr left = 1;
- Expr right = 2;
-}
-
-message RLike {
- Expr left = 1;
- Expr right = 2;
-}
-
-message StartsWith {
- Expr left = 1;
- Expr right = 2;
-}
-
-message EndsWith {
- Expr left = 1;
- Expr right = 2;
-}
-
-message Contains {
- Expr left = 1;
- Expr right = 2;
-}
-
message ScalarFunc {
string func = 1;
repeated Expr args = 2;
DataType return_type = 3;
}
-message BitwiseAnd {
- Expr left = 1;
- Expr right = 2;
-}
-
-message BitwiseOr {
- Expr left = 1;
- Expr right = 2;
-}
-
-message BitwiseXor {
- Expr left = 1;
- Expr right = 2;
-}
-
-message BitwiseShiftRight {
- Expr left = 1;
- Expr right = 2;
-}
-
-message BitwiseShiftLeft {
- Expr left = 1;
- Expr right = 2;
-}
-
message CaseWhen {
// The expr field is added to be consistent with CaseExpr definition in
DataFusion.
// This field is not really used. When constructing a CaseExpr, this expr
field
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index abb138b0..ef4d2563 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -1086,155 +1086,67 @@ object QueryPlanSerde extends Logging with
ShimQueryPlanSerde with CometExprShim
None
case EqualTo(left, right) =>
- val leftExpr = exprToProtoInternal(left, inputs)
- val rightExpr = exprToProtoInternal(right, inputs)
-
- if (leftExpr.isDefined && rightExpr.isDefined) {
- val builder = ExprOuterClass.Equal.newBuilder()
- builder.setLeft(leftExpr.get)
- builder.setRight(rightExpr.get)
-
- Some(
- ExprOuterClass.Expr
- .newBuilder()
- .setEq(builder)
- .build())
- } else {
- withInfo(expr, left, right)
- None
+ createBinaryExpr(left, right, inputs).map { builder =>
+ ExprOuterClass.Expr
+ .newBuilder()
+ .setEq(builder)
+ .build()
}
case Not(EqualTo(left, right)) =>
- val leftExpr = exprToProtoInternal(left, inputs)
- val rightExpr = exprToProtoInternal(right, inputs)
-
- if (leftExpr.isDefined && rightExpr.isDefined) {
- val builder = ExprOuterClass.NotEqual.newBuilder()
- builder.setLeft(leftExpr.get)
- builder.setRight(rightExpr.get)
-
- Some(
- ExprOuterClass.Expr
- .newBuilder()
- .setNeq(builder)
- .build())
- } else {
- withInfo(expr, left, right)
- None
+ createBinaryExpr(left, right, inputs).map { builder =>
+ ExprOuterClass.Expr
+ .newBuilder()
+ .setNeq(builder)
+ .build()
}
case EqualNullSafe(left, right) =>
- val leftExpr = exprToProtoInternal(left, inputs)
- val rightExpr = exprToProtoInternal(right, inputs)
-
- if (leftExpr.isDefined && rightExpr.isDefined) {
- val builder = ExprOuterClass.EqualNullSafe.newBuilder()
- builder.setLeft(leftExpr.get)
- builder.setRight(rightExpr.get)
-
- Some(
- ExprOuterClass.Expr
- .newBuilder()
- .setEqNullSafe(builder)
- .build())
- } else {
- withInfo(expr, left, right)
- None
+ createBinaryExpr(left, right, inputs).map { builder =>
+ ExprOuterClass.Expr
+ .newBuilder()
+ .setEqNullSafe(builder)
+ .build()
}
case Not(EqualNullSafe(left, right)) =>
- val leftExpr = exprToProtoInternal(left, inputs)
- val rightExpr = exprToProtoInternal(right, inputs)
-
- if (leftExpr.isDefined && rightExpr.isDefined) {
- val builder = ExprOuterClass.NotEqualNullSafe.newBuilder()
- builder.setLeft(leftExpr.get)
- builder.setRight(rightExpr.get)
-
- Some(
- ExprOuterClass.Expr
- .newBuilder()
- .setNeqNullSafe(builder)
- .build())
- } else {
- withInfo(expr, left, right)
- None
+ createBinaryExpr(left, right, inputs).map { builder =>
+ ExprOuterClass.Expr
+ .newBuilder()
+ .setNeqNullSafe(builder)
+ .build()
}
case GreaterThan(left, right) =>
- val leftExpr = exprToProtoInternal(left, inputs)
- val rightExpr = exprToProtoInternal(right, inputs)
-
- if (leftExpr.isDefined && rightExpr.isDefined) {
- val builder = ExprOuterClass.GreaterThan.newBuilder()
- builder.setLeft(leftExpr.get)
- builder.setRight(rightExpr.get)
-
- Some(
- ExprOuterClass.Expr
- .newBuilder()
- .setGt(builder)
- .build())
- } else {
- withInfo(expr, left, right)
- None
+ createBinaryExpr(left, right, inputs).map { builder =>
+ ExprOuterClass.Expr
+ .newBuilder()
+ .setGt(builder)
+ .build()
}
case GreaterThanOrEqual(left, right) =>
- val leftExpr = exprToProtoInternal(left, inputs)
- val rightExpr = exprToProtoInternal(right, inputs)
-
- if (leftExpr.isDefined && rightExpr.isDefined) {
- val builder = ExprOuterClass.GreaterThanEqual.newBuilder()
- builder.setLeft(leftExpr.get)
- builder.setRight(rightExpr.get)
-
- Some(
- ExprOuterClass.Expr
- .newBuilder()
- .setGtEq(builder)
- .build())
- } else {
- withInfo(expr, left, right)
- None
+ createBinaryExpr(left, right, inputs).map { builder =>
+ ExprOuterClass.Expr
+ .newBuilder()
+ .setGtEq(builder)
+ .build()
}
case LessThan(left, right) =>
- val leftExpr = exprToProtoInternal(left, inputs)
- val rightExpr = exprToProtoInternal(right, inputs)
-
- if (leftExpr.isDefined && rightExpr.isDefined) {
- val builder = ExprOuterClass.LessThan.newBuilder()
- builder.setLeft(leftExpr.get)
- builder.setRight(rightExpr.get)
-
- Some(
- ExprOuterClass.Expr
- .newBuilder()
- .setLt(builder)
- .build())
- } else {
- withInfo(expr, left, right)
- None
+ createBinaryExpr(left, right, inputs).map { builder =>
+ ExprOuterClass.Expr
+ .newBuilder()
+ .setLt(builder)
+ .build()
}
case LessThanOrEqual(left, right) =>
- val leftExpr = exprToProtoInternal(left, inputs)
- val rightExpr = exprToProtoInternal(right, inputs)
-
- if (leftExpr.isDefined && rightExpr.isDefined) {
- val builder = ExprOuterClass.LessThanEqual.newBuilder()
- builder.setLeft(leftExpr.get)
- builder.setRight(rightExpr.get)
-
- Some(
- ExprOuterClass.Expr
- .newBuilder()
- .setLtEq(builder)
- .build())
- } else {
- withInfo(expr, left, right)
- None
+ createBinaryExpr(left, right, inputs).map { builder =>
+ ExprOuterClass.Expr
+ .newBuilder()
+ .setLtEq(builder)
+ .build()
}
case Literal(value, dataType)
@@ -1372,22 +1284,11 @@ object QueryPlanSerde extends Logging with
ShimQueryPlanSerde with CometExprShim
case Like(left, right, escapeChar) =>
if (escapeChar == '\\') {
- val leftExpr = exprToProtoInternal(left, inputs)
- val rightExpr = exprToProtoInternal(right, inputs)
-
- if (leftExpr.isDefined && rightExpr.isDefined) {
- val builder = ExprOuterClass.Like.newBuilder()
- builder.setLeft(leftExpr.get)
- builder.setRight(rightExpr.get)
-
- Some(
- ExprOuterClass.Expr
- .newBuilder()
- .setLike(builder)
- .build())
- } else {
- withInfo(expr, left, right)
- None
+ createBinaryExpr(left, right, inputs).map { builder =>
+ ExprOuterClass.Expr
+ .newBuilder()
+ .setLike(builder)
+ .build()
}
} else {
// TODO custom escape char
@@ -1413,78 +1314,35 @@ object QueryPlanSerde extends Logging with
ShimQueryPlanSerde with CometExprShim
return None
}
- val leftExpr = exprToProtoInternal(left, inputs)
- val rightExpr = exprToProtoInternal(right, inputs)
-
- if (leftExpr.isDefined && rightExpr.isDefined) {
- val builder = ExprOuterClass.RLike.newBuilder()
- builder.setLeft(leftExpr.get)
- builder.setRight(rightExpr.get)
-
- Some(
- ExprOuterClass.Expr
- .newBuilder()
- .setRlike(builder)
- .build())
- } else {
- withInfo(expr, left, right)
- None
+ createBinaryExpr(left, right, inputs).map { builder =>
+ ExprOuterClass.Expr
+ .newBuilder()
+ .setRlike(builder)
+ .build()
}
- case StartsWith(left, right) =>
- val leftExpr = exprToProtoInternal(left, inputs)
- val rightExpr = exprToProtoInternal(right, inputs)
-
- if (leftExpr.isDefined && rightExpr.isDefined) {
- val builder = ExprOuterClass.StartsWith.newBuilder()
- builder.setLeft(leftExpr.get)
- builder.setRight(rightExpr.get)
- Some(
- ExprOuterClass.Expr
- .newBuilder()
- .setStartsWith(builder)
- .build())
- } else {
- withInfo(expr, left, right)
- None
+ case StartsWith(left, right) =>
+ createBinaryExpr(left, right, inputs).map { builder =>
+ ExprOuterClass.Expr
+ .newBuilder()
+ .setStartsWith(builder)
+ .build()
}
case EndsWith(left, right) =>
- val leftExpr = exprToProtoInternal(left, inputs)
- val rightExpr = exprToProtoInternal(right, inputs)
-
- if (leftExpr.isDefined && rightExpr.isDefined) {
- val builder = ExprOuterClass.EndsWith.newBuilder()
- builder.setLeft(leftExpr.get)
- builder.setRight(rightExpr.get)
-
- Some(
- ExprOuterClass.Expr
- .newBuilder()
- .setEndsWith(builder)
- .build())
- } else {
- withInfo(expr, left, right)
- None
+ createBinaryExpr(left, right, inputs).map { builder =>
+ ExprOuterClass.Expr
+ .newBuilder()
+ .setEndsWith(builder)
+ .build()
}
case Contains(left, right) =>
- val leftExpr = exprToProtoInternal(left, inputs)
- val rightExpr = exprToProtoInternal(right, inputs)
-
- if (leftExpr.isDefined && rightExpr.isDefined) {
- val builder = ExprOuterClass.Contains.newBuilder()
- builder.setLeft(leftExpr.get)
- builder.setRight(rightExpr.get)
-
- Some(
- ExprOuterClass.Expr
- .newBuilder()
- .setContains(builder)
- .build())
- } else {
- withInfo(expr, left, right)
- None
+ createBinaryExpr(left, right, inputs).map { builder =>
+ ExprOuterClass.Expr
+ .newBuilder()
+ .setContains(builder)
+ .build()
}
case StringSpace(child) =>
@@ -1705,41 +1563,19 @@ object QueryPlanSerde extends Logging with
ShimQueryPlanSerde with CometExprShim
}
case And(left, right) =>
- val leftExpr = exprToProtoInternal(left, inputs)
- val rightExpr = exprToProtoInternal(right, inputs)
-
- if (leftExpr.isDefined && rightExpr.isDefined) {
- val builder = ExprOuterClass.And.newBuilder()
- builder.setLeft(leftExpr.get)
- builder.setRight(rightExpr.get)
-
- Some(
- ExprOuterClass.Expr
- .newBuilder()
- .setAnd(builder)
- .build())
- } else {
- withInfo(expr, left, right)
- None
+ createBinaryExpr(left, right, inputs).map { builder =>
+ ExprOuterClass.Expr
+ .newBuilder()
+ .setAnd(builder)
+ .build()
}
case Or(left, right) =>
- val leftExpr = exprToProtoInternal(left, inputs)
- val rightExpr = exprToProtoInternal(right, inputs)
-
- if (leftExpr.isDefined && rightExpr.isDefined) {
- val builder = ExprOuterClass.Or.newBuilder()
- builder.setLeft(leftExpr.get)
- builder.setRight(rightExpr.get)
-
- Some(
- ExprOuterClass.Expr
- .newBuilder()
- .setOr(builder)
- .build())
- } else {
- withInfo(expr, left, right)
- None
+ createBinaryExpr(left, right, inputs).map { builder =>
+ ExprOuterClass.Expr
+ .newBuilder()
+ .setOr(builder)
+ .build()
}
case UnaryExpression(child) if expr.prettyName == "promote_precision"
=>
@@ -2177,22 +2013,11 @@ object QueryPlanSerde extends Logging with
ShimQueryPlanSerde with CometExprShim
}
case BitwiseAnd(left, right) =>
- val leftExpr = exprToProtoInternal(left, inputs)
- val rightExpr = exprToProtoInternal(right, inputs)
-
- if (leftExpr.isDefined && rightExpr.isDefined) {
- val builder = ExprOuterClass.BitwiseAnd.newBuilder()
- builder.setLeft(leftExpr.get)
- builder.setRight(rightExpr.get)
-
- Some(
- ExprOuterClass.Expr
- .newBuilder()
- .setBitwiseAnd(builder)
- .build())
- } else {
- withInfo(expr, left, right)
- None
+ createBinaryExpr(left, right, inputs).map { builder =>
+ ExprOuterClass.Expr
+ .newBuilder()
+ .setBitwiseAnd(builder)
+ .build()
}
case BitwiseNot(child) =>
@@ -2213,45 +2038,22 @@ object QueryPlanSerde extends Logging with
ShimQueryPlanSerde with CometExprShim
}
case BitwiseOr(left, right) =>
- val leftExpr = exprToProtoInternal(left, inputs)
- val rightExpr = exprToProtoInternal(right, inputs)
-
- if (leftExpr.isDefined && rightExpr.isDefined) {
- val builder = ExprOuterClass.BitwiseOr.newBuilder()
- builder.setLeft(leftExpr.get)
- builder.setRight(rightExpr.get)
-
- Some(
- ExprOuterClass.Expr
- .newBuilder()
- .setBitwiseOr(builder)
- .build())
- } else {
- withInfo(expr, left, right)
- None
+ createBinaryExpr(left, right, inputs).map { builder =>
+ ExprOuterClass.Expr
+ .newBuilder()
+ .setBitwiseOr(builder)
+ .build()
}
case BitwiseXor(left, right) =>
- val leftExpr = exprToProtoInternal(left, inputs)
- val rightExpr = exprToProtoInternal(right, inputs)
-
- if (leftExpr.isDefined && rightExpr.isDefined) {
- val builder = ExprOuterClass.BitwiseXor.newBuilder()
- builder.setLeft(leftExpr.get)
- builder.setRight(rightExpr.get)
-
- Some(
- ExprOuterClass.Expr
- .newBuilder()
- .setBitwiseXor(builder)
- .build())
- } else {
- withInfo(expr, left, right)
- None
+ createBinaryExpr(left, right, inputs).map { builder =>
+ ExprOuterClass.Expr
+ .newBuilder()
+ .setBitwiseXor(builder)
+ .build()
}
case ShiftRight(left, right) =>
- val leftExpr = exprToProtoInternal(left, inputs)
// DataFusion bitwise shift right expression requires
// same data type between left and right side
val rightExpression = if (left.dataType == LongType) {
@@ -2259,25 +2061,15 @@ object QueryPlanSerde extends Logging with
ShimQueryPlanSerde with CometExprShim
} else {
right
}
- val rightExpr = exprToProtoInternal(rightExpression, inputs)
- if (leftExpr.isDefined && rightExpr.isDefined) {
- val builder = ExprOuterClass.BitwiseShiftRight.newBuilder()
- builder.setLeft(leftExpr.get)
- builder.setRight(rightExpr.get)
-
- Some(
- ExprOuterClass.Expr
- .newBuilder()
- .setBitwiseShiftRight(builder)
- .build())
- } else {
- withInfo(expr, left, rightExpression)
- None
+ createBinaryExpr(left, rightExpression, inputs).map { builder =>
+ ExprOuterClass.Expr
+ .newBuilder()
+ .setBitwiseShiftRight(builder)
+ .build()
}
case ShiftLeft(left, right) =>
- val leftExpr = exprToProtoInternal(left, inputs)
// DataFusion bitwise shift right expression requires
// same data type between left and right side
val rightExpression = if (left.dataType == LongType) {
@@ -2285,21 +2077,12 @@ object QueryPlanSerde extends Logging with
ShimQueryPlanSerde with CometExprShim
} else {
right
}
- val rightExpr = exprToProtoInternal(rightExpression, inputs)
-
- if (leftExpr.isDefined && rightExpr.isDefined) {
- val builder = ExprOuterClass.BitwiseShiftLeft.newBuilder()
- builder.setLeft(leftExpr.get)
- builder.setRight(rightExpr.get)
- Some(
- ExprOuterClass.Expr
- .newBuilder()
- .setBitwiseShiftLeft(builder)
- .build())
- } else {
- withInfo(expr, left, rightExpression)
- None
+ createBinaryExpr(left, rightExpression, inputs).map { builder =>
+ ExprOuterClass.Expr
+ .newBuilder()
+ .setBitwiseShiftLeft(builder)
+ .build()
}
case In(value, list) =>
@@ -2611,6 +2394,25 @@ object QueryPlanSerde extends Logging with
ShimQueryPlanSerde with CometExprShim
}
}
+ def createBinaryExpr(
+ left: Expression,
+ right: Expression,
+ inputs: Seq[Attribute]): Option[ExprOuterClass.BinaryExpr] = {
+ val leftExpr = exprToProtoInternal(left, inputs)
+ val rightExpr = exprToProtoInternal(right, inputs)
+ if (leftExpr.isDefined && rightExpr.isDefined) {
+ Some(
+ ExprOuterClass.BinaryExpr
+ .newBuilder()
+ .setLeft(leftExpr.get)
+ .setRight(rightExpr.get)
+ .build())
+ } else {
+ withInfo(expr, left, right)
+ None
+ }
+ }
+
def trim(
expr: Expression, // parent expression
srcStr: Expression,
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]