This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new e07f24c8 feat: Support Ansi mode in abs function (#500)
e07f24c8 is described below

commit e07f24c88293b4882bb8fd688acee9809478e07c
Author: Pablo Langa <[email protected]>
AuthorDate: Tue Jun 11 09:28:02 2024 -0400

    feat: Support Ansi mode in abs function (#500)
    
    * change proto msg
    
    * QueryPlanSerde with eval mode
    
    * Move eval mode
    
    * Add abs in planner
    
    * CometAbsFunc wrapper
    
    * Add error management
    
    * Add tests
    
    * Add license
    
    * spotless apply
    
    * format
    
    * Fix clippy
    
    * error msg for all spark versions
    
    * Fix benches
    
    * Use enum to ansi mode
    
    * Fix format
    
    * Add more tests
    
    * Format
    
    * Refactor
    
    * refactor
    
    * fix merge
    
    * fix merge
---
 core/benches/cast_from_string.rs                   |  2 +-
 core/benches/cast_numeric.rs                       |  2 +-
 core/src/execution/datafusion/expressions/abs.rs   | 87 ++++++++++++++++++++++
 core/src/execution/datafusion/expressions/cast.rs  |  9 +--
 core/src/execution/datafusion/expressions/mod.rs   | 29 ++++++++
 .../execution/datafusion/expressions/negative.rs   |  8 +-
 core/src/execution/datafusion/planner.rs           | 19 +++--
 core/src/execution/proto/expr.proto                |  1 +
 .../org/apache/comet/serde/QueryPlanSerde.scala    | 14 ++--
 .../org/apache/comet/CometExpressionSuite.scala    | 54 ++++++++++++++
 10 files changed, 195 insertions(+), 30 deletions(-)

diff --git a/core/benches/cast_from_string.rs b/core/benches/cast_from_string.rs
index 5bfaebf3..9a9ab18c 100644
--- a/core/benches/cast_from_string.rs
+++ b/core/benches/cast_from_string.rs
@@ -17,7 +17,7 @@
 
 use arrow_array::{builder::StringBuilder, RecordBatch};
 use arrow_schema::{DataType, Field, Schema};
-use comet::execution::datafusion::expressions::cast::{Cast, EvalMode};
+use comet::execution::datafusion::expressions::{cast::Cast, EvalMode};
 use criterion::{criterion_group, criterion_main, Criterion};
 use datafusion_physical_expr::{expressions::Column, PhysicalExpr};
 use std::sync::Arc;
diff --git a/core/benches/cast_numeric.rs b/core/benches/cast_numeric.rs
index 398be694..35f24ce5 100644
--- a/core/benches/cast_numeric.rs
+++ b/core/benches/cast_numeric.rs
@@ -17,7 +17,7 @@
 
 use arrow_array::{builder::Int32Builder, RecordBatch};
 use arrow_schema::{DataType, Field, Schema};
-use comet::execution::datafusion::expressions::cast::{Cast, EvalMode};
+use comet::execution::datafusion::expressions::{cast::Cast, EvalMode};
 use criterion::{criterion_group, criterion_main, Criterion};
 use datafusion_physical_expr::{expressions::Column, PhysicalExpr};
 use std::sync::Arc;
diff --git a/core/src/execution/datafusion/expressions/abs.rs 
b/core/src/execution/datafusion/expressions/abs.rs
new file mode 100644
index 00000000..4eb8c7c1
--- /dev/null
+++ b/core/src/execution/datafusion/expressions/abs.rs
@@ -0,0 +1,87 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use arrow::datatypes::DataType;
+use arrow_schema::ArrowError;
+use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature};
+use datafusion_common::DataFusionError;
+use datafusion_functions::math;
+use std::{any::Any, sync::Arc};
+
+use crate::execution::operators::ExecutionError;
+
+use super::{arithmetic_overflow_error, EvalMode};
+
+#[derive(Debug)]
+pub struct CometAbsFunc {
+    inner_abs_func: Arc<dyn ScalarUDFImpl>,
+    eval_mode: EvalMode,
+    data_type_name: String,
+}
+
+impl CometAbsFunc {
+    pub fn new(eval_mode: EvalMode, data_type_name: String) -> Result<Self, 
ExecutionError> {
+        if let EvalMode::Legacy | EvalMode::Ansi = eval_mode {
+            Ok(Self {
+                inner_abs_func: math::abs().inner(),
+                eval_mode,
+                data_type_name,
+            })
+        } else {
+            Err(ExecutionError::GeneralError(format!(
+                "Invalid EvalMode: \"{:?}\"",
+                eval_mode
+            )))
+        }
+    }
+}
+
+impl ScalarUDFImpl for CometAbsFunc {
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
+    fn name(&self) -> &str {
+        "abs"
+    }
+
+    fn signature(&self) -> &Signature {
+        self.inner_abs_func.signature()
+    }
+
+    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType, 
DataFusionError> {
+        self.inner_abs_func.return_type(arg_types)
+    }
+
+    fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue, 
DataFusionError> {
+        match self.inner_abs_func.invoke(args) {
+            Err(DataFusionError::ArrowError(ArrowError::ComputeError(msg), 
trace))
+                if msg.contains("overflow") =>
+            {
+                if self.eval_mode == EvalMode::Legacy {
+                    Ok(args[0].clone())
+                } else {
+                    let msg = 
arithmetic_overflow_error(&self.data_type_name).to_string();
+                    Err(DataFusionError::ArrowError(
+                        ArrowError::ComputeError(msg),
+                        trace,
+                    ))
+                }
+            }
+            other => other,
+        }
+    }
+}
diff --git a/core/src/execution/datafusion/expressions/cast.rs 
b/core/src/execution/datafusion/expressions/cast.rs
index 04562646..4dae62dc 100644
--- a/core/src/execution/datafusion/expressions/cast.rs
+++ b/core/src/execution/datafusion/expressions/cast.rs
@@ -52,6 +52,8 @@ use crate::{
     },
 };
 
+use super::EvalMode;
+
 static TIMESTAMP_FORMAT: Option<&str> = Some("%Y-%m-%d %H:%M:%S%.f");
 
 static CAST_OPTIONS: CastOptions = CastOptions {
@@ -61,13 +63,6 @@ static CAST_OPTIONS: CastOptions = CastOptions {
         .with_timestamp_format(TIMESTAMP_FORMAT),
 };
 
-#[derive(Debug, Hash, PartialEq, Clone, Copy)]
-pub enum EvalMode {
-    Legacy,
-    Ansi,
-    Try,
-}
-
 #[derive(Debug, Hash)]
 pub struct Cast {
     pub child: Arc<dyn PhysicalExpr>,
diff --git a/core/src/execution/datafusion/expressions/mod.rs 
b/core/src/execution/datafusion/expressions/mod.rs
index 05230b4c..5d5f58e0 100644
--- a/core/src/execution/datafusion/expressions/mod.rs
+++ b/core/src/execution/datafusion/expressions/mod.rs
@@ -24,6 +24,10 @@ pub mod if_expr;
 mod normalize_nan;
 pub mod scalar_funcs;
 pub use normalize_nan::NormalizeNaNAndZero;
+use prost::DecodeError;
+
+use crate::{errors::CometError, execution::spark_expression};
+pub mod abs;
 pub mod avg;
 pub mod avg_decimal;
 pub mod bloom_filter_might_contain;
@@ -39,3 +43,28 @@ pub mod temporal;
 pub mod unbound;
 mod utils;
 pub mod variance;
+
+#[derive(Debug, Hash, PartialEq, Clone, Copy)]
+pub enum EvalMode {
+    Legacy,
+    Ansi,
+    Try,
+}
+
+impl TryFrom<i32> for EvalMode {
+    type Error = DecodeError;
+
+    fn try_from(value: i32) -> Result<Self, Self::Error> {
+        match spark_expression::EvalMode::try_from(value)? {
+            spark_expression::EvalMode::Legacy => Ok(EvalMode::Legacy),
+            spark_expression::EvalMode::Try => Ok(EvalMode::Try),
+            spark_expression::EvalMode::Ansi => Ok(EvalMode::Ansi),
+        }
+    }
+}
+
+fn arithmetic_overflow_error(from_type: &str) -> CometError {
+    CometError::ArithmeticOverflow {
+        from_type: from_type.to_string(),
+    }
+}
diff --git a/core/src/execution/datafusion/expressions/negative.rs 
b/core/src/execution/datafusion/expressions/negative.rs
index a85cde89..cd0e9bcc 100644
--- a/core/src/execution/datafusion/expressions/negative.rs
+++ b/core/src/execution/datafusion/expressions/negative.rs
@@ -33,6 +33,8 @@ use std::{
     sync::Arc,
 };
 
+use super::arithmetic_overflow_error;
+
 pub fn create_negate_expr(
     expr: Arc<dyn PhysicalExpr>,
     fail_on_error: bool,
@@ -48,12 +50,6 @@ pub struct NegativeExpr {
     fail_on_error: bool,
 }
 
-fn arithmetic_overflow_error(from_type: &str) -> CometError {
-    CometError::ArithmeticOverflow {
-        from_type: from_type.to_string(),
-    }
-}
-
 macro_rules! check_overflow {
     ($array:expr, $array_type:ty, $min_val:expr, $type_name:expr) => {{
         let typed_array = $array
diff --git a/core/src/execution/datafusion/planner.rs 
b/core/src/execution/datafusion/planner.rs
index e5193215..d92bf578 100644
--- a/core/src/execution/datafusion/planner.rs
+++ b/core/src/execution/datafusion/planner.rs
@@ -24,7 +24,6 @@ use datafusion::{
     arrow::{compute::SortOptions, datatypes::SchemaRef},
     common::DataFusionError,
     execution::FunctionRegistry,
-    functions::math,
     logical_expr::Operator as DataFusionOperator,
     physical_expr::{
         execution_props::ExecutionProps,
@@ -51,6 +50,7 @@ use datafusion_common::{
     tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRecursion, 
TreeNodeRewriter},
     JoinType as DFJoinType, ScalarValue,
 };
+use datafusion_expr::ScalarUDF;
 use datafusion_physical_expr_common::aggregate::create_aggregate_expr;
 use itertools::Itertools;
 use jni::objects::GlobalRef;
@@ -65,7 +65,7 @@ use crate::{
                 avg_decimal::AvgDecimal,
                 bitwise_not::BitwiseNotExpr,
                 bloom_filter_might_contain::BloomFilterMightContain,
-                cast::{Cast, EvalMode},
+                cast::Cast,
                 checkoverflow::CheckOverflow,
                 correlation::Correlation,
                 covariance::Covariance,
@@ -97,6 +97,8 @@ use crate::{
     },
 };
 
+use super::expressions::{abs::CometAbsFunc, EvalMode};
+
 // For clippy error on type_complexity.
 type ExecResult<T> = Result<T, ExecutionError>;
 type PhyAggResult = Result<Vec<Arc<dyn AggregateExpr>>, ExecutionError>;
@@ -356,11 +358,7 @@ impl PhysicalPlanner {
                 let child = self.create_expr(expr.child.as_ref().unwrap(), 
input_schema)?;
                 let datatype = 
to_arrow_datatype(expr.datatype.as_ref().unwrap());
                 let timezone = expr.timezone.clone();
-                let eval_mode = match 
spark_expression::EvalMode::try_from(expr.eval_mode)? {
-                    spark_expression::EvalMode::Legacy => EvalMode::Legacy,
-                    spark_expression::EvalMode::Try => EvalMode::Try,
-                    spark_expression::EvalMode::Ansi => EvalMode::Ansi,
-                };
+                let eval_mode = expr.eval_mode.try_into()?;
 
                 Ok(Arc::new(Cast::new(child, datatype, eval_mode, timezone)))
             }
@@ -499,7 +497,12 @@ impl PhysicalPlanner {
                 let child = self.create_expr(expr.child.as_ref().unwrap(), 
input_schema.clone())?;
                 let return_type = child.data_type(&input_schema)?;
                 let args = vec![child];
-                let expr = ScalarFunctionExpr::new("abs", math::abs(), args, 
return_type);
+                let eval_mode = expr.eval_mode.try_into()?;
+                let comet_abs = 
Arc::new(ScalarUDF::new_from_impl(CometAbsFunc::new(
+                    eval_mode,
+                    return_type.to_string(),
+                )?));
+                let expr = ScalarFunctionExpr::new("abs", comet_abs, args, 
return_type);
                 Ok(Arc::new(expr))
             }
             ExprStruct::CaseWhen(case_when) => {
diff --git a/core/src/execution/proto/expr.proto 
b/core/src/execution/proto/expr.proto
index 5192bbd4..093b07b3 100644
--- a/core/src/execution/proto/expr.proto
+++ b/core/src/execution/proto/expr.proto
@@ -480,6 +480,7 @@ message BitwiseNot {
 
 message Abs {
   Expr child = 1;
+  EvalMode eval_mode = 2;
 }
 
 message Subquery {
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala 
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index 5a0ad38d..c1c8b5c5 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -1476,15 +1476,15 @@ object QueryPlanSerde extends Logging with 
ShimQueryPlanSerde with CometExprShim
             None
           }
 
-        case Abs(child, _) =>
+        case Abs(child, failOnErr) =>
           val childExpr = exprToProtoInternal(child, inputs)
           if (childExpr.isDefined) {
-            val abs =
-              ExprOuterClass.Abs
-                .newBuilder()
-                .setChild(childExpr.get)
-                .build()
-            Some(Expr.newBuilder().setAbs(abs).build())
+            val evalModeStr =
+              if (failOnErr) ExprOuterClass.EvalMode.ANSI else 
ExprOuterClass.EvalMode.LEGACY
+            val absBuilder = ExprOuterClass.Abs.newBuilder()
+            absBuilder.setChild(childExpr.get)
+            absBuilder.setEvalMode(evalModeStr)
+            Some(Expr.newBuilder().setAbs(absBuilder).build())
           } else {
             withInfo(expr, child)
             None
diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala 
b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
index 10fbc468..a2b6edd0 100644
--- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
@@ -19,6 +19,9 @@
 
 package org.apache.comet
 
+import scala.reflect.ClassTag
+import scala.reflect.runtime.universe.TypeTag
+
 import org.apache.hadoop.fs.Path
 import org.apache.spark.sql.{CometTestBase, DataFrame, Row}
 import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
@@ -850,6 +853,57 @@ class CometExpressionSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
     }
   }
 
+  test("abs Overflow ansi mode") {
+
+    def testAbsAnsiOverflow[T <: Product: ClassTag: TypeTag](data: Seq[T]): 
Unit = {
+      withParquetTable(data, "tbl") {
+        checkSparkMaybeThrows(sql("select abs(_1), abs(_2) from tbl")) match {
+          case (Some(sparkExc), Some(cometExc)) =>
+            val cometErrorPattern =
+              """.+[ARITHMETIC_OVERFLOW].+overflow. If necessary set 
"spark.sql.ansi.enabled" to "false" to bypass this error.""".r
+            
assert(cometErrorPattern.findFirstIn(cometExc.getMessage).isDefined)
+            assert(sparkExc.getMessage.contains("overflow"))
+          case _ => fail("Exception should be thrown")
+        }
+      }
+    }
+
+    def testAbsAnsi[T <: Product: ClassTag: TypeTag](data: Seq[T]): Unit = {
+      withParquetTable(data, "tbl") {
+        checkSparkAnswerAndOperator("select abs(_1), abs(_2) from tbl")
+      }
+    }
+
+    withSQLConf(
+      SQLConf.ANSI_ENABLED.key -> "true",
+      CometConf.COMET_ANSI_MODE_ENABLED.key -> "true") {
+      testAbsAnsiOverflow(Seq((Byte.MaxValue, Byte.MinValue)))
+      testAbsAnsiOverflow(Seq((Short.MaxValue, Short.MinValue)))
+      testAbsAnsiOverflow(Seq((Int.MaxValue, Int.MinValue)))
+      testAbsAnsiOverflow(Seq((Long.MaxValue, Long.MinValue)))
+      testAbsAnsi(Seq((Float.MaxValue, Float.MinValue)))
+      testAbsAnsi(Seq((Double.MaxValue, Double.MinValue)))
+    }
+  }
+
+  test("abs Overflow legacy mode") {
+
+    def testAbsLegacyOverflow[T <: Product: ClassTag: TypeTag](data: Seq[T]): 
Unit = {
+      withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") {
+        withParquetTable(data, "tbl") {
+          checkSparkAnswerAndOperator("select abs(_1), abs(_2) from tbl")
+        }
+      }
+    }
+
+    testAbsLegacyOverflow(Seq((Byte.MaxValue, Byte.MinValue)))
+    testAbsLegacyOverflow(Seq((Short.MaxValue, Short.MinValue)))
+    testAbsLegacyOverflow(Seq((Int.MaxValue, Int.MinValue)))
+    testAbsLegacyOverflow(Seq((Long.MaxValue, Long.MinValue)))
+    testAbsLegacyOverflow(Seq((Float.MaxValue, Float.MinValue)))
+    testAbsLegacyOverflow(Seq((Double.MaxValue, Double.MinValue)))
+  }
+
   test("ceil and floor") {
     Seq("true", "false").foreach { dictionary =>
       withSQLConf(


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to