This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 605a7842e Faster decimal parsing (30-60%) (#3939)
605a7842e is described below

commit 605a7842e87abbbdc26c310a82abb4398000a43d
Author: bold <[email protected]>
AuthorDate: Fri Mar 31 15:50:50 2023 +0200

    Faster decimal parsing (30-60%) (#3939)
    
    * Improve decimal parsing
    
    * Add edge tests for decimal parsing
    
    * Add more decimal parsing tests
    
    Co-authored-by: Raphael Taylor-Davies 
<[email protected]>
    
    * Fix test and improve performance further
    
    * Move overflow check out of the loop
    
    * Fix "0" parsing
    
    * Add failing decimal parsing tests
    
    * Fix parse decimal tests
    
    ---------
    
    Co-authored-by: Raphael Taylor-Davies 
<[email protected]>
---
 arrow-cast/src/parse.rs | 159 ++++++++++++++++++++++++++++++++----------------
 1 file changed, 107 insertions(+), 52 deletions(-)

diff --git a/arrow-cast/src/parse.rs b/arrow-cast/src/parse.rs
index d7e5529bc..cc8254916 100644
--- a/arrow-cast/src/parse.rs
+++ b/arrow-cast/src/parse.rs
@@ -623,64 +623,64 @@ pub fn parse_decimal<T: DecimalType>(
     precision: u8,
     scale: i8,
 ) -> Result<T::Native, ArrowError> {
-    let mut seen_dot = false;
-    let mut seen_sign = false;
-    let mut negative = false;
-
     let mut result = T::Native::usize_as(0);
     let mut fractionals = 0;
     let mut digits = 0;
     let base = T::Native::usize_as(10);
-    let mut bs = s.as_bytes().iter();
+
+    let bs = s.as_bytes();
+    let (bs, negative) = match bs.first() {
+        Some(b'-') => (&bs[1..], true),
+        Some(b'+') => (&bs[1..], false),
+        _ => (bs, false),
+    };
+
+    if bs.is_empty() {
+        return Err(ArrowError::ParseError(format!(
+            "can't parse the string value {s} to decimal"
+        )));
+    }
+
+    let mut bs = bs.iter();
+    // Overflow checks are not required if 10^(precision - 1) <= T::MAX holds.
+    // Thus, if we validate the precision correctly, we can skip overflow 
checks.
     while let Some(b) = bs.next() {
         match b {
             b'0'..=b'9' => {
-                if seen_dot {
-                    if fractionals == scale {
-                        // We have processed and validated the whole part of 
our decimal (including sign and dot).
-                        // All that is left is to validate the fractional part.
-                        if bs.any(|b| !b.is_ascii_digit()) {
-                            return Err(ArrowError::ParseError(format!(
-                                "can't parse the string value {s} to decimal"
-                            )));
-                        }
-                        break;
-                    }
-                    fractionals += 1;
+                if digits == 0 && *b == b'0' {
+                    // Ignore leading zeros.
+                    continue;
                 }
                 digits += 1;
-                if digits > precision {
-                    return Err(ArrowError::ParseError(
-                        "parse decimal overflow".to_string(),
-                    ));
-                }
-                result = result.mul_checked(base)?;
-                result = result.add_checked(T::Native::usize_as((b - b'0') as 
usize))?;
+                result = result.mul_wrapping(base);
+                result = result.add_wrapping(T::Native::usize_as((b - b'0') as 
usize));
             }
             b'.' => {
-                if seen_dot {
-                    return Err(ArrowError::ParseError(format!(
-                        "can't parse the string value {s} to decimal"
-                    )));
-                }
-                seen_dot = true;
-            }
-            b'-' => {
-                if seen_sign || digits > 0 || seen_dot {
-                    return Err(ArrowError::ParseError(format!(
-                        "can't parse the string value {s} to decimal"
-                    )));
+                for b in bs.by_ref() {
+                    if !b.is_ascii_digit() {
+                        return Err(ArrowError::ParseError(format!(
+                            "can't parse the string value {s} to decimal"
+                        )));
+                    }
+                    if fractionals == scale {
+                        // We have processed all the digits that we need. All 
that
+                        // is left is to validate that the rest of the string 
contains
+                        // valid digits.
+                        continue;
+                    }
+                    fractionals += 1;
+                    digits += 1;
+                    result = result.mul_wrapping(base);
+                    result =
+                        result.add_wrapping(T::Native::usize_as((b - b'0') as 
usize));
                 }
-                seen_sign = true;
-                negative = true;
-            }
-            b'+' => {
-                if seen_sign || digits > 0 || seen_dot {
+
+                // Fail on "."
+                if digits == 0 {
                     return Err(ArrowError::ParseError(format!(
                         "can't parse the string value {s} to decimal"
                     )));
                 }
-                seen_sign = true;
             }
             _ => {
                 return Err(ArrowError::ParseError(format!(
@@ -689,24 +689,20 @@ pub fn parse_decimal<T: DecimalType>(
             }
         }
     }
-    // Fail on "."
-    if digits == 0 {
-        return Err(ArrowError::ParseError(format!(
-            "can't parse the string value {s} to decimal"
-        )));
-    }
 
     if fractionals < scale {
         let exp = scale - fractionals;
         if exp as u8 + digits > precision {
             return Err(ArrowError::ParseError("parse decimal 
overflow".to_string()));
         }
-        let mul = base.pow_checked(exp as _)?;
-        result = result.mul_checked(mul)?;
+        let mul = base.pow_wrapping(exp as _);
+        result = result.mul_wrapping(mul);
+    } else if digits > precision {
+        return Err(ArrowError::ParseError("parse decimal 
overflow".to_string()));
     }
 
     Ok(if negative {
-        result.neg_checked()?
+        result.neg_wrapping()
     } else {
         result
     })
@@ -1689,6 +1685,7 @@ mod tests {
     #[test]
     fn test_parse_decimal_with_parameter() {
         let tests = [
+            ("0", 0i128),
             ("123.123", 123123i128),
             ("123.1234", 123123i128),
             ("123.1", 123100i128),
@@ -1717,7 +1714,7 @@ mod tests {
             let result_256 = parse_decimal::<Decimal256Type>(s, 20, 3);
             assert_eq!(i256::from_i128(i), result_256.unwrap());
         }
-        let can_not_parse_tests = ["123,123", ".", "123.123.123"];
+        let can_not_parse_tests = ["123,123", ".", "123.123.123", "", "+", 
"-"];
         for s in can_not_parse_tests {
             let result_128 = parse_decimal::<Decimal128Type>(s, 20, 3);
             assert_eq!(
@@ -1750,5 +1747,63 @@ mod tests {
                 "actual: '{actual_256}', expected: '{expected_256}'"
             );
         }
+
+        let edge_tests_128 = [
+            (
+                "99999999999999999999999999999999999999",
+                99999999999999999999999999999999999999i128,
+                0,
+            ),
+            (
+                "999999999999999999999999999999999999.99",
+                99999999999999999999999999999999999999i128,
+                2,
+            ),
+            (
+                "9999999999999999999999999.9999999999999",
+                99999999999999999999999999999999999999i128,
+                13,
+            ),
+            (
+                "9999999999999999999999999",
+                99999999999999999999999990000000000000i128,
+                13,
+            ),
+            (
+                "0.99999999999999999999999999999999999999",
+                99999999999999999999999999999999999999i128,
+                38,
+            ),
+        ];
+        for (s, i, scale) in edge_tests_128 {
+            let result_128 = parse_decimal::<Decimal128Type>(s, 38, scale);
+            assert_eq!(i, result_128.unwrap());
+        }
+        let edge_tests_256 = [
+            (
+                
"9999999999999999999999999999999999999999999999999999999999999999999999999999",
+i256::from_string("9999999999999999999999999999999999999999999999999999999999999999999999999999").unwrap(),
+                0,
+            ),
+            (
+                
"999999999999999999999999999999999999999999999999999999999999999999999999.9999",
+                
i256::from_string("9999999999999999999999999999999999999999999999999999999999999999999999999999").unwrap(),
+                4,
+            ),
+            (
+                
"99999999999999999999999999999999999999999999999999.99999999999999999999999999",
+                
i256::from_string("9999999999999999999999999999999999999999999999999999999999999999999999999999").unwrap(),
+                26,
+            ),
+            (
+                "99999999999999999999999999999999999999999999999999",
+                
i256::from_string("9999999999999999999999999999999999999999999999999900000000000000000000000000").unwrap(),
+                26,
+            ),
+        ];
+        for (s, i, scale) in edge_tests_256 {
+            let result = parse_decimal::<Decimal256Type>(s, 76, scale);
+            assert_eq!(i, result.unwrap());
+        }
     }
 }

Reply via email to