This is an automated email from the ASF dual-hosted git repository.

jayzhan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 14d39734dc Rewrite `array @> array` and `array <@ array` in 
sql_expr_to_logical_expr (#11155)
14d39734dc is described below

commit 14d39734dce366f70ccc014e93a892d8a8b52537
Author: Jay Zhan <[email protected]>
AuthorDate: Sat Jun 29 20:14:15 2024 +0800

    Rewrite `array @> array` and `array <@ array` in sql_expr_to_logical_expr 
(#11155)
    
    * rewrite at arrow
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * rm useless test
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * add test
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * rm test
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    ---------
    
    Signed-off-by: jayzhan211 <[email protected]>
---
 datafusion/expr/src/expr_rewriter/mod.rs     |  1 +
 datafusion/functions-array/src/lib.rs        |  2 -
 datafusion/functions-array/src/rewrite.rs    | 76 ----------------------------
 datafusion/physical-expr-common/src/datum.rs | 33 ------------
 datafusion/sql/src/expr/mod.rs               | 30 ++++++++++-
 datafusion/sql/tests/sql_integration.rs      | 16 ------
 datafusion/sqllogictest/test_files/array.slt | 22 ++++++++
 7 files changed, 52 insertions(+), 128 deletions(-)

diff --git a/datafusion/expr/src/expr_rewriter/mod.rs 
b/datafusion/expr/src/expr_rewriter/mod.rs
index 1441374bdb..024e4a0cea 100644
--- a/datafusion/expr/src/expr_rewriter/mod.rs
+++ b/datafusion/expr/src/expr_rewriter/mod.rs
@@ -43,6 +43,7 @@ pub use order_by::rewrite_sort_cols_by_aggs;
 /// For example, concatenating arrays `a || b` is represented as
 /// `Operator::ArrowAt`, but can be implemented by calling a function
 /// `array_concat` from the `functions-array` crate.
+// This is not used in datafusion internally, but it is still helpful for 
downstream project so don't remove it.
 pub trait FunctionRewrite {
     /// Return a human readable name for this rewrite
     fn name(&self) -> &str;
diff --git a/datafusion/functions-array/src/lib.rs 
b/datafusion/functions-array/src/lib.rs
index b2fcb5717b..543b7a6027 100644
--- a/datafusion/functions-array/src/lib.rs
+++ b/datafusion/functions-array/src/lib.rs
@@ -46,7 +46,6 @@ pub mod repeat;
 pub mod replace;
 pub mod resize;
 pub mod reverse;
-pub mod rewrite;
 pub mod set_ops;
 pub mod sort;
 pub mod string;
@@ -152,7 +151,6 @@ pub fn register_all(registry: &mut dyn FunctionRegistry) -> 
Result<()> {
         }
         Ok(()) as Result<()>
     })?;
-    registry.register_function_rewrite(Arc::new(rewrite::ArrayFunctionRewriter 
{}))?;
 
     Ok(())
 }
diff --git a/datafusion/functions-array/src/rewrite.rs 
b/datafusion/functions-array/src/rewrite.rs
deleted file mode 100644
index 28bc2d5e43..0000000000
--- a/datafusion/functions-array/src/rewrite.rs
+++ /dev/null
@@ -1,76 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements.  See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership.  The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License.  You may obtain a copy of the License at
-//
-//   http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied.  See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-//! Rewrites for using Array Functions
-
-use crate::array_has::array_has_all;
-use datafusion_common::config::ConfigOptions;
-use datafusion_common::tree_node::Transformed;
-use datafusion_common::DFSchema;
-use datafusion_common::Result;
-use datafusion_expr::expr::ScalarFunction;
-use datafusion_expr::expr_rewriter::FunctionRewrite;
-use datafusion_expr::{BinaryExpr, Expr, Operator};
-
-/// Rewrites expressions into function calls to array functions
-pub(crate) struct ArrayFunctionRewriter {}
-
-impl FunctionRewrite for ArrayFunctionRewriter {
-    fn name(&self) -> &str {
-        "ArrayFunctionRewriter"
-    }
-
-    fn rewrite(
-        &self,
-        expr: Expr,
-        _schema: &DFSchema,
-        _config: &ConfigOptions,
-    ) -> Result<Transformed<Expr>> {
-        let transformed = match expr {
-            // array1 @> array2 -> array_has_all(array1, array2)
-            Expr::BinaryExpr(BinaryExpr { left, op, right })
-                if op == Operator::AtArrow
-                    && is_func(&left, "make_array")
-                    && is_func(&right, "make_array") =>
-            {
-                Transformed::yes(array_has_all(*left, *right))
-            }
-
-            // array1 <@ array2 -> array_has_all(array2, array1)
-            Expr::BinaryExpr(BinaryExpr { left, op, right })
-                if op == Operator::ArrowAt
-                    && is_func(&left, "make_array")
-                    && is_func(&right, "make_array") =>
-            {
-                Transformed::yes(array_has_all(*right, *left))
-            }
-
-            _ => Transformed::no(expr),
-        };
-        Ok(transformed)
-    }
-}
-
-/// Returns true if expr is a function call to the specified named function.
-/// Returns false otherwise.
-fn is_func(expr: &Expr, func_name: &str) -> bool {
-    let Expr::ScalarFunction(ScalarFunction { func, args: _ }) = expr else {
-        return false;
-    };
-
-    func.name() == func_name
-}
diff --git a/datafusion/physical-expr-common/src/datum.rs 
b/datafusion/physical-expr-common/src/datum.rs
index 96c903180e..790e742c42 100644
--- a/datafusion/physical-expr-common/src/datum.rs
+++ b/datafusion/physical-expr-common/src/datum.rs
@@ -145,36 +145,3 @@ pub fn compare_op_for_nested(
         Ok(BooleanArray::new(values, nulls))
     }
 }
-
-#[cfg(test)]
-mod tests {
-    use arrow::{
-        array::{make_comparator, Array, BooleanArray, ListArray},
-        buffer::NullBuffer,
-        compute::SortOptions,
-        datatypes::Int32Type,
-    };
-
-    #[test]
-    fn test123() {
-        let data = vec![
-            Some(vec![Some(0), Some(1), Some(2)]),
-            None,
-            Some(vec![Some(3), None, Some(5)]),
-            Some(vec![Some(6), Some(7)]),
-        ];
-        let a = ListArray::from_iter_primitive::<Int32Type, _, _>(data);
-        let data = vec![
-            Some(vec![Some(0), Some(1), Some(2)]),
-            None,
-            Some(vec![Some(3), None, Some(5)]),
-            Some(vec![Some(6), Some(7)]),
-        ];
-        let b = ListArray::from_iter_primitive::<Int32Type, _, _>(data);
-        let cmp = make_comparator(&a, &b, SortOptions::default()).unwrap();
-        let len = a.len().min(b.len());
-        let values = (0..len).map(|i| cmp(i, i).is_eq()).collect();
-        let nulls = NullBuffer::union(a.nulls(), b.nulls());
-        println!("res: {:?}", BooleanArray::new(values, nulls));
-    }
-}
diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs
index a8af37ee6a..b1182b35ec 100644
--- a/datafusion/sql/src/expr/mod.rs
+++ b/datafusion/sql/src/expr/mod.rs
@@ -150,10 +150,38 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
                         vec![left, right],
                     )));
                 } else {
-                    return internal_err!("array_append not found");
+                    return internal_err!("array_prepend not found");
+                }
+            }
+        } else if matches!(op, Operator::AtArrow | Operator::ArrowAt) {
+            let left_type = left.get_type(schema)?;
+            let right_type = right.get_type(schema)?;
+            let left_list_ndims = list_ndims(&left_type);
+            let right_list_ndims = list_ndims(&right_type);
+            // if both are list
+            if left_list_ndims > 0 && right_list_ndims > 0 {
+                if let Some(udf) =
+                    self.context_provider.get_function_meta("array_has_all")
+                {
+                    // array1 @> array2 -> array_has_all(array1, array2)
+                    if op == Operator::AtArrow {
+                        return Ok(Expr::ScalarFunction(ScalarFunction::new_udf(
+                            udf,
+                            vec![left, right],
+                        )));
+                    // array1 <@ array2 -> array_has_all(array2, array1)
+                    } else {
+                        return Ok(Expr::ScalarFunction(ScalarFunction::new_udf(
+                            udf,
+                            vec![right, left],
+                        )));
+                    }
+                } else {
+                    return internal_err!("array_has_all not found");
                 }
             }
         }
+
         Ok(Expr::BinaryExpr(BinaryExpr::new(
             Box::new(left),
             op,
diff --git a/datafusion/sql/tests/sql_integration.rs 
b/datafusion/sql/tests/sql_integration.rs
index e72a439b32..ec623a9561 100644
--- a/datafusion/sql/tests/sql_integration.rs
+++ b/datafusion/sql/tests/sql_integration.rs
@@ -1227,22 +1227,6 @@ fn select_binary_expr_nested() {
     quick_test(sql, expected);
 }
 
-#[test]
-fn select_at_arrow_operator() {
-    let sql = "SELECT left @> right from array";
-    let expected = "Projection: array.left @> array.right\
-                        \n  TableScan: array";
-    quick_test(sql, expected);
-}
-
-#[test]
-fn select_arrow_at_operator() {
-    let sql = "SELECT left <@ right from array";
-    let expected = "Projection: array.left <@ array.right\
-                        \n  TableScan: array";
-    quick_test(sql, expected);
-}
-
 #[test]
 fn select_wildcard_with_groupby() {
     quick_test(
diff --git a/datafusion/sqllogictest/test_files/array.slt 
b/datafusion/sqllogictest/test_files/array.slt
index 77d1a9da1f..7917f1d78d 100644
--- a/datafusion/sqllogictest/test_files/array.slt
+++ b/datafusion/sqllogictest/test_files/array.slt
@@ -6076,6 +6076,17 @@ select make_array(1,2,3) @> make_array(1,3),
 ----
 true false true false false false true
 
+# Make sure it is rewritten to function array_has_all()
+query TT
+explain select [1,2,3] @> [1,3];
+----
+logical_plan
+01)Projection: Boolean(true) AS 
array_has_all(make_array(Int64(1),Int64(2),Int64(3)),make_array(Int64(1),Int64(3)))
+02)--EmptyRelation
+physical_plan
+01)ProjectionExec: expr=[true as 
array_has_all(make_array(Int64(1),Int64(2),Int64(3)),make_array(Int64(1),Int64(3)))]
+02)--PlaceholderRowExec
+
 # array containment operator with scalars #2 (arrow at)
 query BBBBBBB
 select make_array(1,3) <@ make_array(1,2,3),
@@ -6088,6 +6099,17 @@ select make_array(1,3) <@ make_array(1,2,3),
 ----
 true false true false false false true
 
+# Make sure it is rewritten to function array_has_all()
+query TT
+explain select [1,3] <@ [1,2,3];
+----
+logical_plan
+01)Projection: Boolean(true) AS 
array_has_all(make_array(Int64(1),Int64(2),Int64(3)),make_array(Int64(1),Int64(3)))
+02)--EmptyRelation
+physical_plan
+01)ProjectionExec: expr=[true as 
array_has_all(make_array(Int64(1),Int64(2),Int64(3)),make_array(Int64(1),Int64(3)))]
+02)--PlaceholderRowExec
+
 ### Array casting tests
 
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to