This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 9fa020708b Add test case, and fix stack overflow bug for get field 
access expr (#7623)
9fa020708b is described below

commit 9fa020708bab6e7293f3731d807a80284c1204bd
Author: Mustafa Akur <[email protected]>
AuthorDate: Fri Sep 22 15:31:54 2023 +0300

    Add test case, and fix stack overflow bug for get field access expr (#7623)
---
 .../src/expressions/get_indexed_field.rs           | 36 +++++++++++++++++++++-
 1 file changed, 35 insertions(+), 1 deletion(-)

diff --git a/datafusion/physical-expr/src/expressions/get_indexed_field.rs 
b/datafusion/physical-expr/src/expressions/get_indexed_field.rs
index 3a7ce568b3..ab15356dc2 100644
--- a/datafusion/physical-expr/src/expressions/get_indexed_field.rs
+++ b/datafusion/physical-expr/src/expressions/get_indexed_field.rs
@@ -61,9 +61,28 @@ impl std::fmt::Display for GetFieldAccessExpr {
 
 impl PartialEq<dyn Any> for GetFieldAccessExpr {
     fn eq(&self, other: &dyn Any) -> bool {
+        use GetFieldAccessExpr::{ListIndex, ListRange, NamedStructField};
         down_cast_any_ref(other)
             .downcast_ref::<Self>()
-            .map(|x| self.eq(x))
+            .map(|x| match (self, x) {
+                (NamedStructField { name: lhs }, NamedStructField { name: rhs 
}) => {
+                    lhs.eq(rhs)
+                }
+                (ListIndex { key: lhs }, ListIndex { key: rhs }) => 
lhs.eq(rhs),
+                (
+                    ListRange {
+                        start: start_lhs,
+                        stop: stop_lhs,
+                    },
+                    ListRange {
+                        start: start_rhs,
+                        stop: stop_rhs,
+                    },
+                ) => start_lhs.eq(start_rhs) && stop_lhs.eq(stop_rhs),
+                (NamedStructField { .. }, ListIndex { .. } | ListRange { .. }) 
=> false,
+                (ListIndex { .. }, NamedStructField { .. } | ListRange { .. }) 
=> false,
+                (ListRange { .. }, NamedStructField { .. } | ListIndex { .. }) 
=> false,
+            })
             .unwrap_or(false)
     }
 }
@@ -435,4 +454,19 @@ mod tests {
         assert!(result.is_null(0));
         Ok(())
     }
+
+    #[test]
+    fn get_indexed_field_eq() -> Result<()> {
+        let schema = list_schema(&["list", "error"]);
+        let expr = col("list", &schema).unwrap();
+        let key = col("error", &schema).unwrap();
+        let indexed_field =
+            Arc::new(GetIndexedFieldExpr::new_index(expr.clone(), key.clone()))
+                as Arc<dyn PhysicalExpr>;
+        let indexed_field_other =
+            Arc::new(GetIndexedFieldExpr::new_index(key, expr)) as Arc<dyn 
PhysicalExpr>;
+        assert!(indexed_field.eq(&indexed_field));
+        assert!(!indexed_field.eq(&indexed_field_other));
+        Ok(())
+    }
 }

Reply via email to