This is an automated email from the ASF dual-hosted git repository. agrove pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push: new c51f977c8 Feat: support bit_get function (#1713) c51f977c8 is described below commit c51f977c8ef42d0e6bba7f2d3493d29750b25be5 Author: Kazantsev Maksim <kazantsev....@yandex.ru> AuthorDate: Thu Jun 26 15:49:38 2025 +0400 Feat: support bit_get function (#1713) --- native/spark-expr/src/bitwise_funcs/bitwise_get.rs | 317 +++++++++++++++++++++ native/spark-expr/src/bitwise_funcs/mod.rs | 2 + native/spark-expr/src/comet_scalar_funcs.rs | 3 +- .../org/apache/comet/serde/QueryPlanSerde.scala | 74 ++--- .../scala/org/apache/comet/serde/bitwise.scala | 161 +++++++++++ .../apache/comet/CometBitwiseExpressionSuite.scala | 209 ++++++++++++++ .../org/apache/comet/CometExpressionSuite.scala | 113 -------- 7 files changed, 707 insertions(+), 172 deletions(-) diff --git a/native/spark-expr/src/bitwise_funcs/bitwise_get.rs b/native/spark-expr/src/bitwise_funcs/bitwise_get.rs new file mode 100644 index 000000000..18b27ef3f --- /dev/null +++ b/native/spark-expr/src/bitwise_funcs/bitwise_get.rs @@ -0,0 +1,317 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::{array::*, datatypes::DataType}; +use datafusion::common::{exec_err, internal_datafusion_err, Result, ScalarValue}; +use datafusion::logical_expr::ColumnarValue; +use datafusion::logical_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; +use std::any::Any; +use std::sync::Arc; + +#[derive(Debug)] +pub struct SparkBitwiseGet { + signature: Signature, + aliases: Vec<String>, +} + +impl Default for SparkBitwiseGet { + fn default() -> Self { + Self::new() + } +} + +impl SparkBitwiseGet { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + aliases: vec![], + } + } +} + +impl ScalarUDFImpl for SparkBitwiseGet { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "bit_get" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn return_type(&self, _: &[DataType]) -> Result<DataType> { + Ok(DataType::Int8) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> { + let args: [ColumnarValue; 2] = args + .args + .try_into() + .map_err(|_| internal_datafusion_err!("bit_get expects exactly two arguments"))?; + spark_bit_get(&args) + } +} + +macro_rules! bit_get_scalar_position { + ($args:expr, $array_type:ty, $pos:expr, $bit_size:expr) => {{ + if let Some(pos) = $pos { + check_position(*pos, $bit_size as i32)?; + } + let args = $args + .as_any() + .downcast_ref::<$array_type>() + .expect("bit_get_scalar_position failed to downcast array"); + + let result: Int8Array = args + .iter() + .map(|x| x.and_then(|x| $pos.map(|pos| bit_get(x.into(), pos)))) + .collect(); + + Ok(Arc::new(result)) + }}; +} + +macro_rules! bit_get_array_positions { + ($args:expr, $array_type:ty, $positions:expr, $bit_size:expr) => {{ + let args = $args + .as_any() + .downcast_ref::<$array_type>() + .expect("bit_get_array_positions failed to downcast args array"); + + let positions = $positions + .as_any() + .downcast_ref::<Int32Array>() + .expect("bit_get_array_positions failed to downcast positions array"); + + for pos in positions.iter().flatten() { + check_position(pos, $bit_size as i32)? + } + + let result: Int8Array = args + .iter() + .zip(positions.iter()) + .map(|(i, p)| i.and_then(|i| p.map(|p| bit_get(i.into(), p)))) + .collect(); + + Ok(Arc::new(result)) + }}; +} + +pub fn spark_bit_get(args: &[ColumnarValue; 2]) -> Result<ColumnarValue> { + match args { + [ColumnarValue::Array(args), ColumnarValue::Scalar(ScalarValue::Int32(pos))] => { + let result: Result<ArrayRef> = match args.data_type() { + DataType::Int8 => bit_get_scalar_position!(args, Int8Array, pos, i8::BITS), + DataType::Int16 => bit_get_scalar_position!(args, Int16Array, pos, i16::BITS), + DataType::Int32 => bit_get_scalar_position!(args, Int32Array, pos, i32::BITS), + DataType::Int64 => bit_get_scalar_position!(args, Int64Array, pos, i64::BITS), + _ => exec_err!( + "Can't be evaluated because the expression's type is {:?}, not signed int", + args.data_type() + ), + }; + result.map(ColumnarValue::Array) + }, + [ColumnarValue::Array(args), ColumnarValue::Array(positions)] => { + if args.len() != positions.len() { + return exec_err!( + "Input arrays must have equal length. Positions array has {} elements, but arguments array has {} elements", + positions.len(), args.len() + ); + } + if !matches!(positions.data_type(), DataType::Int32) { + return exec_err!( + "Invalid data type for positions array: expected `Int32`, found `{}`", + positions.data_type() + ); + } + let result: Result<ArrayRef> = match args.data_type() { + DataType::Int8 => bit_get_array_positions!(args, Int8Array, positions, i8::BITS), + DataType::Int16 => bit_get_array_positions!(args, Int16Array, positions, i16::BITS), + DataType::Int32 => bit_get_array_positions!(args, Int32Array, positions, i32::BITS), + DataType::Int64 => bit_get_array_positions!(args, Int64Array, positions, i64::BITS), + _ => exec_err!( + "Can't be evaluated because the expression's type is {:?}, not signed int", + args.data_type() + ), + }; + result.map(ColumnarValue::Array) + } + _ => exec_err!( + "Invalid input to function bit_get. Expected (IntegralType array, Int32Scalar) or (IntegralType array, Int32Array)" + ), + } +} + +fn bit_get(arg: i64, pos: i32) -> i8 { + ((arg >> pos) & 1) as i8 +} + +fn check_position(pos: i32, bit_size: i32) -> Result<()> { + if pos < 0 { + return exec_err!("Invalid bit position: {:?} is less than zero", pos); + } + if bit_size <= pos { + return exec_err!( + "Invalid bit position: {:?} exceeds the bit upper limit: {:?}", + pos, + bit_size + ); + } + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use datafusion::common::cast::as_int8_array; + + #[test] + fn bitwise_get_scalar_position() -> Result<()> { + let args = [ + ColumnarValue::Array(Arc::new(Int32Array::from(vec![ + Some(1), + None, + Some(1234553454), + ]))), + ColumnarValue::Scalar(ScalarValue::Int32(Some(1))), + ]; + + let expected = &Int8Array::from(vec![Some(0), None, Some(1)]); + + let ColumnarValue::Array(result) = spark_bit_get(&args)? else { + unreachable!() + }; + + let result = as_int8_array(&result).expect("failed to downcast to Int8Array"); + + assert_eq!(result, expected); + + Ok(()) + } + + #[test] + fn bitwise_get_scalar_negative_position() -> Result<()> { + let args = [ + ColumnarValue::Array(Arc::new(Int32Array::from(vec![ + Some(1), + None, + Some(1234553454), + ]))), + ColumnarValue::Scalar(ScalarValue::Int32(Some(-1))), + ]; + + let expected = String::from("Execution error: Invalid bit position: -1 is less than zero"); + let result = spark_bit_get(&args).err().unwrap().to_string(); + + assert_eq!(result, expected); + + Ok(()) + } + + #[test] + fn bitwise_get_scalar_overflow_position() -> Result<()> { + let args = [ + ColumnarValue::Array(Arc::new(Int32Array::from(vec![ + Some(1), + None, + Some(1234553454), + ]))), + ColumnarValue::Scalar(ScalarValue::Int32(Some(33))), + ]; + + let expected = String::from( + "Execution error: Invalid bit position: 33 exceeds the bit upper limit: 32", + ); + let result = spark_bit_get(&args).err().unwrap().to_string(); + + assert_eq!(result, expected); + + Ok(()) + } + + #[test] + fn bitwise_get_array_positions() -> Result<()> { + let args = [ + ColumnarValue::Array(Arc::new(Int32Array::from(vec![ + Some(1), + None, + Some(1234553454), + ]))), + ColumnarValue::Array(Arc::new(Int32Array::from(vec![Some(1), None, Some(1)]))), + ]; + + let expected = &Int8Array::from(vec![Some(0), None, Some(1)]); + + let ColumnarValue::Array(result) = spark_bit_get(&args)? else { + unreachable!() + }; + + let result = as_int8_array(&result).expect("failed to downcast to Int8Array"); + + assert_eq!(result, expected); + + Ok(()) + } + + #[test] + fn bitwise_get_array_positions_contains_negative() -> Result<()> { + let args = [ + ColumnarValue::Array(Arc::new(Int32Array::from(vec![ + Some(1), + None, + Some(1234553454), + ]))), + ColumnarValue::Array(Arc::new(Int32Array::from(vec![Some(-1), None, Some(1)]))), + ]; + + let expected = String::from("Execution error: Invalid bit position: -1 is less than zero"); + let result = spark_bit_get(&args).err().unwrap().to_string(); + + assert_eq!(result, expected); + + Ok(()) + } + + #[test] + fn bitwise_get_array_positions_contains_overflow() -> Result<()> { + let args = [ + ColumnarValue::Array(Arc::new(Int32Array::from(vec![ + Some(1), + None, + Some(1234553454), + ]))), + ColumnarValue::Array(Arc::new(Int32Array::from(vec![Some(33), None, Some(1)]))), + ]; + + let expected = String::from( + "Execution error: Invalid bit position: 33 exceeds the bit upper limit: 32", + ); + let result = spark_bit_get(&args).err().unwrap().to_string(); + + assert_eq!(result, expected); + + Ok(()) + } +} diff --git a/native/spark-expr/src/bitwise_funcs/mod.rs b/native/spark-expr/src/bitwise_funcs/mod.rs index 3f148a6dc..17d418675 100644 --- a/native/spark-expr/src/bitwise_funcs/mod.rs +++ b/native/spark-expr/src/bitwise_funcs/mod.rs @@ -16,7 +16,9 @@ // under the License. mod bitwise_count; +mod bitwise_get; mod bitwise_not; pub use bitwise_count::SparkBitwiseCount; +pub use bitwise_get::SparkBitwiseGet; pub use bitwise_not::SparkBitwiseNot; diff --git a/native/spark-expr/src/comet_scalar_funcs.rs b/native/spark-expr/src/comet_scalar_funcs.rs index 11d736d04..6177ef498 100644 --- a/native/spark-expr/src/comet_scalar_funcs.rs +++ b/native/spark-expr/src/comet_scalar_funcs.rs @@ -20,7 +20,7 @@ use crate::{ spark_array_repeat, spark_ceil, spark_date_add, spark_date_sub, spark_decimal_div, spark_decimal_integral_div, spark_floor, spark_hex, spark_isnan, spark_make_decimal, spark_read_side_padding, spark_round, spark_rpad, spark_unhex, spark_unscaled_value, - SparkBitwiseCount, SparkBitwiseNot, SparkChrFunc, SparkDateTrunc, + SparkBitwiseCount, SparkBitwiseGet, SparkBitwiseNot, SparkChrFunc, SparkDateTrunc, }; use arrow::datatypes::DataType; use datafusion::common::{DataFusionError, Result as DataFusionResult}; @@ -157,6 +157,7 @@ fn all_scalar_functions() -> Vec<Arc<ScalarUDF>> { Arc::new(ScalarUDF::new_from_impl(SparkChrFunc::default())), Arc::new(ScalarUDF::new_from_impl(SparkBitwiseNot::default())), Arc::new(ScalarUDF::new_from_impl(SparkBitwiseCount::default())), + Arc::new(ScalarUDF::new_from_impl(SparkBitwiseGet::default())), Arc::new(ScalarUDF::new_from_impl(SparkDateTrunc::default())), ] } diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 4e45311d0..d0250d52a 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -1625,69 +1625,27 @@ object QueryPlanSerde extends Logging with CometExprShim { binding, (builder, binaryExpr) => builder.setBitwiseAnd(binaryExpr)) - case BitwiseNot(child) => - val childProto = exprToProto(child, inputs, binding) - val bitNotScalarExpr = - scalarFunctionExprToProto("bit_not", childProto) - optExprWithInfo(bitNotScalarExpr, expr, expr.children: _*) + case _: BitwiseNot => + CometBitwiseNot.convert(expr, inputs, binding) - case BitwiseOr(left, right) => - createBinaryExpr( - expr, - left, - right, - inputs, - binding, - (builder, binaryExpr) => builder.setBitwiseOr(binaryExpr)) + case _: BitwiseOr => + CometBitwiseOr.convert(expr, inputs, binding) - case BitwiseXor(left, right) => - createBinaryExpr( - expr, - left, - right, - inputs, - binding, - (builder, binaryExpr) => builder.setBitwiseXor(binaryExpr)) - - case BitwiseCount(child) => - val childProto = exprToProto(child, inputs, binding) - val bitCountScalarExpr = - scalarFunctionExprToProtoWithReturnType("bit_count", IntegerType, childProto) - optExprWithInfo(bitCountScalarExpr, expr, expr.children: _*) - - case ShiftRight(left, right) => - // DataFusion bitwise shift right expression requires - // same data type between left and right side - val rightExpression = if (left.dataType == LongType) { - Cast(right, LongType) - } else { - right - } + case _: BitwiseXor => + CometBitwiseXor.convert(expr, inputs, binding) - createBinaryExpr( - expr, - left, - rightExpression, - inputs, - binding, - (builder, binaryExpr) => builder.setBitwiseShiftRight(binaryExpr)) + case _: ShiftRight => + CometShiftRight.convert(expr, inputs, binding) - case ShiftLeft(left, right) => - // DataFusion bitwise shift right expression requires - // same data type between left and right side - val rightExpression = if (left.dataType == LongType) { - Cast(right, LongType) - } else { - right - } + case _: BitwiseCount => + CometBitwiseCount.convert(expr, inputs, binding) + + case _: ShiftLeft => + CometShiftLeft.convert(expr, inputs, binding) + + case _: BitwiseGet => + CometBitwiseGet.convert(expr, inputs, binding) - createBinaryExpr( - expr, - left, - rightExpression, - inputs, - binding, - (builder, binaryExpr) => builder.setBitwiseShiftLeft(binaryExpr)) case In(value, list) => in(expr, value, list, inputs, binding, negate = false) diff --git a/spark/src/main/scala/org/apache/comet/serde/bitwise.scala b/spark/src/main/scala/org/apache/comet/serde/bitwise.scala new file mode 100644 index 000000000..50a22e6b9 --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/serde/bitwise.scala @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.serde + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.{ByteType, IntegerType, LongType} + +import org.apache.comet.serde.QueryPlanSerde._ + +object CometBitwiseAdd extends CometExpressionSerde { + override def convert( + expr: Expression, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + val bitwiseAndExpr = expr.asInstanceOf[BitwiseAnd] + createBinaryExpr( + expr, + bitwiseAndExpr.left, + bitwiseAndExpr.right, + inputs, + binding, + (builder, binaryExpr) => builder.setBitwiseAnd(binaryExpr)) + } +} + +object CometBitwiseNot extends CometExpressionSerde { + override def convert( + expr: Expression, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + val bitwiseNotExpr = expr.asInstanceOf[BitwiseNot] + val childProto = exprToProto(bitwiseNotExpr.child, inputs, binding) + val bitNotScalarExpr = + scalarFunctionExprToProto("bit_not", childProto) + optExprWithInfo(bitNotScalarExpr, expr, expr.children: _*) + } +} + +object CometBitwiseOr extends CometExpressionSerde { + override def convert( + expr: Expression, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + val bitwiseOrExpr = expr.asInstanceOf[BitwiseOr] + createBinaryExpr( + expr, + bitwiseOrExpr.left, + bitwiseOrExpr.right, + inputs, + binding, + (builder, binaryExpr) => builder.setBitwiseOr(binaryExpr)) + } +} + +object CometBitwiseXor extends CometExpressionSerde { + override def convert( + expr: Expression, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + val bitwiseXorExpr = expr.asInstanceOf[BitwiseXor] + createBinaryExpr( + expr, + bitwiseXorExpr.left, + bitwiseXorExpr.right, + inputs, + binding, + (builder, binaryExpr) => builder.setBitwiseXor(binaryExpr)) + } +} + +object CometShiftRight extends CometExpressionSerde { + override def convert( + expr: Expression, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + val shiftRightExpr = expr.asInstanceOf[ShiftRight] + // DataFusion bitwise shift right expression requires + // same data type between left and right side + val rightExpression = if (shiftRightExpr.left.dataType == LongType) { + Cast(shiftRightExpr.right, LongType) + } else { + shiftRightExpr.right + } + + createBinaryExpr( + expr, + shiftRightExpr.left, + rightExpression, + inputs, + binding, + (builder, binaryExpr) => builder.setBitwiseShiftRight(binaryExpr)) + } +} + +object CometShiftLeft extends CometExpressionSerde { + override def convert( + expr: Expression, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + val shiftLeftLeft = expr.asInstanceOf[ShiftLeft] + // DataFusion bitwise shift right expression requires + // same data type between left and right side + val rightExpression = if (shiftLeftLeft.left.dataType == LongType) { + Cast(shiftLeftLeft.right, LongType) + } else { + shiftLeftLeft.right + } + + createBinaryExpr( + expr, + shiftLeftLeft.left, + rightExpression, + inputs, + binding, + (builder, binaryExpr) => builder.setBitwiseShiftLeft(binaryExpr)) + } +} + +object CometBitwiseGet extends CometExpressionSerde { + override def convert( + expr: Expression, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + val bitwiseGetExpr = expr.asInstanceOf[BitwiseGet] + val argProto = exprToProto(bitwiseGetExpr.left, inputs, binding) + val posProto = exprToProto(bitwiseGetExpr.right, inputs, binding) + val bitGetScalarExpr = + scalarFunctionExprToProtoWithReturnType("bit_get", ByteType, argProto, posProto) + optExprWithInfo(bitGetScalarExpr, expr, expr.children: _*) + } +} + +object CometBitwiseCount extends CometExpressionSerde { + override def convert( + expr: Expression, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + val bitwiseCountExpr = expr.asInstanceOf[BitwiseCount] + val childProto = exprToProto(bitwiseCountExpr.child, inputs, binding) + val bitCountScalarExpr = + scalarFunctionExprToProtoWithReturnType("bit_count", IntegerType, childProto) + optExprWithInfo(bitCountScalarExpr, expr, expr.children: _*) + } +} diff --git a/spark/src/test/scala/org/apache/comet/CometBitwiseExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometBitwiseExpressionSuite.scala new file mode 100644 index 000000000..d89e81b0f --- /dev/null +++ b/spark/src/test/scala/org/apache/comet/CometBitwiseExpressionSuite.scala @@ -0,0 +1,209 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet + +import scala.util.Random + +import org.apache.hadoop.fs.Path +import org.apache.spark.sql.CometTestBase +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper + +import org.apache.comet.testing.{DataGenOptions, ParquetGenerator} + +class CometBitwiseExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { + + test("bitwise expressions") { + Seq(false, true).foreach { dictionary => + withSQLConf("parquet.enable.dictionary" -> dictionary.toString) { + val table = "test" + withTable(table) { + sql(s"create table $table(col1 int, col2 int) using parquet") + sql(s"insert into $table values(1111, 2)") + sql(s"insert into $table values(1111, 2)") + sql(s"insert into $table values(3333, 4)") + sql(s"insert into $table values(5555, 6)") + + checkSparkAnswerAndOperator( + s"SELECT col1 & col2, col1 | col2, col1 ^ col2 FROM $table") + checkSparkAnswerAndOperator( + s"SELECT col1 & 1234, col1 | 1234, col1 ^ 1234 FROM $table") + checkSparkAnswerAndOperator( + s"SELECT shiftright(col1, 2), shiftright(col1, col2) FROM $table") + checkSparkAnswerAndOperator( + s"SELECT shiftleft(col1, 2), shiftleft(col1, col2) FROM $table") + checkSparkAnswerAndOperator(s"SELECT ~(11), ~col1, ~col2 FROM $table") + } + } + } + } + + test("bitwise shift with different left/right types") { + Seq(false, true).foreach { dictionary => + withSQLConf("parquet.enable.dictionary" -> dictionary.toString) { + val table = "test" + withTable(table) { + sql(s"create table $table(col1 long, col2 int) using parquet") + sql(s"insert into $table values(1111, 2)") + sql(s"insert into $table values(1111, 2)") + sql(s"insert into $table values(3333, 4)") + sql(s"insert into $table values(5555, 6)") + + checkSparkAnswerAndOperator( + s"SELECT shiftright(col1, 2), shiftright(col1, col2) FROM $table") + checkSparkAnswerAndOperator( + s"SELECT shiftleft(col1, 2), shiftleft(col1, col2) FROM $table") + } + } + } + } + + test("bitwise_get - throws exceptions") { + def checkSparkAndCometEqualThrows(query: String): Unit = { + checkSparkMaybeThrows(sql(query)) match { + case (Some(sparkExc), Some(cometExc)) => + assert(sparkExc.getMessage == cometExc.getMessage) + case _ => fail("Exception should be thrown") + } + } + checkSparkAndCometEqualThrows("select bit_get(1000, -30)") + checkSparkAndCometEqualThrows("select bit_get(cast(1000 as byte), 9)") + checkSparkAndCometEqualThrows("select bit_count(cast(null as byte), 4)") + checkSparkAndCometEqualThrows("select bit_count(1000, cast(null as int))") + } + + test("bitwise_get - random values (spark parquet gen)") { + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test.parquet") + val filename = path.toString + val random = new Random(42) + withSQLConf(CometConf.COMET_ENABLED.key -> "false") { + ParquetGenerator.makeParquetFile( + random, + spark, + filename, + 100, + DataGenOptions( + allowNull = true, + generateNegativeZero = true, + generateArray = false, + generateStruct = false, + generateMap = false)) + } + val table = spark.read.parquet(filename) + checkSparkAnswerAndOperator( + table + .selectExpr("bit_get(c1, 7)", "bit_get(c2, 10)", "bit_get(c3, 12)", "bit_get(c4, 16)")) + } + } + + test("bitwise_get - random values (native parquet gen)") { + def randomBitPosition(maxBitPosition: Int): Int = { + Random.nextInt(maxBitPosition) + } + Seq(true, false).foreach { dictionaryEnabled => + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test.parquet") + makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled, 0, 10000, nullEnabled = false) + val table = spark.read.parquet(path.toString) + (0 to 10).foreach { _ => + val byteBitPosition = randomBitPosition(java.lang.Byte.SIZE) + val shortBitPosition = randomBitPosition(java.lang.Short.SIZE) + val intBitPosition = randomBitPosition(java.lang.Integer.SIZE) + val longBitPosition = randomBitPosition(java.lang.Long.SIZE) + checkSparkAnswerAndOperator( + table + .selectExpr( + s"bit_get(_2, $byteBitPosition)", + s"bit_get(_3, $shortBitPosition)", + s"bit_get(_4, $intBitPosition)", + s"bit_get(_5, $longBitPosition)", + s"bit_get(_11, $longBitPosition)")) + } + } + } + } + + test("bitwise_count - min/max values") { + Seq(false, true).foreach { dictionary => + withSQLConf("parquet.enable.dictionary" -> dictionary.toString) { + val table = "bitwise_count_test" + withTable(table) { + sql(s"create table $table(col1 long, col2 int, col3 short, col4 byte) using parquet") + sql(s"insert into $table values(1111, 2222, 17, 7)") + sql( + s"insert into $table values(${Long.MaxValue}, ${Int.MaxValue}, ${Short.MaxValue}, ${Byte.MaxValue})") + sql( + s"insert into $table values(${Long.MinValue}, ${Int.MinValue}, ${Short.MinValue}, ${Byte.MinValue})") + + checkSparkAnswerAndOperator(sql(s"SELECT bit_count(col1) FROM $table")) + checkSparkAnswerAndOperator(sql(s"SELECT bit_count(col2) FROM $table")) + checkSparkAnswerAndOperator(sql(s"SELECT bit_count(col3) FROM $table")) + checkSparkAnswerAndOperator(sql(s"SELECT bit_count(col4) FROM $table")) + checkSparkAnswerAndOperator(sql(s"SELECT bit_count(true) FROM $table")) + checkSparkAnswerAndOperator(sql(s"SELECT bit_count(false) FROM $table")) + } + } + } + } + + test("bitwise_count - random values (spark gen)") { + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test.parquet") + val filename = path.toString + val random = new Random(42) + withSQLConf(CometConf.COMET_ENABLED.key -> "false") { + ParquetGenerator.makeParquetFile( + random, + spark, + filename, + 10, + DataGenOptions( + allowNull = true, + generateNegativeZero = true, + generateArray = false, + generateStruct = false, + generateMap = false)) + } + val table = spark.read.parquet(filename) + val df = + table.selectExpr("bit_count(c1)", "bit_count(c2)", "bit_count(c3)", "bit_count(c4)") + + checkSparkAnswerAndOperator(df) + } + } + + test("bitwise_count - random values (native parquet gen)") { + Seq(true, false).foreach { dictionaryEnabled => + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test.parquet") + makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled, 0, 10000, nullEnabled = false) + val table = spark.read.parquet(path.toString) + checkSparkAnswerAndOperator( + table + .selectExpr( + "bit_count(_2)", + "bit_count(_3)", + "bit_count(_4)", + "bit_count(_5)", + "bit_count(_11)")) + } + } + } +} diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index ce9ac120c..34e38895a 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -43,7 +43,6 @@ import org.apache.spark.sql.internal.SQLConf.SESSION_LOCAL_TIMEZONE import org.apache.spark.sql.types.{Decimal, DecimalType, IntegerType, StringType, StructType} import org.apache.comet.CometSparkSessionExtensions.isSpark40Plus -import org.apache.comet.testing.{DataGenOptions, ParquetGenerator} class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { import testImplicits._ @@ -115,93 +114,6 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } - test("bitwise_count - min/max values") { - Seq(false, true).foreach { dictionary => - withSQLConf("parquet.enable.dictionary" -> dictionary.toString) { - val table = "bitwise_count_test" - withTable(table) { - sql(s"create table $table(col1 long, col2 int, col3 short, col4 byte) using parquet") - sql(s"insert into $table values(1111, 2222, 17, 7)") - sql( - s"insert into $table values(${Long.MaxValue}, ${Int.MaxValue}, ${Short.MaxValue}, ${Byte.MaxValue})") - sql( - s"insert into $table values(${Long.MinValue}, ${Int.MinValue}, ${Short.MinValue}, ${Byte.MinValue})") - - checkSparkAnswerAndOperator(sql(s"SELECT bit_count(col1) FROM $table")) - checkSparkAnswerAndOperator(sql(s"SELECT bit_count(col2) FROM $table")) - checkSparkAnswerAndOperator(sql(s"SELECT bit_count(col3) FROM $table")) - checkSparkAnswerAndOperator(sql(s"SELECT bit_count(col4) FROM $table")) - checkSparkAnswerAndOperator(sql(s"SELECT bit_count(true) FROM $table")) - checkSparkAnswerAndOperator(sql(s"SELECT bit_count(false) FROM $table")) - } - } - } - } - - test("bitwise_count - random values (spark gen)") { - withTempDir { dir => - val path = new Path(dir.toURI.toString, "test.parquet") - val filename = path.toString - val random = new Random(42) - withSQLConf(CometConf.COMET_ENABLED.key -> "false") { - ParquetGenerator.makeParquetFile( - random, - spark, - filename, - 10, - DataGenOptions( - allowNull = true, - generateNegativeZero = true, - generateArray = false, - generateStruct = false, - generateMap = false)) - } - val table = spark.read.parquet(filename) - val df = - table.selectExpr("bit_count(c1)", "bit_count(c2)", "bit_count(c3)", "bit_count(c4)") - - checkSparkAnswerAndOperator(df) - } - } - - test("bitwise_count - random values (native parquet gen)") { - Seq(true, false).foreach { dictionaryEnabled => - withTempDir { dir => - val path = new Path(dir.toURI.toString, "test.parquet") - makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled, 0, 10000, nullEnabled = false) - val table = spark.read.parquet(path.toString) - checkSparkAnswerAndOperator( - table - .selectExpr( - "bit_count(_2)", - "bit_count(_3)", - "bit_count(_4)", - "bit_count(_5)", - "bit_count(_11)")) - } - } - } - - test("bitwise shift with different left/right types") { - Seq(false, true).foreach { dictionary => - withSQLConf("parquet.enable.dictionary" -> dictionary.toString) { - val table = "test" - withTable(table) { - sql(s"create table $table(col1 long, col2 int) using parquet") - sql(s"insert into $table values(1111, 2)") - sql(s"insert into $table values(1111, 2)") - sql(s"insert into $table values(3333, 4)") - sql(s"insert into $table values(5555, 6)") - - checkSparkAnswerAndOperator( - s"SELECT shiftright(col1, 2), shiftright(col1, col2) FROM $table") - checkSparkAnswerAndOperator( - s"SELECT shiftleft(col1, 2), shiftleft(col1, col2) FROM $table") - } - } - } - } - test("basic data type support") { Seq(true, false).foreach { dictionaryEnabled => withTempDir { dir => @@ -1552,31 +1464,6 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { }) } - test("bitwise expressions") { - Seq(false, true).foreach { dictionary => - withSQLConf("parquet.enable.dictionary" -> dictionary.toString) { - val table = "test" - withTable(table) { - sql(s"create table $table(col1 int, col2 int) using parquet") - sql(s"insert into $table values(1111, 2)") - sql(s"insert into $table values(1111, 2)") - sql(s"insert into $table values(3333, 4)") - sql(s"insert into $table values(5555, 6)") - - checkSparkAnswerAndOperator( - s"SELECT col1 & col2, col1 | col2, col1 ^ col2 FROM $table") - checkSparkAnswerAndOperator( - s"SELECT col1 & 1234, col1 | 1234, col1 ^ 1234 FROM $table") - checkSparkAnswerAndOperator( - s"SELECT shiftright(col1, 2), shiftright(col1, col2) FROM $table") - checkSparkAnswerAndOperator( - s"SELECT shiftleft(col1, 2), shiftleft(col1, col2) FROM $table") - checkSparkAnswerAndOperator(s"SELECT ~(11), ~col1, ~col2 FROM $table") - } - } - } - } - test("test in(set)/not in(set)") { Seq("100", "0").foreach { inSetThreshold => Seq(false, true).foreach { dictionary => --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@datafusion.apache.org For additional commands, e-mail: commits-h...@datafusion.apache.org