This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 517c2557f chore: extract math_funcs expressions to folders based on
spark grouping (#1219)
517c2557f is described below
commit 517c2557f3042926b3aaaf1a17daf8b947c0b4ce
Author: Raz Luvaton <[email protected]>
AuthorDate: Mon Jan 20 02:46:07 2025 +0200
chore: extract math_funcs expressions to folders based on spark grouping
(#1219)
* extract math_funcs expressions to folders based on spark grouping
* fix merge conflicts and move chr to `string_funcs`
---
native/spark-expr/benches/decimal_div.rs | 2 +-
native/spark-expr/src/comet_scalar_funcs.rs | 8 +-
native/spark-expr/src/hash_funcs/sha2.rs | 2 +-
native/spark-expr/src/lib.rs | 25 +-
native/spark-expr/src/math_funcs/ceil.rs | 83 +++
native/spark-expr/src/math_funcs/div.rs | 92 ++++
native/spark-expr/src/math_funcs/floor.rs | 83 +++
.../src/{scalar_funcs => math_funcs}/hex.rs | 0
.../src/{ => math_funcs/internal}/checkoverflow.rs | 0
.../src/math_funcs/internal/make_decimal.rs | 66 +++
.../{string_funcs => math_funcs/internal}/mod.rs | 14 +-
.../src/{ => math_funcs/internal}/normalize_nan.rs | 0
.../src/math_funcs/internal/unscaled_value.rs | 44 ++
.../src/{string_funcs => math_funcs}/mod.rs | 23 +-
native/spark-expr/src/{ => math_funcs}/negative.rs | 2 +-
native/spark-expr/src/math_funcs/round.rs | 137 +++++
.../src/{scalar_funcs => math_funcs}/unhex.rs | 0
native/spark-expr/src/math_funcs/utils.rs | 74 +++
native/spark-expr/src/scalar_funcs.rs | 569 ---------------------
.../src/{scalar_funcs => string_funcs}/chr.rs | 0
native/spark-expr/src/string_funcs/mod.rs | 2 +
21 files changed, 625 insertions(+), 601 deletions(-)
diff --git a/native/spark-expr/benches/decimal_div.rs
b/native/spark-expr/benches/decimal_div.rs
index 89f06e505..ad527fecb 100644
--- a/native/spark-expr/benches/decimal_div.rs
+++ b/native/spark-expr/benches/decimal_div.rs
@@ -19,7 +19,7 @@ use arrow::compute::cast;
use arrow_array::builder::Decimal128Builder;
use arrow_schema::DataType;
use criterion::{black_box, criterion_group, criterion_main, Criterion};
-use datafusion_comet_spark_expr::scalar_funcs::spark_decimal_div;
+use datafusion_comet_spark_expr::spark_decimal_div;
use datafusion_expr::ColumnarValue;
use std::sync::Arc;
diff --git a/native/spark-expr/src/comet_scalar_funcs.rs
b/native/spark-expr/src/comet_scalar_funcs.rs
index 27c77d7f2..6070e81d2 100644
--- a/native/spark-expr/src/comet_scalar_funcs.rs
+++ b/native/spark-expr/src/comet_scalar_funcs.rs
@@ -16,11 +16,11 @@
// under the License.
use crate::hash_funcs::*;
-use crate::scalar_funcs::{
- spark_ceil, spark_decimal_div, spark_floor, spark_hex, spark_isnan,
spark_make_decimal,
- spark_round, spark_unhex, spark_unscaled_value, SparkChrFunc,
+use crate::{
+ spark_ceil, spark_date_add, spark_date_sub, spark_decimal_div,
spark_floor, spark_hex,
+ spark_isnan, spark_make_decimal, spark_read_side_padding, spark_round,
spark_unhex,
+ spark_unscaled_value, SparkChrFunc,
};
-use crate::{spark_date_add, spark_date_sub, spark_read_side_padding};
use arrow_schema::DataType;
use datafusion_common::{DataFusionError, Result as DataFusionResult};
use datafusion_expr::registry::FunctionRegistry;
diff --git a/native/spark-expr/src/hash_funcs/sha2.rs
b/native/spark-expr/src/hash_funcs/sha2.rs
index 90917a9eb..40d8def3a 100644
--- a/native/spark-expr/src/hash_funcs/sha2.rs
+++ b/native/spark-expr/src/hash_funcs/sha2.rs
@@ -15,7 +15,7 @@
// specific language governing permissions and limitations
// under the License.
-use crate::scalar_funcs::hex_strings;
+use crate::math_funcs::hex::hex_strings;
use arrow_array::{Array, StringArray};
use datafusion::functions::crypto::{sha224, sha256, sha384, sha512};
use datafusion_common::cast::as_binary_array;
diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs
index 14982264d..c9cfab27d 100644
--- a/native/spark-expr/src/lib.rs
+++ b/native/spark-expr/src/lib.rs
@@ -21,29 +21,22 @@
mod error;
-mod checkoverflow;
-pub use checkoverflow::CheckOverflow;
-
mod kernels;
-pub mod scalar_funcs;
mod schema_adapter;
mod static_invoke;
pub use schema_adapter::SparkSchemaAdapterFactory;
pub use static_invoke::*;
-mod negative;
mod struct_funcs;
-pub use negative::{create_negate_expr, NegativeExpr};
-mod normalize_nan;
+pub use struct_funcs::{CreateNamedStruct, GetStructField};
mod json_funcs;
pub mod test_common;
pub mod timezone;
mod unbound;
pub use unbound::UnboundColumn;
-pub mod utils;
-pub use normalize_nan::NormalizeNaNAndZero;
mod predicate_funcs;
+pub mod utils;
pub use predicate_funcs::{spark_isnan, RLike};
mod agg_funcs;
@@ -57,11 +50,10 @@ mod string_funcs;
mod datetime_funcs;
pub use agg_funcs::*;
-pub use crate::{CreateNamedStruct, GetStructField};
-pub use crate::{DateTruncExpr, HourExpr, MinuteExpr, SecondExpr,
TimestampTruncExpr};
pub use cast::{spark_cast, Cast, SparkCastOptions};
mod conditional_funcs;
mod conversion_funcs;
+mod math_funcs;
pub use array_funcs::*;
pub use bitwise_funcs::*;
@@ -69,12 +61,19 @@ pub use conditional_funcs::*;
pub use conversion_funcs::*;
pub use comet_scalar_funcs::create_comet_physical_fun;
-pub use datetime_funcs::*;
+pub use datetime_funcs::{
+ spark_date_add, spark_date_sub, DateTruncExpr, HourExpr, MinuteExpr,
SecondExpr,
+ TimestampTruncExpr,
+};
pub use error::{SparkError, SparkResult};
pub use hash_funcs::*;
pub use json_funcs::ToJson;
+pub use math_funcs::{
+ create_negate_expr, spark_ceil, spark_decimal_div, spark_floor, spark_hex,
spark_make_decimal,
+ spark_round, spark_unhex, spark_unscaled_value, CheckOverflow,
NegativeExpr,
+ NormalizeNaNAndZero,
+};
pub use string_funcs::*;
-pub use struct_funcs::*;
/// Spark supports three evaluation modes when evaluating expressions, which
affect
/// the behavior when processing input values that are invalid or would result
in an
diff --git a/native/spark-expr/src/math_funcs/ceil.rs
b/native/spark-expr/src/math_funcs/ceil.rs
new file mode 100644
index 000000000..9c0fc9b57
--- /dev/null
+++ b/native/spark-expr/src/math_funcs/ceil.rs
@@ -0,0 +1,83 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::downcast_compute_op;
+use crate::math_funcs::utils::{get_precision_scale, make_decimal_array,
make_decimal_scalar};
+use arrow::array::{Float32Array, Float64Array, Int64Array};
+use arrow_array::{Array, ArrowNativeTypeOp};
+use arrow_schema::DataType;
+use datafusion::physical_plan::ColumnarValue;
+use datafusion_common::{DataFusionError, ScalarValue};
+use num::integer::div_ceil;
+use std::sync::Arc;
+
+/// `ceil` function that simulates Spark `ceil` expression
+pub fn spark_ceil(
+ args: &[ColumnarValue],
+ data_type: &DataType,
+) -> Result<ColumnarValue, DataFusionError> {
+ let value = &args[0];
+ match value {
+ ColumnarValue::Array(array) => match array.data_type() {
+ DataType::Float32 => {
+ let result = downcast_compute_op!(array, "ceil", ceil,
Float32Array, Int64Array);
+ Ok(ColumnarValue::Array(result?))
+ }
+ DataType::Float64 => {
+ let result = downcast_compute_op!(array, "ceil", ceil,
Float64Array, Int64Array);
+ Ok(ColumnarValue::Array(result?))
+ }
+ DataType::Int64 => {
+ let result =
array.as_any().downcast_ref::<Int64Array>().unwrap();
+ Ok(ColumnarValue::Array(Arc::new(result.clone())))
+ }
+ DataType::Decimal128(_, scale) if *scale > 0 => {
+ let f = decimal_ceil_f(scale);
+ let (precision, scale) = get_precision_scale(data_type);
+ make_decimal_array(array, precision, scale, &f)
+ }
+ other => Err(DataFusionError::Internal(format!(
+ "Unsupported data type {:?} for function ceil",
+ other,
+ ))),
+ },
+ ColumnarValue::Scalar(a) => match a {
+ ScalarValue::Float32(a) =>
Ok(ColumnarValue::Scalar(ScalarValue::Int64(
+ a.map(|x| x.ceil() as i64),
+ ))),
+ ScalarValue::Float64(a) =>
Ok(ColumnarValue::Scalar(ScalarValue::Int64(
+ a.map(|x| x.ceil() as i64),
+ ))),
+ ScalarValue::Int64(a) =>
Ok(ColumnarValue::Scalar(ScalarValue::Int64(a.map(|x| x)))),
+ ScalarValue::Decimal128(a, _, scale) if *scale > 0 => {
+ let f = decimal_ceil_f(scale);
+ let (precision, scale) = get_precision_scale(data_type);
+ make_decimal_scalar(a, precision, scale, &f)
+ }
+ _ => Err(DataFusionError::Internal(format!(
+ "Unsupported data type {:?} for function ceil",
+ value.data_type(),
+ ))),
+ },
+ }
+}
+
+#[inline]
+fn decimal_ceil_f(scale: &i8) -> impl Fn(i128) -> i128 {
+ let div = 10_i128.pow_wrapping(*scale as u32);
+ move |x: i128| div_ceil(x, div)
+}
diff --git a/native/spark-expr/src/math_funcs/div.rs
b/native/spark-expr/src/math_funcs/div.rs
new file mode 100644
index 000000000..72c23b9e9
--- /dev/null
+++ b/native/spark-expr/src/math_funcs/div.rs
@@ -0,0 +1,92 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::math_funcs::utils::get_precision_scale;
+use arrow::{
+ array::{ArrayRef, AsArray},
+ datatypes::Decimal128Type,
+};
+use arrow_array::{Array, Decimal128Array};
+use arrow_schema::{DataType, DECIMAL128_MAX_PRECISION};
+use datafusion::physical_plan::ColumnarValue;
+use datafusion_common::DataFusionError;
+use num::{BigInt, Signed, ToPrimitive};
+use std::sync::Arc;
+
+// Let Decimal(p3, s3) as return type i.e. Decimal(p1, s1) / Decimal(p2, s2) =
Decimal(p3, s3).
+// Conversely, Decimal(p1, s1) = Decimal(p2, s2) * Decimal(p3, s3). This means
that, in order to
+// get enough scale that matches with Spark behavior, it requires to widen s1
to s2 + s3 + 1. Since
+// both s2 and s3 are 38 at max., s1 is 77 at max. DataFusion division cannot
handle such scale >
+// Decimal256Type::MAX_SCALE. Therefore, we need to implement this decimal
division using BigInt.
+pub fn spark_decimal_div(
+ args: &[ColumnarValue],
+ data_type: &DataType,
+) -> Result<ColumnarValue, DataFusionError> {
+ let left = &args[0];
+ let right = &args[1];
+ let (p3, s3) = get_precision_scale(data_type);
+
+ let (left, right): (ArrayRef, ArrayRef) = match (left, right) {
+ (ColumnarValue::Array(l), ColumnarValue::Array(r)) => (Arc::clone(l),
Arc::clone(r)),
+ (ColumnarValue::Scalar(l), ColumnarValue::Array(r)) => {
+ (l.to_array_of_size(r.len())?, Arc::clone(r))
+ }
+ (ColumnarValue::Array(l), ColumnarValue::Scalar(r)) => {
+ (Arc::clone(l), r.to_array_of_size(l.len())?)
+ }
+ (ColumnarValue::Scalar(l), ColumnarValue::Scalar(r)) =>
(l.to_array()?, r.to_array()?),
+ };
+ let left = left.as_primitive::<Decimal128Type>();
+ let right = right.as_primitive::<Decimal128Type>();
+ let (p1, s1) = get_precision_scale(left.data_type());
+ let (p2, s2) = get_precision_scale(right.data_type());
+
+ let l_exp = ((s2 + s3 + 1) as u32).saturating_sub(s1 as u32);
+ let r_exp = (s1 as u32).saturating_sub((s2 + s3 + 1) as u32);
+ let result: Decimal128Array = if p1 as u32 + l_exp >
DECIMAL128_MAX_PRECISION as u32
+ || p2 as u32 + r_exp > DECIMAL128_MAX_PRECISION as u32
+ {
+ let ten = BigInt::from(10);
+ let l_mul = ten.pow(l_exp);
+ let r_mul = ten.pow(r_exp);
+ let five = BigInt::from(5);
+ let zero = BigInt::from(0);
+ arrow::compute::kernels::arity::binary(left, right, |l, r| {
+ let l = BigInt::from(l) * &l_mul;
+ let r = BigInt::from(r) * &r_mul;
+ let div = if r.eq(&zero) { zero.clone() } else { &l / &r };
+ let res = if div.is_negative() {
+ div - &five
+ } else {
+ div + &five
+ } / &ten;
+ res.to_i128().unwrap_or(i128::MAX)
+ })?
+ } else {
+ let l_mul = 10_i128.pow(l_exp);
+ let r_mul = 10_i128.pow(r_exp);
+ arrow::compute::kernels::arity::binary(left, right, |l, r| {
+ let l = l * l_mul;
+ let r = r * r_mul;
+ let div = if r == 0 { 0 } else { l / r };
+ let res = if div.is_negative() { div - 5 } else { div + 5 } / 10;
+ res.to_i128().unwrap_or(i128::MAX)
+ })?
+ };
+ let result = result.with_data_type(DataType::Decimal128(p3, s3));
+ Ok(ColumnarValue::Array(Arc::new(result)))
+}
diff --git a/native/spark-expr/src/math_funcs/floor.rs
b/native/spark-expr/src/math_funcs/floor.rs
new file mode 100644
index 000000000..9a95d95af
--- /dev/null
+++ b/native/spark-expr/src/math_funcs/floor.rs
@@ -0,0 +1,83 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::downcast_compute_op;
+use crate::math_funcs::utils::{get_precision_scale, make_decimal_array,
make_decimal_scalar};
+use arrow::array::{Float32Array, Float64Array, Int64Array};
+use arrow_array::{Array, ArrowNativeTypeOp};
+use arrow_schema::DataType;
+use datafusion::physical_plan::ColumnarValue;
+use datafusion_common::{DataFusionError, ScalarValue};
+use num::integer::div_floor;
+use std::sync::Arc;
+
+/// `floor` function that simulates Spark `floor` expression
+pub fn spark_floor(
+ args: &[ColumnarValue],
+ data_type: &DataType,
+) -> Result<ColumnarValue, DataFusionError> {
+ let value = &args[0];
+ match value {
+ ColumnarValue::Array(array) => match array.data_type() {
+ DataType::Float32 => {
+ let result = downcast_compute_op!(array, "floor", floor,
Float32Array, Int64Array);
+ Ok(ColumnarValue::Array(result?))
+ }
+ DataType::Float64 => {
+ let result = downcast_compute_op!(array, "floor", floor,
Float64Array, Int64Array);
+ Ok(ColumnarValue::Array(result?))
+ }
+ DataType::Int64 => {
+ let result =
array.as_any().downcast_ref::<Int64Array>().unwrap();
+ Ok(ColumnarValue::Array(Arc::new(result.clone())))
+ }
+ DataType::Decimal128(_, scale) if *scale > 0 => {
+ let f = decimal_floor_f(scale);
+ let (precision, scale) = get_precision_scale(data_type);
+ make_decimal_array(array, precision, scale, &f)
+ }
+ other => Err(DataFusionError::Internal(format!(
+ "Unsupported data type {:?} for function floor",
+ other,
+ ))),
+ },
+ ColumnarValue::Scalar(a) => match a {
+ ScalarValue::Float32(a) =>
Ok(ColumnarValue::Scalar(ScalarValue::Int64(
+ a.map(|x| x.floor() as i64),
+ ))),
+ ScalarValue::Float64(a) =>
Ok(ColumnarValue::Scalar(ScalarValue::Int64(
+ a.map(|x| x.floor() as i64),
+ ))),
+ ScalarValue::Int64(a) =>
Ok(ColumnarValue::Scalar(ScalarValue::Int64(a.map(|x| x)))),
+ ScalarValue::Decimal128(a, _, scale) if *scale > 0 => {
+ let f = decimal_floor_f(scale);
+ let (precision, scale) = get_precision_scale(data_type);
+ make_decimal_scalar(a, precision, scale, &f)
+ }
+ _ => Err(DataFusionError::Internal(format!(
+ "Unsupported data type {:?} for function floor",
+ value.data_type(),
+ ))),
+ },
+ }
+}
+
+#[inline]
+fn decimal_floor_f(scale: &i8) -> impl Fn(i128) -> i128 {
+ let div = 10_i128.pow_wrapping(*scale as u32);
+ move |x: i128| div_floor(x, div)
+}
diff --git a/native/spark-expr/src/scalar_funcs/hex.rs
b/native/spark-expr/src/math_funcs/hex.rs
similarity index 100%
rename from native/spark-expr/src/scalar_funcs/hex.rs
rename to native/spark-expr/src/math_funcs/hex.rs
diff --git a/native/spark-expr/src/checkoverflow.rs
b/native/spark-expr/src/math_funcs/internal/checkoverflow.rs
similarity index 100%
rename from native/spark-expr/src/checkoverflow.rs
rename to native/spark-expr/src/math_funcs/internal/checkoverflow.rs
diff --git a/native/spark-expr/src/math_funcs/internal/make_decimal.rs
b/native/spark-expr/src/math_funcs/internal/make_decimal.rs
new file mode 100644
index 000000000..dd761cd69
--- /dev/null
+++ b/native/spark-expr/src/math_funcs/internal/make_decimal.rs
@@ -0,0 +1,66 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::math_funcs::utils::get_precision_scale;
+use arrow::{
+ array::{AsArray, Decimal128Builder},
+ datatypes::{validate_decimal_precision, Int64Type},
+};
+use arrow_schema::DataType;
+use datafusion::physical_plan::ColumnarValue;
+use datafusion_common::{internal_err, Result as DataFusionResult, ScalarValue};
+use std::sync::Arc;
+
+/// Spark-compatible `MakeDecimal` expression (internal to Spark optimizer)
+pub fn spark_make_decimal(
+ args: &[ColumnarValue],
+ data_type: &DataType,
+) -> DataFusionResult<ColumnarValue> {
+ let (precision, scale) = get_precision_scale(data_type);
+ match &args[0] {
+ ColumnarValue::Scalar(v) => match v {
+ ScalarValue::Int64(n) =>
Ok(ColumnarValue::Scalar(ScalarValue::Decimal128(
+ long_to_decimal(n, precision),
+ precision,
+ scale,
+ ))),
+ sv => internal_err!("Expected Int64 but found {sv:?}"),
+ },
+ ColumnarValue::Array(a) => {
+ let arr = a.as_primitive::<Int64Type>();
+ let mut result = Decimal128Builder::new();
+ for v in arr.into_iter() {
+ result.append_option(long_to_decimal(&v, precision))
+ }
+ let result_type = DataType::Decimal128(precision, scale);
+
+ Ok(ColumnarValue::Array(Arc::new(
+ result.finish().with_data_type(result_type),
+ )))
+ }
+ }
+}
+
+/// Convert the input long to decimal with the given maximum precision. If
overflows, returns null
+/// instead.
+#[inline]
+fn long_to_decimal(v: &Option<i64>, precision: u8) -> Option<i128> {
+ match v {
+ Some(v) if validate_decimal_precision(*v as i128, precision).is_ok()
=> Some(*v as i128),
+ _ => None,
+ }
+}
diff --git a/native/spark-expr/src/string_funcs/mod.rs
b/native/spark-expr/src/math_funcs/internal/mod.rs
similarity index 76%
copy from native/spark-expr/src/string_funcs/mod.rs
copy to native/spark-expr/src/math_funcs/internal/mod.rs
index 2c2a5b37c..29295f0d5 100644
--- a/native/spark-expr/src/string_funcs/mod.rs
+++ b/native/spark-expr/src/math_funcs/internal/mod.rs
@@ -15,10 +15,12 @@
// specific language governing permissions and limitations
// under the License.
-mod prediction;
-mod string_space;
-mod substring;
+mod checkoverflow;
+mod make_decimal;
+mod normalize_nan;
+mod unscaled_value;
-pub use prediction::*;
-pub use string_space::StringSpaceExpr;
-pub use substring::SubstringExpr;
+pub use checkoverflow::CheckOverflow;
+pub use make_decimal::spark_make_decimal;
+pub use normalize_nan::NormalizeNaNAndZero;
+pub use unscaled_value::spark_unscaled_value;
diff --git a/native/spark-expr/src/normalize_nan.rs
b/native/spark-expr/src/math_funcs/internal/normalize_nan.rs
similarity index 100%
rename from native/spark-expr/src/normalize_nan.rs
rename to native/spark-expr/src/math_funcs/internal/normalize_nan.rs
diff --git a/native/spark-expr/src/math_funcs/internal/unscaled_value.rs
b/native/spark-expr/src/math_funcs/internal/unscaled_value.rs
new file mode 100644
index 000000000..053f9b078
--- /dev/null
+++ b/native/spark-expr/src/math_funcs/internal/unscaled_value.rs
@@ -0,0 +1,44 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use arrow::{
+ array::{AsArray, Int64Builder},
+ datatypes::Decimal128Type,
+};
+use datafusion::physical_plan::ColumnarValue;
+use datafusion_common::{internal_err, Result as DataFusionResult, ScalarValue};
+use std::sync::Arc;
+
+/// Spark-compatible `UnscaledValue` expression (internal to Spark optimizer)
+pub fn spark_unscaled_value(args: &[ColumnarValue]) ->
DataFusionResult<ColumnarValue> {
+ match &args[0] {
+ ColumnarValue::Scalar(v) => match v {
+ ScalarValue::Decimal128(d, _, _) =>
Ok(ColumnarValue::Scalar(ScalarValue::Int64(
+ d.map(|n| n as i64),
+ ))),
+ dt => internal_err!("Expected Decimal128 but found {dt:}"),
+ },
+ ColumnarValue::Array(a) => {
+ let arr = a.as_primitive::<Decimal128Type>();
+ let mut result = Int64Builder::new();
+ for v in arr.into_iter() {
+ result.append_option(v.map(|v| v as i64));
+ }
+ Ok(ColumnarValue::Array(Arc::new(result.finish())))
+ }
+ }
+}
diff --git a/native/spark-expr/src/string_funcs/mod.rs
b/native/spark-expr/src/math_funcs/mod.rs
similarity index 68%
copy from native/spark-expr/src/string_funcs/mod.rs
copy to native/spark-expr/src/math_funcs/mod.rs
index 2c2a5b37c..c559ae15c 100644
--- a/native/spark-expr/src/string_funcs/mod.rs
+++ b/native/spark-expr/src/math_funcs/mod.rs
@@ -15,10 +15,21 @@
// specific language governing permissions and limitations
// under the License.
-mod prediction;
-mod string_space;
-mod substring;
+mod ceil;
+mod div;
+mod floor;
+pub(crate) mod hex;
+pub mod internal;
+mod negative;
+mod round;
+pub(crate) mod unhex;
+mod utils;
-pub use prediction::*;
-pub use string_space::StringSpaceExpr;
-pub use substring::SubstringExpr;
+pub use ceil::spark_ceil;
+pub use div::spark_decimal_div;
+pub use floor::spark_floor;
+pub use hex::spark_hex;
+pub use internal::*;
+pub use negative::{create_negate_expr, NegativeExpr};
+pub use round::spark_round;
+pub use unhex::spark_unhex;
diff --git a/native/spark-expr/src/negative.rs
b/native/spark-expr/src/math_funcs/negative.rs
similarity index 99%
rename from native/spark-expr/src/negative.rs
rename to native/spark-expr/src/math_funcs/negative.rs
index 7fb508917..cafbcfcbd 100644
--- a/native/spark-expr/src/negative.rs
+++ b/native/spark-expr/src/math_funcs/negative.rs
@@ -15,7 +15,7 @@
// specific language governing permissions and limitations
// under the License.
-use super::arithmetic_overflow_error;
+use crate::arithmetic_overflow_error;
use crate::SparkError;
use arrow::{compute::kernels::numeric::neg_wrapping,
datatypes::IntervalDayTimeType};
use arrow_array::RecordBatch;
diff --git a/native/spark-expr/src/math_funcs/round.rs
b/native/spark-expr/src/math_funcs/round.rs
new file mode 100644
index 000000000..a47b7bc29
--- /dev/null
+++ b/native/spark-expr/src/math_funcs/round.rs
@@ -0,0 +1,137 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::math_funcs::utils::{get_precision_scale, make_decimal_array,
make_decimal_scalar};
+use arrow::array::{Int16Array, Int32Array, Int64Array, Int8Array};
+use arrow_array::{Array, ArrowNativeTypeOp};
+use arrow_schema::DataType;
+use datafusion::{functions::math::round::round, physical_plan::ColumnarValue};
+use datafusion_common::{exec_err, internal_err, DataFusionError, ScalarValue};
+use std::{cmp::min, sync::Arc};
+
+macro_rules! integer_round {
+ ($X:expr, $DIV:expr, $HALF:expr) => {{
+ let rem = $X % $DIV;
+ if rem <= -$HALF {
+ ($X - rem).sub_wrapping($DIV)
+ } else if rem >= $HALF {
+ ($X - rem).add_wrapping($DIV)
+ } else {
+ $X - rem
+ }
+ }};
+}
+
+macro_rules! round_integer_array {
+ ($ARRAY:expr, $POINT:expr, $TYPE:ty, $NATIVE:ty) => {{
+ let array = $ARRAY.as_any().downcast_ref::<$TYPE>().unwrap();
+ let ten: $NATIVE = 10;
+ let result: $TYPE = if let Some(div) = ten.checked_pow((-(*$POINT)) as
u32) {
+ let half = div / 2;
+ arrow::compute::kernels::arity::unary(array, |x| integer_round!(x,
div, half))
+ } else {
+ arrow::compute::kernels::arity::unary(array, |_| 0)
+ };
+ Ok(ColumnarValue::Array(Arc::new(result)))
+ }};
+}
+
+macro_rules! round_integer_scalar {
+ ($SCALAR:expr, $POINT:expr, $TYPE:expr, $NATIVE:ty) => {{
+ let ten: $NATIVE = 10;
+ if let Some(div) = ten.checked_pow((-(*$POINT)) as u32) {
+ let half = div / 2;
+ Ok(ColumnarValue::Scalar($TYPE(
+ $SCALAR.map(|x| integer_round!(x, div, half)),
+ )))
+ } else {
+ Ok(ColumnarValue::Scalar($TYPE(Some(0))))
+ }
+ }};
+}
+
+/// `round` function that simulates Spark `round` expression
+pub fn spark_round(
+ args: &[ColumnarValue],
+ data_type: &DataType,
+) -> Result<ColumnarValue, DataFusionError> {
+ let value = &args[0];
+ let point = &args[1];
+ let ColumnarValue::Scalar(ScalarValue::Int64(Some(point))) = point else {
+ return internal_err!("Invalid point argument for Round(): {:#?}",
point);
+ };
+ match value {
+ ColumnarValue::Array(array) => match array.data_type() {
+ DataType::Int64 if *point < 0 => round_integer_array!(array,
point, Int64Array, i64),
+ DataType::Int32 if *point < 0 => round_integer_array!(array,
point, Int32Array, i32),
+ DataType::Int16 if *point < 0 => round_integer_array!(array,
point, Int16Array, i16),
+ DataType::Int8 if *point < 0 => round_integer_array!(array, point,
Int8Array, i8),
+ DataType::Decimal128(_, scale) if *scale >= 0 => {
+ let f = decimal_round_f(scale, point);
+ let (precision, scale) = get_precision_scale(data_type);
+ make_decimal_array(array, precision, scale, &f)
+ }
+ DataType::Float32 | DataType::Float64 => {
+ Ok(ColumnarValue::Array(round(&[Arc::clone(array)])?))
+ }
+ dt => exec_err!("Not supported datatype for ROUND: {dt}"),
+ },
+ ColumnarValue::Scalar(a) => match a {
+ ScalarValue::Int64(a) if *point < 0 => {
+ round_integer_scalar!(a, point, ScalarValue::Int64, i64)
+ }
+ ScalarValue::Int32(a) if *point < 0 => {
+ round_integer_scalar!(a, point, ScalarValue::Int32, i32)
+ }
+ ScalarValue::Int16(a) if *point < 0 => {
+ round_integer_scalar!(a, point, ScalarValue::Int16, i16)
+ }
+ ScalarValue::Int8(a) if *point < 0 => {
+ round_integer_scalar!(a, point, ScalarValue::Int8, i8)
+ }
+ ScalarValue::Decimal128(a, _, scale) if *scale >= 0 => {
+ let f = decimal_round_f(scale, point);
+ let (precision, scale) = get_precision_scale(data_type);
+ make_decimal_scalar(a, precision, scale, &f)
+ }
+ ScalarValue::Float32(_) | ScalarValue::Float64(_) =>
Ok(ColumnarValue::Scalar(
+ ScalarValue::try_from_array(&round(&[a.to_array()?])?, 0)?,
+ )),
+ dt => exec_err!("Not supported datatype for ROUND: {dt}"),
+ },
+ }
+}
+
+// Spark uses BigDecimal. See RoundBase implementation in Spark. Instead, we
do the same by
+// 1) add the half of divisor, 2) round down by division, 3) adjust precision
by multiplication
+#[inline]
+fn decimal_round_f(scale: &i8, point: &i64) -> Box<dyn Fn(i128) -> i128> {
+ if *point < 0 {
+ if let Some(div) = 10_i128.checked_pow((-(*point) as u32) + (*scale as
u32)) {
+ let half = div / 2;
+ let mul = 10_i128.pow_wrapping((-(*point)) as u32);
+ // i128 can hold 39 digits of a base 10 number, adding half will
not cause overflow
+ Box::new(move |x: i128| (x + x.signum() * half) / div * mul)
+ } else {
+ Box::new(move |_: i128| 0)
+ }
+ } else {
+ let div = 10_i128.pow_wrapping((*scale as u32) - min(*scale as u32,
*point as u32));
+ let half = div / 2;
+ Box::new(move |x: i128| (x + x.signum() * half) / div)
+ }
+}
diff --git a/native/spark-expr/src/scalar_funcs/unhex.rs
b/native/spark-expr/src/math_funcs/unhex.rs
similarity index 100%
rename from native/spark-expr/src/scalar_funcs/unhex.rs
rename to native/spark-expr/src/math_funcs/unhex.rs
diff --git a/native/spark-expr/src/math_funcs/utils.rs
b/native/spark-expr/src/math_funcs/utils.rs
new file mode 100644
index 000000000..204b7139e
--- /dev/null
+++ b/native/spark-expr/src/math_funcs/utils.rs
@@ -0,0 +1,74 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use arrow_array::cast::AsArray;
+use arrow_array::types::Decimal128Type;
+use arrow_array::{ArrayRef, Decimal128Array};
+use arrow_schema::DataType;
+use datafusion_common::{DataFusionError, ScalarValue};
+use datafusion_expr_common::columnar_value::ColumnarValue;
+use std::sync::Arc;
+
+#[macro_export]
+macro_rules! downcast_compute_op {
+ ($ARRAY:expr, $NAME:expr, $FUNC:ident, $TYPE:ident, $RESULT:ident) => {{
+ let n = $ARRAY.as_any().downcast_ref::<$TYPE>();
+ match n {
+ Some(array) => {
+ let res: $RESULT =
+ arrow::compute::kernels::arity::unary(array, |x| x.$FUNC()
as i64);
+ Ok(Arc::new(res))
+ }
+ _ => Err(DataFusionError::Internal(format!(
+ "Invalid data type for {}",
+ $NAME
+ ))),
+ }
+ }};
+}
+
+#[inline]
+pub(crate) fn make_decimal_scalar(
+ a: &Option<i128>,
+ precision: u8,
+ scale: i8,
+ f: &dyn Fn(i128) -> i128,
+) -> Result<ColumnarValue, DataFusionError> {
+ let result = ScalarValue::Decimal128(a.map(f), precision, scale);
+ Ok(ColumnarValue::Scalar(result))
+}
+
+#[inline]
+pub(crate) fn make_decimal_array(
+ array: &ArrayRef,
+ precision: u8,
+ scale: i8,
+ f: &dyn Fn(i128) -> i128,
+) -> Result<ColumnarValue, DataFusionError> {
+ let array = array.as_primitive::<Decimal128Type>();
+ let result: Decimal128Array = arrow::compute::kernels::arity::unary(array,
f);
+ let result = result.with_data_type(DataType::Decimal128(precision, scale));
+ Ok(ColumnarValue::Array(Arc::new(result)))
+}
+
+#[inline]
+pub(crate) fn get_precision_scale(data_type: &DataType) -> (u8, i8) {
+ let DataType::Decimal128(precision, scale) = data_type else {
+ unreachable!()
+ };
+ (*precision, *scale)
+}
diff --git a/native/spark-expr/src/scalar_funcs.rs
b/native/spark-expr/src/scalar_funcs.rs
deleted file mode 100644
index 52ece10e8..000000000
--- a/native/spark-expr/src/scalar_funcs.rs
+++ /dev/null
@@ -1,569 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements. See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership. The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License. You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied. See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-use arrow::{
- array::{
- ArrayRef, AsArray, Decimal128Builder, Float32Array, Float64Array,
Int16Array, Int32Array,
- Int64Array, Int64Builder, Int8Array,
- },
- compute::kernels::numeric::{add, sub},
- datatypes::{validate_decimal_precision, Decimal128Type, Int64Type},
-};
-use arrow_array::builder::IntervalDayTimeBuilder;
-use arrow_array::types::{Int16Type, Int32Type, Int8Type, IntervalDayTime};
-use arrow_array::{Array, ArrowNativeTypeOp, BooleanArray, Datum,
Decimal128Array};
-use arrow_schema::{ArrowError, DataType, DECIMAL128_MAX_PRECISION};
-use datafusion::physical_expr_common::datum;
-use datafusion::{functions::math::round::round, physical_plan::ColumnarValue};
-use datafusion_common::{
- exec_err, internal_err, DataFusionError, Result as DataFusionResult,
ScalarValue,
-};
-use num::{
- integer::{div_ceil, div_floor},
- BigInt, Signed, ToPrimitive,
-};
-use std::{cmp::min, sync::Arc};
-
-mod unhex;
-pub use unhex::spark_unhex;
-
-mod hex;
-pub(crate) use hex::hex_strings;
-pub use hex::spark_hex;
-
-mod chr;
-pub use chr::SparkChrFunc;
-
-#[inline]
-fn get_precision_scale(data_type: &DataType) -> (u8, i8) {
- let DataType::Decimal128(precision, scale) = data_type else {
- unreachable!()
- };
- (*precision, *scale)
-}
-
-macro_rules! downcast_compute_op {
- ($ARRAY:expr, $NAME:expr, $FUNC:ident, $TYPE:ident, $RESULT:ident) => {{
- let n = $ARRAY.as_any().downcast_ref::<$TYPE>();
- match n {
- Some(array) => {
- let res: $RESULT =
- arrow::compute::kernels::arity::unary(array, |x| x.$FUNC()
as i64);
- Ok(Arc::new(res))
- }
- _ => Err(DataFusionError::Internal(format!(
- "Invalid data type for {}",
- $NAME
- ))),
- }
- }};
-}
-
-/// `ceil` function that simulates Spark `ceil` expression
-pub fn spark_ceil(
- args: &[ColumnarValue],
- data_type: &DataType,
-) -> Result<ColumnarValue, DataFusionError> {
- let value = &args[0];
- match value {
- ColumnarValue::Array(array) => match array.data_type() {
- DataType::Float32 => {
- let result = downcast_compute_op!(array, "ceil", ceil,
Float32Array, Int64Array);
- Ok(ColumnarValue::Array(result?))
- }
- DataType::Float64 => {
- let result = downcast_compute_op!(array, "ceil", ceil,
Float64Array, Int64Array);
- Ok(ColumnarValue::Array(result?))
- }
- DataType::Int64 => {
- let result =
array.as_any().downcast_ref::<Int64Array>().unwrap();
- Ok(ColumnarValue::Array(Arc::new(result.clone())))
- }
- DataType::Decimal128(_, scale) if *scale > 0 => {
- let f = decimal_ceil_f(scale);
- let (precision, scale) = get_precision_scale(data_type);
- make_decimal_array(array, precision, scale, &f)
- }
- other => Err(DataFusionError::Internal(format!(
- "Unsupported data type {:?} for function ceil",
- other,
- ))),
- },
- ColumnarValue::Scalar(a) => match a {
- ScalarValue::Float32(a) =>
Ok(ColumnarValue::Scalar(ScalarValue::Int64(
- a.map(|x| x.ceil() as i64),
- ))),
- ScalarValue::Float64(a) =>
Ok(ColumnarValue::Scalar(ScalarValue::Int64(
- a.map(|x| x.ceil() as i64),
- ))),
- ScalarValue::Int64(a) =>
Ok(ColumnarValue::Scalar(ScalarValue::Int64(a.map(|x| x)))),
- ScalarValue::Decimal128(a, _, scale) if *scale > 0 => {
- let f = decimal_ceil_f(scale);
- let (precision, scale) = get_precision_scale(data_type);
- make_decimal_scalar(a, precision, scale, &f)
- }
- _ => Err(DataFusionError::Internal(format!(
- "Unsupported data type {:?} for function ceil",
- value.data_type(),
- ))),
- },
- }
-}
-
-/// `floor` function that simulates Spark `floor` expression
-pub fn spark_floor(
- args: &[ColumnarValue],
- data_type: &DataType,
-) -> Result<ColumnarValue, DataFusionError> {
- let value = &args[0];
- match value {
- ColumnarValue::Array(array) => match array.data_type() {
- DataType::Float32 => {
- let result = downcast_compute_op!(array, "floor", floor,
Float32Array, Int64Array);
- Ok(ColumnarValue::Array(result?))
- }
- DataType::Float64 => {
- let result = downcast_compute_op!(array, "floor", floor,
Float64Array, Int64Array);
- Ok(ColumnarValue::Array(result?))
- }
- DataType::Int64 => {
- let result =
array.as_any().downcast_ref::<Int64Array>().unwrap();
- Ok(ColumnarValue::Array(Arc::new(result.clone())))
- }
- DataType::Decimal128(_, scale) if *scale > 0 => {
- let f = decimal_floor_f(scale);
- let (precision, scale) = get_precision_scale(data_type);
- make_decimal_array(array, precision, scale, &f)
- }
- other => Err(DataFusionError::Internal(format!(
- "Unsupported data type {:?} for function floor",
- other,
- ))),
- },
- ColumnarValue::Scalar(a) => match a {
- ScalarValue::Float32(a) =>
Ok(ColumnarValue::Scalar(ScalarValue::Int64(
- a.map(|x| x.floor() as i64),
- ))),
- ScalarValue::Float64(a) =>
Ok(ColumnarValue::Scalar(ScalarValue::Int64(
- a.map(|x| x.floor() as i64),
- ))),
- ScalarValue::Int64(a) =>
Ok(ColumnarValue::Scalar(ScalarValue::Int64(a.map(|x| x)))),
- ScalarValue::Decimal128(a, _, scale) if *scale > 0 => {
- let f = decimal_floor_f(scale);
- let (precision, scale) = get_precision_scale(data_type);
- make_decimal_scalar(a, precision, scale, &f)
- }
- _ => Err(DataFusionError::Internal(format!(
- "Unsupported data type {:?} for function floor",
- value.data_type(),
- ))),
- },
- }
-}
-
-/// Spark-compatible `UnscaledValue` expression (internal to Spark optimizer)
-pub fn spark_unscaled_value(args: &[ColumnarValue]) ->
DataFusionResult<ColumnarValue> {
- match &args[0] {
- ColumnarValue::Scalar(v) => match v {
- ScalarValue::Decimal128(d, _, _) =>
Ok(ColumnarValue::Scalar(ScalarValue::Int64(
- d.map(|n| n as i64),
- ))),
- dt => internal_err!("Expected Decimal128 but found {dt:}"),
- },
- ColumnarValue::Array(a) => {
- let arr = a.as_primitive::<Decimal128Type>();
- let mut result = Int64Builder::new();
- for v in arr.into_iter() {
- result.append_option(v.map(|v| v as i64));
- }
- Ok(ColumnarValue::Array(Arc::new(result.finish())))
- }
- }
-}
-
-/// Spark-compatible `MakeDecimal` expression (internal to Spark optimizer)
-pub fn spark_make_decimal(
- args: &[ColumnarValue],
- data_type: &DataType,
-) -> DataFusionResult<ColumnarValue> {
- let (precision, scale) = get_precision_scale(data_type);
- match &args[0] {
- ColumnarValue::Scalar(v) => match v {
- ScalarValue::Int64(n) =>
Ok(ColumnarValue::Scalar(ScalarValue::Decimal128(
- long_to_decimal(n, precision),
- precision,
- scale,
- ))),
- sv => internal_err!("Expected Int64 but found {sv:?}"),
- },
- ColumnarValue::Array(a) => {
- let arr = a.as_primitive::<Int64Type>();
- let mut result = Decimal128Builder::new();
- for v in arr.into_iter() {
- result.append_option(long_to_decimal(&v, precision))
- }
- let result_type = DataType::Decimal128(precision, scale);
-
- Ok(ColumnarValue::Array(Arc::new(
- result.finish().with_data_type(result_type),
- )))
- }
- }
-}
-
-/// Convert the input long to decimal with the given maximum precision. If
overflows, returns null
-/// instead.
-#[inline]
-fn long_to_decimal(v: &Option<i64>, precision: u8) -> Option<i128> {
- match v {
- Some(v) if validate_decimal_precision(*v as i128, precision).is_ok()
=> Some(*v as i128),
- _ => None,
- }
-}
-
-#[inline]
-fn decimal_ceil_f(scale: &i8) -> impl Fn(i128) -> i128 {
- let div = 10_i128.pow_wrapping(*scale as u32);
- move |x: i128| div_ceil(x, div)
-}
-
-#[inline]
-fn decimal_floor_f(scale: &i8) -> impl Fn(i128) -> i128 {
- let div = 10_i128.pow_wrapping(*scale as u32);
- move |x: i128| div_floor(x, div)
-}
-
-// Spark uses BigDecimal. See RoundBase implementation in Spark. Instead, we
do the same by
-// 1) add the half of divisor, 2) round down by division, 3) adjust precision
by multiplication
-#[inline]
-fn decimal_round_f(scale: &i8, point: &i64) -> Box<dyn Fn(i128) -> i128> {
- if *point < 0 {
- if let Some(div) = 10_i128.checked_pow((-(*point) as u32) + (*scale as
u32)) {
- let half = div / 2;
- let mul = 10_i128.pow_wrapping((-(*point)) as u32);
- // i128 can hold 39 digits of a base 10 number, adding half will
not cause overflow
- Box::new(move |x: i128| (x + x.signum() * half) / div * mul)
- } else {
- Box::new(move |_: i128| 0)
- }
- } else {
- let div = 10_i128.pow_wrapping((*scale as u32) - min(*scale as u32,
*point as u32));
- let half = div / 2;
- Box::new(move |x: i128| (x + x.signum() * half) / div)
- }
-}
-
-#[inline]
-fn make_decimal_array(
- array: &ArrayRef,
- precision: u8,
- scale: i8,
- f: &dyn Fn(i128) -> i128,
-) -> Result<ColumnarValue, DataFusionError> {
- let array = array.as_primitive::<Decimal128Type>();
- let result: Decimal128Array = arrow::compute::kernels::arity::unary(array,
f);
- let result = result.with_data_type(DataType::Decimal128(precision, scale));
- Ok(ColumnarValue::Array(Arc::new(result)))
-}
-
-#[inline]
-fn make_decimal_scalar(
- a: &Option<i128>,
- precision: u8,
- scale: i8,
- f: &dyn Fn(i128) -> i128,
-) -> Result<ColumnarValue, DataFusionError> {
- let result = ScalarValue::Decimal128(a.map(f), precision, scale);
- Ok(ColumnarValue::Scalar(result))
-}
-
-macro_rules! integer_round {
- ($X:expr, $DIV:expr, $HALF:expr) => {{
- let rem = $X % $DIV;
- if rem <= -$HALF {
- ($X - rem).sub_wrapping($DIV)
- } else if rem >= $HALF {
- ($X - rem).add_wrapping($DIV)
- } else {
- $X - rem
- }
- }};
-}
-
-macro_rules! round_integer_array {
- ($ARRAY:expr, $POINT:expr, $TYPE:ty, $NATIVE:ty) => {{
- let array = $ARRAY.as_any().downcast_ref::<$TYPE>().unwrap();
- let ten: $NATIVE = 10;
- let result: $TYPE = if let Some(div) = ten.checked_pow((-(*$POINT)) as
u32) {
- let half = div / 2;
- arrow::compute::kernels::arity::unary(array, |x| integer_round!(x,
div, half))
- } else {
- arrow::compute::kernels::arity::unary(array, |_| 0)
- };
- Ok(ColumnarValue::Array(Arc::new(result)))
- }};
-}
-
-macro_rules! round_integer_scalar {
- ($SCALAR:expr, $POINT:expr, $TYPE:expr, $NATIVE:ty) => {{
- let ten: $NATIVE = 10;
- if let Some(div) = ten.checked_pow((-(*$POINT)) as u32) {
- let half = div / 2;
- Ok(ColumnarValue::Scalar($TYPE(
- $SCALAR.map(|x| integer_round!(x, div, half)),
- )))
- } else {
- Ok(ColumnarValue::Scalar($TYPE(Some(0))))
- }
- }};
-}
-
-/// `round` function that simulates Spark `round` expression
-pub fn spark_round(
- args: &[ColumnarValue],
- data_type: &DataType,
-) -> Result<ColumnarValue, DataFusionError> {
- let value = &args[0];
- let point = &args[1];
- let ColumnarValue::Scalar(ScalarValue::Int64(Some(point))) = point else {
- return internal_err!("Invalid point argument for Round(): {:#?}",
point);
- };
- match value {
- ColumnarValue::Array(array) => match array.data_type() {
- DataType::Int64 if *point < 0 => round_integer_array!(array,
point, Int64Array, i64),
- DataType::Int32 if *point < 0 => round_integer_array!(array,
point, Int32Array, i32),
- DataType::Int16 if *point < 0 => round_integer_array!(array,
point, Int16Array, i16),
- DataType::Int8 if *point < 0 => round_integer_array!(array, point,
Int8Array, i8),
- DataType::Decimal128(_, scale) if *scale >= 0 => {
- let f = decimal_round_f(scale, point);
- let (precision, scale) = get_precision_scale(data_type);
- make_decimal_array(array, precision, scale, &f)
- }
- DataType::Float32 | DataType::Float64 => {
- Ok(ColumnarValue::Array(round(&[Arc::clone(array)])?))
- }
- dt => exec_err!("Not supported datatype for ROUND: {dt}"),
- },
- ColumnarValue::Scalar(a) => match a {
- ScalarValue::Int64(a) if *point < 0 => {
- round_integer_scalar!(a, point, ScalarValue::Int64, i64)
- }
- ScalarValue::Int32(a) if *point < 0 => {
- round_integer_scalar!(a, point, ScalarValue::Int32, i32)
- }
- ScalarValue::Int16(a) if *point < 0 => {
- round_integer_scalar!(a, point, ScalarValue::Int16, i16)
- }
- ScalarValue::Int8(a) if *point < 0 => {
- round_integer_scalar!(a, point, ScalarValue::Int8, i8)
- }
- ScalarValue::Decimal128(a, _, scale) if *scale >= 0 => {
- let f = decimal_round_f(scale, point);
- let (precision, scale) = get_precision_scale(data_type);
- make_decimal_scalar(a, precision, scale, &f)
- }
- ScalarValue::Float32(_) | ScalarValue::Float64(_) =>
Ok(ColumnarValue::Scalar(
- ScalarValue::try_from_array(&round(&[a.to_array()?])?, 0)?,
- )),
- dt => exec_err!("Not supported datatype for ROUND: {dt}"),
- },
- }
-}
-
-// Let Decimal(p3, s3) as return type i.e. Decimal(p1, s1) / Decimal(p2, s2) =
Decimal(p3, s3).
-// Conversely, Decimal(p1, s1) = Decimal(p2, s2) * Decimal(p3, s3). This means
that, in order to
-// get enough scale that matches with Spark behavior, it requires to widen s1
to s2 + s3 + 1. Since
-// both s2 and s3 are 38 at max., s1 is 77 at max. DataFusion division cannot
handle such scale >
-// Decimal256Type::MAX_SCALE. Therefore, we need to implement this decimal
division using BigInt.
-pub fn spark_decimal_div(
- args: &[ColumnarValue],
- data_type: &DataType,
-) -> Result<ColumnarValue, DataFusionError> {
- let left = &args[0];
- let right = &args[1];
- let (p3, s3) = get_precision_scale(data_type);
-
- let (left, right): (ArrayRef, ArrayRef) = match (left, right) {
- (ColumnarValue::Array(l), ColumnarValue::Array(r)) => (Arc::clone(l),
Arc::clone(r)),
- (ColumnarValue::Scalar(l), ColumnarValue::Array(r)) => {
- (l.to_array_of_size(r.len())?, Arc::clone(r))
- }
- (ColumnarValue::Array(l), ColumnarValue::Scalar(r)) => {
- (Arc::clone(l), r.to_array_of_size(l.len())?)
- }
- (ColumnarValue::Scalar(l), ColumnarValue::Scalar(r)) =>
(l.to_array()?, r.to_array()?),
- };
- let left = left.as_primitive::<Decimal128Type>();
- let right = right.as_primitive::<Decimal128Type>();
- let (p1, s1) = get_precision_scale(left.data_type());
- let (p2, s2) = get_precision_scale(right.data_type());
-
- let l_exp = ((s2 + s3 + 1) as u32).saturating_sub(s1 as u32);
- let r_exp = (s1 as u32).saturating_sub((s2 + s3 + 1) as u32);
- let result: Decimal128Array = if p1 as u32 + l_exp >
DECIMAL128_MAX_PRECISION as u32
- || p2 as u32 + r_exp > DECIMAL128_MAX_PRECISION as u32
- {
- let ten = BigInt::from(10);
- let l_mul = ten.pow(l_exp);
- let r_mul = ten.pow(r_exp);
- let five = BigInt::from(5);
- let zero = BigInt::from(0);
- arrow::compute::kernels::arity::binary(left, right, |l, r| {
- let l = BigInt::from(l) * &l_mul;
- let r = BigInt::from(r) * &r_mul;
- let div = if r.eq(&zero) { zero.clone() } else { &l / &r };
- let res = if div.is_negative() {
- div - &five
- } else {
- div + &five
- } / &ten;
- res.to_i128().unwrap_or(i128::MAX)
- })?
- } else {
- let l_mul = 10_i128.pow(l_exp);
- let r_mul = 10_i128.pow(r_exp);
- arrow::compute::kernels::arity::binary(left, right, |l, r| {
- let l = l * l_mul;
- let r = r * r_mul;
- let div = if r == 0 { 0 } else { l / r };
- let res = if div.is_negative() { div - 5 } else { div + 5 } / 10;
- res.to_i128().unwrap_or(i128::MAX)
- })?
- };
- let result = result.with_data_type(DataType::Decimal128(p3, s3));
- Ok(ColumnarValue::Array(Arc::new(result)))
-}
-
-macro_rules! scalar_date_arithmetic {
- ($start:expr, $days:expr, $op:expr) => {{
- let interval = IntervalDayTime::new(*$days as i32, 0);
- let interval_cv =
ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(interval)));
- datum::apply($start, &interval_cv, $op)
- }};
-}
-macro_rules! array_date_arithmetic {
- ($days:expr, $interval_builder:expr, $intType:ty) => {{
- for day in $days.as_primitive::<$intType>().into_iter() {
- if let Some(non_null_day) = day {
-
$interval_builder.append_value(IntervalDayTime::new(non_null_day as i32, 0));
- } else {
- $interval_builder.append_null();
- }
- }
- }};
-}
-
-/// Spark-compatible `date_add` and `date_sub` expressions, which assumes days
for the second
-/// argument, but we cannot directly add that to a Date32. We generate an
IntervalDayTime from the
-/// second argument and use DataFusion's interface to apply Arrow's operators.
-fn spark_date_arithmetic(
- args: &[ColumnarValue],
- op: impl Fn(&dyn Datum, &dyn Datum) -> Result<ArrayRef, ArrowError>,
-) -> Result<ColumnarValue, DataFusionError> {
- let start = &args[0];
- match &args[1] {
- ColumnarValue::Scalar(ScalarValue::Int8(Some(days))) => {
- scalar_date_arithmetic!(start, days, op)
- }
- ColumnarValue::Scalar(ScalarValue::Int16(Some(days))) => {
- scalar_date_arithmetic!(start, days, op)
- }
- ColumnarValue::Scalar(ScalarValue::Int32(Some(days))) => {
- scalar_date_arithmetic!(start, days, op)
- }
- ColumnarValue::Array(days) => {
- let mut interval_builder =
IntervalDayTimeBuilder::with_capacity(days.len());
- match days.data_type() {
- DataType::Int8 => {
- array_date_arithmetic!(days, interval_builder, Int8Type)
- }
- DataType::Int16 => {
- array_date_arithmetic!(days, interval_builder, Int16Type)
- }
- DataType::Int32 => {
- array_date_arithmetic!(days, interval_builder, Int32Type)
- }
- _ => {
- return Err(DataFusionError::Internal(format!(
- "Unsupported data types {:?} for date arithmetic.",
- args,
- )))
- }
- }
- let interval_cv =
ColumnarValue::Array(Arc::new(interval_builder.finish()));
- datum::apply(start, &interval_cv, op)
- }
- _ => Err(DataFusionError::Internal(format!(
- "Unsupported data types {:?} for date arithmetic.",
- args,
- ))),
- }
-}
-pub fn spark_date_add(args: &[ColumnarValue]) -> Result<ColumnarValue,
DataFusionError> {
- spark_date_arithmetic(args, add)
-}
-
-pub fn spark_date_sub(args: &[ColumnarValue]) -> Result<ColumnarValue,
DataFusionError> {
- spark_date_arithmetic(args, sub)
-}
-
-/// Spark-compatible `isnan` expression
-pub fn spark_isnan(args: &[ColumnarValue]) -> Result<ColumnarValue,
DataFusionError> {
- fn set_nulls_to_false(is_nan: BooleanArray) -> ColumnarValue {
- match is_nan.nulls() {
- Some(nulls) => {
- let is_not_null = nulls.inner();
- ColumnarValue::Array(Arc::new(BooleanArray::new(
- is_nan.values() & is_not_null,
- None,
- )))
- }
- None => ColumnarValue::Array(Arc::new(is_nan)),
- }
- }
- let value = &args[0];
- match value {
- ColumnarValue::Array(array) => match array.data_type() {
- DataType::Float64 => {
- let array =
array.as_any().downcast_ref::<Float64Array>().unwrap();
- let is_nan = BooleanArray::from_unary(array, |x| x.is_nan());
- Ok(set_nulls_to_false(is_nan))
- }
- DataType::Float32 => {
- let array =
array.as_any().downcast_ref::<Float32Array>().unwrap();
- let is_nan = BooleanArray::from_unary(array, |x| x.is_nan());
- Ok(set_nulls_to_false(is_nan))
- }
- other => Err(DataFusionError::Internal(format!(
- "Unsupported data type {:?} for function isnan",
- other,
- ))),
- },
- ColumnarValue::Scalar(a) => match a {
- ScalarValue::Float64(a) =>
Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(
- a.map(|x| x.is_nan()).unwrap_or(false),
- )))),
- ScalarValue::Float32(a) =>
Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(
- a.map(|x| x.is_nan()).unwrap_or(false),
- )))),
- _ => Err(DataFusionError::Internal(format!(
- "Unsupported data type {:?} for function isnan",
- value.data_type(),
- ))),
- },
- }
-}
diff --git a/native/spark-expr/src/scalar_funcs/chr.rs
b/native/spark-expr/src/string_funcs/chr.rs
similarity index 100%
rename from native/spark-expr/src/scalar_funcs/chr.rs
rename to native/spark-expr/src/string_funcs/chr.rs
diff --git a/native/spark-expr/src/string_funcs/mod.rs
b/native/spark-expr/src/string_funcs/mod.rs
index 2c2a5b37c..d56b5662c 100644
--- a/native/spark-expr/src/string_funcs/mod.rs
+++ b/native/spark-expr/src/string_funcs/mod.rs
@@ -15,10 +15,12 @@
// specific language governing permissions and limitations
// under the License.
+mod chr;
mod prediction;
mod string_space;
mod substring;
+pub use chr::SparkChrFunc;
pub use prediction::*;
pub use string_space::StringSpaceExpr;
pub use substring::SubstringExpr;
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]