This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 562a8772 Refactor UnaryExpr and MathExpr in protobuf (#1056)
562a8772 is described below
commit 562a877283d5b0d8eb6dfceb1470c100dbc0adfe
Author: Andy Grove <[email protected]>
AuthorDate: Wed Nov 6 08:16:18 2024 -0700
Refactor UnaryExpr and MathExpr in protobuf (#1056)
---
native/proto/src/proto/expr.proto | 68 +-----
.../org/apache/comet/serde/QueryPlanSerde.scala | 245 ++++++++-------------
2 files changed, 101 insertions(+), 212 deletions(-)
diff --git a/native/proto/src/proto/expr.proto
b/native/proto/src/proto/expr.proto
index b5b975a5..3a8193f4 100644
--- a/native/proto/src/proto/expr.proto
+++ b/native/proto/src/proto/expr.proto
@@ -28,10 +28,10 @@ message Expr {
oneof expr_struct {
Literal literal = 2;
BoundReference bound = 3;
- Add add = 4;
- Subtract subtract = 5;
- Multiply multiply = 6;
- Divide divide = 7;
+ MathExpr add = 4;
+ MathExpr subtract = 5;
+ MathExpr multiply = 6;
+ MathExpr divide = 7;
Cast cast = 8;
BinaryExpr eq = 9;
BinaryExpr neq = 10;
@@ -39,13 +39,13 @@ message Expr {
BinaryExpr gt_eq = 12;
BinaryExpr lt = 13;
BinaryExpr lt_eq = 14;
- IsNull is_null = 15;
- IsNotNull is_not_null = 16;
+ UnaryExpr is_null = 15;
+ UnaryExpr is_not_null = 16;
BinaryExpr and = 17;
BinaryExpr or = 18;
SortOrder sort_order = 19;
Substring substring = 20;
- StringSpace string_space = 21;
+ UnaryExpr string_space = 21;
Hour hour = 22;
Minute minute = 23;
Second second = 24;
@@ -61,10 +61,10 @@ message Expr {
BinaryExpr bitwiseAnd = 34;
BinaryExpr bitwiseOr = 35;
BinaryExpr bitwiseXor = 36;
- Remainder remainder = 37;
+ MathExpr remainder = 37;
CaseWhen caseWhen = 38;
In in = 39;
- Not not = 40;
+ UnaryExpr not = 40;
UnaryMinus unary_minus = 41;
BinaryExpr bitwiseShiftRight = 42;
BinaryExpr bitwiseShiftLeft = 43;
@@ -72,7 +72,7 @@ message Expr {
NormalizeNaNAndZero normalize_nan_and_zero = 45;
TruncDate truncDate = 46;
TruncTimestamp truncTimestamp = 47;
- BitwiseNot bitwiseNot = 48;
+ UnaryExpr bitwiseNot = 48;
Abs abs = 49;
Subquery subquery = 50;
UnboundReference unbound = 51;
@@ -220,35 +220,7 @@ message Literal {
bool is_null = 12;
}
-message Add {
- Expr left = 1;
- Expr right = 2;
- bool fail_on_error = 3;
- DataType return_type = 4;
-}
-
-message Subtract {
- Expr left = 1;
- Expr right = 2;
- bool fail_on_error = 3;
- DataType return_type = 4;
-}
-
-message Multiply {
- Expr left = 1;
- Expr right = 2;
- bool fail_on_error = 3;
- DataType return_type = 4;
-}
-
-message Divide {
- Expr left = 1;
- Expr right = 2;
- bool fail_on_error = 3;
- DataType return_type = 4;
-}
-
-message Remainder {
+message MathExpr {
Expr left = 1;
Expr right = 2;
bool fail_on_error = 3;
@@ -274,11 +246,7 @@ message BinaryExpr {
Expr right = 2;
}
-message IsNull {
- Expr child = 1;
-}
-
-message IsNotNull {
+message UnaryExpr {
Expr child = 1;
}
@@ -305,10 +273,6 @@ message Substring {
int32 len = 3;
}
-message StringSpace {
- Expr child = 1;
-}
-
message ToJson {
Expr child = 1;
string timezone = 2;
@@ -368,10 +332,6 @@ message NormalizeNaNAndZero {
DataType datatype = 2;
}
-message Not {
- Expr child = 1;
-}
-
message UnaryMinus {
Expr child = 1;
bool fail_on_error = 2;
@@ -394,10 +354,6 @@ message TruncTimestamp {
string timezone = 3;
}
-message BitwiseNot {
- Expr child = 1;
-}
-
message Abs {
Expr child = 1;
EvalMode eval_mode = 2;
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index ef4d2563..8bdc886d 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -932,26 +932,12 @@ object QueryPlanSerde extends Logging with
ShimQueryPlanSerde with CometExprShim
handleCast(child, inputs, dt, timeZoneId, evalMode(c))
case add @ Add(left, right, _) if supportedDataType(left.dataType) =>
- val leftExpr = exprToProtoInternal(left, inputs)
- val rightExpr = exprToProtoInternal(right, inputs)
-
- if (leftExpr.isDefined && rightExpr.isDefined) {
- val addBuilder = ExprOuterClass.Add.newBuilder()
- addBuilder.setLeft(leftExpr.get)
- addBuilder.setRight(rightExpr.get)
- addBuilder.setFailOnError(getFailOnError(add))
- serializeDataType(add.dataType).foreach { t =>
- addBuilder.setReturnType(t)
- }
-
- Some(
+ createMathExpression(left, right, inputs, add.dataType,
getFailOnError(add)).map {
+ expr =>
ExprOuterClass.Expr
.newBuilder()
- .setAdd(addBuilder)
- .build())
- } else {
- withInfo(add, left, right)
- None
+ .setAdd(expr)
+ .build()
}
case add @ Add(left, _, _) if !supportedDataType(left.dataType) =>
@@ -959,26 +945,12 @@ object QueryPlanSerde extends Logging with
ShimQueryPlanSerde with CometExprShim
None
case sub @ Subtract(left, right, _) if
supportedDataType(left.dataType) =>
- val leftExpr = exprToProtoInternal(left, inputs)
- val rightExpr = exprToProtoInternal(right, inputs)
-
- if (leftExpr.isDefined && rightExpr.isDefined) {
- val builder = ExprOuterClass.Subtract.newBuilder()
- builder.setLeft(leftExpr.get)
- builder.setRight(rightExpr.get)
- builder.setFailOnError(getFailOnError(sub))
- serializeDataType(sub.dataType).foreach { t =>
- builder.setReturnType(t)
- }
-
- Some(
+ createMathExpression(left, right, inputs, sub.dataType,
getFailOnError(sub)).map {
+ expr =>
ExprOuterClass.Expr
.newBuilder()
- .setSubtract(builder)
- .build())
- } else {
- withInfo(sub, left, right)
- None
+ .setSubtract(expr)
+ .build()
}
case sub @ Subtract(left, _, _) if !supportedDataType(left.dataType) =>
@@ -987,26 +959,12 @@ object QueryPlanSerde extends Logging with
ShimQueryPlanSerde with CometExprShim
case mul @ Multiply(left, right, _)
if supportedDataType(left.dataType) &&
!decimalBeforeSpark34(left.dataType) =>
- val leftExpr = exprToProtoInternal(left, inputs)
- val rightExpr = exprToProtoInternal(right, inputs)
-
- if (leftExpr.isDefined && rightExpr.isDefined) {
- val builder = ExprOuterClass.Multiply.newBuilder()
- builder.setLeft(leftExpr.get)
- builder.setRight(rightExpr.get)
- builder.setFailOnError(getFailOnError(mul))
- serializeDataType(mul.dataType).foreach { t =>
- builder.setReturnType(t)
- }
-
- Some(
+ createMathExpression(left, right, inputs, mul.dataType,
getFailOnError(mul)).map {
+ expr =>
ExprOuterClass.Expr
.newBuilder()
- .setMultiply(builder)
- .build())
- } else {
- withInfo(mul, left, right)
- None
+ .setMultiply(expr)
+ .build()
}
case mul @ Multiply(left, _, _) =>
@@ -1020,30 +978,19 @@ object QueryPlanSerde extends Logging with
ShimQueryPlanSerde with CometExprShim
case div @ Divide(left, right, _)
if supportedDataType(left.dataType) &&
!decimalBeforeSpark34(left.dataType) =>
- val leftExpr = exprToProtoInternal(left, inputs)
// Datafusion now throws an exception for dividing by zero
// See https://github.com/apache/arrow-datafusion/pull/6792
// For now, use NullIf to swap zeros with nulls.
- val rightExpr = exprToProtoInternal(nullIfWhenPrimitive(right),
inputs)
-
- if (leftExpr.isDefined && rightExpr.isDefined) {
- val builder = ExprOuterClass.Divide.newBuilder()
- builder.setLeft(leftExpr.get)
- builder.setRight(rightExpr.get)
- builder.setFailOnError(getFailOnError(div))
- serializeDataType(div.dataType).foreach { t =>
- builder.setReturnType(t)
- }
+ val rightExpr = nullIfWhenPrimitive(right)
- Some(
+ createMathExpression(left, rightExpr, inputs, div.dataType,
getFailOnError(div)).map {
+ expr =>
ExprOuterClass.Expr
.newBuilder()
- .setDivide(builder)
- .build())
- } else {
- withInfo(div, left, right)
- None
+ .setDivide(expr)
+ .build()
}
+
case div @ Divide(left, _, _) =>
if (!supportedDataType(left.dataType)) {
withInfo(div, s"Unsupported datatype ${left.dataType}")
@@ -1055,27 +1002,16 @@ object QueryPlanSerde extends Logging with
ShimQueryPlanSerde with CometExprShim
case rem @ Remainder(left, right, _)
if supportedDataType(left.dataType) &&
!decimalBeforeSpark34(left.dataType) =>
- val leftExpr = exprToProtoInternal(left, inputs)
- val rightExpr = exprToProtoInternal(nullIfWhenPrimitive(right),
inputs)
-
- if (leftExpr.isDefined && rightExpr.isDefined) {
- val builder = ExprOuterClass.Remainder.newBuilder()
- builder.setLeft(leftExpr.get)
- builder.setRight(rightExpr.get)
- builder.setFailOnError(getFailOnError(rem))
- serializeDataType(rem.dataType).foreach { t =>
- builder.setReturnType(t)
- }
+ val rightExpr = nullIfWhenPrimitive(right)
- Some(
+ createMathExpression(left, rightExpr, inputs, rem.dataType,
getFailOnError(rem)).map {
+ expr =>
ExprOuterClass.Expr
.newBuilder()
- .setRemainder(builder)
- .build())
- } else {
- withInfo(rem, left, right)
- None
+ .setRemainder(expr)
+ .build()
}
+
case rem @ Remainder(left, _, _) =>
if (!supportedDataType(left.dataType)) {
withInfo(rem, s"Unsupported datatype ${left.dataType}")
@@ -1346,20 +1282,11 @@ object QueryPlanSerde extends Logging with
ShimQueryPlanSerde with CometExprShim
}
case StringSpace(child) =>
- val childExpr = exprToProtoInternal(child, inputs)
-
- if (childExpr.isDefined) {
- val builder = ExprOuterClass.StringSpace.newBuilder()
- builder.setChild(childExpr.get)
-
- Some(
- ExprOuterClass.Expr
- .newBuilder()
- .setStringSpace(builder)
- .build())
- } else {
- withInfo(expr, child)
- None
+ createUnaryExpr(child, inputs).map { expr =>
+ ExprOuterClass.Expr
+ .newBuilder()
+ .setStringSpace(expr)
+ .build()
}
case Hour(child, timeZoneId) =>
@@ -1495,37 +1422,19 @@ object QueryPlanSerde extends Logging with
ShimQueryPlanSerde with CometExprShim
optExprWithInfo(optExpr, expr, child)
case IsNull(child) =>
- val childExpr = exprToProtoInternal(child, inputs)
-
- if (childExpr.isDefined) {
- val castBuilder = ExprOuterClass.IsNull.newBuilder()
- castBuilder.setChild(childExpr.get)
-
- Some(
- ExprOuterClass.Expr
- .newBuilder()
- .setIsNull(castBuilder)
- .build())
- } else {
- withInfo(expr, child)
- None
+ createUnaryExpr(child, inputs).map { expr =>
+ ExprOuterClass.Expr
+ .newBuilder()
+ .setIsNull(expr)
+ .build()
}
case IsNotNull(child) =>
- val childExpr = exprToProtoInternal(child, inputs)
-
- if (childExpr.isDefined) {
- val castBuilder = ExprOuterClass.IsNotNull.newBuilder()
- castBuilder.setChild(childExpr.get)
-
- Some(
- ExprOuterClass.Expr
- .newBuilder()
- .setIsNotNull(castBuilder)
- .build())
- } else {
- withInfo(expr, child)
- None
+ createUnaryExpr(child, inputs).map { expr =>
+ ExprOuterClass.Expr
+ .newBuilder()
+ .setIsNotNull(expr)
+ .build()
}
case IsNaN(child) =>
@@ -2021,20 +1930,11 @@ object QueryPlanSerde extends Logging with
ShimQueryPlanSerde with CometExprShim
}
case BitwiseNot(child) =>
- val childExpr = exprToProtoInternal(child, inputs)
-
- if (childExpr.isDefined) {
- val builder = ExprOuterClass.BitwiseNot.newBuilder()
- builder.setChild(childExpr.get)
-
- Some(
- ExprOuterClass.Expr
- .newBuilder()
- .setBitwiseNot(builder)
- .build())
- } else {
- withInfo(expr, child)
- None
+ createUnaryExpr(child, inputs).map { expr =>
+ ExprOuterClass.Expr
+ .newBuilder()
+ .setBitwiseNot(expr)
+ .build()
}
case BitwiseOr(left, right) =>
@@ -2101,18 +2001,11 @@ object QueryPlanSerde extends Logging with
ShimQueryPlanSerde with CometExprShim
in(expr, value, list, inputs, true)
case Not(child) =>
- val childExpr = exprToProtoInternal(child, inputs)
- if (childExpr.isDefined) {
- val builder = ExprOuterClass.Not.newBuilder()
- builder.setChild(childExpr.get)
- Some(
- ExprOuterClass.Expr
- .newBuilder()
- .setNot(builder)
- .build())
- } else {
- withInfo(expr, child)
- None
+ createUnaryExpr(child, inputs).map { expr =>
+ ExprOuterClass.Expr
+ .newBuilder()
+ .setNot(expr)
+ .build()
}
case UnaryMinus(child, failOnError) =>
@@ -2394,6 +2287,22 @@ object QueryPlanSerde extends Logging with
ShimQueryPlanSerde with CometExprShim
}
}
+ def createUnaryExpr(
+ child: Expression,
+ inputs: Seq[Attribute]): Option[ExprOuterClass.UnaryExpr] = {
+ val childExpr = exprToProtoInternal(child, inputs)
+ if (childExpr.isDefined) {
+ Some(
+ ExprOuterClass.UnaryExpr
+ .newBuilder()
+ .setChild(childExpr.get)
+ .build())
+ } else {
+ withInfo(expr, child)
+ None
+ }
+ }
+
def createBinaryExpr(
left: Expression,
right: Expression,
@@ -2413,6 +2322,30 @@ object QueryPlanSerde extends Logging with
ShimQueryPlanSerde with CometExprShim
}
}
+ def createMathExpression(
+ left: Expression,
+ right: Expression,
+ inputs: Seq[Attribute],
+ dataType: DataType,
+ failOnError: Boolean): Option[ExprOuterClass.MathExpr] = {
+ val leftExpr = exprToProtoInternal(left, inputs)
+ val rightExpr = exprToProtoInternal(right, inputs)
+
+ if (leftExpr.isDefined && rightExpr.isDefined) {
+ val builder = ExprOuterClass.MathExpr.newBuilder()
+ builder.setLeft(leftExpr.get)
+ builder.setRight(rightExpr.get)
+ builder.setFailOnError(failOnError)
+ serializeDataType(dataType).foreach { t =>
+ builder.setReturnType(t)
+ }
+ Some(builder.build())
+ } else {
+ withInfo(expr, left, right)
+ None
+ }
+ }
+
def trim(
expr: Expression, // parent expression
srcStr: Expression,
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]