jecsand838 commented on code in PR #8349:
URL: https://github.com/apache/arrow-rs/pull/8349#discussion_r2373584390
##########
arrow-avro/src/reader/mod.rs:
##########
@@ -2734,6 +2736,763 @@ mod test {
}
}
+ #[test]
+ fn test_union_fields_avro_nullable_and_general_unions() {
+ let path = "test/data/union_fields.avro";
+ let batch = read_file(path, 1024, false);
+ let schema = batch.schema();
+ let idx = schema.index_of("nullable_int_nullfirst").unwrap();
+ let a = batch
+ .column(idx)
+ .as_any()
+ .downcast_ref::<Int32Array>()
+ .expect("nullable_int_nullfirst should be Int32");
+ assert_eq!(a.len(), 4);
+ assert!(a.is_null(0));
+ assert_eq!(a.value(1), 42);
+ assert!(a.is_null(2));
+ assert_eq!(a.value(3), 0);
+ let idx = schema.index_of("nullable_string_nullsecond").unwrap();
+ let s = batch
+ .column(idx)
+ .as_any()
+ .downcast_ref::<StringArray>()
+ .expect("nullable_string_nullsecond should be Utf8");
+ assert_eq!(s.len(), 4);
+ assert_eq!(s.value(0), "s1");
+ assert!(s.is_null(1));
+ assert_eq!(s.value(2), "s3");
+ assert!(s.is_valid(3)); // empty string, not null
+ assert_eq!(s.value(3), "");
+ let idx = schema.index_of("union_prim").unwrap();
+ let u = batch
+ .column(idx)
+ .as_any()
+ .downcast_ref::<UnionArray>()
+ .expect("union_prim should be Union");
+ let fields = match u.data_type() {
+ DataType::Union(fields, mode) => {
+ assert!(matches!(mode, UnionMode::Dense), "expect dense
unions");
+ fields
+ }
+ other => panic!("expected Union, got {other:?}"),
+ };
+ let tid_by_name = |name: &str| -> i8 {
+ for (tid, f) in fields.iter() {
+ if f.name() == name {
+ return tid;
+ }
+ }
+ panic!("union child '{name}' not found");
+ };
+ let expected_type_ids = vec![
+ tid_by_name("long"),
+ tid_by_name("int"),
+ tid_by_name("float"),
+ tid_by_name("double"),
+ ];
+ let type_ids: Vec<i8> = u.type_ids().iter().copied().collect();
+ assert_eq!(
+ type_ids, expected_type_ids,
+ "branch selection for union_prim rows"
+ );
+ let longs = u
+ .child(tid_by_name("long"))
+ .as_any()
+ .downcast_ref::<Int64Array>()
+ .unwrap();
+ assert_eq!(longs.len(), 1);
+ let ints = u
+ .child(tid_by_name("int"))
+ .as_any()
+ .downcast_ref::<Int32Array>()
+ .unwrap();
+ assert_eq!(ints.len(), 1);
+ let floats = u
+ .child(tid_by_name("float"))
+ .as_any()
+ .downcast_ref::<Float32Array>()
+ .unwrap();
+ assert_eq!(floats.len(), 1);
+ let doubles = u
+ .child(tid_by_name("double"))
+ .as_any()
+ .downcast_ref::<Float64Array>()
+ .unwrap();
+ assert_eq!(doubles.len(), 1);
+ let idx = schema.index_of("union_bytes_vs_string").unwrap();
+ let u = batch
+ .column(idx)
+ .as_any()
+ .downcast_ref::<UnionArray>()
+ .expect("union_bytes_vs_string should be Union");
+ let fields = match u.data_type() {
+ DataType::Union(fields, _) => fields,
+ other => panic!("expected Union, got {other:?}"),
+ };
+ let tid_by_name = |name: &str| -> i8 {
+ for (tid, f) in fields.iter() {
+ if f.name() == name {
+ return tid;
+ }
+ }
+ panic!("union child '{name}' not found");
+ };
+ let tid_bytes = tid_by_name("bytes");
+ let tid_string = tid_by_name("string");
+ let type_ids: Vec<i8> = u.type_ids().iter().copied().collect();
+ assert_eq!(
+ type_ids,
+ vec![tid_bytes, tid_string, tid_string, tid_bytes],
+ "branch selection for bytes/string union"
+ );
+ let s_child = u
+ .child(tid_string)
+ .as_any()
+ .downcast_ref::<StringArray>()
+ .unwrap();
+ assert_eq!(s_child.len(), 2);
+ assert_eq!(s_child.value(0), "hello");
+ assert_eq!(s_child.value(1), "world");
+ let b_child = u
+ .child(tid_bytes)
+ .as_any()
+ .downcast_ref::<BinaryArray>()
+ .unwrap();
+ assert_eq!(b_child.len(), 2);
+ assert_eq!(b_child.value(0), &[0x00, 0xFF, 0x7F]);
+ assert_eq!(b_child.value(1), b""); // previously: &[]
+ let idx = schema.index_of("union_enum_records_array_map").unwrap();
+ let u = batch
+ .column(idx)
+ .as_any()
+ .downcast_ref::<UnionArray>()
+ .expect("union_enum_records_array_map should be Union");
+ let fields = match u.data_type() {
+ DataType::Union(fields, _) => fields,
+ other => panic!("expected Union, got {other:?}"),
+ };
+ let mut tid_enum: Option<i8> = None;
+ let mut tid_rec_a: Option<i8> = None;
+ let mut tid_rec_b: Option<i8> = None;
+ let mut tid_array: Option<i8> = None;
+ let mut tid_map: Option<i8> = None;
+ for (tid, f) in fields.iter() {
+ match f.data_type() {
+ DataType::Dictionary(_, _) => tid_enum = Some(tid),
+ DataType::Struct(childs) => {
+ if childs.len() == 2 && childs[0].name() == "a" &&
childs[1].name() == "b" {
+ tid_rec_a = Some(tid);
+ } else if childs.len() == 2
+ && childs[0].name() == "x"
+ && childs[1].name() == "y"
+ {
+ tid_rec_b = Some(tid);
+ }
+ }
+ DataType::List(_) => tid_array = Some(tid),
+ DataType::Map(_, _) => tid_map = Some(tid),
+ _ => {}
+ }
+ }
+ let (tid_enum, tid_rec_a, tid_rec_b, tid_array) = (
+ tid_enum.expect("enum child"),
+ tid_rec_a.expect("RecA child"),
+ tid_rec_b.expect("RecB child"),
+ tid_array.expect("array<long> child"),
+ );
+ let type_ids: Vec<i8> = u.type_ids().iter().copied().collect();
+ assert_eq!(
+ type_ids,
+ vec![tid_enum, tid_rec_a, tid_rec_b, tid_array],
+ "branch selection for complex union"
+ );
+ let dict = u
+ .child(tid_enum)
+ .as_any()
+ .downcast_ref::<DictionaryArray<Int32Type>>()
+ .unwrap();
+ assert_eq!(dict.len(), 1);
+ assert!(dict.is_valid(0));
+ let rec_a = u
+ .child(tid_rec_a)
+ .as_any()
+ .downcast_ref::<StructArray>()
+ .unwrap();
+ assert_eq!(rec_a.len(), 1);
+ let a_val = rec_a
+ .column_by_name("a")
+ .unwrap()
+ .as_any()
+ .downcast_ref::<Int32Array>()
+ .unwrap();
+ assert_eq!(a_val.value(0), 7);
+ let b_val = rec_a
+ .column_by_name("b")
+ .unwrap()
+ .as_any()
+ .downcast_ref::<StringArray>()
+ .unwrap();
+ assert_eq!(b_val.value(0), "x");
+ // RecB row: {"x": 123456789, "y": b"\xFF\x00"}
+ let rec_b = u
+ .child(tid_rec_b)
+ .as_any()
+ .downcast_ref::<StructArray>()
+ .unwrap();
+ let x_val = rec_b
+ .column_by_name("x")
+ .unwrap()
+ .as_any()
+ .downcast_ref::<Int64Array>()
+ .unwrap();
+ assert_eq!(x_val.value(0), 123_456_789_i64);
+ let y_val = rec_b
+ .column_by_name("y")
+ .unwrap()
+ .as_any()
+ .downcast_ref::<BinaryArray>()
+ .unwrap();
+ assert_eq!(y_val.value(0), &[0xFF, 0x00]);
+ let arr = u
+ .child(tid_array)
+ .as_any()
+ .downcast_ref::<ListArray>()
+ .unwrap();
+ assert_eq!(arr.len(), 1);
+ let first_values = arr.value(0);
+ let longs =
first_values.as_any().downcast_ref::<Int64Array>().unwrap();
+ assert_eq!(longs.len(), 3);
+ assert_eq!(longs.value(0), 1);
+ assert_eq!(longs.value(1), 2);
+ assert_eq!(longs.value(2), 3);
+ let idx = schema.index_of("union_date_or_fixed4").unwrap();
+ let u = batch
+ .column(idx)
+ .as_any()
+ .downcast_ref::<UnionArray>()
+ .expect("union_date_or_fixed4 should be Union");
+ let fields = match u.data_type() {
+ DataType::Union(fields, _) => fields,
+ other => panic!("expected Union, got {other:?}"),
+ };
+ let mut tid_date: Option<i8> = None;
+ let mut tid_fixed: Option<i8> = None;
+ for (tid, f) in fields.iter() {
+ match f.data_type() {
+ DataType::Date32 => tid_date = Some(tid),
+ DataType::FixedSizeBinary(4) => tid_fixed = Some(tid),
+ _ => {}
+ }
+ }
+ let (tid_date, tid_fixed) = (tid_date.expect("date"),
tid_fixed.expect("fixed(4)"));
+ let type_ids: Vec<i8> = u.type_ids().iter().copied().collect();
+ assert_eq!(
+ type_ids,
+ vec![tid_date, tid_fixed, tid_date, tid_fixed],
+ "branch selection for date/fixed4 union"
+ );
+ let dates = u
+ .child(tid_date)
+ .as_any()
+ .downcast_ref::<Date32Array>()
+ .unwrap();
+ assert_eq!(dates.len(), 2);
+ assert_eq!(dates.value(0), 19_000); // ~2022‑01‑15
+ assert_eq!(dates.value(1), 0); // epoch
+ let fixed = u
+ .child(tid_fixed)
+ .as_any()
+ .downcast_ref::<FixedSizeBinaryArray>()
+ .unwrap();
+ assert_eq!(fixed.len(), 2);
+ assert_eq!(fixed.value(0), b"ABCD");
+ assert_eq!(fixed.value(1), &[0x00, 0x11, 0x22, 0x33]);
+ }
+
+ #[test]
+ fn test_union_schema_resolution_all_type_combinations() {
+ let path = "test/data/union_fields.avro";
+ let baseline = read_file(path, 1024, false);
+ let baseline_schema = baseline.schema();
+ let mut root = load_writer_schema_json(path);
+ assert_eq!(root["type"], "record", "writer schema must be a record");
+ let fields = root
+ .get_mut("fields")
+ .and_then(|f| f.as_array_mut())
+ .expect("record has fields");
+ fn is_named_type(obj: &Value, ty: &str, nm: &str) -> bool {
+ obj.get("type").and_then(|v| v.as_str()) == Some(ty)
+ && obj.get("name").and_then(|v| v.as_str()) == Some(nm)
+ }
+ fn is_logical(obj: &Value, prim: &str, lt: &str) -> bool {
+ obj.get("type").and_then(|v| v.as_str()) == Some(prim)
+ && obj.get("logicalType").and_then(|v| v.as_str()) == Some(lt)
+ }
+ fn find_first(arr: &[Value], pred: impl Fn(&Value) -> bool) ->
Option<Value> {
+ arr.iter().find(|v| pred(v)).cloned()
+ }
+ fn prim(s: &str) -> Value {
+ Value::String(s.to_string())
+ }
+ for f in fields.iter_mut() {
+ let Some(name) = f.get("name").and_then(|n| n.as_str()) else {
+ continue;
+ };
+ match name {
+ // Flip null ordering – should not affect values
+ "nullable_int_nullfirst" => {
+ f["type"] = json!(["int", "null"]);
+ }
+ "nullable_string_nullsecond" => {
+ f["type"] = json!(["null", "string"]);
+ }
+ "union_prim" => {
+ let orig = f["type"].as_array().unwrap().clone();
+ let long = prim("long");
+ let double = prim("double");
+ let string = prim("string");
+ let bytes = prim("bytes");
+ let boolean = prim("boolean");
+ assert!(orig.contains(&long));
+ assert!(orig.contains(&double));
+ assert!(orig.contains(&string));
+ assert!(orig.contains(&bytes));
+ assert!(orig.contains(&boolean));
+ f["type"] = json!([long, double, string, bytes, boolean]);
+ }
+ "union_bytes_vs_string" => {
+ f["type"] = json!(["string", "bytes"]);
+ }
+ "union_fixed_dur_decfix" => {
+ let orig = f["type"].as_array().unwrap().clone();
+ let fx8 = find_first(&orig, |o| is_named_type(o, "fixed",
"Fx8")).unwrap();
+ let dur12 = find_first(&orig, |o| is_named_type(o,
"fixed", "Dur12")).unwrap();
+ let decfix16 =
+ find_first(&orig, |o| is_named_type(o, "fixed",
"DecFix16")).unwrap();
+ f["type"] = json!([decfix16, dur12, fx8]);
+ }
+ "union_enum_records_array_map" => {
+ let orig = f["type"].as_array().unwrap().clone();
+ let enum_color = find_first(&orig, |o| {
+ o.get("type").and_then(|v| v.as_str()) == Some("enum")
+ })
+ .unwrap();
+ let rec_a = find_first(&orig, |o| is_named_type(o,
"record", "RecA")).unwrap();
+ let rec_b = find_first(&orig, |o| is_named_type(o,
"record", "RecB")).unwrap();
+ let arr = find_first(&orig, |o| {
+ o.get("type").and_then(|v| v.as_str()) == Some("array")
+ })
+ .unwrap();
+ let map = find_first(&orig, |o| {
+ o.get("type").and_then(|v| v.as_str()) == Some("map")
+ })
+ .unwrap();
+ f["type"] = json!([arr, map, rec_b, rec_a, enum_color]);
+ }
+ "union_date_or_fixed4" => {
+ let orig = f["type"].as_array().unwrap().clone();
+ let date = find_first(&orig, |o| is_logical(o, "int",
"date")).unwrap();
+ let fx4 = find_first(&orig, |o| is_named_type(o, "fixed",
"Fx4")).unwrap();
+ f["type"] = json!([fx4, date]);
+ }
+ "union_time_millis_or_enum" => {
+ let orig = f["type"].as_array().unwrap().clone();
+ let time_ms =
+ find_first(&orig, |o| is_logical(o, "int",
"time-millis")).unwrap();
+ let en = find_first(&orig, |o| {
+ o.get("type").and_then(|v| v.as_str()) == Some("enum")
+ })
+ .unwrap();
+ f["type"] = json!([en, time_ms]);
+ }
+ "union_time_micros_or_string" => {
+ let orig = f["type"].as_array().unwrap().clone();
+ let time_us =
+ find_first(&orig, |o| is_logical(o, "long",
"time-micros")).unwrap();
+ f["type"] = json!(["string", time_us]);
+ }
+ "union_ts_millis_utc_or_array" => {
+ let orig = f["type"].as_array().unwrap().clone();
+ let ts_ms =
+ find_first(&orig, |o| is_logical(o, "long",
"timestamp-millis")).unwrap();
+ let arr = find_first(&orig, |o| {
+ o.get("type").and_then(|v| v.as_str()) == Some("array")
+ })
+ .unwrap();
+ f["type"] = json!([arr, ts_ms]);
+ }
+ "union_ts_micros_local_or_bytes" => {
+ let orig = f["type"].as_array().unwrap().clone();
+ let lts_us =
+ find_first(&orig, |o| is_logical(o, "long",
"local-timestamp-micros"))
+ .unwrap();
+ f["type"] = json!(["bytes", lts_us]);
+ }
+ "union_uuid_or_fixed10" => {
+ let orig = f["type"].as_array().unwrap().clone();
+ let uuid = find_first(&orig, |o| is_logical(o, "string",
"uuid")).unwrap();
+ let fx10 = find_first(&orig, |o| is_named_type(o, "fixed",
"Fx10")).unwrap();
+ f["type"] = json!([fx10, uuid]);
+ }
+ "union_dec_bytes_or_dec_fixed" => {
+ let orig = f["type"].as_array().unwrap().clone();
+ let dec_bytes = find_first(&orig, |o| {
+ o.get("type").and_then(|v| v.as_str()) == Some("bytes")
+ && o.get("logicalType").and_then(|v| v.as_str())
== Some("decimal")
+ })
+ .unwrap();
+ let dec_fix = find_first(&orig, |o| {
+ is_named_type(o, "fixed", "DecFix20")
+ && o.get("logicalType").and_then(|v| v.as_str())
== Some("decimal")
+ })
+ .unwrap();
+ f["type"] = json!([dec_fix, dec_bytes]);
+ }
+ "union_null_bytes_string" => {
+ f["type"] = json!(["bytes", "string", "null"]);
+ }
+ "array_of_union" => {
+ let obj = f
+ .get_mut("type")
+ .expect("array type")
+ .as_object_mut()
+ .unwrap();
+ obj.insert("items".to_string(), json!(["string", "long"]));
+ }
+ "map_of_union" => {
+ let obj = f
+ .get_mut("type")
+ .expect("map type")
+ .as_object_mut()
+ .unwrap();
+ obj.insert("values".to_string(), json!(["double",
"null"]));
+ }
+ "record_with_union_field" => {
+ let rec = f
+ .get_mut("type")
+ .expect("record type")
+ .as_object_mut()
+ .unwrap();
+ let rec_fields =
rec.get_mut("fields").unwrap().as_array_mut().unwrap();
+ let mut found = false;
+ for rf in rec_fields.iter_mut() {
+ if rf.get("name").and_then(|v| v.as_str()) ==
Some("u") {
+ rf["type"] = json!(["string", "long"]); // rely on
int→long promotion
+ found = true;
+ break;
+ }
+ }
+ assert!(found, "field 'u' expected in HasUnion");
+ }
+ "union_ts_micros_utc_or_map" => {
+ let orig = f["type"].as_array().unwrap().clone();
+ let ts_us =
+ find_first(&orig, |o| is_logical(o, "long",
"timestamp-micros")).unwrap();
+ let map = find_first(&orig, |o| {
+ o.get("type").and_then(|v| v.as_str()) == Some("map")
+ })
+ .unwrap();
+ f["type"] = json!([map, ts_us]);
+ }
+ "union_ts_millis_local_or_string" => {
+ let orig = f["type"].as_array().unwrap().clone();
+ let lts_ms =
+ find_first(&orig, |o| is_logical(o, "long",
"local-timestamp-millis"))
+ .unwrap();
+ f["type"] = json!(["string", lts_ms]);
+ }
+ "union_bool_or_string" => {
+ f["type"] = json!(["string", "boolean"]);
+ }
+ _ => {}
+ }
+ }
+ let reader_schema = AvroSchema::new(root.to_string());
+ let resolved = read_alltypes_with_reader_schema(path, reader_schema);
+
+ fn branch_token(dt: &DataType) -> String {
+ match dt {
+ DataType::Null => "null".into(),
+ DataType::Boolean => "boolean".into(),
+ DataType::Int32 => "int".into(),
+ DataType::Int64 => "long".into(),
+ DataType::Float32 => "float".into(),
+ DataType::Float64 => "double".into(),
+ DataType::Binary => "bytes".into(),
+ DataType::Utf8 => "string".into(),
+ DataType::Date32 => "date".into(),
+ DataType::Time32(arrow_schema::TimeUnit::Millisecond) =>
"time-millis".into(),
+ DataType::Time64(arrow_schema::TimeUnit::Microsecond) =>
"time-micros".into(),
+ DataType::Timestamp(arrow_schema::TimeUnit::Millisecond, tz)
=> if tz.is_some() {
+ "timestamp-millis"
+ } else {
+ "local-timestamp-millis"
+ }
+ .into(),
+ DataType::Timestamp(arrow_schema::TimeUnit::Microsecond, tz)
=> if tz.is_some() {
+ "timestamp-micros"
+ } else {
+ "local-timestamp-micros"
+ }
+ .into(),
+ DataType::Interval(IntervalUnit::MonthDayNano) =>
"duration".into(),
+ DataType::FixedSizeBinary(n) => format!("fixed{n}"),
+ DataType::Dictionary(_, _) => "enum".into(),
+ DataType::Decimal128(p, s) => format!("decimal({p},{s})"),
+ DataType::Decimal256(p, s) => format!("decimal({p},{s})"),
+ #[cfg(feature = "small_decimals")]
+ DataType::Decimal64(p, s) => format!("decimal({p},{s})"),
+ DataType::Struct(fields) => {
+ if fields.len() == 2 && fields[0].name() == "a" &&
fields[1].name() == "b" {
+ "record:RecA".into()
+ } else if fields.len() == 2
+ && fields[0].name() == "x"
+ && fields[1].name() == "y"
+ {
+ "record:RecB".into()
+ } else {
+ "record".into()
+ }
+ }
+ DataType::List(_) => "array".into(),
+ DataType::Map(_, _) => "map".into(),
+ other => format!("{other:?}"),
+ }
+ }
+
+ fn union_tokens(u: &UnionArray) -> (Vec<i8>, HashMap<i8, String>) {
+ let fields = match u.data_type() {
+ DataType::Union(fields, _) => fields,
+ other => panic!("expected Union, got {other:?}"),
+ };
+ let mut dict: HashMap<i8, String> =
HashMap::with_capacity(fields.len());
+ for (tid, f) in fields.iter() {
+ dict.insert(tid, branch_token(f.data_type()));
+ }
+ let ids: Vec<i8> = u.type_ids().iter().copied().collect();
+ (ids, dict)
+ }
+
+ fn expected_token(field_name: &str, writer_token: &str) -> String {
+ match field_name {
+ "union_prim" => match writer_token {
+ "int" => "long".into(),
+ "float" => "double".into(),
+ other => other.into(),
+ },
+ "record_with_union_field.u" => match writer_token {
+ "int" => "long".into(),
+ other => other.into(),
+ },
+ _ => writer_token.into(),
+ }
+ }
+
+ fn get_union<'a>(
+ rb: &'a RecordBatch,
+ schema: arrow_schema::SchemaRef,
+ fname: &str,
+ ) -> &'a UnionArray {
+ let idx = schema.index_of(fname).unwrap();
+ rb.column(idx)
+ .as_any()
+ .downcast_ref::<UnionArray>()
+ .unwrap_or_else(|| panic!("{fname} should be a Union"))
+ }
+
+ fn assert_union_equivalent(field_name: &str, u_writer: &UnionArray,
u_reader: &UnionArray) {
+ let (ids_w, dict_w) = union_tokens(u_writer);
+ let (ids_r, dict_r) = union_tokens(u_reader);
+ assert_eq!(
+ ids_w.len(),
+ ids_r.len(),
+ "{field_name}: row count mismatch between baseline and
resolved"
+ );
+ for (i, (id_w, id_r)) in
ids_w.iter().zip(ids_r.iter()).enumerate() {
+ let w_tok = dict_w.get(id_w).unwrap();
+ let want = expected_token(field_name, w_tok);
+ let got = dict_r.get(id_r).unwrap();
+ assert_eq!(
+ got, &want,
+ "{field_name}: row {i} resolved to wrong union branch
(writer={w_tok}, expected={want}, got={got})"
+ );
+ }
+ }
+
+ for (fname, dt) in [
+ ("nullable_int_nullfirst", DataType::Int32),
+ ("nullable_string_nullsecond", DataType::Utf8),
+ ] {
+ let idx_b = baseline_schema.index_of(fname).unwrap();
+ let idx_r = resolved.schema().index_of(fname).unwrap();
+ let col_b = baseline.column(idx_b);
+ let col_r = resolved.column(idx_r);
+ assert_eq!(
+ col_b.data_type(),
+ &dt,
+ "baseline {fname} should decode as non-union with nullability"
+ );
+ assert_eq!(
+ col_b.as_ref(),
+ col_r.as_ref(),
+ "{fname}: values must be identical regardless of null-branch
order"
+ );
+ }
+ let union_fields = [
+ "union_prim",
+ "union_bytes_vs_string",
+ "union_fixed_dur_decfix",
+ "union_enum_records_array_map",
+ "union_date_or_fixed4",
+ "union_time_millis_or_enum",
+ "union_time_micros_or_string",
+ "union_ts_millis_utc_or_array",
+ "union_ts_micros_local_or_bytes",
+ "union_uuid_or_fixed10",
+ "union_dec_bytes_or_dec_fixed",
+ "union_null_bytes_string",
+ "union_ts_micros_utc_or_map",
+ "union_ts_millis_local_or_string",
+ "union_bool_or_string",
+ ];
+ for fname in union_fields {
+ let u_b = get_union(&baseline, baseline_schema.clone(), fname);
+ let u_r = get_union(&resolved, resolved.schema(), fname);
+ assert_union_equivalent(fname, u_b, u_r);
+ }
+ {
+ let fname = "array_of_union";
+ let idx_b = baseline_schema.index_of(fname).unwrap();
+ let idx_r = resolved.schema().index_of(fname).unwrap();
+ let arr_b = baseline
+ .column(idx_b)
+ .as_any()
+ .downcast_ref::<ListArray>()
+ .expect("array_of_union should be a List");
+ let arr_r = resolved
+ .column(idx_r)
+ .as_any()
+ .downcast_ref::<ListArray>()
+ .expect("array_of_union should be a List");
+ assert_eq!(
+ arr_b.value_offsets(),
+ arr_r.value_offsets(),
+ "{fname}: list offsets changed after resolution"
+ );
+ let u_b = arr_b
+ .values()
+ .as_any()
+ .downcast_ref::<UnionArray>()
+ .expect("array items should be Union");
+ let u_r = arr_r
+ .values()
+ .as_any()
+ .downcast_ref::<UnionArray>()
+ .expect("array items should be Union");
+ let (ids_b, dict_b) = union_tokens(u_b);
+ let (ids_r, dict_r) = union_tokens(u_r);
+ assert_eq!(ids_b.len(), ids_r.len(), "{fname}: values length
mismatch");
+ for (i, (id_b, id_r)) in
ids_b.iter().zip(ids_r.iter()).enumerate() {
+ let w_tok = dict_b.get(id_b).unwrap();
+ let got = dict_r.get(id_r).unwrap();
+ assert_eq!(
+ got, w_tok,
+ "{fname}: value {i} resolved to wrong branch
(writer={w_tok}, got={got})"
+ );
+ }
+ }
+ {
+ let fname = "map_of_union";
+ let idx_b = baseline_schema.index_of(fname).unwrap();
+ let idx_r = resolved.schema().index_of(fname).unwrap();
+ let map_b = baseline
+ .column(idx_b)
+ .as_any()
+ .downcast_ref::<MapArray>()
+ .expect("map_of_union should be a Map");
+ let map_r = resolved
+ .column(idx_r)
+ .as_any()
+ .downcast_ref::<MapArray>()
+ .expect("map_of_union should be a Map");
+ assert_eq!(
+ map_b.value_offsets(),
+ map_r.value_offsets(),
+ "{fname}: map value offsets changed after resolution"
+ );
+ let ent_b = map_b.entries();
+ let ent_r = map_r.entries();
+ let val_b_any = ent_b.column(1).as_ref();
+ let val_r_any = ent_r.column(1).as_ref();
+ let b_union = val_b_any.as_any().downcast_ref::<UnionArray>();
+ let r_union = val_r_any.as_any().downcast_ref::<UnionArray>();
+ if let (Some(u_b), Some(u_r)) = (b_union, r_union) {
+ assert_union_equivalent(fname, u_b, u_r);
+ } else {
+ assert_eq!(
+ val_b_any.data_type(),
+ val_r_any.data_type(),
+ "{fname}: value data types differ after resolution"
+ );
+ assert_eq!(
+ val_b_any, val_r_any,
+ "{fname}: value arrays differ after resolution (nullable
value column case)"
+ );
+ let value_nullable = |m: &MapArray| -> bool {
+ match m.data_type() {
+ DataType::Map(entries_field, _sorted) => match
entries_field.data_type() {
+ DataType::Struct(fields) => {
+ assert_eq!(fields.len(), 2, "entries struct
must have 2 fields");
+ assert_eq!(fields[0].name(), "key");
+ assert_eq!(fields[1].name(), "value");
+ fields[1].is_nullable()
+ }
+ other => panic!("Map entries field must be Struct,
got {other:?}"),
+ },
+ other => panic!("expected Map data type, got
{other:?}"),
+ }
+ };
+ assert!(
+ value_nullable(map_b),
+ "{fname}: baseline Map value field should be nullable per
Arrow spec"
+ );
+ assert!(
+ value_nullable(map_r),
+ "{fname}: resolved Map value field should be nullable per
Arrow spec"
+ );
+ }
+ }
+ {
+ let fname = "record_with_union_field";
+ let idx_b = baseline_schema.index_of(fname).unwrap();
+ let idx_r = resolved.schema().index_of(fname).unwrap();
+ let rec_b = baseline
+ .column(idx_b)
+ .as_any()
+ .downcast_ref::<StructArray>()
+ .expect("record_with_union_field should be a Struct");
+ let rec_r = resolved
+ .column(idx_r)
+ .as_any()
+ .downcast_ref::<StructArray>()
+ .expect("record_with_union_field should be a Struct");
+ let u_b = rec_b
+ .column_by_name("u")
+ .unwrap()
+ .as_any()
+ .downcast_ref::<UnionArray>()
+ .expect("field 'u' should be Union (baseline)");
+ let u_r = rec_r
+ .column_by_name("u")
+ .unwrap()
+ .as_any()
+ .downcast_ref::<UnionArray>()
+ .expect("field 'u' should be Union (resolved)");
+ assert_union_equivalent("record_with_union_field.u", u_b, u_r);
+ }
+ }
+
Review Comment:
Just pushed up a full e2e test that compares Arrow arrays as well. This was
a good call!
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]