This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 22c138156 Feat: arrow csv decimal256 (#3711)
22c138156 is described below

commit 22c138156715bf62c8c683fb94e947f7a3200149
Author: 苏小刚 <[email protected]>
AuthorDate: Wed Feb 15 06:36:05 2023 +0800

    Feat: arrow csv decimal256 (#3711)
    
    * add test for arrow-csv Decimal256
    
    * pass the test
    There is still room for improvement in the code
    
    * add test_write_csv_decimal for csv_writer
    
    * support i128 and i256 in one generic function
    
    * the test parse_decimal need Neg
    
    * Update arrow-array/src/array/primitive_array.rs
    
    This will allow simplifying trait bounds in a number of other places
    
    Co-authored-by: Raphael Taylor-Davies 
<[email protected]>
    
    * return an error instead of panicking on overflow
    
    * Decimal256(76, 6)
    
    * adding test for Decimal256Type
    
    ---------
    
    Co-authored-by: suxiaogang <[email protected]>
    Co-authored-by: Raphael Taylor-Davies 
<[email protected]>
---
 arrow-array/src/array/primitive_array.rs |   4 +-
 arrow-csv/src/reader/mod.rs              | 141 +++++++++++++++++++++----------
 arrow-csv/src/writer.rs                  |  55 ++++++++++++
 3 files changed, 155 insertions(+), 45 deletions(-)

diff --git a/arrow-array/src/array/primitive_array.rs 
b/arrow-array/src/array/primitive_array.rs
index aeece612d..b64534e98 100644
--- a/arrow-array/src/array/primitive_array.rs
+++ b/arrow-array/src/array/primitive_array.rs
@@ -23,8 +23,8 @@ use crate::temporal_conversions::{
 };
 use crate::timezone::Tz;
 use crate::trusted_len::trusted_len_unzip;
-use crate::types::*;
 use crate::{print_long_array, Array, ArrayAccessor};
+use crate::{types::*, ArrowNativeTypeOp};
 use arrow_buffer::{i256, ArrowNativeType, Buffer};
 use arrow_data::bit_iterator::try_for_each_valid_idx;
 use arrow_data::ArrayData;
@@ -233,7 +233,7 @@ pub type Decimal256Array = PrimitiveArray<Decimal256Type>;
 /// static-typed nature of rust types ([`ArrowNativeType`]) for all types that 
implement [`ArrowNativeType`].
 pub trait ArrowPrimitiveType: 'static {
     /// Corresponding Rust native type for the primitive type.
-    type Native: ArrowNativeType;
+    type Native: ArrowNativeTypeOp;
 
     /// the corresponding Arrow data type of this primitive type.
     const DATA_TYPE: DataType;
diff --git a/arrow-csv/src/reader/mod.rs b/arrow-csv/src/reader/mod.rs
index 610f05155..c5fe20e9d 100644
--- a/arrow-csv/src/reader/mod.rs
+++ b/arrow-csv/src/reader/mod.rs
@@ -42,6 +42,13 @@
 
 mod records;
 
+use arrow_array::builder::PrimitiveBuilder;
+use arrow_array::types::*;
+use arrow_array::ArrowNativeTypeOp;
+use arrow_array::*;
+use arrow_buffer::ArrowNativeType;
+use arrow_cast::parse::Parser;
+use arrow_schema::*;
 use lazy_static::lazy_static;
 use regex::{Regex, RegexSet};
 use std::collections::HashSet;
@@ -50,17 +57,9 @@ use std::fs::File;
 use std::io::{BufRead, BufReader as StdBufReader, Read, Seek, SeekFrom};
 use std::sync::Arc;
 
-use arrow_array::builder::Decimal128Builder;
-use arrow_array::types::*;
-use arrow_array::*;
-use arrow_cast::parse::Parser;
-use arrow_schema::*;
-
 use crate::map_csv_error;
 use crate::reader::records::{RecordDecoder, StringRecords};
-use arrow_data::decimal::validate_decimal_precision;
 use csv::StringRecord;
-use std::ops::Neg;
 
 lazy_static! {
     static ref REGEX_SET: RegexSet = RegexSet::new([
@@ -608,7 +607,22 @@ fn parse(
             match field.data_type() {
                 DataType::Boolean => build_boolean_array(line_number, rows, i),
                 DataType::Decimal128(precision, scale) => {
-                    build_decimal_array(line_number, rows, i, *precision, 
*scale)
+                    build_decimal_array::<Decimal128Type>(
+                        line_number,
+                        rows,
+                        i,
+                        *precision,
+                        *scale,
+                    )
+                }
+                DataType::Decimal256(precision, scale) => {
+                    build_decimal_array::<Decimal256Type>(
+                        line_number,
+                        rows,
+                        i,
+                        *precision,
+                        *scale,
+                    )
                 }
                 DataType::Int8 => {
                     build_primitive_array::<Int8Type>(line_number, rows, i, 
None)
@@ -781,22 +795,22 @@ fn parse_bool(string: &str) -> Option<bool> {
 }
 
 // parse the column string to an Arrow Array
-fn build_decimal_array(
+fn build_decimal_array<T: DecimalType>(
     _line_number: usize,
     rows: &StringRecords<'_>,
     col_idx: usize,
     precision: u8,
     scale: i8,
 ) -> Result<ArrayRef, ArrowError> {
-    let mut decimal_builder = Decimal128Builder::with_capacity(rows.len());
+    let mut decimal_builder = PrimitiveBuilder::<T>::with_capacity(rows.len());
     for row in rows.iter() {
         let s = row.get(col_idx);
         if s.is_empty() {
             // append null
             decimal_builder.append_null();
         } else {
-            let decimal_value: Result<i128, _> =
-                parse_decimal_with_parameter(s, precision, scale);
+            let decimal_value: Result<T::Native, _> =
+                parse_decimal_with_parameter::<T>(s, precision, scale);
             match decimal_value {
                 Ok(v) => {
                     decimal_builder.append_value(v);
@@ -814,17 +828,17 @@ fn build_decimal_array(
     ))
 }
 
-// Parse the string format decimal value to i128 format and checking the 
precision and scale.
-// The result i128 value can't be out of bounds.
-fn parse_decimal_with_parameter(
+// Parse the string format decimal value to i128/i256 format and checking the 
precision and scale.
+// The result value can't be out of bounds.
+fn parse_decimal_with_parameter<T: DecimalType>(
     s: &str,
     precision: u8,
     scale: i8,
-) -> Result<i128, ArrowError> {
+) -> Result<T::Native, ArrowError> {
     if PARSE_DECIMAL_RE.is_match(s) {
         let mut offset = s.len();
         let len = s.len();
-        let mut base = 1;
+        let mut base = T::Native::usize_as(1);
         let scale_usize = usize::from(scale as u8);
 
         // handle the value after the '.' and meet the scale
@@ -832,7 +846,7 @@ fn parse_decimal_with_parameter(
         match delimiter_position {
             None => {
                 // there is no '.'
-                base = 10_i128.pow(scale as u32);
+                base = T::Native::usize_as(10).pow_checked(scale as u32)?;
             }
             Some(mid) => {
                 // there is the '.'
@@ -841,7 +855,8 @@ fn parse_decimal_with_parameter(
                     offset -= len - mid - 1 - scale_usize;
                 } else {
                     // If the string value is "123.12" and the scale is 4, we 
should append '00' to the tail.
-                    base = 10_i128.pow((scale_usize + 1 + mid - len) as u32);
+                    base = T::Native::usize_as(10)
+                        .pow_checked((scale_usize + 1 + mid - len) as u32)?;
                 }
             }
         };
@@ -849,25 +864,29 @@ fn parse_decimal_with_parameter(
         // each byte is digit、'-' or '.'
         let bytes = s.as_bytes();
         let mut negative = false;
-        let mut result: i128 = 0;
+        let mut result = T::Native::usize_as(0);
 
-        bytes[0..offset].iter().rev().for_each(|&byte| match byte {
-            b'-' => {
-                negative = true;
-            }
-            b'0'..=b'9' => {
-                result += i128::from(byte - b'0') * base;
-                base *= 10;
+        for byte in bytes[0..offset].iter().rev() {
+            match byte {
+                b'-' => {
+                    negative = true;
+                }
+                b'0'..=b'9' => {
+                    let add =
+                        T::Native::usize_as((byte - b'0') as 
usize).mul_checked(base)?;
+                    result = result.add_checked(add)?;
+                    base = base.mul_checked(T::Native::usize_as(10))?;
+                }
+                // because of the PARSE_DECIMAL_RE, bytes just contains 
digit、'-' and '.'.
+                _ => {}
             }
-            // because of the PARSE_DECIMAL_RE, bytes just contains digit、'-' 
and '.'.
-            _ => {}
-        });
+        }
 
         if negative {
-            result = result.neg();
+            result = result.neg_checked()?;
         }
 
-        match validate_decimal_precision(result, precision) {
+        match T::validate_decimal_precision(result, precision) {
             Ok(_) => Ok(result),
             Err(e) => Err(ArrowError::ParseError(format!(
                 "parse decimal overflow: {e}"
@@ -884,6 +903,8 @@ fn parse_decimal_with_parameter(
 // Like "125.12" to 12512_i128.
 #[cfg(test)]
 fn parse_decimal(s: &str) -> Result<i128, ArrowError> {
+    use std::ops::Neg;
+
     if PARSE_DECIMAL_RE.is_match(s) {
         let mut offset = s.len();
         // each byte is digit、'-' or '.'
@@ -1230,6 +1251,7 @@ impl ReaderBuilder {
 mod tests {
     use super::*;
 
+    use arrow_buffer::i256;
     use std::io::{Cursor, Write};
     use tempfile::NamedTempFile;
 
@@ -1318,7 +1340,7 @@ mod tests {
         let schema = Schema::new(vec![
             Field::new("city", DataType::Utf8, false),
             Field::new("lat", DataType::Decimal128(38, 6), false),
-            Field::new("lng", DataType::Decimal128(38, 6), false),
+            Field::new("lng", DataType::Decimal256(76, 6), false),
         ]);
 
         let file = File::open("test/data/decimal_test.csv").unwrap();
@@ -1343,6 +1365,23 @@ mod tests {
         assert_eq!("123.000000", lat.value_as_string(7));
         assert_eq!("123.000000", lat.value_as_string(8));
         assert_eq!("-50.760000", lat.value_as_string(9));
+
+        let lng = batch
+            .column(2)
+            .as_any()
+            .downcast_ref::<Decimal256Array>()
+            .unwrap();
+
+        assert_eq!("-3.335724", lng.value_as_string(0));
+        assert_eq!("-2.179404", lng.value_as_string(1));
+        assert_eq!("-1.778197", lng.value_as_string(2));
+        assert_eq!("-3.179090", lng.value_as_string(3));
+        assert_eq!("-3.179090", lng.value_as_string(4));
+        assert_eq!("0.290472", lng.value_as_string(5));
+        assert_eq!("0.290472", lng.value_as_string(6));
+        assert_eq!("0.290472", lng.value_as_string(7));
+        assert_eq!("0.290472", lng.value_as_string(8));
+        assert_eq!("0.290472", lng.value_as_string(9));
     }
 
     #[test]
@@ -1788,26 +1827,42 @@ mod tests {
             ("-123.", -123000i128),
         ];
         for (s, i) in tests {
-            let result = parse_decimal_with_parameter(s, 20, 3);
-            assert_eq!(i, result.unwrap())
+            let result_128 = parse_decimal_with_parameter::<Decimal128Type>(s, 
20, 3);
+            assert_eq!(i, result_128.unwrap());
+            let result_256 = parse_decimal_with_parameter::<Decimal256Type>(s, 
20, 3);
+            assert_eq!(i256::from_i128(i), result_256.unwrap());
         }
         let can_not_parse_tests = ["123,123", ".", "123.123.123"];
         for s in can_not_parse_tests {
-            let result = parse_decimal_with_parameter(s, 20, 3);
+            let result_128 = parse_decimal_with_parameter::<Decimal128Type>(s, 
20, 3);
+            assert_eq!(
+                format!("Parser error: can't parse the string value {s} to 
decimal"),
+                result_128.unwrap_err().to_string()
+            );
+            let result_256 = parse_decimal_with_parameter::<Decimal256Type>(s, 
20, 3);
             assert_eq!(
                 format!("Parser error: can't parse the string value {s} to 
decimal"),
-                result.unwrap_err().to_string()
+                result_256.unwrap_err().to_string()
             );
         }
         let overflow_parse_tests = ["12345678", "12345678.9", "99999999.99"];
         for s in overflow_parse_tests {
-            let result = parse_decimal_with_parameter(s, 10, 3);
-            let expected = "Parser error: parse decimal overflow";
-            let actual = result.unwrap_err().to_string();
+            let result_128 = parse_decimal_with_parameter::<Decimal128Type>(s, 
10, 3);
+            let expected_128 = "Parser error: parse decimal overflow";
+            let actual_128 = result_128.unwrap_err().to_string();
+
+            assert!(
+                actual_128.contains(expected_128),
+                "actual: '{actual_128}', expected: '{expected_128}'"
+            );
+
+            let result_256 = parse_decimal_with_parameter::<Decimal256Type>(s, 
10, 3);
+            let expected_256 = "Parser error: parse decimal overflow";
+            let actual_256 = result_256.unwrap_err().to_string();
 
             assert!(
-                actual.contains(expected),
-                "actual: '{actual}', expected: '{expected}'"
+                actual_256.contains(expected_256),
+                "actual: '{actual_256}', expected: '{expected_256}'"
             );
         }
     }
diff --git a/arrow-csv/src/writer.rs b/arrow-csv/src/writer.rs
index e0734a15f..d9331053f 100644
--- a/arrow-csv/src/writer.rs
+++ b/arrow-csv/src/writer.rs
@@ -326,7 +326,9 @@ mod tests {
     use super::*;
 
     use crate::Reader;
+    use arrow_array::builder::{Decimal128Builder, Decimal256Builder};
     use arrow_array::types::*;
+    use arrow_buffer::i256;
     use std::io::{Cursor, Read, Seek};
     use std::sync::Arc;
 
@@ -406,6 +408,59 @@ sed do eiusmod 
tempor,-556132.25,1,,2019-04-18T02:45:55.555000000,23:46:03,foo
         assert_eq!(expected.to_string(), String::from_utf8(buffer).unwrap());
     }
 
+    #[test]
+    fn test_write_csv_decimal() {
+        let schema = Schema::new(vec![
+            Field::new("c1", DataType::Decimal128(38, 6), true),
+            Field::new("c2", DataType::Decimal256(76, 6), true),
+        ]);
+
+        let mut c1_builder =
+            Decimal128Builder::new().with_data_type(DataType::Decimal128(38, 
6));
+        c1_builder.extend(vec![Some(-3335724), Some(2179404), None, 
Some(290472)]);
+        let c1 = c1_builder.finish();
+
+        let mut c2_builder =
+            Decimal256Builder::new().with_data_type(DataType::Decimal256(76, 
6));
+        c2_builder.extend(vec![
+            Some(i256::from_i128(-3335724)),
+            Some(i256::from_i128(2179404)),
+            None,
+            Some(i256::from_i128(290472)),
+        ]);
+        let c2 = c2_builder.finish();
+
+        let batch =
+            RecordBatch::try_new(Arc::new(schema), vec![Arc::new(c1), 
Arc::new(c2)])
+                .unwrap();
+
+        let mut file = tempfile::tempfile().unwrap();
+
+        let mut writer = Writer::new(&mut file);
+        let batches = vec![&batch, &batch];
+        for batch in batches {
+            writer.write(batch).unwrap();
+        }
+        drop(writer);
+
+        // check that file was written successfully
+        file.rewind().unwrap();
+        let mut buffer: Vec<u8> = vec![];
+        file.read_to_end(&mut buffer).unwrap();
+
+        let expected = r#"c1,c2
+-3.335724,-3.335724
+2.179404,2.179404
+,
+0.290472,0.290472
+-3.335724,-3.335724
+2.179404,2.179404
+,
+0.290472,0.290472
+"#;
+        assert_eq!(expected.to_string(), String::from_utf8(buffer).unwrap());
+    }
+
     #[test]
     fn test_write_csv_custom_options() {
         let schema = Schema::new(vec![

Reply via email to