This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 863b40f5 chore: Refactoring of CometError/SparkError (#655)
863b40f5 is described below
commit 863b40f5f99f4116ad627accbc863fecb9017d50
Author: Andy Grove <[email protected]>
AuthorDate: Fri Jul 12 05:52:16 2024 -0600
chore: Refactoring of CometError/SparkError (#655)
---
native/Cargo.lock | 1 +
native/Cargo.toml | 1 +
native/core/Cargo.toml | 2 +-
native/core/src/errors.rs | 41 +------
.../src/execution/datafusion/expressions/cast.rs | 126 ++++++++++-----------
.../src/execution/datafusion/expressions/mod.rs | 6 +-
.../execution/datafusion/expressions/negative.rs | 6 +-
native/spark-expr/Cargo.toml | 1 +
native/spark-expr/src/abs.rs | 7 +-
native/spark-expr/src/error.rs | 73 ++++++++++++
native/spark-expr/src/lib.rs | 21 +---
11 files changed, 157 insertions(+), 128 deletions(-)
diff --git a/native/Cargo.lock b/native/Cargo.lock
index 9bf8247d..605af92e 100644
--- a/native/Cargo.lock
+++ b/native/Cargo.lock
@@ -914,6 +914,7 @@ dependencies = [
"datafusion-common",
"datafusion-functions",
"datafusion-physical-expr",
+ "thiserror",
]
[[package]]
diff --git a/native/Cargo.toml b/native/Cargo.toml
index 53afed85..944b6e28 100644
--- a/native/Cargo.toml
+++ b/native/Cargo.toml
@@ -48,6 +48,7 @@ datafusion-physical-expr-common = { git =
"https://github.com/apache/datafusion.
datafusion-physical-expr = { git = "https://github.com/apache/datafusion.git",
rev = "40.0.0-rc1", default-features = false }
datafusion-comet-spark-expr = { path = "spark-expr", version = "0.1.0" }
datafusion-comet-utils = { path = "utils", version = "0.1.0" }
+thiserror = "1"
[profile.release]
debug = true
diff --git a/native/core/Cargo.toml b/native/core/Cargo.toml
index be135d4e..50c1ce2b 100644
--- a/native/core/Cargo.toml
+++ b/native/core/Cargo.toml
@@ -48,7 +48,7 @@ tokio = { version = "1", features = ["rt-multi-thread"] }
async-trait = "0.1"
log = "0.4"
log4rs = "1.2.0"
-thiserror = "1"
+thiserror = { workspace = true }
serde = { version = "1", features = ["derive"] }
lazy_static = "1.4.0"
prost = "0.12.1"
diff --git a/native/core/src/errors.rs b/native/core/src/errors.rs
index 8c02a72d..ff89e77d 100644
--- a/native/core/src/errors.rs
+++ b/native/core/src/errors.rs
@@ -38,6 +38,7 @@ use std::{
use jni::sys::{jboolean, jbyte, jchar, jdouble, jfloat, jint, jlong, jobject,
jshort};
use crate::execution::operators::ExecutionError;
+use datafusion_comet_spark_expr::SparkError;
use jni::objects::{GlobalRef, JThrowable};
use jni::JNIEnv;
use lazy_static::lazy_static;
@@ -62,36 +63,10 @@ pub enum CometError {
#[error("Comet Internal Error: {0}")]
Internal(String),
- // Note that this message format is based on Spark 3.4 and is more
detailed than the message
- // returned by Spark 3.3
- #[error("[CAST_INVALID_INPUT] The value '{value}' of the type
\"{from_type}\" cannot be cast to \"{to_type}\" \
- because it is malformed. Correct the value as per the syntax, or
change its target type. \
- Use `try_cast` to tolerate malformed input and return NULL instead. If
necessary \
- set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")]
- CastInvalidValue {
- value: String,
- from_type: String,
- to_type: String,
- },
-
- #[error("[NUMERIC_VALUE_OUT_OF_RANGE] {value} cannot be represented as
Decimal({precision}, {scale}). If necessary set \"spark.sql.ansi.enabled\" to
\"false\" to bypass this error, and return NULL instead.")]
- NumericValueOutOfRange {
- value: String,
- precision: u8,
- scale: i8,
- },
-
- #[error("[CAST_OVERFLOW] The value {value} of the type \"{from_type}\"
cannot be cast to \"{to_type}\" \
- due to an overflow. Use `try_cast` to tolerate overflow and return
NULL instead. If necessary \
- set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")]
- CastOverFlow {
- value: String,
- from_type: String,
- to_type: String,
- },
-
- #[error("[ARITHMETIC_OVERFLOW] {from_type} overflow. If necessary set
\"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")]
- ArithmeticOverflow { from_type: String },
+ /// CometError::Spark is typically used in native code to emulate the same
errors
+ /// that Spark would return
+ #[error(transparent)]
+ Spark(SparkError),
#[error(transparent)]
Arrow {
@@ -239,11 +214,7 @@ impl jni::errors::ToException for CometError {
class: "java/lang/NullPointerException".to_string(),
msg: self.to_string(),
},
- CometError::CastInvalidValue { .. } => Exception {
- class: "org/apache/spark/SparkException".to_string(),
- msg: self.to_string(),
- },
- CometError::NumericValueOutOfRange { .. } => Exception {
+ CometError::Spark { .. } => Exception {
class: "org/apache/spark/SparkException".to_string(),
msg: self.to_string(),
},
diff --git a/native/core/src/execution/datafusion/expressions/cast.rs
b/native/core/src/execution/datafusion/expressions/cast.rs
index 154ff28b..0b513e77 100644
--- a/native/core/src/execution/datafusion/expressions/cast.rs
+++ b/native/core/src/execution/datafusion/expressions/cast.rs
@@ -40,16 +40,14 @@ use arrow_array::{
use arrow_schema::{DataType, Schema};
use chrono::{NaiveDate, NaiveDateTime, TimeZone, Timelike};
use datafusion::logical_expr::ColumnarValue;
+use datafusion_comet_spark_expr::{SparkError, SparkResult};
use datafusion_common::{internal_err, Result as DataFusionResult, ScalarValue};
use datafusion_physical_expr::PhysicalExpr;
use num::{cast::AsPrimitive, traits::CheckedNeg, CheckedSub, Integer, Num,
ToPrimitive};
use regex::Regex;
-use crate::{
- errors::{CometError, CometResult},
- execution::datafusion::expressions::utils::{
- array_with_timezone, down_cast_any_ref, spark_cast,
- },
+use crate::execution::datafusion::expressions::utils::{
+ array_with_timezone, down_cast_any_ref, spark_cast,
};
use super::EvalMode;
@@ -87,7 +85,7 @@ macro_rules! cast_utf8_to_int {
cast_array.append_null()
}
}
- let result: CometResult<ArrayRef> = Ok(Arc::new(cast_array.finish())
as ArrayRef);
+ let result: SparkResult<ArrayRef> = Ok(Arc::new(cast_array.finish())
as ArrayRef);
result
}};
}
@@ -116,7 +114,7 @@ macro_rules! cast_float_to_string {
fn cast<OffsetSize>(
from: &dyn Array,
_eval_mode: EvalMode,
- ) -> CometResult<ArrayRef>
+ ) -> SparkResult<ArrayRef>
where
OffsetSize: OffsetSizeTrait, {
let array =
from.as_any().downcast_ref::<$output_type>().unwrap();
@@ -169,7 +167,7 @@ macro_rules! cast_float_to_string {
Some(value) => Ok(Some(value.to_string())),
_ => Ok(None),
})
- .collect::<Result<GenericStringArray<OffsetSize>,
CometError>>()?;
+ .collect::<Result<GenericStringArray<OffsetSize>,
SparkError>>()?;
Ok(Arc::new(output_array))
}
@@ -205,7 +203,7 @@ macro_rules! cast_int_to_int_macro {
.iter()
.map(|value| match value {
Some(value) => {
- Ok::<Option<$to_native_type>, CometError>(Some(value
as $to_native_type))
+ Ok::<Option<$to_native_type>, SparkError>(Some(value
as $to_native_type))
}
_ => Ok(None),
})
@@ -222,14 +220,14 @@ macro_rules! cast_int_to_int_macro {
$spark_to_data_type_name,
))
} else {
- Ok::<Option<$to_native_type>,
CometError>(Some(res.unwrap()))
+ Ok::<Option<$to_native_type>,
SparkError>(Some(res.unwrap()))
}
}
_ => Ok(None),
})
.collect::<Result<PrimitiveArray<$to_arrow_primitive_type>,
_>>(),
}?;
- let result: CometResult<ArrayRef> = Ok(Arc::new(output_array) as
ArrayRef);
+ let result: SparkResult<ArrayRef> = Ok(Arc::new(output_array) as
ArrayRef);
result
}};
}
@@ -286,7 +284,7 @@ macro_rules! cast_float_to_int16_down {
.map(|value| match value {
Some(value) => {
let i32_value = value as i32;
- Ok::<Option<$rust_dest_type>, CometError>(Some(
+ Ok::<Option<$rust_dest_type>, SparkError>(Some(
i32_value as $rust_dest_type,
))
}
@@ -339,7 +337,7 @@ macro_rules! cast_float_to_int32_up {
.iter()
.map(|value| match value {
Some(value) => {
- Ok::<Option<$rust_dest_type>, CometError>(Some(value
as $rust_dest_type))
+ Ok::<Option<$rust_dest_type>, SparkError>(Some(value
as $rust_dest_type))
}
None => Ok(None),
})
@@ -402,7 +400,7 @@ macro_rules! cast_decimal_to_int16_down {
Some(value) => {
let divisor = 10_i128.pow($scale as u32);
let i32_value = (value / divisor) as i32;
- Ok::<Option<$rust_dest_type>, CometError>(Some(
+ Ok::<Option<$rust_dest_type>, SparkError>(Some(
i32_value as $rust_dest_type,
))
}
@@ -456,7 +454,7 @@ macro_rules! cast_decimal_to_int32_up {
Some(value) => {
let divisor = 10_i128.pow($scale as u32);
let truncated = value / divisor;
- Ok::<Option<$rust_dest_type>, CometError>(Some(
+ Ok::<Option<$rust_dest_type>, SparkError>(Some(
truncated as $rust_dest_type,
))
}
@@ -596,7 +594,7 @@ impl Cast {
// we should never reach this code because the Scala code
should be checking
// for supported cast operations and falling back to Spark for
anything that
// is not yet supported
- Err(CometError::Internal(format!(
+ Err(SparkError::Internal(format!(
"Native cast invoked for unsupported cast from
{from_type:?} to {to_type:?}"
)))
}
@@ -680,7 +678,7 @@ impl Cast {
to_type: &DataType,
array: &ArrayRef,
eval_mode: EvalMode,
- ) -> CometResult<ArrayRef> {
+ ) -> SparkResult<ArrayRef> {
let string_array = array
.as_any()
.downcast_ref::<GenericStringArray<OffsetSize>>()
@@ -711,7 +709,7 @@ impl Cast {
array: &ArrayRef,
to_type: &DataType,
eval_mode: EvalMode,
- ) -> CometResult<ArrayRef> {
+ ) -> SparkResult<ArrayRef> {
let string_array = array
.as_any()
.downcast_ref::<GenericStringArray<i32>>()
@@ -743,7 +741,7 @@ impl Cast {
array: &ArrayRef,
to_type: &DataType,
eval_mode: EvalMode,
- ) -> CometResult<ArrayRef> {
+ ) -> SparkResult<ArrayRef> {
let string_array = array
.as_any()
.downcast_ref::<GenericStringArray<i32>>()
@@ -768,7 +766,7 @@ impl Cast {
precision: u8,
scale: i8,
eval_mode: EvalMode,
- ) -> CometResult<ArrayRef> {
+ ) -> SparkResult<ArrayRef> {
Self::cast_floating_point_to_decimal128::<Float64Type>(array,
precision, scale, eval_mode)
}
@@ -777,7 +775,7 @@ impl Cast {
precision: u8,
scale: i8,
eval_mode: EvalMode,
- ) -> CometResult<ArrayRef> {
+ ) -> SparkResult<ArrayRef> {
Self::cast_floating_point_to_decimal128::<Float32Type>(array,
precision, scale, eval_mode)
}
@@ -786,7 +784,7 @@ impl Cast {
precision: u8,
scale: i8,
eval_mode: EvalMode,
- ) -> CometResult<ArrayRef>
+ ) -> SparkResult<ArrayRef>
where
<T as ArrowPrimitiveType>::Native: AsPrimitive<f64>,
{
@@ -806,7 +804,7 @@ impl Cast {
Some(v) => {
if Decimal128Type::validate_decimal_precision(v,
precision).is_err() {
if eval_mode == EvalMode::Ansi {
- return Err(CometError::NumericValueOutOfRange {
+ return Err(SparkError::NumericValueOutOfRange {
value: input_value.to_string(),
precision,
scale,
@@ -819,7 +817,7 @@ impl Cast {
}
None => {
if eval_mode == EvalMode::Ansi {
- return Err(CometError::NumericValueOutOfRange {
+ return Err(SparkError::NumericValueOutOfRange {
value: input_value.to_string(),
precision,
scale,
@@ -843,7 +841,7 @@ impl Cast {
fn spark_cast_float64_to_utf8<OffsetSize>(
from: &dyn Array,
_eval_mode: EvalMode,
- ) -> CometResult<ArrayRef>
+ ) -> SparkResult<ArrayRef>
where
OffsetSize: OffsetSizeTrait,
{
@@ -853,7 +851,7 @@ impl Cast {
fn spark_cast_float32_to_utf8<OffsetSize>(
from: &dyn Array,
_eval_mode: EvalMode,
- ) -> CometResult<ArrayRef>
+ ) -> SparkResult<ArrayRef>
where
OffsetSize: OffsetSizeTrait,
{
@@ -865,7 +863,7 @@ impl Cast {
eval_mode: EvalMode,
from_type: &DataType,
to_type: &DataType,
- ) -> CometResult<ArrayRef> {
+ ) -> SparkResult<ArrayRef> {
match (from_type, to_type) {
(DataType::Int64, DataType::Int32) => cast_int_to_int_macro!(
array, eval_mode, Int64Type, Int32Type, from_type, i32,
"BIGINT", "INT"
@@ -895,7 +893,7 @@ impl Cast {
fn spark_cast_utf8_to_boolean<OffsetSize>(
from: &dyn Array,
eval_mode: EvalMode,
- ) -> CometResult<ArrayRef>
+ ) -> SparkResult<ArrayRef>
where
OffsetSize: OffsetSizeTrait,
{
@@ -910,7 +908,7 @@ impl Cast {
Some(value) => match value.to_ascii_lowercase().trim() {
"t" | "true" | "y" | "yes" | "1" => Ok(Some(true)),
"f" | "false" | "n" | "no" | "0" => Ok(Some(false)),
- _ if eval_mode == EvalMode::Ansi =>
Err(CometError::CastInvalidValue {
+ _ if eval_mode == EvalMode::Ansi =>
Err(SparkError::CastInvalidValue {
value: value.to_string(),
from_type: "STRING".to_string(),
to_type: "BOOLEAN".to_string(),
@@ -929,7 +927,7 @@ impl Cast {
eval_mode: EvalMode,
from_type: &DataType,
to_type: &DataType,
- ) -> CometResult<ArrayRef> {
+ ) -> SparkResult<ArrayRef> {
match (from_type, to_type) {
(DataType::Float32, DataType::Int8) => cast_float_to_int16_down!(
array,
@@ -1066,7 +1064,7 @@ impl Cast {
}
/// Equivalent to org.apache.spark.unsafe.types.UTF8String.toByte
-fn cast_string_to_i8(str: &str, eval_mode: EvalMode) ->
CometResult<Option<i8>> {
+fn cast_string_to_i8(str: &str, eval_mode: EvalMode) ->
SparkResult<Option<i8>> {
Ok(cast_string_to_int_with_range_check(
str,
eval_mode,
@@ -1078,7 +1076,7 @@ fn cast_string_to_i8(str: &str, eval_mode: EvalMode) ->
CometResult<Option<i8>>
}
/// Equivalent to org.apache.spark.unsafe.types.UTF8String.toShort
-fn cast_string_to_i16(str: &str, eval_mode: EvalMode) ->
CometResult<Option<i16>> {
+fn cast_string_to_i16(str: &str, eval_mode: EvalMode) ->
SparkResult<Option<i16>> {
Ok(cast_string_to_int_with_range_check(
str,
eval_mode,
@@ -1090,12 +1088,12 @@ fn cast_string_to_i16(str: &str, eval_mode: EvalMode)
-> CometResult<Option<i16>
}
/// Equivalent to org.apache.spark.unsafe.types.UTF8String.toInt(IntWrapper
intWrapper)
-fn cast_string_to_i32(str: &str, eval_mode: EvalMode) ->
CometResult<Option<i32>> {
+fn cast_string_to_i32(str: &str, eval_mode: EvalMode) ->
SparkResult<Option<i32>> {
do_cast_string_to_int::<i32>(str, eval_mode, "INT", i32::MIN)
}
/// Equivalent to org.apache.spark.unsafe.types.UTF8String.toLong(LongWrapper
intWrapper)
-fn cast_string_to_i64(str: &str, eval_mode: EvalMode) ->
CometResult<Option<i64>> {
+fn cast_string_to_i64(str: &str, eval_mode: EvalMode) ->
SparkResult<Option<i64>> {
do_cast_string_to_int::<i64>(str, eval_mode, "BIGINT", i64::MIN)
}
@@ -1105,7 +1103,7 @@ fn cast_string_to_int_with_range_check(
type_name: &str,
min: i32,
max: i32,
-) -> CometResult<Option<i32>> {
+) -> SparkResult<Option<i32>> {
match do_cast_string_to_int(str, eval_mode, type_name, i32::MIN)? {
None => Ok(None),
Some(v) if v >= min && v <= max => Ok(Some(v)),
@@ -1124,7 +1122,7 @@ fn do_cast_string_to_int<
eval_mode: EvalMode,
type_name: &str,
min_value: T,
-) -> CometResult<Option<T>> {
+) -> SparkResult<Option<T>> {
let trimmed_str = str.trim();
if trimmed_str.is_empty() {
return none_or_err(eval_mode, type_name, str);
@@ -1208,9 +1206,9 @@ fn do_cast_string_to_int<
Ok(Some(result))
}
-/// Either return Ok(None) or Err(CometError::CastInvalidValue) depending on
the evaluation mode
+/// Either return Ok(None) or Err(SparkError::CastInvalidValue) depending on
the evaluation mode
#[inline]
-fn none_or_err<T>(eval_mode: EvalMode, type_name: &str, str: &str) ->
CometResult<Option<T>> {
+fn none_or_err<T>(eval_mode: EvalMode, type_name: &str, str: &str) ->
SparkResult<Option<T>> {
match eval_mode {
EvalMode::Ansi => Err(invalid_value(str, "STRING", type_name)),
_ => Ok(None),
@@ -1218,8 +1216,8 @@ fn none_or_err<T>(eval_mode: EvalMode, type_name: &str,
str: &str) -> CometResul
}
#[inline]
-fn invalid_value(value: &str, from_type: &str, to_type: &str) -> CometError {
- CometError::CastInvalidValue {
+fn invalid_value(value: &str, from_type: &str, to_type: &str) -> SparkError {
+ SparkError::CastInvalidValue {
value: value.to_string(),
from_type: from_type.to_string(),
to_type: to_type.to_string(),
@@ -1227,8 +1225,8 @@ fn invalid_value(value: &str, from_type: &str, to_type:
&str) -> CometError {
}
#[inline]
-fn cast_overflow(value: &str, from_type: &str, to_type: &str) -> CometError {
- CometError::CastOverFlow {
+fn cast_overflow(value: &str, from_type: &str, to_type: &str) -> SparkError {
+ SparkError::CastOverFlow {
value: value.to_string(),
from_type: from_type.to_string(),
to_type: to_type.to_string(),
@@ -1316,7 +1314,7 @@ impl PhysicalExpr for Cast {
}
}
-fn timestamp_parser(value: &str, eval_mode: EvalMode) ->
CometResult<Option<i64>> {
+fn timestamp_parser(value: &str, eval_mode: EvalMode) ->
SparkResult<Option<i64>> {
let value = value.trim();
if value.is_empty() {
return Ok(None);
@@ -1325,7 +1323,7 @@ fn timestamp_parser(value: &str, eval_mode: EvalMode) ->
CometResult<Option<i64>
let patterns = &[
(
Regex::new(r"^\d{4}$").unwrap(),
- parse_str_to_year_timestamp as fn(&str) ->
CometResult<Option<i64>>,
+ parse_str_to_year_timestamp as fn(&str) ->
SparkResult<Option<i64>>,
),
(
Regex::new(r"^\d{4}-\d{2}$").unwrap(),
@@ -1369,7 +1367,7 @@ fn timestamp_parser(value: &str, eval_mode: EvalMode) ->
CometResult<Option<i64>
if timestamp.is_none() {
return if eval_mode == EvalMode::Ansi {
- Err(CometError::CastInvalidValue {
+ Err(SparkError::CastInvalidValue {
value: value.to_string(),
from_type: "STRING".to_string(),
to_type: "TIMESTAMP".to_string(),
@@ -1381,20 +1379,20 @@ fn timestamp_parser(value: &str, eval_mode: EvalMode)
-> CometResult<Option<i64>
match timestamp {
Some(ts) => Ok(Some(ts)),
- None => Err(CometError::Internal(
+ None => Err(SparkError::Internal(
"Failed to parse timestamp".to_string(),
)),
}
}
-fn parse_ymd_timestamp(year: i32, month: u32, day: u32) ->
CometResult<Option<i64>> {
+fn parse_ymd_timestamp(year: i32, month: u32, day: u32) ->
SparkResult<Option<i64>> {
let datetime = chrono::Utc.with_ymd_and_hms(year, month, day, 0, 0, 0);
// Check if datetime is not None
let utc_datetime = match datetime.single() {
Some(dt) => dt.with_timezone(&chrono::Utc),
None => {
- return Err(CometError::Internal(
+ return Err(SparkError::Internal(
"Failed to parse timestamp".to_string(),
));
}
@@ -1411,7 +1409,7 @@ fn parse_hms_timestamp(
minute: u32,
second: u32,
microsecond: u32,
-) -> CometResult<Option<i64>> {
+) -> SparkResult<Option<i64>> {
let datetime = chrono::Utc.with_ymd_and_hms(year, month, day, hour,
minute, second);
// Check if datetime is not None
@@ -1420,7 +1418,7 @@ fn parse_hms_timestamp(
.with_timezone(&chrono::Utc)
.with_nanosecond(microsecond * 1000),
None => {
- return Err(CometError::Internal(
+ return Err(SparkError::Internal(
"Failed to parse timestamp".to_string(),
));
}
@@ -1429,7 +1427,7 @@ fn parse_hms_timestamp(
let result = match utc_datetime {
Some(dt) => dt.timestamp_micros(),
None => {
- return Err(CometError::Internal(
+ return Err(SparkError::Internal(
"Failed to parse timestamp".to_string(),
));
}
@@ -1438,7 +1436,7 @@ fn parse_hms_timestamp(
Ok(Some(result))
}
-fn get_timestamp_values(value: &str, timestamp_type: &str) ->
CometResult<Option<i64>> {
+fn get_timestamp_values(value: &str, timestamp_type: &str) ->
SparkResult<Option<i64>> {
let values: Vec<_> = value
.split(|c| c == 'T' || c == '-' || c == ':' || c == '.')
.collect();
@@ -1458,7 +1456,7 @@ fn get_timestamp_values(value: &str, timestamp_type:
&str) -> CometResult<Option
"minute" => parse_hms_timestamp(year, month, day, hour, minute, 0, 0),
"second" => parse_hms_timestamp(year, month, day, hour, minute,
second, 0),
"microsecond" => parse_hms_timestamp(year, month, day, hour, minute,
second, microsecond),
- _ => Err(CometError::CastInvalidValue {
+ _ => Err(SparkError::CastInvalidValue {
value: value.to_string(),
from_type: "STRING".to_string(),
to_type: "TIMESTAMP".to_string(),
@@ -1466,35 +1464,35 @@ fn get_timestamp_values(value: &str, timestamp_type:
&str) -> CometResult<Option
}
}
-fn parse_str_to_year_timestamp(value: &str) -> CometResult<Option<i64>> {
+fn parse_str_to_year_timestamp(value: &str) -> SparkResult<Option<i64>> {
get_timestamp_values(value, "year")
}
-fn parse_str_to_month_timestamp(value: &str) -> CometResult<Option<i64>> {
+fn parse_str_to_month_timestamp(value: &str) -> SparkResult<Option<i64>> {
get_timestamp_values(value, "month")
}
-fn parse_str_to_day_timestamp(value: &str) -> CometResult<Option<i64>> {
+fn parse_str_to_day_timestamp(value: &str) -> SparkResult<Option<i64>> {
get_timestamp_values(value, "day")
}
-fn parse_str_to_hour_timestamp(value: &str) -> CometResult<Option<i64>> {
+fn parse_str_to_hour_timestamp(value: &str) -> SparkResult<Option<i64>> {
get_timestamp_values(value, "hour")
}
-fn parse_str_to_minute_timestamp(value: &str) -> CometResult<Option<i64>> {
+fn parse_str_to_minute_timestamp(value: &str) -> SparkResult<Option<i64>> {
get_timestamp_values(value, "minute")
}
-fn parse_str_to_second_timestamp(value: &str) -> CometResult<Option<i64>> {
+fn parse_str_to_second_timestamp(value: &str) -> SparkResult<Option<i64>> {
get_timestamp_values(value, "second")
}
-fn parse_str_to_microsecond_timestamp(value: &str) -> CometResult<Option<i64>>
{
+fn parse_str_to_microsecond_timestamp(value: &str) -> SparkResult<Option<i64>>
{
get_timestamp_values(value, "microsecond")
}
-fn parse_str_to_time_only_timestamp(value: &str) -> CometResult<Option<i64>> {
+fn parse_str_to_time_only_timestamp(value: &str) -> SparkResult<Option<i64>> {
let values: Vec<&str> = value.split('T').collect();
let time_values: Vec<u32> = values[1]
.split(':')
@@ -1514,7 +1512,7 @@ fn parse_str_to_time_only_timestamp(value: &str) ->
CometResult<Option<i64>> {
}
//a string to date parser - port of spark's SparkDateTimeUtils#stringToDate.
-fn date_parser(date_str: &str, eval_mode: EvalMode) ->
CometResult<Option<i32>> {
+fn date_parser(date_str: &str, eval_mode: EvalMode) ->
SparkResult<Option<i32>> {
// local functions
fn get_trimmed_start(bytes: &[u8]) -> usize {
let mut start = 0;
@@ -1545,9 +1543,9 @@ fn date_parser(date_str: &str, eval_mode: EvalMode) ->
CometResult<Option<i32>>
|| (segment != 0 && digits > 0 && digits <= 2)
}
- fn return_result(date_str: &str, eval_mode: EvalMode) ->
CometResult<Option<i32>> {
+ fn return_result(date_str: &str, eval_mode: EvalMode) ->
SparkResult<Option<i32>> {
if eval_mode == EvalMode::Ansi {
- Err(CometError::CastInvalidValue {
+ Err(SparkError::CastInvalidValue {
value: date_str.to_string(),
from_type: "STRING".to_string(),
to_type: "DATE".to_string(),
diff --git a/native/core/src/execution/datafusion/expressions/mod.rs
b/native/core/src/execution/datafusion/expressions/mod.rs
index d573c237..c61266ce 100644
--- a/native/core/src/execution/datafusion/expressions/mod.rs
+++ b/native/core/src/execution/datafusion/expressions/mod.rs
@@ -43,10 +43,10 @@ mod utils;
pub mod variance;
pub mod xxhash64;
-pub use datafusion_comet_spark_expr::EvalMode;
+pub use datafusion_comet_spark_expr::{EvalMode, SparkError};
fn arithmetic_overflow_error(from_type: &str) -> CometError {
- CometError::ArithmeticOverflow {
+ CometError::Spark(SparkError::ArithmeticOverflow {
from_type: from_type.to_string(),
- }
+ })
}
diff --git a/native/core/src/execution/datafusion/expressions/negative.rs
b/native/core/src/execution/datafusion/expressions/negative.rs
index cd0e9bcc..9e82812b 100644
--- a/native/core/src/execution/datafusion/expressions/negative.rs
+++ b/native/core/src/execution/datafusion/expressions/negative.rs
@@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.
+use super::arithmetic_overflow_error;
use crate::errors::CometError;
use arrow::{compute::kernels::numeric::neg_wrapping,
datatypes::IntervalDayTimeType};
use arrow_array::RecordBatch;
@@ -24,6 +25,7 @@ use datafusion::{
logical_expr::{interval_arithmetic::Interval, ColumnarValue},
physical_expr::PhysicalExpr,
};
+use datafusion_comet_spark_expr::SparkError;
use datafusion_common::{Result, ScalarValue};
use datafusion_expr::sort_properties::ExprProperties;
use datafusion_physical_expr::aggregate::utils::down_cast_any_ref;
@@ -33,8 +35,6 @@ use std::{
sync::Arc,
};
-use super::arithmetic_overflow_error;
-
pub fn create_negate_expr(
expr: Arc<dyn PhysicalExpr>,
fail_on_error: bool,
@@ -234,7 +234,7 @@ impl PhysicalExpr for NegativeExpr {
|| child_interval.lower() == &ScalarValue::Int64(Some(i64::MIN))
|| child_interval.upper() == &ScalarValue::Int64(Some(i64::MIN))
{
- return Err(CometError::ArithmeticOverflow {
+ return Err(SparkError::ArithmeticOverflow {
from_type: "long".to_string(),
}
.into());
diff --git a/native/spark-expr/Cargo.toml b/native/spark-expr/Cargo.toml
index 8bf76dff..4a9b9408 100644
--- a/native/spark-expr/Cargo.toml
+++ b/native/spark-expr/Cargo.toml
@@ -34,6 +34,7 @@ datafusion-common = { workspace = true }
datafusion-functions = { workspace = true }
datafusion-physical-expr = { workspace = true }
datafusion-comet-utils = { workspace = true }
+thiserror = { workspace = true }
[lib]
name = "datafusion_comet_spark_expr"
diff --git a/native/spark-expr/src/abs.rs b/native/spark-expr/src/abs.rs
index 198a96e5..fa25a777 100644
--- a/native/spark-expr/src/abs.rs
+++ b/native/spark-expr/src/abs.rs
@@ -77,9 +77,10 @@ impl ScalarUDFImpl for Abs {
if self.eval_mode == EvalMode::Legacy {
Ok(args[0].clone())
} else {
- Err(DataFusionError::External(Box::new(
-
SparkError::ArithmeticOverflow(self.data_type_name.clone()),
- )))
+ Err(SparkError::ArithmeticOverflow {
+ from_type: self.data_type_name.clone(),
+ }
+ .into())
}
}
other => other,
diff --git a/native/spark-expr/src/error.rs b/native/spark-expr/src/error.rs
new file mode 100644
index 00000000..728a35a9
--- /dev/null
+++ b/native/spark-expr/src/error.rs
@@ -0,0 +1,73 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use arrow_schema::ArrowError;
+use datafusion_common::DataFusionError;
+
+#[derive(thiserror::Error, Debug)]
+pub enum SparkError {
+ // Note that this message format is based on Spark 3.4 and is more
detailed than the message
+ // returned by Spark 3.3
+ #[error("[CAST_INVALID_INPUT] The value '{value}' of the type
\"{from_type}\" cannot be cast to \"{to_type}\" \
+ because it is malformed. Correct the value as per the syntax, or
change its target type. \
+ Use `try_cast` to tolerate malformed input and return NULL instead. If
necessary \
+ set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")]
+ CastInvalidValue {
+ value: String,
+ from_type: String,
+ to_type: String,
+ },
+
+ #[error("[NUMERIC_VALUE_OUT_OF_RANGE] {value} cannot be represented as
Decimal({precision}, {scale}). If necessary set \"spark.sql.ansi.enabled\" to
\"false\" to bypass this error, and return NULL instead.")]
+ NumericValueOutOfRange {
+ value: String,
+ precision: u8,
+ scale: i8,
+ },
+
+ #[error("[CAST_OVERFLOW] The value {value} of the type \"{from_type}\"
cannot be cast to \"{to_type}\" \
+ due to an overflow. Use `try_cast` to tolerate overflow and return
NULL instead. If necessary \
+ set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")]
+ CastOverFlow {
+ value: String,
+ from_type: String,
+ to_type: String,
+ },
+
+ #[error("[ARITHMETIC_OVERFLOW] {from_type} overflow. If necessary set
\"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")]
+ ArithmeticOverflow { from_type: String },
+
+ #[error("ArrowError: {0}.")]
+ Arrow(ArrowError),
+
+ #[error("InternalError: {0}.")]
+ Internal(String),
+}
+
+pub type SparkResult<T> = Result<T, SparkError>;
+
+impl From<ArrowError> for SparkError {
+ fn from(value: ArrowError) -> Self {
+ SparkError::Arrow(value)
+ }
+}
+
+impl From<SparkError> for DataFusionError {
+ fn from(value: SparkError) -> Self {
+ DataFusionError::External(Box::new(value))
+ }
+}
diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs
index c36e8855..57da56f9 100644
--- a/native/spark-expr/src/lib.rs
+++ b/native/spark-expr/src/lib.rs
@@ -15,13 +15,12 @@
// specific language governing permissions and limitations
// under the License.
-use std::error::Error;
-use std::fmt::{Display, Formatter};
-
mod abs;
+mod error;
mod if_expr;
pub use abs::Abs;
+pub use error::{SparkError, SparkResult};
pub use if_expr::IfExpr;
/// Spark supports three evaluation modes when evaluating expressions, which
affect
@@ -42,19 +41,3 @@ pub enum EvalMode {
/// failing the entire query.
Try,
}
-
-#[derive(Debug)]
-pub enum SparkError {
- ArithmeticOverflow(String),
-}
-
-impl Error for SparkError {}
-
-impl Display for SparkError {
- fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
- match self {
- Self::ArithmeticOverflow(data_type) =>
- write!(f, "[ARITHMETIC_OVERFLOW] {} overflow. If necessary set
\"spark.sql.ansi.enabled\" to \"false\" to bypass this error.", data_type)
- }
- }
-}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]