This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 03c062646 feat: implement_ansi_eval_mode_arithmetic (#2136)
03c062646 is described below

commit 03c062646ba183069b2fd252df23d00e1f2e51c8
Author: B Vadlamani <[email protected]>
AuthorDate: Thu Sep 25 13:59:00 2025 -0700

    feat: implement_ansi_eval_mode_arithmetic (#2136)
---
 dev/diffs/3.4.3.diff                               |  46 +++++++-
 dev/diffs/3.5.6.diff                               |  46 +++++++-
 native/core/src/execution/planner.rs               |  23 ++--
 native/spark-expr/src/comet_scalar_funcs.rs        |  36 +++++-
 native/spark-expr/src/lib.rs                       |   5 +-
 .../src/math_funcs/checked_arithmetic.rs           | 126 +++++++++++++++++++--
 .../scala/org/apache/comet/serde/arithmetic.scala  |  35 +-----
 .../org/apache/comet/CometExpressionSuite.scala    | 106 +++++++++++++++++
 8 files changed, 357 insertions(+), 66 deletions(-)

diff --git a/dev/diffs/3.4.3.diff b/dev/diffs/3.4.3.diff
index 1c0ca867d..ab9ac0888 100644
--- a/dev/diffs/3.4.3.diff
+++ b/dev/diffs/3.4.3.diff
@@ -193,6 +193,19 @@ index 41fd4de2a09..44cd244d3b0 100644
  -- Test aggregate operator with codegen on and off.
  --CONFIG_DIM1 spark.sql.codegen.wholeStage=true
  --CONFIG_DIM1 
spark.sql.codegen.wholeStage=false,spark.sql.codegen.factoryMode=CODEGEN_ONLY
+diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int4.sql 
b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int4.sql
+index 3a409eea348..38fed024c98 100644
+--- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int4.sql
++++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int4.sql
+@@ -69,6 +69,8 @@ SELECT '' AS one, i.* FROM INT4_TBL i WHERE (i.f1 % 
smallint('2')) = smallint('1
+ -- any evens
+ SELECT '' AS three, i.* FROM INT4_TBL i WHERE (i.f1 % int('2')) = 
smallint('0');
+ 
++-- https://github.com/apache/datafusion-comet/issues/2215
++--SET spark.comet.exec.enabled=false
+ -- [SPARK-28024] Incorrect value when out of range
+ SELECT '' AS five, i.f1, i.f1 * smallint('2') AS x FROM INT4_TBL i;
+ 
 diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int8.sql 
b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int8.sql
 index fac23b4a26f..2b73732c33f 100644
 --- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int8.sql
@@ -881,7 +894,7 @@ index b5b34922694..a72403780c4 100644
    protected val baseResourcePath = {
      // use the same way as `SQLQueryTestSuite` to get the resource path
 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
-index 525d97e4998..5e04319dd97 100644
+index 525d97e4998..843f0472c23 100644
 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
 +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
 @@ -1508,7 +1508,8 @@ class SQLQuerySuite extends QueryTest with 
SharedSparkSession with AdaptiveSpark
@@ -894,7 +907,27 @@ index 525d97e4998..5e04319dd97 100644
      AccumulatorSuite.verifyPeakExecutionMemorySet(sparkContext, "external 
sort") {
        sql("SELECT * FROM testData2 ORDER BY a ASC, b ASC").collect()
      }
-@@ -4467,7 +4468,11 @@ class SQLQuerySuite extends QueryTest with 
SharedSparkSession with AdaptiveSpark
+@@ -4429,7 +4430,8 @@ class SQLQuerySuite extends QueryTest with 
SharedSparkSession with AdaptiveSpark
+   }
+ 
+   test("SPARK-39166: Query context of binary arithmetic should be serialized 
to executors" +
+-    " when WSCG is off") {
++    " when WSCG is off",
++    IgnoreComet("https://github.com/apache/datafusion-comet/issues/2215";)) {
+     withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false",
+       SQLConf.ANSI_ENABLED.key -> "true") {
+       withTable("t") {
+@@ -4450,7 +4452,8 @@ class SQLQuerySuite extends QueryTest with 
SharedSparkSession with AdaptiveSpark
+   }
+ 
+   test("SPARK-39175: Query context of Cast should be serialized to executors" 
+
+-    " when WSCG is off") {
++    " when WSCG is off",
++    IgnoreComet("https://github.com/apache/datafusion-comet/issues/2215";)) {
+     withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false",
+       SQLConf.ANSI_ENABLED.key -> "true") {
+       withTable("t") {
+@@ -4467,14 +4470,19 @@ class SQLQuerySuite extends QueryTest with 
SharedSparkSession with AdaptiveSpark
            val msg = intercept[SparkException] {
              sql(query).collect()
            }.getMessage
@@ -907,6 +940,15 @@ index 525d97e4998..5e04319dd97 100644
          }
        }
      }
+   }
+ 
+   test("SPARK-39190,SPARK-39208,SPARK-39210: Query context of decimal 
overflow error should " +
+-    "be serialized to executors when WSCG is off") {
++    "be serialized to executors when WSCG is off",
++    IgnoreComet("https://github.com/apache/datafusion-comet/issues/2215";)) {
+     withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false",
+       SQLConf.ANSI_ENABLED.key -> "true") {
+       withTable("t") {
 diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
 index 48ad10992c5..51d1ee65422 100644
 --- 
a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
diff --git a/dev/diffs/3.5.6.diff b/dev/diffs/3.5.6.diff
index f3909d074..63f0d3eb0 100644
--- a/dev/diffs/3.5.6.diff
+++ b/dev/diffs/3.5.6.diff
@@ -172,6 +172,19 @@ index 41fd4de2a09..44cd244d3b0 100644
  -- Test aggregate operator with codegen on and off.
  --CONFIG_DIM1 spark.sql.codegen.wholeStage=true
  --CONFIG_DIM1 
spark.sql.codegen.wholeStage=false,spark.sql.codegen.factoryMode=CODEGEN_ONLY
+diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int4.sql 
b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int4.sql
+index 3a409eea348..38fed024c98 100644
+--- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int4.sql
++++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int4.sql
+@@ -69,6 +69,8 @@ SELECT '' AS one, i.* FROM INT4_TBL i WHERE (i.f1 % 
smallint('2')) = smallint('1
+ -- any evens
+ SELECT '' AS three, i.* FROM INT4_TBL i WHERE (i.f1 % int('2')) = 
smallint('0');
+ 
++-- https://github.com/apache/datafusion-comet/issues/2215
++--SET spark.comet.exec.enabled=false
+ -- [SPARK-28024] Incorrect value when out of range
+ SELECT '' AS five, i.f1, i.f1 * smallint('2') AS x FROM INT4_TBL i;
+ 
 diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int8.sql 
b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int8.sql
 index fac23b4a26f..2b73732c33f 100644
 --- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int8.sql
@@ -866,7 +879,7 @@ index c26757c9cff..d55775f09d7 100644
    protected val baseResourcePath = {
      // use the same way as `SQLQueryTestSuite` to get the resource path
 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
-index 793a0da6a86..e48e74091cb 100644
+index 793a0da6a86..181bfc16e4b 100644
 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
 +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
 @@ -1521,7 +1521,8 @@ class SQLQuerySuite extends QueryTest with 
SharedSparkSession with AdaptiveSpark
@@ -879,7 +892,27 @@ index 793a0da6a86..e48e74091cb 100644
      AccumulatorSuite.verifyPeakExecutionMemorySet(sparkContext, "external 
sort") {
        sql("SELECT * FROM testData2 ORDER BY a ASC, b ASC").collect()
      }
-@@ -4497,7 +4498,11 @@ class SQLQuerySuite extends QueryTest with 
SharedSparkSession with AdaptiveSpark
+@@ -4459,7 +4460,8 @@ class SQLQuerySuite extends QueryTest with 
SharedSparkSession with AdaptiveSpark
+   }
+ 
+   test("SPARK-39166: Query context of binary arithmetic should be serialized 
to executors" +
+-    " when WSCG is off") {
++    " when WSCG is off",
++    IgnoreComet("https://github.com/apache/datafusion-comet/issues/2215";)) {
+     withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false",
+       SQLConf.ANSI_ENABLED.key -> "true") {
+       withTable("t") {
+@@ -4480,7 +4482,8 @@ class SQLQuerySuite extends QueryTest with 
SharedSparkSession with AdaptiveSpark
+   }
+ 
+   test("SPARK-39175: Query context of Cast should be serialized to executors" 
+
+-    " when WSCG is off") {
++    " when WSCG is off",
++    IgnoreComet("https://github.com/apache/datafusion-comet/issues/2215";)) {
+     withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false",
+       SQLConf.ANSI_ENABLED.key -> "true") {
+       withTable("t") {
+@@ -4497,14 +4500,19 @@ class SQLQuerySuite extends QueryTest with 
SharedSparkSession with AdaptiveSpark
            val msg = intercept[SparkException] {
              sql(query).collect()
            }.getMessage
@@ -892,6 +925,15 @@ index 793a0da6a86..e48e74091cb 100644
          }
        }
      }
+   }
+ 
+   test("SPARK-39190,SPARK-39208,SPARK-39210: Query context of decimal 
overflow error should " +
+-    "be serialized to executors when WSCG is off") {
++    "be serialized to executors when WSCG is off",
++    IgnoreComet("https://github.com/apache/datafusion-comet/issues/2215";)) {
+     withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false",
+       SQLConf.ANSI_ENABLED.key -> "true") {
+       withTable("t") {
 diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
 index fa1a64460fc..1d2e215d6a3 100644
 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
diff --git a/native/core/src/execution/planner.rs 
b/native/core/src/execution/planner.rs
index 64efa31d5..517c037e9 100644
--- a/native/core/src/execution/planner.rs
+++ b/native/core/src/execution/planner.rs
@@ -62,8 +62,9 @@ use datafusion::{
     prelude::SessionContext,
 };
 use datafusion_comet_spark_expr::{
-    create_comet_physical_fun, create_modulo_expr, create_negate_expr, 
BinaryOutputStyle,
-    BloomFilterAgg, BloomFilterMightContain, EvalMode, SparkHour, SparkMinute, 
SparkSecond,
+    create_comet_physical_fun, create_comet_physical_fun_with_eval_mode, 
create_modulo_expr,
+    create_negate_expr, BinaryOutputStyle, BloomFilterAgg, 
BloomFilterMightContain, EvalMode,
+    SparkHour, SparkMinute, SparkSecond,
 };
 
 use crate::execution::operators::ExecutionError::GeneralError;
@@ -242,8 +243,6 @@ impl PhysicalPlanner {
     ) -> Result<Arc<dyn PhysicalExpr>, ExecutionError> {
         match spark_expr.expr_struct.as_ref().unwrap() {
             ExprStruct::Add(expr) => {
-                // TODO respect ANSI eval mode
-                // https://github.com/apache/datafusion-comet/issues/536
                 let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?;
                 self.create_binary_expr(
                     expr.left.as_ref().unwrap(),
@@ -255,8 +254,6 @@ impl PhysicalPlanner {
                 )
             }
             ExprStruct::Subtract(expr) => {
-                // TODO respect ANSI eval mode
-                // https://github.com/apache/datafusion-comet/issues/535
                 let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?;
                 self.create_binary_expr(
                     expr.left.as_ref().unwrap(),
@@ -268,8 +265,6 @@ impl PhysicalPlanner {
                 )
             }
             ExprStruct::Multiply(expr) => {
-                // TODO respect ANSI eval mode
-                // https://github.com/apache/datafusion-comet/issues/534
                 let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?;
                 self.create_binary_expr(
                     expr.left.as_ref().unwrap(),
@@ -281,8 +276,6 @@ impl PhysicalPlanner {
                 )
             }
             ExprStruct::Divide(expr) => {
-                // TODO respect ANSI eval mode
-                // https://github.com/apache/datafusion-comet/issues/533
                 let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?;
                 self.create_binary_expr(
                     expr.left.as_ref().unwrap(),
@@ -1010,21 +1003,25 @@ impl PhysicalPlanner {
             }
             _ => {
                 let data_type = return_type.map(to_arrow_datatype).unwrap();
-                if eval_mode == EvalMode::Try && data_type.is_integer() {
+                if [EvalMode::Try, EvalMode::Ansi].contains(&eval_mode)
+                    && (data_type.is_integer()
+                        || (data_type.is_floating() && op == 
DataFusionOperator::Divide))
+                {
                     let op_str = match op {
                         DataFusionOperator::Plus => "checked_add",
                         DataFusionOperator::Minus => "checked_sub",
                         DataFusionOperator::Multiply => "checked_mul",
                         DataFusionOperator::Divide => "checked_div",
                         _ => {
-                            todo!("Operator yet to be implemented!");
+                            todo!("ANSI mode for Operator yet to be 
implemented!");
                         }
                     };
-                    let fun_expr = create_comet_physical_fun(
+                    let fun_expr = create_comet_physical_fun_with_eval_mode(
                         op_str,
                         data_type.clone(),
                         &self.session_ctx.state(),
                         None,
+                        eval_mode,
                     )?;
                     Ok(Arc::new(ScalarFunctionExpr::new(
                         op_str,
diff --git a/native/spark-expr/src/comet_scalar_funcs.rs 
b/native/spark-expr/src/comet_scalar_funcs.rs
index 93a820ba9..f96ddffce 100644
--- a/native/spark-expr/src/comet_scalar_funcs.rs
+++ b/native/spark-expr/src/comet_scalar_funcs.rs
@@ -21,7 +21,7 @@ use crate::math_funcs::modulo_expr::spark_modulo;
 use crate::{
     spark_array_repeat, spark_ceil, spark_date_add, spark_date_sub, 
spark_decimal_div,
     spark_decimal_integral_div, spark_floor, spark_hex, spark_isnan, 
spark_make_decimal,
-    spark_read_side_padding, spark_round, spark_rpad, spark_unhex, 
spark_unscaled_value,
+    spark_read_side_padding, spark_round, spark_rpad, spark_unhex, 
spark_unscaled_value, EvalMode,
     SparkBitwiseCount, SparkBitwiseGet, SparkBitwiseNot, SparkDateTrunc, 
SparkStringSpace,
 };
 use arrow::datatypes::DataType;
@@ -64,6 +64,15 @@ macro_rules! make_comet_scalar_udf {
         );
         Ok(Arc::new(ScalarUDF::new_from_impl(scalar_func)))
     }};
+    ($name:expr, $func:ident, $data_type:ident, $eval_mode:ident) => {{
+        let scalar_func = CometScalarFunction::new(
+            $name.to_string(),
+            Signature::variadic_any(Volatility::Immutable),
+            $data_type.clone(),
+            Arc::new(move |args| $func(args, &$data_type, $eval_mode)),
+        );
+        Ok(Arc::new(ScalarUDF::new_from_impl(scalar_func)))
+    }};
 }
 
 /// Create a physical scalar function.
@@ -72,6 +81,23 @@ pub fn create_comet_physical_fun(
     data_type: DataType,
     registry: &dyn FunctionRegistry,
     fail_on_error: Option<bool>,
+) -> Result<Arc<ScalarUDF>, DataFusionError> {
+    create_comet_physical_fun_with_eval_mode(
+        fun_name,
+        data_type,
+        registry,
+        fail_on_error,
+        EvalMode::Legacy,
+    )
+}
+
+/// Create a physical scalar function with eval mode. Goal is to deprecate 
above function once all the operators have ANSI support
+pub fn create_comet_physical_fun_with_eval_mode(
+    fun_name: &str,
+    data_type: DataType,
+    registry: &dyn FunctionRegistry,
+    fail_on_error: Option<bool>,
+    eval_mode: EvalMode,
 ) -> Result<Arc<ScalarUDF>, DataFusionError> {
     match fun_name {
         "ceil" => {
@@ -117,16 +143,16 @@ pub fn create_comet_physical_fun(
             )
         }
         "checked_add" => {
-            make_comet_scalar_udf!("checked_add", checked_add, data_type)
+            make_comet_scalar_udf!("checked_add", checked_add, data_type, 
eval_mode)
         }
         "checked_sub" => {
-            make_comet_scalar_udf!("checked_sub", checked_sub, data_type)
+            make_comet_scalar_udf!("checked_sub", checked_sub, data_type, 
eval_mode)
         }
         "checked_mul" => {
-            make_comet_scalar_udf!("checked_mul", checked_mul, data_type)
+            make_comet_scalar_udf!("checked_mul", checked_mul, data_type, 
eval_mode)
         }
         "checked_div" => {
-            make_comet_scalar_udf!("checked_div", checked_div, data_type)
+            make_comet_scalar_udf!("checked_div", checked_div, data_type, 
eval_mode)
         }
         "murmur3_hash" => {
             let func = Arc::new(spark_murmur3_hash);
diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs
index af5677a9b..7bdc7ff51 100644
--- a/native/spark-expr/src/lib.rs
+++ b/native/spark-expr/src/lib.rs
@@ -64,7 +64,10 @@ pub use conditional_funcs::*;
 pub use conversion_funcs::*;
 pub use nondetermenistic_funcs::*;
 
-pub use comet_scalar_funcs::{create_comet_physical_fun, 
register_all_comet_functions};
+pub use comet_scalar_funcs::{
+    create_comet_physical_fun, create_comet_physical_fun_with_eval_mode,
+    register_all_comet_functions,
+};
 pub use datetime_funcs::{
     spark_date_add, spark_date_sub, SparkDateTrunc, SparkHour, SparkMinute, 
SparkSecond,
     TimestampTruncExpr,
diff --git a/native/spark-expr/src/math_funcs/checked_arithmetic.rs 
b/native/spark-expr/src/math_funcs/checked_arithmetic.rs
index 0312cdb0b..bb4118f86 100644
--- a/native/spark-expr/src/math_funcs/checked_arithmetic.rs
+++ b/native/spark-expr/src/math_funcs/checked_arithmetic.rs
@@ -18,7 +18,11 @@
 use arrow::array::{Array, ArrowNativeTypeOp, PrimitiveArray, PrimitiveBuilder};
 use arrow::array::{ArrayRef, AsArray};
 
-use arrow::datatypes::{ArrowPrimitiveType, DataType, Int32Type, Int64Type};
+use crate::{divide_by_zero_error, EvalMode, SparkError};
+use arrow::datatypes::{
+    ArrowPrimitiveType, DataType, Float16Type, Float32Type, Float64Type, 
Int16Type, Int32Type,
+    Int64Type, Int8Type,
+};
 use datafusion::common::DataFusionError;
 use datafusion::physical_plan::ColumnarValue;
 use std::sync::Arc;
@@ -27,6 +31,7 @@ pub fn try_arithmetic_kernel<T>(
     left: &PrimitiveArray<T>,
     right: &PrimitiveArray<T>,
     op: &str,
+    is_ansi_mode: bool,
 ) -> Result<ArrayRef, DataFusionError>
 where
     T: ArrowPrimitiveType,
@@ -39,7 +44,19 @@ where
                 if left.is_null(i) || right.is_null(i) {
                     builder.append_null();
                 } else {
-                    
builder.append_option(left.value(i).add_checked(right.value(i)).ok());
+                    match left.value(i).add_checked(right.value(i)) {
+                        Ok(v) => builder.append_value(v),
+                        Err(_e) => {
+                            if is_ansi_mode {
+                                return Err(SparkError::ArithmeticOverflow {
+                                    from_type: String::from("integer"),
+                                }
+                                .into());
+                            } else {
+                                builder.append_null();
+                            }
+                        }
+                    }
                 }
             }
         }
@@ -48,7 +65,19 @@ where
                 if left.is_null(i) || right.is_null(i) {
                     builder.append_null();
                 } else {
-                    
builder.append_option(left.value(i).sub_checked(right.value(i)).ok());
+                    match left.value(i).sub_checked(right.value(i)) {
+                        Ok(v) => builder.append_value(v),
+                        Err(_e) => {
+                            if is_ansi_mode {
+                                return Err(SparkError::ArithmeticOverflow {
+                                    from_type: String::from("integer"),
+                                }
+                                .into());
+                            } else {
+                                builder.append_null();
+                            }
+                        }
+                    }
                 }
             }
         }
@@ -57,7 +86,19 @@ where
                 if left.is_null(i) || right.is_null(i) {
                     builder.append_null();
                 } else {
-                    
builder.append_option(left.value(i).mul_checked(right.value(i)).ok());
+                    match left.value(i).mul_checked(right.value(i)) {
+                        Ok(v) => builder.append_value(v),
+                        Err(_e) => {
+                            if is_ansi_mode {
+                                return Err(SparkError::ArithmeticOverflow {
+                                    from_type: String::from("integer"),
+                                }
+                                .into());
+                            } else {
+                                builder.append_null();
+                            }
+                        }
+                    }
                 }
             }
         }
@@ -66,7 +107,23 @@ where
                 if left.is_null(i) || right.is_null(i) {
                     builder.append_null();
                 } else {
-                    
builder.append_option(left.value(i).div_checked(right.value(i)).ok());
+                    match left.value(i).div_checked(right.value(i)) {
+                        Ok(v) => builder.append_value(v),
+                        Err(_e) => {
+                            if is_ansi_mode {
+                                return if right.value(i).is_zero() {
+                                    Err(divide_by_zero_error().into())
+                                } else {
+                                    return Err(SparkError::ArithmeticOverflow {
+                                        from_type: String::from("integer"),
+                                    }
+                                    .into());
+                                };
+                            } else {
+                                builder.append_null();
+                            }
+                        }
+                    }
                 }
             }
         }
@@ -84,39 +141,55 @@ where
 pub fn checked_add(
     args: &[ColumnarValue],
     data_type: &DataType,
+    eval_mode: EvalMode,
 ) -> Result<ColumnarValue, DataFusionError> {
-    checked_arithmetic_internal(args, data_type, "checked_add")
+    checked_arithmetic_internal(args, data_type, "checked_add", eval_mode)
 }
 
 pub fn checked_sub(
     args: &[ColumnarValue],
     data_type: &DataType,
+    eval_mode: EvalMode,
 ) -> Result<ColumnarValue, DataFusionError> {
-    checked_arithmetic_internal(args, data_type, "checked_sub")
+    checked_arithmetic_internal(args, data_type, "checked_sub", eval_mode)
 }
 
 pub fn checked_mul(
     args: &[ColumnarValue],
     data_type: &DataType,
+    eval_mode: EvalMode,
 ) -> Result<ColumnarValue, DataFusionError> {
-    checked_arithmetic_internal(args, data_type, "checked_mul")
+    checked_arithmetic_internal(args, data_type, "checked_mul", eval_mode)
 }
 
 pub fn checked_div(
     args: &[ColumnarValue],
     data_type: &DataType,
+    eval_mode: EvalMode,
 ) -> Result<ColumnarValue, DataFusionError> {
-    checked_arithmetic_internal(args, data_type, "checked_div")
+    checked_arithmetic_internal(args, data_type, "checked_div", eval_mode)
 }
 
 fn checked_arithmetic_internal(
     args: &[ColumnarValue],
     data_type: &DataType,
     op: &str,
+    eval_mode: EvalMode,
 ) -> Result<ColumnarValue, DataFusionError> {
     let left = &args[0];
     let right = &args[1];
 
+    let is_ansi_mode = match eval_mode {
+        EvalMode::Try => false,
+        EvalMode::Ansi => true,
+        _ => {
+            return Err(DataFusionError::Internal(format!(
+                "Unsupported mode : {:?}",
+                eval_mode
+            )))
+        }
+    };
+
     let (left_arr, right_arr): (ArrayRef, ArrayRef) = match (left, right) {
         (ColumnarValue::Array(l), ColumnarValue::Array(r)) => (Arc::clone(l), 
Arc::clone(r)),
         (ColumnarValue::Scalar(l), ColumnarValue::Array(r)) => {
@@ -128,17 +201,50 @@ fn checked_arithmetic_internal(
         (ColumnarValue::Scalar(l), ColumnarValue::Scalar(r)) => 
(l.to_array()?, r.to_array()?),
     };
 
-    // Rust only supports checked_arithmetic on Int32 and Int64
+    // Rust only supports checked_arithmetic on numeric types
     let result_array = match data_type {
+        DataType::Int8 => try_arithmetic_kernel::<Int8Type>(
+            left_arr.as_primitive::<Int8Type>(),
+            right_arr.as_primitive::<Int8Type>(),
+            op,
+            is_ansi_mode,
+        ),
+        DataType::Int16 => try_arithmetic_kernel::<Int16Type>(
+            left_arr.as_primitive::<Int16Type>(),
+            right_arr.as_primitive::<Int16Type>(),
+            op,
+            is_ansi_mode,
+        ),
         DataType::Int32 => try_arithmetic_kernel::<Int32Type>(
             left_arr.as_primitive::<Int32Type>(),
             right_arr.as_primitive::<Int32Type>(),
             op,
+            is_ansi_mode,
         ),
         DataType::Int64 => try_arithmetic_kernel::<Int64Type>(
             left_arr.as_primitive::<Int64Type>(),
             right_arr.as_primitive::<Int64Type>(),
             op,
+            is_ansi_mode,
+        ),
+        // Spark always casts division operands to floats
+        DataType::Float16 if (op == "checked_div") => 
try_arithmetic_kernel::<Float16Type>(
+            left_arr.as_primitive::<Float16Type>(),
+            right_arr.as_primitive::<Float16Type>(),
+            op,
+            is_ansi_mode,
+        ),
+        DataType::Float32 if (op == "checked_div") => 
try_arithmetic_kernel::<Float32Type>(
+            left_arr.as_primitive::<Float32Type>(),
+            right_arr.as_primitive::<Float32Type>(),
+            op,
+            is_ansi_mode,
+        ),
+        DataType::Float64 if (op == "checked_div") => 
try_arithmetic_kernel::<Float64Type>(
+            left_arr.as_primitive::<Float64Type>(),
+            right_arr.as_primitive::<Float64Type>(),
+            op,
+            is_ansi_mode,
         ),
         _ => Err(DataFusionError::Internal(format!(
             "Unsupported data type: {:?}",
diff --git a/spark/src/main/scala/org/apache/comet/serde/arithmetic.scala 
b/spark/src/main/scala/org/apache/comet/serde/arithmetic.scala
index 0f1eeb758..4507dc107 100644
--- a/spark/src/main/scala/org/apache/comet/serde/arithmetic.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/arithmetic.scala
@@ -87,14 +87,6 @@ trait MathBase {
 
 object CometAdd extends CometExpressionSerde[Add] with MathBase {
 
-  override def getSupportLevel(expr: Add): SupportLevel = {
-    if (expr.evalMode == EvalMode.ANSI) {
-      Incompatible(Some("ANSI mode is not supported"))
-    } else {
-      Compatible(None)
-    }
-  }
-
   override def convert(
       expr: Add,
       inputs: Seq[Attribute],
@@ -117,14 +109,6 @@ object CometAdd extends CometExpressionSerde[Add] with 
MathBase {
 
 object CometSubtract extends CometExpressionSerde[Subtract] with MathBase {
 
-  override def getSupportLevel(expr: Subtract): SupportLevel = {
-    if (expr.evalMode == EvalMode.ANSI) {
-      Incompatible(Some("ANSI mode is not supported"))
-    } else {
-      Compatible(None)
-    }
-  }
-
   override def convert(
       expr: Subtract,
       inputs: Seq[Attribute],
@@ -147,14 +131,6 @@ object CometSubtract extends 
CometExpressionSerde[Subtract] with MathBase {
 
 object CometMultiply extends CometExpressionSerde[Multiply] with MathBase {
 
-  override def getSupportLevel(expr: Multiply): SupportLevel = {
-    if (expr.evalMode == EvalMode.ANSI) {
-      Incompatible(Some("ANSI mode is not supported"))
-    } else {
-      Compatible(None)
-    }
-  }
-
   override def convert(
       expr: Multiply,
       inputs: Seq[Attribute],
@@ -177,14 +153,6 @@ object CometMultiply extends 
CometExpressionSerde[Multiply] with MathBase {
 
 object CometDivide extends CometExpressionSerde[Divide] with MathBase {
 
-  override def getSupportLevel(expr: Divide): SupportLevel = {
-    if (expr.evalMode == EvalMode.ANSI) {
-      Incompatible(Some("ANSI mode is not supported"))
-    } else {
-      Compatible(None)
-    }
-  }
-
   override def convert(
       expr: Divide,
       inputs: Seq[Attribute],
@@ -192,7 +160,8 @@ object CometDivide extends CometExpressionSerde[Divide] 
with MathBase {
     // Datafusion now throws an exception for dividing by zero
     // See https://github.com/apache/arrow-datafusion/pull/6792
     // For now, use NullIf to swap zeros with nulls.
-    val rightExpr = nullIfWhenPrimitive(expr.right)
+    val rightExpr =
+      if (expr.evalMode != EvalMode.ANSI) nullIfWhenPrimitive(expr.right) else 
expr.right
     if (!supportedDataType(expr.left.dataType)) {
       withInfo(expr, s"Unsupported datatype ${expr.left.dataType}")
       return None
diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala 
b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
index 600f4e45b..daf0e45cc 100644
--- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
@@ -55,6 +55,11 @@ class CometExpressionSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
     }
   }
 
+  val ARITHMETIC_OVERFLOW_EXCEPTION_MSG =
+    """org.apache.comet.CometNativeException: [ARITHMETIC_OVERFLOW] integer 
overflow. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this 
error"""
+  val DIVIDE_BY_ZERO_EXCEPTION_MSG =
+    """org.apache.comet.CometNativeException: [DIVIDE_BY_ZERO] Division by 
zero. Use `try_divide` to tolerate divisor being 0 and return NULL instead"""
+
   test("compare true/false to negative zero") {
     Seq(false, true).foreach { dictionary =>
       withSQLConf("parquet.enable.dictionary" -> dictionary.toString) {
@@ -2864,6 +2869,107 @@ class CometExpressionSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
     }
   }
 
+  test("ANSI support for add") {
+    val data = Seq((Integer.MAX_VALUE, 1), (Integer.MIN_VALUE, -1))
+    withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") {
+      withParquetTable(data, "tbl") {
+        val res = spark.sql("""
+                              |SELECT
+                              |  _1 + _2
+                              |  from tbl
+                              |  """.stripMargin)
+
+        checkSparkMaybeThrows(res) match {
+          case (Some(sparkExc), Some(cometExc)) =>
+            
assert(cometExc.getMessage.contains(ARITHMETIC_OVERFLOW_EXCEPTION_MSG))
+            assert(sparkExc.getMessage.contains("overflow"))
+          case _ => fail("Exception should be thrown")
+        }
+      }
+    }
+  }
+
+  test("ANSI support for subtract") {
+    val data = Seq((Integer.MIN_VALUE, 1))
+    withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") {
+      withParquetTable(data, "tbl") {
+        val res = spark.sql("""
+                              |SELECT
+                              |  _1 - _2
+                              |  from tbl
+                              |  """.stripMargin)
+        checkSparkMaybeThrows(res) match {
+          case (Some(sparkExc), Some(cometExc)) =>
+            
assert(cometExc.getMessage.contains(ARITHMETIC_OVERFLOW_EXCEPTION_MSG))
+            assert(sparkExc.getMessage.contains("overflow"))
+          case _ => fail("Exception should be thrown")
+        }
+      }
+    }
+  }
+
+  test("ANSI support for multiply") {
+    val data = Seq((Integer.MAX_VALUE, 10))
+    withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") {
+      withParquetTable(data, "tbl") {
+        val res = spark.sql("""
+                              |SELECT
+                              |  _1 * _2
+                              |  from tbl
+                              |  """.stripMargin)
+
+        checkSparkMaybeThrows(res) match {
+          case (Some(sparkExc), Some(cometExc)) =>
+            
assert(cometExc.getMessage.contains(ARITHMETIC_OVERFLOW_EXCEPTION_MSG))
+            assert(sparkExc.getMessage.contains("overflow"))
+          case _ => fail("Exception should be thrown")
+        }
+      }
+    }
+  }
+
+  test("ANSI support for divide (division by zero)") {
+    //    TODO : Support ANSI mode in Integral divide -
+    val data = Seq((Integer.MIN_VALUE, 0))
+    withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") {
+      withParquetTable(data, "tbl") {
+        val res = spark.sql("""
+                              |SELECT
+                              |  _1 / _2
+                              |  from tbl
+                              |  """.stripMargin)
+
+        checkSparkMaybeThrows(res) match {
+          case (Some(sparkExc), Some(cometExc)) =>
+            assert(cometExc.getMessage.contains(DIVIDE_BY_ZERO_EXCEPTION_MSG))
+            assert(sparkExc.getMessage.contains("Division by zero"))
+          case _ => fail("Exception should be thrown")
+        }
+      }
+    }
+  }
+
+  test("ANSI support for divide (division by zero) float division") {
+    //    TODO : Support ANSI mode in Integral divide -
+    val data = Seq((Float.MinPositiveValue, 0.0))
+    withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") {
+      withParquetTable(data, "tbl") {
+        val res = spark.sql("""
+                              |SELECT
+                              |  _1 / _2
+                              |  from tbl
+                              |  """.stripMargin)
+
+        checkSparkMaybeThrows(res) match {
+          case (Some(sparkExc), Some(cometExc)) =>
+            assert(cometExc.getMessage.contains(DIVIDE_BY_ZERO_EXCEPTION_MSG))
+            assert(sparkExc.getMessage.contains("Division by zero"))
+          case _ => fail("Exception should be thrown")
+        }
+      }
+    }
+  }
+
   test("test integral divide overflow for decimal") {
     if (isSpark40Plus) {
       Seq(true, false)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to