This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 5fc6327b feat: Implement Spark-compatible CAST float/double to string
(#346)
5fc6327b is described below
commit 5fc6327bc74249767487fb727aec1e6ef99ba7a3
Author: RickestCode <[email protected]>
AuthorDate: Fri May 3 21:22:32 2024 +0200
feat: Implement Spark-compatible CAST float/double to string (#346)
---
core/src/execution/datafusion/expressions/cast.rs | 105 ++++++++++++++++++++-
.../scala/org/apache/comet/CometCastSuite.scala | 30 +++++-
2 files changed, 129 insertions(+), 6 deletions(-)
diff --git a/core/src/execution/datafusion/expressions/cast.rs
b/core/src/execution/datafusion/expressions/cast.rs
index 7560e0c2..45859c5f 100644
--- a/core/src/execution/datafusion/expressions/cast.rs
+++ b/core/src/execution/datafusion/expressions/cast.rs
@@ -17,7 +17,7 @@
use std::{
any::Any,
- fmt::{Display, Formatter},
+ fmt::{Debug, Display, Formatter},
hash::{Hash, Hasher},
sync::Arc,
};
@@ -31,7 +31,8 @@ use arrow::{
};
use arrow_array::{
types::{Int16Type, Int32Type, Int64Type, Int8Type},
- Array, ArrayRef, BooleanArray, GenericStringArray, OffsetSizeTrait,
PrimitiveArray,
+ Array, ArrayRef, BooleanArray, Float32Array, Float64Array,
GenericStringArray, OffsetSizeTrait,
+ PrimitiveArray,
};
use arrow_schema::{DataType, Schema};
use chrono::{TimeZone, Timelike};
@@ -107,6 +108,74 @@ macro_rules! cast_utf8_to_timestamp {
}};
}
+macro_rules! cast_float_to_string {
+ ($from:expr, $eval_mode:expr, $type:ty, $output_type:ty, $offset_type:ty)
=> {{
+
+ fn cast<OffsetSize>(
+ from: &dyn Array,
+ _eval_mode: EvalMode,
+ ) -> CometResult<ArrayRef>
+ where
+ OffsetSize: OffsetSizeTrait, {
+ let array =
from.as_any().downcast_ref::<$output_type>().unwrap();
+
+ // If the absolute number is less than 10,000,000 and greater
or equal than 0.001, the
+ // result is expressed without scientific notation with at
least one digit on either side of
+ // the decimal point. Otherwise, Spark uses a mantissa
followed by E and an
+ // exponent. The mantissa has an optional leading minus sign
followed by one digit to the
+ // left of the decimal point, and the minimal number of digits
greater than zero to the
+ // right. The exponent has and optional leading minus sign.
+ // source:
https://docs.databricks.com/en/sql/language-manual/functions/cast.html
+
+ const LOWER_SCIENTIFIC_BOUND: $type = 0.001;
+ const UPPER_SCIENTIFIC_BOUND: $type = 10000000.0;
+
+ let output_array = array
+ .iter()
+ .map(|value| match value {
+ Some(value) if value == <$type>::INFINITY =>
Ok(Some("Infinity".to_string())),
+ Some(value) if value == <$type>::NEG_INFINITY =>
Ok(Some("-Infinity".to_string())),
+ Some(value)
+ if (value.abs() < UPPER_SCIENTIFIC_BOUND
+ && value.abs() >= LOWER_SCIENTIFIC_BOUND)
+ || value.abs() == 0.0 =>
+ {
+ let trailing_zero = if value.fract() == 0.0 { ".0"
} else { "" };
+
+ Ok(Some(format!("{value}{trailing_zero}")))
+ }
+ Some(value)
+ if value.abs() >= UPPER_SCIENTIFIC_BOUND
+ || value.abs() < LOWER_SCIENTIFIC_BOUND =>
+ {
+ let formatted = format!("{value:E}");
+
+ if formatted.contains(".") {
+ Ok(Some(formatted))
+ } else {
+ // `formatted` is already in scientific
notation and can be split up by E
+ // in order to add the missing trailing 0
which gets removed for numbers with a fraction of 0.0
+ let prepare_number: Vec<&str> =
formatted.split("E").collect();
+
+ let coefficient = prepare_number[0];
+
+ let exponent = prepare_number[1];
+
+ Ok(Some(format!("{coefficient}.0E{exponent}")))
+ }
+ }
+ Some(value) => Ok(Some(value.to_string())),
+ _ => Ok(None),
+ })
+ .collect::<Result<GenericStringArray<OffsetSize>,
CometError>>()?;
+
+ Ok(Arc::new(output_array))
+ }
+
+ cast::<$offset_type>($from, $eval_mode)
+ }};
+}
+
impl Cast {
pub fn new(
child: Arc<dyn PhysicalExpr>,
@@ -185,6 +254,18 @@ impl Cast {
),
}
}
+ (DataType::Float64, DataType::Utf8) => {
+ Self::spark_cast_float64_to_utf8::<i32>(&array,
self.eval_mode)?
+ }
+ (DataType::Float64, DataType::LargeUtf8) => {
+ Self::spark_cast_float64_to_utf8::<i64>(&array,
self.eval_mode)?
+ }
+ (DataType::Float32, DataType::Utf8) => {
+ Self::spark_cast_float32_to_utf8::<i32>(&array,
self.eval_mode)?
+ }
+ (DataType::Float32, DataType::LargeUtf8) => {
+ Self::spark_cast_float32_to_utf8::<i64>(&array,
self.eval_mode)?
+ }
_ => {
// when we have no Spark-specific casting we delegate to
DataFusion
cast_with_options(&array, to_type, &CAST_OPTIONS)?
@@ -248,6 +329,26 @@ impl Cast {
Ok(cast_array)
}
+ fn spark_cast_float64_to_utf8<OffsetSize>(
+ from: &dyn Array,
+ _eval_mode: EvalMode,
+ ) -> CometResult<ArrayRef>
+ where
+ OffsetSize: OffsetSizeTrait,
+ {
+ cast_float_to_string!(from, _eval_mode, f64, Float64Array, OffsetSize)
+ }
+
+ fn spark_cast_float32_to_utf8<OffsetSize>(
+ from: &dyn Array,
+ _eval_mode: EvalMode,
+ ) -> CometResult<ArrayRef>
+ where
+ OffsetSize: OffsetSizeTrait,
+ {
+ cast_float_to_string!(from, _eval_mode, f32, Float32Array, OffsetSize)
+ }
+
fn spark_cast_utf8_to_boolean<OffsetSize>(
from: &dyn Array,
eval_mode: EvalMode,
diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
index a31f4e68..3be7dcb6 100644
--- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
@@ -329,9 +329,22 @@ class CometCastSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
castTest(generateFloats(), DataTypes.createDecimalType(10, 2))
}
- ignore("cast FloatType to StringType") {
+ test("cast FloatType to StringType") {
// https://github.com/apache/datafusion-comet/issues/312
- castTest(generateFloats(), DataTypes.StringType)
+ val r = new Random(0)
+ val values = Seq(
+ Float.MaxValue,
+ Float.MinValue,
+ Float.NaN,
+ Float.PositiveInfinity,
+ Float.NegativeInfinity,
+ 1.0f,
+ -1.0f,
+ Short.MinValue.toFloat,
+ Short.MaxValue.toFloat,
+ 0.0f) ++
+ Range(0, dataSize).map(_ => r.nextFloat())
+ withNulls(values).toDF("a")
}
ignore("cast FloatType to TimestampType") {
@@ -374,9 +387,18 @@ class CometCastSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
castTest(generateDoubles(), DataTypes.createDecimalType(10, 2))
}
- ignore("cast DoubleType to StringType") {
+ test("cast DoubleType to StringType") {
// https://github.com/apache/datafusion-comet/issues/312
- castTest(generateDoubles(), DataTypes.StringType)
+ val r = new Random(0)
+ val values = Seq(
+ Double.MaxValue,
+ Double.MinValue,
+ Double.NaN,
+ Double.PositiveInfinity,
+ Double.NegativeInfinity,
+ 0.0d) ++
+ Range(0, dataSize).map(_ => r.nextDouble())
+ withNulls(values).toDF("a")
}
ignore("cast DoubleType to TimestampType") {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]