This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new c73ac2e7f feat: implement cast from whole numbers to binary format and
bool to decimal (#3083)
c73ac2e7f is described below
commit c73ac2e7f79a1b885a19851b5ecf2bb24cb55dfd
Author: B Vadlamani <[email protected]>
AuthorDate: Wed Feb 11 12:38:38 2026 -0800
feat: implement cast from whole numbers to binary format and bool to
decimal (#3083)
---
native/spark-expr/src/conversion_funcs/cast.rs | 62 ++++++++++--
.../org/apache/comet/expressions/CometCast.scala | 108 +++++++++++----------
.../scala/org/apache/comet/CometCastSuite.scala | 75 ++++++++------
3 files changed, 163 insertions(+), 82 deletions(-)
diff --git a/native/spark-expr/src/conversion_funcs/cast.rs
b/native/spark-expr/src/conversion_funcs/cast.rs
index 5c6533618..be5257477 100644
--- a/native/spark-expr/src/conversion_funcs/cast.rs
+++ b/native/spark-expr/src/conversion_funcs/cast.rs
@@ -16,11 +16,12 @@
// under the License.
use crate::utils::array_with_timezone;
+use crate::EvalMode::Legacy;
use crate::{timezone, BinaryOutputStyle};
use crate::{EvalMode, SparkError, SparkResult};
use arrow::array::builder::StringBuilder;
use arrow::array::{
- BooleanBuilder, Decimal128Builder, DictionaryArray, GenericByteArray,
ListArray,
+ BinaryBuilder, BooleanBuilder, Decimal128Builder, DictionaryArray,
GenericByteArray, ListArray,
PrimitiveBuilder, StringArray, StructArray, TimestampMicrosecondBuilder,
};
use arrow::compute::can_cast_types;
@@ -304,14 +305,17 @@ fn can_cast_from_timestamp(to_type: &DataType, _options:
&SparkCastOptions) -> b
fn can_cast_from_boolean(to_type: &DataType, _: &SparkCastOptions) -> bool {
use DataType::*;
- matches!(to_type, Int8 | Int16 | Int32 | Int64 | Float32 | Float64)
+ matches!(
+ to_type,
+ Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Decimal128(_, _)
+ )
}
fn can_cast_from_byte(to_type: &DataType, _: &SparkCastOptions) -> bool {
use DataType::*;
matches!(
to_type,
- Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 |
Decimal128(_, _)
+ Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 |
Decimal128(_, _) | Binary
)
}
@@ -319,14 +323,14 @@ fn can_cast_from_short(to_type: &DataType, _:
&SparkCastOptions) -> bool {
use DataType::*;
matches!(
to_type,
- Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 |
Decimal128(_, _)
+ Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 |
Decimal128(_, _) | Binary
)
}
fn can_cast_from_int(to_type: &DataType, options: &SparkCastOptions) -> bool {
use DataType::*;
match to_type {
- Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Utf8 =>
true,
+ Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Utf8 |
Binary => true,
Decimal128(_, _) => {
// incompatible: no overflow check
options.allow_incompat
@@ -338,7 +342,7 @@ fn can_cast_from_int(to_type: &DataType, options:
&SparkCastOptions) -> bool {
fn can_cast_from_long(to_type: &DataType, options: &SparkCastOptions) -> bool {
use DataType::*;
match to_type {
- Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 => true,
+ Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Binary =>
true,
Decimal128(_, _) => {
// incompatible: no overflow check
options.allow_incompat
@@ -501,6 +505,29 @@ macro_rules! cast_float_to_string {
}};
}
+// eval mode is not needed since all ints can be implemented in binary format
+macro_rules! cast_whole_num_to_binary {
+ ($array:expr, $primitive_type:ty, $byte_size:expr) => {{
+ let input_arr = $array
+ .as_any()
+ .downcast_ref::<$primitive_type>()
+ .ok_or_else(|| SparkError::Internal("Expected numeric
array".to_string()))?;
+
+ let len = input_arr.len();
+ let mut builder = BinaryBuilder::with_capacity(len, len * $byte_size);
+
+ for i in 0..input_arr.len() {
+ if input_arr.is_null(i) {
+ builder.append_null();
+ } else {
+ builder.append_value(input_arr.value(i).to_be_bytes());
+ }
+ }
+
+ Ok(Arc::new(builder.finish()) as ArrayRef)
+ }};
+}
+
macro_rules! cast_int_to_int_macro {
(
$array: expr,
@@ -1101,6 +1128,19 @@ fn cast_array(
}
(Binary, Utf8) => Ok(cast_binary_to_string::<i32>(&array,
cast_options)?),
(Date32, Timestamp(_, tz)) => Ok(cast_date_to_timestamp(&array,
cast_options, tz)?),
+ (Int8, Binary) if (eval_mode == Legacy) =>
cast_whole_num_to_binary!(&array, Int8Array, 1),
+ (Int16, Binary) if (eval_mode == Legacy) => {
+ cast_whole_num_to_binary!(&array, Int16Array, 2)
+ }
+ (Int32, Binary) if (eval_mode == Legacy) => {
+ cast_whole_num_to_binary!(&array, Int32Array, 4)
+ }
+ (Int64, Binary) if (eval_mode == Legacy) => {
+ cast_whole_num_to_binary!(&array, Int64Array, 8)
+ }
+ (Boolean, Decimal128(precision, scale)) => {
+ cast_boolean_to_decimal(&array, *precision, *scale)
+ }
_ if cast_options.is_adapting_schema
|| is_datafusion_spark_compatible(from_type, to_type) =>
{
@@ -1163,6 +1203,16 @@ fn cast_date_to_timestamp(
))
}
+fn cast_boolean_to_decimal(array: &ArrayRef, precision: u8, scale: i8) ->
SparkResult<ArrayRef> {
+ let bool_array = array.as_boolean();
+ let scaled_val = 10_i128.pow(scale as u32);
+ let result: Decimal128Array = bool_array
+ .iter()
+ .map(|v| v.map(|b| if b { scaled_val } else { 0 }))
+ .collect();
+ Ok(Arc::new(result.with_precision_and_scale(precision, scale)?))
+}
+
fn cast_string_to_float(
array: &ArrayRef,
to_type: &DataType,
diff --git a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
index f42a5d8d8..000cc5fd4 100644
--- a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
+++ b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
@@ -147,13 +147,13 @@ object CometCast extends CometExpressionSerde[Cast] with
CometExprShim {
case (DataTypes.BooleanType, _) =>
canCastFromBoolean(toType)
case (DataTypes.ByteType, _) =>
- canCastFromByte(toType)
+ canCastFromByte(toType, evalMode)
case (DataTypes.ShortType, _) =>
- canCastFromShort(toType)
+ canCastFromShort(toType, evalMode)
case (DataTypes.IntegerType, _) =>
- canCastFromInt(toType)
+ canCastFromInt(toType, evalMode)
case (DataTypes.LongType, _) =>
- canCastFromLong(toType)
+ canCastFromLong(toType, evalMode)
case (DataTypes.FloatType, _) =>
canCastFromFloat(toType)
case (DataTypes.DoubleType, _) =>
@@ -264,58 +264,68 @@ object CometCast extends CometExpressionSerde[Cast] with
CometExprShim {
private def canCastFromBoolean(toType: DataType): SupportLevel = toType
match {
case DataTypes.ByteType | DataTypes.ShortType | DataTypes.IntegerType |
DataTypes.LongType |
- DataTypes.FloatType | DataTypes.DoubleType =>
+ DataTypes.FloatType | DataTypes.DoubleType | _: DecimalType =>
Compatible()
case _ => unsupported(DataTypes.BooleanType, toType)
}
- private def canCastFromByte(toType: DataType): SupportLevel = toType match {
- case DataTypes.BooleanType =>
- Compatible()
- case DataTypes.ShortType | DataTypes.IntegerType | DataTypes.LongType =>
- Compatible()
- case DataTypes.FloatType | DataTypes.DoubleType | _: DecimalType =>
- Compatible()
- case _ =>
- unsupported(DataTypes.ByteType, toType)
- }
+ private def canCastFromByte(toType: DataType, evalMode:
CometEvalMode.Value): SupportLevel =
+ toType match {
+ case DataTypes.BooleanType =>
+ Compatible()
+ case DataTypes.ShortType | DataTypes.IntegerType | DataTypes.LongType =>
+ Compatible()
+ case DataTypes.FloatType | DataTypes.DoubleType | _: DecimalType =>
+ Compatible()
+ case DataTypes.BinaryType if (evalMode == CometEvalMode.LEGACY) =>
+ Compatible()
+ case _ =>
+ unsupported(DataTypes.ByteType, toType)
+ }
- private def canCastFromShort(toType: DataType): SupportLevel = toType match {
- case DataTypes.BooleanType =>
- Compatible()
- case DataTypes.ByteType | DataTypes.IntegerType | DataTypes.LongType =>
- Compatible()
- case DataTypes.FloatType | DataTypes.DoubleType | _: DecimalType =>
- Compatible()
- case _ =>
- unsupported(DataTypes.ShortType, toType)
- }
+ private def canCastFromShort(toType: DataType, evalMode:
CometEvalMode.Value): SupportLevel =
+ toType match {
+ case DataTypes.BooleanType =>
+ Compatible()
+ case DataTypes.ByteType | DataTypes.IntegerType | DataTypes.LongType =>
+ Compatible()
+ case DataTypes.FloatType | DataTypes.DoubleType | _: DecimalType =>
+ Compatible()
+ case DataTypes.BinaryType if (evalMode == CometEvalMode.LEGACY) =>
+ Compatible()
+ case _ =>
+ unsupported(DataTypes.ShortType, toType)
+ }
- private def canCastFromInt(toType: DataType): SupportLevel = toType match {
- case DataTypes.BooleanType =>
- Compatible()
- case DataTypes.ByteType | DataTypes.ShortType | DataTypes.LongType =>
- Compatible()
- case DataTypes.FloatType | DataTypes.DoubleType =>
- Compatible()
- case _: DecimalType =>
- Compatible()
- case _ =>
- unsupported(DataTypes.IntegerType, toType)
- }
+ private def canCastFromInt(toType: DataType, evalMode: CometEvalMode.Value):
SupportLevel =
+ toType match {
+ case DataTypes.BooleanType =>
+ Compatible()
+ case DataTypes.ByteType | DataTypes.ShortType | DataTypes.LongType =>
+ Compatible()
+ case DataTypes.FloatType | DataTypes.DoubleType =>
+ Compatible()
+ case _: DecimalType =>
+ Compatible()
+ case DataTypes.BinaryType if (evalMode == CometEvalMode.LEGACY) =>
Compatible()
+ case _ =>
+ unsupported(DataTypes.IntegerType, toType)
+ }
- private def canCastFromLong(toType: DataType): SupportLevel = toType match {
- case DataTypes.BooleanType =>
- Compatible()
- case DataTypes.ByteType | DataTypes.ShortType | DataTypes.IntegerType =>
- Compatible()
- case DataTypes.FloatType | DataTypes.DoubleType =>
- Compatible()
- case _: DecimalType =>
- Compatible()
- case _ =>
- unsupported(DataTypes.LongType, toType)
- }
+ private def canCastFromLong(toType: DataType, evalMode:
CometEvalMode.Value): SupportLevel =
+ toType match {
+ case DataTypes.BooleanType =>
+ Compatible()
+ case DataTypes.ByteType | DataTypes.ShortType | DataTypes.IntegerType =>
+ Compatible()
+ case DataTypes.FloatType | DataTypes.DoubleType =>
+ Compatible()
+ case _: DecimalType =>
+ Compatible()
+ case DataTypes.BinaryType if (evalMode == CometEvalMode.LEGACY) =>
Compatible()
+ case _ =>
+ unsupported(DataTypes.LongType, toType)
+ }
private def canCastFromFloat(toType: DataType): SupportLevel = toType match {
case DataTypes.BooleanType | DataTypes.DoubleType | DataTypes.ByteType |
DataTypes.ShortType |
diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
index bea701d49..9fc9a1657 100644
--- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
@@ -134,11 +134,18 @@ class CometCastSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
castTest(generateBools(), DataTypes.DoubleType)
}
- ignore("cast BooleanType to DecimalType(10,2)") {
- // Arrow error: Cast error: Casting from Boolean to Decimal128(10, 2) not
supported
+ test("cast BooleanType to DecimalType(10,2)") {
castTest(generateBools(), DataTypes.createDecimalType(10, 2))
}
+ test("cast BooleanType to DecimalType(14,4)") {
+ castTest(generateBools(), DataTypes.createDecimalType(14, 4))
+ }
+
+ test("cast BooleanType to DecimalType(30,0)") {
+ castTest(generateBools(), DataTypes.createDecimalType(30, 0))
+ }
+
test("cast BooleanType to StringType") {
castTest(generateBools(), DataTypes.StringType)
}
@@ -206,11 +213,14 @@ class CometCastSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
hasIncompatibleType = usingParquetExecWithIncompatTypes)
}
- ignore("cast ByteType to BinaryType") {
+ test("cast ByteType to BinaryType") {
+ // Spark does not support ANSI or Try mode
castTest(
generateBytes(),
DataTypes.BinaryType,
- hasIncompatibleType = usingParquetExecWithIncompatTypes)
+ hasIncompatibleType = usingParquetExecWithIncompatTypes,
+ testAnsi = false,
+ testTry = false)
}
ignore("cast ByteType to TimestampType") {
@@ -280,11 +290,14 @@ class CometCastSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
hasIncompatibleType = usingParquetExecWithIncompatTypes)
}
- ignore("cast ShortType to BinaryType") {
+ test("cast ShortType to BinaryType") {
+// Spark does not support ANSI or Try mode
castTest(
generateShorts(),
DataTypes.BinaryType,
- hasIncompatibleType = usingParquetExecWithIncompatTypes)
+ hasIncompatibleType = usingParquetExecWithIncompatTypes,
+ testAnsi = false,
+ testTry = false)
}
ignore("cast ShortType to TimestampType") {
@@ -345,8 +358,9 @@ class CometCastSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
castTest(generateInts(), DataTypes.StringType)
}
- ignore("cast IntegerType to BinaryType") {
- castTest(generateInts(), DataTypes.BinaryType)
+ test("cast IntegerType to BinaryType") {
+ // Spark does not support ANSI or Try mode
+ castTest(generateInts(), DataTypes.BinaryType, testAnsi = false, testTry =
false)
}
ignore("cast IntegerType to TimestampType") {
@@ -391,8 +405,9 @@ class CometCastSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
castTest(generateLongs(), DataTypes.StringType)
}
- ignore("cast LongType to BinaryType") {
- castTest(generateLongs(), DataTypes.BinaryType)
+ test("cast LongType to BinaryType") {
+ // Spark does not support ANSI or Try mode
+ castTest(generateLongs(), DataTypes.BinaryType, testAnsi = false, testTry
= false)
}
ignore("cast LongType to TimestampType") {
@@ -1416,28 +1431,32 @@ class CometCastSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
input: DataFrame,
toType: DataType,
hasIncompatibleType: Boolean = false,
- testAnsi: Boolean = true): Unit = {
+ testAnsi: Boolean = true,
+ testTry: Boolean = true): Unit = {
withTempPath { dir =>
val data = roundtripParquet(input, dir).coalesce(1)
- data.createOrReplaceTempView("t")
withSQLConf((SQLConf.ANSI_ENABLED.key, "false")) {
// cast() should return null for invalid inputs when ansi mode is
disabled
- val df = spark.sql(s"select a, cast(a as ${toType.sql}) from t order
by a")
+ val df = data.select(col("a"), col("a").cast(toType)).orderBy(col("a"))
if (hasIncompatibleType) {
checkSparkAnswer(df)
} else {
checkSparkAnswerAndOperator(df)
}
- // try_cast() should always return null for invalid inputs
- val df2 =
- spark.sql(s"select a, try_cast(a as ${toType.sql}) from t order by
a")
- if (hasIncompatibleType) {
- checkSparkAnswer(df2)
- } else {
- checkSparkAnswerAndOperator(df2)
+ if (testTry) {
+ data.createOrReplaceTempView("t")
+// try_cast() should always return null for invalid inputs
+// not using spark DSL since it `try_cast` is only available from
Spark 4x
+ val df2 =
+ spark.sql(s"select a, try_cast(a as ${toType.sql}) from t order by
a")
+ if (hasIncompatibleType) {
+ checkSparkAnswer(df2)
+ } else {
+ checkSparkAnswerAndOperator(df2)
+ }
}
}
@@ -1495,14 +1514,16 @@ class CometCastSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
}
// try_cast() should always return null for invalid inputs
- val df2 =
- spark.sql(s"select a, try_cast(a as ${toType.sql}) from t order by
a")
- if (hasIncompatibleType) {
- checkSparkAnswer(df2)
- } else {
- checkSparkAnswerAndOperator(df2)
+ if (testTry) {
+ data.createOrReplaceTempView("t")
+ val df2 =
+ spark.sql(s"select a, try_cast(a as ${toType.sql}) from t order
by a")
+ if (hasIncompatibleType) {
+ checkSparkAnswer(df2)
+ } else {
+ checkSparkAnswerAndOperator(df2)
+ }
}
-
}
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]