This is an automated email from the ASF dual-hosted git repository.

sunchao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 96dfccf  fix: Cast string to boolean not compatible with Spark (#107)
96dfccf is described below

commit 96dfccf638470407c31b71aaada05d35836e9d93
Author: Eren Avsarogullari <erenavsarogull...@gmail.com>
AuthorDate: Sun Feb 25 21:00:45 2024 -0800

    fix: Cast string to boolean not compatible with Spark (#107)
---
 core/src/execution/datafusion/expressions/cast.rs  | 40 +++++++++++++++++++---
 .../org/apache/comet/CometExpressionSuite.scala    | 24 +++++++++++++
 2 files changed, 60 insertions(+), 4 deletions(-)

diff --git a/core/src/execution/datafusion/expressions/cast.rs 
b/core/src/execution/datafusion/expressions/cast.rs
index d845068..447c277 100644
--- a/core/src/execution/datafusion/expressions/cast.rs
+++ b/core/src/execution/datafusion/expressions/cast.rs
@@ -27,7 +27,7 @@ use arrow::{
     record_batch::RecordBatch,
     util::display::FormatOptions,
 };
-use arrow_array::ArrayRef;
+use arrow_array::{Array, ArrayRef, BooleanArray, GenericStringArray, 
OffsetSizeTrait};
 use arrow_schema::{DataType, Schema};
 use datafusion::logical_expr::ColumnarValue;
 use datafusion_common::{Result as DataFusionResult, ScalarValue};
@@ -73,10 +73,42 @@ impl Cast {
     }
 
     fn cast_array(&self, array: ArrayRef) -> DataFusionResult<ArrayRef> {
-        let array = array_with_timezone(array, self.timezone.clone(), 
Some(&self.data_type));
+        let to_type = &self.data_type;
+        let array = array_with_timezone(array, self.timezone.clone(), 
Some(to_type));
         let from_type = array.data_type();
-        let cast_result = cast_with_options(&array, &self.data_type, 
&CAST_OPTIONS)?;
-        Ok(spark_cast(cast_result, from_type, &self.data_type))
+        let cast_result = match (from_type, to_type) {
+            (DataType::Utf8, DataType::Boolean) => 
Self::spark_cast_utf8_to_boolean::<i32>(&array),
+            (DataType::LargeUtf8, DataType::Boolean) => {
+                Self::spark_cast_utf8_to_boolean::<i64>(&array)
+            }
+            _ => cast_with_options(&array, to_type, &CAST_OPTIONS)?,
+        };
+        let result = spark_cast(cast_result, from_type, to_type);
+        Ok(result)
+    }
+
+    fn spark_cast_utf8_to_boolean<OffsetSize>(from: &dyn Array) -> ArrayRef
+    where
+        OffsetSize: OffsetSizeTrait,
+    {
+        let array = from
+            .as_any()
+            .downcast_ref::<GenericStringArray<OffsetSize>>()
+            .unwrap();
+
+        let output_array = array
+            .iter()
+            .map(|value| match value {
+                Some(value) => match value.to_ascii_lowercase().trim() {
+                    "t" | "true" | "y" | "yes" | "1" => Some(true),
+                    "f" | "false" | "n" | "no" | "0" => Some(false),
+                    _ => None,
+                },
+                _ => None,
+            })
+            .collect::<BooleanArray>();
+
+        Arc::new(output_array)
     }
 }
 
diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala 
b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
index 66ee275..3f29e95 100644
--- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
@@ -1302,4 +1302,28 @@ class CometExpressionSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
       }
     }
   }
+
+  test("test cast utf8 to boolean as compatible with Spark") {
+    def testCastedColumn(inputValues: Seq[String]): Unit = {
+      val table = "test_table"
+      withTable(table) {
+        val values = inputValues.map(x => s"('$x')").mkString(",")
+        sql(s"create table $table(base_column char(20)) using parquet")
+        sql(s"insert into $table values $values")
+        checkSparkAnswerAndOperator(
+          s"select base_column, cast(base_column as boolean) as casted_column 
from $table")
+      }
+    }
+
+    // Supported boolean values as true by both Arrow and Spark
+    testCastedColumn(inputValues = Seq("t", "true", "y", "yes", "1", "T", 
"TrUe", "Y", "YES"))
+    // Supported boolean values as false by both Arrow and Spark
+    testCastedColumn(inputValues = Seq("f", "false", "n", "no", "0", "F", 
"FaLSe", "N", "No"))
+    // Supported boolean values by Arrow but not Spark
+    testCastedColumn(inputValues =
+      Seq("TR", "FA", "tr", "tru", "ye", "on", "fa", "fal", "fals", "of", 
"off"))
+    // Invalid boolean casting values for Arrow and Spark
+    testCastedColumn(inputValues = Seq("car", "Truck"))
+  }
+
 }

Reply via email to