This is an automated email from the ASF dual-hosted git repository. agrove pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push: new 91f4b9c54 Chore: implement bit_count as ScalarUDFImpl (#1826) 91f4b9c54 is described below commit 91f4b9c54d5517b4eb4975fd122e276aeb75cc1d Author: Kazantsev Maksim <kazantsev....@yandex.ru> AuthorDate: Tue Jun 3 22:19:30 2025 +0400 Chore: implement bit_count as ScalarUDFImpl (#1826) --- native/core/src/execution/planner.rs | 5 +- .../spark-expr/src/bitwise_funcs/bitwise_count.rs | 81 +++++++++++++++++----- native/spark-expr/src/bitwise_funcs/mod.rs | 2 +- native/spark-expr/src/comet_scalar_funcs.rs | 12 ++-- 4 files changed, 71 insertions(+), 29 deletions(-) diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 9bdb7e6c2..94e79e4e0 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -64,7 +64,9 @@ use datafusion::{ }, prelude::SessionContext, }; -use datafusion_comet_spark_expr::{create_comet_physical_fun, create_negate_expr, SparkBitwiseNot}; +use datafusion_comet_spark_expr::{ + create_comet_physical_fun, create_negate_expr, SparkBitwiseCount, SparkBitwiseNot, +}; use crate::execution::operators::ExecutionError::GeneralError; use crate::execution::shuffle::CompressionCodec; @@ -155,6 +157,7 @@ impl PhysicalPlanner { // register UDFs from datafusion-spark crate session_ctx.register_udf(ScalarUDF::new_from_impl(SparkExpm1::default())); session_ctx.register_udf(ScalarUDF::new_from_impl(SparkBitwiseNot::default())); + session_ctx.register_udf(ScalarUDF::new_from_impl(SparkBitwiseCount::default())); Self { exec_context_id: TEST_EXEC_CONTEXT_ID, session_ctx, diff --git a/native/spark-expr/src/bitwise_funcs/bitwise_count.rs b/native/spark-expr/src/bitwise_funcs/bitwise_count.rs index f0a1b0073..bee7b1327 100644 --- a/native/spark-expr/src/bitwise_funcs/bitwise_count.rs +++ b/native/spark-expr/src/bitwise_funcs/bitwise_count.rs @@ -16,10 +16,63 @@ // under the License. use arrow::{array::*, datatypes::DataType}; -use datafusion::common::Result; +use datafusion::common::{exec_err, internal_datafusion_err, internal_err, Result}; +use datafusion::logical_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; use datafusion::{error::DataFusionError, logical_expr::ColumnarValue}; +use std::any::Any; use std::sync::Arc; +#[derive(Debug)] +pub struct SparkBitwiseCount { + signature: Signature, + aliases: Vec<String>, +} + +impl Default for SparkBitwiseCount { + fn default() -> Self { + Self::new() + } +} + +impl SparkBitwiseCount { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + aliases: vec![], + } + } +} + +impl ScalarUDFImpl for SparkBitwiseCount { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "bit_count" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _: &[DataType]) -> Result<DataType> { + Ok(DataType::Int32) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> { + let args: [ColumnarValue; 1] = args + .args + .try_into() + .map_err(|_| internal_datafusion_err!("bit_count expects exactly one argument"))?; + spark_bit_count(args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + macro_rules! compute_op { ($OPERAND:expr, $DT:ident) => {{ let operand = $OPERAND.as_any().downcast_ref::<$DT>().ok_or_else(|| { @@ -38,29 +91,19 @@ macro_rules! compute_op { }}; } -pub fn spark_bit_count(args: &[ColumnarValue]) -> Result<ColumnarValue> { - if args.len() != 1 { - return Err(DataFusionError::Internal( - "bit_count expects exactly one argument".to_string(), - )); - } - match &args[0] { - ColumnarValue::Array(array) => { +pub fn spark_bit_count(args: [ColumnarValue; 1]) -> Result<ColumnarValue> { + match args { + [ColumnarValue::Array(array)] => { let result: Result<ArrayRef> = match array.data_type() { DataType::Int8 | DataType::Boolean => compute_op!(array, Int8Array), DataType::Int16 => compute_op!(array, Int16Array), DataType::Int32 => compute_op!(array, Int32Array), DataType::Int64 => compute_op!(array, Int64Array), - _ => Err(DataFusionError::Execution(format!( - "Can't be evaluated because the expression's type is {:?}, not signed int", - array.data_type(), - ))), + _ => exec_err!("bit_count can't be evaluated because the expression's type is {:?}, not signed int", array.data_type()), }; result.map(ColumnarValue::Array) } - ColumnarValue::Scalar(_) => Err(DataFusionError::Internal( - "shouldn't go to bit_count scalar path".to_string(), - )), + [ColumnarValue::Scalar(_)] => internal_err!("shouldn't go to bitwise count scalar path"), } } @@ -84,16 +127,16 @@ mod tests { #[test] fn bitwise_count_op() -> Result<()> { - let args = vec![ColumnarValue::Array(Arc::new(Int32Array::from(vec![ + let args = ColumnarValue::Array(Arc::new(Int32Array::from(vec![ Some(1), None, Some(12345), Some(89), Some(-3456), - ])))]; + ]))); let expected = &Int32Array::from(vec![Some(1), None, Some(6), Some(4), Some(54)]); - let ColumnarValue::Array(result) = spark_bit_count(&args)? else { + let ColumnarValue::Array(result) = spark_bit_count([args])? else { unreachable!() }; diff --git a/native/spark-expr/src/bitwise_funcs/mod.rs b/native/spark-expr/src/bitwise_funcs/mod.rs index 47267d852..3f148a6dc 100644 --- a/native/spark-expr/src/bitwise_funcs/mod.rs +++ b/native/spark-expr/src/bitwise_funcs/mod.rs @@ -18,5 +18,5 @@ mod bitwise_count; mod bitwise_not; -pub use bitwise_count::spark_bit_count; +pub use bitwise_count::SparkBitwiseCount; pub use bitwise_not::SparkBitwiseNot; diff --git a/native/spark-expr/src/comet_scalar_funcs.rs b/native/spark-expr/src/comet_scalar_funcs.rs index f85206000..cf06d3633 100644 --- a/native/spark-expr/src/comet_scalar_funcs.rs +++ b/native/spark-expr/src/comet_scalar_funcs.rs @@ -17,10 +17,10 @@ use crate::hash_funcs::*; use crate::{ - spark_array_repeat, spark_bit_count, spark_ceil, spark_date_add, spark_date_sub, - spark_decimal_div, spark_decimal_integral_div, spark_floor, spark_hex, spark_isnan, - spark_make_decimal, spark_read_side_padding, spark_round, spark_rpad, spark_unhex, - spark_unscaled_value, SparkChrFunc, + spark_array_repeat, spark_ceil, spark_date_add, spark_date_sub, spark_decimal_div, + spark_decimal_integral_div, spark_floor, spark_hex, spark_isnan, spark_make_decimal, + spark_read_side_padding, spark_round, spark_rpad, spark_unhex, spark_unscaled_value, + SparkChrFunc, }; use arrow::datatypes::DataType; use datafusion::common::{DataFusionError, Result as DataFusionResult}; @@ -145,10 +145,6 @@ pub fn create_comet_physical_fun( let func = Arc::new(spark_array_repeat); make_comet_scalar_udf!("array_repeat", func, without data_type) } - "bit_count" => { - let func = Arc::new(spark_bit_count); - make_comet_scalar_udf!("bit_count", func, without data_type) - } _ => registry.udf(fun_name).map_err(|e| { DataFusionError::Execution(format!( "Function {fun_name} not found in the registry: {e}", --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@datafusion.apache.org For additional commands, e-mail: commits-h...@datafusion.apache.org