This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new b03b80c Support read decimal data from csv reader if user provide the
schema with decimal data type (#941)
b03b80c is described below
commit b03b80cacc62576bfae193a0c69dbabea370a655
Author: Kun Liu <[email protected]>
AuthorDate: Wed Nov 17 00:43:29 2021 +0800
Support read decimal data from csv reader if user provide the schema with
decimal data type (#941)
* support decimal data type for csv reader
* format code and fix lint check
* fix the clippy error
* enchance the parse csv to decimal and add more test
---
arrow/src/array/builder.rs | 4 +-
arrow/src/array/mod.rs | 2 +
arrow/src/csv/reader.rs | 263 ++++++++++++++++++++++++++++++++++++++-
arrow/test/data/decimal_test.csv | 10 ++
4 files changed, 275 insertions(+), 4 deletions(-)
diff --git a/arrow/src/array/builder.rs b/arrow/src/array/builder.rs
index d08816c..af6f3c3 100644
--- a/arrow/src/array/builder.rs
+++ b/arrow/src/array/builder.rs
@@ -1118,7 +1118,7 @@ pub struct FixedSizeBinaryBuilder {
builder: FixedSizeListBuilder<UInt8Builder>,
}
-const MAX_DECIMAL_FOR_EACH_PRECISION: [i128; 38] = [
+pub const MAX_DECIMAL_FOR_EACH_PRECISION: [i128; 38] = [
9,
99,
999,
@@ -1158,7 +1158,7 @@ const MAX_DECIMAL_FOR_EACH_PRECISION: [i128; 38] = [
9999999999999999999999999999999999999,
170141183460469231731687303715884105727,
];
-const MIN_DECIMAL_FOR_EACH_PRECISION: [i128; 38] = [
+pub const MIN_DECIMAL_FOR_EACH_PRECISION: [i128; 38] = [
-9,
-99,
-999,
diff --git a/arrow/src/array/mod.rs b/arrow/src/array/mod.rs
index 5d4e57a..1298e58 100644
--- a/arrow/src/array/mod.rs
+++ b/arrow/src/array/mod.rs
@@ -393,6 +393,8 @@ pub use self::builder::StringBuilder;
pub use self::builder::StringDictionaryBuilder;
pub use self::builder::StructBuilder;
pub use self::builder::UnionBuilder;
+pub use self::builder::MAX_DECIMAL_FOR_EACH_PRECISION;
+pub use self::builder::MIN_DECIMAL_FOR_EACH_PRECISION;
pub type Int8Builder = PrimitiveBuilder<Int8Type>;
pub type Int16Builder = PrimitiveBuilder<Int16Type>;
diff --git a/arrow/src/csv/reader.rs b/arrow/src/csv/reader.rs
index 4940ea2..ac72939 100644
--- a/arrow/src/csv/reader.rs
+++ b/arrow/src/csv/reader.rs
@@ -50,7 +50,8 @@ use std::io::{Read, Seek, SeekFrom};
use std::sync::Arc;
use crate::array::{
- ArrayRef, BooleanArray, DictionaryArray, PrimitiveArray, StringArray,
+ ArrayRef, BooleanArray, DecimalBuilder, DictionaryArray, PrimitiveArray,
StringArray,
+ MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION,
};
use crate::compute::kernels::cast_utils::string_to_timestamp_nanos;
use crate::datatypes::*;
@@ -58,8 +59,11 @@ use crate::error::{ArrowError, Result};
use crate::record_batch::RecordBatch;
use csv_crate::{ByteRecord, StringRecord};
+use std::ops::Neg;
lazy_static! {
+ static ref PARSE_DECIMAL_RE: Regex =
+ Regex::new(r"^-?(\d+\.?\d*|\d*\.?\d+)$").unwrap();
static ref DECIMAL_RE: Regex =
Regex::new(r"^-?(\d*\.\d+|\d+\.\d*)$").unwrap();
static ref INTEGER_RE: Regex = Regex::new(r"^-?(\d+)$").unwrap();
static ref BOOLEAN_RE: Regex = RegexBuilder::new(r"^(true)$|^(false)$")
@@ -99,7 +103,7 @@ fn infer_field_schema(string: &str) -> DataType {
///
/// If `max_read_records` is not set, the whole file is read to infer its
schema.
///
-/// Return infered schema and number of records used for inference. This
function does not change
+/// Return inferred schema and number of records used for inference. This
function does not change
/// reader cursor offset.
pub fn infer_file_schema<R: Read + Seek>(
reader: &mut R,
@@ -513,6 +517,9 @@ fn parse(
let field = &fields[i];
match field.data_type() {
DataType::Boolean => build_boolean_array(line_number, rows, i),
+ DataType::Decimal(precision, scale) => {
+ build_decimal_array(line_number, rows, i, *precision,
*scale)
+ }
DataType::Int8 =>
build_primitive_array::<Int8Type>(line_number, rows, i),
DataType::Int16 => {
build_primitive_array::<Int16Type>(line_number, rows, i)
@@ -728,6 +735,161 @@ fn parse_bool(string: &str) -> Option<bool> {
}
}
+// parse the column string to an Arrow Array
+fn build_decimal_array(
+ _line_number: usize,
+ rows: &[StringRecord],
+ col_idx: usize,
+ precision: usize,
+ scale: usize,
+) -> Result<ArrayRef> {
+ let mut decimal_builder = DecimalBuilder::new(rows.len(), precision,
scale);
+ for row in rows {
+ let col_s = row.get(col_idx);
+ match col_s {
+ None => {
+ // No data for this row
+ decimal_builder.append_null()?;
+ }
+ Some(s) => {
+ if s.is_empty() {
+ // append null
+ decimal_builder.append_null()?;
+ } else {
+ let decimal_value: Result<i128> =
+ parse_decimal_with_parameter(s, precision, scale);
+ match decimal_value {
+ Ok(v) => {
+ decimal_builder.append_value(v)?;
+ }
+ Err(e) => {
+ return Err(e);
+ }
+ }
+ }
+ }
+ }
+ }
+ Ok(Arc::new(decimal_builder.finish()))
+}
+
+// Parse the string format decimal value to i128 format and checking the
precision and scale.
+// The result i128 value can't be out of bounds.
+fn parse_decimal_with_parameter(s: &str, precision: usize, scale: usize) ->
Result<i128> {
+ if PARSE_DECIMAL_RE.is_match(s) {
+ let mut offset = s.len();
+ let len = s.len();
+ // each byte is digit、'-' or '.'
+ let mut base = 1;
+
+ // handle the value after the '.' and meet the scale
+ let delimiter_position = s.find('.');
+ match delimiter_position {
+ None => {
+ // there is no '.'
+ base = 10_i128.pow(scale as u32);
+ }
+ Some(mid) => {
+ // there is the '.'
+ if len - mid >= scale + 1 {
+ // If the string value is "123.12345" and the scale is 2,
we should just remain '.12' and drop the '345' value.
+ offset -= len - mid - 1 - scale;
+ } else {
+ // If the string value is "123.12" and the scale is 4, we
should append '00' to the tail.
+ base = 10_i128.pow((scale + 1 + mid - len) as u32);
+ }
+ }
+ };
+
+ let bytes = s.as_bytes();
+ let mut negative = false;
+ let mut result: i128 = 0;
+
+ while offset > 0 {
+ match bytes[offset - 1] {
+ b'-' => {
+ negative = true;
+ }
+ b'.' => {
+ // do nothing
+ }
+ b'0'..=b'9' => {
+ result += i128::from(bytes[offset - 1] - b'0') * base;
+ base *= 10;
+ }
+ _ => {
+ return Err(ArrowError::ParseError(format!(
+ "can't match byte {}",
+ bytes[offset - 1]
+ )));
+ }
+ }
+ offset -= 1;
+ }
+ if negative {
+ result = result.neg();
+ }
+ if result > MAX_DECIMAL_FOR_EACH_PRECISION[precision - 1]
+ || result < MIN_DECIMAL_FOR_EACH_PRECISION[precision - 1]
+ {
+ return Err(ArrowError::ParseError(format!(
+ "parse decimal overflow, the precision {}, the scale {}, the
value {}",
+ precision, scale, s
+ )));
+ }
+ Ok(result)
+ } else {
+ Err(ArrowError::ParseError(format!(
+ "can't parse the string value {} to decimal",
+ s
+ )))
+ }
+}
+
+// Parse the string format decimal value to i128 format without checking the
precision and scale.
+// Like "125.12" to 12512_i128.
+fn parse_decimal(s: &str) -> Result<i128> {
+ if PARSE_DECIMAL_RE.is_match(s) {
+ let mut offset = s.len();
+ // each byte is digit、'-' or '.'
+ let bytes = s.as_bytes();
+ let mut negative = false;
+ let mut result: i128 = 0;
+ let mut base = 1;
+ while offset > 0 {
+ match bytes[offset - 1] {
+ b'-' => {
+ negative = true;
+ }
+ b'.' => {
+ // do nothing
+ }
+ b'0'..=b'9' => {
+ result += i128::from(bytes[offset - 1] - b'0') * base;
+ base *= 10;
+ }
+ _ => {
+ return Err(ArrowError::ParseError(format!(
+ "can't match byte {}",
+ bytes[offset - 1]
+ )));
+ }
+ }
+ offset -= 1;
+ }
+ if negative {
+ Ok(result.neg())
+ } else {
+ Ok(result)
+ }
+ } else {
+ Err(ArrowError::ParseError(format!(
+ "can't parse the string value {} to decimal",
+ s
+ )))
+ }
+}
+
// parses a specific column (col_idx) into an Arrow Array.
fn build_primitive_array<T: ArrowPrimitiveType + Parser>(
line_number: usize,
@@ -1056,6 +1218,37 @@ mod tests {
}
#[test]
+ fn test_csv_reader_with_decimal() {
+ let schema = Schema::new(vec![
+ Field::new("city", DataType::Utf8, false),
+ Field::new("lat", DataType::Decimal(26, 6), false),
+ Field::new("lng", DataType::Decimal(26, 6), false),
+ ]);
+
+ let file = File::open("test/data/decimal_test.csv").unwrap();
+
+ let mut csv = Reader::new(file, Arc::new(schema), false, None, 1024,
None, None);
+ let batch = csv.next().unwrap().unwrap();
+ // access data from a primitive array
+ let lat = batch
+ .column(1)
+ .as_any()
+ .downcast_ref::<DecimalArray>()
+ .unwrap();
+
+ assert_eq!("57.653484", lat.value_as_string(0));
+ assert_eq!("53.002666", lat.value_as_string(1));
+ assert_eq!("52.412811", lat.value_as_string(2));
+ assert_eq!("51.481583", lat.value_as_string(3));
+ assert_eq!("12.123456", lat.value_as_string(4));
+ assert_eq!("50.760000", lat.value_as_string(5));
+ assert_eq!("0.123000", lat.value_as_string(6));
+ assert_eq!("123.000000", lat.value_as_string(7));
+ assert_eq!("123.000000", lat.value_as_string(8));
+ assert_eq!("-50.760000", lat.value_as_string(9));
+ }
+
+ #[test]
fn test_csv_from_buf_reader() {
let schema = Schema::new(vec![
Field::new("city", DataType::Utf8, false),
@@ -1348,6 +1541,8 @@ mod tests {
assert_eq!(infer_field_schema("false"), DataType::Boolean);
assert_eq!(infer_field_schema("2020-11-08"), DataType::Date32);
assert_eq!(infer_field_schema("2020-11-08T14:20:01"),
DataType::Date64);
+ assert_eq!(infer_field_schema("-5.13"), DataType::Float64);
+ assert_eq!(infer_field_schema("0.1300"), DataType::Float64);
}
#[test]
@@ -1374,6 +1569,70 @@ mod tests {
);
}
+ #[test]
+ fn test_parse_decimal() {
+ let tests = [
+ ("123.00", 12300i128),
+ ("123.123", 123123i128),
+ ("0.0123", 123i128),
+ ("0.12300", 12300i128),
+ ("-5.123", -5123i128),
+ ("-45.432432", -45432432i128),
+ ];
+ for (s, i) in tests {
+ let result = parse_decimal(s);
+ assert_eq!(i, result.unwrap());
+ }
+ }
+
+ #[test]
+ fn test_parse_decimal_with_parameter() {
+ let tests = [
+ ("123.123", 123123i128),
+ ("123.1234", 123123i128),
+ ("123.1", 123100i128),
+ ("123", 123000i128),
+ ("-123.123", -123123i128),
+ ("-123.1234", -123123i128),
+ ("-123.1", -123100i128),
+ ("-123", -123000i128),
+ ("0.0000123", 0i128),
+ ("12.", 12000i128),
+ ("-12.", -12000i128),
+ ("00.1", 100i128),
+ ("-00.1", -100i128),
+ ("12345678912345678.1234", 12345678912345678123i128),
+ ("-12345678912345678.1234", -12345678912345678123i128),
+ ("99999999999999999.999", 99999999999999999999i128),
+ ("-99999999999999999.999", -99999999999999999999i128),
+ (".123", 123i128),
+ ("-.123", -123i128),
+ ("123.", 123000i128),
+ ("-123.", -123000i128),
+ ];
+ for (s, i) in tests {
+ let result = parse_decimal_with_parameter(s, 20, 3);
+ assert_eq!(i, result.unwrap())
+ }
+ let can_not_parse_tests = ["123,123", "."];
+ for s in can_not_parse_tests {
+ let result = parse_decimal_with_parameter(s, 20, 3);
+ assert_eq!(
+ format!(
+ "Parser error: can't parse the string value {} to decimal",
+ s
+ ),
+ result.unwrap_err().to_string()
+ );
+ }
+ let overflow_parse_tests = ["12345678", "12345678.9", "99999999.99"];
+ for s in overflow_parse_tests {
+ let result = parse_decimal_with_parameter(s, 10, 3);
+ assert_eq!(format!(
+ "Parser error: parse decimal overflow, the precision {}, the
scale {}, the value {}", 10,3, s),result.unwrap_err().to_string());
+ }
+ }
+
/// Interprets a naive_datetime (with no explicit timezone offset)
/// using the local timezone and returns the timestamp in UTC (0
/// offset)
diff --git a/arrow/test/data/decimal_test.csv b/arrow/test/data/decimal_test.csv
new file mode 100644
index 0000000..460ed80
--- /dev/null
+++ b/arrow/test/data/decimal_test.csv
@@ -0,0 +1,10 @@
+"Elgin, Scotland, the UK",57.653484,-3.335724
+"Stoke-on-Trent, Staffordshire, the UK",53.002666,-2.179404
+"Solihull, Birmingham, UK",52.412811,-1.778197
+"Cardiff, Cardiff county, UK",51.481583,-3.179090
+"Cardiff, Cardiff county, UK",12.12345678,-3.179090
+"Eastbourne, East Sussex, UK",50.76,0.290472
+"Eastbourne, East Sussex, UK",.123,0.290472
+"Eastbourne, East Sussex, UK",123.,0.290472
+"Eastbourne, East Sussex, UK",123,0.290472
+"Eastbourne, East Sussex, UK",-50.76,0.290472
\ No newline at end of file