This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 03c062646 feat: implement_ansi_eval_mode_arithmetic (#2136)
03c062646 is described below
commit 03c062646ba183069b2fd252df23d00e1f2e51c8
Author: B Vadlamani <[email protected]>
AuthorDate: Thu Sep 25 13:59:00 2025 -0700
feat: implement_ansi_eval_mode_arithmetic (#2136)
---
dev/diffs/3.4.3.diff | 46 +++++++-
dev/diffs/3.5.6.diff | 46 +++++++-
native/core/src/execution/planner.rs | 23 ++--
native/spark-expr/src/comet_scalar_funcs.rs | 36 +++++-
native/spark-expr/src/lib.rs | 5 +-
.../src/math_funcs/checked_arithmetic.rs | 126 +++++++++++++++++++--
.../scala/org/apache/comet/serde/arithmetic.scala | 35 +-----
.../org/apache/comet/CometExpressionSuite.scala | 106 +++++++++++++++++
8 files changed, 357 insertions(+), 66 deletions(-)
diff --git a/dev/diffs/3.4.3.diff b/dev/diffs/3.4.3.diff
index 1c0ca867d..ab9ac0888 100644
--- a/dev/diffs/3.4.3.diff
+++ b/dev/diffs/3.4.3.diff
@@ -193,6 +193,19 @@ index 41fd4de2a09..44cd244d3b0 100644
-- Test aggregate operator with codegen on and off.
--CONFIG_DIM1 spark.sql.codegen.wholeStage=true
--CONFIG_DIM1
spark.sql.codegen.wholeStage=false,spark.sql.codegen.factoryMode=CODEGEN_ONLY
+diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int4.sql
b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int4.sql
+index 3a409eea348..38fed024c98 100644
+--- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int4.sql
++++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int4.sql
+@@ -69,6 +69,8 @@ SELECT '' AS one, i.* FROM INT4_TBL i WHERE (i.f1 %
smallint('2')) = smallint('1
+ -- any evens
+ SELECT '' AS three, i.* FROM INT4_TBL i WHERE (i.f1 % int('2')) =
smallint('0');
+
++-- https://github.com/apache/datafusion-comet/issues/2215
++--SET spark.comet.exec.enabled=false
+ -- [SPARK-28024] Incorrect value when out of range
+ SELECT '' AS five, i.f1, i.f1 * smallint('2') AS x FROM INT4_TBL i;
+
diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int8.sql
b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int8.sql
index fac23b4a26f..2b73732c33f 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int8.sql
@@ -881,7 +894,7 @@ index b5b34922694..a72403780c4 100644
protected val baseResourcePath = {
// use the same way as `SQLQueryTestSuite` to get the resource path
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
-index 525d97e4998..5e04319dd97 100644
+index 525d97e4998..843f0472c23 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -1508,7 +1508,8 @@ class SQLQuerySuite extends QueryTest with
SharedSparkSession with AdaptiveSpark
@@ -894,7 +907,27 @@ index 525d97e4998..5e04319dd97 100644
AccumulatorSuite.verifyPeakExecutionMemorySet(sparkContext, "external
sort") {
sql("SELECT * FROM testData2 ORDER BY a ASC, b ASC").collect()
}
-@@ -4467,7 +4468,11 @@ class SQLQuerySuite extends QueryTest with
SharedSparkSession with AdaptiveSpark
+@@ -4429,7 +4430,8 @@ class SQLQuerySuite extends QueryTest with
SharedSparkSession with AdaptiveSpark
+ }
+
+ test("SPARK-39166: Query context of binary arithmetic should be serialized
to executors" +
+- " when WSCG is off") {
++ " when WSCG is off",
++ IgnoreComet("https://github.com/apache/datafusion-comet/issues/2215")) {
+ withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false",
+ SQLConf.ANSI_ENABLED.key -> "true") {
+ withTable("t") {
+@@ -4450,7 +4452,8 @@ class SQLQuerySuite extends QueryTest with
SharedSparkSession with AdaptiveSpark
+ }
+
+ test("SPARK-39175: Query context of Cast should be serialized to executors"
+
+- " when WSCG is off") {
++ " when WSCG is off",
++ IgnoreComet("https://github.com/apache/datafusion-comet/issues/2215")) {
+ withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false",
+ SQLConf.ANSI_ENABLED.key -> "true") {
+ withTable("t") {
+@@ -4467,14 +4470,19 @@ class SQLQuerySuite extends QueryTest with
SharedSparkSession with AdaptiveSpark
val msg = intercept[SparkException] {
sql(query).collect()
}.getMessage
@@ -907,6 +940,15 @@ index 525d97e4998..5e04319dd97 100644
}
}
}
+ }
+
+ test("SPARK-39190,SPARK-39208,SPARK-39210: Query context of decimal
overflow error should " +
+- "be serialized to executors when WSCG is off") {
++ "be serialized to executors when WSCG is off",
++ IgnoreComet("https://github.com/apache/datafusion-comet/issues/2215")) {
+ withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false",
+ SQLConf.ANSI_ENABLED.key -> "true") {
+ withTable("t") {
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
index 48ad10992c5..51d1ee65422 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
diff --git a/dev/diffs/3.5.6.diff b/dev/diffs/3.5.6.diff
index f3909d074..63f0d3eb0 100644
--- a/dev/diffs/3.5.6.diff
+++ b/dev/diffs/3.5.6.diff
@@ -172,6 +172,19 @@ index 41fd4de2a09..44cd244d3b0 100644
-- Test aggregate operator with codegen on and off.
--CONFIG_DIM1 spark.sql.codegen.wholeStage=true
--CONFIG_DIM1
spark.sql.codegen.wholeStage=false,spark.sql.codegen.factoryMode=CODEGEN_ONLY
+diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int4.sql
b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int4.sql
+index 3a409eea348..38fed024c98 100644
+--- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int4.sql
++++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int4.sql
+@@ -69,6 +69,8 @@ SELECT '' AS one, i.* FROM INT4_TBL i WHERE (i.f1 %
smallint('2')) = smallint('1
+ -- any evens
+ SELECT '' AS three, i.* FROM INT4_TBL i WHERE (i.f1 % int('2')) =
smallint('0');
+
++-- https://github.com/apache/datafusion-comet/issues/2215
++--SET spark.comet.exec.enabled=false
+ -- [SPARK-28024] Incorrect value when out of range
+ SELECT '' AS five, i.f1, i.f1 * smallint('2') AS x FROM INT4_TBL i;
+
diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int8.sql
b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int8.sql
index fac23b4a26f..2b73732c33f 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int8.sql
@@ -866,7 +879,7 @@ index c26757c9cff..d55775f09d7 100644
protected val baseResourcePath = {
// use the same way as `SQLQueryTestSuite` to get the resource path
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
-index 793a0da6a86..e48e74091cb 100644
+index 793a0da6a86..181bfc16e4b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -1521,7 +1521,8 @@ class SQLQuerySuite extends QueryTest with
SharedSparkSession with AdaptiveSpark
@@ -879,7 +892,27 @@ index 793a0da6a86..e48e74091cb 100644
AccumulatorSuite.verifyPeakExecutionMemorySet(sparkContext, "external
sort") {
sql("SELECT * FROM testData2 ORDER BY a ASC, b ASC").collect()
}
-@@ -4497,7 +4498,11 @@ class SQLQuerySuite extends QueryTest with
SharedSparkSession with AdaptiveSpark
+@@ -4459,7 +4460,8 @@ class SQLQuerySuite extends QueryTest with
SharedSparkSession with AdaptiveSpark
+ }
+
+ test("SPARK-39166: Query context of binary arithmetic should be serialized
to executors" +
+- " when WSCG is off") {
++ " when WSCG is off",
++ IgnoreComet("https://github.com/apache/datafusion-comet/issues/2215")) {
+ withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false",
+ SQLConf.ANSI_ENABLED.key -> "true") {
+ withTable("t") {
+@@ -4480,7 +4482,8 @@ class SQLQuerySuite extends QueryTest with
SharedSparkSession with AdaptiveSpark
+ }
+
+ test("SPARK-39175: Query context of Cast should be serialized to executors"
+
+- " when WSCG is off") {
++ " when WSCG is off",
++ IgnoreComet("https://github.com/apache/datafusion-comet/issues/2215")) {
+ withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false",
+ SQLConf.ANSI_ENABLED.key -> "true") {
+ withTable("t") {
+@@ -4497,14 +4500,19 @@ class SQLQuerySuite extends QueryTest with
SharedSparkSession with AdaptiveSpark
val msg = intercept[SparkException] {
sql(query).collect()
}.getMessage
@@ -892,6 +925,15 @@ index 793a0da6a86..e48e74091cb 100644
}
}
}
+ }
+
+ test("SPARK-39190,SPARK-39208,SPARK-39210: Query context of decimal
overflow error should " +
+- "be serialized to executors when WSCG is off") {
++ "be serialized to executors when WSCG is off",
++ IgnoreComet("https://github.com/apache/datafusion-comet/issues/2215")) {
+ withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false",
+ SQLConf.ANSI_ENABLED.key -> "true") {
+ withTable("t") {
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
index fa1a64460fc..1d2e215d6a3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
diff --git a/native/core/src/execution/planner.rs
b/native/core/src/execution/planner.rs
index 64efa31d5..517c037e9 100644
--- a/native/core/src/execution/planner.rs
+++ b/native/core/src/execution/planner.rs
@@ -62,8 +62,9 @@ use datafusion::{
prelude::SessionContext,
};
use datafusion_comet_spark_expr::{
- create_comet_physical_fun, create_modulo_expr, create_negate_expr,
BinaryOutputStyle,
- BloomFilterAgg, BloomFilterMightContain, EvalMode, SparkHour, SparkMinute,
SparkSecond,
+ create_comet_physical_fun, create_comet_physical_fun_with_eval_mode,
create_modulo_expr,
+ create_negate_expr, BinaryOutputStyle, BloomFilterAgg,
BloomFilterMightContain, EvalMode,
+ SparkHour, SparkMinute, SparkSecond,
};
use crate::execution::operators::ExecutionError::GeneralError;
@@ -242,8 +243,6 @@ impl PhysicalPlanner {
) -> Result<Arc<dyn PhysicalExpr>, ExecutionError> {
match spark_expr.expr_struct.as_ref().unwrap() {
ExprStruct::Add(expr) => {
- // TODO respect ANSI eval mode
- // https://github.com/apache/datafusion-comet/issues/536
let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?;
self.create_binary_expr(
expr.left.as_ref().unwrap(),
@@ -255,8 +254,6 @@ impl PhysicalPlanner {
)
}
ExprStruct::Subtract(expr) => {
- // TODO respect ANSI eval mode
- // https://github.com/apache/datafusion-comet/issues/535
let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?;
self.create_binary_expr(
expr.left.as_ref().unwrap(),
@@ -268,8 +265,6 @@ impl PhysicalPlanner {
)
}
ExprStruct::Multiply(expr) => {
- // TODO respect ANSI eval mode
- // https://github.com/apache/datafusion-comet/issues/534
let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?;
self.create_binary_expr(
expr.left.as_ref().unwrap(),
@@ -281,8 +276,6 @@ impl PhysicalPlanner {
)
}
ExprStruct::Divide(expr) => {
- // TODO respect ANSI eval mode
- // https://github.com/apache/datafusion-comet/issues/533
let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?;
self.create_binary_expr(
expr.left.as_ref().unwrap(),
@@ -1010,21 +1003,25 @@ impl PhysicalPlanner {
}
_ => {
let data_type = return_type.map(to_arrow_datatype).unwrap();
- if eval_mode == EvalMode::Try && data_type.is_integer() {
+ if [EvalMode::Try, EvalMode::Ansi].contains(&eval_mode)
+ && (data_type.is_integer()
+ || (data_type.is_floating() && op ==
DataFusionOperator::Divide))
+ {
let op_str = match op {
DataFusionOperator::Plus => "checked_add",
DataFusionOperator::Minus => "checked_sub",
DataFusionOperator::Multiply => "checked_mul",
DataFusionOperator::Divide => "checked_div",
_ => {
- todo!("Operator yet to be implemented!");
+ todo!("ANSI mode for Operator yet to be
implemented!");
}
};
- let fun_expr = create_comet_physical_fun(
+ let fun_expr = create_comet_physical_fun_with_eval_mode(
op_str,
data_type.clone(),
&self.session_ctx.state(),
None,
+ eval_mode,
)?;
Ok(Arc::new(ScalarFunctionExpr::new(
op_str,
diff --git a/native/spark-expr/src/comet_scalar_funcs.rs
b/native/spark-expr/src/comet_scalar_funcs.rs
index 93a820ba9..f96ddffce 100644
--- a/native/spark-expr/src/comet_scalar_funcs.rs
+++ b/native/spark-expr/src/comet_scalar_funcs.rs
@@ -21,7 +21,7 @@ use crate::math_funcs::modulo_expr::spark_modulo;
use crate::{
spark_array_repeat, spark_ceil, spark_date_add, spark_date_sub,
spark_decimal_div,
spark_decimal_integral_div, spark_floor, spark_hex, spark_isnan,
spark_make_decimal,
- spark_read_side_padding, spark_round, spark_rpad, spark_unhex,
spark_unscaled_value,
+ spark_read_side_padding, spark_round, spark_rpad, spark_unhex,
spark_unscaled_value, EvalMode,
SparkBitwiseCount, SparkBitwiseGet, SparkBitwiseNot, SparkDateTrunc,
SparkStringSpace,
};
use arrow::datatypes::DataType;
@@ -64,6 +64,15 @@ macro_rules! make_comet_scalar_udf {
);
Ok(Arc::new(ScalarUDF::new_from_impl(scalar_func)))
}};
+ ($name:expr, $func:ident, $data_type:ident, $eval_mode:ident) => {{
+ let scalar_func = CometScalarFunction::new(
+ $name.to_string(),
+ Signature::variadic_any(Volatility::Immutable),
+ $data_type.clone(),
+ Arc::new(move |args| $func(args, &$data_type, $eval_mode)),
+ );
+ Ok(Arc::new(ScalarUDF::new_from_impl(scalar_func)))
+ }};
}
/// Create a physical scalar function.
@@ -72,6 +81,23 @@ pub fn create_comet_physical_fun(
data_type: DataType,
registry: &dyn FunctionRegistry,
fail_on_error: Option<bool>,
+) -> Result<Arc<ScalarUDF>, DataFusionError> {
+ create_comet_physical_fun_with_eval_mode(
+ fun_name,
+ data_type,
+ registry,
+ fail_on_error,
+ EvalMode::Legacy,
+ )
+}
+
+/// Create a physical scalar function with eval mode. Goal is to deprecate
above function once all the operators have ANSI support
+pub fn create_comet_physical_fun_with_eval_mode(
+ fun_name: &str,
+ data_type: DataType,
+ registry: &dyn FunctionRegistry,
+ fail_on_error: Option<bool>,
+ eval_mode: EvalMode,
) -> Result<Arc<ScalarUDF>, DataFusionError> {
match fun_name {
"ceil" => {
@@ -117,16 +143,16 @@ pub fn create_comet_physical_fun(
)
}
"checked_add" => {
- make_comet_scalar_udf!("checked_add", checked_add, data_type)
+ make_comet_scalar_udf!("checked_add", checked_add, data_type,
eval_mode)
}
"checked_sub" => {
- make_comet_scalar_udf!("checked_sub", checked_sub, data_type)
+ make_comet_scalar_udf!("checked_sub", checked_sub, data_type,
eval_mode)
}
"checked_mul" => {
- make_comet_scalar_udf!("checked_mul", checked_mul, data_type)
+ make_comet_scalar_udf!("checked_mul", checked_mul, data_type,
eval_mode)
}
"checked_div" => {
- make_comet_scalar_udf!("checked_div", checked_div, data_type)
+ make_comet_scalar_udf!("checked_div", checked_div, data_type,
eval_mode)
}
"murmur3_hash" => {
let func = Arc::new(spark_murmur3_hash);
diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs
index af5677a9b..7bdc7ff51 100644
--- a/native/spark-expr/src/lib.rs
+++ b/native/spark-expr/src/lib.rs
@@ -64,7 +64,10 @@ pub use conditional_funcs::*;
pub use conversion_funcs::*;
pub use nondetermenistic_funcs::*;
-pub use comet_scalar_funcs::{create_comet_physical_fun,
register_all_comet_functions};
+pub use comet_scalar_funcs::{
+ create_comet_physical_fun, create_comet_physical_fun_with_eval_mode,
+ register_all_comet_functions,
+};
pub use datetime_funcs::{
spark_date_add, spark_date_sub, SparkDateTrunc, SparkHour, SparkMinute,
SparkSecond,
TimestampTruncExpr,
diff --git a/native/spark-expr/src/math_funcs/checked_arithmetic.rs
b/native/spark-expr/src/math_funcs/checked_arithmetic.rs
index 0312cdb0b..bb4118f86 100644
--- a/native/spark-expr/src/math_funcs/checked_arithmetic.rs
+++ b/native/spark-expr/src/math_funcs/checked_arithmetic.rs
@@ -18,7 +18,11 @@
use arrow::array::{Array, ArrowNativeTypeOp, PrimitiveArray, PrimitiveBuilder};
use arrow::array::{ArrayRef, AsArray};
-use arrow::datatypes::{ArrowPrimitiveType, DataType, Int32Type, Int64Type};
+use crate::{divide_by_zero_error, EvalMode, SparkError};
+use arrow::datatypes::{
+ ArrowPrimitiveType, DataType, Float16Type, Float32Type, Float64Type,
Int16Type, Int32Type,
+ Int64Type, Int8Type,
+};
use datafusion::common::DataFusionError;
use datafusion::physical_plan::ColumnarValue;
use std::sync::Arc;
@@ -27,6 +31,7 @@ pub fn try_arithmetic_kernel<T>(
left: &PrimitiveArray<T>,
right: &PrimitiveArray<T>,
op: &str,
+ is_ansi_mode: bool,
) -> Result<ArrayRef, DataFusionError>
where
T: ArrowPrimitiveType,
@@ -39,7 +44,19 @@ where
if left.is_null(i) || right.is_null(i) {
builder.append_null();
} else {
-
builder.append_option(left.value(i).add_checked(right.value(i)).ok());
+ match left.value(i).add_checked(right.value(i)) {
+ Ok(v) => builder.append_value(v),
+ Err(_e) => {
+ if is_ansi_mode {
+ return Err(SparkError::ArithmeticOverflow {
+ from_type: String::from("integer"),
+ }
+ .into());
+ } else {
+ builder.append_null();
+ }
+ }
+ }
}
}
}
@@ -48,7 +65,19 @@ where
if left.is_null(i) || right.is_null(i) {
builder.append_null();
} else {
-
builder.append_option(left.value(i).sub_checked(right.value(i)).ok());
+ match left.value(i).sub_checked(right.value(i)) {
+ Ok(v) => builder.append_value(v),
+ Err(_e) => {
+ if is_ansi_mode {
+ return Err(SparkError::ArithmeticOverflow {
+ from_type: String::from("integer"),
+ }
+ .into());
+ } else {
+ builder.append_null();
+ }
+ }
+ }
}
}
}
@@ -57,7 +86,19 @@ where
if left.is_null(i) || right.is_null(i) {
builder.append_null();
} else {
-
builder.append_option(left.value(i).mul_checked(right.value(i)).ok());
+ match left.value(i).mul_checked(right.value(i)) {
+ Ok(v) => builder.append_value(v),
+ Err(_e) => {
+ if is_ansi_mode {
+ return Err(SparkError::ArithmeticOverflow {
+ from_type: String::from("integer"),
+ }
+ .into());
+ } else {
+ builder.append_null();
+ }
+ }
+ }
}
}
}
@@ -66,7 +107,23 @@ where
if left.is_null(i) || right.is_null(i) {
builder.append_null();
} else {
-
builder.append_option(left.value(i).div_checked(right.value(i)).ok());
+ match left.value(i).div_checked(right.value(i)) {
+ Ok(v) => builder.append_value(v),
+ Err(_e) => {
+ if is_ansi_mode {
+ return if right.value(i).is_zero() {
+ Err(divide_by_zero_error().into())
+ } else {
+ return Err(SparkError::ArithmeticOverflow {
+ from_type: String::from("integer"),
+ }
+ .into());
+ };
+ } else {
+ builder.append_null();
+ }
+ }
+ }
}
}
}
@@ -84,39 +141,55 @@ where
pub fn checked_add(
args: &[ColumnarValue],
data_type: &DataType,
+ eval_mode: EvalMode,
) -> Result<ColumnarValue, DataFusionError> {
- checked_arithmetic_internal(args, data_type, "checked_add")
+ checked_arithmetic_internal(args, data_type, "checked_add", eval_mode)
}
pub fn checked_sub(
args: &[ColumnarValue],
data_type: &DataType,
+ eval_mode: EvalMode,
) -> Result<ColumnarValue, DataFusionError> {
- checked_arithmetic_internal(args, data_type, "checked_sub")
+ checked_arithmetic_internal(args, data_type, "checked_sub", eval_mode)
}
pub fn checked_mul(
args: &[ColumnarValue],
data_type: &DataType,
+ eval_mode: EvalMode,
) -> Result<ColumnarValue, DataFusionError> {
- checked_arithmetic_internal(args, data_type, "checked_mul")
+ checked_arithmetic_internal(args, data_type, "checked_mul", eval_mode)
}
pub fn checked_div(
args: &[ColumnarValue],
data_type: &DataType,
+ eval_mode: EvalMode,
) -> Result<ColumnarValue, DataFusionError> {
- checked_arithmetic_internal(args, data_type, "checked_div")
+ checked_arithmetic_internal(args, data_type, "checked_div", eval_mode)
}
fn checked_arithmetic_internal(
args: &[ColumnarValue],
data_type: &DataType,
op: &str,
+ eval_mode: EvalMode,
) -> Result<ColumnarValue, DataFusionError> {
let left = &args[0];
let right = &args[1];
+ let is_ansi_mode = match eval_mode {
+ EvalMode::Try => false,
+ EvalMode::Ansi => true,
+ _ => {
+ return Err(DataFusionError::Internal(format!(
+ "Unsupported mode : {:?}",
+ eval_mode
+ )))
+ }
+ };
+
let (left_arr, right_arr): (ArrayRef, ArrayRef) = match (left, right) {
(ColumnarValue::Array(l), ColumnarValue::Array(r)) => (Arc::clone(l),
Arc::clone(r)),
(ColumnarValue::Scalar(l), ColumnarValue::Array(r)) => {
@@ -128,17 +201,50 @@ fn checked_arithmetic_internal(
(ColumnarValue::Scalar(l), ColumnarValue::Scalar(r)) =>
(l.to_array()?, r.to_array()?),
};
- // Rust only supports checked_arithmetic on Int32 and Int64
+ // Rust only supports checked_arithmetic on numeric types
let result_array = match data_type {
+ DataType::Int8 => try_arithmetic_kernel::<Int8Type>(
+ left_arr.as_primitive::<Int8Type>(),
+ right_arr.as_primitive::<Int8Type>(),
+ op,
+ is_ansi_mode,
+ ),
+ DataType::Int16 => try_arithmetic_kernel::<Int16Type>(
+ left_arr.as_primitive::<Int16Type>(),
+ right_arr.as_primitive::<Int16Type>(),
+ op,
+ is_ansi_mode,
+ ),
DataType::Int32 => try_arithmetic_kernel::<Int32Type>(
left_arr.as_primitive::<Int32Type>(),
right_arr.as_primitive::<Int32Type>(),
op,
+ is_ansi_mode,
),
DataType::Int64 => try_arithmetic_kernel::<Int64Type>(
left_arr.as_primitive::<Int64Type>(),
right_arr.as_primitive::<Int64Type>(),
op,
+ is_ansi_mode,
+ ),
+ // Spark always casts division operands to floats
+ DataType::Float16 if (op == "checked_div") =>
try_arithmetic_kernel::<Float16Type>(
+ left_arr.as_primitive::<Float16Type>(),
+ right_arr.as_primitive::<Float16Type>(),
+ op,
+ is_ansi_mode,
+ ),
+ DataType::Float32 if (op == "checked_div") =>
try_arithmetic_kernel::<Float32Type>(
+ left_arr.as_primitive::<Float32Type>(),
+ right_arr.as_primitive::<Float32Type>(),
+ op,
+ is_ansi_mode,
+ ),
+ DataType::Float64 if (op == "checked_div") =>
try_arithmetic_kernel::<Float64Type>(
+ left_arr.as_primitive::<Float64Type>(),
+ right_arr.as_primitive::<Float64Type>(),
+ op,
+ is_ansi_mode,
),
_ => Err(DataFusionError::Internal(format!(
"Unsupported data type: {:?}",
diff --git a/spark/src/main/scala/org/apache/comet/serde/arithmetic.scala
b/spark/src/main/scala/org/apache/comet/serde/arithmetic.scala
index 0f1eeb758..4507dc107 100644
--- a/spark/src/main/scala/org/apache/comet/serde/arithmetic.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/arithmetic.scala
@@ -87,14 +87,6 @@ trait MathBase {
object CometAdd extends CometExpressionSerde[Add] with MathBase {
- override def getSupportLevel(expr: Add): SupportLevel = {
- if (expr.evalMode == EvalMode.ANSI) {
- Incompatible(Some("ANSI mode is not supported"))
- } else {
- Compatible(None)
- }
- }
-
override def convert(
expr: Add,
inputs: Seq[Attribute],
@@ -117,14 +109,6 @@ object CometAdd extends CometExpressionSerde[Add] with
MathBase {
object CometSubtract extends CometExpressionSerde[Subtract] with MathBase {
- override def getSupportLevel(expr: Subtract): SupportLevel = {
- if (expr.evalMode == EvalMode.ANSI) {
- Incompatible(Some("ANSI mode is not supported"))
- } else {
- Compatible(None)
- }
- }
-
override def convert(
expr: Subtract,
inputs: Seq[Attribute],
@@ -147,14 +131,6 @@ object CometSubtract extends
CometExpressionSerde[Subtract] with MathBase {
object CometMultiply extends CometExpressionSerde[Multiply] with MathBase {
- override def getSupportLevel(expr: Multiply): SupportLevel = {
- if (expr.evalMode == EvalMode.ANSI) {
- Incompatible(Some("ANSI mode is not supported"))
- } else {
- Compatible(None)
- }
- }
-
override def convert(
expr: Multiply,
inputs: Seq[Attribute],
@@ -177,14 +153,6 @@ object CometMultiply extends
CometExpressionSerde[Multiply] with MathBase {
object CometDivide extends CometExpressionSerde[Divide] with MathBase {
- override def getSupportLevel(expr: Divide): SupportLevel = {
- if (expr.evalMode == EvalMode.ANSI) {
- Incompatible(Some("ANSI mode is not supported"))
- } else {
- Compatible(None)
- }
- }
-
override def convert(
expr: Divide,
inputs: Seq[Attribute],
@@ -192,7 +160,8 @@ object CometDivide extends CometExpressionSerde[Divide]
with MathBase {
// Datafusion now throws an exception for dividing by zero
// See https://github.com/apache/arrow-datafusion/pull/6792
// For now, use NullIf to swap zeros with nulls.
- val rightExpr = nullIfWhenPrimitive(expr.right)
+ val rightExpr =
+ if (expr.evalMode != EvalMode.ANSI) nullIfWhenPrimitive(expr.right) else
expr.right
if (!supportedDataType(expr.left.dataType)) {
withInfo(expr, s"Unsupported datatype ${expr.left.dataType}")
return None
diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
index 600f4e45b..daf0e45cc 100644
--- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
@@ -55,6 +55,11 @@ class CometExpressionSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
}
}
+ val ARITHMETIC_OVERFLOW_EXCEPTION_MSG =
+ """org.apache.comet.CometNativeException: [ARITHMETIC_OVERFLOW] integer
overflow. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this
error"""
+ val DIVIDE_BY_ZERO_EXCEPTION_MSG =
+ """org.apache.comet.CometNativeException: [DIVIDE_BY_ZERO] Division by
zero. Use `try_divide` to tolerate divisor being 0 and return NULL instead"""
+
test("compare true/false to negative zero") {
Seq(false, true).foreach { dictionary =>
withSQLConf("parquet.enable.dictionary" -> dictionary.toString) {
@@ -2864,6 +2869,107 @@ class CometExpressionSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
}
}
+ test("ANSI support for add") {
+ val data = Seq((Integer.MAX_VALUE, 1), (Integer.MIN_VALUE, -1))
+ withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") {
+ withParquetTable(data, "tbl") {
+ val res = spark.sql("""
+ |SELECT
+ | _1 + _2
+ | from tbl
+ | """.stripMargin)
+
+ checkSparkMaybeThrows(res) match {
+ case (Some(sparkExc), Some(cometExc)) =>
+
assert(cometExc.getMessage.contains(ARITHMETIC_OVERFLOW_EXCEPTION_MSG))
+ assert(sparkExc.getMessage.contains("overflow"))
+ case _ => fail("Exception should be thrown")
+ }
+ }
+ }
+ }
+
+ test("ANSI support for subtract") {
+ val data = Seq((Integer.MIN_VALUE, 1))
+ withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") {
+ withParquetTable(data, "tbl") {
+ val res = spark.sql("""
+ |SELECT
+ | _1 - _2
+ | from tbl
+ | """.stripMargin)
+ checkSparkMaybeThrows(res) match {
+ case (Some(sparkExc), Some(cometExc)) =>
+
assert(cometExc.getMessage.contains(ARITHMETIC_OVERFLOW_EXCEPTION_MSG))
+ assert(sparkExc.getMessage.contains("overflow"))
+ case _ => fail("Exception should be thrown")
+ }
+ }
+ }
+ }
+
+ test("ANSI support for multiply") {
+ val data = Seq((Integer.MAX_VALUE, 10))
+ withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") {
+ withParquetTable(data, "tbl") {
+ val res = spark.sql("""
+ |SELECT
+ | _1 * _2
+ | from tbl
+ | """.stripMargin)
+
+ checkSparkMaybeThrows(res) match {
+ case (Some(sparkExc), Some(cometExc)) =>
+
assert(cometExc.getMessage.contains(ARITHMETIC_OVERFLOW_EXCEPTION_MSG))
+ assert(sparkExc.getMessage.contains("overflow"))
+ case _ => fail("Exception should be thrown")
+ }
+ }
+ }
+ }
+
+ test("ANSI support for divide (division by zero)") {
+ // TODO : Support ANSI mode in Integral divide -
+ val data = Seq((Integer.MIN_VALUE, 0))
+ withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") {
+ withParquetTable(data, "tbl") {
+ val res = spark.sql("""
+ |SELECT
+ | _1 / _2
+ | from tbl
+ | """.stripMargin)
+
+ checkSparkMaybeThrows(res) match {
+ case (Some(sparkExc), Some(cometExc)) =>
+ assert(cometExc.getMessage.contains(DIVIDE_BY_ZERO_EXCEPTION_MSG))
+ assert(sparkExc.getMessage.contains("Division by zero"))
+ case _ => fail("Exception should be thrown")
+ }
+ }
+ }
+ }
+
+ test("ANSI support for divide (division by zero) float division") {
+ // TODO : Support ANSI mode in Integral divide -
+ val data = Seq((Float.MinPositiveValue, 0.0))
+ withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") {
+ withParquetTable(data, "tbl") {
+ val res = spark.sql("""
+ |SELECT
+ | _1 / _2
+ | from tbl
+ | """.stripMargin)
+
+ checkSparkMaybeThrows(res) match {
+ case (Some(sparkExc), Some(cometExc)) =>
+ assert(cometExc.getMessage.contains(DIVIDE_BY_ZERO_EXCEPTION_MSG))
+ assert(sparkExc.getMessage.contains("Division by zero"))
+ case _ => fail("Exception should be thrown")
+ }
+ }
+ }
+ }
+
test("test integral divide overflow for decimal") {
if (isSpark40Plus) {
Seq(true, false)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]