This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new c5442cf Fix generate_non_canonical_map_case, fix `MapArray` equality
(#1476)
c5442cf is described below
commit c5442cf2fd8f046e6ad75d2d5c7efb2899dd654d
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Sun Mar 27 03:46:56 2022 -0700
Fix generate_non_canonical_map_case, fix `MapArray` equality (#1476)
* Revamp list_equal for map type
* Canonicalize schema
* Add nullability and metadata
---
arrow/src/array/equal/list.rs | 102 ++++++++++++++++-----
arrow/src/array/equal/utils.rs | 30 +++++-
.../src/bin/arrow-json-integration-test.rs | 45 ++++++++-
3 files changed, 153 insertions(+), 24 deletions(-)
diff --git a/arrow/src/array/equal/list.rs b/arrow/src/array/equal/list.rs
index 20e6400..000b31a 100644
--- a/arrow/src/array/equal/list.rs
+++ b/arrow/src/array/equal/list.rs
@@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.
+use crate::datatypes::DataType;
use crate::{
array::ArrayData,
array::{data::count_nulls, OffsetSizeTrait},
@@ -22,7 +23,9 @@ use crate::{
util::bit_util::get_bit,
};
-use super::{equal_range, utils::child_logical_null_buffer};
+use super::{
+ equal_range, equal_values, utils::child_logical_null_buffer,
utils::equal_nulls,
+};
fn lengths_equal<T: OffsetSizeTrait>(lhs: &[T], rhs: &[T]) -> bool {
// invariant from `base_equal`
@@ -58,22 +61,47 @@ fn offset_value_equal<T: OffsetSizeTrait>(
lhs_pos: usize,
rhs_pos: usize,
len: usize,
+ data_type: &DataType,
) -> bool {
let lhs_start = lhs_offsets[lhs_pos].to_usize().unwrap();
let rhs_start = rhs_offsets[rhs_pos].to_usize().unwrap();
let lhs_len = lhs_offsets[lhs_pos + len] - lhs_offsets[lhs_pos];
let rhs_len = rhs_offsets[rhs_pos + len] - rhs_offsets[rhs_pos];
- lhs_len == rhs_len
- && equal_range(
- lhs_values,
- rhs_values,
- lhs_nulls,
- rhs_nulls,
- lhs_start,
- rhs_start,
- lhs_len.to_usize().unwrap(),
- )
+ lhs_len == rhs_len && {
+ match data_type {
+ DataType::Map(_, _) => {
+ // Don't use `equal_range` which calls `utils::base_equal`
that checks
+ // struct fields, but we don't enforce struct field names.
+ equal_nulls(
+ lhs_values,
+ rhs_values,
+ lhs_nulls,
+ rhs_nulls,
+ lhs_start,
+ rhs_start,
+ lhs_len.to_usize().unwrap(),
+ ) && equal_values(
+ lhs_values,
+ rhs_values,
+ lhs_nulls,
+ rhs_nulls,
+ lhs_start,
+ rhs_start,
+ lhs_len.to_usize().unwrap(),
+ )
+ }
+ _ => equal_range(
+ lhs_values,
+ rhs_values,
+ lhs_nulls,
+ rhs_nulls,
+ lhs_start,
+ rhs_start,
+ lhs_len.to_usize().unwrap(),
+ ),
+ }
+ }
}
pub(super) fn list_equal<T: OffsetSizeTrait>(
@@ -131,17 +159,46 @@ pub(super) fn list_equal<T: OffsetSizeTrait>(
lengths_equal(
&lhs_offsets[lhs_start..lhs_start + len],
&rhs_offsets[rhs_start..rhs_start + len],
- ) && equal_range(
- lhs_values,
- rhs_values,
- child_lhs_nulls.as_ref(),
- child_rhs_nulls.as_ref(),
- lhs_offsets[lhs_start].to_usize().unwrap(),
- rhs_offsets[rhs_start].to_usize().unwrap(),
- (lhs_offsets[lhs_start + len] - lhs_offsets[lhs_start])
- .to_usize()
- .unwrap(),
- )
+ ) && {
+ match lhs.data_type() {
+ DataType::Map(_, _) => {
+ // Don't use `equal_range` which calls `utils::base_equal`
that checks
+ // struct fields, but we don't enforce struct field names.
+ equal_nulls(
+ lhs_values,
+ rhs_values,
+ child_lhs_nulls.as_ref(),
+ child_rhs_nulls.as_ref(),
+ lhs_offsets[lhs_start].to_usize().unwrap(),
+ rhs_offsets[rhs_start].to_usize().unwrap(),
+ (lhs_offsets[lhs_start + len] - lhs_offsets[lhs_start])
+ .to_usize()
+ .unwrap(),
+ ) && equal_values(
+ lhs_values,
+ rhs_values,
+ child_lhs_nulls.as_ref(),
+ child_rhs_nulls.as_ref(),
+ lhs_offsets[lhs_start].to_usize().unwrap(),
+ rhs_offsets[rhs_start].to_usize().unwrap(),
+ (lhs_offsets[lhs_start + len] - lhs_offsets[lhs_start])
+ .to_usize()
+ .unwrap(),
+ )
+ }
+ _ => equal_range(
+ lhs_values,
+ rhs_values,
+ child_lhs_nulls.as_ref(),
+ child_rhs_nulls.as_ref(),
+ lhs_offsets[lhs_start].to_usize().unwrap(),
+ rhs_offsets[rhs_start].to_usize().unwrap(),
+ (lhs_offsets[lhs_start + len] - lhs_offsets[lhs_start])
+ .to_usize()
+ .unwrap(),
+ ),
+ }
+ }
} else {
// get a ref of the parent null buffer bytes, to use in testing for
nullness
let lhs_null_bytes = lhs_nulls.unwrap().as_slice();
@@ -166,6 +223,7 @@ pub(super) fn list_equal<T: OffsetSizeTrait>(
lhs_pos,
rhs_pos,
1,
+ lhs.data_type(),
)
})
}
diff --git a/arrow/src/array/equal/utils.rs b/arrow/src/array/equal/utils.rs
index 819ae32..1bced97 100644
--- a/arrow/src/array/equal/utils.rs
+++ b/arrow/src/array/equal/utils.rs
@@ -66,7 +66,35 @@ pub(super) fn equal_nulls(
#[inline]
pub(super) fn base_equal(lhs: &ArrayData, rhs: &ArrayData) -> bool {
- lhs.data_type() == rhs.data_type() && lhs.len() == rhs.len()
+ let equal_type = match (lhs.data_type(), rhs.data_type()) {
+ (DataType::Map(l_field, l_sorted), DataType::Map(r_field, r_sorted))
=> {
+ let field_equal = match (l_field.data_type(), r_field.data_type())
{
+ (DataType::Struct(l_fields), DataType::Struct(r_fields))
+ if l_fields.len() == 2 && r_fields.len() == 2 =>
+ {
+ let l_key_field = l_fields.get(0).unwrap();
+ let r_key_field = r_fields.get(0).unwrap();
+ let l_value_field = l_fields.get(1).unwrap();
+ let r_value_field = r_fields.get(1).unwrap();
+
+ // We don't enforce the equality of field names
+ let data_type_equal = l_key_field.data_type()
+ == r_key_field.data_type()
+ && l_value_field.data_type() ==
r_value_field.data_type();
+ let nullability_equal = l_key_field.is_nullable()
+ == r_key_field.is_nullable()
+ && l_value_field.is_nullable() ==
r_value_field.is_nullable();
+ let metadata_equal = l_key_field.metadata() ==
r_key_field.metadata()
+ && l_value_field.metadata() ==
r_value_field.metadata();
+ data_type_equal && nullability_equal && metadata_equal
+ }
+ _ => panic!("Map type should have 2 fields Struct in its
field"),
+ };
+ field_equal && l_sorted == r_sorted
+ }
+ (l_data_type, r_data_type) => l_data_type == r_data_type,
+ };
+ equal_type && lhs.len() == rhs.len()
}
// whether the two memory regions are equal
diff --git a/integration-testing/src/bin/arrow-json-integration-test.rs
b/integration-testing/src/bin/arrow-json-integration-test.rs
index 17d2528..69b73b1 100644
--- a/integration-testing/src/bin/arrow-json-integration-test.rs
+++ b/integration-testing/src/bin/arrow-json-integration-test.rs
@@ -15,6 +15,8 @@
// specific language governing permissions and limitations
// under the License.
+use arrow::datatypes::Schema;
+use arrow::datatypes::{DataType, Field};
use arrow::error::{ArrowError, Result};
use arrow::ipc::reader::FileReader;
use arrow::ipc::writer::FileWriter;
@@ -107,6 +109,47 @@ fn arrow_to_json(arrow_name: &str, json_name: &str,
verbose: bool) -> Result<()>
Ok(())
}
+fn canonicalize_schema(schema: &Schema) -> Schema {
+ let fields = schema
+ .fields()
+ .iter()
+ .map(|field| match field.data_type() {
+ DataType::Map(child_field, sorted) => match
child_field.data_type() {
+ DataType::Struct(fields) if fields.len() == 2 => {
+ let first_field = fields.get(0).unwrap();
+ let key_field = Field::new(
+ "key",
+ first_field.data_type().clone(),
+ first_field.is_nullable(),
+ );
+ let second_field = fields.get(1).unwrap();
+ let value_field = Field::new(
+ "value",
+ second_field.data_type().clone(),
+ second_field.is_nullable(),
+ );
+
+ let struct_type = DataType::Struct(vec![key_field,
value_field]);
+ let child_field =
+ Field::new("entries", struct_type,
child_field.is_nullable());
+
+ Field::new(
+ field.name().as_str(),
+ DataType::Map(Box::new(child_field), *sorted),
+ field.is_nullable(),
+ )
+ }
+ _ => panic!(
+ "The child field of Map type should be Struct type with 2
fields."
+ ),
+ },
+ _ => field.clone(),
+ })
+ .collect::<Vec<_>>();
+
+ Schema::new(fields).with_metadata(schema.metadata().clone())
+}
+
fn validate(arrow_name: &str, json_name: &str, verbose: bool) -> Result<()> {
if verbose {
eprintln!("Validating {} and {}", arrow_name, json_name);
@@ -121,7 +164,7 @@ fn validate(arrow_name: &str, json_name: &str, verbose:
bool) -> Result<()> {
let arrow_schema = arrow_reader.schema().as_ref().to_owned();
// compare schemas
- if json_file.schema != arrow_schema {
+ if canonicalize_schema(&json_file.schema) !=
canonicalize_schema(&arrow_schema) {
return Err(ArrowError::ComputeError(format!(
"Schemas do not match. JSON: {:?}. Arrow: {:?}",
json_file.schema, arrow_schema