This is an automated email from the ASF dual-hosted git repository.
tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new 2075cd125d csv: Add option to specify custom null values (#4795)
2075cd125d is described below
commit 2075cd125dc0c132be5cb9dbf65748abf52243f1
Author: Vaibhav Rabber <[email protected]>
AuthorDate: Wed Sep 13 16:27:39 2023 +0530
csv: Add option to specify custom null values (#4795)
* csv: Add option to specify custom null regex
Can specify custom strings as `NULL` values for CSVs as a regular
expression. This allows reading a CSV files which have placeholders for
NULL values instead of empty strings.
Fixes #4794
Signed-off-by: Vaibhav <[email protected]>
* Apply suggestions from code review
---------
Signed-off-by: Vaibhav <[email protected]>
Co-authored-by: Raphael Taylor-Davies
<[email protected]>
---
arrow-csv/src/reader/mod.rs | 203 ++++++++++++++++++++++++++-----
arrow-csv/test/data/custom_null_test.csv | 6 +
2 files changed, 180 insertions(+), 29 deletions(-)
diff --git a/arrow-csv/src/reader/mod.rs b/arrow-csv/src/reader/mod.rs
index 328c2cd41f..695e3d4796 100644
--- a/arrow-csv/src/reader/mod.rs
+++ b/arrow-csv/src/reader/mod.rs
@@ -133,8 +133,8 @@ use arrow_schema::*;
use chrono::{TimeZone, Utc};
use csv::StringRecord;
use lazy_static::lazy_static;
-use regex::RegexSet;
-use std::fmt;
+use regex::{Regex, RegexSet};
+use std::fmt::{self, Debug};
use std::fs::File;
use std::io::{BufRead, BufReader as StdBufReader, Read, Seek, SeekFrom};
use std::sync::Arc;
@@ -157,6 +157,22 @@ lazy_static! {
]).unwrap();
}
+/// A wrapper over `Option<Regex>` to check if the value is `NULL`.
+#[derive(Debug, Clone, Default)]
+struct NullRegex(Option<Regex>);
+
+impl NullRegex {
+ /// Returns true if the value should be considered as `NULL` according to
+ /// the provided regular expression.
+ #[inline]
+ fn is_null(&self, s: &str) -> bool {
+ match &self.0 {
+ Some(r) => r.is_match(s),
+ None => s.is_empty(),
+ }
+ }
+}
+
#[derive(Default, Copy, Clone)]
struct InferredDataType {
/// Packed booleans indicating type
@@ -213,6 +229,7 @@ pub struct Format {
escape: Option<u8>,
quote: Option<u8>,
terminator: Option<u8>,
+ null_regex: NullRegex,
}
impl Format {
@@ -241,6 +258,12 @@ impl Format {
self
}
+ /// Provide a regex to match null values, defaults to `^$`
+ pub fn with_null_regex(mut self, null_regex: Regex) -> Self {
+ self.null_regex = NullRegex(Some(null_regex));
+ self
+ }
+
/// Infer schema of CSV records from the provided `reader`
///
/// If `max_records` is `None`, all records will be read, otherwise up to
`max_records`
@@ -287,7 +310,7 @@ impl Format {
column_types.iter_mut().enumerate().take(header_length)
{
if let Some(string) = record.get(i) {
- if !string.is_empty() {
+ if !self.null_regex.is_null(string) {
column_type.update(string)
}
}
@@ -557,6 +580,9 @@ pub struct Decoder {
/// A decoder for [`StringRecords`]
record_decoder: RecordDecoder,
+
+ /// Check if the string matches this pattern for `NULL`.
+ null_regex: NullRegex,
}
impl Decoder {
@@ -603,6 +629,7 @@ impl Decoder {
Some(self.schema.metadata.clone()),
self.projection.as_ref(),
self.line_number,
+ &self.null_regex,
)?;
self.line_number += rows.len();
Ok(Some(batch))
@@ -621,6 +648,7 @@ fn parse(
metadata: Option<std::collections::HashMap<String, String>>,
projection: Option<&Vec<usize>>,
line_number: usize,
+ null_regex: &NullRegex,
) -> Result<RecordBatch, ArrowError> {
let projection: Vec<usize> = match projection {
Some(v) => v.clone(),
@@ -633,7 +661,9 @@ fn parse(
let i = *i;
let field = &fields[i];
match field.data_type() {
- DataType::Boolean => build_boolean_array(line_number, rows, i),
+ DataType::Boolean => {
+ build_boolean_array(line_number, rows, i, null_regex)
+ }
DataType::Decimal128(precision, scale) => {
build_decimal_array::<Decimal128Type>(
line_number,
@@ -641,6 +671,7 @@ fn parse(
i,
*precision,
*scale,
+ null_regex,
)
}
DataType::Decimal256(precision, scale) => {
@@ -650,53 +681,73 @@ fn parse(
i,
*precision,
*scale,
+ null_regex,
)
}
- DataType::Int8 =>
build_primitive_array::<Int8Type>(line_number, rows, i),
+ DataType::Int8 => {
+ build_primitive_array::<Int8Type>(line_number, rows, i,
null_regex)
+ }
DataType::Int16 => {
- build_primitive_array::<Int16Type>(line_number, rows, i)
+ build_primitive_array::<Int16Type>(line_number, rows, i,
null_regex)
}
DataType::Int32 => {
- build_primitive_array::<Int32Type>(line_number, rows, i)
+ build_primitive_array::<Int32Type>(line_number, rows, i,
null_regex)
}
DataType::Int64 => {
- build_primitive_array::<Int64Type>(line_number, rows, i)
+ build_primitive_array::<Int64Type>(line_number, rows, i,
null_regex)
}
DataType::UInt8 => {
- build_primitive_array::<UInt8Type>(line_number, rows, i)
+ build_primitive_array::<UInt8Type>(line_number, rows, i,
null_regex)
}
DataType::UInt16 => {
- build_primitive_array::<UInt16Type>(line_number, rows, i)
+ build_primitive_array::<UInt16Type>(line_number, rows, i,
null_regex)
}
DataType::UInt32 => {
- build_primitive_array::<UInt32Type>(line_number, rows, i)
+ build_primitive_array::<UInt32Type>(line_number, rows, i,
null_regex)
}
DataType::UInt64 => {
- build_primitive_array::<UInt64Type>(line_number, rows, i)
+ build_primitive_array::<UInt64Type>(line_number, rows, i,
null_regex)
}
DataType::Float32 => {
- build_primitive_array::<Float32Type>(line_number, rows, i)
+ build_primitive_array::<Float32Type>(line_number, rows, i,
null_regex)
}
DataType::Float64 => {
- build_primitive_array::<Float64Type>(line_number, rows, i)
+ build_primitive_array::<Float64Type>(line_number, rows, i,
null_regex)
}
DataType::Date32 => {
- build_primitive_array::<Date32Type>(line_number, rows, i)
+ build_primitive_array::<Date32Type>(line_number, rows, i,
null_regex)
}
DataType::Date64 => {
- build_primitive_array::<Date64Type>(line_number, rows, i)
- }
- DataType::Time32(TimeUnit::Second) => {
- build_primitive_array::<Time32SecondType>(line_number,
rows, i)
+ build_primitive_array::<Date64Type>(line_number, rows, i,
null_regex)
}
+ DataType::Time32(TimeUnit::Second) => build_primitive_array::<
+ Time32SecondType,
+ >(
+ line_number, rows, i, null_regex
+ ),
DataType::Time32(TimeUnit::Millisecond) => {
-
build_primitive_array::<Time32MillisecondType>(line_number, rows, i)
+ build_primitive_array::<Time32MillisecondType>(
+ line_number,
+ rows,
+ i,
+ null_regex,
+ )
}
DataType::Time64(TimeUnit::Microsecond) => {
-
build_primitive_array::<Time64MicrosecondType>(line_number, rows, i)
+ build_primitive_array::<Time64MicrosecondType>(
+ line_number,
+ rows,
+ i,
+ null_regex,
+ )
}
DataType::Time64(TimeUnit::Nanosecond) => {
- build_primitive_array::<Time64NanosecondType>(line_number,
rows, i)
+ build_primitive_array::<Time64NanosecondType>(
+ line_number,
+ rows,
+ i,
+ null_regex,
+ )
}
DataType::Timestamp(TimeUnit::Second, tz) => {
build_timestamp_array::<TimestampSecondType>(
@@ -704,6 +755,7 @@ fn parse(
rows,
i,
tz.as_deref(),
+ null_regex,
)
}
DataType::Timestamp(TimeUnit::Millisecond, tz) => {
@@ -712,6 +764,7 @@ fn parse(
rows,
i,
tz.as_deref(),
+ null_regex,
)
}
DataType::Timestamp(TimeUnit::Microsecond, tz) => {
@@ -720,6 +773,7 @@ fn parse(
rows,
i,
tz.as_deref(),
+ null_regex,
)
}
DataType::Timestamp(TimeUnit::Nanosecond, tz) => {
@@ -728,6 +782,7 @@ fn parse(
rows,
i,
tz.as_deref(),
+ null_regex,
)
}
DataType::Utf8 => Ok(Arc::new(
@@ -827,11 +882,12 @@ fn build_decimal_array<T: DecimalType>(
col_idx: usize,
precision: u8,
scale: i8,
+ null_regex: &NullRegex,
) -> Result<ArrayRef, ArrowError> {
let mut decimal_builder = PrimitiveBuilder::<T>::with_capacity(rows.len());
for row in rows.iter() {
let s = row.get(col_idx);
- if s.is_empty() {
+ if null_regex.is_null(s) {
// append null
decimal_builder.append_null();
} else {
@@ -859,12 +915,13 @@ fn build_primitive_array<T: ArrowPrimitiveType + Parser>(
line_number: usize,
rows: &StringRecords<'_>,
col_idx: usize,
+ null_regex: &NullRegex,
) -> Result<ArrayRef, ArrowError> {
rows.iter()
.enumerate()
.map(|(row_index, row)| {
let s = row.get(col_idx);
- if s.is_empty() {
+ if null_regex.is_null(s) {
return Ok(None);
}
@@ -888,14 +945,27 @@ fn build_timestamp_array<T: ArrowTimestampType>(
rows: &StringRecords<'_>,
col_idx: usize,
timezone: Option<&str>,
+ null_regex: &NullRegex,
) -> Result<ArrayRef, ArrowError> {
Ok(Arc::new(match timezone {
Some(timezone) => {
let tz: Tz = timezone.parse()?;
- build_timestamp_array_impl::<T, _>(line_number, rows, col_idx,
&tz)?
- .with_timezone(timezone)
+ build_timestamp_array_impl::<T, _>(
+ line_number,
+ rows,
+ col_idx,
+ &tz,
+ null_regex,
+ )?
+ .with_timezone(timezone)
}
- None => build_timestamp_array_impl::<T, _>(line_number, rows, col_idx,
&Utc)?,
+ None => build_timestamp_array_impl::<T, _>(
+ line_number,
+ rows,
+ col_idx,
+ &Utc,
+ null_regex,
+ )?,
}))
}
@@ -904,12 +974,13 @@ fn build_timestamp_array_impl<T: ArrowTimestampType, Tz:
TimeZone>(
rows: &StringRecords<'_>,
col_idx: usize,
timezone: &Tz,
+ null_regex: &NullRegex,
) -> Result<PrimitiveArray<T>, ArrowError> {
rows.iter()
.enumerate()
.map(|(row_index, row)| {
let s = row.get(col_idx);
- if s.is_empty() {
+ if null_regex.is_null(s) {
return Ok(None);
}
@@ -936,12 +1007,13 @@ fn build_boolean_array(
line_number: usize,
rows: &StringRecords<'_>,
col_idx: usize,
+ null_regex: &NullRegex,
) -> Result<ArrayRef, ArrowError> {
rows.iter()
.enumerate()
.map(|(row_index, row)| {
let s = row.get(col_idx);
- if s.is_empty() {
+ if null_regex.is_null(s) {
return Ok(None);
}
let parsed = parse_bool(s);
@@ -1042,6 +1114,12 @@ impl ReaderBuilder {
self
}
+ /// Provide a regex to match null values, defaults to `^$`
+ pub fn with_null_regex(mut self, null_regex: Regex) -> Self {
+ self.format.null_regex = NullRegex(Some(null_regex));
+ self
+ }
+
/// Set the batch size (number of records to load at one time)
pub fn with_batch_size(mut self, batch_size: usize) -> Self {
self.batch_size = batch_size;
@@ -1100,6 +1178,7 @@ impl ReaderBuilder {
end,
projection: self.projection,
batch_size: self.batch_size,
+ null_regex: self.format.null_regex,
}
}
}
@@ -1426,6 +1505,36 @@ mod tests {
assert!(!batch.column(1).is_null(4));
}
+ #[test]
+ fn test_custom_nulls() {
+ let schema = Arc::new(Schema::new(vec![
+ Field::new("c_int", DataType::UInt64, true),
+ Field::new("c_float", DataType::Float32, true),
+ Field::new("c_string", DataType::Utf8, true),
+ Field::new("c_bool", DataType::Boolean, true),
+ ]));
+
+ let file = File::open("test/data/custom_null_test.csv").unwrap();
+
+ let null_regex = Regex::new("^nil$").unwrap();
+
+ let mut csv = ReaderBuilder::new(schema)
+ .has_header(true)
+ .with_null_regex(null_regex)
+ .build(file)
+ .unwrap();
+
+ let batch = csv.next().unwrap().unwrap();
+
+ // "nil"s should be NULL
+ assert!(batch.column(0).is_null(1));
+ assert!(batch.column(1).is_null(2));
+ assert!(batch.column(3).is_null(4));
+ // String won't be empty
+ assert!(!batch.column(2).is_null(3));
+ assert!(!batch.column(2).is_null(4));
+ }
+
#[test]
fn test_nulls_with_inference() {
let mut file = File::open("test/data/various_types.csv").unwrap();
@@ -1485,6 +1594,42 @@ mod tests {
assert!(!batch.column(1).is_null(4));
}
+ #[test]
+ fn test_custom_nulls_with_inference() {
+ let mut file = File::open("test/data/custom_null_test.csv").unwrap();
+
+ let null_regex = Regex::new("^nil$").unwrap();
+
+ let format = Format::default()
+ .with_header(true)
+ .with_null_regex(null_regex);
+
+ let (schema, _) = format.infer_schema(&mut file, None).unwrap();
+ file.rewind().unwrap();
+
+ let expected_schema = Schema::new(vec![
+ Field::new("c_int", DataType::Int64, true),
+ Field::new("c_float", DataType::Float64, true),
+ Field::new("c_string", DataType::Utf8, true),
+ Field::new("c_bool", DataType::Boolean, true),
+ ]);
+
+ assert_eq!(schema, expected_schema);
+
+ let builder = ReaderBuilder::new(Arc::new(schema))
+ .with_format(format)
+ .with_batch_size(512)
+ .with_projection(vec![0, 1, 2, 3]);
+
+ let mut csv = builder.build(file).unwrap();
+ let batch = csv.next().unwrap().unwrap();
+
+ assert_eq!(5, batch.num_rows());
+ assert_eq!(4, batch.num_columns());
+
+ assert_eq!(batch.schema().as_ref(), &expected_schema);
+ }
+
#[test]
fn test_parse_invalid_csv() {
let file = File::open("test/data/various_types_invalid.csv").unwrap();
diff --git a/arrow-csv/test/data/custom_null_test.csv
b/arrow-csv/test/data/custom_null_test.csv
new file mode 100644
index 0000000000..39f9fc4b3e
--- /dev/null
+++ b/arrow-csv/test/data/custom_null_test.csv
@@ -0,0 +1,6 @@
+c_int,c_float,c_string,c_bool
+1,1.1,"1.11",True
+nil,2.2,"2.22",TRUE
+3,nil,"3.33",true
+4,4.4,nil,False
+5,6.6,"",nil