This is an automated email from the ASF dual-hosted git repository.
tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new e33dbe269 Support String Coercion in Raw JSON Reader (#3736)
e33dbe269 is described below
commit e33dbe26989f290010bcac8fe933379014884d68
Author: Rafael Guerreiro <[email protected]>
AuthorDate: Tue Feb 21 06:00:36 2023 -0800
Support String Coercion in Raw JSON Reader (#3736)
* Add ability to coerce primitive values into string behind an option
* Moving coerce_primitive flag into the Decoders because it makes more
sense than the Tape
---
arrow-json/src/raw/list_array.rs | 12 +++-
arrow-json/src/raw/map_array.rs | 19 ++++--
arrow-json/src/raw/mod.rs | 128 +++++++++++++++++++++++++++++++++----
arrow-json/src/raw/string_array.rs | 34 +++++++++-
arrow-json/src/raw/struct_array.rs | 10 ++-
5 files changed, 181 insertions(+), 22 deletions(-)
diff --git a/arrow-json/src/raw/list_array.rs b/arrow-json/src/raw/list_array.rs
index 7d37fc51d..91ca4b727 100644
--- a/arrow-json/src/raw/list_array.rs
+++ b/arrow-json/src/raw/list_array.rs
@@ -31,13 +31,21 @@ pub struct ListArrayDecoder<O> {
}
impl<O: OffsetSizeTrait> ListArrayDecoder<O> {
- pub fn new(data_type: DataType, is_nullable: bool) -> Result<Self,
ArrowError> {
+ pub fn new(
+ data_type: DataType,
+ coerce_primitive: bool,
+ is_nullable: bool,
+ ) -> Result<Self, ArrowError> {
let field = match &data_type {
DataType::List(f) if !O::IS_LARGE => f,
DataType::LargeList(f) if O::IS_LARGE => f,
_ => unreachable!(),
};
- let decoder = make_decoder(field.data_type().clone(),
field.is_nullable())?;
+ let decoder = make_decoder(
+ field.data_type().clone(),
+ coerce_primitive,
+ field.is_nullable(),
+ )?;
Ok(Self {
data_type,
diff --git a/arrow-json/src/raw/map_array.rs b/arrow-json/src/raw/map_array.rs
index 670210f66..ac48d8bce 100644
--- a/arrow-json/src/raw/map_array.rs
+++ b/arrow-json/src/raw/map_array.rs
@@ -30,7 +30,11 @@ pub struct MapArrayDecoder {
}
impl MapArrayDecoder {
- pub fn new(data_type: DataType, is_nullable: bool) -> Result<Self,
ArrowError> {
+ pub fn new(
+ data_type: DataType,
+ coerce_primitive: bool,
+ is_nullable: bool,
+ ) -> Result<Self, ArrowError> {
let fields = match &data_type {
DataType::Map(_, true) => {
return Err(ArrowError::NotYetImplemented(
@@ -48,9 +52,16 @@ impl MapArrayDecoder {
_ => unreachable!(),
};
- let keys = make_decoder(fields[0].data_type().clone(),
fields[0].is_nullable())?;
- let values =
- make_decoder(fields[1].data_type().clone(),
fields[1].is_nullable())?;
+ let keys = make_decoder(
+ fields[0].data_type().clone(),
+ coerce_primitive,
+ fields[0].is_nullable(),
+ )?;
+ let values = make_decoder(
+ fields[1].data_type().clone(),
+ coerce_primitive,
+ fields[1].is_nullable(),
+ )?;
Ok(Self {
data_type,
diff --git a/arrow-json/src/raw/mod.rs b/arrow-json/src/raw/mod.rs
index a45ff8ea8..e597753a9 100644
--- a/arrow-json/src/raw/mod.rs
+++ b/arrow-json/src/raw/mod.rs
@@ -43,6 +43,7 @@ mod tape;
/// A builder for [`RawReader`] and [`RawDecoder`]
pub struct RawReaderBuilder {
batch_size: usize,
+ coerce_primitive: bool,
schema: SchemaRef,
}
@@ -58,6 +59,7 @@ impl RawReaderBuilder {
pub fn new(schema: SchemaRef) -> Self {
Self {
batch_size: 1024,
+ coerce_primitive: false,
schema,
}
}
@@ -67,6 +69,14 @@ impl RawReaderBuilder {
Self { batch_size, ..self }
}
+ /// Sets if the decoder should coerce primitive values (bool and number)
into string when the Schema's column is Utf8 or LargeUtf8.
+ pub fn coerce_primitive(self, coerce_primitive: bool) -> Self {
+ Self {
+ coerce_primitive,
+ ..self
+ }
+ }
+
/// Create a [`RawReader`] with the provided [`BufRead`]
pub fn build<R: BufRead>(self, reader: R) -> Result<RawReader<R>,
ArrowError> {
Ok(RawReader {
@@ -77,7 +87,11 @@ impl RawReaderBuilder {
/// Create a [`RawDecoder`]
pub fn build_decoder(self) -> Result<RawDecoder, ArrowError> {
- let decoder =
make_decoder(DataType::Struct(self.schema.fields.clone()), false)?;
+ let decoder = make_decoder(
+ DataType::Struct(self.schema.fields.clone()),
+ self.coerce_primitive,
+ false,
+ )?;
let num_fields = self.schema.all_fields().len();
Ok(RawDecoder {
@@ -270,6 +284,7 @@ macro_rules! primitive_decoder {
fn make_decoder(
data_type: DataType,
+ coerce_primitive: bool,
is_nullable: bool,
) -> Result<Box<dyn ArrayDecoder>, ArrowError> {
downcast_integer! {
@@ -277,15 +292,15 @@ fn make_decoder(
DataType::Float32 => primitive_decoder!(Float32Type, data_type),
DataType::Float64 => primitive_decoder!(Float64Type, data_type),
DataType::Boolean => Ok(Box::<BooleanArrayDecoder>::default()),
- DataType::Utf8 => Ok(Box::<StringArrayDecoder::<i32>>::default()),
- DataType::LargeUtf8 => Ok(Box::<StringArrayDecoder::<i64>>::default()),
- DataType::List(_) =>
Ok(Box::new(ListArrayDecoder::<i32>::new(data_type, is_nullable)?)),
- DataType::LargeList(_) =>
Ok(Box::new(ListArrayDecoder::<i64>::new(data_type, is_nullable)?)),
- DataType::Struct(_) => Ok(Box::new(StructArrayDecoder::new(data_type,
is_nullable)?)),
+ DataType::Utf8 =>
Ok(Box::new(StringArrayDecoder::<i32>::new(coerce_primitive))),
+ DataType::LargeUtf8 =>
Ok(Box::new(StringArrayDecoder::<i64>::new(coerce_primitive))),
+ DataType::List(_) =>
Ok(Box::new(ListArrayDecoder::<i32>::new(data_type, coerce_primitive,
is_nullable)?)),
+ DataType::LargeList(_) =>
Ok(Box::new(ListArrayDecoder::<i64>::new(data_type, coerce_primitive,
is_nullable)?)),
+ DataType::Struct(_) => Ok(Box::new(StructArrayDecoder::new(data_type,
coerce_primitive, is_nullable)?)),
DataType::Binary | DataType::LargeBinary |
DataType::FixedSizeBinary(_) => {
Err(ArrowError::JsonError(format!("{data_type} is not supported by
JSON")))
}
- DataType::Map(_, _) => Ok(Box::new(MapArrayDecoder::new(data_type,
is_nullable)?)),
+ DataType::Map(_, _) => Ok(Box::new(MapArrayDecoder::new(data_type,
coerce_primitive, is_nullable)?)),
d => Err(ArrowError::NotYetImplemented(format!("Support for {d} in
JSON reader")))
}
}
@@ -311,13 +326,19 @@ mod tests {
use std::io::{BufReader, Cursor, Seek};
use std::sync::Arc;
- fn do_read(buf: &str, batch_size: usize, schema: SchemaRef) ->
Vec<RecordBatch> {
+ fn do_read(
+ buf: &str,
+ batch_size: usize,
+ coerce_primitive: bool,
+ schema: SchemaRef,
+ ) -> Vec<RecordBatch> {
let mut unbuffered = vec![];
// Test with different batch sizes to test for boundary conditions
for batch_size in [1, 3, 100, batch_size] {
unbuffered = RawReaderBuilder::new(schema.clone())
.with_batch_size(batch_size)
+ .coerce_primitive(coerce_primitive)
.build(Cursor::new(buf.as_bytes()))
.unwrap()
.collect::<Result<Vec<_>, _>>()
@@ -331,6 +352,7 @@ mod tests {
for b in [1, 3, 5] {
let buffered = RawReaderBuilder::new(schema.clone())
.with_batch_size(batch_size)
+ .coerce_primitive(coerce_primitive)
.build(BufReader::with_capacity(b,
Cursor::new(buf.as_bytes())))
.unwrap()
.collect::<Result<Vec<_>, _>>()
@@ -360,7 +382,7 @@ mod tests {
Field::new("c", DataType::Boolean, true),
]));
- let batches = do_read(buf, 1024, schema);
+ let batches = do_read(buf, 1024, false, schema);
assert_eq!(batches.len(), 1);
let col1 = as_primitive_array::<Int64Type>(batches[0].column(0));
@@ -397,7 +419,7 @@ mod tests {
Field::new("b", DataType::LargeUtf8, true),
]));
- let batches = do_read(buf, 1024, schema);
+ let batches = do_read(buf, 1024, false, schema);
assert_eq!(batches.len(), 1);
let col1 = as_string_array(batches[0].column(0));
@@ -454,7 +476,7 @@ mod tests {
),
]));
- let batches = do_read(buf, 1024, schema);
+ let batches = do_read(buf, 1024, false, schema);
assert_eq!(batches.len(), 1);
let list = as_list_array(batches[0].column(0).as_ref());
@@ -517,7 +539,7 @@ mod tests {
),
]));
- let batches = do_read(buf, 1024, schema);
+ let batches = do_read(buf, 1024, false, schema);
assert_eq!(batches.len(), 1);
let nested = as_struct_array(batches[0].column(0).as_ref());
@@ -561,7 +583,7 @@ mod tests {
let map = DataType::Map(Box::new(Field::new("entries", entries,
true)), false);
let schema = Arc::new(Schema::new(vec![Field::new("map", map, true)]));
- let batches = do_read(buf, 1024, schema);
+ let batches = do_read(buf, 1024, false, schema);
assert_eq!(batches.len(), 1);
let map = as_map_array(batches[0].column(0).as_ref());
@@ -612,4 +634,84 @@ mod tests {
assert_eq!(a_result, b_result);
}
}
+
+ #[test]
+ fn test_not_coercing_primitive_into_string_without_flag() {
+ let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Utf8,
true)]));
+
+ let buf = r#"{"a": 1}"#;
+ let result = RawReaderBuilder::new(schema.clone())
+ .with_batch_size(1024)
+ .build(Cursor::new(buf.as_bytes()))
+ .unwrap()
+ .read();
+
+ assert!(result.is_err());
+ assert_eq!(
+ result.unwrap_err().to_string(),
+ "Json error: expected string got number".to_string()
+ );
+
+ let buf = r#"{"a": true}"#;
+ let result = RawReaderBuilder::new(schema)
+ .with_batch_size(1024)
+ .build(Cursor::new(buf.as_bytes()))
+ .unwrap()
+ .read();
+
+ assert!(result.is_err());
+ assert_eq!(
+ result.unwrap_err().to_string(),
+ "Json error: expected string got true".to_string()
+ );
+ }
+
+ #[test]
+ fn test_coercing_primitive_into_string() {
+ let buf = r#"
+ {"a": 1, "b": 2, "c": true}
+ {"a": 2E0, "b": 4, "c": false}
+
+ {"b": 6, "a": 2.0}
+ {"b": "5", "a": 2}
+ {"b": 4e0}
+ {"b": 7, "a": null}
+ "#;
+
+ let schema = Arc::new(Schema::new(vec![
+ Field::new("a", DataType::Utf8, true),
+ Field::new("b", DataType::Utf8, true),
+ Field::new("c", DataType::Utf8, true),
+ ]));
+
+ let batches = do_read(buf, 1024, true, schema);
+ assert_eq!(batches.len(), 1);
+
+ let col1 = as_string_array(batches[0].column(0));
+ assert_eq!(col1.null_count(), 2);
+ assert_eq!(col1.value(0), "1");
+ assert_eq!(col1.value(1), "2E0");
+ assert_eq!(col1.value(2), "2.0");
+ assert_eq!(col1.value(3), "2");
+ assert!(col1.is_null(4));
+ assert!(col1.is_null(5));
+
+ let col2 = as_string_array(batches[0].column(1));
+ assert_eq!(col2.null_count(), 0);
+ assert_eq!(col2.value(0), "2");
+ assert_eq!(col2.value(1), "4");
+ assert_eq!(col2.value(2), "6");
+ assert_eq!(col2.value(3), "5");
+ assert_eq!(col2.value(4), "4e0");
+ assert_eq!(col2.value(5), "7");
+
+ let col3 = as_string_array(batches[0].column(2));
+ assert_eq!(col3.null_count(), 4);
+ assert_eq!(col3.value(0), "true");
+ assert_eq!(col3.value(1), "false");
+ assert!(col3.is_null(2));
+ assert!(col3.is_null(3));
+ assert!(col3.is_null(4));
+ assert!(col3.is_null(5));
+ }
}
diff --git a/arrow-json/src/raw/string_array.rs
b/arrow-json/src/raw/string_array.rs
index 31a7a99be..104e4e83f 100644
--- a/arrow-json/src/raw/string_array.rs
+++ b/arrow-json/src/raw/string_array.rs
@@ -24,13 +24,27 @@ use std::marker::PhantomData;
use crate::raw::tape::{Tape, TapeElement};
use crate::raw::{tape_error, ArrayDecoder};
-#[derive(Default)]
+const TRUE: &str = "true";
+const FALSE: &str = "false";
+
pub struct StringArrayDecoder<O: OffsetSizeTrait> {
+ coerce_primitive: bool,
phantom: PhantomData<O>,
}
+impl<O: OffsetSizeTrait> StringArrayDecoder<O> {
+ pub fn new(coerce_primitive: bool) -> Self {
+ Self {
+ coerce_primitive,
+ phantom: Default::default(),
+ }
+ }
+}
+
impl<O: OffsetSizeTrait> ArrayDecoder for StringArrayDecoder<O> {
fn decode(&mut self, tape: &Tape<'_>, pos: &[u32]) -> Result<ArrayData,
ArrowError> {
+ let coerce_primitive = self.coerce_primitive;
+
let mut data_capacity = 0;
for p in pos {
match tape.get(*p) {
@@ -38,6 +52,15 @@ impl<O: OffsetSizeTrait> ArrayDecoder for
StringArrayDecoder<O> {
data_capacity += tape.get_string(idx).len();
}
TapeElement::Null => {}
+ TapeElement::True if coerce_primitive => {
+ data_capacity += TRUE.len();
+ }
+ TapeElement::False if coerce_primitive => {
+ data_capacity += FALSE.len();
+ }
+ TapeElement::Number(idx) if coerce_primitive => {
+ data_capacity += tape.get_string(idx).len();
+ }
d => return Err(tape_error(d, "string")),
}
}
@@ -58,6 +81,15 @@ impl<O: OffsetSizeTrait> ArrayDecoder for
StringArrayDecoder<O> {
builder.append_value(tape.get_string(idx));
}
TapeElement::Null => builder.append_null(),
+ TapeElement::True if coerce_primitive => {
+ builder.append_value(TRUE);
+ }
+ TapeElement::False if coerce_primitive => {
+ builder.append_value(FALSE);
+ }
+ TapeElement::Number(idx) if coerce_primitive => {
+ builder.append_value(tape.get_string(idx));
+ }
_ => unreachable!(),
}
}
diff --git a/arrow-json/src/raw/struct_array.rs
b/arrow-json/src/raw/struct_array.rs
index 418d8abcc..64ceff224 100644
--- a/arrow-json/src/raw/struct_array.rs
+++ b/arrow-json/src/raw/struct_array.rs
@@ -28,10 +28,16 @@ pub struct StructArrayDecoder {
}
impl StructArrayDecoder {
- pub fn new(data_type: DataType, is_nullable: bool) -> Result<Self,
ArrowError> {
+ pub fn new(
+ data_type: DataType,
+ coerce_primitive: bool,
+ is_nullable: bool,
+ ) -> Result<Self, ArrowError> {
let decoders = struct_fields(&data_type)
.iter()
- .map(|f| make_decoder(f.data_type().clone(), f.is_nullable()))
+ .map(|f| {
+ make_decoder(f.data_type().clone(), coerce_primitive,
f.is_nullable())
+ })
.collect::<Result<Vec<_>, ArrowError>>()?;
Ok(Self {