This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new fc3e6e9b0 feat: Add support for `abs` (#2689)
fc3e6e9b0 is described below
commit fc3e6e9b0fbdfbc075db7da6faacc0ac7b63ace9
Author: Andy Grove <[email protected]>
AuthorDate: Wed Nov 5 17:17:08 2025 -0700
feat: Add support for `abs` (#2689)
---
.github/workflows/pr_build_linux.yml | 1 +
.github/workflows/pr_build_macos.yml | 1 +
docs/source/user-guide/latest/configs.md | 1 +
docs/source/user-guide/latest/expressions.md | 1 +
native/core/src/execution/planner.rs | 13 -
native/proto/src/proto/expr.proto | 6 -
native/spark-expr/src/comet_scalar_funcs.rs | 5 +
native/spark-expr/src/math_funcs/abs.rs | 890 +++++++++++++++++++++
native/spark-expr/src/math_funcs/mod.rs | 1 +
.../org/apache/comet/serde/QueryPlanSerde.scala | 3 +-
.../main/scala/org/apache/comet/serde/math.scala | 34 +-
.../org/apache/comet/CometExpressionSuite.scala | 70 --
.../apache/comet/CometMathExpressionSuite.scala | 93 +++
13 files changed, 1027 insertions(+), 92 deletions(-)
diff --git a/.github/workflows/pr_build_linux.yml
b/.github/workflows/pr_build_linux.yml
index 2867f61da..02b544e2d 100644
--- a/.github/workflows/pr_build_linux.yml
+++ b/.github/workflows/pr_build_linux.yml
@@ -134,6 +134,7 @@ jobs:
org.apache.comet.CometCastSuite
org.apache.comet.CometExpressionSuite
org.apache.comet.CometExpressionCoverageSuite
+ org.apache.comet.CometMathExpressionSuite
org.apache.comet.CometNativeSuite
org.apache.comet.CometSparkSessionExtensionsSuite
org.apache.comet.CometStringExpressionSuite
diff --git a/.github/workflows/pr_build_macos.yml
b/.github/workflows/pr_build_macos.yml
index 0fd1cb606..3a1b82d04 100644
--- a/.github/workflows/pr_build_macos.yml
+++ b/.github/workflows/pr_build_macos.yml
@@ -99,6 +99,7 @@ jobs:
org.apache.comet.CometCastSuite
org.apache.comet.CometExpressionSuite
org.apache.comet.CometExpressionCoverageSuite
+ org.apache.comet.CometMathExpressionSuite
org.apache.comet.CometNativeSuite
org.apache.comet.CometSparkSessionExtensionsSuite
org.apache.comet.CometStringExpressionSuite
diff --git a/docs/source/user-guide/latest/configs.md
b/docs/source/user-guide/latest/configs.md
index 537d0d774..6caaa53b1 100644
--- a/docs/source/user-guide/latest/configs.md
+++ b/docs/source/user-guide/latest/configs.md
@@ -164,6 +164,7 @@ These settings can be used to determine which parts of the
plan are accelerated
<!--BEGIN:CONFIG_TABLE[enable_expr]-->
| Config | Description | Default Value |
|--------|-------------|---------------|
+| `spark.comet.expression.Abs.enabled` | Enable Comet acceleration for `Abs` |
true |
| `spark.comet.expression.Acos.enabled` | Enable Comet acceleration for `Acos`
| true |
| `spark.comet.expression.Add.enabled` | Enable Comet acceleration for `Add` |
true |
| `spark.comet.expression.Alias.enabled` | Enable Comet acceleration for
`Alias` | true |
diff --git a/docs/source/user-guide/latest/expressions.md
b/docs/source/user-guide/latest/expressions.md
index b1bcc9dd5..fe42f49a4 100644
--- a/docs/source/user-guide/latest/expressions.md
+++ b/docs/source/user-guide/latest/expressions.md
@@ -119,6 +119,7 @@ incompatible expressions.
| Expression | SQL | Spark-Compatible? | Compatibility Notes
|
|----------------|-----------|-------------------|-----------------------------------|
+| Abs | `abs` | Yes |
|
| Acos | `acos` | Yes |
|
| Add | `+` | Yes |
|
| Asin | `asin` | Yes |
|
diff --git a/native/core/src/execution/planner.rs
b/native/core/src/execution/planner.rs
index a37f928e9..a33df705b 100644
--- a/native/core/src/execution/planner.rs
+++ b/native/core/src/execution/planner.rs
@@ -674,19 +674,6 @@ impl PhysicalPlanner {
let op = DataFusionOperator::BitwiseShiftLeft;
Ok(Arc::new(BinaryExpr::new(left, op, right)))
}
- // https://github.com/apache/datafusion-comet/issues/666
- // ExprStruct::Abs(expr) => {
- // let child = self.create_expr(expr.child.as_ref().unwrap(),
Arc::clone(&input_schema))?;
- // let return_type = child.data_type(&input_schema)?;
- // let args = vec![child];
- // let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?;
- // let comet_abs = Arc::new(ScalarUDF::new_from_impl(Abs::new(
- // eval_mode,
- // return_type.to_string(),
- // )?));
- // let expr = ScalarFunctionExpr::new("abs", comet_abs, args,
return_type);
- // Ok(Arc::new(expr))
- // }
ExprStruct::CaseWhen(case_when) => {
let when_then_pairs = case_when
.when
diff --git a/native/proto/src/proto/expr.proto
b/native/proto/src/proto/expr.proto
index 5853bc613..c9037dcd6 100644
--- a/native/proto/src/proto/expr.proto
+++ b/native/proto/src/proto/expr.proto
@@ -70,7 +70,6 @@ message Expr {
IfExpr if = 44;
NormalizeNaNAndZero normalize_nan_and_zero = 45;
TruncTimestamp truncTimestamp = 47;
- Abs abs = 49;
Subquery subquery = 50;
UnboundReference unbound = 51;
BloomFilterMightContain bloom_filter_might_contain = 52;
@@ -351,11 +350,6 @@ message TruncTimestamp {
string timezone = 3;
}
-message Abs {
- Expr child = 1;
- EvalMode eval_mode = 2;
-}
-
message Subquery {
int64 id = 1;
DataType datatype = 2;
diff --git a/native/spark-expr/src/comet_scalar_funcs.rs
b/native/spark-expr/src/comet_scalar_funcs.rs
index fc0c096b1..021bb1c78 100644
--- a/native/spark-expr/src/comet_scalar_funcs.rs
+++ b/native/spark-expr/src/comet_scalar_funcs.rs
@@ -16,6 +16,7 @@
// under the License.
use crate::hash_funcs::*;
+use crate::math_funcs::abs::abs;
use crate::math_funcs::checked_arithmetic::{checked_add, checked_div,
checked_mul, checked_sub};
use crate::math_funcs::modulo_expr::spark_modulo;
use crate::{
@@ -180,6 +181,10 @@ pub fn create_comet_physical_fun_with_eval_mode(
let func = Arc::new(spark_modulo);
make_comet_scalar_udf!("spark_modulo", func, without data_type,
fail_on_error)
}
+ "abs" => {
+ let func = Arc::new(abs);
+ make_comet_scalar_udf!("abs", func, without data_type)
+ }
_ => registry.udf(fun_name).map_err(|e| {
DataFusionError::Execution(format!(
"Function {fun_name} not found in the registry: {e}",
diff --git a/native/spark-expr/src/math_funcs/abs.rs
b/native/spark-expr/src/math_funcs/abs.rs
new file mode 100644
index 000000000..5a16398ec
--- /dev/null
+++ b/native/spark-expr/src/math_funcs/abs.rs
@@ -0,0 +1,890 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::arithmetic_overflow_error;
+use arrow::array::*;
+use arrow::datatypes::*;
+use arrow::error::ArrowError;
+use datafusion::common::{exec_err, DataFusionError, Result, ScalarValue};
+use datafusion::logical_expr::ColumnarValue;
+use std::sync::Arc;
+
+macro_rules! legacy_compute_op {
+ ($ARRAY:expr, $FUNC:ident, $TYPE:ident, $RESULT:ident) => {{
+ let n = $ARRAY.as_any().downcast_ref::<$TYPE>();
+ match n {
+ Some(array) => {
+ let res: $RESULT =
arrow::compute::kernels::arity::unary(array, |x| x.$FUNC());
+ Ok(res)
+ }
+ _ => Err(DataFusionError::Internal(format!(
+ "Invalid data type for abs"
+ ))),
+ }
+ }};
+}
+
+macro_rules! ansi_compute_op {
+ ($ARRAY:expr, $FUNC:ident, $TYPE:ident, $RESULT:ident, $NATIVE:ident,
$FROM_TYPE:expr) => {{
+ let n = $ARRAY.as_any().downcast_ref::<$TYPE>();
+ match n {
+ Some(array) => {
+ match arrow::compute::kernels::arity::try_unary(array, |x| {
+ if x == $NATIVE::MIN {
+
Err(ArrowError::ArithmeticOverflow($FROM_TYPE.to_string()))
+ } else {
+ Ok(x.$FUNC())
+ }
+ }) {
+ Ok(res) =>
Ok(ColumnarValue::Array(Arc::<PrimitiveArray<$RESULT>>::new(
+ res,
+ ))),
+ Err(_) =>
Err(arithmetic_overflow_error($FROM_TYPE).into()),
+ }
+ }
+ _ => Err(DataFusionError::Internal("Invalid data
type".to_string())),
+ }
+ }};
+}
+
+/// This function mimics SparkSQL's [Abs]:
https://github.com/apache/spark/blob/v4.0.1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala#L148
+/// Spark's [ANSI-compliant]:
https://spark.apache.org/docs/latest/sql-ref-ansi-compliance.html#arithmetic-operations
dialect mode throws org.apache.spark.SparkArithmeticException
+/// when abs causes overflow.
+pub fn abs(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
+ if args.is_empty() || args.len() > 2 {
+ return exec_err!("abs takes 1 or 2 arguments, but got: {}",
args.len());
+ }
+
+ let fail_on_error = if args.len() == 2 {
+ match &args[1] {
+ ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error)))
=> *fail_on_error,
+ _ => {
+ return exec_err!(
+ "The second argument must be boolean scalar, but got:
{:?}",
+ args[1]
+ );
+ }
+ }
+ } else {
+ false
+ };
+
+ match &args[0] {
+ ColumnarValue::Array(array) => match array.data_type() {
+ DataType::Null
+ | DataType::UInt8
+ | DataType::UInt16
+ | DataType::UInt32
+ | DataType::UInt64 => Ok(args[0].clone()),
+ DataType::Int8 => {
+ if !fail_on_error {
+ let result = legacy_compute_op!(array, wrapping_abs,
Int8Array, Int8Array);
+ Ok(ColumnarValue::Array(Arc::new(result?)))
+ } else {
+ ansi_compute_op!(array, abs, Int8Array, Int8Type, i8,
"Int8")
+ }
+ }
+ DataType::Int16 => {
+ if !fail_on_error {
+ let result = legacy_compute_op!(array, wrapping_abs,
Int16Array, Int16Array);
+ Ok(ColumnarValue::Array(Arc::new(result?)))
+ } else {
+ ansi_compute_op!(array, abs, Int16Array, Int16Type, i16,
"Int16")
+ }
+ }
+ DataType::Int32 => {
+ if !fail_on_error {
+ let result = legacy_compute_op!(array, wrapping_abs,
Int32Array, Int32Array);
+ Ok(ColumnarValue::Array(Arc::new(result?)))
+ } else {
+ ansi_compute_op!(array, abs, Int32Array, Int32Type, i32,
"Int32")
+ }
+ }
+ DataType::Int64 => {
+ if !fail_on_error {
+ let result = legacy_compute_op!(array, wrapping_abs,
Int64Array, Int64Array);
+ Ok(ColumnarValue::Array(Arc::new(result?)))
+ } else {
+ ansi_compute_op!(array, abs, Int64Array, Int64Type, i64,
"Int64")
+ }
+ }
+ DataType::Float32 => {
+ let result = legacy_compute_op!(array, abs, Float32Array,
Float32Array);
+ Ok(ColumnarValue::Array(Arc::new(result?)))
+ }
+ DataType::Float64 => {
+ let result = legacy_compute_op!(array, abs, Float64Array,
Float64Array);
+ Ok(ColumnarValue::Array(Arc::new(result?)))
+ }
+ DataType::Decimal128(precision, scale) => {
+ if !fail_on_error {
+ let result =
+ legacy_compute_op!(array, wrapping_abs,
Decimal128Array, Decimal128Array)?;
+ let result =
result.with_data_type(DataType::Decimal128(*precision, *scale));
+ Ok(ColumnarValue::Array(Arc::new(result)))
+ } else {
+ // Need to pass precision and scale from input, so not
using ansi_compute_op
+ let input =
array.as_any().downcast_ref::<Decimal128Array>();
+ match input {
+ Some(i) => {
+ match arrow::compute::kernels::arity::try_unary(i,
|x| {
+ if x == i128::MIN {
+
Err(ArrowError::ArithmeticOverflow("Decimal128".to_string()))
+ } else {
+ Ok(x.abs())
+ }
+ }) {
+ Ok(res) => Ok(ColumnarValue::Array(Arc::<
+ PrimitiveArray<Decimal128Type>,
+ >::new(
+
res.with_data_type(DataType::Decimal128(*precision, *scale)),
+ ))),
+ Err(_) =>
Err(arithmetic_overflow_error("Decimal128").into()),
+ }
+ }
+ _ => Err(DataFusionError::Internal("Invalid data
type".to_string())),
+ }
+ }
+ }
+ DataType::Decimal256(precision, scale) => {
+ if !fail_on_error {
+ let result =
+ legacy_compute_op!(array, wrapping_abs,
Decimal256Array, Decimal256Array)?;
+ let result =
result.with_data_type(DataType::Decimal256(*precision, *scale));
+ Ok(ColumnarValue::Array(Arc::new(result)))
+ } else {
+ // Need to pass precision and scale from input, so not
using ansi_compute_op
+ let input =
array.as_any().downcast_ref::<Decimal256Array>();
+ match input {
+ Some(i) => {
+ match arrow::compute::kernels::arity::try_unary(i,
|x| {
+ if x == i256::MIN {
+
Err(ArrowError::ArithmeticOverflow("Decimal256".to_string()))
+ } else {
+ Ok(x.wrapping_abs()) // i256 doesn't
define abs() method
+ }
+ }) {
+ Ok(res) => Ok(ColumnarValue::Array(Arc::<
+ PrimitiveArray<Decimal256Type>,
+ >::new(
+
res.with_data_type(DataType::Decimal256(*precision, *scale)),
+ ))),
+ Err(_) =>
Err(arithmetic_overflow_error("Decimal256").into()),
+ }
+ }
+ _ => Err(DataFusionError::Internal("Invalid data
type".to_string())),
+ }
+ }
+ }
+ dt => exec_err!("Not supported datatype for ABS: {dt}"),
+ },
+ ColumnarValue::Scalar(sv) => match sv {
+ ScalarValue::Null
+ | ScalarValue::UInt8(_)
+ | ScalarValue::UInt16(_)
+ | ScalarValue::UInt32(_)
+ | ScalarValue::UInt64(_) => Ok(args[0].clone()),
+ ScalarValue::Int8(a) => match a {
+ None => Ok(args[0].clone()),
+ Some(v) => match v.checked_abs() {
+ Some(abs_val) =>
Ok(ColumnarValue::Scalar(ScalarValue::Int8(Some(abs_val)))),
+ None => {
+ if !fail_on_error {
+ // return the original value
+
Ok(ColumnarValue::Scalar(ScalarValue::Int8(Some(*v))))
+ } else {
+ Err(arithmetic_overflow_error("Int8").into())
+ }
+ }
+ },
+ },
+ ScalarValue::Int16(a) => match a {
+ None => Ok(args[0].clone()),
+ Some(v) => match v.checked_abs() {
+ Some(abs_val) =>
Ok(ColumnarValue::Scalar(ScalarValue::Int16(Some(abs_val)))),
+ None => {
+ if !fail_on_error {
+ // return the original value
+
Ok(ColumnarValue::Scalar(ScalarValue::Int16(Some(*v))))
+ } else {
+ Err(arithmetic_overflow_error("Int16").into())
+ }
+ }
+ },
+ },
+ ScalarValue::Int32(a) => match a {
+ None => Ok(args[0].clone()),
+ Some(v) => match v.checked_abs() {
+ Some(abs_val) =>
Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(abs_val)))),
+ None => {
+ if !fail_on_error {
+ // return the original value
+
Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(*v))))
+ } else {
+ Err(arithmetic_overflow_error("Int32").into())
+ }
+ }
+ },
+ },
+ ScalarValue::Int64(a) => match a {
+ None => Ok(args[0].clone()),
+ Some(v) => match v.checked_abs() {
+ Some(abs_val) =>
Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some(abs_val)))),
+ None => {
+ if !fail_on_error {
+ // return the original value
+
Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some(*v))))
+ } else {
+ Err(arithmetic_overflow_error("Int64").into())
+ }
+ }
+ },
+ },
+ ScalarValue::Float32(a) =>
Ok(ColumnarValue::Scalar(ScalarValue::Float32(
+ a.map(|x| x.abs()),
+ ))),
+ ScalarValue::Float64(a) =>
Ok(ColumnarValue::Scalar(ScalarValue::Float64(
+ a.map(|x| x.abs()),
+ ))),
+ ScalarValue::Decimal128(a, precision, scale) => match a {
+ None => Ok(args[0].clone()),
+ Some(v) => match v.checked_abs() {
+ Some(abs_val) =>
Ok(ColumnarValue::Scalar(ScalarValue::Decimal128(
+ Some(abs_val),
+ *precision,
+ *scale,
+ ))),
+ None => {
+ if !fail_on_error {
+ // return the original value
+ Ok(ColumnarValue::Scalar(ScalarValue::Decimal128(
+ Some(*v),
+ *precision,
+ *scale,
+ )))
+ } else {
+ Err(arithmetic_overflow_error("Decimal128").into())
+ }
+ }
+ },
+ },
+ ScalarValue::Decimal256(a, precision, scale) => match a {
+ None => Ok(args[0].clone()),
+ Some(v) => match v.checked_abs() {
+ Some(abs_val) =>
Ok(ColumnarValue::Scalar(ScalarValue::Decimal256(
+ Some(abs_val),
+ *precision,
+ *scale,
+ ))),
+ None => {
+ if !fail_on_error {
+ // return the original value
+ Ok(ColumnarValue::Scalar(ScalarValue::Decimal256(
+ Some(*v),
+ *precision,
+ *scale,
+ )))
+ } else {
+ Err(arithmetic_overflow_error("Decimal256").into())
+ }
+ }
+ },
+ },
+ dt => exec_err!("Not supported datatype for ABS: {dt}"),
+ },
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use datafusion::common::cast::{
+ as_decimal128_array, as_decimal256_array, as_float32_array,
as_float64_array,
+ as_int16_array, as_int32_array, as_int64_array, as_int8_array,
as_uint64_array,
+ };
+
+ fn with_fail_on_error<F: Fn(bool) -> Result<()>>(test_fn: F) {
+ for fail_on_error in [true, false] {
+ test_fn(fail_on_error).expect("test should pass on error
successfully");
+ }
+ }
+
+ // Unsigned types, return as is
+ #[test]
+ fn test_abs_u8_scalar() {
+ with_fail_on_error(|fail_on_error| {
+ let args =
ColumnarValue::Scalar(ScalarValue::UInt8(Some(u8::MAX)));
+ let fail_on_error_arg =
+
ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error)));
+ match abs(&[args, fail_on_error_arg]) {
+ Ok(ColumnarValue::Scalar(ScalarValue::UInt8(Some(result)))) =>
{
+ assert_eq!(result, u8::MAX);
+ Ok(())
+ }
+ Err(e) => {
+ unreachable!("Didn't expect error, but got: {e:?}")
+ }
+ _ => unreachable!(),
+ }
+ });
+ }
+
+ #[test]
+ fn test_abs_i8_scalar() {
+ with_fail_on_error(|fail_on_error| {
+ let args = ColumnarValue::Scalar(ScalarValue::Int8(Some(i8::MIN)));
+ let fail_on_error_arg =
+
ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error)));
+ match abs(&[args, fail_on_error_arg]) {
+ Ok(ColumnarValue::Scalar(ScalarValue::Int8(Some(result)))) => {
+ assert_eq!(result, i8::MIN);
+ Ok(())
+ }
+ Err(e) => {
+ if fail_on_error {
+ assert!(
+ e.to_string().contains("ARITHMETIC_OVERFLOW"),
+ "Error message did not match. Actual message: {e}"
+ );
+ Ok(())
+ } else {
+ unreachable!("Didn't expect error, but got: {e:?}")
+ }
+ }
+ _ => unreachable!(),
+ }
+ });
+ }
+
+ #[test]
+ fn test_abs_i16_scalar() {
+ with_fail_on_error(|fail_on_error| {
+ let args =
ColumnarValue::Scalar(ScalarValue::Int16(Some(i16::MIN)));
+ let fail_on_error_arg =
+
ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error)));
+ match abs(&[args, fail_on_error_arg]) {
+ Ok(ColumnarValue::Scalar(ScalarValue::Int16(Some(result)))) =>
{
+ assert_eq!(result, i16::MIN);
+ Ok(())
+ }
+ Err(e) => {
+ if fail_on_error {
+ assert!(
+ e.to_string().contains("ARITHMETIC_OVERFLOW"),
+ "Error message did not match. Actual message: {e}"
+ );
+ Ok(())
+ } else {
+ unreachable!("Didn't expect error, but got: {e:?}")
+ }
+ }
+ _ => unreachable!(),
+ }
+ });
+ }
+
+ #[test]
+ fn test_abs_i32_scalar() {
+ with_fail_on_error(|fail_on_error| {
+ let args =
ColumnarValue::Scalar(ScalarValue::Int32(Some(i32::MIN)));
+ let fail_on_error_arg =
+
ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error)));
+ match abs(&[args, fail_on_error_arg]) {
+ Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(result)))) =>
{
+ assert_eq!(result, i32::MIN);
+ Ok(())
+ }
+ Err(e) => {
+ if fail_on_error {
+ assert!(
+ e.to_string().contains("ARITHMETIC_OVERFLOW"),
+ "Error message did not match. Actual message: {e}"
+ );
+ Ok(())
+ } else {
+ panic!("Didn't expect error, but got: {e:?}")
+ }
+ }
+ _ => unreachable!(),
+ }
+ });
+ }
+
+ #[test]
+ fn test_abs_i64_scalar() {
+ with_fail_on_error(|fail_on_error| {
+ let args =
ColumnarValue::Scalar(ScalarValue::Int64(Some(i64::MIN)));
+ let fail_on_error_arg =
+
ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error)));
+ match abs(&[args, fail_on_error_arg]) {
+ Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some(result)))) =>
{
+ assert_eq!(result, i64::MIN);
+ Ok(())
+ }
+ Err(e) => {
+ if fail_on_error {
+ assert!(
+ e.to_string().contains("ARITHMETIC_OVERFLOW"),
+ "Error message did not match. Actual message: {e}"
+ );
+ Ok(())
+ } else {
+ panic!("Didn't expect error, but got: {e:?}")
+ }
+ }
+ _ => unreachable!(),
+ }
+ });
+ }
+
+ #[test]
+ fn test_abs_decimal128_scalar() {
+ with_fail_on_error(|fail_on_error| {
+ let args =
ColumnarValue::Scalar(ScalarValue::Decimal128(Some(i128::MIN), 18, 10));
+ let fail_on_error_arg =
+
ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error)));
+ match abs(&[args, fail_on_error_arg]) {
+ Ok(ColumnarValue::Scalar(ScalarValue::Decimal128(
+ Some(result),
+ precision,
+ scale,
+ ))) => {
+ assert_eq!(result, i128::MIN);
+ assert_eq!(precision, 18);
+ assert_eq!(scale, 10);
+ Ok(())
+ }
+ Err(e) => {
+ if fail_on_error {
+ assert!(
+ e.to_string().contains("ARITHMETIC_OVERFLOW"),
+ "Error message did not match. Actual message: {e}"
+ );
+ Ok(())
+ } else {
+ panic!("Didn't expect error, but got: {e:?}")
+ }
+ }
+ _ => unreachable!(),
+ }
+ });
+ }
+
+ #[test]
+ fn test_abs_decimal256_scalar() {
+ with_fail_on_error(|fail_on_error| {
+ let args =
ColumnarValue::Scalar(ScalarValue::Decimal256(Some(i256::MIN), 10, 2));
+ let fail_on_error_arg =
+
ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error)));
+ match abs(&[args, fail_on_error_arg]) {
+ Ok(ColumnarValue::Scalar(ScalarValue::Decimal256(
+ Some(result),
+ precision,
+ scale,
+ ))) => {
+ assert_eq!(result, i256::MIN);
+ assert_eq!(precision, 10);
+ assert_eq!(scale, 2);
+ Ok(())
+ }
+ Err(e) => {
+ if fail_on_error {
+ assert!(
+ e.to_string().contains("ARITHMETIC_OVERFLOW"),
+ "Error message did not match. Actual message: {e}"
+ );
+ Ok(())
+ } else {
+ panic!("Didn't expect error, but got: {e:?}")
+ }
+ }
+ _ => unreachable!(),
+ }
+ });
+ }
+
+ #[test]
+ fn test_abs_i8_array() {
+ with_fail_on_error(|fail_on_error| {
+ let input = Int8Array::from(vec![Some(-1), Some(i8::MIN),
Some(i8::MAX), None]);
+ let args = ColumnarValue::Array(Arc::new(input));
+ let expected = Int8Array::from(vec![Some(1), Some(i8::MIN),
Some(i8::MAX), None]);
+ let fail_on_error_arg =
+
ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error)));
+
+ match abs(&[args, fail_on_error_arg]) {
+ Ok(ColumnarValue::Array(result)) => {
+ let actual = as_int8_array(&result)?;
+ assert_eq!(actual, &expected);
+ Ok(())
+ }
+ Err(e) => {
+ if fail_on_error {
+ assert!(
+ e.to_string().contains("ARITHMETIC_OVERFLOW"),
+ "Error message did not match. Actual message: {e}"
+ );
+ Ok(())
+ } else {
+ panic!("Didn't expect error, but got: {e:?}")
+ }
+ }
+ _ => unreachable!(),
+ }
+ });
+ }
+
+ #[test]
+ fn test_abs_i16_array() {
+ with_fail_on_error(|fail_on_error| {
+ let input = Int16Array::from(vec![Some(-1), Some(i16::MIN),
Some(i16::MAX), None]);
+ let args = ColumnarValue::Array(Arc::new(input));
+ let expected = Int16Array::from(vec![Some(1), Some(i16::MIN),
Some(i16::MAX), None]);
+ let fail_on_error_arg =
+
ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error)));
+
+ match abs(&[args, fail_on_error_arg]) {
+ Ok(ColumnarValue::Array(result)) => {
+ let actual = as_int16_array(&result)?;
+ assert_eq!(actual, &expected);
+ Ok(())
+ }
+ Err(e) => {
+ if fail_on_error {
+ assert!(
+ e.to_string().contains("ARITHMETIC_OVERFLOW"),
+ "Error message did not match. Actual message: {e}"
+ );
+ Ok(())
+ } else {
+ panic!("Didn't expect error, but got: {e:?}")
+ }
+ }
+ _ => unreachable!(),
+ }
+ });
+ }
+
+ #[test]
+ fn test_abs_i32_array() {
+ with_fail_on_error(|fail_on_error| {
+ let input = Int32Array::from(vec![Some(-1), Some(i32::MIN),
Some(i32::MAX), None]);
+ let args = ColumnarValue::Array(Arc::new(input));
+ let expected = Int32Array::from(vec![Some(1), Some(i32::MIN),
Some(i32::MAX), None]);
+ let fail_on_error_arg =
+
ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error)));
+
+ match abs(&[args, fail_on_error_arg]) {
+ Ok(ColumnarValue::Array(result)) => {
+ let actual = as_int32_array(&result)?;
+ assert_eq!(actual, &expected);
+ Ok(())
+ }
+ Err(e) => {
+ if fail_on_error {
+ assert!(
+ e.to_string().contains("ARITHMETIC_OVERFLOW"),
+ "Error message did not match. Actual message: {e}"
+ );
+ Ok(())
+ } else {
+ panic!("Didn't expect error, but got: {e:?}")
+ }
+ }
+ _ => unreachable!(),
+ }
+ });
+ }
+
+ #[test]
+ fn test_abs_i64_array() {
+ with_fail_on_error(|fail_on_error| {
+ let input = Int64Array::from(vec![Some(-1), Some(i64::MIN),
Some(i64::MAX), None]);
+ let args = ColumnarValue::Array(Arc::new(input));
+ let expected = Int64Array::from(vec![Some(1), Some(i64::MIN),
Some(i64::MAX), None]);
+ let fail_on_error_arg =
+
ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error)));
+
+ match abs(&[args, fail_on_error_arg]) {
+ Ok(ColumnarValue::Array(result)) => {
+ let actual = as_int64_array(&result)?;
+ assert_eq!(actual, &expected);
+ Ok(())
+ }
+ Err(e) => {
+ if fail_on_error {
+ assert!(
+ e.to_string().contains("ARITHMETIC_OVERFLOW"),
+ "Error message did not match. Actual message: {e}"
+ );
+ Ok(())
+ } else {
+ panic!("Didn't expect error, but got: {e:?}")
+ }
+ }
+ _ => unreachable!(),
+ }
+ });
+ }
+
+ #[test]
+ fn test_abs_f32_array() {
+ with_fail_on_error(|fail_on_error| {
+ let input = Float32Array::from(vec![
+ Some(-1f32),
+ Some(f32::MIN),
+ Some(f32::MAX),
+ None,
+ Some(f32::NAN),
+ Some(f32::NEG_INFINITY),
+ Some(f32::INFINITY),
+ Some(-0.0),
+ Some(0.0),
+ ]);
+ let args = ColumnarValue::Array(Arc::new(input));
+ let expected = Float32Array::from(vec![
+ Some(1f32),
+ Some(f32::MAX),
+ Some(f32::MAX),
+ None,
+ Some(f32::NAN),
+ Some(f32::INFINITY),
+ Some(f32::INFINITY),
+ Some(0.0),
+ Some(0.0),
+ ]);
+ let fail_on_error_arg =
+
ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error)));
+
+ match abs(&[args, fail_on_error_arg]) {
+ Ok(ColumnarValue::Array(result)) => {
+ let actual = as_float32_array(&result)?;
+ assert_eq!(actual, &expected);
+ Ok(())
+ }
+ Err(e) => {
+ if fail_on_error {
+ assert!(
+ e.to_string().contains("ARITHMETIC_OVERFLOW"),
+ "Error message did not match. Actual message: {e}"
+ );
+ Ok(())
+ } else {
+ panic!("Didn't expect error, but got: {e:?}")
+ }
+ }
+ _ => unreachable!(),
+ }
+ });
+ }
+
+ #[test]
+ fn test_abs_f64_array() {
+ with_fail_on_error(|fail_on_error| {
+ let input = Float64Array::from(vec![Some(-1f64), Some(f64::MIN),
Some(f64::MAX), None]);
+ let args = ColumnarValue::Array(Arc::new(input));
+ let expected =
+ Float64Array::from(vec![Some(1f64), Some(f64::MAX),
Some(f64::MAX), None]);
+ let fail_on_error_arg =
+
ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error)));
+
+ match abs(&[args, fail_on_error_arg]) {
+ Ok(ColumnarValue::Array(result)) => {
+ let actual = as_float64_array(&result)?;
+ assert_eq!(actual, &expected);
+ Ok(())
+ }
+ Err(e) => {
+ if fail_on_error {
+ assert!(
+ e.to_string().contains("ARITHMETIC_OVERFLOW"),
+ "Error message did not match. Actual message: {e}"
+ );
+ Ok(())
+ } else {
+ panic!("Didn't expect error, but got: {e:?}")
+ }
+ }
+ _ => unreachable!(),
+ }
+ });
+ }
+
+ #[test]
+ fn test_abs_decimal128_array() {
+ with_fail_on_error(|fail_on_error| {
+ let input = Decimal128Array::from(vec![Some(i128::MIN), None])
+ .with_precision_and_scale(38, 37)?;
+ let args = ColumnarValue::Array(Arc::new(input));
+ let expected = Decimal128Array::from(vec![Some(i128::MIN), None])
+ .with_precision_and_scale(38, 37)?;
+ let fail_on_error_arg =
+
ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error)));
+
+ match abs(&[args, fail_on_error_arg]) {
+ Ok(ColumnarValue::Array(result)) => {
+ let actual = as_decimal128_array(&result)?;
+ assert_eq!(actual, &expected);
+ Ok(())
+ }
+ Err(e) => {
+ if fail_on_error {
+ assert!(
+ e.to_string().contains("ARITHMETIC_OVERFLOW"),
+ "Error message did not match. Actual message: {e}"
+ );
+ Ok(())
+ } else {
+ panic!("Didn't expect error, but got: {e:?}")
+ }
+ }
+ _ => unreachable!(),
+ }
+ });
+ }
+
+ #[test]
+ fn test_abs_decimal256_array() {
+ with_fail_on_error(|fail_on_error| {
+ let input = Decimal256Array::from(vec![Some(i256::MIN), None])
+ .with_precision_and_scale(5, 2)?;
+ let args = ColumnarValue::Array(Arc::new(input));
+ let expected = Decimal256Array::from(vec![Some(i256::MIN), None])
+ .with_precision_and_scale(5, 2)?;
+ let fail_on_error_arg =
+
ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error)));
+
+ match abs(&[args, fail_on_error_arg]) {
+ Ok(ColumnarValue::Array(result)) => {
+ let actual = as_decimal256_array(&result)?;
+ assert_eq!(actual, &expected);
+ Ok(())
+ }
+ Err(e) => {
+ if fail_on_error {
+ assert!(
+ e.to_string().contains("ARITHMETIC_OVERFLOW"),
+ "Error message did not match. Actual message: {e}"
+ );
+ Ok(())
+ } else {
+ panic!("Didn't expect error, but got: {e:?}")
+ }
+ }
+ _ => unreachable!(),
+ }
+ });
+ }
+
+ #[test]
+ fn test_abs_u64_array() {
+ with_fail_on_error(|fail_on_error| {
+ let input = UInt64Array::from(vec![Some(u64::MIN), Some(u64::MAX),
None]);
+ let args = ColumnarValue::Array(Arc::new(input));
+ let expected = UInt64Array::from(vec![Some(u64::MIN),
Some(u64::MAX), None]);
+ let fail_on_error_arg =
+
ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error)));
+
+ match abs(&[args, fail_on_error_arg]) {
+ Ok(ColumnarValue::Array(result)) => {
+ let actual = as_uint64_array(&result)?;
+ assert_eq!(actual, &expected);
+ Ok(())
+ }
+ Err(e) => {
+ if fail_on_error {
+ assert!(
+ e.to_string().contains("ARITHMETIC_OVERFLOW"),
+ "Error message did not match. Actual message: {e}"
+ );
+ Ok(())
+ } else {
+ panic!("Didn't expect error, but got: {e:?}")
+ }
+ }
+ _ => unreachable!(),
+ }
+ });
+ }
+
+ #[test]
+ fn test_abs_null_scalars() {
+ // Test that NULL scalars return NULL (no panic) for all signed types
+ with_fail_on_error(|fail_on_error| {
+ let fail_on_error_arg =
+
ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error)));
+
+ // Test Int8
+ let args = ColumnarValue::Scalar(ScalarValue::Int8(None));
+ match abs(&[args.clone(), fail_on_error_arg.clone()]) {
+ Ok(ColumnarValue::Scalar(ScalarValue::Int8(None))) => {}
+ _ => panic!("Expected NULL Int8, got different result"),
+ }
+
+ // Test Int16
+ let args = ColumnarValue::Scalar(ScalarValue::Int16(None));
+ match abs(&[args.clone(), fail_on_error_arg.clone()]) {
+ Ok(ColumnarValue::Scalar(ScalarValue::Int16(None))) => {}
+ _ => panic!("Expected NULL Int16, got different result"),
+ }
+
+ // Test Int32
+ let args = ColumnarValue::Scalar(ScalarValue::Int32(None));
+ match abs(&[args.clone(), fail_on_error_arg.clone()]) {
+ Ok(ColumnarValue::Scalar(ScalarValue::Int32(None))) => {}
+ _ => panic!("Expected NULL Int32, got different result"),
+ }
+
+ // Test Int64
+ let args = ColumnarValue::Scalar(ScalarValue::Int64(None));
+ match abs(&[args.clone(), fail_on_error_arg.clone()]) {
+ Ok(ColumnarValue::Scalar(ScalarValue::Int64(None))) => {}
+ _ => panic!("Expected NULL Int64, got different result"),
+ }
+
+ // Test Decimal128
+ let args = ColumnarValue::Scalar(ScalarValue::Decimal128(None, 10,
2));
+ match abs(&[args.clone(), fail_on_error_arg.clone()]) {
+ Ok(ColumnarValue::Scalar(ScalarValue::Decimal128(None, 10,
2))) => {}
+ _ => panic!("Expected NULL Decimal128, got different result"),
+ }
+
+ // Test Decimal256
+ let args = ColumnarValue::Scalar(ScalarValue::Decimal256(None, 10,
2));
+ match abs(&[args.clone(), fail_on_error_arg.clone()]) {
+ Ok(ColumnarValue::Scalar(ScalarValue::Decimal256(None, 10,
2))) => {}
+ _ => panic!("Expected NULL Decimal256, got different result"),
+ }
+
+ // Test Float32
+ let args = ColumnarValue::Scalar(ScalarValue::Float32(None));
+ match abs(&[args.clone(), fail_on_error_arg.clone()]) {
+ Ok(ColumnarValue::Scalar(ScalarValue::Float32(None))) => {}
+ _ => panic!("Expected NULL Float32, got different result"),
+ }
+
+ // Test Float64
+ let args = ColumnarValue::Scalar(ScalarValue::Float64(None));
+ match abs(&[args.clone(), fail_on_error_arg.clone()]) {
+ Ok(ColumnarValue::Scalar(ScalarValue::Float64(None))) => {}
+ _ => panic!("Expected NULL Float64, got different result"),
+ }
+
+ Ok(())
+ });
+ }
+}
diff --git a/native/spark-expr/src/math_funcs/mod.rs
b/native/spark-expr/src/math_funcs/mod.rs
index 873b290eb..7df87eb9f 100644
--- a/native/spark-expr/src/math_funcs/mod.rs
+++ b/native/spark-expr/src/math_funcs/mod.rs
@@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.
+pub(crate) mod abs;
mod ceil;
pub(crate) mod checked_arithmetic;
mod div;
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index 570c07cb0..63e18c145 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -137,7 +137,8 @@ object QueryPlanSerde extends Logging with CometExprShim {
classOf[Subtract] -> CometSubtract,
classOf[Tan] -> CometScalarFunction("tan"),
classOf[UnaryMinus] -> CometUnaryMinus,
- classOf[Unhex] -> CometUnhex)
+ classOf[Unhex] -> CometUnhex,
+ classOf[Abs] -> CometAbs)
private val mapExpressions: Map[Class[_ <: Expression],
CometExpressionSerde[_]] = Map(
classOf[GetMapValue] -> CometMapExtract,
diff --git a/spark/src/main/scala/org/apache/comet/serde/math.scala
b/spark/src/main/scala/org/apache/comet/serde/math.scala
index bfcd242d7..68b6e8d11 100644
--- a/spark/src/main/scala/org/apache/comet/serde/math.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/math.scala
@@ -19,8 +19,8 @@
package org.apache.comet.serde
-import org.apache.spark.sql.catalyst.expressions.{Atan2, Attribute, Ceil,
CheckOverflow, Expression, Floor, Hex, If, LessThanOrEqual, Literal, Log,
Log10, Log2, Unhex}
-import org.apache.spark.sql.types.DecimalType
+import org.apache.spark.sql.catalyst.expressions.{Abs, Atan2, Attribute, Ceil,
CheckOverflow, Expression, Floor, Hex, If, LessThanOrEqual, Literal, Log,
Log10, Log2, Unhex}
+import org.apache.spark.sql.types.{DecimalType, NumericType}
import org.apache.comet.CometSparkSessionExtensions.withInfo
import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal,
optExprWithInfo, scalarFunctionExprToProto,
scalarFunctionExprToProtoWithReturnType, serializeDataType}
@@ -144,6 +144,36 @@ object CometUnhex extends CometExpressionSerde[Unhex] with
MathExprBase {
}
}
+object CometAbs extends CometExpressionSerde[Abs] with MathExprBase {
+
+ override def getSupportLevel(expr: Abs): SupportLevel = {
+ expr.child.dataType match {
+ case _: NumericType =>
+ Compatible()
+ case _ =>
+ // Spark supports NumericType, DayTimeIntervalType, and
YearMonthIntervalType
+ Unsupported(Some("Only integral, floating-point, and decimal types are
supported"))
+ }
+ }
+
+ override def convert(
+ expr: Abs,
+ inputs: Seq[Attribute],
+ binding: Boolean): Option[ExprOuterClass.Expr] = {
+ val childExpr = exprToProtoInternal(expr.child, inputs, binding)
+ val failOnErrorExpr = exprToProtoInternal(Literal(expr.failOnError),
inputs, binding)
+
+ val optExpr =
+ scalarFunctionExprToProtoWithReturnType(
+ "abs",
+ expr.dataType,
+ false,
+ childExpr,
+ failOnErrorExpr)
+ optExprWithInfo(optExpr, expr, expr.child)
+ }
+}
+
sealed trait MathExprBase {
protected def nullIfNegative(expression: Expression): Expression = {
val zero = Literal.default(expression.dataType)
diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
index 7b6ed1945..d50274938 100644
--- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
@@ -21,8 +21,6 @@ package org.apache.comet
import java.time.{Duration, Period}
-import scala.reflect.ClassTag
-import scala.reflect.runtime.universe.TypeTag
import scala.util.Random
import org.scalactic.source.Position
@@ -1430,74 +1428,6 @@ class CometExpressionSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
testDoubleScalarExpr("expm1")
}
- // https://github.com/apache/datafusion-comet/issues/666
- ignore("abs") {
- Seq(true, false).foreach { dictionaryEnabled =>
- withTempDir { dir =>
- val path = new Path(dir.toURI.toString, "test.parquet")
- makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled =
dictionaryEnabled, 100)
- withParquetTable(path.toString, "tbl") {
- Seq(2, 3, 4, 5, 6, 7, 9, 10, 11, 12, 15, 16, 17).foreach { col =>
- checkSparkAnswerAndOperator(s"SELECT abs(_${col}) FROM tbl")
- }
- }
- }
- }
- }
-
- // https://github.com/apache/datafusion-comet/issues/666
- ignore("abs Overflow ansi mode") {
-
- def testAbsAnsiOverflow[T <: Product: ClassTag: TypeTag](data: Seq[T]):
Unit = {
- withParquetTable(data, "tbl") {
- checkSparkMaybeThrows(sql("select abs(_1), abs(_2) from tbl")) match {
- case (Some(sparkExc), Some(cometExc)) =>
- val cometErrorPattern =
- """.+[ARITHMETIC_OVERFLOW].+overflow. If necessary set
"spark.sql.ansi.enabled" to "false" to bypass this error.""".r
-
assert(cometErrorPattern.findFirstIn(cometExc.getMessage).isDefined)
- assert(sparkExc.getMessage.contains("overflow"))
- case _ => fail("Exception should be thrown")
- }
- }
- }
-
- def testAbsAnsi[T <: Product: ClassTag: TypeTag](data: Seq[T]): Unit = {
- withParquetTable(data, "tbl") {
- checkSparkAnswerAndOperator("select abs(_1), abs(_2) from tbl")
- }
- }
-
- withSQLConf(
- SQLConf.ANSI_ENABLED.key -> "true",
- CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key -> "true") {
- testAbsAnsiOverflow(Seq((Byte.MaxValue, Byte.MinValue)))
- testAbsAnsiOverflow(Seq((Short.MaxValue, Short.MinValue)))
- testAbsAnsiOverflow(Seq((Int.MaxValue, Int.MinValue)))
- testAbsAnsiOverflow(Seq((Long.MaxValue, Long.MinValue)))
- testAbsAnsi(Seq((Float.MaxValue, Float.MinValue)))
- testAbsAnsi(Seq((Double.MaxValue, Double.MinValue)))
- }
- }
-
- // https://github.com/apache/datafusion-comet/issues/666
- ignore("abs Overflow legacy mode") {
-
- def testAbsLegacyOverflow[T <: Product: ClassTag: TypeTag](data: Seq[T]):
Unit = {
- withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") {
- withParquetTable(data, "tbl") {
- checkSparkAnswerAndOperator("select abs(_1), abs(_2) from tbl")
- }
- }
- }
-
- testAbsLegacyOverflow(Seq((Byte.MaxValue, Byte.MinValue)))
- testAbsLegacyOverflow(Seq((Short.MaxValue, Short.MinValue)))
- testAbsLegacyOverflow(Seq((Int.MaxValue, Int.MinValue)))
- testAbsLegacyOverflow(Seq((Long.MaxValue, Long.MinValue)))
- testAbsLegacyOverflow(Seq((Float.MaxValue, Float.MinValue)))
- testAbsLegacyOverflow(Seq((Double.MaxValue, Double.MinValue)))
- }
-
test("ceil and floor") {
Seq("true", "false").foreach { dictionary =>
withSQLConf(
diff --git
a/spark/src/test/scala/org/apache/comet/CometMathExpressionSuite.scala
b/spark/src/test/scala/org/apache/comet/CometMathExpressionSuite.scala
new file mode 100644
index 000000000..c95047a0e
--- /dev/null
+++ b/spark/src/test/scala/org/apache/comet/CometMathExpressionSuite.scala
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet
+
+import scala.util.Random
+
+import org.apache.spark.sql.CometTestBase
+import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.{DataTypes, StructField, StructType}
+
+import org.apache.comet.testing.{DataGenOptions, FuzzDataGenerator}
+
+class CometMathExpressionSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
+
+ test("abs") {
+ val df = createTestData(generateNegativeZero = false)
+ df.createOrReplaceTempView("tbl")
+ for (field <- df.schema.fields) {
+ val col = field.name
+ checkSparkAnswerAndOperator(s"SELECT $col, abs($col) FROM tbl ORDER BY
$col")
+ }
+ }
+
+ test("abs - negative zero") {
+ val df = createTestData(generateNegativeZero = true)
+ df.createOrReplaceTempView("tbl")
+ for (field <- df.schema.fields.filter(f =>
+ f.dataType == DataTypes.FloatType || f.dataType ==
DataTypes.DoubleType)) {
+ val col = field.name
+ checkSparkAnswerAndOperator(
+ s"SELECT $col, abs($col) FROM tbl WHERE CAST($col as string) = '-0.0'
ORDER BY $col")
+ }
+ }
+
+ test("abs (ANSI mode)") {
+ val df = createTestData(generateNegativeZero = false)
+ df.createOrReplaceTempView("tbl")
+ withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") {
+ for (field <- df.schema.fields) {
+ val col = field.name
+ checkSparkMaybeThrows(sql(s"SELECT $col, abs($col) FROM tbl ORDER BY
$col")) match {
+ case (Some(sparkExc), Some(cometExc)) =>
+ val cometErrorPattern =
+ """.+[ARITHMETIC_OVERFLOW].+overflow. If necessary set
"spark.sql.ansi.enabled" to "false" to bypass this error.""".r
+
assert(cometErrorPattern.findFirstIn(cometExc.getMessage).isDefined)
+ assert(sparkExc.getMessage.contains("overflow"))
+ case (Some(_), None) =>
+ fail("Exception should be thrown")
+ case (None, Some(cometExc)) =>
+ throw cometExc
+ case _ =>
+ }
+ }
+ }
+ }
+
+ private def createTestData(generateNegativeZero: Boolean) = {
+ val r = new Random(42)
+ val schema = StructType(
+ Seq(
+ StructField("c0", DataTypes.ByteType, nullable = true),
+ StructField("c1", DataTypes.ShortType, nullable = true),
+ StructField("c2", DataTypes.IntegerType, nullable = true),
+ StructField("c3", DataTypes.LongType, nullable = true),
+ StructField("c4", DataTypes.FloatType, nullable = true),
+ StructField("c5", DataTypes.DoubleType, nullable = true),
+ StructField("c6", DataTypes.createDecimalType(10, 2), nullable =
true)))
+ FuzzDataGenerator.generateDataFrame(
+ r,
+ spark,
+ schema,
+ 1000,
+ DataGenOptions(generateNegativeZero = generateNegativeZero))
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]