This is an automated email from the ASF dual-hosted git repository.
tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new 22c138156 Feat: arrow csv decimal256 (#3711)
22c138156 is described below
commit 22c138156715bf62c8c683fb94e947f7a3200149
Author: 苏小刚 <[email protected]>
AuthorDate: Wed Feb 15 06:36:05 2023 +0800
Feat: arrow csv decimal256 (#3711)
* add test for arrow-csv Decimal256
* pass the test
There is still room for improvement in the code
* add test_write_csv_decimal for csv_writer
* support i128 and i256 in one generic function
* the test parse_decimal need Neg
* Update arrow-array/src/array/primitive_array.rs
This will allow simplifying trait bounds in a number of other places
Co-authored-by: Raphael Taylor-Davies
<[email protected]>
* return an error instead of panicking on overflow
* Decimal256(76, 6)
* adding test for Decimal256Type
---------
Co-authored-by: suxiaogang <[email protected]>
Co-authored-by: Raphael Taylor-Davies
<[email protected]>
---
arrow-array/src/array/primitive_array.rs | 4 +-
arrow-csv/src/reader/mod.rs | 141 +++++++++++++++++++++----------
arrow-csv/src/writer.rs | 55 ++++++++++++
3 files changed, 155 insertions(+), 45 deletions(-)
diff --git a/arrow-array/src/array/primitive_array.rs
b/arrow-array/src/array/primitive_array.rs
index aeece612d..b64534e98 100644
--- a/arrow-array/src/array/primitive_array.rs
+++ b/arrow-array/src/array/primitive_array.rs
@@ -23,8 +23,8 @@ use crate::temporal_conversions::{
};
use crate::timezone::Tz;
use crate::trusted_len::trusted_len_unzip;
-use crate::types::*;
use crate::{print_long_array, Array, ArrayAccessor};
+use crate::{types::*, ArrowNativeTypeOp};
use arrow_buffer::{i256, ArrowNativeType, Buffer};
use arrow_data::bit_iterator::try_for_each_valid_idx;
use arrow_data::ArrayData;
@@ -233,7 +233,7 @@ pub type Decimal256Array = PrimitiveArray<Decimal256Type>;
/// static-typed nature of rust types ([`ArrowNativeType`]) for all types that
implement [`ArrowNativeType`].
pub trait ArrowPrimitiveType: 'static {
/// Corresponding Rust native type for the primitive type.
- type Native: ArrowNativeType;
+ type Native: ArrowNativeTypeOp;
/// the corresponding Arrow data type of this primitive type.
const DATA_TYPE: DataType;
diff --git a/arrow-csv/src/reader/mod.rs b/arrow-csv/src/reader/mod.rs
index 610f05155..c5fe20e9d 100644
--- a/arrow-csv/src/reader/mod.rs
+++ b/arrow-csv/src/reader/mod.rs
@@ -42,6 +42,13 @@
mod records;
+use arrow_array::builder::PrimitiveBuilder;
+use arrow_array::types::*;
+use arrow_array::ArrowNativeTypeOp;
+use arrow_array::*;
+use arrow_buffer::ArrowNativeType;
+use arrow_cast::parse::Parser;
+use arrow_schema::*;
use lazy_static::lazy_static;
use regex::{Regex, RegexSet};
use std::collections::HashSet;
@@ -50,17 +57,9 @@ use std::fs::File;
use std::io::{BufRead, BufReader as StdBufReader, Read, Seek, SeekFrom};
use std::sync::Arc;
-use arrow_array::builder::Decimal128Builder;
-use arrow_array::types::*;
-use arrow_array::*;
-use arrow_cast::parse::Parser;
-use arrow_schema::*;
-
use crate::map_csv_error;
use crate::reader::records::{RecordDecoder, StringRecords};
-use arrow_data::decimal::validate_decimal_precision;
use csv::StringRecord;
-use std::ops::Neg;
lazy_static! {
static ref REGEX_SET: RegexSet = RegexSet::new([
@@ -608,7 +607,22 @@ fn parse(
match field.data_type() {
DataType::Boolean => build_boolean_array(line_number, rows, i),
DataType::Decimal128(precision, scale) => {
- build_decimal_array(line_number, rows, i, *precision,
*scale)
+ build_decimal_array::<Decimal128Type>(
+ line_number,
+ rows,
+ i,
+ *precision,
+ *scale,
+ )
+ }
+ DataType::Decimal256(precision, scale) => {
+ build_decimal_array::<Decimal256Type>(
+ line_number,
+ rows,
+ i,
+ *precision,
+ *scale,
+ )
}
DataType::Int8 => {
build_primitive_array::<Int8Type>(line_number, rows, i,
None)
@@ -781,22 +795,22 @@ fn parse_bool(string: &str) -> Option<bool> {
}
// parse the column string to an Arrow Array
-fn build_decimal_array(
+fn build_decimal_array<T: DecimalType>(
_line_number: usize,
rows: &StringRecords<'_>,
col_idx: usize,
precision: u8,
scale: i8,
) -> Result<ArrayRef, ArrowError> {
- let mut decimal_builder = Decimal128Builder::with_capacity(rows.len());
+ let mut decimal_builder = PrimitiveBuilder::<T>::with_capacity(rows.len());
for row in rows.iter() {
let s = row.get(col_idx);
if s.is_empty() {
// append null
decimal_builder.append_null();
} else {
- let decimal_value: Result<i128, _> =
- parse_decimal_with_parameter(s, precision, scale);
+ let decimal_value: Result<T::Native, _> =
+ parse_decimal_with_parameter::<T>(s, precision, scale);
match decimal_value {
Ok(v) => {
decimal_builder.append_value(v);
@@ -814,17 +828,17 @@ fn build_decimal_array(
))
}
-// Parse the string format decimal value to i128 format and checking the
precision and scale.
-// The result i128 value can't be out of bounds.
-fn parse_decimal_with_parameter(
+// Parse the string format decimal value to i128/i256 format and checking the
precision and scale.
+// The result value can't be out of bounds.
+fn parse_decimal_with_parameter<T: DecimalType>(
s: &str,
precision: u8,
scale: i8,
-) -> Result<i128, ArrowError> {
+) -> Result<T::Native, ArrowError> {
if PARSE_DECIMAL_RE.is_match(s) {
let mut offset = s.len();
let len = s.len();
- let mut base = 1;
+ let mut base = T::Native::usize_as(1);
let scale_usize = usize::from(scale as u8);
// handle the value after the '.' and meet the scale
@@ -832,7 +846,7 @@ fn parse_decimal_with_parameter(
match delimiter_position {
None => {
// there is no '.'
- base = 10_i128.pow(scale as u32);
+ base = T::Native::usize_as(10).pow_checked(scale as u32)?;
}
Some(mid) => {
// there is the '.'
@@ -841,7 +855,8 @@ fn parse_decimal_with_parameter(
offset -= len - mid - 1 - scale_usize;
} else {
// If the string value is "123.12" and the scale is 4, we
should append '00' to the tail.
- base = 10_i128.pow((scale_usize + 1 + mid - len) as u32);
+ base = T::Native::usize_as(10)
+ .pow_checked((scale_usize + 1 + mid - len) as u32)?;
}
}
};
@@ -849,25 +864,29 @@ fn parse_decimal_with_parameter(
// each byte is digit、'-' or '.'
let bytes = s.as_bytes();
let mut negative = false;
- let mut result: i128 = 0;
+ let mut result = T::Native::usize_as(0);
- bytes[0..offset].iter().rev().for_each(|&byte| match byte {
- b'-' => {
- negative = true;
- }
- b'0'..=b'9' => {
- result += i128::from(byte - b'0') * base;
- base *= 10;
+ for byte in bytes[0..offset].iter().rev() {
+ match byte {
+ b'-' => {
+ negative = true;
+ }
+ b'0'..=b'9' => {
+ let add =
+ T::Native::usize_as((byte - b'0') as
usize).mul_checked(base)?;
+ result = result.add_checked(add)?;
+ base = base.mul_checked(T::Native::usize_as(10))?;
+ }
+ // because of the PARSE_DECIMAL_RE, bytes just contains
digit、'-' and '.'.
+ _ => {}
}
- // because of the PARSE_DECIMAL_RE, bytes just contains digit、'-'
and '.'.
- _ => {}
- });
+ }
if negative {
- result = result.neg();
+ result = result.neg_checked()?;
}
- match validate_decimal_precision(result, precision) {
+ match T::validate_decimal_precision(result, precision) {
Ok(_) => Ok(result),
Err(e) => Err(ArrowError::ParseError(format!(
"parse decimal overflow: {e}"
@@ -884,6 +903,8 @@ fn parse_decimal_with_parameter(
// Like "125.12" to 12512_i128.
#[cfg(test)]
fn parse_decimal(s: &str) -> Result<i128, ArrowError> {
+ use std::ops::Neg;
+
if PARSE_DECIMAL_RE.is_match(s) {
let mut offset = s.len();
// each byte is digit、'-' or '.'
@@ -1230,6 +1251,7 @@ impl ReaderBuilder {
mod tests {
use super::*;
+ use arrow_buffer::i256;
use std::io::{Cursor, Write};
use tempfile::NamedTempFile;
@@ -1318,7 +1340,7 @@ mod tests {
let schema = Schema::new(vec![
Field::new("city", DataType::Utf8, false),
Field::new("lat", DataType::Decimal128(38, 6), false),
- Field::new("lng", DataType::Decimal128(38, 6), false),
+ Field::new("lng", DataType::Decimal256(76, 6), false),
]);
let file = File::open("test/data/decimal_test.csv").unwrap();
@@ -1343,6 +1365,23 @@ mod tests {
assert_eq!("123.000000", lat.value_as_string(7));
assert_eq!("123.000000", lat.value_as_string(8));
assert_eq!("-50.760000", lat.value_as_string(9));
+
+ let lng = batch
+ .column(2)
+ .as_any()
+ .downcast_ref::<Decimal256Array>()
+ .unwrap();
+
+ assert_eq!("-3.335724", lng.value_as_string(0));
+ assert_eq!("-2.179404", lng.value_as_string(1));
+ assert_eq!("-1.778197", lng.value_as_string(2));
+ assert_eq!("-3.179090", lng.value_as_string(3));
+ assert_eq!("-3.179090", lng.value_as_string(4));
+ assert_eq!("0.290472", lng.value_as_string(5));
+ assert_eq!("0.290472", lng.value_as_string(6));
+ assert_eq!("0.290472", lng.value_as_string(7));
+ assert_eq!("0.290472", lng.value_as_string(8));
+ assert_eq!("0.290472", lng.value_as_string(9));
}
#[test]
@@ -1788,26 +1827,42 @@ mod tests {
("-123.", -123000i128),
];
for (s, i) in tests {
- let result = parse_decimal_with_parameter(s, 20, 3);
- assert_eq!(i, result.unwrap())
+ let result_128 = parse_decimal_with_parameter::<Decimal128Type>(s,
20, 3);
+ assert_eq!(i, result_128.unwrap());
+ let result_256 = parse_decimal_with_parameter::<Decimal256Type>(s,
20, 3);
+ assert_eq!(i256::from_i128(i), result_256.unwrap());
}
let can_not_parse_tests = ["123,123", ".", "123.123.123"];
for s in can_not_parse_tests {
- let result = parse_decimal_with_parameter(s, 20, 3);
+ let result_128 = parse_decimal_with_parameter::<Decimal128Type>(s,
20, 3);
+ assert_eq!(
+ format!("Parser error: can't parse the string value {s} to
decimal"),
+ result_128.unwrap_err().to_string()
+ );
+ let result_256 = parse_decimal_with_parameter::<Decimal256Type>(s,
20, 3);
assert_eq!(
format!("Parser error: can't parse the string value {s} to
decimal"),
- result.unwrap_err().to_string()
+ result_256.unwrap_err().to_string()
);
}
let overflow_parse_tests = ["12345678", "12345678.9", "99999999.99"];
for s in overflow_parse_tests {
- let result = parse_decimal_with_parameter(s, 10, 3);
- let expected = "Parser error: parse decimal overflow";
- let actual = result.unwrap_err().to_string();
+ let result_128 = parse_decimal_with_parameter::<Decimal128Type>(s,
10, 3);
+ let expected_128 = "Parser error: parse decimal overflow";
+ let actual_128 = result_128.unwrap_err().to_string();
+
+ assert!(
+ actual_128.contains(expected_128),
+ "actual: '{actual_128}', expected: '{expected_128}'"
+ );
+
+ let result_256 = parse_decimal_with_parameter::<Decimal256Type>(s,
10, 3);
+ let expected_256 = "Parser error: parse decimal overflow";
+ let actual_256 = result_256.unwrap_err().to_string();
assert!(
- actual.contains(expected),
- "actual: '{actual}', expected: '{expected}'"
+ actual_256.contains(expected_256),
+ "actual: '{actual_256}', expected: '{expected_256}'"
);
}
}
diff --git a/arrow-csv/src/writer.rs b/arrow-csv/src/writer.rs
index e0734a15f..d9331053f 100644
--- a/arrow-csv/src/writer.rs
+++ b/arrow-csv/src/writer.rs
@@ -326,7 +326,9 @@ mod tests {
use super::*;
use crate::Reader;
+ use arrow_array::builder::{Decimal128Builder, Decimal256Builder};
use arrow_array::types::*;
+ use arrow_buffer::i256;
use std::io::{Cursor, Read, Seek};
use std::sync::Arc;
@@ -406,6 +408,59 @@ sed do eiusmod
tempor,-556132.25,1,,2019-04-18T02:45:55.555000000,23:46:03,foo
assert_eq!(expected.to_string(), String::from_utf8(buffer).unwrap());
}
+ #[test]
+ fn test_write_csv_decimal() {
+ let schema = Schema::new(vec![
+ Field::new("c1", DataType::Decimal128(38, 6), true),
+ Field::new("c2", DataType::Decimal256(76, 6), true),
+ ]);
+
+ let mut c1_builder =
+ Decimal128Builder::new().with_data_type(DataType::Decimal128(38,
6));
+ c1_builder.extend(vec![Some(-3335724), Some(2179404), None,
Some(290472)]);
+ let c1 = c1_builder.finish();
+
+ let mut c2_builder =
+ Decimal256Builder::new().with_data_type(DataType::Decimal256(76,
6));
+ c2_builder.extend(vec![
+ Some(i256::from_i128(-3335724)),
+ Some(i256::from_i128(2179404)),
+ None,
+ Some(i256::from_i128(290472)),
+ ]);
+ let c2 = c2_builder.finish();
+
+ let batch =
+ RecordBatch::try_new(Arc::new(schema), vec![Arc::new(c1),
Arc::new(c2)])
+ .unwrap();
+
+ let mut file = tempfile::tempfile().unwrap();
+
+ let mut writer = Writer::new(&mut file);
+ let batches = vec![&batch, &batch];
+ for batch in batches {
+ writer.write(batch).unwrap();
+ }
+ drop(writer);
+
+ // check that file was written successfully
+ file.rewind().unwrap();
+ let mut buffer: Vec<u8> = vec![];
+ file.read_to_end(&mut buffer).unwrap();
+
+ let expected = r#"c1,c2
+-3.335724,-3.335724
+2.179404,2.179404
+,
+0.290472,0.290472
+-3.335724,-3.335724
+2.179404,2.179404
+,
+0.290472,0.290472
+"#;
+ assert_eq!(expected.to_string(), String::from_utf8(buffer).unwrap());
+ }
+
#[test]
fn test_write_csv_custom_options() {
let schema = Schema::new(vec![