This is an automated email from the ASF dual-hosted git repository. kazuyukitanimura pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push: new 53c724eac Feat: support bit_count function (#1602) 53c724eac is described below commit 53c724eaca4cd9a8e9aa65c5a81bb0ac1a601665 Author: Kazantsev Maksim <kazantsev....@yandex.ru> AuthorDate: Fri May 30 23:52:03 2025 +0400 Feat: support bit_count function (#1602) ## Which issue does this PR close? Related to Epic: https://github.com/apache/datafusion-comet/issues/240 bit_count: SELECT bit_count(0) => 0 DataFusionComet bit_count has same behavior with Spark 's bit_count function Spark: https://spark.apache.org/docs/latest/api/sql/index.html#bit_count Closes #. ## Rationale for this change Defined under Epic: https://github.com/apache/datafusion-comet/issues/240 ## What changes are included in this PR? bitwise_count.rs: impl for bit_count function planner.rs: Maps Spark 's bit_count function to DataFusionComet bit_count physical expression from Spark physical expression expr.proto: bit_count has been added, QueryPlanSerde.scala: bit_count pattern matching case has been added, CometExpressionSuite.scala: A new UT has been added for bit_count function. ## How are these changes tested? A new UT has been added. --- .../spark-expr/src/bitwise_funcs/bitwise_count.rs | 105 +++++++++++++++++++++ native/spark-expr/src/bitwise_funcs/mod.rs | 2 + native/spark-expr/src/comet_scalar_funcs.rs | 12 ++- .../org/apache/comet/serde/QueryPlanSerde.scala | 6 ++ .../org/apache/comet/CometExpressionSuite.scala | 68 +++++++++++++ 5 files changed, 189 insertions(+), 4 deletions(-) diff --git a/native/spark-expr/src/bitwise_funcs/bitwise_count.rs b/native/spark-expr/src/bitwise_funcs/bitwise_count.rs new file mode 100644 index 000000000..f0a1b0073 --- /dev/null +++ b/native/spark-expr/src/bitwise_funcs/bitwise_count.rs @@ -0,0 +1,105 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::{array::*, datatypes::DataType}; +use datafusion::common::Result; +use datafusion::{error::DataFusionError, logical_expr::ColumnarValue}; +use std::sync::Arc; + +macro_rules! compute_op { + ($OPERAND:expr, $DT:ident) => {{ + let operand = $OPERAND.as_any().downcast_ref::<$DT>().ok_or_else(|| { + DataFusionError::Execution(format!( + "compute_op failed to downcast array to: {:?}", + stringify!($DT) + )) + })?; + + let result: Int32Array = operand + .iter() + .map(|x| x.map(|y| bit_count(y.into()))) + .collect(); + + Ok(Arc::new(result)) + }}; +} + +pub fn spark_bit_count(args: &[ColumnarValue]) -> Result<ColumnarValue> { + if args.len() != 1 { + return Err(DataFusionError::Internal( + "bit_count expects exactly one argument".to_string(), + )); + } + match &args[0] { + ColumnarValue::Array(array) => { + let result: Result<ArrayRef> = match array.data_type() { + DataType::Int8 | DataType::Boolean => compute_op!(array, Int8Array), + DataType::Int16 => compute_op!(array, Int16Array), + DataType::Int32 => compute_op!(array, Int32Array), + DataType::Int64 => compute_op!(array, Int64Array), + _ => Err(DataFusionError::Execution(format!( + "Can't be evaluated because the expression's type is {:?}, not signed int", + array.data_type(), + ))), + }; + result.map(ColumnarValue::Array) + } + ColumnarValue::Scalar(_) => Err(DataFusionError::Internal( + "shouldn't go to bit_count scalar path".to_string(), + )), + } +} + +// Here’s the equivalent Rust implementation of the bitCount function (similar to Apache Spark's bitCount for LongType) +fn bit_count(i: i64) -> i32 { + let mut u = i as u64; + u = u - ((u >> 1) & 0x5555555555555555); + u = (u & 0x3333333333333333) + ((u >> 2) & 0x3333333333333333); + u = (u + (u >> 4)) & 0x0f0f0f0f0f0f0f0f; + u = u + (u >> 8); + u = u + (u >> 16); + u = u + (u >> 32); + (u as i32) & 0x7f +} + +#[cfg(test)] +mod tests { + use datafusion::common::{cast::as_int32_array, Result}; + + use super::*; + + #[test] + fn bitwise_count_op() -> Result<()> { + let args = vec![ColumnarValue::Array(Arc::new(Int32Array::from(vec![ + Some(1), + None, + Some(12345), + Some(89), + Some(-3456), + ])))]; + let expected = &Int32Array::from(vec![Some(1), None, Some(6), Some(4), Some(54)]); + + let ColumnarValue::Array(result) = spark_bit_count(&args)? else { + unreachable!() + }; + + let result = as_int32_array(&result).expect("failed to downcast to In32Array"); + assert_eq!(result, expected); + + Ok(()) + } +} diff --git a/native/spark-expr/src/bitwise_funcs/mod.rs b/native/spark-expr/src/bitwise_funcs/mod.rs index 9c2636331..718cfc7ca 100644 --- a/native/spark-expr/src/bitwise_funcs/mod.rs +++ b/native/spark-expr/src/bitwise_funcs/mod.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +mod bitwise_count; mod bitwise_not; +pub use bitwise_count::spark_bit_count; pub use bitwise_not::{bitwise_not, BitwiseNotExpr}; diff --git a/native/spark-expr/src/comet_scalar_funcs.rs b/native/spark-expr/src/comet_scalar_funcs.rs index cf06d3633..f85206000 100644 --- a/native/spark-expr/src/comet_scalar_funcs.rs +++ b/native/spark-expr/src/comet_scalar_funcs.rs @@ -17,10 +17,10 @@ use crate::hash_funcs::*; use crate::{ - spark_array_repeat, spark_ceil, spark_date_add, spark_date_sub, spark_decimal_div, - spark_decimal_integral_div, spark_floor, spark_hex, spark_isnan, spark_make_decimal, - spark_read_side_padding, spark_round, spark_rpad, spark_unhex, spark_unscaled_value, - SparkChrFunc, + spark_array_repeat, spark_bit_count, spark_ceil, spark_date_add, spark_date_sub, + spark_decimal_div, spark_decimal_integral_div, spark_floor, spark_hex, spark_isnan, + spark_make_decimal, spark_read_side_padding, spark_round, spark_rpad, spark_unhex, + spark_unscaled_value, SparkChrFunc, }; use arrow::datatypes::DataType; use datafusion::common::{DataFusionError, Result as DataFusionResult}; @@ -145,6 +145,10 @@ pub fn create_comet_physical_fun( let func = Arc::new(spark_array_repeat); make_comet_scalar_udf!("array_repeat", func, without data_type) } + "bit_count" => { + let func = Arc::new(spark_bit_count); + make_comet_scalar_udf!("bit_count", func, without data_type) + } _ => registry.udf(fun_name).map_err(|e| { DataFusionError::Execution(format!( "Function {fun_name} not found in the registry: {e}", diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 02e7530e0..32918677e 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -1634,6 +1634,12 @@ object QueryPlanSerde extends Logging with CometExprShim { binding, (builder, binaryExpr) => builder.setBitwiseXor(binaryExpr)) + case BitwiseCount(child) => + val childProto = exprToProto(child, inputs, binding) + val bitCountScalarExpr = + scalarFunctionExprToProtoWithReturnType("bit_count", IntegerType, childProto) + optExprWithInfo(bitCountScalarExpr, expr, expr.children: _*) + case ShiftRight(left, right) => // DataFusion bitwise shift right expression requires // same data type between left and right side diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 2099426fa..6273ab9b0 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -37,6 +37,7 @@ import org.apache.spark.sql.internal.SQLConf.SESSION_LOCAL_TIMEZONE import org.apache.spark.sql.types.{Decimal, DecimalType} import org.apache.comet.CometSparkSessionExtensions.isSpark40Plus +import org.apache.comet.testing.{DataGenOptions, ParquetGenerator} class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { import testImplicits._ @@ -99,6 +100,73 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("bitwise_count - min/max values") { + Seq(false, true).foreach { dictionary => + withSQLConf("parquet.enable.dictionary" -> dictionary.toString) { + val table = "bitwise_count_test" + withTable(table) { + sql(s"create table $table(col1 long, col2 int, col3 short, col4 byte) using parquet") + sql(s"insert into $table values(1111, 2222, 17, 7)") + sql( + s"insert into $table values(${Long.MaxValue}, ${Int.MaxValue}, ${Short.MaxValue}, ${Byte.MaxValue})") + sql( + s"insert into $table values(${Long.MinValue}, ${Int.MinValue}, ${Short.MinValue}, ${Byte.MinValue})") + + checkSparkAnswerAndOperator(sql(s"SELECT bit_count(col1) FROM $table")) + checkSparkAnswerAndOperator(sql(s"SELECT bit_count(col2) FROM $table")) + checkSparkAnswerAndOperator(sql(s"SELECT bit_count(col3) FROM $table")) + checkSparkAnswerAndOperator(sql(s"SELECT bit_count(col4) FROM $table")) + checkSparkAnswerAndOperator(sql(s"SELECT bit_count(true) FROM $table")) + checkSparkAnswerAndOperator(sql(s"SELECT bit_count(false) FROM $table")) + } + } + } + } + + test("bitwise_count - random values (spark gen)") { + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test.parquet") + val filename = path.toString + val random = new Random(42) + withSQLConf(CometConf.COMET_ENABLED.key -> "false") { + ParquetGenerator.makeParquetFile( + random, + spark, + filename, + 10, + DataGenOptions( + allowNull = true, + generateNegativeZero = true, + generateArray = false, + generateStruct = false, + generateMap = false)) + } + val table = spark.read.parquet(filename) + val df = + table.selectExpr("bit_count(c1)", "bit_count(c2)", "bit_count(c3)", "bit_count(c4)") + + checkSparkAnswerAndOperator(df) + } + } + + test("bitwise_count - random values (native parquet gen)") { + Seq(true, false).foreach { dictionaryEnabled => + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test.parquet") + makeParquetFileAllTypes(path, dictionaryEnabled, 0, 10000, nullEnabled = false) + val table = spark.read.parquet(path.toString) + checkSparkAnswerAndOperator( + table + .selectExpr( + "bit_count(_2)", + "bit_count(_3)", + "bit_count(_4)", + "bit_count(_5)", + "bit_count(_11)")) + } + } + } + test("bitwise shift with different left/right types") { Seq(false, true).foreach { dictionary => withSQLConf("parquet.enable.dictionary" -> dictionary.toString) { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@datafusion.apache.org For additional commands, e-mail: commits-h...@datafusion.apache.org