This is an automated email from the ASF dual-hosted git repository.

viirya pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new fbcf02510 chore: extract predicate_functions expressions to folders 
based on spark grouping (#1218)
fbcf02510 is described below

commit fbcf0251082a43b5ee25b6c5933a9262cce44071
Author: Raz Luvaton <[email protected]>
AuthorDate: Wed Jan 8 09:10:34 2025 +0200

    chore: extract predicate_functions expressions to folders based on spark 
grouping (#1218)
    
    * extract predicate_functions expressions to folders based on spark grouping
    
    * code review changes
    
    ---------
    
    Co-authored-by: Andy Grove <[email protected]>
---
 native/spark-expr/src/lib.rs                       |  4 +-
 native/spark-expr/src/predicate_funcs/is_nan.rs    | 70 +++++++++++++++++++
 native/spark-expr/src/predicate_funcs/mod.rs       | 22 ++++++
 .../src/{regexp.rs => predicate_funcs/rlike.rs}    |  0
 native/spark-expr/src/scalar_funcs.rs              | 81 +++++++++++++++++++++-
 5 files changed, 173 insertions(+), 4 deletions(-)

diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs
index c7c54a4e9..c614e1f0a 100644
--- a/native/spark-expr/src/lib.rs
+++ b/native/spark-expr/src/lib.rs
@@ -30,7 +30,6 @@ pub use checkoverflow::CheckOverflow;
 
 mod kernels;
 mod list;
-mod regexp;
 pub mod scalar_funcs;
 mod schema_adapter;
 mod static_invoke;
@@ -50,6 +49,8 @@ mod unbound;
 pub use unbound::UnboundColumn;
 pub mod utils;
 pub use normalize_nan::NormalizeNaNAndZero;
+mod predicate_funcs;
+pub use predicate_funcs::{spark_isnan, RLike};
 
 mod agg_funcs;
 mod comet_scalar_funcs;
@@ -66,7 +67,6 @@ pub use datetime_funcs::*;
 pub use error::{SparkError, SparkResult};
 pub use if_expr::IfExpr;
 pub use list::{ArrayInsert, GetArrayStructFields, ListExtract};
-pub use regexp::RLike;
 pub use string_funcs::*;
 pub use struct_funcs::*;
 pub use to_json::ToJson;
diff --git a/native/spark-expr/src/predicate_funcs/is_nan.rs 
b/native/spark-expr/src/predicate_funcs/is_nan.rs
new file mode 100644
index 000000000..bf4d7e0f2
--- /dev/null
+++ b/native/spark-expr/src/predicate_funcs/is_nan.rs
@@ -0,0 +1,70 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use arrow::array::{Float32Array, Float64Array};
+use arrow_array::{Array, BooleanArray};
+use arrow_schema::DataType;
+use datafusion::physical_plan::ColumnarValue;
+use datafusion_common::{DataFusionError, ScalarValue};
+use std::sync::Arc;
+
+/// Spark-compatible `isnan` expression
+pub fn spark_isnan(args: &[ColumnarValue]) -> Result<ColumnarValue, 
DataFusionError> {
+    fn set_nulls_to_false(is_nan: BooleanArray) -> ColumnarValue {
+        match is_nan.nulls() {
+            Some(nulls) => {
+                let is_not_null = nulls.inner();
+                ColumnarValue::Array(Arc::new(BooleanArray::new(
+                    is_nan.values() & is_not_null,
+                    None,
+                )))
+            }
+            None => ColumnarValue::Array(Arc::new(is_nan)),
+        }
+    }
+    let value = &args[0];
+    match value {
+        ColumnarValue::Array(array) => match array.data_type() {
+            DataType::Float64 => {
+                let array = 
array.as_any().downcast_ref::<Float64Array>().unwrap();
+                let is_nan = BooleanArray::from_unary(array, |x| x.is_nan());
+                Ok(set_nulls_to_false(is_nan))
+            }
+            DataType::Float32 => {
+                let array = 
array.as_any().downcast_ref::<Float32Array>().unwrap();
+                let is_nan = BooleanArray::from_unary(array, |x| x.is_nan());
+                Ok(set_nulls_to_false(is_nan))
+            }
+            other => Err(DataFusionError::Internal(format!(
+                "Unsupported data type {:?} for function isnan",
+                other,
+            ))),
+        },
+        ColumnarValue::Scalar(a) => match a {
+            ScalarValue::Float64(a) => 
Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(
+                a.map(|x| x.is_nan()).unwrap_or(false),
+            )))),
+            ScalarValue::Float32(a) => 
Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(
+                a.map(|x| x.is_nan()).unwrap_or(false),
+            )))),
+            _ => Err(DataFusionError::Internal(format!(
+                "Unsupported data type {:?} for function isnan",
+                value.data_type(),
+            ))),
+        },
+    }
+}
diff --git a/native/spark-expr/src/predicate_funcs/mod.rs 
b/native/spark-expr/src/predicate_funcs/mod.rs
new file mode 100644
index 000000000..5f1f570c0
--- /dev/null
+++ b/native/spark-expr/src/predicate_funcs/mod.rs
@@ -0,0 +1,22 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+mod is_nan;
+mod rlike;
+
+pub use is_nan::spark_isnan;
+pub use rlike::RLike;
diff --git a/native/spark-expr/src/regexp.rs 
b/native/spark-expr/src/predicate_funcs/rlike.rs
similarity index 100%
rename from native/spark-expr/src/regexp.rs
rename to native/spark-expr/src/predicate_funcs/rlike.rs
diff --git a/native/spark-expr/src/scalar_funcs.rs 
b/native/spark-expr/src/scalar_funcs.rs
index e11d1c5db..9421d54fd 100644
--- a/native/spark-expr/src/scalar_funcs.rs
+++ b/native/spark-expr/src/scalar_funcs.rs
@@ -20,10 +20,14 @@ use arrow::{
         ArrayRef, AsArray, Decimal128Builder, Float32Array, Float64Array, 
Int16Array, Int32Array,
         Int64Array, Int64Builder, Int8Array,
     },
+    compute::kernels::numeric::{add, sub},
     datatypes::{validate_decimal_precision, Decimal128Type, Int64Type},
 };
-use arrow_array::{Array, ArrowNativeTypeOp, BooleanArray, Decimal128Array};
-use arrow_schema::{DataType, DECIMAL128_MAX_PRECISION};
+use arrow_array::builder::IntervalDayTimeBuilder;
+use arrow_array::types::{Int16Type, Int32Type, Int8Type, IntervalDayTime};
+use arrow_array::{Array, ArrowNativeTypeOp, BooleanArray, Datum, 
Decimal128Array};
+use arrow_schema::{ArrowError, DataType, DECIMAL128_MAX_PRECISION};
+use datafusion::physical_expr_common::datum;
 use datafusion::{functions::math::round::round, physical_plan::ColumnarValue};
 use datafusion_common::{
     exec_err, internal_err, DataFusionError, Result as DataFusionResult, 
ScalarValue,
@@ -447,6 +451,79 @@ pub fn spark_decimal_div(
     Ok(ColumnarValue::Array(Arc::new(result)))
 }
 
+macro_rules! scalar_date_arithmetic {
+    ($start:expr, $days:expr, $op:expr) => {{
+        let interval = IntervalDayTime::new(*$days as i32, 0);
+        let interval_cv = 
ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(interval)));
+        datum::apply($start, &interval_cv, $op)
+    }};
+}
+macro_rules! array_date_arithmetic {
+    ($days:expr, $interval_builder:expr, $intType:ty) => {{
+        for day in $days.as_primitive::<$intType>().into_iter() {
+            if let Some(non_null_day) = day {
+                
$interval_builder.append_value(IntervalDayTime::new(non_null_day as i32, 0));
+            } else {
+                $interval_builder.append_null();
+            }
+        }
+    }};
+}
+
+/// Spark-compatible `date_add` and `date_sub` expressions, which assumes days 
for the second
+/// argument, but we cannot directly add that to a Date32. We generate an 
IntervalDayTime from the
+/// second argument and use DataFusion's interface to apply Arrow's operators.
+fn spark_date_arithmetic(
+    args: &[ColumnarValue],
+    op: impl Fn(&dyn Datum, &dyn Datum) -> Result<ArrayRef, ArrowError>,
+) -> Result<ColumnarValue, DataFusionError> {
+    let start = &args[0];
+    match &args[1] {
+        ColumnarValue::Scalar(ScalarValue::Int8(Some(days))) => {
+            scalar_date_arithmetic!(start, days, op)
+        }
+        ColumnarValue::Scalar(ScalarValue::Int16(Some(days))) => {
+            scalar_date_arithmetic!(start, days, op)
+        }
+        ColumnarValue::Scalar(ScalarValue::Int32(Some(days))) => {
+            scalar_date_arithmetic!(start, days, op)
+        }
+        ColumnarValue::Array(days) => {
+            let mut interval_builder = 
IntervalDayTimeBuilder::with_capacity(days.len());
+            match days.data_type() {
+                DataType::Int8 => {
+                    array_date_arithmetic!(days, interval_builder, Int8Type)
+                }
+                DataType::Int16 => {
+                    array_date_arithmetic!(days, interval_builder, Int16Type)
+                }
+                DataType::Int32 => {
+                    array_date_arithmetic!(days, interval_builder, Int32Type)
+                }
+                _ => {
+                    return Err(DataFusionError::Internal(format!(
+                        "Unsupported data types {:?} for date arithmetic.",
+                        args,
+                    )))
+                }
+            }
+            let interval_cv = 
ColumnarValue::Array(Arc::new(interval_builder.finish()));
+            datum::apply(start, &interval_cv, op)
+        }
+        _ => Err(DataFusionError::Internal(format!(
+            "Unsupported data types {:?} for date arithmetic.",
+            args,
+        ))),
+    }
+}
+pub fn spark_date_add(args: &[ColumnarValue]) -> Result<ColumnarValue, 
DataFusionError> {
+    spark_date_arithmetic(args, add)
+}
+
+pub fn spark_date_sub(args: &[ColumnarValue]) -> Result<ColumnarValue, 
DataFusionError> {
+    spark_date_arithmetic(args, sub)
+}
+
 /// Spark-compatible `isnan` expression
 pub fn spark_isnan(args: &[ColumnarValue]) -> Result<ColumnarValue, 
DataFusionError> {
     fn set_nulls_to_false(is_nan: BooleanArray) -> ColumnarValue {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to