This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new d84a1a6152f String to decimal conversion written using E/scientific 
notation (#5611)
d84a1a6152f is described below

commit d84a1a6152f5c7b52e44f935fe2f0421d09d14b8
Author: Nekit2217 <[email protected]>
AuthorDate: Mon Apr 15 14:56:10 2024 +0300

    String to decimal conversion written using E/scientific notation (#5611)
    
    * Added decimal casting for the E notation
    
    * Clippy
    
    * Fmt
    
    * Added several new tests
    Fixed error e without fractional part (e.g. 0e-8, 1e7)
    
    * Clippy, fmt
    
    * parse_e_notation function minor rework
    added additional tests for e/scientific notation
---
 arrow-cast/src/parse.rs | 294 +++++++++++++++++++++++++++++++++++++++++++++---
 1 file changed, 279 insertions(+), 15 deletions(-)

diff --git a/arrow-cast/src/parse.rs b/arrow-cast/src/parse.rs
index afa00f17629..3102fcc58ad 100644
--- a/arrow-cast/src/parse.rs
+++ b/arrow-cast/src/parse.rs
@@ -671,6 +671,126 @@ impl Parser for Date64Type {
     }
 }
 
+fn parse_e_notation<T: DecimalType>(
+    s: &str,
+    mut digits: u16,
+    mut fractionals: i16,
+    mut result: T::Native,
+    index: usize,
+    precision: u16,
+    scale: i16,
+) -> Result<T::Native, ArrowError> {
+    let mut exp: i16 = 0;
+    let base = T::Native::usize_as(10);
+
+    let mut exp_start: bool = false;
+    // e has a plus sign
+    let mut pos_shift_direction: bool = true;
+
+    // skip to point or exponent index
+    let mut bs;
+    if fractionals > 0 {
+        // it's a fraction, so the point index needs to be skipped, so +1
+        bs = s.as_bytes().iter().skip(index + fractionals as usize + 1);
+    } else {
+        // it's actually an integer that is already written into the result, 
so let's skip on to e
+        bs = s.as_bytes().iter().skip(index);
+    }
+
+    while let Some(b) = bs.next() {
+        match b {
+            b'0'..=b'9' => {
+                result = result.mul_wrapping(base);
+                result = result.add_wrapping(T::Native::usize_as((b - b'0') as 
usize));
+                if fractionals > 0 {
+                    fractionals += 1;
+                }
+                digits += 1;
+            }
+            &b'e' | &b'E' => {
+                exp_start = true;
+            }
+            _ => {
+                return Err(ArrowError::ParseError(format!(
+                    "can't parse the string value {s} to decimal"
+                )));
+            }
+        };
+
+        if exp_start {
+            pos_shift_direction = match bs.next() {
+                Some(&b'-') => false,
+                Some(&b'+') => true,
+                Some(b) => {
+                    if !b.is_ascii_digit() {
+                        return Err(ArrowError::ParseError(format!(
+                            "can't parse the string value {s} to decimal"
+                        )));
+                    }
+
+                    exp *= 10;
+                    exp += (b - b'0') as i16;
+
+                    true
+                }
+                None => {
+                    return Err(ArrowError::ParseError(format!(
+                        "can't parse the string value {s} to decimal"
+                    )))
+                }
+            };
+
+            for b in bs.by_ref() {
+                if !b.is_ascii_digit() {
+                    return Err(ArrowError::ParseError(format!(
+                        "can't parse the string value {s} to decimal"
+                    )));
+                }
+                exp *= 10;
+                exp += (b - b'0') as i16;
+            }
+        }
+    }
+
+    if digits == 0 && fractionals == 0 && exp == 0 {
+        return Err(ArrowError::ParseError(format!(
+            "can't parse the string value {s} to decimal"
+        )));
+    }
+
+    if !pos_shift_direction {
+        // exponent has a large negative sign
+        // 1.12345e-30 => 0.0{29}12345, scale = 5
+        if exp - (digits as i16 + scale) > 0 {
+            return Ok(T::Native::usize_as(0));
+        }
+        exp *= -1;
+    }
+
+    // point offset
+    exp = fractionals - exp;
+    // We have zeros on the left, we need to count them
+    if !pos_shift_direction && exp > digits as i16 {
+        digits = exp as u16;
+    }
+    // Number of numbers to be removed or added
+    exp = scale - exp;
+
+    if (digits as i16 + exp) as u16 > precision {
+        return Err(ArrowError::ParseError(format!(
+            "parse decimal overflow ({s})"
+        )));
+    }
+
+    if exp < 0 {
+        result = result.div_wrapping(base.pow_wrapping(-exp as _));
+    } else {
+        result = result.mul_wrapping(base.pow_wrapping(exp as _));
+    }
+
+    Ok(result)
+}
+
 /// Parse the string format decimal value to i128/i256 format and checking the 
precision and scale.
 /// The result value can't be out of bounds.
 pub fn parse_decimal<T: DecimalType>(
@@ -679,8 +799,8 @@ pub fn parse_decimal<T: DecimalType>(
     scale: i8,
 ) -> Result<T::Native, ArrowError> {
     let mut result = T::Native::usize_as(0);
-    let mut fractionals = 0;
-    let mut digits = 0;
+    let mut fractionals: i8 = 0;
+    let mut digits: u8 = 0;
     let base = T::Native::usize_as(10);
 
     let bs = s.as_bytes();
@@ -696,10 +816,13 @@ pub fn parse_decimal<T: DecimalType>(
         )));
     }
 
-    let mut bs = bs.iter();
+    let mut bs = bs.iter().enumerate();
+
+    let mut is_e_notation = false;
+
     // Overflow checks are not required if 10^(precision - 1) <= T::MAX holds.
     // Thus, if we validate the precision correctly, we can skip overflow 
checks.
-    while let Some(b) = bs.next() {
+    while let Some((index, b)) = bs.next() {
         match b {
             b'0'..=b'9' => {
                 if digits == 0 && *b == b'0' {
@@ -711,8 +834,28 @@ pub fn parse_decimal<T: DecimalType>(
                 result = result.add_wrapping(T::Native::usize_as((b - b'0') as 
usize));
             }
             b'.' => {
-                for b in bs.by_ref() {
+                let point_index = index;
+
+                for (_, b) in bs.by_ref() {
                     if !b.is_ascii_digit() {
+                        if *b == b'e' || *b == b'E' {
+                            result = match parse_e_notation::<T>(
+                                s,
+                                digits as u16,
+                                fractionals as i16,
+                                result,
+                                point_index,
+                                precision as u16,
+                                scale as i16,
+                            ) {
+                                Err(e) => return Err(e),
+                                Ok(v) => v,
+                            };
+
+                            is_e_notation = true;
+
+                            break;
+                        }
                         return Err(ArrowError::ParseError(format!(
                             "can't parse the string value {s} to decimal"
                         )));
@@ -729,6 +872,10 @@ pub fn parse_decimal<T: DecimalType>(
                     result = result.add_wrapping(T::Native::usize_as((b - 
b'0') as usize));
                 }
 
+                if is_e_notation {
+                    break;
+                }
+
                 // Fail on "."
                 if digits == 0 {
                     return Err(ArrowError::ParseError(format!(
@@ -736,6 +883,24 @@ pub fn parse_decimal<T: DecimalType>(
                     )));
                 }
             }
+            b'e' | b'E' => {
+                result = match parse_e_notation::<T>(
+                    s,
+                    digits as u16,
+                    fractionals as i16,
+                    result,
+                    index,
+                    precision as u16,
+                    scale as i16,
+                ) {
+                    Err(e) => return Err(e),
+                    Ok(v) => v,
+                };
+
+                is_e_notation = true;
+
+                break;
+            }
             _ => {
                 return Err(ArrowError::ParseError(format!(
                     "can't parse the string value {s} to decimal"
@@ -744,15 +909,21 @@ pub fn parse_decimal<T: DecimalType>(
         }
     }
 
-    if fractionals < scale {
-        let exp = scale - fractionals;
-        if exp as u8 + digits > precision {
-            return Err(ArrowError::ParseError("parse decimal 
overflow".to_string()));
+    if !is_e_notation {
+        if fractionals < scale {
+            let exp = scale - fractionals;
+            if exp as u8 + digits > precision {
+                return Err(ArrowError::ParseError(format!(
+                    "parse decimal overflow ({s})"
+                )));
+            }
+            let mul = base.pow_wrapping(exp as _);
+            result = result.mul_wrapping(mul);
+        } else if digits > precision {
+            return Err(ArrowError::ParseError(format!(
+                "parse decimal overflow ({s})"
+            )));
         }
-        let mul = base.pow_wrapping(exp as _);
-        result = result.mul_wrapping(mul);
-    } else if digits > precision {
-        return Err(ArrowError::ParseError("parse decimal 
overflow".to_string()));
     }
 
     Ok(if negative {
@@ -2202,7 +2373,65 @@ mod tests {
             let result_256 = parse_decimal::<Decimal256Type>(s, 20, 3);
             assert_eq!(i256::from_i128(i), result_256.unwrap());
         }
-        let can_not_parse_tests = ["123,123", ".", "123.123.123", "", "+", 
"-"];
+
+        let e_notation_tests = [
+            ("1.23e3", "1230.0", 2),
+            ("5.6714e+2", "567.14", 4),
+            ("5.6714e-2", "0.056714", 4),
+            ("5.6714e-2", "0.056714", 3),
+            ("5.6741214125e2", "567.41214125", 4),
+            ("8.91E4", "89100.0", 2),
+            ("3.14E+5", "314000.0", 2),
+            ("2.718e0", "2.718", 2),
+            ("9.999999e-1", "0.9999999", 4),
+            ("1.23e+3", "1230", 2),
+            ("1.234559e+3", "1234.559", 2),
+            ("1.00E-10", "0.0000000001", 11),
+            ("1.23e-4", "0.000123", 2),
+            ("9.876e7", "98760000.0", 2),
+            ("5.432E+8", "543200000.0", 10),
+            ("1.234567e9", "1234567000.0", 2),
+            ("1.234567e2", "123.45670000", 2),
+            ("4749.3e-5", "0.047493", 10),
+            ("4749.3e+5", "474930000", 10),
+            ("4749.3e-5", "0.047493", 1),
+            ("4749.3e+5", "474930000", 1),
+            ("0E-8", "0", 10),
+            ("0E+6", "0", 10),
+            ("1E-8", "0.00000001", 10),
+            ("12E+6", "12000000", 10),
+            ("12E-6", "0.000012", 10),
+            ("0.1e-6", "0.0000001", 10),
+            ("0.1e+6", "100000", 10),
+            ("0.12e-6", "0.00000012", 10),
+            ("0.12e+6", "120000", 10),
+            ("000000000001e0", "000000000001", 3),
+            ("000001.1034567002e0", "000001.1034567002", 3),
+        ];
+        for (e, d, scale) in e_notation_tests {
+            let result_128_e = parse_decimal::<Decimal128Type>(e, 20, scale);
+            let result_128_d = parse_decimal::<Decimal128Type>(d, 20, scale);
+            assert_eq!(result_128_e.unwrap(), result_128_d.unwrap());
+            let result_256_e = parse_decimal::<Decimal256Type>(e, 20, scale);
+            let result_256_d = parse_decimal::<Decimal256Type>(d, 20, scale);
+            assert_eq!(result_256_e.unwrap(), result_256_d.unwrap());
+        }
+        let can_not_parse_tests = [
+            "123,123",
+            ".",
+            "123.123.123",
+            "",
+            "+",
+            "-",
+            "e",
+            "1.3e+e3",
+            "5.6714ee-2",
+            "4.11ee-+4",
+            "4.11e++4",
+            "1.1e.12",
+            "1.23e+3.",
+            "1.23e+3.1",
+        ];
         for s in can_not_parse_tests {
             let result_128 = parse_decimal::<Decimal128Type>(s, 20, 3);
             assert_eq!(
@@ -2215,7 +2444,16 @@ mod tests {
                 result_256.unwrap_err().to_string()
             );
         }
-        let overflow_parse_tests = ["12345678", "12345678.9", "99999999.99"];
+        let overflow_parse_tests = [
+            "12345678",
+            "1.2345678e7",
+            "12345678.9",
+            "1.23456789e+7",
+            "99999999.99",
+            "9.999999999e7",
+            "12345678908765.123456",
+            "123456789087651234.56e-4",
+        ];
         for s in overflow_parse_tests {
             let result_128 = parse_decimal::<Decimal128Type>(s, 10, 3);
             let expected_128 = "Parser error: parse decimal overflow";
@@ -2262,6 +2500,16 @@ mod tests {
                 99999999999999999999999999999999999999i128,
                 38,
             ),
+            (
+                
"0.00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001016744",
+                0i128,
+                15,
+            ),
+            (
+                "1.016744e-320",
+                0i128,
+                15,
+            ),
         ];
         for (s, i, scale) in edge_tests_128 {
             let result_128 = parse_decimal::<Decimal128Type>(s, 38, scale);
@@ -2292,6 +2540,14 @@ mod tests {
                 .unwrap(),
                 26,
             ),
+            (
+                
"9.999999999999999999999999999999999999999999999999999999999999999999999999999e49",
+                i256::from_string(
+                    
"9999999999999999999999999999999999999999999999999999999999999999999999999999",
+                )
+                .unwrap(),
+                26,
+            ),
             (
                 "99999999999999999999999999999999999999999999999999",
                 i256::from_string(
@@ -2300,6 +2556,14 @@ mod tests {
                 .unwrap(),
                 26,
             ),
+            (
+                "9.9999999999999999999999999999999999999999999999999e+49",
+                i256::from_string(
+                    
"9999999999999999999999999999999999999999999999999900000000000000000000000000",
+                )
+                .unwrap(),
+                26,
+            ),
         ];
         for (s, i, scale) in edge_tests_256 {
             let result = parse_decimal::<Decimal256Type>(s, 76, scale);

Reply via email to