This is an automated email from the ASF dual-hosted git repository.

kazuyukitanimura pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 53c724eac Feat: support bit_count function (#1602)
53c724eac is described below

commit 53c724eaca4cd9a8e9aa65c5a81bb0ac1a601665
Author: Kazantsev Maksim <kazantsev....@yandex.ru>
AuthorDate: Fri May 30 23:52:03 2025 +0400

    Feat: support bit_count function (#1602)
    
    ## Which issue does this PR close?
    
    Related to Epic: https://github.com/apache/datafusion-comet/issues/240
    bit_count: SELECT bit_count(0) => 0
    DataFusionComet bit_count has same behavior with Spark 's bit_count function
    Spark: https://spark.apache.org/docs/latest/api/sql/index.html#bit_count
    
    Closes #.
    
    ## Rationale for this change
    
    Defined under Epic: https://github.com/apache/datafusion-comet/issues/240
    
    ## What changes are included in this PR?
    
    bitwise_count.rs: impl for bit_count function
    planner.rs: Maps Spark 's bit_count function to DataFusionComet bit_count 
physical expression from Spark physical expression
    expr.proto: bit_count has been added,
    QueryPlanSerde.scala: bit_count pattern matching case has been added,
    CometExpressionSuite.scala: A new UT has been added for bit_count function.
    
    ## How are these changes tested?
    
    A new UT has been added.
---
 .../spark-expr/src/bitwise_funcs/bitwise_count.rs  | 105 +++++++++++++++++++++
 native/spark-expr/src/bitwise_funcs/mod.rs         |   2 +
 native/spark-expr/src/comet_scalar_funcs.rs        |  12 ++-
 .../org/apache/comet/serde/QueryPlanSerde.scala    |   6 ++
 .../org/apache/comet/CometExpressionSuite.scala    |  68 +++++++++++++
 5 files changed, 189 insertions(+), 4 deletions(-)

diff --git a/native/spark-expr/src/bitwise_funcs/bitwise_count.rs 
b/native/spark-expr/src/bitwise_funcs/bitwise_count.rs
new file mode 100644
index 000000000..f0a1b0073
--- /dev/null
+++ b/native/spark-expr/src/bitwise_funcs/bitwise_count.rs
@@ -0,0 +1,105 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use arrow::{array::*, datatypes::DataType};
+use datafusion::common::Result;
+use datafusion::{error::DataFusionError, logical_expr::ColumnarValue};
+use std::sync::Arc;
+
+macro_rules! compute_op {
+    ($OPERAND:expr, $DT:ident) => {{
+        let operand = $OPERAND.as_any().downcast_ref::<$DT>().ok_or_else(|| {
+            DataFusionError::Execution(format!(
+                "compute_op failed to downcast array to: {:?}",
+                stringify!($DT)
+            ))
+        })?;
+
+        let result: Int32Array = operand
+            .iter()
+            .map(|x| x.map(|y| bit_count(y.into())))
+            .collect();
+
+        Ok(Arc::new(result))
+    }};
+}
+
+pub fn spark_bit_count(args: &[ColumnarValue]) -> Result<ColumnarValue> {
+    if args.len() != 1 {
+        return Err(DataFusionError::Internal(
+            "bit_count expects exactly one argument".to_string(),
+        ));
+    }
+    match &args[0] {
+        ColumnarValue::Array(array) => {
+            let result: Result<ArrayRef> = match array.data_type() {
+                DataType::Int8 | DataType::Boolean => compute_op!(array, 
Int8Array),
+                DataType::Int16 => compute_op!(array, Int16Array),
+                DataType::Int32 => compute_op!(array, Int32Array),
+                DataType::Int64 => compute_op!(array, Int64Array),
+                _ => Err(DataFusionError::Execution(format!(
+                    "Can't be evaluated because the expression's type is {:?}, 
not signed int",
+                    array.data_type(),
+                ))),
+            };
+            result.map(ColumnarValue::Array)
+        }
+        ColumnarValue::Scalar(_) => Err(DataFusionError::Internal(
+            "shouldn't go to bit_count scalar path".to_string(),
+        )),
+    }
+}
+
+// Here’s the equivalent Rust implementation of the bitCount function (similar 
to Apache Spark's bitCount for LongType)
+fn bit_count(i: i64) -> i32 {
+    let mut u = i as u64;
+    u = u - ((u >> 1) & 0x5555555555555555);
+    u = (u & 0x3333333333333333) + ((u >> 2) & 0x3333333333333333);
+    u = (u + (u >> 4)) & 0x0f0f0f0f0f0f0f0f;
+    u = u + (u >> 8);
+    u = u + (u >> 16);
+    u = u + (u >> 32);
+    (u as i32) & 0x7f
+}
+
+#[cfg(test)]
+mod tests {
+    use datafusion::common::{cast::as_int32_array, Result};
+
+    use super::*;
+
+    #[test]
+    fn bitwise_count_op() -> Result<()> {
+        let args = vec![ColumnarValue::Array(Arc::new(Int32Array::from(vec![
+            Some(1),
+            None,
+            Some(12345),
+            Some(89),
+            Some(-3456),
+        ])))];
+        let expected = &Int32Array::from(vec![Some(1), None, Some(6), Some(4), 
Some(54)]);
+
+        let ColumnarValue::Array(result) = spark_bit_count(&args)? else {
+            unreachable!()
+        };
+
+        let result = as_int32_array(&result).expect("failed to downcast to 
In32Array");
+        assert_eq!(result, expected);
+
+        Ok(())
+    }
+}
diff --git a/native/spark-expr/src/bitwise_funcs/mod.rs 
b/native/spark-expr/src/bitwise_funcs/mod.rs
index 9c2636331..718cfc7ca 100644
--- a/native/spark-expr/src/bitwise_funcs/mod.rs
+++ b/native/spark-expr/src/bitwise_funcs/mod.rs
@@ -15,6 +15,8 @@
 // specific language governing permissions and limitations
 // under the License.
 
+mod bitwise_count;
 mod bitwise_not;
 
+pub use bitwise_count::spark_bit_count;
 pub use bitwise_not::{bitwise_not, BitwiseNotExpr};
diff --git a/native/spark-expr/src/comet_scalar_funcs.rs 
b/native/spark-expr/src/comet_scalar_funcs.rs
index cf06d3633..f85206000 100644
--- a/native/spark-expr/src/comet_scalar_funcs.rs
+++ b/native/spark-expr/src/comet_scalar_funcs.rs
@@ -17,10 +17,10 @@
 
 use crate::hash_funcs::*;
 use crate::{
-    spark_array_repeat, spark_ceil, spark_date_add, spark_date_sub, 
spark_decimal_div,
-    spark_decimal_integral_div, spark_floor, spark_hex, spark_isnan, 
spark_make_decimal,
-    spark_read_side_padding, spark_round, spark_rpad, spark_unhex, 
spark_unscaled_value,
-    SparkChrFunc,
+    spark_array_repeat, spark_bit_count, spark_ceil, spark_date_add, 
spark_date_sub,
+    spark_decimal_div, spark_decimal_integral_div, spark_floor, spark_hex, 
spark_isnan,
+    spark_make_decimal, spark_read_side_padding, spark_round, spark_rpad, 
spark_unhex,
+    spark_unscaled_value, SparkChrFunc,
 };
 use arrow::datatypes::DataType;
 use datafusion::common::{DataFusionError, Result as DataFusionResult};
@@ -145,6 +145,10 @@ pub fn create_comet_physical_fun(
             let func = Arc::new(spark_array_repeat);
             make_comet_scalar_udf!("array_repeat", func, without data_type)
         }
+        "bit_count" => {
+            let func = Arc::new(spark_bit_count);
+            make_comet_scalar_udf!("bit_count", func, without data_type)
+        }
         _ => registry.udf(fun_name).map_err(|e| {
             DataFusionError::Execution(format!(
                 "Function {fun_name} not found in the registry: {e}",
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala 
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index 02e7530e0..32918677e 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -1634,6 +1634,12 @@ object QueryPlanSerde extends Logging with CometExprShim 
{
           binding,
           (builder, binaryExpr) => builder.setBitwiseXor(binaryExpr))
 
+      case BitwiseCount(child) =>
+        val childProto = exprToProto(child, inputs, binding)
+        val bitCountScalarExpr =
+          scalarFunctionExprToProtoWithReturnType("bit_count", IntegerType, 
childProto)
+        optExprWithInfo(bitCountScalarExpr, expr, expr.children: _*)
+
       case ShiftRight(left, right) =>
         // DataFusion bitwise shift right expression requires
         // same data type between left and right side
diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala 
b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
index 2099426fa..6273ab9b0 100644
--- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
@@ -37,6 +37,7 @@ import 
org.apache.spark.sql.internal.SQLConf.SESSION_LOCAL_TIMEZONE
 import org.apache.spark.sql.types.{Decimal, DecimalType}
 
 import org.apache.comet.CometSparkSessionExtensions.isSpark40Plus
+import org.apache.comet.testing.{DataGenOptions, ParquetGenerator}
 
 class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
   import testImplicits._
@@ -99,6 +100,73 @@ class CometExpressionSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
     }
   }
 
+  test("bitwise_count - min/max values") {
+    Seq(false, true).foreach { dictionary =>
+      withSQLConf("parquet.enable.dictionary" -> dictionary.toString) {
+        val table = "bitwise_count_test"
+        withTable(table) {
+          sql(s"create table $table(col1 long, col2 int, col3 short, col4 
byte) using parquet")
+          sql(s"insert into $table values(1111, 2222, 17, 7)")
+          sql(
+            s"insert into $table values(${Long.MaxValue}, ${Int.MaxValue}, 
${Short.MaxValue}, ${Byte.MaxValue})")
+          sql(
+            s"insert into $table values(${Long.MinValue}, ${Int.MinValue}, 
${Short.MinValue}, ${Byte.MinValue})")
+
+          checkSparkAnswerAndOperator(sql(s"SELECT bit_count(col1) FROM 
$table"))
+          checkSparkAnswerAndOperator(sql(s"SELECT bit_count(col2) FROM 
$table"))
+          checkSparkAnswerAndOperator(sql(s"SELECT bit_count(col3) FROM 
$table"))
+          checkSparkAnswerAndOperator(sql(s"SELECT bit_count(col4) FROM 
$table"))
+          checkSparkAnswerAndOperator(sql(s"SELECT bit_count(true) FROM 
$table"))
+          checkSparkAnswerAndOperator(sql(s"SELECT bit_count(false) FROM 
$table"))
+        }
+      }
+    }
+  }
+
+  test("bitwise_count - random values (spark gen)") {
+    withTempDir { dir =>
+      val path = new Path(dir.toURI.toString, "test.parquet")
+      val filename = path.toString
+      val random = new Random(42)
+      withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
+        ParquetGenerator.makeParquetFile(
+          random,
+          spark,
+          filename,
+          10,
+          DataGenOptions(
+            allowNull = true,
+            generateNegativeZero = true,
+            generateArray = false,
+            generateStruct = false,
+            generateMap = false))
+      }
+      val table = spark.read.parquet(filename)
+      val df =
+        table.selectExpr("bit_count(c1)", "bit_count(c2)", "bit_count(c3)", 
"bit_count(c4)")
+
+      checkSparkAnswerAndOperator(df)
+    }
+  }
+
+  test("bitwise_count - random values (native parquet gen)") {
+    Seq(true, false).foreach { dictionaryEnabled =>
+      withTempDir { dir =>
+        val path = new Path(dir.toURI.toString, "test.parquet")
+        makeParquetFileAllTypes(path, dictionaryEnabled, 0, 10000, nullEnabled 
= false)
+        val table = spark.read.parquet(path.toString)
+        checkSparkAnswerAndOperator(
+          table
+            .selectExpr(
+              "bit_count(_2)",
+              "bit_count(_3)",
+              "bit_count(_4)",
+              "bit_count(_5)",
+              "bit_count(_11)"))
+      }
+    }
+  }
+
   test("bitwise shift with different left/right types") {
     Seq(false, true).foreach { dictionary =>
       withSQLConf("parquet.enable.dictionary" -> dictionary.toString) {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@datafusion.apache.org
For additional commands, e-mail: commits-h...@datafusion.apache.org

Reply via email to