This is an automated email from the ASF dual-hosted git repository.
tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new 605a7842e Faster decimal parsing (30-60%) (#3939)
605a7842e is described below
commit 605a7842e87abbbdc26c310a82abb4398000a43d
Author: bold <[email protected]>
AuthorDate: Fri Mar 31 15:50:50 2023 +0200
Faster decimal parsing (30-60%) (#3939)
* Improve decimal parsing
* Add edge tests for decimal parsing
* Add more decimal parsing tests
Co-authored-by: Raphael Taylor-Davies
<[email protected]>
* Fix test and improve performance further
* Move overflow check out of the loop
* Fix "0" parsing
* Add failing decimal parsing tests
* Fix parse decimal tests
---------
Co-authored-by: Raphael Taylor-Davies
<[email protected]>
---
arrow-cast/src/parse.rs | 159 ++++++++++++++++++++++++++++++++----------------
1 file changed, 107 insertions(+), 52 deletions(-)
diff --git a/arrow-cast/src/parse.rs b/arrow-cast/src/parse.rs
index d7e5529bc..cc8254916 100644
--- a/arrow-cast/src/parse.rs
+++ b/arrow-cast/src/parse.rs
@@ -623,64 +623,64 @@ pub fn parse_decimal<T: DecimalType>(
precision: u8,
scale: i8,
) -> Result<T::Native, ArrowError> {
- let mut seen_dot = false;
- let mut seen_sign = false;
- let mut negative = false;
-
let mut result = T::Native::usize_as(0);
let mut fractionals = 0;
let mut digits = 0;
let base = T::Native::usize_as(10);
- let mut bs = s.as_bytes().iter();
+
+ let bs = s.as_bytes();
+ let (bs, negative) = match bs.first() {
+ Some(b'-') => (&bs[1..], true),
+ Some(b'+') => (&bs[1..], false),
+ _ => (bs, false),
+ };
+
+ if bs.is_empty() {
+ return Err(ArrowError::ParseError(format!(
+ "can't parse the string value {s} to decimal"
+ )));
+ }
+
+ let mut bs = bs.iter();
+ // Overflow checks are not required if 10^(precision - 1) <= T::MAX holds.
+ // Thus, if we validate the precision correctly, we can skip overflow
checks.
while let Some(b) = bs.next() {
match b {
b'0'..=b'9' => {
- if seen_dot {
- if fractionals == scale {
- // We have processed and validated the whole part of
our decimal (including sign and dot).
- // All that is left is to validate the fractional part.
- if bs.any(|b| !b.is_ascii_digit()) {
- return Err(ArrowError::ParseError(format!(
- "can't parse the string value {s} to decimal"
- )));
- }
- break;
- }
- fractionals += 1;
+ if digits == 0 && *b == b'0' {
+ // Ignore leading zeros.
+ continue;
}
digits += 1;
- if digits > precision {
- return Err(ArrowError::ParseError(
- "parse decimal overflow".to_string(),
- ));
- }
- result = result.mul_checked(base)?;
- result = result.add_checked(T::Native::usize_as((b - b'0') as
usize))?;
+ result = result.mul_wrapping(base);
+ result = result.add_wrapping(T::Native::usize_as((b - b'0') as
usize));
}
b'.' => {
- if seen_dot {
- return Err(ArrowError::ParseError(format!(
- "can't parse the string value {s} to decimal"
- )));
- }
- seen_dot = true;
- }
- b'-' => {
- if seen_sign || digits > 0 || seen_dot {
- return Err(ArrowError::ParseError(format!(
- "can't parse the string value {s} to decimal"
- )));
+ for b in bs.by_ref() {
+ if !b.is_ascii_digit() {
+ return Err(ArrowError::ParseError(format!(
+ "can't parse the string value {s} to decimal"
+ )));
+ }
+ if fractionals == scale {
+ // We have processed all the digits that we need. All
that
+ // is left is to validate that the rest of the string
contains
+ // valid digits.
+ continue;
+ }
+ fractionals += 1;
+ digits += 1;
+ result = result.mul_wrapping(base);
+ result =
+ result.add_wrapping(T::Native::usize_as((b - b'0') as
usize));
}
- seen_sign = true;
- negative = true;
- }
- b'+' => {
- if seen_sign || digits > 0 || seen_dot {
+
+ // Fail on "."
+ if digits == 0 {
return Err(ArrowError::ParseError(format!(
"can't parse the string value {s} to decimal"
)));
}
- seen_sign = true;
}
_ => {
return Err(ArrowError::ParseError(format!(
@@ -689,24 +689,20 @@ pub fn parse_decimal<T: DecimalType>(
}
}
}
- // Fail on "."
- if digits == 0 {
- return Err(ArrowError::ParseError(format!(
- "can't parse the string value {s} to decimal"
- )));
- }
if fractionals < scale {
let exp = scale - fractionals;
if exp as u8 + digits > precision {
return Err(ArrowError::ParseError("parse decimal
overflow".to_string()));
}
- let mul = base.pow_checked(exp as _)?;
- result = result.mul_checked(mul)?;
+ let mul = base.pow_wrapping(exp as _);
+ result = result.mul_wrapping(mul);
+ } else if digits > precision {
+ return Err(ArrowError::ParseError("parse decimal
overflow".to_string()));
}
Ok(if negative {
- result.neg_checked()?
+ result.neg_wrapping()
} else {
result
})
@@ -1689,6 +1685,7 @@ mod tests {
#[test]
fn test_parse_decimal_with_parameter() {
let tests = [
+ ("0", 0i128),
("123.123", 123123i128),
("123.1234", 123123i128),
("123.1", 123100i128),
@@ -1717,7 +1714,7 @@ mod tests {
let result_256 = parse_decimal::<Decimal256Type>(s, 20, 3);
assert_eq!(i256::from_i128(i), result_256.unwrap());
}
- let can_not_parse_tests = ["123,123", ".", "123.123.123"];
+ let can_not_parse_tests = ["123,123", ".", "123.123.123", "", "+",
"-"];
for s in can_not_parse_tests {
let result_128 = parse_decimal::<Decimal128Type>(s, 20, 3);
assert_eq!(
@@ -1750,5 +1747,63 @@ mod tests {
"actual: '{actual_256}', expected: '{expected_256}'"
);
}
+
+ let edge_tests_128 = [
+ (
+ "99999999999999999999999999999999999999",
+ 99999999999999999999999999999999999999i128,
+ 0,
+ ),
+ (
+ "999999999999999999999999999999999999.99",
+ 99999999999999999999999999999999999999i128,
+ 2,
+ ),
+ (
+ "9999999999999999999999999.9999999999999",
+ 99999999999999999999999999999999999999i128,
+ 13,
+ ),
+ (
+ "9999999999999999999999999",
+ 99999999999999999999999990000000000000i128,
+ 13,
+ ),
+ (
+ "0.99999999999999999999999999999999999999",
+ 99999999999999999999999999999999999999i128,
+ 38,
+ ),
+ ];
+ for (s, i, scale) in edge_tests_128 {
+ let result_128 = parse_decimal::<Decimal128Type>(s, 38, scale);
+ assert_eq!(i, result_128.unwrap());
+ }
+ let edge_tests_256 = [
+ (
+
"9999999999999999999999999999999999999999999999999999999999999999999999999999",
+i256::from_string("9999999999999999999999999999999999999999999999999999999999999999999999999999").unwrap(),
+ 0,
+ ),
+ (
+
"999999999999999999999999999999999999999999999999999999999999999999999999.9999",
+
i256::from_string("9999999999999999999999999999999999999999999999999999999999999999999999999999").unwrap(),
+ 4,
+ ),
+ (
+
"99999999999999999999999999999999999999999999999999.99999999999999999999999999",
+
i256::from_string("9999999999999999999999999999999999999999999999999999999999999999999999999999").unwrap(),
+ 26,
+ ),
+ (
+ "99999999999999999999999999999999999999999999999999",
+
i256::from_string("9999999999999999999999999999999999999999999999999900000000000000000000000000").unwrap(),
+ 26,
+ ),
+ ];
+ for (s, i, scale) in edge_tests_256 {
+ let result = parse_decimal::<Decimal256Type>(s, 76, scale);
+ assert_eq!(i, result.unwrap());
+ }
}
}