This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new de8fe455 fix: Enable cast string to int tests and fix compatibility 
issue (#453)
de8fe455 is described below

commit de8fe455be6e875535aa5c974f76e902636fac9e
Author: Andy Grove <[email protected]>
AuthorDate: Mon May 20 20:02:15 2024 -0600

    fix: Enable cast string to int tests and fix compatibility issue (#453)
    
    * simplify cast string to int logic and use untrimmed string in error 
messages
    
    * remove state enum
---
 core/src/execution/datafusion/expressions/cast.rs  | 57 +++++-----------------
 docs/source/user-guide/compatibility.md            |  8 +--
 .../org/apache/comet/expressions/CometCast.scala   |  2 +-
 .../scala/org/apache/comet/CometCastSuite.scala    |  8 +--
 4 files changed, 20 insertions(+), 55 deletions(-)

diff --git a/core/src/execution/datafusion/expressions/cast.rs 
b/core/src/execution/datafusion/expressions/cast.rs
index 35ab23a7..f68732fb 100644
--- a/core/src/execution/datafusion/expressions/cast.rs
+++ b/core/src/execution/datafusion/expressions/cast.rs
@@ -82,7 +82,7 @@ macro_rules! cast_utf8_to_int {
         for i in 0..len {
             if $array.is_null(i) {
                 cast_array.append_null()
-            } else if let Some(cast_value) = 
$cast_method($array.value(i).trim(), $eval_mode)? {
+            } else if let Some(cast_value) = $cast_method($array.value(i), 
$eval_mode)? {
                 cast_array.append_value(cast_value);
             } else {
                 cast_array.append_null()
@@ -1010,14 +1010,6 @@ fn cast_string_to_int_with_range_check(
     }
 }
 
-#[derive(PartialEq)]
-enum State {
-    SkipLeadingWhiteSpace,
-    SkipTrailingWhiteSpace,
-    ParseSignAndDigits,
-    ParseFractionalDigits,
-}
-
 /// Equivalent to
 /// - org.apache.spark.unsafe.types.UTF8String.toInt(IntWrapper intWrapper, 
boolean allowDecimal)
 /// - org.apache.spark.unsafe.types.UTF8String.toLong(LongWrapper longWrapper, 
boolean allowDecimal)
@@ -1029,34 +1021,22 @@ fn do_cast_string_to_int<
     type_name: &str,
     min_value: T,
 ) -> CometResult<Option<T>> {
-    let len = str.len();
-    if str.is_empty() {
+    let trimmed_str = str.trim();
+    if trimmed_str.is_empty() {
         return none_or_err(eval_mode, type_name, str);
     }
-
+    let len = trimmed_str.len();
     let mut result: T = T::zero();
     let mut negative = false;
     let radix = T::from(10);
     let stop_value = min_value / radix;
-    let mut state = State::SkipLeadingWhiteSpace;
-    let mut parsed_sign = false;
-
-    for (i, ch) in str.char_indices() {
-        // skip leading whitespace
-        if state == State::SkipLeadingWhiteSpace {
-            if ch.is_whitespace() {
-                // consume this char
-                continue;
-            }
-            // change state and fall through to next section
-            state = State::ParseSignAndDigits;
-        }
+    let mut parse_sign_and_digits = true;
 
-        if state == State::ParseSignAndDigits {
-            if !parsed_sign {
+    for (i, ch) in trimmed_str.char_indices() {
+        if parse_sign_and_digits {
+            if i == 0 {
                 negative = ch == '-';
                 let positive = ch == '+';
-                parsed_sign = true;
                 if negative || positive {
                     if i + 1 == len {
                         // input string is just "+" or "-"
@@ -1070,7 +1050,7 @@ fn do_cast_string_to_int<
             if ch == '.' {
                 if eval_mode == EvalMode::Legacy {
                     // truncate decimal in legacy mode
-                    state = State::ParseFractionalDigits;
+                    parse_sign_and_digits = false;
                     continue;
                 } else {
                     return none_or_err(eval_mode, type_name, str);
@@ -1102,27 +1082,12 @@ fn do_cast_string_to_int<
                     return none_or_err(eval_mode, type_name, str);
                 }
             }
-        }
-
-        if state == State::ParseFractionalDigits {
-            // This is the case when we've encountered a decimal separator. 
The fractional
-            // part will not change the number, but we will verify that the 
fractional part
-            // is well-formed.
-            if ch.is_whitespace() {
-                // finished parsing fractional digits, now need to skip 
trailing whitespace
-                state = State::SkipTrailingWhiteSpace;
-                // consume this char
-                continue;
-            }
+        } else {
+            // make sure fractional digits are valid digits but ignore them
             if !ch.is_ascii_digit() {
                 return none_or_err(eval_mode, type_name, str);
             }
         }
-
-        // skip trailing whitespace
-        if state == State::SkipTrailingWhiteSpace && !ch.is_whitespace() {
-            return none_or_err(eval_mode, type_name, str);
-        }
     }
 
     if !negative {
diff --git a/docs/source/user-guide/compatibility.md 
b/docs/source/user-guide/compatibility.md
index 278edb84..a4ed9289 100644
--- a/docs/source/user-guide/compatibility.md
+++ b/docs/source/user-guide/compatibility.md
@@ -110,6 +110,10 @@ The following cast operations are generally compatible 
with Spark except for the
 | decimal | float |  |
 | decimal | double |  |
 | string | boolean |  |
+| string | byte |  |
+| string | short |  |
+| string | integer |  |
+| string | long |  |
 | string | binary |  |
 | date | string |  |
 | timestamp | long |  |
@@ -125,10 +129,6 @@ The following cast operations are not compatible with 
Spark for all inputs and a
 |-|-|-|
 | integer | decimal  | No overflow check |
 | long | decimal  | No overflow check |
-| string | byte  | Not all invalid inputs are detected |
-| string | short  | Not all invalid inputs are detected |
-| string | integer  | Not all invalid inputs are detected |
-| string | long  | Not all invalid inputs are detected |
 | string | timestamp  | Not all valid formats are supported |
 | binary | string  | Only works for binary data representing valid UTF-8 
strings |
 
diff --git a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala 
b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
index 9c3695ba..795bdb42 100644
--- a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
+++ b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
@@ -108,7 +108,7 @@ object CometCast {
         Compatible()
       case DataTypes.ByteType | DataTypes.ShortType | DataTypes.IntegerType |
           DataTypes.LongType =>
-        Incompatible(Some("Not all invalid inputs are detected"))
+        Compatible()
       case DataTypes.BinaryType =>
         Compatible()
       case DataTypes.FloatType | DataTypes.DoubleType =>
diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala 
b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
index ea3355d0..8caba14c 100644
--- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
@@ -519,28 +519,28 @@ class CometCastSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
     "9223372036854775808" // Long.MaxValue + 1
   )
 
-  ignore("cast StringType to ByteType") {
+  test("cast StringType to ByteType") {
     // test with hand-picked values
     castTest(castStringToIntegralInputs.toDF("a"), DataTypes.ByteType)
     // fuzz test
     castTest(gen.generateStrings(dataSize, numericPattern, 4).toDF("a"), 
DataTypes.ByteType)
   }
 
-  ignore("cast StringType to ShortType") {
+  test("cast StringType to ShortType") {
     // test with hand-picked values
     castTest(castStringToIntegralInputs.toDF("a"), DataTypes.ShortType)
     // fuzz test
     castTest(gen.generateStrings(dataSize, numericPattern, 5).toDF("a"), 
DataTypes.ShortType)
   }
 
-  ignore("cast StringType to IntegerType") {
+  test("cast StringType to IntegerType") {
     // test with hand-picked values
     castTest(castStringToIntegralInputs.toDF("a"), DataTypes.IntegerType)
     // fuzz test
     castTest(gen.generateStrings(dataSize, numericPattern, 8).toDF("a"), 
DataTypes.IntegerType)
   }
 
-  ignore("cast StringType to LongType") {
+  test("cast StringType to LongType") {
     // test with hand-picked values
     castTest(castStringToIntegralInputs.toDF("a"), DataTypes.LongType)
     // fuzz test


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to