jecsand838 commented on code in PR #8124: URL: https://github.com/apache/arrow-rs/pull/8124#discussion_r2280990869
########## arrow-avro/src/reader/mod.rs: ########## @@ -802,6 +802,411 @@ mod test { msg } + fn load_writer_schema_json(path: &str) -> Value { + let file = File::open(path).unwrap(); + let header = super::read_header(BufReader::new(file)).unwrap(); + let schema = header.schema().unwrap().unwrap(); + serde_json::to_value(&schema).unwrap() + } + + fn make_reader_schema_with_promotions( + path: &str, + promotions: &HashMap<&str, &str>, + ) -> AvroSchema { + let mut root = load_writer_schema_json(path); + assert_eq!(root["type"], "record", "writer schema must be a record"); + let fields = root + .get_mut("fields") + .and_then(|f| f.as_array_mut()) + .expect("record has fields"); + for f in fields.iter_mut() { + let Some(name) = f.get("name").and_then(|n| n.as_str()) else { + continue; + }; + if let Some(new_ty) = promotions.get(name) { + let ty = f.get_mut("type").expect("field has a type"); + match ty { + Value::String(_) => { + *ty = Value::String((*new_ty).to_string()); + } + // Union + Value::Array(arr) => { + for b in arr.iter_mut() { + match b { + Value::String(s) if s != "null" => { + *b = Value::String((*new_ty).to_string()); + break; + } + Value::Object(_) => { + *b = Value::String((*new_ty).to_string()); + break; + } + _ => {} + } + } + } + Value::Object(_) => { + *ty = Value::String((*new_ty).to_string()); + } + _ => {} + } + } + } + AvroSchema::new(root.to_string()) + } + + fn read_alltypes_with_reader_schema(path: &str, reader_schema: AvroSchema) -> RecordBatch { + let file = File::open(path).unwrap(); + let reader = ReaderBuilder::new() + .with_batch_size(1024) + .with_utf8_view(false) + .with_reader_schema(reader_schema) + .build(BufReader::new(file)) + .unwrap(); + + let schema = reader.schema(); + let batches = reader.collect::<Result<Vec<_>, _>>().unwrap(); + arrow::compute::concat_batches(&schema, &batches).unwrap() + } + + #[test] + fn test_alltypes_schema_promotion_mixed() { + let files = [ + "avro/alltypes_plain.avro", + "avro/alltypes_plain.snappy.avro", + "avro/alltypes_plain.zstandard.avro", + "avro/alltypes_plain.bzip2.avro", + "avro/alltypes_plain.xz.avro", + ]; + for file in files { + let file = arrow_test_data(file); + let mut promotions: HashMap<&str, &str> = HashMap::new(); + promotions.insert("id", "long"); + promotions.insert("tinyint_col", "float"); + promotions.insert("smallint_col", "double"); + promotions.insert("int_col", "double"); + promotions.insert("bigint_col", "double"); + promotions.insert("float_col", "double"); + promotions.insert("date_string_col", "string"); + promotions.insert("string_col", "string"); + let reader_schema = make_reader_schema_with_promotions(&file, &promotions); + let batch = read_alltypes_with_reader_schema(&file, reader_schema); + let expected = RecordBatch::try_from_iter_with_nullable([ + ( + "id", + Arc::new(Int64Array::from(vec![4i64, 5, 6, 7, 2, 3, 0, 1])) as _, + true, + ), + ( + "bool_col", + Arc::new(BooleanArray::from_iter((0..8).map(|x| Some(x % 2 == 0)))) as _, + true, + ), + ( + "tinyint_col", + Arc::new(Float32Array::from_iter_values( + (0..8).map(|x| (x % 2) as f32), + )) as _, + true, + ), + ( + "smallint_col", + Arc::new(Float64Array::from_iter_values( + (0..8).map(|x| (x % 2) as f64), + )) as _, + true, + ), + ( + "int_col", + Arc::new(Float64Array::from_iter_values( + (0..8).map(|x| (x % 2) as f64), + )) as _, + true, + ), + ( + "bigint_col", + Arc::new(Float64Array::from_iter_values( + (0..8).map(|x| ((x % 2) * 10) as f64), + )) as _, + true, + ), + ( + "float_col", + Arc::new(Float64Array::from_iter_values( + (0..8).map(|x| ((x % 2) as f32 * 1.1f32) as f64), + )) as _, + true, + ), + ( + "double_col", + Arc::new(Float64Array::from_iter_values( + (0..8).map(|x| (x % 2) as f64 * 10.1), + )) as _, + true, + ), + ( + "date_string_col", + Arc::new(StringArray::from(vec![ + "03/01/09", "03/01/09", "04/01/09", "04/01/09", "02/01/09", "02/01/09", + "01/01/09", "01/01/09", + ])) as _, + true, + ), + ( + "string_col", + Arc::new(StringArray::from( + (0..8) + .map(|x| if x % 2 == 0 { "0" } else { "1" }) + .collect::<Vec<_>>(), + )) as _, + true, + ), + ( + "timestamp_col", + Arc::new( + TimestampMicrosecondArray::from_iter_values([ + 1235865600000000, // 2009-03-01T00:00:00.000 + 1235865660000000, // 2009-03-01T00:01:00.000 + 1238544000000000, // 2009-04-01T00:00:00.000 + 1238544060000000, // 2009-04-01T00:01:00.000 + 1233446400000000, // 2009-02-01T00:00:00.000 + 1233446460000000, // 2009-02-01T00:01:00.000 + 1230768000000000, // 2009-01-01T00:00:00.000 + 1230768060000000, // 2009-01-01T00:01:00.000 + ]) + .with_timezone("+00:00"), + ) as _, + true, + ), + ]) + .unwrap(); + assert_eq!(batch, expected, "mismatch for file {file}"); + } + } + + #[test] + fn test_alltypes_schema_promotion_long_to_float_only() { + let files = [ + "avro/alltypes_plain.avro", + "avro/alltypes_plain.snappy.avro", + "avro/alltypes_plain.zstandard.avro", + "avro/alltypes_plain.bzip2.avro", + "avro/alltypes_plain.xz.avro", + ]; + for file in files { + let file = arrow_test_data(file); + let mut promotions: HashMap<&str, &str> = HashMap::new(); + promotions.insert("bigint_col", "float"); + let reader_schema = make_reader_schema_with_promotions(&file, &promotions); + let batch = read_alltypes_with_reader_schema(&file, reader_schema); + let expected = RecordBatch::try_from_iter_with_nullable([ + ( + "id", + Arc::new(Int32Array::from(vec![4, 5, 6, 7, 2, 3, 0, 1])) as _, + true, + ), + ( + "bool_col", + Arc::new(BooleanArray::from_iter((0..8).map(|x| Some(x % 2 == 0)))) as _, + true, + ), + ( + "tinyint_col", + Arc::new(Int32Array::from_iter_values((0..8).map(|x| x % 2))) as _, + true, + ), + ( + "smallint_col", + Arc::new(Int32Array::from_iter_values((0..8).map(|x| x % 2))) as _, + true, + ), + ( + "int_col", + Arc::new(Int32Array::from_iter_values((0..8).map(|x| x % 2))) as _, + true, + ), + ( + "bigint_col", + Arc::new(Float32Array::from_iter_values( + (0..8).map(|x| ((x % 2) * 10) as f32), + )) as _, + true, + ), + ( + "float_col", + Arc::new(Float32Array::from_iter_values( + (0..8).map(|x| (x % 2) as f32 * 1.1), + )) as _, + true, + ), + ( + "double_col", + Arc::new(Float64Array::from_iter_values( + (0..8).map(|x| (x % 2) as f64 * 10.1), + )) as _, + true, + ), + ( + "date_string_col", Review Comment: @alamb The underlying data in the `alltypes_plain.avro` file is typed and stored as `bytes`. Here's the Avro schema info for the `date_string_col` field: ```json { "name": "date_string_col", "type": [ "bytes", "null" ] }, ``` I also noticed that the `string_col` is bytes as well: ```json { "name": "string_col", "type": [ "bytes", "null" ] }, ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@arrow.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org