This is an automated email from the ASF dual-hosted git repository.

jayzhan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new b8fab5cdf4 Replace `GetFieldAccess` with indexing function in 
`SqlToRel ` (#10375)
b8fab5cdf4 is described below

commit b8fab5cdf418e1fba5e6012b815a5bc40c7771cc
Author: Jay Zhan <[email protected]>
AuthorDate: Tue May 14 10:22:28 2024 +0800

    Replace `GetFieldAccess` with indexing function in `SqlToRel ` (#10375)
    
    * use func in parser
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * add tests
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * add test
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * rm test1
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * parser done
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * fmt
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * fix exprapi test
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * fix test
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * fix conflicts
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    ---------
    
    Signed-off-by: jayzhan211 <[email protected]>
---
 datafusion/core/tests/expr_api/mod.rs       |  14 +---
 datafusion/functions-array/src/rewrite.rs   |  29 +------
 datafusion/sql/src/expr/identifier.rs       |  17 ++++-
 datafusion/sql/src/expr/mod.rs              |  48 ++++++++++--
 datafusion/sqllogictest/test_files/expr.slt | 114 +++++++++++++++++++++++++++-
 5 files changed, 172 insertions(+), 50 deletions(-)

diff --git a/datafusion/core/tests/expr_api/mod.rs 
b/datafusion/core/tests/expr_api/mod.rs
index 0dde7604cc..d7e839824b 100644
--- a/datafusion/core/tests/expr_api/mod.rs
+++ b/datafusion/core/tests/expr_api/mod.rs
@@ -60,9 +60,8 @@ fn test_eq_with_coercion() {
 
 #[test]
 fn test_get_field() {
-    // field access Expr::field() requires a rewrite to work
     evaluate_expr_test(
-        col("props").field("a"),
+        get_field(col("props"), lit("a")),
         vec![
             "+------------+",
             "| expr       |",
@@ -77,11 +76,8 @@ fn test_get_field() {
 
 #[test]
 fn test_nested_get_field() {
-    // field access Expr::field() requires a rewrite to work, test when it is
-    // not the root expression
     evaluate_expr_test(
-        col("props")
-            .field("a")
+        get_field(col("props"), lit("a"))
             .eq(lit("2021-02-02"))
             .or(col("id").eq(lit(1))),
         vec![
@@ -98,9 +94,8 @@ fn test_nested_get_field() {
 
 #[test]
 fn test_list() {
-    // list access also requires a rewrite to work
     evaluate_expr_test(
-        col("list").index(lit(1i64)),
+        array_element(col("list"), lit(1i64)),
         vec![
             "+------+", "| expr |", "+------+", "| one  |", "| two  |", "| 
five |",
             "+------+",
@@ -110,9 +105,8 @@ fn test_list() {
 
 #[test]
 fn test_list_range() {
-    // range access also requires a rewrite to work
     evaluate_expr_test(
-        col("list").range(lit(1i64), lit(2i64)),
+        array_slice(col("list"), lit(1i64), lit(2i64), None),
         vec![
             "+--------------+",
             "| expr         |",
diff --git a/datafusion/functions-array/src/rewrite.rs 
b/datafusion/functions-array/src/rewrite.rs
index 5280355a82..a7aba78c1d 100644
--- a/datafusion/functions-array/src/rewrite.rs
+++ b/datafusion/functions-array/src/rewrite.rs
@@ -19,7 +19,6 @@
 
 use crate::array_has::array_has_all;
 use crate::concat::{array_append, array_concat, array_prepend};
-use crate::extract::{array_element, array_slice};
 use datafusion_common::config::ConfigOptions;
 use datafusion_common::tree_node::Transformed;
 use datafusion_common::utils::list_ndims;
@@ -27,8 +26,7 @@ use datafusion_common::Result;
 use datafusion_common::{Column, DFSchema};
 use datafusion_expr::expr::ScalarFunction;
 use datafusion_expr::expr_rewriter::FunctionRewrite;
-use datafusion_expr::{BinaryExpr, Expr, GetFieldAccess, GetIndexedField, 
Operator};
-use datafusion_functions::expr_fn::get_field;
+use datafusion_expr::{BinaryExpr, Expr, Operator};
 
 /// Rewrites expressions into function calls to array functions
 pub(crate) struct ArrayFunctionRewriter {}
@@ -148,31 +146,6 @@ impl FunctionRewrite for ArrayFunctionRewriter {
                 Transformed::yes(array_prepend(*left, *right))
             }
 
-            Expr::GetIndexedField(GetIndexedField {
-                expr,
-                field: GetFieldAccess::NamedStructField { name },
-            }) => {
-                let name = Expr::Literal(name);
-                Transformed::yes(get_field(*expr, name))
-            }
-
-            // expr[idx] ==> array_element(expr, idx)
-            Expr::GetIndexedField(GetIndexedField {
-                expr,
-                field: GetFieldAccess::ListIndex { key },
-            }) => Transformed::yes(array_element(*expr, *key)),
-
-            // expr[start, stop, stride] ==> array_slice(expr, start, stop, 
stride)
-            Expr::GetIndexedField(GetIndexedField {
-                expr,
-                field:
-                    GetFieldAccess::ListRange {
-                        start,
-                        stop,
-                        stride,
-                    },
-            }) => Transformed::yes(array_slice(*expr, *start, *stop, 
Some(*stride))),
-
             _ => Transformed::no(expr),
         };
         Ok(transformed)
diff --git a/datafusion/sql/src/expr/identifier.rs 
b/datafusion/sql/src/expr/identifier.rs
index 713ad6f72c..d297b2e4df 100644
--- a/datafusion/sql/src/expr/identifier.rs
+++ b/datafusion/sql/src/expr/identifier.rs
@@ -19,9 +19,9 @@ use crate::planner::{ContextProvider, PlannerContext, 
SqlToRel};
 use arrow_schema::Field;
 use datafusion_common::{
     internal_err, plan_datafusion_err, Column, DFSchema, DataFusionError, 
Result,
-    TableReference,
+    ScalarValue, TableReference,
 };
-use datafusion_expr::{Case, Expr};
+use datafusion_expr::{expr::ScalarFunction, lit, Case, Expr};
 use sqlparser::ast::{Expr as SQLExpr, Ident};
 
 impl<'a, S: ContextProvider> SqlToRel<'a, S> {
@@ -133,7 +133,18 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
                         );
                     }
                     let nested_name = nested_names[0].to_string();
-                    Ok(Expr::Column(Column::from((qualifier, 
field))).field(nested_name))
+
+                    let col = Expr::Column(Column::from((qualifier, field)));
+                    if let Some(udf) =
+                        self.context_provider.get_function_meta("get_field")
+                    {
+                        Ok(Expr::ScalarFunction(ScalarFunction::new_udf(
+                            udf,
+                            vec![col, lit(ScalarValue::from(nested_name))],
+                        )))
+                    } else {
+                        internal_err!("get_field not found")
+                    }
                 }
                 // found matching field with no spare identifier(s)
                 Some((field, qualifier, _nested_names)) => {
diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs
index ed5421edfb..6445c3f7a8 100644
--- a/datafusion/sql/src/expr/mod.rs
+++ b/datafusion/sql/src/expr/mod.rs
@@ -29,7 +29,7 @@ use datafusion_expr::expr::InList;
 use datafusion_expr::expr::ScalarFunction;
 use datafusion_expr::{
     col, expr, lit, AggregateFunction, Between, BinaryExpr, Cast, Expr, 
ExprSchemable,
-    GetFieldAccess, GetIndexedField, Like, Literal, Operator, TryCast,
+    GetFieldAccess, Like, Literal, Operator, TryCast,
 };
 
 use crate::planner::{ContextProvider, PlannerContext, SqlToRel};
@@ -1019,10 +1019,48 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
             expr
         };
 
-        Ok(Expr::GetIndexedField(GetIndexedField::new(
-            Box::new(expr),
-            self.plan_indices(indices, schema, planner_context)?,
-        )))
+        let field = self.plan_indices(indices, schema, planner_context)?;
+        match field {
+            GetFieldAccess::NamedStructField { name } => {
+                if let Some(udf) = 
self.context_provider.get_function_meta("get_field") {
+                    Ok(Expr::ScalarFunction(ScalarFunction::new_udf(
+                        udf,
+                        vec![expr, lit(name)],
+                    )))
+                } else {
+                    internal_err!("get_field not found")
+                }
+            }
+            // expr[idx] ==> array_element(expr, idx)
+            GetFieldAccess::ListIndex { key } => {
+                if let Some(udf) =
+                    self.context_provider.get_function_meta("array_element")
+                {
+                    Ok(Expr::ScalarFunction(ScalarFunction::new_udf(
+                        udf,
+                        vec![expr, *key],
+                    )))
+                } else {
+                    internal_err!("get_field not found")
+                }
+            }
+            // expr[start, stop, stride] ==> array_slice(expr, start, stop, 
stride)
+            GetFieldAccess::ListRange {
+                start,
+                stop,
+                stride,
+            } => {
+                if let Some(udf) = 
self.context_provider.get_function_meta("array_slice")
+                {
+                    Ok(Expr::ScalarFunction(ScalarFunction::new_udf(
+                        udf,
+                        vec![expr, *start, *stop, *stride],
+                    )))
+                } else {
+                    internal_err!("array_slice not found")
+                }
+            }
+        }
     }
 }
 
diff --git a/datafusion/sqllogictest/test_files/expr.slt 
b/datafusion/sqllogictest/test_files/expr.slt
index 4b5f4d770a..2dc00cbc50 100644
--- a/datafusion/sqllogictest/test_files/expr.slt
+++ b/datafusion/sqllogictest/test_files/expr.slt
@@ -2324,28 +2324,134 @@ host3 3.3
 
 # can have an aggregate function with an inner CASE WHEN
 query TR
-select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host, sum((case 
when t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] is not null then 
t2."struct(t1.time,t1.load1,t1.load2,t1.host)" end)['c2']) from (select 
struct(time,load1,load2,host) from t1) t2 where 
t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] IS NOT NULL group by 
t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] order by host;
+select 
+    t2.server_host as host, 
+    sum((
+        case when t2.server_host is not null 
+        then t2.server_load2
+        end
+    )) 
+    from (
+        select 
+            struct(time,load1,load2,host)['c2'] as server_load2,
+            struct(time,load1,load2,host)['c3'] as server_host
+        from t1
+    ) t2 
+    where server_host IS NOT NULL 
+    group by server_host order by host;
 ----
 host1 101
 host2 202
 host3 303
 
+# TODO: Issue tracked in https://github.com/apache/datafusion/issues/10364
+query error
+select 
+    t2.server['c3'] as host, 
+    sum((
+        case when t2.server['c3'] is not null 
+        then t2.server['c2']
+        end
+    )) 
+    from (
+        select 
+            struct(time,load1,load2,host) as server
+        from t1
+    ) t2 
+    where t2.server['c3'] IS NOT NULL 
+    group by t2.server['c3'] order by host;
+
 # can have 2 projections with aggr(short_circuited), with different 
short-circuited expr
 query TRR
-select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host, 
sum(coalesce(t2."struct(t1.time,t1.load1,t1.load2,t1.host)")['c1']), sum((case 
when t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] is not null then 
t2."struct(t1.time,t1.load1,t1.load2,t1.host)" end)['c2']) from (select 
struct(time,load1,load2,host) from t1) t2 where 
t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] IS NOT NULL group by 
t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] order by host;
+select 
+    t2.server_host as host, 
+    sum(coalesce(server_load1)),
+    sum((
+        case when t2.server_host is not null 
+        then t2.server_load2
+        end
+    )) 
+    from (
+        select 
+            struct(time,load1,load2,host)['c1'] as server_load1,
+            struct(time,load1,load2,host)['c2'] as server_load2,
+            struct(time,load1,load2,host)['c3'] as server_host
+        from t1
+    ) t2 
+    where server_host IS NOT NULL 
+    group by server_host order by host;
 ----
 host1 1.1 101
 host2 2.2 202
 host3 3.3 303
 
-# can have 2 projections with aggr(short_circuited), with the same 
short-circuited expr (e.g. CASE WHEN)
+# TODO: Issue tracked in https://github.com/apache/datafusion/issues/10364
+query error
+select 
+    t2.server['c3'] as host, 
+    sum(coalesce(server['c1'])),
+    sum((
+        case when t2.server['c3'] is not null 
+        then t2.server['c2']
+        end
+    )) 
+    from (
+        select 
+            struct(time,load1,load2,host) as server,
+        from t1
+    ) t2 
+    where server_host IS NOT NULL 
+    group by server_host order by host;
+
 query TRR
-select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host, sum((case 
when t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] is not null then 
t2."struct(t1.time,t1.load1,t1.load2,t1.host)" end)['c1']), sum((case when 
t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] is not null then 
t2."struct(t1.time,t1.load1,t1.load2,t1.host)" end)['c2']) from (select 
struct(time,load1,load2,host) from t1) t2 where 
t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] IS NOT NULL group by [...]
+select 
+    t2.server_host as host, 
+    sum((
+        case when t2.server_host is not null 
+        then server_load1 
+        end
+    )), 
+    sum((
+        case when server_host is not null 
+        then server_load2 
+        end
+    )) 
+    from (
+        select 
+            struct(time,load1,load2,host)['c1'] as server_load1,
+            struct(time,load1,load2,host)['c2'] as server_load2,
+            struct(time,load1,load2,host)['c3'] as server_host
+        from t1
+    ) t2 
+    where server_host IS NOT NULL 
+    group by server_host order by host;
 ----
 host1 1.1 101
 host2 2.2 202
 host3 3.3 303
 
+# TODO: Issue tracked in https://github.com/apache/datafusion/issues/10364
+query error
+select 
+    t2.server['c3'] as host, 
+    sum((
+        case when t2.server['c3'] is not null 
+        then t2.server['c1']
+        end
+    )), 
+    sum((
+        case when t2.server['c3'] is not null 
+        then t2.server['c2']
+        end
+    )) 
+    from (
+        select 
+            struct(time,load1,load2,host) as server 
+        from t1
+    ) t2 
+    where t2.server['c3'] IS NOT NULL 
+    group by t2.server['c3'] order by host;
+
 # can have 2 projections with aggr(short_circuited), with the same 
short-circuited expr (e.g. coalesce)
 query TRR
 select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host, 
sum(coalesce(t2."struct(t1.time,t1.load1,t1.load2,t1.host)")['c1']), 
sum(coalesce(t2."struct(t1.time,t1.load1,t1.load2,t1.host)")['c2']) from 
(select struct(time,load1,load2,host) from t1) t2 where 
t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] IS NOT NULL group by 
t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] order by host;


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to