This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new de8fe455 fix: Enable cast string to int tests and fix compatibility
issue (#453)
de8fe455 is described below
commit de8fe455be6e875535aa5c974f76e902636fac9e
Author: Andy Grove <[email protected]>
AuthorDate: Mon May 20 20:02:15 2024 -0600
fix: Enable cast string to int tests and fix compatibility issue (#453)
* simplify cast string to int logic and use untrimmed string in error
messages
* remove state enum
---
core/src/execution/datafusion/expressions/cast.rs | 57 +++++-----------------
docs/source/user-guide/compatibility.md | 8 +--
.../org/apache/comet/expressions/CometCast.scala | 2 +-
.../scala/org/apache/comet/CometCastSuite.scala | 8 +--
4 files changed, 20 insertions(+), 55 deletions(-)
diff --git a/core/src/execution/datafusion/expressions/cast.rs
b/core/src/execution/datafusion/expressions/cast.rs
index 35ab23a7..f68732fb 100644
--- a/core/src/execution/datafusion/expressions/cast.rs
+++ b/core/src/execution/datafusion/expressions/cast.rs
@@ -82,7 +82,7 @@ macro_rules! cast_utf8_to_int {
for i in 0..len {
if $array.is_null(i) {
cast_array.append_null()
- } else if let Some(cast_value) =
$cast_method($array.value(i).trim(), $eval_mode)? {
+ } else if let Some(cast_value) = $cast_method($array.value(i),
$eval_mode)? {
cast_array.append_value(cast_value);
} else {
cast_array.append_null()
@@ -1010,14 +1010,6 @@ fn cast_string_to_int_with_range_check(
}
}
-#[derive(PartialEq)]
-enum State {
- SkipLeadingWhiteSpace,
- SkipTrailingWhiteSpace,
- ParseSignAndDigits,
- ParseFractionalDigits,
-}
-
/// Equivalent to
/// - org.apache.spark.unsafe.types.UTF8String.toInt(IntWrapper intWrapper,
boolean allowDecimal)
/// - org.apache.spark.unsafe.types.UTF8String.toLong(LongWrapper longWrapper,
boolean allowDecimal)
@@ -1029,34 +1021,22 @@ fn do_cast_string_to_int<
type_name: &str,
min_value: T,
) -> CometResult<Option<T>> {
- let len = str.len();
- if str.is_empty() {
+ let trimmed_str = str.trim();
+ if trimmed_str.is_empty() {
return none_or_err(eval_mode, type_name, str);
}
-
+ let len = trimmed_str.len();
let mut result: T = T::zero();
let mut negative = false;
let radix = T::from(10);
let stop_value = min_value / radix;
- let mut state = State::SkipLeadingWhiteSpace;
- let mut parsed_sign = false;
-
- for (i, ch) in str.char_indices() {
- // skip leading whitespace
- if state == State::SkipLeadingWhiteSpace {
- if ch.is_whitespace() {
- // consume this char
- continue;
- }
- // change state and fall through to next section
- state = State::ParseSignAndDigits;
- }
+ let mut parse_sign_and_digits = true;
- if state == State::ParseSignAndDigits {
- if !parsed_sign {
+ for (i, ch) in trimmed_str.char_indices() {
+ if parse_sign_and_digits {
+ if i == 0 {
negative = ch == '-';
let positive = ch == '+';
- parsed_sign = true;
if negative || positive {
if i + 1 == len {
// input string is just "+" or "-"
@@ -1070,7 +1050,7 @@ fn do_cast_string_to_int<
if ch == '.' {
if eval_mode == EvalMode::Legacy {
// truncate decimal in legacy mode
- state = State::ParseFractionalDigits;
+ parse_sign_and_digits = false;
continue;
} else {
return none_or_err(eval_mode, type_name, str);
@@ -1102,27 +1082,12 @@ fn do_cast_string_to_int<
return none_or_err(eval_mode, type_name, str);
}
}
- }
-
- if state == State::ParseFractionalDigits {
- // This is the case when we've encountered a decimal separator.
The fractional
- // part will not change the number, but we will verify that the
fractional part
- // is well-formed.
- if ch.is_whitespace() {
- // finished parsing fractional digits, now need to skip
trailing whitespace
- state = State::SkipTrailingWhiteSpace;
- // consume this char
- continue;
- }
+ } else {
+ // make sure fractional digits are valid digits but ignore them
if !ch.is_ascii_digit() {
return none_or_err(eval_mode, type_name, str);
}
}
-
- // skip trailing whitespace
- if state == State::SkipTrailingWhiteSpace && !ch.is_whitespace() {
- return none_or_err(eval_mode, type_name, str);
- }
}
if !negative {
diff --git a/docs/source/user-guide/compatibility.md
b/docs/source/user-guide/compatibility.md
index 278edb84..a4ed9289 100644
--- a/docs/source/user-guide/compatibility.md
+++ b/docs/source/user-guide/compatibility.md
@@ -110,6 +110,10 @@ The following cast operations are generally compatible
with Spark except for the
| decimal | float | |
| decimal | double | |
| string | boolean | |
+| string | byte | |
+| string | short | |
+| string | integer | |
+| string | long | |
| string | binary | |
| date | string | |
| timestamp | long | |
@@ -125,10 +129,6 @@ The following cast operations are not compatible with
Spark for all inputs and a
|-|-|-|
| integer | decimal | No overflow check |
| long | decimal | No overflow check |
-| string | byte | Not all invalid inputs are detected |
-| string | short | Not all invalid inputs are detected |
-| string | integer | Not all invalid inputs are detected |
-| string | long | Not all invalid inputs are detected |
| string | timestamp | Not all valid formats are supported |
| binary | string | Only works for binary data representing valid UTF-8
strings |
diff --git a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
index 9c3695ba..795bdb42 100644
--- a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
+++ b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
@@ -108,7 +108,7 @@ object CometCast {
Compatible()
case DataTypes.ByteType | DataTypes.ShortType | DataTypes.IntegerType |
DataTypes.LongType =>
- Incompatible(Some("Not all invalid inputs are detected"))
+ Compatible()
case DataTypes.BinaryType =>
Compatible()
case DataTypes.FloatType | DataTypes.DoubleType =>
diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
index ea3355d0..8caba14c 100644
--- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
@@ -519,28 +519,28 @@ class CometCastSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
"9223372036854775808" // Long.MaxValue + 1
)
- ignore("cast StringType to ByteType") {
+ test("cast StringType to ByteType") {
// test with hand-picked values
castTest(castStringToIntegralInputs.toDF("a"), DataTypes.ByteType)
// fuzz test
castTest(gen.generateStrings(dataSize, numericPattern, 4).toDF("a"),
DataTypes.ByteType)
}
- ignore("cast StringType to ShortType") {
+ test("cast StringType to ShortType") {
// test with hand-picked values
castTest(castStringToIntegralInputs.toDF("a"), DataTypes.ShortType)
// fuzz test
castTest(gen.generateStrings(dataSize, numericPattern, 5).toDF("a"),
DataTypes.ShortType)
}
- ignore("cast StringType to IntegerType") {
+ test("cast StringType to IntegerType") {
// test with hand-picked values
castTest(castStringToIntegralInputs.toDF("a"), DataTypes.IntegerType)
// fuzz test
castTest(gen.generateStrings(dataSize, numericPattern, 8).toDF("a"),
DataTypes.IntegerType)
}
- ignore("cast StringType to LongType") {
+ test("cast StringType to LongType") {
// test with hand-picked values
castTest(castStringToIntegralInputs.toDF("a"), DataTypes.LongType)
// fuzz test
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]