This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 9fa020708b Add test case, and fix stack overflow bug for get field
access expr (#7623)
9fa020708b is described below
commit 9fa020708bab6e7293f3731d807a80284c1204bd
Author: Mustafa Akur <[email protected]>
AuthorDate: Fri Sep 22 15:31:54 2023 +0300
Add test case, and fix stack overflow bug for get field access expr (#7623)
---
.../src/expressions/get_indexed_field.rs | 36 +++++++++++++++++++++-
1 file changed, 35 insertions(+), 1 deletion(-)
diff --git a/datafusion/physical-expr/src/expressions/get_indexed_field.rs
b/datafusion/physical-expr/src/expressions/get_indexed_field.rs
index 3a7ce568b3..ab15356dc2 100644
--- a/datafusion/physical-expr/src/expressions/get_indexed_field.rs
+++ b/datafusion/physical-expr/src/expressions/get_indexed_field.rs
@@ -61,9 +61,28 @@ impl std::fmt::Display for GetFieldAccessExpr {
impl PartialEq<dyn Any> for GetFieldAccessExpr {
fn eq(&self, other: &dyn Any) -> bool {
+ use GetFieldAccessExpr::{ListIndex, ListRange, NamedStructField};
down_cast_any_ref(other)
.downcast_ref::<Self>()
- .map(|x| self.eq(x))
+ .map(|x| match (self, x) {
+ (NamedStructField { name: lhs }, NamedStructField { name: rhs
}) => {
+ lhs.eq(rhs)
+ }
+ (ListIndex { key: lhs }, ListIndex { key: rhs }) =>
lhs.eq(rhs),
+ (
+ ListRange {
+ start: start_lhs,
+ stop: stop_lhs,
+ },
+ ListRange {
+ start: start_rhs,
+ stop: stop_rhs,
+ },
+ ) => start_lhs.eq(start_rhs) && stop_lhs.eq(stop_rhs),
+ (NamedStructField { .. }, ListIndex { .. } | ListRange { .. })
=> false,
+ (ListIndex { .. }, NamedStructField { .. } | ListRange { .. })
=> false,
+ (ListRange { .. }, NamedStructField { .. } | ListIndex { .. })
=> false,
+ })
.unwrap_or(false)
}
}
@@ -435,4 +454,19 @@ mod tests {
assert!(result.is_null(0));
Ok(())
}
+
+ #[test]
+ fn get_indexed_field_eq() -> Result<()> {
+ let schema = list_schema(&["list", "error"]);
+ let expr = col("list", &schema).unwrap();
+ let key = col("error", &schema).unwrap();
+ let indexed_field =
+ Arc::new(GetIndexedFieldExpr::new_index(expr.clone(), key.clone()))
+ as Arc<dyn PhysicalExpr>;
+ let indexed_field_other =
+ Arc::new(GetIndexedFieldExpr::new_index(key, expr)) as Arc<dyn
PhysicalExpr>;
+ assert!(indexed_field.eq(&indexed_field));
+ assert!(!indexed_field.eq(&indexed_field_other));
+ Ok(())
+ }
}