This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 2075cd125d csv: Add option to specify custom null values (#4795)
2075cd125d is described below

commit 2075cd125dc0c132be5cb9dbf65748abf52243f1
Author: Vaibhav Rabber <[email protected]>
AuthorDate: Wed Sep 13 16:27:39 2023 +0530

    csv: Add option to specify custom null values (#4795)
    
    * csv: Add option to specify custom null regex
    
    Can specify custom strings as `NULL` values for CSVs as a regular
    expression. This allows reading a CSV files which have placeholders for
    NULL values instead of empty strings.
    
    Fixes #4794
    
    Signed-off-by: Vaibhav <[email protected]>
    
    * Apply suggestions from code review
    
    ---------
    
    Signed-off-by: Vaibhav <[email protected]>
    Co-authored-by: Raphael Taylor-Davies 
<[email protected]>
---
 arrow-csv/src/reader/mod.rs              | 203 ++++++++++++++++++++++++++-----
 arrow-csv/test/data/custom_null_test.csv |   6 +
 2 files changed, 180 insertions(+), 29 deletions(-)

diff --git a/arrow-csv/src/reader/mod.rs b/arrow-csv/src/reader/mod.rs
index 328c2cd41f..695e3d4796 100644
--- a/arrow-csv/src/reader/mod.rs
+++ b/arrow-csv/src/reader/mod.rs
@@ -133,8 +133,8 @@ use arrow_schema::*;
 use chrono::{TimeZone, Utc};
 use csv::StringRecord;
 use lazy_static::lazy_static;
-use regex::RegexSet;
-use std::fmt;
+use regex::{Regex, RegexSet};
+use std::fmt::{self, Debug};
 use std::fs::File;
 use std::io::{BufRead, BufReader as StdBufReader, Read, Seek, SeekFrom};
 use std::sync::Arc;
@@ -157,6 +157,22 @@ lazy_static! {
     ]).unwrap();
 }
 
+/// A wrapper over `Option<Regex>` to check if the value is `NULL`.
+#[derive(Debug, Clone, Default)]
+struct NullRegex(Option<Regex>);
+
+impl NullRegex {
+    /// Returns true if the value should be considered as `NULL` according to
+    /// the provided regular expression.
+    #[inline]
+    fn is_null(&self, s: &str) -> bool {
+        match &self.0 {
+            Some(r) => r.is_match(s),
+            None => s.is_empty(),
+        }
+    }
+}
+
 #[derive(Default, Copy, Clone)]
 struct InferredDataType {
     /// Packed booleans indicating type
@@ -213,6 +229,7 @@ pub struct Format {
     escape: Option<u8>,
     quote: Option<u8>,
     terminator: Option<u8>,
+    null_regex: NullRegex,
 }
 
 impl Format {
@@ -241,6 +258,12 @@ impl Format {
         self
     }
 
+    /// Provide a regex to match null values, defaults to `^$`
+    pub fn with_null_regex(mut self, null_regex: Regex) -> Self {
+        self.null_regex = NullRegex(Some(null_regex));
+        self
+    }
+
     /// Infer schema of CSV records from the provided `reader`
     ///
     /// If `max_records` is `None`, all records will be read, otherwise up to 
`max_records`
@@ -287,7 +310,7 @@ impl Format {
                 column_types.iter_mut().enumerate().take(header_length)
             {
                 if let Some(string) = record.get(i) {
-                    if !string.is_empty() {
+                    if !self.null_regex.is_null(string) {
                         column_type.update(string)
                     }
                 }
@@ -557,6 +580,9 @@ pub struct Decoder {
 
     /// A decoder for [`StringRecords`]
     record_decoder: RecordDecoder,
+
+    /// Check if the string matches this pattern for `NULL`.
+    null_regex: NullRegex,
 }
 
 impl Decoder {
@@ -603,6 +629,7 @@ impl Decoder {
             Some(self.schema.metadata.clone()),
             self.projection.as_ref(),
             self.line_number,
+            &self.null_regex,
         )?;
         self.line_number += rows.len();
         Ok(Some(batch))
@@ -621,6 +648,7 @@ fn parse(
     metadata: Option<std::collections::HashMap<String, String>>,
     projection: Option<&Vec<usize>>,
     line_number: usize,
+    null_regex: &NullRegex,
 ) -> Result<RecordBatch, ArrowError> {
     let projection: Vec<usize> = match projection {
         Some(v) => v.clone(),
@@ -633,7 +661,9 @@ fn parse(
             let i = *i;
             let field = &fields[i];
             match field.data_type() {
-                DataType::Boolean => build_boolean_array(line_number, rows, i),
+                DataType::Boolean => {
+                    build_boolean_array(line_number, rows, i, null_regex)
+                }
                 DataType::Decimal128(precision, scale) => {
                     build_decimal_array::<Decimal128Type>(
                         line_number,
@@ -641,6 +671,7 @@ fn parse(
                         i,
                         *precision,
                         *scale,
+                        null_regex,
                     )
                 }
                 DataType::Decimal256(precision, scale) => {
@@ -650,53 +681,73 @@ fn parse(
                         i,
                         *precision,
                         *scale,
+                        null_regex,
                     )
                 }
-                DataType::Int8 => 
build_primitive_array::<Int8Type>(line_number, rows, i),
+                DataType::Int8 => {
+                    build_primitive_array::<Int8Type>(line_number, rows, i, 
null_regex)
+                }
                 DataType::Int16 => {
-                    build_primitive_array::<Int16Type>(line_number, rows, i)
+                    build_primitive_array::<Int16Type>(line_number, rows, i, 
null_regex)
                 }
                 DataType::Int32 => {
-                    build_primitive_array::<Int32Type>(line_number, rows, i)
+                    build_primitive_array::<Int32Type>(line_number, rows, i, 
null_regex)
                 }
                 DataType::Int64 => {
-                    build_primitive_array::<Int64Type>(line_number, rows, i)
+                    build_primitive_array::<Int64Type>(line_number, rows, i, 
null_regex)
                 }
                 DataType::UInt8 => {
-                    build_primitive_array::<UInt8Type>(line_number, rows, i)
+                    build_primitive_array::<UInt8Type>(line_number, rows, i, 
null_regex)
                 }
                 DataType::UInt16 => {
-                    build_primitive_array::<UInt16Type>(line_number, rows, i)
+                    build_primitive_array::<UInt16Type>(line_number, rows, i, 
null_regex)
                 }
                 DataType::UInt32 => {
-                    build_primitive_array::<UInt32Type>(line_number, rows, i)
+                    build_primitive_array::<UInt32Type>(line_number, rows, i, 
null_regex)
                 }
                 DataType::UInt64 => {
-                    build_primitive_array::<UInt64Type>(line_number, rows, i)
+                    build_primitive_array::<UInt64Type>(line_number, rows, i, 
null_regex)
                 }
                 DataType::Float32 => {
-                    build_primitive_array::<Float32Type>(line_number, rows, i)
+                    build_primitive_array::<Float32Type>(line_number, rows, i, 
null_regex)
                 }
                 DataType::Float64 => {
-                    build_primitive_array::<Float64Type>(line_number, rows, i)
+                    build_primitive_array::<Float64Type>(line_number, rows, i, 
null_regex)
                 }
                 DataType::Date32 => {
-                    build_primitive_array::<Date32Type>(line_number, rows, i)
+                    build_primitive_array::<Date32Type>(line_number, rows, i, 
null_regex)
                 }
                 DataType::Date64 => {
-                    build_primitive_array::<Date64Type>(line_number, rows, i)
-                }
-                DataType::Time32(TimeUnit::Second) => {
-                    build_primitive_array::<Time32SecondType>(line_number, 
rows, i)
+                    build_primitive_array::<Date64Type>(line_number, rows, i, 
null_regex)
                 }
+                DataType::Time32(TimeUnit::Second) => build_primitive_array::<
+                    Time32SecondType,
+                >(
+                    line_number, rows, i, null_regex
+                ),
                 DataType::Time32(TimeUnit::Millisecond) => {
-                    
build_primitive_array::<Time32MillisecondType>(line_number, rows, i)
+                    build_primitive_array::<Time32MillisecondType>(
+                        line_number,
+                        rows,
+                        i,
+                        null_regex,
+                    )
                 }
                 DataType::Time64(TimeUnit::Microsecond) => {
-                    
build_primitive_array::<Time64MicrosecondType>(line_number, rows, i)
+                    build_primitive_array::<Time64MicrosecondType>(
+                        line_number,
+                        rows,
+                        i,
+                        null_regex,
+                    )
                 }
                 DataType::Time64(TimeUnit::Nanosecond) => {
-                    build_primitive_array::<Time64NanosecondType>(line_number, 
rows, i)
+                    build_primitive_array::<Time64NanosecondType>(
+                        line_number,
+                        rows,
+                        i,
+                        null_regex,
+                    )
                 }
                 DataType::Timestamp(TimeUnit::Second, tz) => {
                     build_timestamp_array::<TimestampSecondType>(
@@ -704,6 +755,7 @@ fn parse(
                         rows,
                         i,
                         tz.as_deref(),
+                        null_regex,
                     )
                 }
                 DataType::Timestamp(TimeUnit::Millisecond, tz) => {
@@ -712,6 +764,7 @@ fn parse(
                         rows,
                         i,
                         tz.as_deref(),
+                        null_regex,
                     )
                 }
                 DataType::Timestamp(TimeUnit::Microsecond, tz) => {
@@ -720,6 +773,7 @@ fn parse(
                         rows,
                         i,
                         tz.as_deref(),
+                        null_regex,
                     )
                 }
                 DataType::Timestamp(TimeUnit::Nanosecond, tz) => {
@@ -728,6 +782,7 @@ fn parse(
                         rows,
                         i,
                         tz.as_deref(),
+                        null_regex,
                     )
                 }
                 DataType::Utf8 => Ok(Arc::new(
@@ -827,11 +882,12 @@ fn build_decimal_array<T: DecimalType>(
     col_idx: usize,
     precision: u8,
     scale: i8,
+    null_regex: &NullRegex,
 ) -> Result<ArrayRef, ArrowError> {
     let mut decimal_builder = PrimitiveBuilder::<T>::with_capacity(rows.len());
     for row in rows.iter() {
         let s = row.get(col_idx);
-        if s.is_empty() {
+        if null_regex.is_null(s) {
             // append null
             decimal_builder.append_null();
         } else {
@@ -859,12 +915,13 @@ fn build_primitive_array<T: ArrowPrimitiveType + Parser>(
     line_number: usize,
     rows: &StringRecords<'_>,
     col_idx: usize,
+    null_regex: &NullRegex,
 ) -> Result<ArrayRef, ArrowError> {
     rows.iter()
         .enumerate()
         .map(|(row_index, row)| {
             let s = row.get(col_idx);
-            if s.is_empty() {
+            if null_regex.is_null(s) {
                 return Ok(None);
             }
 
@@ -888,14 +945,27 @@ fn build_timestamp_array<T: ArrowTimestampType>(
     rows: &StringRecords<'_>,
     col_idx: usize,
     timezone: Option<&str>,
+    null_regex: &NullRegex,
 ) -> Result<ArrayRef, ArrowError> {
     Ok(Arc::new(match timezone {
         Some(timezone) => {
             let tz: Tz = timezone.parse()?;
-            build_timestamp_array_impl::<T, _>(line_number, rows, col_idx, 
&tz)?
-                .with_timezone(timezone)
+            build_timestamp_array_impl::<T, _>(
+                line_number,
+                rows,
+                col_idx,
+                &tz,
+                null_regex,
+            )?
+            .with_timezone(timezone)
         }
-        None => build_timestamp_array_impl::<T, _>(line_number, rows, col_idx, 
&Utc)?,
+        None => build_timestamp_array_impl::<T, _>(
+            line_number,
+            rows,
+            col_idx,
+            &Utc,
+            null_regex,
+        )?,
     }))
 }
 
@@ -904,12 +974,13 @@ fn build_timestamp_array_impl<T: ArrowTimestampType, Tz: 
TimeZone>(
     rows: &StringRecords<'_>,
     col_idx: usize,
     timezone: &Tz,
+    null_regex: &NullRegex,
 ) -> Result<PrimitiveArray<T>, ArrowError> {
     rows.iter()
         .enumerate()
         .map(|(row_index, row)| {
             let s = row.get(col_idx);
-            if s.is_empty() {
+            if null_regex.is_null(s) {
                 return Ok(None);
             }
 
@@ -936,12 +1007,13 @@ fn build_boolean_array(
     line_number: usize,
     rows: &StringRecords<'_>,
     col_idx: usize,
+    null_regex: &NullRegex,
 ) -> Result<ArrayRef, ArrowError> {
     rows.iter()
         .enumerate()
         .map(|(row_index, row)| {
             let s = row.get(col_idx);
-            if s.is_empty() {
+            if null_regex.is_null(s) {
                 return Ok(None);
             }
             let parsed = parse_bool(s);
@@ -1042,6 +1114,12 @@ impl ReaderBuilder {
         self
     }
 
+    /// Provide a regex to match null values, defaults to `^$`
+    pub fn with_null_regex(mut self, null_regex: Regex) -> Self {
+        self.format.null_regex = NullRegex(Some(null_regex));
+        self
+    }
+
     /// Set the batch size (number of records to load at one time)
     pub fn with_batch_size(mut self, batch_size: usize) -> Self {
         self.batch_size = batch_size;
@@ -1100,6 +1178,7 @@ impl ReaderBuilder {
             end,
             projection: self.projection,
             batch_size: self.batch_size,
+            null_regex: self.format.null_regex,
         }
     }
 }
@@ -1426,6 +1505,36 @@ mod tests {
         assert!(!batch.column(1).is_null(4));
     }
 
+    #[test]
+    fn test_custom_nulls() {
+        let schema = Arc::new(Schema::new(vec![
+            Field::new("c_int", DataType::UInt64, true),
+            Field::new("c_float", DataType::Float32, true),
+            Field::new("c_string", DataType::Utf8, true),
+            Field::new("c_bool", DataType::Boolean, true),
+        ]));
+
+        let file = File::open("test/data/custom_null_test.csv").unwrap();
+
+        let null_regex = Regex::new("^nil$").unwrap();
+
+        let mut csv = ReaderBuilder::new(schema)
+            .has_header(true)
+            .with_null_regex(null_regex)
+            .build(file)
+            .unwrap();
+
+        let batch = csv.next().unwrap().unwrap();
+
+        // "nil"s should be NULL
+        assert!(batch.column(0).is_null(1));
+        assert!(batch.column(1).is_null(2));
+        assert!(batch.column(3).is_null(4));
+        // String won't be empty
+        assert!(!batch.column(2).is_null(3));
+        assert!(!batch.column(2).is_null(4));
+    }
+
     #[test]
     fn test_nulls_with_inference() {
         let mut file = File::open("test/data/various_types.csv").unwrap();
@@ -1485,6 +1594,42 @@ mod tests {
         assert!(!batch.column(1).is_null(4));
     }
 
+    #[test]
+    fn test_custom_nulls_with_inference() {
+        let mut file = File::open("test/data/custom_null_test.csv").unwrap();
+
+        let null_regex = Regex::new("^nil$").unwrap();
+
+        let format = Format::default()
+            .with_header(true)
+            .with_null_regex(null_regex);
+
+        let (schema, _) = format.infer_schema(&mut file, None).unwrap();
+        file.rewind().unwrap();
+
+        let expected_schema = Schema::new(vec![
+            Field::new("c_int", DataType::Int64, true),
+            Field::new("c_float", DataType::Float64, true),
+            Field::new("c_string", DataType::Utf8, true),
+            Field::new("c_bool", DataType::Boolean, true),
+        ]);
+
+        assert_eq!(schema, expected_schema);
+
+        let builder = ReaderBuilder::new(Arc::new(schema))
+            .with_format(format)
+            .with_batch_size(512)
+            .with_projection(vec![0, 1, 2, 3]);
+
+        let mut csv = builder.build(file).unwrap();
+        let batch = csv.next().unwrap().unwrap();
+
+        assert_eq!(5, batch.num_rows());
+        assert_eq!(4, batch.num_columns());
+
+        assert_eq!(batch.schema().as_ref(), &expected_schema);
+    }
+
     #[test]
     fn test_parse_invalid_csv() {
         let file = File::open("test/data/various_types_invalid.csv").unwrap();
diff --git a/arrow-csv/test/data/custom_null_test.csv 
b/arrow-csv/test/data/custom_null_test.csv
new file mode 100644
index 0000000000..39f9fc4b3e
--- /dev/null
+++ b/arrow-csv/test/data/custom_null_test.csv
@@ -0,0 +1,6 @@
+c_int,c_float,c_string,c_bool
+1,1.1,"1.11",True
+nil,2.2,"2.22",TRUE
+3,nil,"3.33",true
+4,4.4,nil,False
+5,6.6,"",nil

Reply via email to