This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new abfe18483 Ignore Field Metadata in equals_datatype for Dictionary, 
RunEndEncoded, Map and Union (#4111)
abfe18483 is described below

commit abfe184831a105e34d9939070acbaa9fcbfe56f2
Author: Igor Izvekov <[email protected]>
AuthorDate: Tue Apr 25 13:54:41 2023 +0300

    Ignore Field Metadata in equals_datatype for Dictionary, RunEndEncoded, Map 
and Union (#4111)
    
    * fix: equality of nested data types
    
    * fix: cargo clippy
    
    * fix: cargo fmt
    
    * feat: add tests with differing nullability
---
 arrow-schema/src/datatype.rs | 152 +++++++++++++++++++++++++++++++++++++++++--
 1 file changed, 146 insertions(+), 6 deletions(-)

diff --git a/arrow-schema/src/datatype.rs b/arrow-schema/src/datatype.rs
index 0bbd64f30..edd1dd096 100644
--- a/arrow-schema/src/datatype.rs
+++ b/arrow-schema/src/datatype.rs
@@ -431,7 +431,40 @@ impl DataType {
             (
                 DataType::Map(a_field, a_is_sorted),
                 DataType::Map(b_field, b_is_sorted),
-            ) => a_field == b_field && a_is_sorted == b_is_sorted,
+            ) => {
+                a_field.is_nullable() == b_field.is_nullable()
+                    && a_field.data_type().equals_datatype(b_field.data_type())
+                    && a_is_sorted == b_is_sorted
+            }
+            (
+                DataType::Dictionary(a_key, a_value),
+                DataType::Dictionary(b_key, b_value),
+            ) => a_key.equals_datatype(b_key) && 
a_value.equals_datatype(b_value),
+            (
+                DataType::RunEndEncoded(a_run_ends, a_values),
+                DataType::RunEndEncoded(b_run_ends, b_values),
+            ) => {
+                a_run_ends.is_nullable() == b_run_ends.is_nullable()
+                    && a_run_ends
+                        .data_type()
+                        .equals_datatype(b_run_ends.data_type())
+                    && a_values.is_nullable() == b_values.is_nullable()
+                    && 
a_values.data_type().equals_datatype(b_values.data_type())
+            }
+            (
+                DataType::Union(a_union_fields, a_union_mode),
+                DataType::Union(b_union_fields, b_union_mode),
+            ) => {
+                a_union_mode == b_union_mode
+                    && a_union_fields.len() == b_union_fields.len()
+                    && a_union_fields.iter().all(|a| {
+                        b_union_fields.iter().any(|b| {
+                            a.0 == b.0
+                                && a.1.is_nullable() == b.1.is_nullable()
+                                && 
a.1.data_type().equals_datatype(b.1.data_type())
+                        })
+                    })
+            }
             _ => self == other,
         }
     }
@@ -564,7 +597,7 @@ pub const DECIMAL_DEFAULT_SCALE: i8 = 10;
 #[cfg(test)]
 mod tests {
     use super::*;
-    use crate::Field;
+    use crate::{Field, UnionMode};
 
     #[test]
     #[cfg(feature = "serde")]
@@ -628,10 +661,14 @@ mod tests {
         assert!(!list_b.equals_datatype(&list_c));
         assert!(!list_a.equals_datatype(&list_d));
 
-        let list_e =
-            DataType::FixedSizeList(Arc::new(Field::new("item", list_a, 
false)), 3);
-        let list_f =
-            DataType::FixedSizeList(Arc::new(Field::new("array", list_b, 
false)), 3);
+        let list_e = DataType::FixedSizeList(
+            Arc::new(Field::new("item", list_a.clone(), false)),
+            3,
+        );
+        let list_f = DataType::FixedSizeList(
+            Arc::new(Field::new("array", list_b.clone(), false)),
+            3,
+        );
         let list_g = DataType::FixedSizeList(
             Arc::new(Field::new("item", DataType::FixedSizeBinary(3), true)),
             3,
@@ -664,6 +701,109 @@ mod tests {
         assert!(!list_h.equals_datatype(&list_j));
         assert!(!list_k.equals_datatype(&list_l));
         assert!(list_k.equals_datatype(&list_m));
+
+        let list_n =
+            DataType::Map(Arc::new(Field::new("f1", list_a.clone(), true)), 
true);
+        let list_o =
+            DataType::Map(Arc::new(Field::new("f2", list_b.clone(), true)), 
true);
+        let list_p =
+            DataType::Map(Arc::new(Field::new("f2", list_b.clone(), true)), 
false);
+        let list_q =
+            DataType::Map(Arc::new(Field::new("f2", list_c.clone(), true)), 
true);
+        let list_r =
+            DataType::Map(Arc::new(Field::new("f1", list_a.clone(), false)), 
true);
+
+        assert!(list_n.equals_datatype(&list_o));
+        assert!(!list_n.equals_datatype(&list_p));
+        assert!(!list_n.equals_datatype(&list_q));
+        assert!(!list_n.equals_datatype(&list_r));
+
+        let list_s = DataType::Dictionary(Box::new(DataType::UInt8), 
Box::new(list_a));
+        let list_t =
+            DataType::Dictionary(Box::new(DataType::UInt8), 
Box::new(list_b.clone()));
+        let list_u = DataType::Dictionary(Box::new(DataType::Int8), 
Box::new(list_b));
+        let list_v = DataType::Dictionary(Box::new(DataType::UInt8), 
Box::new(list_c));
+
+        assert!(list_s.equals_datatype(&list_t));
+        assert!(!list_s.equals_datatype(&list_u));
+        assert!(!list_s.equals_datatype(&list_v));
+
+        let union_a = DataType::Union(
+            UnionFields::new(
+                vec![1, 2],
+                vec![
+                    Field::new("f1", DataType::Utf8, false),
+                    Field::new("f2", DataType::UInt8, false),
+                ],
+            ),
+            UnionMode::Sparse,
+        );
+        let union_b = DataType::Union(
+            UnionFields::new(
+                vec![1, 2],
+                vec![
+                    Field::new("ff1", DataType::Utf8, false),
+                    Field::new("ff2", DataType::UInt8, false),
+                ],
+            ),
+            UnionMode::Sparse,
+        );
+        let union_c = DataType::Union(
+            UnionFields::new(
+                vec![2, 1],
+                vec![
+                    Field::new("fff2", DataType::UInt8, false),
+                    Field::new("fff1", DataType::Utf8, false),
+                ],
+            ),
+            UnionMode::Sparse,
+        );
+        let union_d = DataType::Union(
+            UnionFields::new(
+                vec![2, 1],
+                vec![
+                    Field::new("fff1", DataType::Int8, false),
+                    Field::new("fff2", DataType::UInt8, false),
+                ],
+            ),
+            UnionMode::Sparse,
+        );
+        let union_e = DataType::Union(
+            UnionFields::new(
+                vec![1, 2],
+                vec![
+                    Field::new("f1", DataType::Utf8, true),
+                    Field::new("f2", DataType::UInt8, false),
+                ],
+            ),
+            UnionMode::Sparse,
+        );
+
+        assert!(union_a.equals_datatype(&union_b));
+        assert!(union_a.equals_datatype(&union_c));
+        assert!(!union_a.equals_datatype(&union_d));
+        assert!(!union_a.equals_datatype(&union_e));
+
+        let list_w = DataType::RunEndEncoded(
+            Arc::new(Field::new("f1", DataType::Int64, true)),
+            Arc::new(Field::new("f2", DataType::Utf8, true)),
+        );
+        let list_x = DataType::RunEndEncoded(
+            Arc::new(Field::new("ff1", DataType::Int64, true)),
+            Arc::new(Field::new("ff2", DataType::Utf8, true)),
+        );
+        let list_y = DataType::RunEndEncoded(
+            Arc::new(Field::new("ff1", DataType::UInt16, true)),
+            Arc::new(Field::new("ff2", DataType::Utf8, true)),
+        );
+        let list_z = DataType::RunEndEncoded(
+            Arc::new(Field::new("f1", DataType::Int64, false)),
+            Arc::new(Field::new("f2", DataType::Utf8, true)),
+        );
+
+        assert!(list_w.equals_datatype(&list_x));
+        assert!(!list_w.equals_datatype(&list_y));
+        assert!(!list_w.equals_datatype(&list_z));
     }
 
     #[test]

Reply via email to