This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new c73ac2e7f feat: implement cast from whole numbers to binary format and 
bool to decimal (#3083)
c73ac2e7f is described below

commit c73ac2e7f79a1b885a19851b5ecf2bb24cb55dfd
Author: B Vadlamani <[email protected]>
AuthorDate: Wed Feb 11 12:38:38 2026 -0800

    feat: implement cast from whole numbers to binary format and bool to 
decimal (#3083)
---
 native/spark-expr/src/conversion_funcs/cast.rs     |  62 ++++++++++--
 .../org/apache/comet/expressions/CometCast.scala   | 108 +++++++++++----------
 .../scala/org/apache/comet/CometCastSuite.scala    |  75 ++++++++------
 3 files changed, 163 insertions(+), 82 deletions(-)

diff --git a/native/spark-expr/src/conversion_funcs/cast.rs 
b/native/spark-expr/src/conversion_funcs/cast.rs
index 5c6533618..be5257477 100644
--- a/native/spark-expr/src/conversion_funcs/cast.rs
+++ b/native/spark-expr/src/conversion_funcs/cast.rs
@@ -16,11 +16,12 @@
 // under the License.
 
 use crate::utils::array_with_timezone;
+use crate::EvalMode::Legacy;
 use crate::{timezone, BinaryOutputStyle};
 use crate::{EvalMode, SparkError, SparkResult};
 use arrow::array::builder::StringBuilder;
 use arrow::array::{
-    BooleanBuilder, Decimal128Builder, DictionaryArray, GenericByteArray, 
ListArray,
+    BinaryBuilder, BooleanBuilder, Decimal128Builder, DictionaryArray, 
GenericByteArray, ListArray,
     PrimitiveBuilder, StringArray, StructArray, TimestampMicrosecondBuilder,
 };
 use arrow::compute::can_cast_types;
@@ -304,14 +305,17 @@ fn can_cast_from_timestamp(to_type: &DataType, _options: 
&SparkCastOptions) -> b
 
 fn can_cast_from_boolean(to_type: &DataType, _: &SparkCastOptions) -> bool {
     use DataType::*;
-    matches!(to_type, Int8 | Int16 | Int32 | Int64 | Float32 | Float64)
+    matches!(
+        to_type,
+        Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Decimal128(_, _)
+    )
 }
 
 fn can_cast_from_byte(to_type: &DataType, _: &SparkCastOptions) -> bool {
     use DataType::*;
     matches!(
         to_type,
-        Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | 
Decimal128(_, _)
+        Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | 
Decimal128(_, _) | Binary
     )
 }
 
@@ -319,14 +323,14 @@ fn can_cast_from_short(to_type: &DataType, _: 
&SparkCastOptions) -> bool {
     use DataType::*;
     matches!(
         to_type,
-        Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | 
Decimal128(_, _)
+        Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | 
Decimal128(_, _) | Binary
     )
 }
 
 fn can_cast_from_int(to_type: &DataType, options: &SparkCastOptions) -> bool {
     use DataType::*;
     match to_type {
-        Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Utf8 => 
true,
+        Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Utf8 | 
Binary => true,
         Decimal128(_, _) => {
             // incompatible: no overflow check
             options.allow_incompat
@@ -338,7 +342,7 @@ fn can_cast_from_int(to_type: &DataType, options: 
&SparkCastOptions) -> bool {
 fn can_cast_from_long(to_type: &DataType, options: &SparkCastOptions) -> bool {
     use DataType::*;
     match to_type {
-        Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 => true,
+        Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Binary => 
true,
         Decimal128(_, _) => {
             // incompatible: no overflow check
             options.allow_incompat
@@ -501,6 +505,29 @@ macro_rules! cast_float_to_string {
     }};
 }
 
+// eval mode is not needed since all ints can be implemented in binary format
+macro_rules! cast_whole_num_to_binary {
+    ($array:expr, $primitive_type:ty, $byte_size:expr) => {{
+        let input_arr = $array
+            .as_any()
+            .downcast_ref::<$primitive_type>()
+            .ok_or_else(|| SparkError::Internal("Expected numeric 
array".to_string()))?;
+
+        let len = input_arr.len();
+        let mut builder = BinaryBuilder::with_capacity(len, len * $byte_size);
+
+        for i in 0..input_arr.len() {
+            if input_arr.is_null(i) {
+                builder.append_null();
+            } else {
+                builder.append_value(input_arr.value(i).to_be_bytes());
+            }
+        }
+
+        Ok(Arc::new(builder.finish()) as ArrayRef)
+    }};
+}
+
 macro_rules! cast_int_to_int_macro {
     (
         $array: expr,
@@ -1101,6 +1128,19 @@ fn cast_array(
         }
         (Binary, Utf8) => Ok(cast_binary_to_string::<i32>(&array, 
cast_options)?),
         (Date32, Timestamp(_, tz)) => Ok(cast_date_to_timestamp(&array, 
cast_options, tz)?),
+        (Int8, Binary) if (eval_mode == Legacy) => 
cast_whole_num_to_binary!(&array, Int8Array, 1),
+        (Int16, Binary) if (eval_mode == Legacy) => {
+            cast_whole_num_to_binary!(&array, Int16Array, 2)
+        }
+        (Int32, Binary) if (eval_mode == Legacy) => {
+            cast_whole_num_to_binary!(&array, Int32Array, 4)
+        }
+        (Int64, Binary) if (eval_mode == Legacy) => {
+            cast_whole_num_to_binary!(&array, Int64Array, 8)
+        }
+        (Boolean, Decimal128(precision, scale)) => {
+            cast_boolean_to_decimal(&array, *precision, *scale)
+        }
         _ if cast_options.is_adapting_schema
             || is_datafusion_spark_compatible(from_type, to_type) =>
         {
@@ -1163,6 +1203,16 @@ fn cast_date_to_timestamp(
     ))
 }
 
+fn cast_boolean_to_decimal(array: &ArrayRef, precision: u8, scale: i8) -> 
SparkResult<ArrayRef> {
+    let bool_array = array.as_boolean();
+    let scaled_val = 10_i128.pow(scale as u32);
+    let result: Decimal128Array = bool_array
+        .iter()
+        .map(|v| v.map(|b| if b { scaled_val } else { 0 }))
+        .collect();
+    Ok(Arc::new(result.with_precision_and_scale(precision, scale)?))
+}
+
 fn cast_string_to_float(
     array: &ArrayRef,
     to_type: &DataType,
diff --git a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala 
b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
index f42a5d8d8..000cc5fd4 100644
--- a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
+++ b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
@@ -147,13 +147,13 @@ object CometCast extends CometExpressionSerde[Cast] with 
CometExprShim {
       case (DataTypes.BooleanType, _) =>
         canCastFromBoolean(toType)
       case (DataTypes.ByteType, _) =>
-        canCastFromByte(toType)
+        canCastFromByte(toType, evalMode)
       case (DataTypes.ShortType, _) =>
-        canCastFromShort(toType)
+        canCastFromShort(toType, evalMode)
       case (DataTypes.IntegerType, _) =>
-        canCastFromInt(toType)
+        canCastFromInt(toType, evalMode)
       case (DataTypes.LongType, _) =>
-        canCastFromLong(toType)
+        canCastFromLong(toType, evalMode)
       case (DataTypes.FloatType, _) =>
         canCastFromFloat(toType)
       case (DataTypes.DoubleType, _) =>
@@ -264,58 +264,68 @@ object CometCast extends CometExpressionSerde[Cast] with 
CometExprShim {
 
   private def canCastFromBoolean(toType: DataType): SupportLevel = toType 
match {
     case DataTypes.ByteType | DataTypes.ShortType | DataTypes.IntegerType | 
DataTypes.LongType |
-        DataTypes.FloatType | DataTypes.DoubleType =>
+        DataTypes.FloatType | DataTypes.DoubleType | _: DecimalType =>
       Compatible()
     case _ => unsupported(DataTypes.BooleanType, toType)
   }
 
-  private def canCastFromByte(toType: DataType): SupportLevel = toType match {
-    case DataTypes.BooleanType =>
-      Compatible()
-    case DataTypes.ShortType | DataTypes.IntegerType | DataTypes.LongType =>
-      Compatible()
-    case DataTypes.FloatType | DataTypes.DoubleType | _: DecimalType =>
-      Compatible()
-    case _ =>
-      unsupported(DataTypes.ByteType, toType)
-  }
+  private def canCastFromByte(toType: DataType, evalMode: 
CometEvalMode.Value): SupportLevel =
+    toType match {
+      case DataTypes.BooleanType =>
+        Compatible()
+      case DataTypes.ShortType | DataTypes.IntegerType | DataTypes.LongType =>
+        Compatible()
+      case DataTypes.FloatType | DataTypes.DoubleType | _: DecimalType =>
+        Compatible()
+      case DataTypes.BinaryType if (evalMode == CometEvalMode.LEGACY) =>
+        Compatible()
+      case _ =>
+        unsupported(DataTypes.ByteType, toType)
+    }
 
-  private def canCastFromShort(toType: DataType): SupportLevel = toType match {
-    case DataTypes.BooleanType =>
-      Compatible()
-    case DataTypes.ByteType | DataTypes.IntegerType | DataTypes.LongType =>
-      Compatible()
-    case DataTypes.FloatType | DataTypes.DoubleType | _: DecimalType =>
-      Compatible()
-    case _ =>
-      unsupported(DataTypes.ShortType, toType)
-  }
+  private def canCastFromShort(toType: DataType, evalMode: 
CometEvalMode.Value): SupportLevel =
+    toType match {
+      case DataTypes.BooleanType =>
+        Compatible()
+      case DataTypes.ByteType | DataTypes.IntegerType | DataTypes.LongType =>
+        Compatible()
+      case DataTypes.FloatType | DataTypes.DoubleType | _: DecimalType =>
+        Compatible()
+      case DataTypes.BinaryType if (evalMode == CometEvalMode.LEGACY) =>
+        Compatible()
+      case _ =>
+        unsupported(DataTypes.ShortType, toType)
+    }
 
-  private def canCastFromInt(toType: DataType): SupportLevel = toType match {
-    case DataTypes.BooleanType =>
-      Compatible()
-    case DataTypes.ByteType | DataTypes.ShortType | DataTypes.LongType =>
-      Compatible()
-    case DataTypes.FloatType | DataTypes.DoubleType =>
-      Compatible()
-    case _: DecimalType =>
-      Compatible()
-    case _ =>
-      unsupported(DataTypes.IntegerType, toType)
-  }
+  private def canCastFromInt(toType: DataType, evalMode: CometEvalMode.Value): 
SupportLevel =
+    toType match {
+      case DataTypes.BooleanType =>
+        Compatible()
+      case DataTypes.ByteType | DataTypes.ShortType | DataTypes.LongType =>
+        Compatible()
+      case DataTypes.FloatType | DataTypes.DoubleType =>
+        Compatible()
+      case _: DecimalType =>
+        Compatible()
+      case DataTypes.BinaryType if (evalMode == CometEvalMode.LEGACY) => 
Compatible()
+      case _ =>
+        unsupported(DataTypes.IntegerType, toType)
+    }
 
-  private def canCastFromLong(toType: DataType): SupportLevel = toType match {
-    case DataTypes.BooleanType =>
-      Compatible()
-    case DataTypes.ByteType | DataTypes.ShortType | DataTypes.IntegerType =>
-      Compatible()
-    case DataTypes.FloatType | DataTypes.DoubleType =>
-      Compatible()
-    case _: DecimalType =>
-      Compatible()
-    case _ =>
-      unsupported(DataTypes.LongType, toType)
-  }
+  private def canCastFromLong(toType: DataType, evalMode: 
CometEvalMode.Value): SupportLevel =
+    toType match {
+      case DataTypes.BooleanType =>
+        Compatible()
+      case DataTypes.ByteType | DataTypes.ShortType | DataTypes.IntegerType =>
+        Compatible()
+      case DataTypes.FloatType | DataTypes.DoubleType =>
+        Compatible()
+      case _: DecimalType =>
+        Compatible()
+      case DataTypes.BinaryType if (evalMode == CometEvalMode.LEGACY) => 
Compatible()
+      case _ =>
+        unsupported(DataTypes.LongType, toType)
+    }
 
   private def canCastFromFloat(toType: DataType): SupportLevel = toType match {
     case DataTypes.BooleanType | DataTypes.DoubleType | DataTypes.ByteType | 
DataTypes.ShortType |
diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala 
b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
index bea701d49..9fc9a1657 100644
--- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
@@ -134,11 +134,18 @@ class CometCastSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
     castTest(generateBools(), DataTypes.DoubleType)
   }
 
-  ignore("cast BooleanType to DecimalType(10,2)") {
-    // Arrow error: Cast error: Casting from Boolean to Decimal128(10, 2) not 
supported
+  test("cast BooleanType to DecimalType(10,2)") {
     castTest(generateBools(), DataTypes.createDecimalType(10, 2))
   }
 
+  test("cast BooleanType to DecimalType(14,4)") {
+    castTest(generateBools(), DataTypes.createDecimalType(14, 4))
+  }
+
+  test("cast BooleanType to DecimalType(30,0)") {
+    castTest(generateBools(), DataTypes.createDecimalType(30, 0))
+  }
+
   test("cast BooleanType to StringType") {
     castTest(generateBools(), DataTypes.StringType)
   }
@@ -206,11 +213,14 @@ class CometCastSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
       hasIncompatibleType = usingParquetExecWithIncompatTypes)
   }
 
-  ignore("cast ByteType to BinaryType") {
+  test("cast ByteType to BinaryType") {
+    //    Spark does not support ANSI or Try mode
     castTest(
       generateBytes(),
       DataTypes.BinaryType,
-      hasIncompatibleType = usingParquetExecWithIncompatTypes)
+      hasIncompatibleType = usingParquetExecWithIncompatTypes,
+      testAnsi = false,
+      testTry = false)
   }
 
   ignore("cast ByteType to TimestampType") {
@@ -280,11 +290,14 @@ class CometCastSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
       hasIncompatibleType = usingParquetExecWithIncompatTypes)
   }
 
-  ignore("cast ShortType to BinaryType") {
+  test("cast ShortType to BinaryType") {
+//    Spark does not support ANSI or Try mode
     castTest(
       generateShorts(),
       DataTypes.BinaryType,
-      hasIncompatibleType = usingParquetExecWithIncompatTypes)
+      hasIncompatibleType = usingParquetExecWithIncompatTypes,
+      testAnsi = false,
+      testTry = false)
   }
 
   ignore("cast ShortType to TimestampType") {
@@ -345,8 +358,9 @@ class CometCastSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
     castTest(generateInts(), DataTypes.StringType)
   }
 
-  ignore("cast IntegerType to BinaryType") {
-    castTest(generateInts(), DataTypes.BinaryType)
+  test("cast IntegerType to BinaryType") {
+    //    Spark does not support ANSI or Try mode
+    castTest(generateInts(), DataTypes.BinaryType, testAnsi = false, testTry = 
false)
   }
 
   ignore("cast IntegerType to TimestampType") {
@@ -391,8 +405,9 @@ class CometCastSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
     castTest(generateLongs(), DataTypes.StringType)
   }
 
-  ignore("cast LongType to BinaryType") {
-    castTest(generateLongs(), DataTypes.BinaryType)
+  test("cast LongType to BinaryType") {
+    //    Spark does not support ANSI or Try mode
+    castTest(generateLongs(), DataTypes.BinaryType, testAnsi = false, testTry 
= false)
   }
 
   ignore("cast LongType to TimestampType") {
@@ -1416,28 +1431,32 @@ class CometCastSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
       input: DataFrame,
       toType: DataType,
       hasIncompatibleType: Boolean = false,
-      testAnsi: Boolean = true): Unit = {
+      testAnsi: Boolean = true,
+      testTry: Boolean = true): Unit = {
 
     withTempPath { dir =>
       val data = roundtripParquet(input, dir).coalesce(1)
-      data.createOrReplaceTempView("t")
 
       withSQLConf((SQLConf.ANSI_ENABLED.key, "false")) {
         // cast() should return null for invalid inputs when ansi mode is 
disabled
-        val df = spark.sql(s"select a, cast(a as ${toType.sql}) from t order 
by a")
+        val df = data.select(col("a"), col("a").cast(toType)).orderBy(col("a"))
         if (hasIncompatibleType) {
           checkSparkAnswer(df)
         } else {
           checkSparkAnswerAndOperator(df)
         }
 
-        // try_cast() should always return null for invalid inputs
-        val df2 =
-          spark.sql(s"select a, try_cast(a as ${toType.sql}) from t order by 
a")
-        if (hasIncompatibleType) {
-          checkSparkAnswer(df2)
-        } else {
-          checkSparkAnswerAndOperator(df2)
+        if (testTry) {
+          data.createOrReplaceTempView("t")
+//          try_cast() should always return null for invalid inputs
+//          not using spark DSL since it `try_cast` is only available from 
Spark 4x
+          val df2 =
+            spark.sql(s"select a, try_cast(a as ${toType.sql}) from t order by 
a")
+          if (hasIncompatibleType) {
+            checkSparkAnswer(df2)
+          } else {
+            checkSparkAnswerAndOperator(df2)
+          }
         }
       }
 
@@ -1495,14 +1514,16 @@ class CometCastSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
           }
 
           // try_cast() should always return null for invalid inputs
-          val df2 =
-            spark.sql(s"select a, try_cast(a as ${toType.sql}) from t order by 
a")
-          if (hasIncompatibleType) {
-            checkSparkAnswer(df2)
-          } else {
-            checkSparkAnswerAndOperator(df2)
+          if (testTry) {
+            data.createOrReplaceTempView("t")
+            val df2 =
+              spark.sql(s"select a, try_cast(a as ${toType.sql}) from t order 
by a")
+            if (hasIncompatibleType) {
+              checkSparkAnswer(df2)
+            } else {
+              checkSparkAnswerAndOperator(df2)
+            }
           }
-
         }
       }
     }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to