This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new c261af34 feat: Implement Spark-compatible CAST from non-integral
numeric types to integral types (#399)
c261af34 is described below
commit c261af34527c17d1b5eb74d6c06f1f657819fe11
Author: Rohit Rastogi <[email protected]>
AuthorDate: Wed May 8 12:40:36 2024 -0700
feat: Implement Spark-compatible CAST from non-integral numeric types to
integral types (#399)
* WIP - float to int, sketchy
* WIP - extremely ugly but functional
* WIP - use macro
* simply further
* delete dead code
* make format
* progress on decimals
* refactor
* format decimal value in overflow exception
* wip - have to use 4 macros, need more decimal tests
* ready for review
* forgot to commit whoops
* bad merge
* address pr comments
* commit missed compatibility
* improve error message
* improve error message again
* revert perf reression in cast_int_to_int_macro
* remove branching in loop for legacy case
---------
Co-authored-by: Rohit Rastogi <[email protected]>
---
core/src/execution/datafusion/expressions/cast.rs | 418 ++++++++++++++++++++-
docs/source/user-guide/compatibility.md | 12 +
.../org/apache/comet/expressions/CometCast.scala | 12 +-
.../scala/org/apache/comet/CometCastSuite.scala | 116 ++++--
4 files changed, 511 insertions(+), 47 deletions(-)
diff --git a/core/src/execution/datafusion/expressions/cast.rs
b/core/src/execution/datafusion/expressions/cast.rs
index a6e3adac..2ad9c40d 100644
--- a/core/src/execution/datafusion/expressions/cast.rs
+++ b/core/src/execution/datafusion/expressions/cast.rs
@@ -31,8 +31,8 @@ use arrow::{
};
use arrow_array::{
types::{Int16Type, Int32Type, Int64Type, Int8Type},
- Array, ArrayRef, BooleanArray, Float32Array, Float64Array,
GenericStringArray, OffsetSizeTrait,
- PrimitiveArray,
+ Array, ArrayRef, BooleanArray, Decimal128Array, Float32Array,
Float64Array, GenericStringArray,
+ Int16Array, Int32Array, Int64Array, Int8Array, OffsetSizeTrait,
PrimitiveArray,
};
use arrow_schema::{DataType, Schema};
use chrono::{TimeZone, Timelike};
@@ -214,11 +214,11 @@ macro_rules! cast_int_to_int_macro {
Some(value) => {
let res = <$to_native_type>::try_from(value);
if res.is_err() {
- Err(CometError::CastOverFlow {
- value: value.to_string() +
spark_int_literal_suffix,
- from_type:
$spark_from_data_type_name.to_string(),
- to_type: $spark_to_data_type_name.to_string(),
- })
+ Err(cast_overflow(
+ &(value.to_string() +
spark_int_literal_suffix),
+ $spark_from_data_type_name,
+ $spark_to_data_type_name,
+ ))
} else {
Ok::<Option<$to_native_type>,
CometError>(Some(res.unwrap()))
}
@@ -232,6 +232,240 @@ macro_rules! cast_int_to_int_macro {
}};
}
+// When Spark casts to Byte/Short Types, it does not cast directly to
Byte/Short.
+// It casts to Int first and then to Byte/Short. Because of potential
overflows in the Int cast,
+// this can cause unexpected Short/Byte cast results. Replicate this behavior.
+macro_rules! cast_float_to_int16_down {
+ (
+ $array:expr,
+ $eval_mode:expr,
+ $src_array_type:ty,
+ $dest_array_type:ty,
+ $rust_src_type:ty,
+ $rust_dest_type:ty,
+ $src_type_str:expr,
+ $dest_type_str:expr,
+ $format_str:expr
+ ) => {{
+ let cast_array = $array
+ .as_any()
+ .downcast_ref::<$src_array_type>()
+ .expect(concat!("Expected a ", stringify!($src_array_type)));
+
+ let output_array = match $eval_mode {
+ EvalMode::Ansi => cast_array
+ .iter()
+ .map(|value| match value {
+ Some(value) => {
+ let is_overflow = value.is_nan() || value.abs() as i32
== std::i32::MAX;
+ if is_overflow {
+ return Err(cast_overflow(
+ &format!($format_str, value).replace("e", "E"),
+ $src_type_str,
+ $dest_type_str,
+ ));
+ }
+ let i32_value = value as i32;
+ <$rust_dest_type>::try_from(i32_value)
+ .map_err(|_| {
+ cast_overflow(
+ &format!($format_str, value).replace("e",
"E"),
+ $src_type_str,
+ $dest_type_str,
+ )
+ })
+ .map(Some)
+ }
+ None => Ok(None),
+ })
+ .collect::<Result<$dest_array_type, _>>()?,
+ _ => cast_array
+ .iter()
+ .map(|value| match value {
+ Some(value) => {
+ let i32_value = value as i32;
+ Ok::<Option<$rust_dest_type>, CometError>(Some(
+ i32_value as $rust_dest_type,
+ ))
+ }
+ None => Ok(None),
+ })
+ .collect::<Result<$dest_array_type, _>>()?,
+ };
+ Ok(Arc::new(output_array) as ArrayRef)
+ }};
+}
+
+macro_rules! cast_float_to_int32_up {
+ (
+ $array:expr,
+ $eval_mode:expr,
+ $src_array_type:ty,
+ $dest_array_type:ty,
+ $rust_src_type:ty,
+ $rust_dest_type:ty,
+ $src_type_str:expr,
+ $dest_type_str:expr,
+ $max_dest_val:expr,
+ $format_str:expr
+ ) => {{
+ let cast_array = $array
+ .as_any()
+ .downcast_ref::<$src_array_type>()
+ .expect(concat!("Expected a ", stringify!($src_array_type)));
+
+ let output_array = match $eval_mode {
+ EvalMode::Ansi => cast_array
+ .iter()
+ .map(|value| match value {
+ Some(value) => {
+ let is_overflow =
+ value.is_nan() || value.abs() as $rust_dest_type
== $max_dest_val;
+ if is_overflow {
+ return Err(cast_overflow(
+ &format!($format_str, value).replace("e", "E"),
+ $src_type_str,
+ $dest_type_str,
+ ));
+ }
+ Ok(Some(value as $rust_dest_type))
+ }
+ None => Ok(None),
+ })
+ .collect::<Result<$dest_array_type, _>>()?,
+ _ => cast_array
+ .iter()
+ .map(|value| match value {
+ Some(value) => {
+ Ok::<Option<$rust_dest_type>, CometError>(Some(value
as $rust_dest_type))
+ }
+ None => Ok(None),
+ })
+ .collect::<Result<$dest_array_type, _>>()?,
+ };
+ Ok(Arc::new(output_array) as ArrayRef)
+ }};
+}
+
+// When Spark casts to Byte/Short Types, it does not cast directly to
Byte/Short.
+// It casts to Int first and then to Byte/Short. Because of potential
overflows in the Int cast,
+// this can cause unexpected Short/Byte cast results. Replicate this behavior.
+macro_rules! cast_decimal_to_int16_down {
+ (
+ $array:expr,
+ $eval_mode:expr,
+ $dest_array_type:ty,
+ $rust_dest_type:ty,
+ $dest_type_str:expr,
+ $precision:expr,
+ $scale:expr
+ ) => {{
+ let cast_array = $array
+ .as_any()
+ .downcast_ref::<Decimal128Array>()
+ .expect(concat!("Expected a Decimal128ArrayType"));
+
+ let output_array = match $eval_mode {
+ EvalMode::Ansi => cast_array
+ .iter()
+ .map(|value| match value {
+ Some(value) => {
+ let divisor = 10_i128.pow($scale as u32);
+ let (truncated, decimal) = (value / divisor, (value %
divisor).abs());
+ let is_overflow = truncated.abs() >
std::i32::MAX.into();
+ if is_overflow {
+ return Err(cast_overflow(
+ &format!("{}.{}BD", truncated, decimal),
+ &format!("DECIMAL({},{})", $precision, $scale),
+ $dest_type_str,
+ ));
+ }
+ let i32_value = truncated as i32;
+ <$rust_dest_type>::try_from(i32_value)
+ .map_err(|_| {
+ cast_overflow(
+ &format!("{}.{}BD", truncated, decimal),
+ &format!("DECIMAL({},{})", $precision,
$scale),
+ $dest_type_str,
+ )
+ })
+ .map(Some)
+ }
+ None => Ok(None),
+ })
+ .collect::<Result<$dest_array_type, _>>()?,
+ _ => cast_array
+ .iter()
+ .map(|value| match value {
+ Some(value) => {
+ let divisor = 10_i128.pow($scale as u32);
+ let i32_value = (value / divisor) as i32;
+ Ok::<Option<$rust_dest_type>, CometError>(Some(
+ i32_value as $rust_dest_type,
+ ))
+ }
+ None => Ok(None),
+ })
+ .collect::<Result<$dest_array_type, _>>()?,
+ };
+ Ok(Arc::new(output_array) as ArrayRef)
+ }};
+}
+
+macro_rules! cast_decimal_to_int32_up {
+ (
+ $array:expr,
+ $eval_mode:expr,
+ $dest_array_type:ty,
+ $rust_dest_type:ty,
+ $dest_type_str:expr,
+ $max_dest_val:expr,
+ $precision:expr,
+ $scale:expr
+ ) => {{
+ let cast_array = $array
+ .as_any()
+ .downcast_ref::<Decimal128Array>()
+ .expect(concat!("Expected a Decimal128ArrayType"));
+
+ let output_array = match $eval_mode {
+ EvalMode::Ansi => cast_array
+ .iter()
+ .map(|value| match value {
+ Some(value) => {
+ let divisor = 10_i128.pow($scale as u32);
+ let (truncated, decimal) = (value / divisor, (value %
divisor).abs());
+ let is_overflow = truncated.abs() >
$max_dest_val.into();
+ if is_overflow {
+ return Err(cast_overflow(
+ &format!("{}.{}BD", truncated, decimal),
+ &format!("DECIMAL({},{})", $precision, $scale),
+ $dest_type_str,
+ ));
+ }
+ Ok(Some(truncated as $rust_dest_type))
+ }
+ None => Ok(None),
+ })
+ .collect::<Result<$dest_array_type, _>>()?,
+ _ => cast_array
+ .iter()
+ .map(|value| match value {
+ Some(value) => {
+ let divisor = 10_i128.pow($scale as u32);
+ let truncated = value / divisor;
+ Ok::<Option<$rust_dest_type>, CometError>(Some(
+ truncated as $rust_dest_type,
+ ))
+ }
+ None => Ok(None),
+ })
+ .collect::<Result<$dest_array_type, _>>()?,
+ };
+ Ok(Arc::new(output_array) as ArrayRef)
+ }};
+}
+
impl Cast {
pub fn new(
child: Arc<dyn PhysicalExpr>,
@@ -332,6 +566,27 @@ impl Cast {
(DataType::Float32, DataType::LargeUtf8) => {
Self::spark_cast_float32_to_utf8::<i64>(&array,
self.eval_mode)?
}
+ (DataType::Float32, DataType::Int8)
+ | (DataType::Float32, DataType::Int16)
+ | (DataType::Float32, DataType::Int32)
+ | (DataType::Float32, DataType::Int64)
+ | (DataType::Float64, DataType::Int8)
+ | (DataType::Float64, DataType::Int16)
+ | (DataType::Float64, DataType::Int32)
+ | (DataType::Float64, DataType::Int64)
+ | (DataType::Decimal128(_, _), DataType::Int8)
+ | (DataType::Decimal128(_, _), DataType::Int16)
+ | (DataType::Decimal128(_, _), DataType::Int32)
+ | (DataType::Decimal128(_, _), DataType::Int64)
+ if self.eval_mode != EvalMode::Try =>
+ {
+ Self::spark_cast_nonintegral_numeric_to_integral(
+ &array,
+ self.eval_mode,
+ from_type,
+ to_type,
+ )?
+ }
_ => {
// when we have no Spark-specific casting we delegate to
DataFusion
cast_with_options(&array, to_type, &CAST_OPTIONS)?
@@ -478,6 +733,146 @@ impl Cast {
Ok(Arc::new(output_array))
}
+
+ fn spark_cast_nonintegral_numeric_to_integral(
+ array: &dyn Array,
+ eval_mode: EvalMode,
+ from_type: &DataType,
+ to_type: &DataType,
+ ) -> CometResult<ArrayRef> {
+ match (from_type, to_type) {
+ (DataType::Float32, DataType::Int8) => cast_float_to_int16_down!(
+ array,
+ eval_mode,
+ Float32Array,
+ Int8Array,
+ f32,
+ i8,
+ "FLOAT",
+ "TINYINT",
+ "{:e}"
+ ),
+ (DataType::Float32, DataType::Int16) => cast_float_to_int16_down!(
+ array,
+ eval_mode,
+ Float32Array,
+ Int16Array,
+ f32,
+ i16,
+ "FLOAT",
+ "SMALLINT",
+ "{:e}"
+ ),
+ (DataType::Float32, DataType::Int32) => cast_float_to_int32_up!(
+ array,
+ eval_mode,
+ Float32Array,
+ Int32Array,
+ f32,
+ i32,
+ "FLOAT",
+ "INT",
+ std::i32::MAX,
+ "{:e}"
+ ),
+ (DataType::Float32, DataType::Int64) => cast_float_to_int32_up!(
+ array,
+ eval_mode,
+ Float32Array,
+ Int64Array,
+ f32,
+ i64,
+ "FLOAT",
+ "BIGINT",
+ std::i64::MAX,
+ "{:e}"
+ ),
+ (DataType::Float64, DataType::Int8) => cast_float_to_int16_down!(
+ array,
+ eval_mode,
+ Float64Array,
+ Int8Array,
+ f64,
+ i8,
+ "DOUBLE",
+ "TINYINT",
+ "{:e}D"
+ ),
+ (DataType::Float64, DataType::Int16) => cast_float_to_int16_down!(
+ array,
+ eval_mode,
+ Float64Array,
+ Int16Array,
+ f64,
+ i16,
+ "DOUBLE",
+ "SMALLINT",
+ "{:e}D"
+ ),
+ (DataType::Float64, DataType::Int32) => cast_float_to_int32_up!(
+ array,
+ eval_mode,
+ Float64Array,
+ Int32Array,
+ f64,
+ i32,
+ "DOUBLE",
+ "INT",
+ std::i32::MAX,
+ "{:e}D"
+ ),
+ (DataType::Float64, DataType::Int64) => cast_float_to_int32_up!(
+ array,
+ eval_mode,
+ Float64Array,
+ Int64Array,
+ f64,
+ i64,
+ "DOUBLE",
+ "BIGINT",
+ std::i64::MAX,
+ "{:e}D"
+ ),
+ (DataType::Decimal128(precision, scale), DataType::Int8) => {
+ cast_decimal_to_int16_down!(
+ array, eval_mode, Int8Array, i8, "TINYINT", precision,
*scale
+ )
+ }
+ (DataType::Decimal128(precision, scale), DataType::Int16) => {
+ cast_decimal_to_int16_down!(
+ array, eval_mode, Int16Array, i16, "SMALLINT", precision,
*scale
+ )
+ }
+ (DataType::Decimal128(precision, scale), DataType::Int32) => {
+ cast_decimal_to_int32_up!(
+ array,
+ eval_mode,
+ Int32Array,
+ i32,
+ "INT",
+ std::i32::MAX,
+ *precision,
+ *scale
+ )
+ }
+ (DataType::Decimal128(precision, scale), DataType::Int64) => {
+ cast_decimal_to_int32_up!(
+ array,
+ eval_mode,
+ Int64Array,
+ i64,
+ "BIGINT",
+ std::i64::MAX,
+ *precision,
+ *scale
+ )
+ }
+ _ => unreachable!(
+ "{}",
+ format!("invalid cast from non-integral numeric type:
{from_type} to integral numeric type: {to_type}")
+ ),
+ }
+ }
}
/// Equivalent to org.apache.spark.unsafe.types.UTF8String.toByte
@@ -676,6 +1071,15 @@ fn invalid_value(value: &str, from_type: &str, to_type:
&str) -> CometError {
}
}
+#[inline]
+fn cast_overflow(value: &str, from_type: &str, to_type: &str) -> CometError {
+ CometError::CastOverFlow {
+ value: value.to_string(),
+ from_type: from_type.to_string(),
+ to_type: to_type.to_string(),
+ }
+}
+
impl Display for Cast {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(
diff --git a/docs/source/user-guide/compatibility.md
b/docs/source/user-guide/compatibility.md
index 57a4271f..2fd4b09b 100644
--- a/docs/source/user-guide/compatibility.md
+++ b/docs/source/user-guide/compatibility.md
@@ -88,11 +88,23 @@ The following cast operations are generally compatible with
Spark except for the
| long | double | |
| long | string | |
| float | boolean | |
+| float | byte | |
+| float | short | |
+| float | integer | |
+| float | long | |
| float | double | |
| float | string | There can be differences in precision. For example, the
input "1.4E-45" will produce 1.0E-45 instead of 1.4E-45 |
| double | boolean | |
+| double | byte | |
+| double | short | |
+| double | integer | |
+| double | long | |
| double | float | |
| double | string | There can be differences in precision. For example, the
input "1.4E-45" will produce 1.0E-45 instead of 1.4E-45 |
+| decimal | byte | |
+| decimal | short | |
+| decimal | integer | |
+| decimal | long | |
| decimal | float | |
| decimal | double | |
| string | boolean | |
diff --git a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
index 57e07b8c..5c225e3b 100644
--- a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
+++ b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
@@ -226,19 +226,25 @@ object CometCast {
}
private def canCastFromFloat(toType: DataType): SupportLevel = toType match {
- case DataTypes.BooleanType | DataTypes.DoubleType => Compatible()
+ case DataTypes.BooleanType | DataTypes.DoubleType | DataTypes.ByteType |
DataTypes.ShortType |
+ DataTypes.IntegerType | DataTypes.LongType =>
+ Compatible()
case _: DecimalType => Incompatible(Some("No overflow check"))
case _ => Unsupported
}
private def canCastFromDouble(toType: DataType): SupportLevel = toType match
{
- case DataTypes.BooleanType | DataTypes.FloatType => Compatible()
+ case DataTypes.BooleanType | DataTypes.FloatType | DataTypes.ByteType |
DataTypes.ShortType |
+ DataTypes.IntegerType | DataTypes.LongType =>
+ Compatible()
case _: DecimalType => Incompatible(Some("No overflow check"))
case _ => Unsupported
}
private def canCastFromDecimal(toType: DataType): SupportLevel = toType
match {
- case DataTypes.FloatType | DataTypes.DoubleType => Compatible()
+ case DataTypes.FloatType | DataTypes.DoubleType | DataTypes.ByteType |
DataTypes.ShortType |
+ DataTypes.IntegerType | DataTypes.LongType =>
+ Compatible()
case _ => Unsupported
}
diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
index 219feca1..827f4238 100644
--- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.Cast
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.types.{DataType, DataTypes}
+import org.apache.spark.sql.types.{DataType, DataTypes, DecimalType}
import org.apache.comet.expressions.{CometCast, Compatible}
@@ -320,23 +320,19 @@ class CometCastSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
castTest(generateFloats(), DataTypes.BooleanType)
}
- ignore("cast FloatType to ByteType") {
- // https://github.com/apache/datafusion-comet/issues/350
+ test("cast FloatType to ByteType") {
castTest(generateFloats(), DataTypes.ByteType)
}
- ignore("cast FloatType to ShortType") {
- // https://github.com/apache/datafusion-comet/issues/350
+ test("cast FloatType to ShortType") {
castTest(generateFloats(), DataTypes.ShortType)
}
- ignore("cast FloatType to IntegerType") {
- // https://github.com/apache/datafusion-comet/issues/350
+ test("cast FloatType to IntegerType") {
castTest(generateFloats(), DataTypes.IntegerType)
}
- ignore("cast FloatType to LongType") {
- // https://github.com/apache/datafusion-comet/issues/350
+ test("cast FloatType to LongType") {
castTest(generateFloats(), DataTypes.LongType)
}
@@ -378,23 +374,19 @@ class CometCastSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
castTest(generateDoubles(), DataTypes.BooleanType)
}
- ignore("cast DoubleType to ByteType") {
- // https://github.com/apache/datafusion-comet/issues/350
+ test("cast DoubleType to ByteType") {
castTest(generateDoubles(), DataTypes.ByteType)
}
- ignore("cast DoubleType to ShortType") {
- // https://github.com/apache/datafusion-comet/issues/350
+ test("cast DoubleType to ShortType") {
castTest(generateDoubles(), DataTypes.ShortType)
}
- ignore("cast DoubleType to IntegerType") {
- // https://github.com/apache/datafusion-comet/issues/350
+ test("cast DoubleType to IntegerType") {
castTest(generateDoubles(), DataTypes.IntegerType)
}
- ignore("cast DoubleType to LongType") {
- // https://github.com/apache/datafusion-comet/issues/350
+ test("cast DoubleType to LongType") {
castTest(generateDoubles(), DataTypes.LongType)
}
@@ -430,45 +422,57 @@ class CometCastSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
ignore("cast DecimalType(10,2) to BooleanType") {
// Arrow error: Cast error: Casting from Decimal128(38, 18) to Boolean not
supported
- castTest(generateDecimals(), DataTypes.BooleanType)
+ castTest(generateDecimalsPrecision10Scale2(), DataTypes.BooleanType)
}
- ignore("cast DecimalType(10,2) to ByteType") {
- // https://github.com/apache/datafusion-comet/issues/350
- castTest(generateDecimals(), DataTypes.ByteType)
+ test("cast DecimalType(10,2) to ByteType") {
+ castTest(generateDecimalsPrecision10Scale2(), DataTypes.ByteType)
}
- ignore("cast DecimalType(10,2) to ShortType") {
- // https://github.com/apache/datafusion-comet/issues/350
- castTest(generateDecimals(), DataTypes.ShortType)
+ test("cast DecimalType(10,2) to ShortType") {
+ castTest(generateDecimalsPrecision10Scale2(), DataTypes.ShortType)
}
- ignore("cast DecimalType(10,2) to IntegerType") {
- // https://github.com/apache/datafusion-comet/issues/350
- castTest(generateDecimals(), DataTypes.IntegerType)
+ test("cast DecimalType(10,2) to IntegerType") {
+ castTest(generateDecimalsPrecision10Scale2(), DataTypes.IntegerType)
}
- ignore("cast DecimalType(10,2) to LongType") {
- // https://github.com/apache/datafusion-comet/issues/350
- castTest(generateDecimals(), DataTypes.LongType)
+ test("cast DecimalType(10,2) to LongType") {
+ castTest(generateDecimalsPrecision10Scale2(), DataTypes.LongType)
}
test("cast DecimalType(10,2) to FloatType") {
- castTest(generateDecimals(), DataTypes.FloatType)
+ castTest(generateDecimalsPrecision10Scale2(), DataTypes.FloatType)
}
test("cast DecimalType(10,2) to DoubleType") {
- castTest(generateDecimals(), DataTypes.DoubleType)
+ castTest(generateDecimalsPrecision10Scale2(), DataTypes.DoubleType)
+ }
+
+ test("cast DecimalType(38,18) to ByteType") {
+ castTest(generateDecimalsPrecision38Scale18(), DataTypes.ByteType)
+ }
+
+ test("cast DecimalType(38,18) to ShortType") {
+ castTest(generateDecimalsPrecision38Scale18(), DataTypes.ShortType)
+ }
+
+ test("cast DecimalType(38,18) to IntegerType") {
+ castTest(generateDecimalsPrecision38Scale18(), DataTypes.IntegerType)
+ }
+
+ test("cast DecimalType(38,18) to LongType") {
+ castTest(generateDecimalsPrecision38Scale18(), DataTypes.LongType)
}
ignore("cast DecimalType(10,2) to StringType") {
// input: 0E-18, expected: 0E-18, actual: 0.000000000000000000
- castTest(generateDecimals(), DataTypes.StringType)
+ castTest(generateDecimalsPrecision10Scale2(), DataTypes.StringType)
}
ignore("cast DecimalType(10,2) to TimestampType") {
// input: -123456.789000000000000000, expected: 1969-12-30 05:42:23.211,
actual: 1969-12-31 15:59:59.876544
- castTest(generateDecimals(), DataTypes.TimestampType)
+ castTest(generateDecimalsPrecision10Scale2(), DataTypes.TimestampType)
}
// CAST from StringType
@@ -800,9 +804,47 @@ class CometCastSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
withNulls(values).toDF("a")
}
- private def generateDecimals(): DataFrame = {
- // TODO improve this
- val values = Seq(BigDecimal("123456.789"), BigDecimal("-123456.789"),
BigDecimal("0.0"))
+ private def generateDecimalsPrecision10Scale2(): DataFrame = {
+ val values = Seq(
+ BigDecimal("-99999999.999"),
+ BigDecimal("-123456.789"),
+ BigDecimal("-32768.678"),
+ // Short Min
+ BigDecimal("-32767.123"),
+ BigDecimal("-128.12312"),
+ // Byte Min
+ BigDecimal("-127.123"),
+ BigDecimal("0.0"),
+ // Byte Max
+ BigDecimal("127.123"),
+ BigDecimal("128.12312"),
+ BigDecimal("32767.122"),
+ // Short Max
+ BigDecimal("32768.678"),
+ BigDecimal("123456.789"),
+ BigDecimal("99999999.999"))
+ withNulls(values).toDF("b").withColumn("a", col("b").cast(DecimalType(10,
2))).drop("b")
+ }
+
+ private def generateDecimalsPrecision38Scale18(): DataFrame = {
+ val values = Seq(
+ BigDecimal("-99999999999999999999.999999999999"),
+ BigDecimal("-9223372036854775808.234567"),
+ // Long Min
+ BigDecimal("-9223372036854775807.123123"),
+ BigDecimal("-2147483648.123123123"),
+ // Int Min
+ BigDecimal("-2147483647.123123123"),
+ BigDecimal("-123456.789"),
+ BigDecimal("0.00000000000"),
+ BigDecimal("123456.789"),
+ // Int Max
+ BigDecimal("2147483647.123123123"),
+ BigDecimal("2147483648.123123123"),
+ BigDecimal("9223372036854775807.123123"),
+ // Long Max
+ BigDecimal("9223372036854775808.234567"),
+ BigDecimal("99999999999999999999.999999999999"))
withNulls(values).toDF("a")
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]