This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new e33dbe269 Support String Coercion in Raw JSON Reader (#3736)
e33dbe269 is described below

commit e33dbe26989f290010bcac8fe933379014884d68
Author: Rafael Guerreiro <[email protected]>
AuthorDate: Tue Feb 21 06:00:36 2023 -0800

    Support String Coercion in Raw JSON Reader (#3736)
    
    * Add ability to coerce primitive values into string behind an option
    
    * Moving coerce_primitive flag into the Decoders because it makes more 
sense than the Tape
---
 arrow-json/src/raw/list_array.rs   |  12 +++-
 arrow-json/src/raw/map_array.rs    |  19 ++++--
 arrow-json/src/raw/mod.rs          | 128 +++++++++++++++++++++++++++++++++----
 arrow-json/src/raw/string_array.rs |  34 +++++++++-
 arrow-json/src/raw/struct_array.rs |  10 ++-
 5 files changed, 181 insertions(+), 22 deletions(-)

diff --git a/arrow-json/src/raw/list_array.rs b/arrow-json/src/raw/list_array.rs
index 7d37fc51d..91ca4b727 100644
--- a/arrow-json/src/raw/list_array.rs
+++ b/arrow-json/src/raw/list_array.rs
@@ -31,13 +31,21 @@ pub struct ListArrayDecoder<O> {
 }
 
 impl<O: OffsetSizeTrait> ListArrayDecoder<O> {
-    pub fn new(data_type: DataType, is_nullable: bool) -> Result<Self, 
ArrowError> {
+    pub fn new(
+        data_type: DataType,
+        coerce_primitive: bool,
+        is_nullable: bool,
+    ) -> Result<Self, ArrowError> {
         let field = match &data_type {
             DataType::List(f) if !O::IS_LARGE => f,
             DataType::LargeList(f) if O::IS_LARGE => f,
             _ => unreachable!(),
         };
-        let decoder = make_decoder(field.data_type().clone(), 
field.is_nullable())?;
+        let decoder = make_decoder(
+            field.data_type().clone(),
+            coerce_primitive,
+            field.is_nullable(),
+        )?;
 
         Ok(Self {
             data_type,
diff --git a/arrow-json/src/raw/map_array.rs b/arrow-json/src/raw/map_array.rs
index 670210f66..ac48d8bce 100644
--- a/arrow-json/src/raw/map_array.rs
+++ b/arrow-json/src/raw/map_array.rs
@@ -30,7 +30,11 @@ pub struct MapArrayDecoder {
 }
 
 impl MapArrayDecoder {
-    pub fn new(data_type: DataType, is_nullable: bool) -> Result<Self, 
ArrowError> {
+    pub fn new(
+        data_type: DataType,
+        coerce_primitive: bool,
+        is_nullable: bool,
+    ) -> Result<Self, ArrowError> {
         let fields = match &data_type {
             DataType::Map(_, true) => {
                 return Err(ArrowError::NotYetImplemented(
@@ -48,9 +52,16 @@ impl MapArrayDecoder {
             _ => unreachable!(),
         };
 
-        let keys = make_decoder(fields[0].data_type().clone(), 
fields[0].is_nullable())?;
-        let values =
-            make_decoder(fields[1].data_type().clone(), 
fields[1].is_nullable())?;
+        let keys = make_decoder(
+            fields[0].data_type().clone(),
+            coerce_primitive,
+            fields[0].is_nullable(),
+        )?;
+        let values = make_decoder(
+            fields[1].data_type().clone(),
+            coerce_primitive,
+            fields[1].is_nullable(),
+        )?;
 
         Ok(Self {
             data_type,
diff --git a/arrow-json/src/raw/mod.rs b/arrow-json/src/raw/mod.rs
index a45ff8ea8..e597753a9 100644
--- a/arrow-json/src/raw/mod.rs
+++ b/arrow-json/src/raw/mod.rs
@@ -43,6 +43,7 @@ mod tape;
 /// A builder for [`RawReader`] and [`RawDecoder`]
 pub struct RawReaderBuilder {
     batch_size: usize,
+    coerce_primitive: bool,
 
     schema: SchemaRef,
 }
@@ -58,6 +59,7 @@ impl RawReaderBuilder {
     pub fn new(schema: SchemaRef) -> Self {
         Self {
             batch_size: 1024,
+            coerce_primitive: false,
             schema,
         }
     }
@@ -67,6 +69,14 @@ impl RawReaderBuilder {
         Self { batch_size, ..self }
     }
 
+    /// Sets if the decoder should coerce primitive values (bool and number) 
into string when the Schema's column is Utf8 or LargeUtf8.
+    pub fn coerce_primitive(self, coerce_primitive: bool) -> Self {
+        Self {
+            coerce_primitive,
+            ..self
+        }
+    }
+
     /// Create a [`RawReader`] with the provided [`BufRead`]
     pub fn build<R: BufRead>(self, reader: R) -> Result<RawReader<R>, 
ArrowError> {
         Ok(RawReader {
@@ -77,7 +87,11 @@ impl RawReaderBuilder {
 
     /// Create a [`RawDecoder`]
     pub fn build_decoder(self) -> Result<RawDecoder, ArrowError> {
-        let decoder = 
make_decoder(DataType::Struct(self.schema.fields.clone()), false)?;
+        let decoder = make_decoder(
+            DataType::Struct(self.schema.fields.clone()),
+            self.coerce_primitive,
+            false,
+        )?;
         let num_fields = self.schema.all_fields().len();
 
         Ok(RawDecoder {
@@ -270,6 +284,7 @@ macro_rules! primitive_decoder {
 
 fn make_decoder(
     data_type: DataType,
+    coerce_primitive: bool,
     is_nullable: bool,
 ) -> Result<Box<dyn ArrayDecoder>, ArrowError> {
     downcast_integer! {
@@ -277,15 +292,15 @@ fn make_decoder(
         DataType::Float32 => primitive_decoder!(Float32Type, data_type),
         DataType::Float64 => primitive_decoder!(Float64Type, data_type),
         DataType::Boolean => Ok(Box::<BooleanArrayDecoder>::default()),
-        DataType::Utf8 => Ok(Box::<StringArrayDecoder::<i32>>::default()),
-        DataType::LargeUtf8 => Ok(Box::<StringArrayDecoder::<i64>>::default()),
-        DataType::List(_) => 
Ok(Box::new(ListArrayDecoder::<i32>::new(data_type, is_nullable)?)),
-        DataType::LargeList(_) => 
Ok(Box::new(ListArrayDecoder::<i64>::new(data_type, is_nullable)?)),
-        DataType::Struct(_) => Ok(Box::new(StructArrayDecoder::new(data_type, 
is_nullable)?)),
+        DataType::Utf8 => 
Ok(Box::new(StringArrayDecoder::<i32>::new(coerce_primitive))),
+        DataType::LargeUtf8 => 
Ok(Box::new(StringArrayDecoder::<i64>::new(coerce_primitive))),
+        DataType::List(_) => 
Ok(Box::new(ListArrayDecoder::<i32>::new(data_type, coerce_primitive, 
is_nullable)?)),
+        DataType::LargeList(_) => 
Ok(Box::new(ListArrayDecoder::<i64>::new(data_type, coerce_primitive, 
is_nullable)?)),
+        DataType::Struct(_) => Ok(Box::new(StructArrayDecoder::new(data_type, 
coerce_primitive, is_nullable)?)),
         DataType::Binary | DataType::LargeBinary | 
DataType::FixedSizeBinary(_) => {
             Err(ArrowError::JsonError(format!("{data_type} is not supported by 
JSON")))
         }
-        DataType::Map(_, _) => Ok(Box::new(MapArrayDecoder::new(data_type, 
is_nullable)?)),
+        DataType::Map(_, _) => Ok(Box::new(MapArrayDecoder::new(data_type, 
coerce_primitive, is_nullable)?)),
         d => Err(ArrowError::NotYetImplemented(format!("Support for {d} in 
JSON reader")))
     }
 }
@@ -311,13 +326,19 @@ mod tests {
     use std::io::{BufReader, Cursor, Seek};
     use std::sync::Arc;
 
-    fn do_read(buf: &str, batch_size: usize, schema: SchemaRef) -> 
Vec<RecordBatch> {
+    fn do_read(
+        buf: &str,
+        batch_size: usize,
+        coerce_primitive: bool,
+        schema: SchemaRef,
+    ) -> Vec<RecordBatch> {
         let mut unbuffered = vec![];
 
         // Test with different batch sizes to test for boundary conditions
         for batch_size in [1, 3, 100, batch_size] {
             unbuffered = RawReaderBuilder::new(schema.clone())
                 .with_batch_size(batch_size)
+                .coerce_primitive(coerce_primitive)
                 .build(Cursor::new(buf.as_bytes()))
                 .unwrap()
                 .collect::<Result<Vec<_>, _>>()
@@ -331,6 +352,7 @@ mod tests {
             for b in [1, 3, 5] {
                 let buffered = RawReaderBuilder::new(schema.clone())
                     .with_batch_size(batch_size)
+                    .coerce_primitive(coerce_primitive)
                     .build(BufReader::with_capacity(b, 
Cursor::new(buf.as_bytes())))
                     .unwrap()
                     .collect::<Result<Vec<_>, _>>()
@@ -360,7 +382,7 @@ mod tests {
             Field::new("c", DataType::Boolean, true),
         ]));
 
-        let batches = do_read(buf, 1024, schema);
+        let batches = do_read(buf, 1024, false, schema);
         assert_eq!(batches.len(), 1);
 
         let col1 = as_primitive_array::<Int64Type>(batches[0].column(0));
@@ -397,7 +419,7 @@ mod tests {
             Field::new("b", DataType::LargeUtf8, true),
         ]));
 
-        let batches = do_read(buf, 1024, schema);
+        let batches = do_read(buf, 1024, false, schema);
         assert_eq!(batches.len(), 1);
 
         let col1 = as_string_array(batches[0].column(0));
@@ -454,7 +476,7 @@ mod tests {
             ),
         ]));
 
-        let batches = do_read(buf, 1024, schema);
+        let batches = do_read(buf, 1024, false, schema);
         assert_eq!(batches.len(), 1);
 
         let list = as_list_array(batches[0].column(0).as_ref());
@@ -517,7 +539,7 @@ mod tests {
             ),
         ]));
 
-        let batches = do_read(buf, 1024, schema);
+        let batches = do_read(buf, 1024, false, schema);
         assert_eq!(batches.len(), 1);
 
         let nested = as_struct_array(batches[0].column(0).as_ref());
@@ -561,7 +583,7 @@ mod tests {
         let map = DataType::Map(Box::new(Field::new("entries", entries, 
true)), false);
         let schema = Arc::new(Schema::new(vec![Field::new("map", map, true)]));
 
-        let batches = do_read(buf, 1024, schema);
+        let batches = do_read(buf, 1024, false, schema);
         assert_eq!(batches.len(), 1);
 
         let map = as_map_array(batches[0].column(0).as_ref());
@@ -612,4 +634,84 @@ mod tests {
             assert_eq!(a_result, b_result);
         }
     }
+
+    #[test]
+    fn test_not_coercing_primitive_into_string_without_flag() {
+        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Utf8, 
true)]));
+
+        let buf = r#"{"a": 1}"#;
+        let result = RawReaderBuilder::new(schema.clone())
+            .with_batch_size(1024)
+            .build(Cursor::new(buf.as_bytes()))
+            .unwrap()
+            .read();
+
+        assert!(result.is_err());
+        assert_eq!(
+            result.unwrap_err().to_string(),
+            "Json error: expected string got number".to_string()
+        );
+
+        let buf = r#"{"a": true}"#;
+        let result = RawReaderBuilder::new(schema)
+            .with_batch_size(1024)
+            .build(Cursor::new(buf.as_bytes()))
+            .unwrap()
+            .read();
+
+        assert!(result.is_err());
+        assert_eq!(
+            result.unwrap_err().to_string(),
+            "Json error: expected string got true".to_string()
+        );
+    }
+
+    #[test]
+    fn test_coercing_primitive_into_string() {
+        let buf = r#"
+        {"a": 1, "b": 2, "c": true}
+        {"a": 2E0, "b": 4, "c": false}
+
+        {"b": 6, "a": 2.0}
+        {"b": "5", "a": 2}
+        {"b": 4e0}
+        {"b": 7, "a": null}
+        "#;
+
+        let schema = Arc::new(Schema::new(vec![
+            Field::new("a", DataType::Utf8, true),
+            Field::new("b", DataType::Utf8, true),
+            Field::new("c", DataType::Utf8, true),
+        ]));
+
+        let batches = do_read(buf, 1024, true, schema);
+        assert_eq!(batches.len(), 1);
+
+        let col1 = as_string_array(batches[0].column(0));
+        assert_eq!(col1.null_count(), 2);
+        assert_eq!(col1.value(0), "1");
+        assert_eq!(col1.value(1), "2E0");
+        assert_eq!(col1.value(2), "2.0");
+        assert_eq!(col1.value(3), "2");
+        assert!(col1.is_null(4));
+        assert!(col1.is_null(5));
+
+        let col2 = as_string_array(batches[0].column(1));
+        assert_eq!(col2.null_count(), 0);
+        assert_eq!(col2.value(0), "2");
+        assert_eq!(col2.value(1), "4");
+        assert_eq!(col2.value(2), "6");
+        assert_eq!(col2.value(3), "5");
+        assert_eq!(col2.value(4), "4e0");
+        assert_eq!(col2.value(5), "7");
+
+        let col3 = as_string_array(batches[0].column(2));
+        assert_eq!(col3.null_count(), 4);
+        assert_eq!(col3.value(0), "true");
+        assert_eq!(col3.value(1), "false");
+        assert!(col3.is_null(2));
+        assert!(col3.is_null(3));
+        assert!(col3.is_null(4));
+        assert!(col3.is_null(5));
+    }
 }
diff --git a/arrow-json/src/raw/string_array.rs 
b/arrow-json/src/raw/string_array.rs
index 31a7a99be..104e4e83f 100644
--- a/arrow-json/src/raw/string_array.rs
+++ b/arrow-json/src/raw/string_array.rs
@@ -24,13 +24,27 @@ use std::marker::PhantomData;
 use crate::raw::tape::{Tape, TapeElement};
 use crate::raw::{tape_error, ArrayDecoder};
 
-#[derive(Default)]
+const TRUE: &str = "true";
+const FALSE: &str = "false";
+
 pub struct StringArrayDecoder<O: OffsetSizeTrait> {
+    coerce_primitive: bool,
     phantom: PhantomData<O>,
 }
 
+impl<O: OffsetSizeTrait> StringArrayDecoder<O> {
+    pub fn new(coerce_primitive: bool) -> Self {
+        Self {
+            coerce_primitive,
+            phantom: Default::default(),
+        }
+    }
+}
+
 impl<O: OffsetSizeTrait> ArrayDecoder for StringArrayDecoder<O> {
     fn decode(&mut self, tape: &Tape<'_>, pos: &[u32]) -> Result<ArrayData, 
ArrowError> {
+        let coerce_primitive = self.coerce_primitive;
+
         let mut data_capacity = 0;
         for p in pos {
             match tape.get(*p) {
@@ -38,6 +52,15 @@ impl<O: OffsetSizeTrait> ArrayDecoder for 
StringArrayDecoder<O> {
                     data_capacity += tape.get_string(idx).len();
                 }
                 TapeElement::Null => {}
+                TapeElement::True if coerce_primitive => {
+                    data_capacity += TRUE.len();
+                }
+                TapeElement::False if coerce_primitive => {
+                    data_capacity += FALSE.len();
+                }
+                TapeElement::Number(idx) if coerce_primitive => {
+                    data_capacity += tape.get_string(idx).len();
+                }
                 d => return Err(tape_error(d, "string")),
             }
         }
@@ -58,6 +81,15 @@ impl<O: OffsetSizeTrait> ArrayDecoder for 
StringArrayDecoder<O> {
                     builder.append_value(tape.get_string(idx));
                 }
                 TapeElement::Null => builder.append_null(),
+                TapeElement::True if coerce_primitive => {
+                    builder.append_value(TRUE);
+                }
+                TapeElement::False if coerce_primitive => {
+                    builder.append_value(FALSE);
+                }
+                TapeElement::Number(idx) if coerce_primitive => {
+                    builder.append_value(tape.get_string(idx));
+                }
                 _ => unreachable!(),
             }
         }
diff --git a/arrow-json/src/raw/struct_array.rs 
b/arrow-json/src/raw/struct_array.rs
index 418d8abcc..64ceff224 100644
--- a/arrow-json/src/raw/struct_array.rs
+++ b/arrow-json/src/raw/struct_array.rs
@@ -28,10 +28,16 @@ pub struct StructArrayDecoder {
 }
 
 impl StructArrayDecoder {
-    pub fn new(data_type: DataType, is_nullable: bool) -> Result<Self, 
ArrowError> {
+    pub fn new(
+        data_type: DataType,
+        coerce_primitive: bool,
+        is_nullable: bool,
+    ) -> Result<Self, ArrowError> {
         let decoders = struct_fields(&data_type)
             .iter()
-            .map(|f| make_decoder(f.data_type().clone(), f.is_nullable()))
+            .map(|f| {
+                make_decoder(f.data_type().clone(), coerce_primitive, 
f.is_nullable())
+            })
             .collect::<Result<Vec<_>, ArrowError>>()?;
 
         Ok(Self {

Reply via email to