This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/main by this push:
     new 93ce75c75d Correctly handling nullable in CSV parser (#6830)
93ce75c75d is described below

commit 93ce75c75d2f4c753345bd585cbbe0bb978f4bab
Author: Edmondo Porcu <[email protected]>
AuthorDate: Thu Dec 5 02:42:27 2024 -0800

    Correctly handling nullable in CSV parser (#6830)
---
 arrow-csv/src/reader/mod.rs                      | 74 +++++++++++++++++++++---
 arrow-csv/test/data/dictionary_nullable_test.csv |  3 +
 2 files changed, 69 insertions(+), 8 deletions(-)

diff --git a/arrow-csv/src/reader/mod.rs b/arrow-csv/src/reader/mod.rs
index f55053e503..9bdb80ef31 100644
--- a/arrow-csv/src/reader/mod.rs
+++ b/arrow-csv/src/reader/mod.rs
@@ -779,42 +779,66 @@ fn parse(
                     match key_type.as_ref() {
                         DataType::Int8 => Ok(Arc::new(
                             rows.iter()
-                                .map(|row| row.get(i))
+                                .map(|row| {
+                                    let s = row.get(i);
+                                    (!null_regex.is_null(s)).then_some(s)
+                                })
                                 .collect::<DictionaryArray<Int8Type>>(),
                         ) as ArrayRef),
                         DataType::Int16 => Ok(Arc::new(
                             rows.iter()
-                                .map(|row| row.get(i))
+                                .map(|row| {
+                                    let s = row.get(i);
+                                    (!null_regex.is_null(s)).then_some(s)
+                                })
                                 .collect::<DictionaryArray<Int16Type>>(),
                         ) as ArrayRef),
                         DataType::Int32 => Ok(Arc::new(
                             rows.iter()
-                                .map(|row| row.get(i))
+                                .map(|row| {
+                                    let s = row.get(i);
+                                    (!null_regex.is_null(s)).then_some(s)
+                                })
                                 .collect::<DictionaryArray<Int32Type>>(),
                         ) as ArrayRef),
                         DataType::Int64 => Ok(Arc::new(
                             rows.iter()
-                                .map(|row| row.get(i))
+                                .map(|row| {
+                                    let s = row.get(i);
+                                    (!null_regex.is_null(s)).then_some(s)
+                                })
                                 .collect::<DictionaryArray<Int64Type>>(),
                         ) as ArrayRef),
                         DataType::UInt8 => Ok(Arc::new(
                             rows.iter()
-                                .map(|row| row.get(i))
+                                .map(|row| {
+                                    let s = row.get(i);
+                                    (!null_regex.is_null(s)).then_some(s)
+                                })
                                 .collect::<DictionaryArray<UInt8Type>>(),
                         ) as ArrayRef),
                         DataType::UInt16 => Ok(Arc::new(
                             rows.iter()
-                                .map(|row| row.get(i))
+                                .map(|row| {
+                                    let s = row.get(i);
+                                    (!null_regex.is_null(s)).then_some(s)
+                                })
                                 .collect::<DictionaryArray<UInt16Type>>(),
                         ) as ArrayRef),
                         DataType::UInt32 => Ok(Arc::new(
                             rows.iter()
-                                .map(|row| row.get(i))
+                                .map(|row| {
+                                    let s = row.get(i);
+                                    (!null_regex.is_null(s)).then_some(s)
+                                })
                                 .collect::<DictionaryArray<UInt32Type>>(),
                         ) as ArrayRef),
                         DataType::UInt64 => Ok(Arc::new(
                             rows.iter()
-                                .map(|row| row.get(i))
+                                .map(|row| {
+                                    let s = row.get(i);
+                                    (!null_regex.is_null(s)).then_some(s)
+                                })
                                 .collect::<DictionaryArray<UInt64Type>>(),
                         ) as ArrayRef),
                         _ => Err(ArrowError::ParseError(format!(
@@ -1475,6 +1499,40 @@ mod tests {
         assert_eq!(strings.value(29), "Uckfield, East Sussex, UK");
     }
 
+    #[test]
+    fn test_csv_with_nullable_dictionary() {
+        let offset_type = vec![
+            DataType::Int8,
+            DataType::Int16,
+            DataType::Int32,
+            DataType::Int64,
+            DataType::UInt8,
+            DataType::UInt16,
+            DataType::UInt32,
+            DataType::UInt64,
+        ];
+        for data_type in offset_type {
+            let file = 
File::open("test/data/dictionary_nullable_test.csv").unwrap();
+            let dictionary_type =
+                DataType::Dictionary(Box::new(data_type), 
Box::new(DataType::Utf8));
+            let schema = Arc::new(Schema::new(vec![
+                Field::new("id", DataType::Utf8, false),
+                Field::new("name", dictionary_type.clone(), true),
+            ]));
+
+            let mut csv = ReaderBuilder::new(schema)
+                .build(file.try_clone().unwrap())
+                .unwrap();
+
+            let batch = csv.next().unwrap().unwrap();
+            assert_eq!(3, batch.num_rows());
+            assert_eq!(2, batch.num_columns());
+
+            let names = arrow_cast::cast(batch.column(1), 
&dictionary_type).unwrap();
+            assert!(!names.is_null(2));
+            assert!(names.is_null(1));
+        }
+    }
     #[test]
     fn test_nulls() {
         let schema = Arc::new(Schema::new(vec![
diff --git a/arrow-csv/test/data/dictionary_nullable_test.csv 
b/arrow-csv/test/data/dictionary_nullable_test.csv
new file mode 100644
index 0000000000..c9ada5293b
--- /dev/null
+++ b/arrow-csv/test/data/dictionary_nullable_test.csv
@@ -0,0 +1,3 @@
+id,name
+1,
+2,bob

Reply via email to