This is an automated email from the ASF dual-hosted git repository.
tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new df69ef57d0 fix: coerce_primitive for serde decoded data (#5101)
df69ef57d0 is described below
commit df69ef57d055453c399fa925ad315d19211d7ab2
Author: fan <[email protected]>
AuthorDate: Tue Nov 21 16:42:51 2023 +0800
fix: coerce_primitive for serde decoded data (#5101)
* fix: fix json decode number
Signed-off-by: fan <[email protected]>
* follow reviews
Signed-off-by: fan <[email protected]>
* follow reviews
Signed-off-by: fan <[email protected]>
* use fixed size space
Signed-off-by: fan <[email protected]>
---------
Signed-off-by: fan <[email protected]>
---
arrow-json/src/reader/mod.rs | 43 ++++++++++++++++++++++++++++++++++-
arrow-json/src/reader/string_array.rs | 33 ++++++++++++++++++++++++++-
2 files changed, 74 insertions(+), 2 deletions(-)
diff --git a/arrow-json/src/reader/mod.rs b/arrow-json/src/reader/mod.rs
index 71a73df9fe..5afe0dec27 100644
--- a/arrow-json/src/reader/mod.rs
+++ b/arrow-json/src/reader/mod.rs
@@ -717,7 +717,9 @@ mod tests {
use arrow_array::cast::AsArray;
use arrow_array::types::Int32Type;
- use arrow_array::{make_array, Array, BooleanArray, ListArray, StringArray,
StructArray};
+ use arrow_array::{
+ make_array, Array, BooleanArray, Float64Array, ListArray, StringArray,
StructArray,
+ };
use arrow_buffer::{ArrowNativeType, Buffer};
use arrow_cast::display::{ArrayFormatter, FormatOptions};
use arrow_data::ArrayDataBuilder;
@@ -2259,4 +2261,43 @@ mod tests {
.values();
assert_eq!(values, &[1699148028689, 2, 3, 4]);
}
+
+ #[test]
+ fn test_coercing_primitive_into_string_decoder() {
+ let buf = &format!(
+ r#"[{{"a": 1, "b": "A", "c": "T"}}, {{"a": 2, "b": "BB", "c":
"F"}}, {{"a": {}, "b": 123, "c": false}}, {{"a": {}, "b": 789, "c": true}}]"#,
+ (std::i32::MAX as i64 + 10),
+ std::i64::MAX - 10
+ );
+ let schema = Schema::new(vec![
+ Field::new("a", DataType::Float64, true),
+ Field::new("b", DataType::Utf8, true),
+ Field::new("c", DataType::Utf8, true),
+ ]);
+ let json_array: Vec<serde_json::Value> =
serde_json::from_str(buf).unwrap();
+ let schema_ref = Arc::new(schema);
+
+ // read record batches
+ let reader =
ReaderBuilder::new(schema_ref.clone()).with_coerce_primitive(true);
+ let mut decoder = reader.build_decoder().unwrap();
+ decoder.serialize(json_array.as_slice()).unwrap();
+ let batch = decoder.flush().unwrap().unwrap();
+ assert_eq!(
+ batch,
+ RecordBatch::try_new(
+ schema_ref,
+ vec![
+ Arc::new(Float64Array::from(vec![
+ 1.0,
+ 2.0,
+ (std::i32::MAX as i64 + 10) as f64,
+ (std::i64::MAX - 10) as f64
+ ])),
+ Arc::new(StringArray::from(vec!["A", "BB", "123", "789"])),
+ Arc::new(StringArray::from(vec!["T", "F", "false",
"true"])),
+ ]
+ )
+ .unwrap()
+ );
+ }
}
diff --git a/arrow-json/src/reader/string_array.rs
b/arrow-json/src/reader/string_array.rs
index 63a9bcedb7..5ab4d09d5d 100644
--- a/arrow-json/src/reader/string_array.rs
+++ b/arrow-json/src/reader/string_array.rs
@@ -61,7 +61,18 @@ impl<O: OffsetSizeTrait> ArrayDecoder for
StringArrayDecoder<O> {
TapeElement::Number(idx) if coerce_primitive => {
data_capacity += tape.get_string(idx).len();
}
- _ => return Err(tape.error(*p, "string")),
+ TapeElement::I64(_)
+ | TapeElement::I32(_)
+ | TapeElement::F64(_)
+ | TapeElement::F32(_)
+ if coerce_primitive =>
+ {
+ // An arbitrary estimate
+ data_capacity += 10;
+ }
+ _ => {
+ return Err(tape.error(*p, "string"));
+ }
}
}
@@ -89,6 +100,26 @@ impl<O: OffsetSizeTrait> ArrayDecoder for
StringArrayDecoder<O> {
TapeElement::Number(idx) if coerce_primitive => {
builder.append_value(tape.get_string(idx));
}
+ TapeElement::I64(high) if coerce_primitive => match tape.get(p
+ 1) {
+ TapeElement::I32(low) => {
+ let val = (high as i64) << 32 | (low as u32) as i64;
+ builder.append_value(val.to_string());
+ }
+ _ => unreachable!(),
+ },
+ TapeElement::I32(n) if coerce_primitive => {
+ builder.append_value(n.to_string());
+ }
+ TapeElement::F32(n) if coerce_primitive => {
+ builder.append_value(n.to_string());
+ }
+ TapeElement::F64(high) if coerce_primitive => match tape.get(p
+ 1) {
+ TapeElement::F32(low) => {
+ let val = f64::from_bits((high as u64) << 32 | low as
u64);
+ builder.append_value(val.to_string());
+ }
+ _ => unreachable!(),
+ },
_ => unreachable!(),
}
}