This is an automated email from the ASF dual-hosted git repository.
viirya pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new fbcf02510 chore: extract predicate_functions expressions to folders
based on spark grouping (#1218)
fbcf02510 is described below
commit fbcf0251082a43b5ee25b6c5933a9262cce44071
Author: Raz Luvaton <[email protected]>
AuthorDate: Wed Jan 8 09:10:34 2025 +0200
chore: extract predicate_functions expressions to folders based on spark
grouping (#1218)
* extract predicate_functions expressions to folders based on spark grouping
* code review changes
---------
Co-authored-by: Andy Grove <[email protected]>
---
native/spark-expr/src/lib.rs | 4 +-
native/spark-expr/src/predicate_funcs/is_nan.rs | 70 +++++++++++++++++++
native/spark-expr/src/predicate_funcs/mod.rs | 22 ++++++
.../src/{regexp.rs => predicate_funcs/rlike.rs} | 0
native/spark-expr/src/scalar_funcs.rs | 81 +++++++++++++++++++++-
5 files changed, 173 insertions(+), 4 deletions(-)
diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs
index c7c54a4e9..c614e1f0a 100644
--- a/native/spark-expr/src/lib.rs
+++ b/native/spark-expr/src/lib.rs
@@ -30,7 +30,6 @@ pub use checkoverflow::CheckOverflow;
mod kernels;
mod list;
-mod regexp;
pub mod scalar_funcs;
mod schema_adapter;
mod static_invoke;
@@ -50,6 +49,8 @@ mod unbound;
pub use unbound::UnboundColumn;
pub mod utils;
pub use normalize_nan::NormalizeNaNAndZero;
+mod predicate_funcs;
+pub use predicate_funcs::{spark_isnan, RLike};
mod agg_funcs;
mod comet_scalar_funcs;
@@ -66,7 +67,6 @@ pub use datetime_funcs::*;
pub use error::{SparkError, SparkResult};
pub use if_expr::IfExpr;
pub use list::{ArrayInsert, GetArrayStructFields, ListExtract};
-pub use regexp::RLike;
pub use string_funcs::*;
pub use struct_funcs::*;
pub use to_json::ToJson;
diff --git a/native/spark-expr/src/predicate_funcs/is_nan.rs
b/native/spark-expr/src/predicate_funcs/is_nan.rs
new file mode 100644
index 000000000..bf4d7e0f2
--- /dev/null
+++ b/native/spark-expr/src/predicate_funcs/is_nan.rs
@@ -0,0 +1,70 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use arrow::array::{Float32Array, Float64Array};
+use arrow_array::{Array, BooleanArray};
+use arrow_schema::DataType;
+use datafusion::physical_plan::ColumnarValue;
+use datafusion_common::{DataFusionError, ScalarValue};
+use std::sync::Arc;
+
+/// Spark-compatible `isnan` expression
+pub fn spark_isnan(args: &[ColumnarValue]) -> Result<ColumnarValue,
DataFusionError> {
+ fn set_nulls_to_false(is_nan: BooleanArray) -> ColumnarValue {
+ match is_nan.nulls() {
+ Some(nulls) => {
+ let is_not_null = nulls.inner();
+ ColumnarValue::Array(Arc::new(BooleanArray::new(
+ is_nan.values() & is_not_null,
+ None,
+ )))
+ }
+ None => ColumnarValue::Array(Arc::new(is_nan)),
+ }
+ }
+ let value = &args[0];
+ match value {
+ ColumnarValue::Array(array) => match array.data_type() {
+ DataType::Float64 => {
+ let array =
array.as_any().downcast_ref::<Float64Array>().unwrap();
+ let is_nan = BooleanArray::from_unary(array, |x| x.is_nan());
+ Ok(set_nulls_to_false(is_nan))
+ }
+ DataType::Float32 => {
+ let array =
array.as_any().downcast_ref::<Float32Array>().unwrap();
+ let is_nan = BooleanArray::from_unary(array, |x| x.is_nan());
+ Ok(set_nulls_to_false(is_nan))
+ }
+ other => Err(DataFusionError::Internal(format!(
+ "Unsupported data type {:?} for function isnan",
+ other,
+ ))),
+ },
+ ColumnarValue::Scalar(a) => match a {
+ ScalarValue::Float64(a) =>
Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(
+ a.map(|x| x.is_nan()).unwrap_or(false),
+ )))),
+ ScalarValue::Float32(a) =>
Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(
+ a.map(|x| x.is_nan()).unwrap_or(false),
+ )))),
+ _ => Err(DataFusionError::Internal(format!(
+ "Unsupported data type {:?} for function isnan",
+ value.data_type(),
+ ))),
+ },
+ }
+}
diff --git a/native/spark-expr/src/predicate_funcs/mod.rs
b/native/spark-expr/src/predicate_funcs/mod.rs
new file mode 100644
index 000000000..5f1f570c0
--- /dev/null
+++ b/native/spark-expr/src/predicate_funcs/mod.rs
@@ -0,0 +1,22 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+mod is_nan;
+mod rlike;
+
+pub use is_nan::spark_isnan;
+pub use rlike::RLike;
diff --git a/native/spark-expr/src/regexp.rs
b/native/spark-expr/src/predicate_funcs/rlike.rs
similarity index 100%
rename from native/spark-expr/src/regexp.rs
rename to native/spark-expr/src/predicate_funcs/rlike.rs
diff --git a/native/spark-expr/src/scalar_funcs.rs
b/native/spark-expr/src/scalar_funcs.rs
index e11d1c5db..9421d54fd 100644
--- a/native/spark-expr/src/scalar_funcs.rs
+++ b/native/spark-expr/src/scalar_funcs.rs
@@ -20,10 +20,14 @@ use arrow::{
ArrayRef, AsArray, Decimal128Builder, Float32Array, Float64Array,
Int16Array, Int32Array,
Int64Array, Int64Builder, Int8Array,
},
+ compute::kernels::numeric::{add, sub},
datatypes::{validate_decimal_precision, Decimal128Type, Int64Type},
};
-use arrow_array::{Array, ArrowNativeTypeOp, BooleanArray, Decimal128Array};
-use arrow_schema::{DataType, DECIMAL128_MAX_PRECISION};
+use arrow_array::builder::IntervalDayTimeBuilder;
+use arrow_array::types::{Int16Type, Int32Type, Int8Type, IntervalDayTime};
+use arrow_array::{Array, ArrowNativeTypeOp, BooleanArray, Datum,
Decimal128Array};
+use arrow_schema::{ArrowError, DataType, DECIMAL128_MAX_PRECISION};
+use datafusion::physical_expr_common::datum;
use datafusion::{functions::math::round::round, physical_plan::ColumnarValue};
use datafusion_common::{
exec_err, internal_err, DataFusionError, Result as DataFusionResult,
ScalarValue,
@@ -447,6 +451,79 @@ pub fn spark_decimal_div(
Ok(ColumnarValue::Array(Arc::new(result)))
}
+macro_rules! scalar_date_arithmetic {
+ ($start:expr, $days:expr, $op:expr) => {{
+ let interval = IntervalDayTime::new(*$days as i32, 0);
+ let interval_cv =
ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(interval)));
+ datum::apply($start, &interval_cv, $op)
+ }};
+}
+macro_rules! array_date_arithmetic {
+ ($days:expr, $interval_builder:expr, $intType:ty) => {{
+ for day in $days.as_primitive::<$intType>().into_iter() {
+ if let Some(non_null_day) = day {
+
$interval_builder.append_value(IntervalDayTime::new(non_null_day as i32, 0));
+ } else {
+ $interval_builder.append_null();
+ }
+ }
+ }};
+}
+
+/// Spark-compatible `date_add` and `date_sub` expressions, which assumes days
for the second
+/// argument, but we cannot directly add that to a Date32. We generate an
IntervalDayTime from the
+/// second argument and use DataFusion's interface to apply Arrow's operators.
+fn spark_date_arithmetic(
+ args: &[ColumnarValue],
+ op: impl Fn(&dyn Datum, &dyn Datum) -> Result<ArrayRef, ArrowError>,
+) -> Result<ColumnarValue, DataFusionError> {
+ let start = &args[0];
+ match &args[1] {
+ ColumnarValue::Scalar(ScalarValue::Int8(Some(days))) => {
+ scalar_date_arithmetic!(start, days, op)
+ }
+ ColumnarValue::Scalar(ScalarValue::Int16(Some(days))) => {
+ scalar_date_arithmetic!(start, days, op)
+ }
+ ColumnarValue::Scalar(ScalarValue::Int32(Some(days))) => {
+ scalar_date_arithmetic!(start, days, op)
+ }
+ ColumnarValue::Array(days) => {
+ let mut interval_builder =
IntervalDayTimeBuilder::with_capacity(days.len());
+ match days.data_type() {
+ DataType::Int8 => {
+ array_date_arithmetic!(days, interval_builder, Int8Type)
+ }
+ DataType::Int16 => {
+ array_date_arithmetic!(days, interval_builder, Int16Type)
+ }
+ DataType::Int32 => {
+ array_date_arithmetic!(days, interval_builder, Int32Type)
+ }
+ _ => {
+ return Err(DataFusionError::Internal(format!(
+ "Unsupported data types {:?} for date arithmetic.",
+ args,
+ )))
+ }
+ }
+ let interval_cv =
ColumnarValue::Array(Arc::new(interval_builder.finish()));
+ datum::apply(start, &interval_cv, op)
+ }
+ _ => Err(DataFusionError::Internal(format!(
+ "Unsupported data types {:?} for date arithmetic.",
+ args,
+ ))),
+ }
+}
+pub fn spark_date_add(args: &[ColumnarValue]) -> Result<ColumnarValue,
DataFusionError> {
+ spark_date_arithmetic(args, add)
+}
+
+pub fn spark_date_sub(args: &[ColumnarValue]) -> Result<ColumnarValue,
DataFusionError> {
+ spark_date_arithmetic(args, sub)
+}
+
/// Spark-compatible `isnan` expression
pub fn spark_isnan(args: &[ColumnarValue]) -> Result<ColumnarValue,
DataFusionError> {
fn set_nulls_to_false(is_nan: BooleanArray) -> ColumnarValue {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]