This is an automated email from the ASF dual-hosted git repository.
jayzhan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 14d39734dc Rewrite `array @> array` and `array <@ array` in
sql_expr_to_logical_expr (#11155)
14d39734dc is described below
commit 14d39734dce366f70ccc014e93a892d8a8b52537
Author: Jay Zhan <[email protected]>
AuthorDate: Sat Jun 29 20:14:15 2024 +0800
Rewrite `array @> array` and `array <@ array` in sql_expr_to_logical_expr
(#11155)
* rewrite at arrow
Signed-off-by: jayzhan211 <[email protected]>
* rm useless test
Signed-off-by: jayzhan211 <[email protected]>
* add test
Signed-off-by: jayzhan211 <[email protected]>
* rm test
Signed-off-by: jayzhan211 <[email protected]>
---------
Signed-off-by: jayzhan211 <[email protected]>
---
datafusion/expr/src/expr_rewriter/mod.rs | 1 +
datafusion/functions-array/src/lib.rs | 2 -
datafusion/functions-array/src/rewrite.rs | 76 ----------------------------
datafusion/physical-expr-common/src/datum.rs | 33 ------------
datafusion/sql/src/expr/mod.rs | 30 ++++++++++-
datafusion/sql/tests/sql_integration.rs | 16 ------
datafusion/sqllogictest/test_files/array.slt | 22 ++++++++
7 files changed, 52 insertions(+), 128 deletions(-)
diff --git a/datafusion/expr/src/expr_rewriter/mod.rs
b/datafusion/expr/src/expr_rewriter/mod.rs
index 1441374bdb..024e4a0cea 100644
--- a/datafusion/expr/src/expr_rewriter/mod.rs
+++ b/datafusion/expr/src/expr_rewriter/mod.rs
@@ -43,6 +43,7 @@ pub use order_by::rewrite_sort_cols_by_aggs;
/// For example, concatenating arrays `a || b` is represented as
/// `Operator::ArrowAt`, but can be implemented by calling a function
/// `array_concat` from the `functions-array` crate.
+// This is not used in datafusion internally, but it is still helpful for
downstream project so don't remove it.
pub trait FunctionRewrite {
/// Return a human readable name for this rewrite
fn name(&self) -> &str;
diff --git a/datafusion/functions-array/src/lib.rs
b/datafusion/functions-array/src/lib.rs
index b2fcb5717b..543b7a6027 100644
--- a/datafusion/functions-array/src/lib.rs
+++ b/datafusion/functions-array/src/lib.rs
@@ -46,7 +46,6 @@ pub mod repeat;
pub mod replace;
pub mod resize;
pub mod reverse;
-pub mod rewrite;
pub mod set_ops;
pub mod sort;
pub mod string;
@@ -152,7 +151,6 @@ pub fn register_all(registry: &mut dyn FunctionRegistry) ->
Result<()> {
}
Ok(()) as Result<()>
})?;
- registry.register_function_rewrite(Arc::new(rewrite::ArrayFunctionRewriter
{}))?;
Ok(())
}
diff --git a/datafusion/functions-array/src/rewrite.rs
b/datafusion/functions-array/src/rewrite.rs
deleted file mode 100644
index 28bc2d5e43..0000000000
--- a/datafusion/functions-array/src/rewrite.rs
+++ /dev/null
@@ -1,76 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements. See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership. The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License. You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied. See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-//! Rewrites for using Array Functions
-
-use crate::array_has::array_has_all;
-use datafusion_common::config::ConfigOptions;
-use datafusion_common::tree_node::Transformed;
-use datafusion_common::DFSchema;
-use datafusion_common::Result;
-use datafusion_expr::expr::ScalarFunction;
-use datafusion_expr::expr_rewriter::FunctionRewrite;
-use datafusion_expr::{BinaryExpr, Expr, Operator};
-
-/// Rewrites expressions into function calls to array functions
-pub(crate) struct ArrayFunctionRewriter {}
-
-impl FunctionRewrite for ArrayFunctionRewriter {
- fn name(&self) -> &str {
- "ArrayFunctionRewriter"
- }
-
- fn rewrite(
- &self,
- expr: Expr,
- _schema: &DFSchema,
- _config: &ConfigOptions,
- ) -> Result<Transformed<Expr>> {
- let transformed = match expr {
- // array1 @> array2 -> array_has_all(array1, array2)
- Expr::BinaryExpr(BinaryExpr { left, op, right })
- if op == Operator::AtArrow
- && is_func(&left, "make_array")
- && is_func(&right, "make_array") =>
- {
- Transformed::yes(array_has_all(*left, *right))
- }
-
- // array1 <@ array2 -> array_has_all(array2, array1)
- Expr::BinaryExpr(BinaryExpr { left, op, right })
- if op == Operator::ArrowAt
- && is_func(&left, "make_array")
- && is_func(&right, "make_array") =>
- {
- Transformed::yes(array_has_all(*right, *left))
- }
-
- _ => Transformed::no(expr),
- };
- Ok(transformed)
- }
-}
-
-/// Returns true if expr is a function call to the specified named function.
-/// Returns false otherwise.
-fn is_func(expr: &Expr, func_name: &str) -> bool {
- let Expr::ScalarFunction(ScalarFunction { func, args: _ }) = expr else {
- return false;
- };
-
- func.name() == func_name
-}
diff --git a/datafusion/physical-expr-common/src/datum.rs
b/datafusion/physical-expr-common/src/datum.rs
index 96c903180e..790e742c42 100644
--- a/datafusion/physical-expr-common/src/datum.rs
+++ b/datafusion/physical-expr-common/src/datum.rs
@@ -145,36 +145,3 @@ pub fn compare_op_for_nested(
Ok(BooleanArray::new(values, nulls))
}
}
-
-#[cfg(test)]
-mod tests {
- use arrow::{
- array::{make_comparator, Array, BooleanArray, ListArray},
- buffer::NullBuffer,
- compute::SortOptions,
- datatypes::Int32Type,
- };
-
- #[test]
- fn test123() {
- let data = vec![
- Some(vec![Some(0), Some(1), Some(2)]),
- None,
- Some(vec![Some(3), None, Some(5)]),
- Some(vec![Some(6), Some(7)]),
- ];
- let a = ListArray::from_iter_primitive::<Int32Type, _, _>(data);
- let data = vec![
- Some(vec![Some(0), Some(1), Some(2)]),
- None,
- Some(vec![Some(3), None, Some(5)]),
- Some(vec![Some(6), Some(7)]),
- ];
- let b = ListArray::from_iter_primitive::<Int32Type, _, _>(data);
- let cmp = make_comparator(&a, &b, SortOptions::default()).unwrap();
- let len = a.len().min(b.len());
- let values = (0..len).map(|i| cmp(i, i).is_eq()).collect();
- let nulls = NullBuffer::union(a.nulls(), b.nulls());
- println!("res: {:?}", BooleanArray::new(values, nulls));
- }
-}
diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs
index a8af37ee6a..b1182b35ec 100644
--- a/datafusion/sql/src/expr/mod.rs
+++ b/datafusion/sql/src/expr/mod.rs
@@ -150,10 +150,38 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
vec![left, right],
)));
} else {
- return internal_err!("array_append not found");
+ return internal_err!("array_prepend not found");
+ }
+ }
+ } else if matches!(op, Operator::AtArrow | Operator::ArrowAt) {
+ let left_type = left.get_type(schema)?;
+ let right_type = right.get_type(schema)?;
+ let left_list_ndims = list_ndims(&left_type);
+ let right_list_ndims = list_ndims(&right_type);
+ // if both are list
+ if left_list_ndims > 0 && right_list_ndims > 0 {
+ if let Some(udf) =
+ self.context_provider.get_function_meta("array_has_all")
+ {
+ // array1 @> array2 -> array_has_all(array1, array2)
+ if op == Operator::AtArrow {
+ return Ok(Expr::ScalarFunction(ScalarFunction::new_udf(
+ udf,
+ vec![left, right],
+ )));
+ // array1 <@ array2 -> array_has_all(array2, array1)
+ } else {
+ return Ok(Expr::ScalarFunction(ScalarFunction::new_udf(
+ udf,
+ vec![right, left],
+ )));
+ }
+ } else {
+ return internal_err!("array_has_all not found");
}
}
}
+
Ok(Expr::BinaryExpr(BinaryExpr::new(
Box::new(left),
op,
diff --git a/datafusion/sql/tests/sql_integration.rs
b/datafusion/sql/tests/sql_integration.rs
index e72a439b32..ec623a9561 100644
--- a/datafusion/sql/tests/sql_integration.rs
+++ b/datafusion/sql/tests/sql_integration.rs
@@ -1227,22 +1227,6 @@ fn select_binary_expr_nested() {
quick_test(sql, expected);
}
-#[test]
-fn select_at_arrow_operator() {
- let sql = "SELECT left @> right from array";
- let expected = "Projection: array.left @> array.right\
- \n TableScan: array";
- quick_test(sql, expected);
-}
-
-#[test]
-fn select_arrow_at_operator() {
- let sql = "SELECT left <@ right from array";
- let expected = "Projection: array.left <@ array.right\
- \n TableScan: array";
- quick_test(sql, expected);
-}
-
#[test]
fn select_wildcard_with_groupby() {
quick_test(
diff --git a/datafusion/sqllogictest/test_files/array.slt
b/datafusion/sqllogictest/test_files/array.slt
index 77d1a9da1f..7917f1d78d 100644
--- a/datafusion/sqllogictest/test_files/array.slt
+++ b/datafusion/sqllogictest/test_files/array.slt
@@ -6076,6 +6076,17 @@ select make_array(1,2,3) @> make_array(1,3),
----
true false true false false false true
+# Make sure it is rewritten to function array_has_all()
+query TT
+explain select [1,2,3] @> [1,3];
+----
+logical_plan
+01)Projection: Boolean(true) AS
array_has_all(make_array(Int64(1),Int64(2),Int64(3)),make_array(Int64(1),Int64(3)))
+02)--EmptyRelation
+physical_plan
+01)ProjectionExec: expr=[true as
array_has_all(make_array(Int64(1),Int64(2),Int64(3)),make_array(Int64(1),Int64(3)))]
+02)--PlaceholderRowExec
+
# array containment operator with scalars #2 (arrow at)
query BBBBBBB
select make_array(1,3) <@ make_array(1,2,3),
@@ -6088,6 +6099,17 @@ select make_array(1,3) <@ make_array(1,2,3),
----
true false true false false false true
+# Make sure it is rewritten to function array_has_all()
+query TT
+explain select [1,3] <@ [1,2,3];
+----
+logical_plan
+01)Projection: Boolean(true) AS
array_has_all(make_array(Int64(1),Int64(2),Int64(3)),make_array(Int64(1),Int64(3)))
+02)--EmptyRelation
+physical_plan
+01)ProjectionExec: expr=[true as
array_has_all(make_array(Int64(1),Int64(2),Int64(3)),make_array(Int64(1),Int64(3)))]
+02)--PlaceholderRowExec
+
### Array casting tests
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]