This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new c51f977c8 Feat: support bit_get function (#1713)
c51f977c8 is described below

commit c51f977c8ef42d0e6bba7f2d3493d29750b25be5
Author: Kazantsev Maksim <kazantsev....@yandex.ru>
AuthorDate: Thu Jun 26 15:49:38 2025 +0400

    Feat: support bit_get function (#1713)
---
 native/spark-expr/src/bitwise_funcs/bitwise_get.rs | 317 +++++++++++++++++++++
 native/spark-expr/src/bitwise_funcs/mod.rs         |   2 +
 native/spark-expr/src/comet_scalar_funcs.rs        |   3 +-
 .../org/apache/comet/serde/QueryPlanSerde.scala    |  74 ++---
 .../scala/org/apache/comet/serde/bitwise.scala     | 161 +++++++++++
 .../apache/comet/CometBitwiseExpressionSuite.scala | 209 ++++++++++++++
 .../org/apache/comet/CometExpressionSuite.scala    | 113 --------
 7 files changed, 707 insertions(+), 172 deletions(-)

diff --git a/native/spark-expr/src/bitwise_funcs/bitwise_get.rs 
b/native/spark-expr/src/bitwise_funcs/bitwise_get.rs
new file mode 100644
index 000000000..18b27ef3f
--- /dev/null
+++ b/native/spark-expr/src/bitwise_funcs/bitwise_get.rs
@@ -0,0 +1,317 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use arrow::{array::*, datatypes::DataType};
+use datafusion::common::{exec_err, internal_datafusion_err, Result, 
ScalarValue};
+use datafusion::logical_expr::ColumnarValue;
+use datafusion::logical_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature, 
Volatility};
+use std::any::Any;
+use std::sync::Arc;
+
+#[derive(Debug)]
+pub struct SparkBitwiseGet {
+    signature: Signature,
+    aliases: Vec<String>,
+}
+
+impl Default for SparkBitwiseGet {
+    fn default() -> Self {
+        Self::new()
+    }
+}
+
+impl SparkBitwiseGet {
+    pub fn new() -> Self {
+        Self {
+            signature: Signature::user_defined(Volatility::Immutable),
+            aliases: vec![],
+        }
+    }
+}
+
+impl ScalarUDFImpl for SparkBitwiseGet {
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
+
+    fn name(&self) -> &str {
+        "bit_get"
+    }
+
+    fn signature(&self) -> &Signature {
+        &self.signature
+    }
+
+    fn aliases(&self) -> &[String] {
+        &self.aliases
+    }
+
+    fn return_type(&self, _: &[DataType]) -> Result<DataType> {
+        Ok(DataType::Int8)
+    }
+
+    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> 
Result<ColumnarValue> {
+        let args: [ColumnarValue; 2] = args
+            .args
+            .try_into()
+            .map_err(|_| internal_datafusion_err!("bit_get expects exactly two 
arguments"))?;
+        spark_bit_get(&args)
+    }
+}
+
+macro_rules! bit_get_scalar_position {
+    ($args:expr, $array_type:ty, $pos:expr, $bit_size:expr) => {{
+        if let Some(pos) = $pos {
+            check_position(*pos, $bit_size as i32)?;
+        }
+        let args = $args
+            .as_any()
+            .downcast_ref::<$array_type>()
+            .expect("bit_get_scalar_position failed to downcast array");
+
+        let result: Int8Array = args
+            .iter()
+            .map(|x| x.and_then(|x| $pos.map(|pos| bit_get(x.into(), pos))))
+            .collect();
+
+        Ok(Arc::new(result))
+    }};
+}
+
+macro_rules! bit_get_array_positions {
+    ($args:expr, $array_type:ty, $positions:expr, $bit_size:expr) => {{
+        let args = $args
+            .as_any()
+            .downcast_ref::<$array_type>()
+            .expect("bit_get_array_positions failed to downcast args array");
+
+        let positions = $positions
+            .as_any()
+            .downcast_ref::<Int32Array>()
+            .expect("bit_get_array_positions failed to downcast positions 
array");
+
+        for pos in positions.iter().flatten() {
+            check_position(pos, $bit_size as i32)?
+        }
+
+        let result: Int8Array = args
+            .iter()
+            .zip(positions.iter())
+            .map(|(i, p)| i.and_then(|i| p.map(|p| bit_get(i.into(), p))))
+            .collect();
+
+        Ok(Arc::new(result))
+    }};
+}
+
+pub fn spark_bit_get(args: &[ColumnarValue; 2]) -> Result<ColumnarValue> {
+    match args {
+        [ColumnarValue::Array(args), 
ColumnarValue::Scalar(ScalarValue::Int32(pos))] => {
+            let result: Result<ArrayRef> = match args.data_type() {
+                DataType::Int8 => bit_get_scalar_position!(args, Int8Array, 
pos, i8::BITS),
+                DataType::Int16 => bit_get_scalar_position!(args, Int16Array, 
pos, i16::BITS),
+                DataType::Int32 => bit_get_scalar_position!(args, Int32Array, 
pos, i32::BITS),
+                DataType::Int64 => bit_get_scalar_position!(args, Int64Array, 
pos, i64::BITS),
+                _ => exec_err!(
+                    "Can't be evaluated because the expression's type is {:?}, 
not signed int",
+                    args.data_type()
+                ),
+            };
+            result.map(ColumnarValue::Array)
+        },
+        [ColumnarValue::Array(args), ColumnarValue::Array(positions)] => {
+            if args.len() != positions.len() {
+                return exec_err!(
+                    "Input arrays must have equal length. Positions array has 
{} elements, but arguments array has {} elements",
+                    positions.len(), args.len()
+                );
+            }
+            if !matches!(positions.data_type(), DataType::Int32) {
+                return exec_err!(
+                    "Invalid data type for positions array: expected `Int32`, 
found `{}`",
+                    positions.data_type()
+                );
+            }
+            let result: Result<ArrayRef> = match args.data_type() {
+                DataType::Int8 => bit_get_array_positions!(args, Int8Array, 
positions, i8::BITS),
+                DataType::Int16 => bit_get_array_positions!(args, Int16Array, 
positions, i16::BITS),
+                DataType::Int32 => bit_get_array_positions!(args, Int32Array, 
positions, i32::BITS),
+                DataType::Int64 => bit_get_array_positions!(args, Int64Array, 
positions, i64::BITS),
+                _ => exec_err!(
+                    "Can't be evaluated because the expression's type is {:?}, 
not signed int",
+                    args.data_type()
+                ),
+            };
+            result.map(ColumnarValue::Array)
+        }
+        _ => exec_err!(
+            "Invalid input to function bit_get. Expected (IntegralType array, 
Int32Scalar) or (IntegralType array, Int32Array)"
+        ),
+    }
+}
+
+fn bit_get(arg: i64, pos: i32) -> i8 {
+    ((arg >> pos) & 1) as i8
+}
+
+fn check_position(pos: i32, bit_size: i32) -> Result<()> {
+    if pos < 0 {
+        return exec_err!("Invalid bit position: {:?} is less than zero", pos);
+    }
+    if bit_size <= pos {
+        return exec_err!(
+            "Invalid bit position: {:?} exceeds the bit upper limit: {:?}",
+            pos,
+            bit_size
+        );
+    }
+    Ok(())
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use datafusion::common::cast::as_int8_array;
+
+    #[test]
+    fn bitwise_get_scalar_position() -> Result<()> {
+        let args = [
+            ColumnarValue::Array(Arc::new(Int32Array::from(vec![
+                Some(1),
+                None,
+                Some(1234553454),
+            ]))),
+            ColumnarValue::Scalar(ScalarValue::Int32(Some(1))),
+        ];
+
+        let expected = &Int8Array::from(vec![Some(0), None, Some(1)]);
+
+        let ColumnarValue::Array(result) = spark_bit_get(&args)? else {
+            unreachable!()
+        };
+
+        let result = as_int8_array(&result).expect("failed to downcast to 
Int8Array");
+
+        assert_eq!(result, expected);
+
+        Ok(())
+    }
+
+    #[test]
+    fn bitwise_get_scalar_negative_position() -> Result<()> {
+        let args = [
+            ColumnarValue::Array(Arc::new(Int32Array::from(vec![
+                Some(1),
+                None,
+                Some(1234553454),
+            ]))),
+            ColumnarValue::Scalar(ScalarValue::Int32(Some(-1))),
+        ];
+
+        let expected = String::from("Execution error: Invalid bit position: -1 
is less than zero");
+        let result = spark_bit_get(&args).err().unwrap().to_string();
+
+        assert_eq!(result, expected);
+
+        Ok(())
+    }
+
+    #[test]
+    fn bitwise_get_scalar_overflow_position() -> Result<()> {
+        let args = [
+            ColumnarValue::Array(Arc::new(Int32Array::from(vec![
+                Some(1),
+                None,
+                Some(1234553454),
+            ]))),
+            ColumnarValue::Scalar(ScalarValue::Int32(Some(33))),
+        ];
+
+        let expected = String::from(
+            "Execution error: Invalid bit position: 33 exceeds the bit upper 
limit: 32",
+        );
+        let result = spark_bit_get(&args).err().unwrap().to_string();
+
+        assert_eq!(result, expected);
+
+        Ok(())
+    }
+
+    #[test]
+    fn bitwise_get_array_positions() -> Result<()> {
+        let args = [
+            ColumnarValue::Array(Arc::new(Int32Array::from(vec![
+                Some(1),
+                None,
+                Some(1234553454),
+            ]))),
+            ColumnarValue::Array(Arc::new(Int32Array::from(vec![Some(1), None, 
Some(1)]))),
+        ];
+
+        let expected = &Int8Array::from(vec![Some(0), None, Some(1)]);
+
+        let ColumnarValue::Array(result) = spark_bit_get(&args)? else {
+            unreachable!()
+        };
+
+        let result = as_int8_array(&result).expect("failed to downcast to 
Int8Array");
+
+        assert_eq!(result, expected);
+
+        Ok(())
+    }
+
+    #[test]
+    fn bitwise_get_array_positions_contains_negative() -> Result<()> {
+        let args = [
+            ColumnarValue::Array(Arc::new(Int32Array::from(vec![
+                Some(1),
+                None,
+                Some(1234553454),
+            ]))),
+            ColumnarValue::Array(Arc::new(Int32Array::from(vec![Some(-1), 
None, Some(1)]))),
+        ];
+
+        let expected = String::from("Execution error: Invalid bit position: -1 
is less than zero");
+        let result = spark_bit_get(&args).err().unwrap().to_string();
+
+        assert_eq!(result, expected);
+
+        Ok(())
+    }
+
+    #[test]
+    fn bitwise_get_array_positions_contains_overflow() -> Result<()> {
+        let args = [
+            ColumnarValue::Array(Arc::new(Int32Array::from(vec![
+                Some(1),
+                None,
+                Some(1234553454),
+            ]))),
+            ColumnarValue::Array(Arc::new(Int32Array::from(vec![Some(33), 
None, Some(1)]))),
+        ];
+
+        let expected = String::from(
+            "Execution error: Invalid bit position: 33 exceeds the bit upper 
limit: 32",
+        );
+        let result = spark_bit_get(&args).err().unwrap().to_string();
+
+        assert_eq!(result, expected);
+
+        Ok(())
+    }
+}
diff --git a/native/spark-expr/src/bitwise_funcs/mod.rs 
b/native/spark-expr/src/bitwise_funcs/mod.rs
index 3f148a6dc..17d418675 100644
--- a/native/spark-expr/src/bitwise_funcs/mod.rs
+++ b/native/spark-expr/src/bitwise_funcs/mod.rs
@@ -16,7 +16,9 @@
 // under the License.
 
 mod bitwise_count;
+mod bitwise_get;
 mod bitwise_not;
 
 pub use bitwise_count::SparkBitwiseCount;
+pub use bitwise_get::SparkBitwiseGet;
 pub use bitwise_not::SparkBitwiseNot;
diff --git a/native/spark-expr/src/comet_scalar_funcs.rs 
b/native/spark-expr/src/comet_scalar_funcs.rs
index 11d736d04..6177ef498 100644
--- a/native/spark-expr/src/comet_scalar_funcs.rs
+++ b/native/spark-expr/src/comet_scalar_funcs.rs
@@ -20,7 +20,7 @@ use crate::{
     spark_array_repeat, spark_ceil, spark_date_add, spark_date_sub, 
spark_decimal_div,
     spark_decimal_integral_div, spark_floor, spark_hex, spark_isnan, 
spark_make_decimal,
     spark_read_side_padding, spark_round, spark_rpad, spark_unhex, 
spark_unscaled_value,
-    SparkBitwiseCount, SparkBitwiseNot, SparkChrFunc, SparkDateTrunc,
+    SparkBitwiseCount, SparkBitwiseGet, SparkBitwiseNot, SparkChrFunc, 
SparkDateTrunc,
 };
 use arrow::datatypes::DataType;
 use datafusion::common::{DataFusionError, Result as DataFusionResult};
@@ -157,6 +157,7 @@ fn all_scalar_functions() -> Vec<Arc<ScalarUDF>> {
         Arc::new(ScalarUDF::new_from_impl(SparkChrFunc::default())),
         Arc::new(ScalarUDF::new_from_impl(SparkBitwiseNot::default())),
         Arc::new(ScalarUDF::new_from_impl(SparkBitwiseCount::default())),
+        Arc::new(ScalarUDF::new_from_impl(SparkBitwiseGet::default())),
         Arc::new(ScalarUDF::new_from_impl(SparkDateTrunc::default())),
     ]
 }
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala 
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index 4e45311d0..d0250d52a 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -1625,69 +1625,27 @@ object QueryPlanSerde extends Logging with 
CometExprShim {
           binding,
           (builder, binaryExpr) => builder.setBitwiseAnd(binaryExpr))
 
-      case BitwiseNot(child) =>
-        val childProto = exprToProto(child, inputs, binding)
-        val bitNotScalarExpr =
-          scalarFunctionExprToProto("bit_not", childProto)
-        optExprWithInfo(bitNotScalarExpr, expr, expr.children: _*)
+      case _: BitwiseNot =>
+        CometBitwiseNot.convert(expr, inputs, binding)
 
-      case BitwiseOr(left, right) =>
-        createBinaryExpr(
-          expr,
-          left,
-          right,
-          inputs,
-          binding,
-          (builder, binaryExpr) => builder.setBitwiseOr(binaryExpr))
+      case _: BitwiseOr =>
+        CometBitwiseOr.convert(expr, inputs, binding)
 
-      case BitwiseXor(left, right) =>
-        createBinaryExpr(
-          expr,
-          left,
-          right,
-          inputs,
-          binding,
-          (builder, binaryExpr) => builder.setBitwiseXor(binaryExpr))
-
-      case BitwiseCount(child) =>
-        val childProto = exprToProto(child, inputs, binding)
-        val bitCountScalarExpr =
-          scalarFunctionExprToProtoWithReturnType("bit_count", IntegerType, 
childProto)
-        optExprWithInfo(bitCountScalarExpr, expr, expr.children: _*)
-
-      case ShiftRight(left, right) =>
-        // DataFusion bitwise shift right expression requires
-        // same data type between left and right side
-        val rightExpression = if (left.dataType == LongType) {
-          Cast(right, LongType)
-        } else {
-          right
-        }
+      case _: BitwiseXor =>
+        CometBitwiseXor.convert(expr, inputs, binding)
 
-        createBinaryExpr(
-          expr,
-          left,
-          rightExpression,
-          inputs,
-          binding,
-          (builder, binaryExpr) => builder.setBitwiseShiftRight(binaryExpr))
+      case _: ShiftRight =>
+        CometShiftRight.convert(expr, inputs, binding)
 
-      case ShiftLeft(left, right) =>
-        // DataFusion bitwise shift right expression requires
-        // same data type between left and right side
-        val rightExpression = if (left.dataType == LongType) {
-          Cast(right, LongType)
-        } else {
-          right
-        }
+      case _: BitwiseCount =>
+        CometBitwiseCount.convert(expr, inputs, binding)
+
+      case _: ShiftLeft =>
+        CometShiftLeft.convert(expr, inputs, binding)
+
+      case _: BitwiseGet =>
+        CometBitwiseGet.convert(expr, inputs, binding)
 
-        createBinaryExpr(
-          expr,
-          left,
-          rightExpression,
-          inputs,
-          binding,
-          (builder, binaryExpr) => builder.setBitwiseShiftLeft(binaryExpr))
       case In(value, list) =>
         in(expr, value, list, inputs, binding, negate = false)
 
diff --git a/spark/src/main/scala/org/apache/comet/serde/bitwise.scala 
b/spark/src/main/scala/org/apache/comet/serde/bitwise.scala
new file mode 100644
index 000000000..50a22e6b9
--- /dev/null
+++ b/spark/src/main/scala/org/apache/comet/serde/bitwise.scala
@@ -0,0 +1,161 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.serde
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.types.{ByteType, IntegerType, LongType}
+
+import org.apache.comet.serde.QueryPlanSerde._
+
+object CometBitwiseAdd extends CometExpressionSerde {
+  override def convert(
+      expr: Expression,
+      inputs: Seq[Attribute],
+      binding: Boolean): Option[ExprOuterClass.Expr] = {
+    val bitwiseAndExpr = expr.asInstanceOf[BitwiseAnd]
+    createBinaryExpr(
+      expr,
+      bitwiseAndExpr.left,
+      bitwiseAndExpr.right,
+      inputs,
+      binding,
+      (builder, binaryExpr) => builder.setBitwiseAnd(binaryExpr))
+  }
+}
+
+object CometBitwiseNot extends CometExpressionSerde {
+  override def convert(
+      expr: Expression,
+      inputs: Seq[Attribute],
+      binding: Boolean): Option[ExprOuterClass.Expr] = {
+    val bitwiseNotExpr = expr.asInstanceOf[BitwiseNot]
+    val childProto = exprToProto(bitwiseNotExpr.child, inputs, binding)
+    val bitNotScalarExpr =
+      scalarFunctionExprToProto("bit_not", childProto)
+    optExprWithInfo(bitNotScalarExpr, expr, expr.children: _*)
+  }
+}
+
+object CometBitwiseOr extends CometExpressionSerde {
+  override def convert(
+      expr: Expression,
+      inputs: Seq[Attribute],
+      binding: Boolean): Option[ExprOuterClass.Expr] = {
+    val bitwiseOrExpr = expr.asInstanceOf[BitwiseOr]
+    createBinaryExpr(
+      expr,
+      bitwiseOrExpr.left,
+      bitwiseOrExpr.right,
+      inputs,
+      binding,
+      (builder, binaryExpr) => builder.setBitwiseOr(binaryExpr))
+  }
+}
+
+object CometBitwiseXor extends CometExpressionSerde {
+  override def convert(
+      expr: Expression,
+      inputs: Seq[Attribute],
+      binding: Boolean): Option[ExprOuterClass.Expr] = {
+    val bitwiseXorExpr = expr.asInstanceOf[BitwiseXor]
+    createBinaryExpr(
+      expr,
+      bitwiseXorExpr.left,
+      bitwiseXorExpr.right,
+      inputs,
+      binding,
+      (builder, binaryExpr) => builder.setBitwiseXor(binaryExpr))
+  }
+}
+
+object CometShiftRight extends CometExpressionSerde {
+  override def convert(
+      expr: Expression,
+      inputs: Seq[Attribute],
+      binding: Boolean): Option[ExprOuterClass.Expr] = {
+    val shiftRightExpr = expr.asInstanceOf[ShiftRight]
+    // DataFusion bitwise shift right expression requires
+    // same data type between left and right side
+    val rightExpression = if (shiftRightExpr.left.dataType == LongType) {
+      Cast(shiftRightExpr.right, LongType)
+    } else {
+      shiftRightExpr.right
+    }
+
+    createBinaryExpr(
+      expr,
+      shiftRightExpr.left,
+      rightExpression,
+      inputs,
+      binding,
+      (builder, binaryExpr) => builder.setBitwiseShiftRight(binaryExpr))
+  }
+}
+
+object CometShiftLeft extends CometExpressionSerde {
+  override def convert(
+      expr: Expression,
+      inputs: Seq[Attribute],
+      binding: Boolean): Option[ExprOuterClass.Expr] = {
+    val shiftLeftLeft = expr.asInstanceOf[ShiftLeft]
+    // DataFusion bitwise shift right expression requires
+    // same data type between left and right side
+    val rightExpression = if (shiftLeftLeft.left.dataType == LongType) {
+      Cast(shiftLeftLeft.right, LongType)
+    } else {
+      shiftLeftLeft.right
+    }
+
+    createBinaryExpr(
+      expr,
+      shiftLeftLeft.left,
+      rightExpression,
+      inputs,
+      binding,
+      (builder, binaryExpr) => builder.setBitwiseShiftLeft(binaryExpr))
+  }
+}
+
+object CometBitwiseGet extends CometExpressionSerde {
+  override def convert(
+      expr: Expression,
+      inputs: Seq[Attribute],
+      binding: Boolean): Option[ExprOuterClass.Expr] = {
+    val bitwiseGetExpr = expr.asInstanceOf[BitwiseGet]
+    val argProto = exprToProto(bitwiseGetExpr.left, inputs, binding)
+    val posProto = exprToProto(bitwiseGetExpr.right, inputs, binding)
+    val bitGetScalarExpr =
+      scalarFunctionExprToProtoWithReturnType("bit_get", ByteType, argProto, 
posProto)
+    optExprWithInfo(bitGetScalarExpr, expr, expr.children: _*)
+  }
+}
+
+object CometBitwiseCount extends CometExpressionSerde {
+  override def convert(
+      expr: Expression,
+      inputs: Seq[Attribute],
+      binding: Boolean): Option[ExprOuterClass.Expr] = {
+    val bitwiseCountExpr = expr.asInstanceOf[BitwiseCount]
+    val childProto = exprToProto(bitwiseCountExpr.child, inputs, binding)
+    val bitCountScalarExpr =
+      scalarFunctionExprToProtoWithReturnType("bit_count", IntegerType, 
childProto)
+    optExprWithInfo(bitCountScalarExpr, expr, expr.children: _*)
+  }
+}
diff --git 
a/spark/src/test/scala/org/apache/comet/CometBitwiseExpressionSuite.scala 
b/spark/src/test/scala/org/apache/comet/CometBitwiseExpressionSuite.scala
new file mode 100644
index 000000000..d89e81b0f
--- /dev/null
+++ b/spark/src/test/scala/org/apache/comet/CometBitwiseExpressionSuite.scala
@@ -0,0 +1,209 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet
+
+import scala.util.Random
+
+import org.apache.hadoop.fs.Path
+import org.apache.spark.sql.CometTestBase
+import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
+
+import org.apache.comet.testing.{DataGenOptions, ParquetGenerator}
+
+class CometBitwiseExpressionSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
+
+  test("bitwise expressions") {
+    Seq(false, true).foreach { dictionary =>
+      withSQLConf("parquet.enable.dictionary" -> dictionary.toString) {
+        val table = "test"
+        withTable(table) {
+          sql(s"create table $table(col1 int, col2 int) using parquet")
+          sql(s"insert into $table values(1111, 2)")
+          sql(s"insert into $table values(1111, 2)")
+          sql(s"insert into $table values(3333, 4)")
+          sql(s"insert into $table values(5555, 6)")
+
+          checkSparkAnswerAndOperator(
+            s"SELECT col1 & col2,  col1 | col2, col1 ^ col2 FROM $table")
+          checkSparkAnswerAndOperator(
+            s"SELECT col1 & 1234,  col1 | 1234, col1 ^ 1234 FROM $table")
+          checkSparkAnswerAndOperator(
+            s"SELECT shiftright(col1, 2), shiftright(col1, col2) FROM $table")
+          checkSparkAnswerAndOperator(
+            s"SELECT shiftleft(col1, 2), shiftleft(col1, col2) FROM $table")
+          checkSparkAnswerAndOperator(s"SELECT ~(11), ~col1, ~col2 FROM 
$table")
+        }
+      }
+    }
+  }
+
+  test("bitwise shift with different left/right types") {
+    Seq(false, true).foreach { dictionary =>
+      withSQLConf("parquet.enable.dictionary" -> dictionary.toString) {
+        val table = "test"
+        withTable(table) {
+          sql(s"create table $table(col1 long, col2 int) using parquet")
+          sql(s"insert into $table values(1111, 2)")
+          sql(s"insert into $table values(1111, 2)")
+          sql(s"insert into $table values(3333, 4)")
+          sql(s"insert into $table values(5555, 6)")
+
+          checkSparkAnswerAndOperator(
+            s"SELECT shiftright(col1, 2), shiftright(col1, col2) FROM $table")
+          checkSparkAnswerAndOperator(
+            s"SELECT shiftleft(col1, 2), shiftleft(col1, col2) FROM $table")
+        }
+      }
+    }
+  }
+
+  test("bitwise_get - throws exceptions") {
+    def checkSparkAndCometEqualThrows(query: String): Unit = {
+      checkSparkMaybeThrows(sql(query)) match {
+        case (Some(sparkExc), Some(cometExc)) =>
+          assert(sparkExc.getMessage == cometExc.getMessage)
+        case _ => fail("Exception should be thrown")
+      }
+    }
+    checkSparkAndCometEqualThrows("select bit_get(1000, -30)")
+    checkSparkAndCometEqualThrows("select bit_get(cast(1000 as byte), 9)")
+    checkSparkAndCometEqualThrows("select bit_count(cast(null as byte), 4)")
+    checkSparkAndCometEqualThrows("select bit_count(1000, cast(null as int))")
+  }
+
+  test("bitwise_get - random values (spark parquet gen)") {
+    withTempDir { dir =>
+      val path = new Path(dir.toURI.toString, "test.parquet")
+      val filename = path.toString
+      val random = new Random(42)
+      withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
+        ParquetGenerator.makeParquetFile(
+          random,
+          spark,
+          filename,
+          100,
+          DataGenOptions(
+            allowNull = true,
+            generateNegativeZero = true,
+            generateArray = false,
+            generateStruct = false,
+            generateMap = false))
+      }
+      val table = spark.read.parquet(filename)
+      checkSparkAnswerAndOperator(
+        table
+          .selectExpr("bit_get(c1, 7)", "bit_get(c2, 10)", "bit_get(c3, 12)", 
"bit_get(c4, 16)"))
+    }
+  }
+
+  test("bitwise_get - random values (native parquet gen)") {
+    def randomBitPosition(maxBitPosition: Int): Int = {
+      Random.nextInt(maxBitPosition)
+    }
+    Seq(true, false).foreach { dictionaryEnabled =>
+      withTempDir { dir =>
+        val path = new Path(dir.toURI.toString, "test.parquet")
+        makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled, 0, 10000, 
nullEnabled = false)
+        val table = spark.read.parquet(path.toString)
+        (0 to 10).foreach { _ =>
+          val byteBitPosition = randomBitPosition(java.lang.Byte.SIZE)
+          val shortBitPosition = randomBitPosition(java.lang.Short.SIZE)
+          val intBitPosition = randomBitPosition(java.lang.Integer.SIZE)
+          val longBitPosition = randomBitPosition(java.lang.Long.SIZE)
+          checkSparkAnswerAndOperator(
+            table
+              .selectExpr(
+                s"bit_get(_2, $byteBitPosition)",
+                s"bit_get(_3, $shortBitPosition)",
+                s"bit_get(_4, $intBitPosition)",
+                s"bit_get(_5, $longBitPosition)",
+                s"bit_get(_11, $longBitPosition)"))
+        }
+      }
+    }
+  }
+
+  test("bitwise_count - min/max values") {
+    Seq(false, true).foreach { dictionary =>
+      withSQLConf("parquet.enable.dictionary" -> dictionary.toString) {
+        val table = "bitwise_count_test"
+        withTable(table) {
+          sql(s"create table $table(col1 long, col2 int, col3 short, col4 
byte) using parquet")
+          sql(s"insert into $table values(1111, 2222, 17, 7)")
+          sql(
+            s"insert into $table values(${Long.MaxValue}, ${Int.MaxValue}, 
${Short.MaxValue}, ${Byte.MaxValue})")
+          sql(
+            s"insert into $table values(${Long.MinValue}, ${Int.MinValue}, 
${Short.MinValue}, ${Byte.MinValue})")
+
+          checkSparkAnswerAndOperator(sql(s"SELECT bit_count(col1) FROM 
$table"))
+          checkSparkAnswerAndOperator(sql(s"SELECT bit_count(col2) FROM 
$table"))
+          checkSparkAnswerAndOperator(sql(s"SELECT bit_count(col3) FROM 
$table"))
+          checkSparkAnswerAndOperator(sql(s"SELECT bit_count(col4) FROM 
$table"))
+          checkSparkAnswerAndOperator(sql(s"SELECT bit_count(true) FROM 
$table"))
+          checkSparkAnswerAndOperator(sql(s"SELECT bit_count(false) FROM 
$table"))
+        }
+      }
+    }
+  }
+
+  test("bitwise_count - random values (spark gen)") {
+    withTempDir { dir =>
+      val path = new Path(dir.toURI.toString, "test.parquet")
+      val filename = path.toString
+      val random = new Random(42)
+      withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
+        ParquetGenerator.makeParquetFile(
+          random,
+          spark,
+          filename,
+          10,
+          DataGenOptions(
+            allowNull = true,
+            generateNegativeZero = true,
+            generateArray = false,
+            generateStruct = false,
+            generateMap = false))
+      }
+      val table = spark.read.parquet(filename)
+      val df =
+        table.selectExpr("bit_count(c1)", "bit_count(c2)", "bit_count(c3)", 
"bit_count(c4)")
+
+      checkSparkAnswerAndOperator(df)
+    }
+  }
+
+  test("bitwise_count - random values (native parquet gen)") {
+    Seq(true, false).foreach { dictionaryEnabled =>
+      withTempDir { dir =>
+        val path = new Path(dir.toURI.toString, "test.parquet")
+        makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled, 0, 10000, 
nullEnabled = false)
+        val table = spark.read.parquet(path.toString)
+        checkSparkAnswerAndOperator(
+          table
+            .selectExpr(
+              "bit_count(_2)",
+              "bit_count(_3)",
+              "bit_count(_4)",
+              "bit_count(_5)",
+              "bit_count(_11)"))
+      }
+    }
+  }
+}
diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala 
b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
index ce9ac120c..34e38895a 100644
--- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
@@ -43,7 +43,6 @@ import 
org.apache.spark.sql.internal.SQLConf.SESSION_LOCAL_TIMEZONE
 import org.apache.spark.sql.types.{Decimal, DecimalType, IntegerType, 
StringType, StructType}
 
 import org.apache.comet.CometSparkSessionExtensions.isSpark40Plus
-import org.apache.comet.testing.{DataGenOptions, ParquetGenerator}
 
 class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
   import testImplicits._
@@ -115,93 +114,6 @@ class CometExpressionSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
     }
   }
 
-  test("bitwise_count - min/max values") {
-    Seq(false, true).foreach { dictionary =>
-      withSQLConf("parquet.enable.dictionary" -> dictionary.toString) {
-        val table = "bitwise_count_test"
-        withTable(table) {
-          sql(s"create table $table(col1 long, col2 int, col3 short, col4 
byte) using parquet")
-          sql(s"insert into $table values(1111, 2222, 17, 7)")
-          sql(
-            s"insert into $table values(${Long.MaxValue}, ${Int.MaxValue}, 
${Short.MaxValue}, ${Byte.MaxValue})")
-          sql(
-            s"insert into $table values(${Long.MinValue}, ${Int.MinValue}, 
${Short.MinValue}, ${Byte.MinValue})")
-
-          checkSparkAnswerAndOperator(sql(s"SELECT bit_count(col1) FROM 
$table"))
-          checkSparkAnswerAndOperator(sql(s"SELECT bit_count(col2) FROM 
$table"))
-          checkSparkAnswerAndOperator(sql(s"SELECT bit_count(col3) FROM 
$table"))
-          checkSparkAnswerAndOperator(sql(s"SELECT bit_count(col4) FROM 
$table"))
-          checkSparkAnswerAndOperator(sql(s"SELECT bit_count(true) FROM 
$table"))
-          checkSparkAnswerAndOperator(sql(s"SELECT bit_count(false) FROM 
$table"))
-        }
-      }
-    }
-  }
-
-  test("bitwise_count - random values (spark gen)") {
-    withTempDir { dir =>
-      val path = new Path(dir.toURI.toString, "test.parquet")
-      val filename = path.toString
-      val random = new Random(42)
-      withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
-        ParquetGenerator.makeParquetFile(
-          random,
-          spark,
-          filename,
-          10,
-          DataGenOptions(
-            allowNull = true,
-            generateNegativeZero = true,
-            generateArray = false,
-            generateStruct = false,
-            generateMap = false))
-      }
-      val table = spark.read.parquet(filename)
-      val df =
-        table.selectExpr("bit_count(c1)", "bit_count(c2)", "bit_count(c3)", 
"bit_count(c4)")
-
-      checkSparkAnswerAndOperator(df)
-    }
-  }
-
-  test("bitwise_count - random values (native parquet gen)") {
-    Seq(true, false).foreach { dictionaryEnabled =>
-      withTempDir { dir =>
-        val path = new Path(dir.toURI.toString, "test.parquet")
-        makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled, 0, 10000, 
nullEnabled = false)
-        val table = spark.read.parquet(path.toString)
-        checkSparkAnswerAndOperator(
-          table
-            .selectExpr(
-              "bit_count(_2)",
-              "bit_count(_3)",
-              "bit_count(_4)",
-              "bit_count(_5)",
-              "bit_count(_11)"))
-      }
-    }
-  }
-
-  test("bitwise shift with different left/right types") {
-    Seq(false, true).foreach { dictionary =>
-      withSQLConf("parquet.enable.dictionary" -> dictionary.toString) {
-        val table = "test"
-        withTable(table) {
-          sql(s"create table $table(col1 long, col2 int) using parquet")
-          sql(s"insert into $table values(1111, 2)")
-          sql(s"insert into $table values(1111, 2)")
-          sql(s"insert into $table values(3333, 4)")
-          sql(s"insert into $table values(5555, 6)")
-
-          checkSparkAnswerAndOperator(
-            s"SELECT shiftright(col1, 2), shiftright(col1, col2) FROM $table")
-          checkSparkAnswerAndOperator(
-            s"SELECT shiftleft(col1, 2), shiftleft(col1, col2) FROM $table")
-        }
-      }
-    }
-  }
-
   test("basic data type support") {
     Seq(true, false).foreach { dictionaryEnabled =>
       withTempDir { dir =>
@@ -1552,31 +1464,6 @@ class CometExpressionSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
       })
   }
 
-  test("bitwise expressions") {
-    Seq(false, true).foreach { dictionary =>
-      withSQLConf("parquet.enable.dictionary" -> dictionary.toString) {
-        val table = "test"
-        withTable(table) {
-          sql(s"create table $table(col1 int, col2 int) using parquet")
-          sql(s"insert into $table values(1111, 2)")
-          sql(s"insert into $table values(1111, 2)")
-          sql(s"insert into $table values(3333, 4)")
-          sql(s"insert into $table values(5555, 6)")
-
-          checkSparkAnswerAndOperator(
-            s"SELECT col1 & col2,  col1 | col2, col1 ^ col2 FROM $table")
-          checkSparkAnswerAndOperator(
-            s"SELECT col1 & 1234,  col1 | 1234, col1 ^ 1234 FROM $table")
-          checkSparkAnswerAndOperator(
-            s"SELECT shiftright(col1, 2), shiftright(col1, col2) FROM $table")
-          checkSparkAnswerAndOperator(
-            s"SELECT shiftleft(col1, 2), shiftleft(col1, col2) FROM $table")
-          checkSparkAnswerAndOperator(s"SELECT ~(11), ~col1, ~col2 FROM 
$table")
-        }
-      }
-    }
-  }
-
   test("test in(set)/not in(set)") {
     Seq("100", "0").foreach { inSetThreshold =>
       Seq(false, true).foreach { dictionary =>


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@datafusion.apache.org
For additional commands, e-mail: commits-h...@datafusion.apache.org


Reply via email to