This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new c5442cf  Fix generate_non_canonical_map_case, fix `MapArray` equality  
(#1476)
c5442cf is described below

commit c5442cf2fd8f046e6ad75d2d5c7efb2899dd654d
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Sun Mar 27 03:46:56 2022 -0700

    Fix generate_non_canonical_map_case, fix `MapArray` equality  (#1476)
    
    * Revamp list_equal for map type
    
    * Canonicalize schema
    
    * Add nullability and metadata
---
 arrow/src/array/equal/list.rs                      | 102 ++++++++++++++++-----
 arrow/src/array/equal/utils.rs                     |  30 +++++-
 .../src/bin/arrow-json-integration-test.rs         |  45 ++++++++-
 3 files changed, 153 insertions(+), 24 deletions(-)

diff --git a/arrow/src/array/equal/list.rs b/arrow/src/array/equal/list.rs
index 20e6400..000b31a 100644
--- a/arrow/src/array/equal/list.rs
+++ b/arrow/src/array/equal/list.rs
@@ -15,6 +15,7 @@
 // specific language governing permissions and limitations
 // under the License.
 
+use crate::datatypes::DataType;
 use crate::{
     array::ArrayData,
     array::{data::count_nulls, OffsetSizeTrait},
@@ -22,7 +23,9 @@ use crate::{
     util::bit_util::get_bit,
 };
 
-use super::{equal_range, utils::child_logical_null_buffer};
+use super::{
+    equal_range, equal_values, utils::child_logical_null_buffer, 
utils::equal_nulls,
+};
 
 fn lengths_equal<T: OffsetSizeTrait>(lhs: &[T], rhs: &[T]) -> bool {
     // invariant from `base_equal`
@@ -58,22 +61,47 @@ fn offset_value_equal<T: OffsetSizeTrait>(
     lhs_pos: usize,
     rhs_pos: usize,
     len: usize,
+    data_type: &DataType,
 ) -> bool {
     let lhs_start = lhs_offsets[lhs_pos].to_usize().unwrap();
     let rhs_start = rhs_offsets[rhs_pos].to_usize().unwrap();
     let lhs_len = lhs_offsets[lhs_pos + len] - lhs_offsets[lhs_pos];
     let rhs_len = rhs_offsets[rhs_pos + len] - rhs_offsets[rhs_pos];
 
-    lhs_len == rhs_len
-        && equal_range(
-            lhs_values,
-            rhs_values,
-            lhs_nulls,
-            rhs_nulls,
-            lhs_start,
-            rhs_start,
-            lhs_len.to_usize().unwrap(),
-        )
+    lhs_len == rhs_len && {
+        match data_type {
+            DataType::Map(_, _) => {
+                // Don't use `equal_range` which calls `utils::base_equal` 
that checks
+                // struct fields, but we don't enforce struct field names.
+                equal_nulls(
+                    lhs_values,
+                    rhs_values,
+                    lhs_nulls,
+                    rhs_nulls,
+                    lhs_start,
+                    rhs_start,
+                    lhs_len.to_usize().unwrap(),
+                ) && equal_values(
+                    lhs_values,
+                    rhs_values,
+                    lhs_nulls,
+                    rhs_nulls,
+                    lhs_start,
+                    rhs_start,
+                    lhs_len.to_usize().unwrap(),
+                )
+            }
+            _ => equal_range(
+                lhs_values,
+                rhs_values,
+                lhs_nulls,
+                rhs_nulls,
+                lhs_start,
+                rhs_start,
+                lhs_len.to_usize().unwrap(),
+            ),
+        }
+    }
 }
 
 pub(super) fn list_equal<T: OffsetSizeTrait>(
@@ -131,17 +159,46 @@ pub(super) fn list_equal<T: OffsetSizeTrait>(
         lengths_equal(
             &lhs_offsets[lhs_start..lhs_start + len],
             &rhs_offsets[rhs_start..rhs_start + len],
-        ) && equal_range(
-            lhs_values,
-            rhs_values,
-            child_lhs_nulls.as_ref(),
-            child_rhs_nulls.as_ref(),
-            lhs_offsets[lhs_start].to_usize().unwrap(),
-            rhs_offsets[rhs_start].to_usize().unwrap(),
-            (lhs_offsets[lhs_start + len] - lhs_offsets[lhs_start])
-                .to_usize()
-                .unwrap(),
-        )
+        ) && {
+            match lhs.data_type() {
+                DataType::Map(_, _) => {
+                    // Don't use `equal_range` which calls `utils::base_equal` 
that checks
+                    // struct fields, but we don't enforce struct field names.
+                    equal_nulls(
+                        lhs_values,
+                        rhs_values,
+                        child_lhs_nulls.as_ref(),
+                        child_rhs_nulls.as_ref(),
+                        lhs_offsets[lhs_start].to_usize().unwrap(),
+                        rhs_offsets[rhs_start].to_usize().unwrap(),
+                        (lhs_offsets[lhs_start + len] - lhs_offsets[lhs_start])
+                            .to_usize()
+                            .unwrap(),
+                    ) && equal_values(
+                        lhs_values,
+                        rhs_values,
+                        child_lhs_nulls.as_ref(),
+                        child_rhs_nulls.as_ref(),
+                        lhs_offsets[lhs_start].to_usize().unwrap(),
+                        rhs_offsets[rhs_start].to_usize().unwrap(),
+                        (lhs_offsets[lhs_start + len] - lhs_offsets[lhs_start])
+                            .to_usize()
+                            .unwrap(),
+                    )
+                }
+                _ => equal_range(
+                    lhs_values,
+                    rhs_values,
+                    child_lhs_nulls.as_ref(),
+                    child_rhs_nulls.as_ref(),
+                    lhs_offsets[lhs_start].to_usize().unwrap(),
+                    rhs_offsets[rhs_start].to_usize().unwrap(),
+                    (lhs_offsets[lhs_start + len] - lhs_offsets[lhs_start])
+                        .to_usize()
+                        .unwrap(),
+                ),
+            }
+        }
     } else {
         // get a ref of the parent null buffer bytes, to use in testing for 
nullness
         let lhs_null_bytes = lhs_nulls.unwrap().as_slice();
@@ -166,6 +223,7 @@ pub(super) fn list_equal<T: OffsetSizeTrait>(
                         lhs_pos,
                         rhs_pos,
                         1,
+                        lhs.data_type(),
                     )
         })
     }
diff --git a/arrow/src/array/equal/utils.rs b/arrow/src/array/equal/utils.rs
index 819ae32..1bced97 100644
--- a/arrow/src/array/equal/utils.rs
+++ b/arrow/src/array/equal/utils.rs
@@ -66,7 +66,35 @@ pub(super) fn equal_nulls(
 
 #[inline]
 pub(super) fn base_equal(lhs: &ArrayData, rhs: &ArrayData) -> bool {
-    lhs.data_type() == rhs.data_type() && lhs.len() == rhs.len()
+    let equal_type = match (lhs.data_type(), rhs.data_type()) {
+        (DataType::Map(l_field, l_sorted), DataType::Map(r_field, r_sorted)) 
=> {
+            let field_equal = match (l_field.data_type(), r_field.data_type()) 
{
+                (DataType::Struct(l_fields), DataType::Struct(r_fields))
+                    if l_fields.len() == 2 && r_fields.len() == 2 =>
+                {
+                    let l_key_field = l_fields.get(0).unwrap();
+                    let r_key_field = r_fields.get(0).unwrap();
+                    let l_value_field = l_fields.get(1).unwrap();
+                    let r_value_field = r_fields.get(1).unwrap();
+
+                    // We don't enforce the equality of field names
+                    let data_type_equal = l_key_field.data_type()
+                        == r_key_field.data_type()
+                        && l_value_field.data_type() == 
r_value_field.data_type();
+                    let nullability_equal = l_key_field.is_nullable()
+                        == r_key_field.is_nullable()
+                        && l_value_field.is_nullable() == 
r_value_field.is_nullable();
+                    let metadata_equal = l_key_field.metadata() == 
r_key_field.metadata()
+                        && l_value_field.metadata() == 
r_value_field.metadata();
+                    data_type_equal && nullability_equal && metadata_equal
+                }
+                _ => panic!("Map type should have 2 fields Struct in its 
field"),
+            };
+            field_equal && l_sorted == r_sorted
+        }
+        (l_data_type, r_data_type) => l_data_type == r_data_type,
+    };
+    equal_type && lhs.len() == rhs.len()
 }
 
 // whether the two memory regions are equal
diff --git a/integration-testing/src/bin/arrow-json-integration-test.rs 
b/integration-testing/src/bin/arrow-json-integration-test.rs
index 17d2528..69b73b1 100644
--- a/integration-testing/src/bin/arrow-json-integration-test.rs
+++ b/integration-testing/src/bin/arrow-json-integration-test.rs
@@ -15,6 +15,8 @@
 // specific language governing permissions and limitations
 // under the License.
 
+use arrow::datatypes::Schema;
+use arrow::datatypes::{DataType, Field};
 use arrow::error::{ArrowError, Result};
 use arrow::ipc::reader::FileReader;
 use arrow::ipc::writer::FileWriter;
@@ -107,6 +109,47 @@ fn arrow_to_json(arrow_name: &str, json_name: &str, 
verbose: bool) -> Result<()>
     Ok(())
 }
 
+fn canonicalize_schema(schema: &Schema) -> Schema {
+    let fields = schema
+        .fields()
+        .iter()
+        .map(|field| match field.data_type() {
+            DataType::Map(child_field, sorted) => match 
child_field.data_type() {
+                DataType::Struct(fields) if fields.len() == 2 => {
+                    let first_field = fields.get(0).unwrap();
+                    let key_field = Field::new(
+                        "key",
+                        first_field.data_type().clone(),
+                        first_field.is_nullable(),
+                    );
+                    let second_field = fields.get(1).unwrap();
+                    let value_field = Field::new(
+                        "value",
+                        second_field.data_type().clone(),
+                        second_field.is_nullable(),
+                    );
+
+                    let struct_type = DataType::Struct(vec![key_field, 
value_field]);
+                    let child_field =
+                        Field::new("entries", struct_type, 
child_field.is_nullable());
+
+                    Field::new(
+                        field.name().as_str(),
+                        DataType::Map(Box::new(child_field), *sorted),
+                        field.is_nullable(),
+                    )
+                }
+                _ => panic!(
+                    "The child field of Map type should be Struct type with 2 
fields."
+                ),
+            },
+            _ => field.clone(),
+        })
+        .collect::<Vec<_>>();
+
+    Schema::new(fields).with_metadata(schema.metadata().clone())
+}
+
 fn validate(arrow_name: &str, json_name: &str, verbose: bool) -> Result<()> {
     if verbose {
         eprintln!("Validating {} and {}", arrow_name, json_name);
@@ -121,7 +164,7 @@ fn validate(arrow_name: &str, json_name: &str, verbose: 
bool) -> Result<()> {
     let arrow_schema = arrow_reader.schema().as_ref().to_owned();
 
     // compare schemas
-    if json_file.schema != arrow_schema {
+    if canonicalize_schema(&json_file.schema) != 
canonicalize_schema(&arrow_schema) {
         return Err(ArrowError::ComputeError(format!(
             "Schemas do not match. JSON: {:?}. Arrow: {:?}",
             json_file.schema, arrow_schema

Reply via email to