This is an automated email from the ASF dual-hosted git repository. sunchao pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push: new 96dfccf fix: Cast string to boolean not compatible with Spark (#107) 96dfccf is described below commit 96dfccf638470407c31b71aaada05d35836e9d93 Author: Eren Avsarogullari <erenavsarogull...@gmail.com> AuthorDate: Sun Feb 25 21:00:45 2024 -0800 fix: Cast string to boolean not compatible with Spark (#107) --- core/src/execution/datafusion/expressions/cast.rs | 40 +++++++++++++++++++--- .../org/apache/comet/CometExpressionSuite.scala | 24 +++++++++++++ 2 files changed, 60 insertions(+), 4 deletions(-) diff --git a/core/src/execution/datafusion/expressions/cast.rs b/core/src/execution/datafusion/expressions/cast.rs index d845068..447c277 100644 --- a/core/src/execution/datafusion/expressions/cast.rs +++ b/core/src/execution/datafusion/expressions/cast.rs @@ -27,7 +27,7 @@ use arrow::{ record_batch::RecordBatch, util::display::FormatOptions, }; -use arrow_array::ArrayRef; +use arrow_array::{Array, ArrayRef, BooleanArray, GenericStringArray, OffsetSizeTrait}; use arrow_schema::{DataType, Schema}; use datafusion::logical_expr::ColumnarValue; use datafusion_common::{Result as DataFusionResult, ScalarValue}; @@ -73,10 +73,42 @@ impl Cast { } fn cast_array(&self, array: ArrayRef) -> DataFusionResult<ArrayRef> { - let array = array_with_timezone(array, self.timezone.clone(), Some(&self.data_type)); + let to_type = &self.data_type; + let array = array_with_timezone(array, self.timezone.clone(), Some(to_type)); let from_type = array.data_type(); - let cast_result = cast_with_options(&array, &self.data_type, &CAST_OPTIONS)?; - Ok(spark_cast(cast_result, from_type, &self.data_type)) + let cast_result = match (from_type, to_type) { + (DataType::Utf8, DataType::Boolean) => Self::spark_cast_utf8_to_boolean::<i32>(&array), + (DataType::LargeUtf8, DataType::Boolean) => { + Self::spark_cast_utf8_to_boolean::<i64>(&array) + } + _ => cast_with_options(&array, to_type, &CAST_OPTIONS)?, + }; + let result = spark_cast(cast_result, from_type, to_type); + Ok(result) + } + + fn spark_cast_utf8_to_boolean<OffsetSize>(from: &dyn Array) -> ArrayRef + where + OffsetSize: OffsetSizeTrait, + { + let array = from + .as_any() + .downcast_ref::<GenericStringArray<OffsetSize>>() + .unwrap(); + + let output_array = array + .iter() + .map(|value| match value { + Some(value) => match value.to_ascii_lowercase().trim() { + "t" | "true" | "y" | "yes" | "1" => Some(true), + "f" | "false" | "n" | "no" | "0" => Some(false), + _ => None, + }, + _ => None, + }) + .collect::<BooleanArray>(); + + Arc::new(output_array) } } diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 66ee275..3f29e95 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -1302,4 +1302,28 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } } + + test("test cast utf8 to boolean as compatible with Spark") { + def testCastedColumn(inputValues: Seq[String]): Unit = { + val table = "test_table" + withTable(table) { + val values = inputValues.map(x => s"('$x')").mkString(",") + sql(s"create table $table(base_column char(20)) using parquet") + sql(s"insert into $table values $values") + checkSparkAnswerAndOperator( + s"select base_column, cast(base_column as boolean) as casted_column from $table") + } + } + + // Supported boolean values as true by both Arrow and Spark + testCastedColumn(inputValues = Seq("t", "true", "y", "yes", "1", "T", "TrUe", "Y", "YES")) + // Supported boolean values as false by both Arrow and Spark + testCastedColumn(inputValues = Seq("f", "false", "n", "no", "0", "F", "FaLSe", "N", "No")) + // Supported boolean values by Arrow but not Spark + testCastedColumn(inputValues = + Seq("TR", "FA", "tr", "tru", "ye", "on", "fa", "fal", "fals", "of", "off")) + // Invalid boolean casting values for Arrow and Spark + testCastedColumn(inputValues = Seq("car", "Truck")) + } + }