This is an automated email from the ASF dual-hosted git repository.
jayzhan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new b8fab5cdf4 Replace `GetFieldAccess` with indexing function in
`SqlToRel ` (#10375)
b8fab5cdf4 is described below
commit b8fab5cdf418e1fba5e6012b815a5bc40c7771cc
Author: Jay Zhan <[email protected]>
AuthorDate: Tue May 14 10:22:28 2024 +0800
Replace `GetFieldAccess` with indexing function in `SqlToRel ` (#10375)
* use func in parser
Signed-off-by: jayzhan211 <[email protected]>
* add tests
Signed-off-by: jayzhan211 <[email protected]>
* add test
Signed-off-by: jayzhan211 <[email protected]>
* rm test1
Signed-off-by: jayzhan211 <[email protected]>
* parser done
Signed-off-by: jayzhan211 <[email protected]>
* fmt
Signed-off-by: jayzhan211 <[email protected]>
* fix exprapi test
Signed-off-by: jayzhan211 <[email protected]>
* fix test
Signed-off-by: jayzhan211 <[email protected]>
* fix conflicts
Signed-off-by: jayzhan211 <[email protected]>
---------
Signed-off-by: jayzhan211 <[email protected]>
---
datafusion/core/tests/expr_api/mod.rs | 14 +---
datafusion/functions-array/src/rewrite.rs | 29 +------
datafusion/sql/src/expr/identifier.rs | 17 ++++-
datafusion/sql/src/expr/mod.rs | 48 ++++++++++--
datafusion/sqllogictest/test_files/expr.slt | 114 +++++++++++++++++++++++++++-
5 files changed, 172 insertions(+), 50 deletions(-)
diff --git a/datafusion/core/tests/expr_api/mod.rs
b/datafusion/core/tests/expr_api/mod.rs
index 0dde7604cc..d7e839824b 100644
--- a/datafusion/core/tests/expr_api/mod.rs
+++ b/datafusion/core/tests/expr_api/mod.rs
@@ -60,9 +60,8 @@ fn test_eq_with_coercion() {
#[test]
fn test_get_field() {
- // field access Expr::field() requires a rewrite to work
evaluate_expr_test(
- col("props").field("a"),
+ get_field(col("props"), lit("a")),
vec![
"+------------+",
"| expr |",
@@ -77,11 +76,8 @@ fn test_get_field() {
#[test]
fn test_nested_get_field() {
- // field access Expr::field() requires a rewrite to work, test when it is
- // not the root expression
evaluate_expr_test(
- col("props")
- .field("a")
+ get_field(col("props"), lit("a"))
.eq(lit("2021-02-02"))
.or(col("id").eq(lit(1))),
vec![
@@ -98,9 +94,8 @@ fn test_nested_get_field() {
#[test]
fn test_list() {
- // list access also requires a rewrite to work
evaluate_expr_test(
- col("list").index(lit(1i64)),
+ array_element(col("list"), lit(1i64)),
vec![
"+------+", "| expr |", "+------+", "| one |", "| two |", "|
five |",
"+------+",
@@ -110,9 +105,8 @@ fn test_list() {
#[test]
fn test_list_range() {
- // range access also requires a rewrite to work
evaluate_expr_test(
- col("list").range(lit(1i64), lit(2i64)),
+ array_slice(col("list"), lit(1i64), lit(2i64), None),
vec![
"+--------------+",
"| expr |",
diff --git a/datafusion/functions-array/src/rewrite.rs
b/datafusion/functions-array/src/rewrite.rs
index 5280355a82..a7aba78c1d 100644
--- a/datafusion/functions-array/src/rewrite.rs
+++ b/datafusion/functions-array/src/rewrite.rs
@@ -19,7 +19,6 @@
use crate::array_has::array_has_all;
use crate::concat::{array_append, array_concat, array_prepend};
-use crate::extract::{array_element, array_slice};
use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::Transformed;
use datafusion_common::utils::list_ndims;
@@ -27,8 +26,7 @@ use datafusion_common::Result;
use datafusion_common::{Column, DFSchema};
use datafusion_expr::expr::ScalarFunction;
use datafusion_expr::expr_rewriter::FunctionRewrite;
-use datafusion_expr::{BinaryExpr, Expr, GetFieldAccess, GetIndexedField,
Operator};
-use datafusion_functions::expr_fn::get_field;
+use datafusion_expr::{BinaryExpr, Expr, Operator};
/// Rewrites expressions into function calls to array functions
pub(crate) struct ArrayFunctionRewriter {}
@@ -148,31 +146,6 @@ impl FunctionRewrite for ArrayFunctionRewriter {
Transformed::yes(array_prepend(*left, *right))
}
- Expr::GetIndexedField(GetIndexedField {
- expr,
- field: GetFieldAccess::NamedStructField { name },
- }) => {
- let name = Expr::Literal(name);
- Transformed::yes(get_field(*expr, name))
- }
-
- // expr[idx] ==> array_element(expr, idx)
- Expr::GetIndexedField(GetIndexedField {
- expr,
- field: GetFieldAccess::ListIndex { key },
- }) => Transformed::yes(array_element(*expr, *key)),
-
- // expr[start, stop, stride] ==> array_slice(expr, start, stop,
stride)
- Expr::GetIndexedField(GetIndexedField {
- expr,
- field:
- GetFieldAccess::ListRange {
- start,
- stop,
- stride,
- },
- }) => Transformed::yes(array_slice(*expr, *start, *stop,
Some(*stride))),
-
_ => Transformed::no(expr),
};
Ok(transformed)
diff --git a/datafusion/sql/src/expr/identifier.rs
b/datafusion/sql/src/expr/identifier.rs
index 713ad6f72c..d297b2e4df 100644
--- a/datafusion/sql/src/expr/identifier.rs
+++ b/datafusion/sql/src/expr/identifier.rs
@@ -19,9 +19,9 @@ use crate::planner::{ContextProvider, PlannerContext,
SqlToRel};
use arrow_schema::Field;
use datafusion_common::{
internal_err, plan_datafusion_err, Column, DFSchema, DataFusionError,
Result,
- TableReference,
+ ScalarValue, TableReference,
};
-use datafusion_expr::{Case, Expr};
+use datafusion_expr::{expr::ScalarFunction, lit, Case, Expr};
use sqlparser::ast::{Expr as SQLExpr, Ident};
impl<'a, S: ContextProvider> SqlToRel<'a, S> {
@@ -133,7 +133,18 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
);
}
let nested_name = nested_names[0].to_string();
- Ok(Expr::Column(Column::from((qualifier,
field))).field(nested_name))
+
+ let col = Expr::Column(Column::from((qualifier, field)));
+ if let Some(udf) =
+ self.context_provider.get_function_meta("get_field")
+ {
+ Ok(Expr::ScalarFunction(ScalarFunction::new_udf(
+ udf,
+ vec![col, lit(ScalarValue::from(nested_name))],
+ )))
+ } else {
+ internal_err!("get_field not found")
+ }
}
// found matching field with no spare identifier(s)
Some((field, qualifier, _nested_names)) => {
diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs
index ed5421edfb..6445c3f7a8 100644
--- a/datafusion/sql/src/expr/mod.rs
+++ b/datafusion/sql/src/expr/mod.rs
@@ -29,7 +29,7 @@ use datafusion_expr::expr::InList;
use datafusion_expr::expr::ScalarFunction;
use datafusion_expr::{
col, expr, lit, AggregateFunction, Between, BinaryExpr, Cast, Expr,
ExprSchemable,
- GetFieldAccess, GetIndexedField, Like, Literal, Operator, TryCast,
+ GetFieldAccess, Like, Literal, Operator, TryCast,
};
use crate::planner::{ContextProvider, PlannerContext, SqlToRel};
@@ -1019,10 +1019,48 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
expr
};
- Ok(Expr::GetIndexedField(GetIndexedField::new(
- Box::new(expr),
- self.plan_indices(indices, schema, planner_context)?,
- )))
+ let field = self.plan_indices(indices, schema, planner_context)?;
+ match field {
+ GetFieldAccess::NamedStructField { name } => {
+ if let Some(udf) =
self.context_provider.get_function_meta("get_field") {
+ Ok(Expr::ScalarFunction(ScalarFunction::new_udf(
+ udf,
+ vec![expr, lit(name)],
+ )))
+ } else {
+ internal_err!("get_field not found")
+ }
+ }
+ // expr[idx] ==> array_element(expr, idx)
+ GetFieldAccess::ListIndex { key } => {
+ if let Some(udf) =
+ self.context_provider.get_function_meta("array_element")
+ {
+ Ok(Expr::ScalarFunction(ScalarFunction::new_udf(
+ udf,
+ vec![expr, *key],
+ )))
+ } else {
+ internal_err!("get_field not found")
+ }
+ }
+ // expr[start, stop, stride] ==> array_slice(expr, start, stop,
stride)
+ GetFieldAccess::ListRange {
+ start,
+ stop,
+ stride,
+ } => {
+ if let Some(udf) =
self.context_provider.get_function_meta("array_slice")
+ {
+ Ok(Expr::ScalarFunction(ScalarFunction::new_udf(
+ udf,
+ vec![expr, *start, *stop, *stride],
+ )))
+ } else {
+ internal_err!("array_slice not found")
+ }
+ }
+ }
}
}
diff --git a/datafusion/sqllogictest/test_files/expr.slt
b/datafusion/sqllogictest/test_files/expr.slt
index 4b5f4d770a..2dc00cbc50 100644
--- a/datafusion/sqllogictest/test_files/expr.slt
+++ b/datafusion/sqllogictest/test_files/expr.slt
@@ -2324,28 +2324,134 @@ host3 3.3
# can have an aggregate function with an inner CASE WHEN
query TR
-select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host, sum((case
when t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] is not null then
t2."struct(t1.time,t1.load1,t1.load2,t1.host)" end)['c2']) from (select
struct(time,load1,load2,host) from t1) t2 where
t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] IS NOT NULL group by
t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] order by host;
+select
+ t2.server_host as host,
+ sum((
+ case when t2.server_host is not null
+ then t2.server_load2
+ end
+ ))
+ from (
+ select
+ struct(time,load1,load2,host)['c2'] as server_load2,
+ struct(time,load1,load2,host)['c3'] as server_host
+ from t1
+ ) t2
+ where server_host IS NOT NULL
+ group by server_host order by host;
----
host1 101
host2 202
host3 303
+# TODO: Issue tracked in https://github.com/apache/datafusion/issues/10364
+query error
+select
+ t2.server['c3'] as host,
+ sum((
+ case when t2.server['c3'] is not null
+ then t2.server['c2']
+ end
+ ))
+ from (
+ select
+ struct(time,load1,load2,host) as server
+ from t1
+ ) t2
+ where t2.server['c3'] IS NOT NULL
+ group by t2.server['c3'] order by host;
+
# can have 2 projections with aggr(short_circuited), with different
short-circuited expr
query TRR
-select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host,
sum(coalesce(t2."struct(t1.time,t1.load1,t1.load2,t1.host)")['c1']), sum((case
when t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] is not null then
t2."struct(t1.time,t1.load1,t1.load2,t1.host)" end)['c2']) from (select
struct(time,load1,load2,host) from t1) t2 where
t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] IS NOT NULL group by
t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] order by host;
+select
+ t2.server_host as host,
+ sum(coalesce(server_load1)),
+ sum((
+ case when t2.server_host is not null
+ then t2.server_load2
+ end
+ ))
+ from (
+ select
+ struct(time,load1,load2,host)['c1'] as server_load1,
+ struct(time,load1,load2,host)['c2'] as server_load2,
+ struct(time,load1,load2,host)['c3'] as server_host
+ from t1
+ ) t2
+ where server_host IS NOT NULL
+ group by server_host order by host;
----
host1 1.1 101
host2 2.2 202
host3 3.3 303
-# can have 2 projections with aggr(short_circuited), with the same
short-circuited expr (e.g. CASE WHEN)
+# TODO: Issue tracked in https://github.com/apache/datafusion/issues/10364
+query error
+select
+ t2.server['c3'] as host,
+ sum(coalesce(server['c1'])),
+ sum((
+ case when t2.server['c3'] is not null
+ then t2.server['c2']
+ end
+ ))
+ from (
+ select
+ struct(time,load1,load2,host) as server,
+ from t1
+ ) t2
+ where server_host IS NOT NULL
+ group by server_host order by host;
+
query TRR
-select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host, sum((case
when t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] is not null then
t2."struct(t1.time,t1.load1,t1.load2,t1.host)" end)['c1']), sum((case when
t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] is not null then
t2."struct(t1.time,t1.load1,t1.load2,t1.host)" end)['c2']) from (select
struct(time,load1,load2,host) from t1) t2 where
t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] IS NOT NULL group by [...]
+select
+ t2.server_host as host,
+ sum((
+ case when t2.server_host is not null
+ then server_load1
+ end
+ )),
+ sum((
+ case when server_host is not null
+ then server_load2
+ end
+ ))
+ from (
+ select
+ struct(time,load1,load2,host)['c1'] as server_load1,
+ struct(time,load1,load2,host)['c2'] as server_load2,
+ struct(time,load1,load2,host)['c3'] as server_host
+ from t1
+ ) t2
+ where server_host IS NOT NULL
+ group by server_host order by host;
----
host1 1.1 101
host2 2.2 202
host3 3.3 303
+# TODO: Issue tracked in https://github.com/apache/datafusion/issues/10364
+query error
+select
+ t2.server['c3'] as host,
+ sum((
+ case when t2.server['c3'] is not null
+ then t2.server['c1']
+ end
+ )),
+ sum((
+ case when t2.server['c3'] is not null
+ then t2.server['c2']
+ end
+ ))
+ from (
+ select
+ struct(time,load1,load2,host) as server
+ from t1
+ ) t2
+ where t2.server['c3'] IS NOT NULL
+ group by t2.server['c3'] order by host;
+
# can have 2 projections with aggr(short_circuited), with the same
short-circuited expr (e.g. coalesce)
query TRR
select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host,
sum(coalesce(t2."struct(t1.time,t1.load1,t1.load2,t1.host)")['c1']),
sum(coalesce(t2."struct(t1.time,t1.load1,t1.load2,t1.host)")['c2']) from
(select struct(time,load1,load2,host) from t1) t2 where
t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] IS NOT NULL group by
t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] order by host;
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]