This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 4ecf3e739 feat: prepare logical plan to logical plan without 
params/placeholders (#4561)
4ecf3e739 is described below

commit 4ecf3e7399b1f728a5bd84c10dbfe7a4f0220dd9
Author: Nga Tran <[email protected]>
AuthorDate: Mon Dec 12 09:32:43 2022 -0500

    feat: prepare logical plan to logical plan without params/placeholders 
(#4561)
    
    * feat: prepare logical plan to logicl plan without params/placeholders
    
    * fix: typo
    
    * refactor: address review comments
    
    * refactor: add index of the params/values into the error message
---
 datafusion/core/tests/sql/select.rs      |  68 +++++++++
 datafusion/expr/src/logical_plan/plan.rs | 107 +++++++++++++-
 datafusion/sql/src/planner.rs            | 239 +++++++++++++++++++++++++------
 3 files changed, 373 insertions(+), 41 deletions(-)

diff --git a/datafusion/core/tests/sql/select.rs 
b/datafusion/core/tests/sql/select.rs
index 5a56247e4..c82eee7d0 100644
--- a/datafusion/core/tests/sql/select.rs
+++ b/datafusion/core/tests/sql/select.rs
@@ -20,6 +20,7 @@ use datafusion::{
     datasource::empty::EmptyTable, from_slice::FromSlice,
     physical_plan::collect_partitioned,
 };
+use datafusion_common::ScalarValue;
 use tempfile::TempDir;
 
 #[tokio::test]
@@ -1257,6 +1258,73 @@ async fn csv_join_unaliased_subqueries() -> Result<()> {
     Ok(())
 }
 
+// Test prepare statement from sql to final result
+// This test is equivalent with the test parallel_query_with_filter below but 
using prepare statement
+#[tokio::test]
+async fn test_prepare_statement() -> Result<()> {
+    let tmp_dir = TempDir::new()?;
+    let partition_count = 4;
+    let ctx = partitioned_csv::create_ctx(&tmp_dir, partition_count).await?;
+
+    // sql to statement then to prepare logical plan with parameters
+    // c1 defined as UINT32, c2 defined as UInt64 but the params are Int32 and 
Float64
+    let logical_plan =
+        ctx.create_logical_plan("PREPARE my_plan(INT, DOUBLE) AS SELECT c1, c2 
FROM test WHERE c1 > $2 AND c1 < $1")?;
+
+    // prepare logical plan to logical plan without parameters
+    let param_values = vec![ScalarValue::Int32(Some(3)), 
ScalarValue::Float64(Some(0.0))];
+    let logical_plan = logical_plan.with_param_values(param_values)?;
+
+    // logical plan to optimized logical plan
+    let logical_plan = ctx.optimize(&logical_plan)?;
+
+    // optimized logical plan to physical plan
+    let physical_plan = ctx.create_physical_plan(&logical_plan).await?;
+
+    let task_ctx = ctx.task_ctx();
+    let results = collect_partitioned(physical_plan, task_ctx).await?;
+
+    // note that the order of partitions is not deterministic
+    let mut num_rows = 0;
+    for partition in &results {
+        for batch in partition {
+            num_rows += batch.num_rows();
+        }
+    }
+    assert_eq!(20, num_rows);
+
+    let results: Vec<RecordBatch> = results.into_iter().flatten().collect();
+    let expected = vec![
+        "+----+----+",
+        "| c1 | c2 |",
+        "+----+----+",
+        "| 1  | 1  |",
+        "| 1  | 10 |",
+        "| 1  | 2  |",
+        "| 1  | 3  |",
+        "| 1  | 4  |",
+        "| 1  | 5  |",
+        "| 1  | 6  |",
+        "| 1  | 7  |",
+        "| 1  | 8  |",
+        "| 1  | 9  |",
+        "| 2  | 1  |",
+        "| 2  | 10 |",
+        "| 2  | 2  |",
+        "| 2  | 3  |",
+        "| 2  | 4  |",
+        "| 2  | 5  |",
+        "| 2  | 6  |",
+        "| 2  | 7  |",
+        "| 2  | 8  |",
+        "| 2  | 9  |",
+        "+----+----+",
+    ];
+    assert_batches_sorted_eq!(expected, &results);
+
+    Ok(())
+}
+
 #[tokio::test]
 async fn parallel_query_with_filter() -> Result<()> {
     let tmp_dir = TempDir::new()?;
diff --git a/datafusion/expr/src/logical_plan/plan.rs 
b/datafusion/expr/src/logical_plan/plan.rs
index 7f38e7dbb..43e615e14 100644
--- a/datafusion/expr/src/logical_plan/plan.rs
+++ b/datafusion/expr/src/logical_plan/plan.rs
@@ -16,17 +16,20 @@
 // under the License.
 
 use crate::expr::BinaryExpr;
+use crate::expr_rewriter::{ExprRewritable, ExprRewriter};
 ///! Logical plan types
 use crate::logical_plan::builder::validate_unique_names;
 use crate::logical_plan::display::{GraphvizVisitor, IndentVisitor};
 use crate::logical_plan::extension::UserDefinedLogicalNode;
 use crate::utils::{
-    exprlist_to_fields, from_plan, grouping_set_expr_count, 
grouping_set_to_exprlist,
+    self, exprlist_to_fields, from_plan, grouping_set_expr_count,
+    grouping_set_to_exprlist,
 };
 use crate::{Expr, ExprSchemable, TableProviderFilterPushDown, TableSource};
 use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
 use datafusion_common::{
     plan_err, Column, DFSchema, DFSchemaRef, DataFusionError, 
OwnedTableReference,
+    ScalarValue,
 };
 use std::collections::{HashMap, HashSet};
 use std::fmt::{self, Debug, Display, Formatter};
@@ -364,6 +367,42 @@ impl LogicalPlan {
     ) -> Result<LogicalPlan, DataFusionError> {
         from_plan(self, &self.expressions(), inputs)
     }
+
+    /// Convert a prepare logical plan into its inner logical plan with all 
params replaced with their corresponding values
+    pub fn with_param_values(
+        self,
+        param_values: Vec<ScalarValue>,
+    ) -> Result<LogicalPlan, DataFusionError> {
+        match self {
+            LogicalPlan::Prepare(prepare_lp) => {
+                // Verify if the number of params matches the number of values
+                if prepare_lp.data_types.len() != param_values.len() {
+                    return Err(DataFusionError::Internal(format!(
+                        "Expected {} parameters, got {}",
+                        prepare_lp.data_types.len(),
+                        param_values.len()
+                    )));
+                }
+
+                // Verify if the types of the params matches the types of the 
values
+                let iter = 
prepare_lp.data_types.iter().zip(param_values.iter());
+                for (i, (param_type, value)) in iter.enumerate() {
+                    if *param_type != value.get_datatype() {
+                        return Err(DataFusionError::Internal(format!(
+                            "Expected parameter of type {:?}, got {:?} at 
index {}",
+                            param_type,
+                            value.get_datatype(),
+                            i
+                        )));
+                    }
+                }
+
+                let input_plan = prepare_lp.input;
+                input_plan.replace_params_with_values(&param_values)
+            }
+            _ => Ok(self),
+        }
+    }
 }
 
 /// Trait that implements the [Visitor
@@ -534,6 +573,72 @@ impl LogicalPlan {
             _ => {}
         }
     }
+
+    /// Return a logical plan with all placeholders/params (e.g $1 $2, ...) 
replaced with corresponding values provided in the prams_values
+    pub fn replace_params_with_values(
+        &self,
+        param_values: &Vec<ScalarValue>,
+    ) -> Result<LogicalPlan, DataFusionError> {
+        let exprs = self.expressions();
+        let mut new_exprs = vec![];
+        for expr in exprs {
+            new_exprs.push(Self::replace_placeholders_with_values(expr, 
param_values)?);
+        }
+
+        let new_inputs = self.inputs();
+        let mut new_inputs_with_values = vec![];
+        for input in new_inputs {
+            
new_inputs_with_values.push(input.replace_params_with_values(param_values)?);
+        }
+
+        let new_plan = utils::from_plan(self, &new_exprs, 
&new_inputs_with_values)?;
+        Ok(new_plan)
+    }
+
+    /// Return an Expr with all placeholders replaced with their corresponding 
values provided in the prams_values
+    fn replace_placeholders_with_values(
+        expr: Expr,
+        param_values: &Vec<ScalarValue>,
+    ) -> Result<Expr, DataFusionError> {
+        struct PlaceholderReplacer<'a> {
+            param_values: &'a Vec<ScalarValue>,
+        }
+
+        impl<'a> ExprRewriter for PlaceholderReplacer<'a> {
+            fn mutate(&mut self, expr: Expr) -> Result<Expr, DataFusionError> {
+                if let Expr::Placeholder { id, data_type } = &expr {
+                    // convert id (in format $1, $2, ..) to idx (0, 1, ..)
+                    let idx = id[1..].parse::<usize>().map_err(|e| {
+                        DataFusionError::Internal(format!(
+                            "Failed to parse placeholder id: {}",
+                            e
+                        ))
+                    })? - 1;
+                    // value at the idx-th position in param_values should be 
the value for the placeholder
+                    let value = self.param_values.get(idx).ok_or_else(|| {
+                        DataFusionError::Internal(format!(
+                            "No value found for placeholder with id {}",
+                            id
+                        ))
+                    })?;
+                    // check if the data type of the value matches the data 
type of the placeholder
+                    if value.get_datatype() != *data_type {
+                        return Err(DataFusionError::Internal(format!(
+                            "Placeholder value type mismatch: expected {:?}, 
got {:?}",
+                            data_type,
+                            value.get_datatype()
+                        )));
+                    }
+                    // Replace the placeholder with the value
+                    Ok(Expr::Literal(value.clone()))
+                } else {
+                    Ok(expr)
+                }
+            }
+        }
+
+        expr.rewrite(&mut PlaceholderReplacer { param_values })
+    }
 }
 
 // Various implementations for printing out LogicalPlans
diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs
index ee3839318..6f54e5803 100644
--- a/datafusion/sql/src/planner.rs
+++ b/datafusion/sql/src/planner.rs
@@ -5379,15 +5379,32 @@ mod tests {
         sql: &str,
         expected_plan: &str,
         expected_data_types: &str,
-    ) {
+    ) -> LogicalPlan {
         let plan = logical_plan(sql).unwrap();
+
+        let assert_plan = plan.clone();
         // verify plan
-        assert_eq!(format!("{:?}", plan), expected_plan);
+        assert_eq!(format!("{:?}", assert_plan), expected_plan);
+
         // verify data types
-        if let LogicalPlan::Prepare(Prepare { data_types, .. }) = plan {
+        if let LogicalPlan::Prepare(Prepare { data_types, .. }) = assert_plan {
             let dt = format!("{:?}", data_types);
             assert_eq!(dt, expected_data_types);
         }
+
+        plan
+    }
+
+    fn prepare_stmt_replace_params_quick_test(
+        plan: LogicalPlan,
+        param_values: Vec<ScalarValue>,
+        expected_plan: &str,
+    ) -> LogicalPlan {
+        // replace params
+        let plan = plan.with_param_values(param_values).unwrap();
+        assert_eq!(format!("{:?}", plan), expected_plan);
+
+        plan
     }
 
     struct MockContextProvider {}
@@ -6197,11 +6214,7 @@ mod tests {
         // param is not number following the $ sign
         // panic due to error returned from the parser
         let sql = "PREPARE my_plan(INT) AS SELECT id, age  FROM person WHERE 
age = $foo";
-
-        let expected_plan = "whatever";
-        let expected_dt = "whatever";
-
-        prepare_stmt_quick_test(sql, expected_plan, expected_dt);
+        logical_plan(sql).unwrap();
     }
 
     #[test]
@@ -6210,11 +6223,7 @@ mod tests {
         // param is not number following the $ sign
         // panic due to error returned from the parser
         let sql = "PREPARE AS SELECT id, age  FROM person WHERE age = $foo";
-
-        let expected_plan = "whatever";
-        let expected_dt = "whatever";
-
-        prepare_stmt_quick_test(sql, expected_plan, expected_dt);
+        logical_plan(sql).unwrap();
     }
 
     #[test]
@@ -6223,11 +6232,7 @@ mod tests {
     )]
     fn test_prepare_statement_to_plan_panic_no_relation_and_constant_param() {
         let sql = "PREPARE my_plan(INT) AS SELECT id + $1";
-
-        let expected_plan = "whatever";
-        let expected_dt = "whatever";
-
-        prepare_stmt_quick_test(sql, expected_plan, expected_dt);
+        logical_plan(sql).unwrap();
     }
 
     #[test]
@@ -6237,11 +6242,7 @@ mod tests {
     fn test_prepare_statement_to_plan_panic_no_data_types() {
         // only provide 1 data type while using 2 params
         let sql = "PREPARE my_plan(INT) AS SELECT 1 + $1 + $2";
-
-        let expected_plan = "whatever";
-        let expected_dt = "whatever";
-
-        prepare_stmt_quick_test(sql, expected_plan, expected_dt);
+        logical_plan(sql).unwrap();
     }
 
     #[test]
@@ -6250,11 +6251,7 @@ mod tests {
     )]
     fn test_prepare_statement_to_plan_panic_is_param() {
         let sql = "PREPARE my_plan(INT) AS SELECT id, age  FROM person WHERE 
age is $1";
-
-        let expected_plan = "whatever";
-        let expected_dt = "whatever";
-
-        prepare_stmt_quick_test(sql, expected_plan, expected_dt);
+        logical_plan(sql).unwrap();
     }
 
     #[test]
@@ -6269,9 +6266,18 @@ mod tests {
 
         let expected_dt = "[Int32]";
 
-        prepare_stmt_quick_test(sql, expected_plan, expected_dt);
+        let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt);
+
+        ///////////////////
+        // replace params with values
+        let param_values = vec![ScalarValue::Int32(Some(10))];
+        let expected_plan = "Projection: person.id, person.age\
+        \n  Filter: person.age = Int64(10)\
+        \n    TableScan: person";
 
-        /////////////////////////
+        prepare_stmt_replace_params_quick_test(plan, param_values, 
expected_plan);
+
+        //////////////////////////////////////////
         // no embedded parameter and no declare it
         let sql = "PREPARE my_plan AS SELECT id, age  FROM person WHERE age = 
10";
 
@@ -6282,7 +6288,54 @@ mod tests {
 
         let expected_dt = "[]";
 
-        prepare_stmt_quick_test(sql, expected_plan, expected_dt);
+        let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt);
+
+        ///////////////////
+        // replace params with values
+        let param_values = vec![];
+        let expected_plan = "Projection: person.id, person.age\
+        \n  Filter: person.age = Int64(10)\
+        \n    TableScan: person";
+
+        prepare_stmt_replace_params_quick_test(plan, param_values, 
expected_plan);
+    }
+
+    #[test]
+    #[should_panic(expected = "value: Internal(\"Expected 1 parameters, got 
0\")")]
+    fn test_prepare_statement_to_plan_one_param_no_value_panic() {
+        // no embedded parameter but still declare it
+        let sql = "PREPARE my_plan(INT) AS SELECT id, age  FROM person WHERE 
age = 10";
+        let plan = logical_plan(sql).unwrap();
+        // declare 1 param but provide 0
+        let param_values = vec![];
+        let expected_plan = "whatever";
+        prepare_stmt_replace_params_quick_test(plan, param_values, 
expected_plan);
+    }
+
+    #[test]
+    #[should_panic(
+        expected = "value: Internal(\"Expected parameter of type Int32, got 
Float64 at index 0\")"
+    )]
+    fn 
test_prepare_statement_to_plan_one_param_one_value_different_type_panic() {
+        // no embedded parameter but still declare it
+        let sql = "PREPARE my_plan(INT) AS SELECT id, age  FROM person WHERE 
age = 10";
+        let plan = logical_plan(sql).unwrap();
+        // declare 1 param but provide 0
+        let param_values = vec![ScalarValue::Float64(Some(20.0))];
+        let expected_plan = "whatever";
+        prepare_stmt_replace_params_quick_test(plan, param_values, 
expected_plan);
+    }
+
+    #[test]
+    #[should_panic(expected = "value: Internal(\"Expected 0 parameters, got 
1\")")]
+    fn test_prepare_statement_to_plan_no_param_on_value_panic() {
+        // no embedded parameter but still declare it
+        let sql = "PREPARE my_plan AS SELECT id, age  FROM person WHERE age = 
10";
+        let plan = logical_plan(sql).unwrap();
+        // declare 1 param but provide 0
+        let param_values = vec![ScalarValue::Int32(Some(10))];
+        let expected_plan = "whatever";
+        prepare_stmt_replace_params_quick_test(plan, param_values, 
expected_plan);
     }
 
     #[test]
@@ -6293,25 +6346,50 @@ mod tests {
         \n  Projection: $1\n    EmptyRelation";
         let expected_dt = "[Int32]";
 
-        prepare_stmt_quick_test(sql, expected_plan, expected_dt);
+        let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt);
 
-        /////////////////////////
+        ///////////////////
+        // replace params with values
+        let param_values = vec![ScalarValue::Int32(Some(10))];
+        let expected_plan = "Projection: Int32(10)\n  EmptyRelation";
+
+        prepare_stmt_replace_params_quick_test(plan, param_values, 
expected_plan);
+
+        ///////////////////////////////////////
         let sql = "PREPARE my_plan(INT) AS SELECT 1 + $1";
 
         let expected_plan = "Prepare: \"my_plan\" [Int32] \
         \n  Projection: Int64(1) + $1\n    EmptyRelation";
         let expected_dt = "[Int32]";
 
-        prepare_stmt_quick_test(sql, expected_plan, expected_dt);
+        let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt);
+
+        ///////////////////
+        // replace params with values
+        let param_values = vec![ScalarValue::Int32(Some(10))];
+        let expected_plan = "Projection: Int64(1) + Int32(10)\n  
EmptyRelation";
 
-        /////////////////////////
+        prepare_stmt_replace_params_quick_test(plan, param_values, 
expected_plan);
+
+        ///////////////////////////////////////
         let sql = "PREPARE my_plan(INT, DOUBLE) AS SELECT 1 + $1 + $2";
 
         let expected_plan = "Prepare: \"my_plan\" [Int32, Float64] \
         \n  Projection: Int64(1) + $1 + $2\n    EmptyRelation";
         let expected_dt = "[Int32, Float64]";
 
-        prepare_stmt_quick_test(sql, expected_plan, expected_dt);
+        let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt);
+
+        ///////////////////
+        // replace params with values
+        let param_values = vec![
+            ScalarValue::Int32(Some(10)),
+            ScalarValue::Float64(Some(10.0)),
+        ];
+        let expected_plan =
+            "Projection: Int64(1) + Int32(10) + Float64(10)\n  EmptyRelation";
+
+        prepare_stmt_replace_params_quick_test(plan, param_values, 
expected_plan);
     }
 
     #[test]
@@ -6325,7 +6403,41 @@ mod tests {
 
         let expected_dt = "[Int32]";
 
-        prepare_stmt_quick_test(sql, expected_plan, expected_dt);
+        let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt);
+
+        ///////////////////
+        // replace params with values
+        let param_values = vec![ScalarValue::Int32(Some(10))];
+        let expected_plan = "Projection: person.id, person.age\
+        \n  Filter: person.age = Int32(10)\
+        \n    TableScan: person";
+
+        prepare_stmt_replace_params_quick_test(plan, param_values, 
expected_plan);
+    }
+
+    #[test]
+    fn test_prepare_statement_to_plan_data_type() {
+        let sql = "PREPARE my_plan(DOUBLE) AS SELECT id, age  FROM person 
WHERE age = $1";
+
+        // age is defined as Int32 but prepare statement declares it as 
DOUBLE/Float64
+        // Prepare statement and its logical plan should be created 
successfully
+        let expected_plan = "Prepare: \"my_plan\" [Float64] \
+        \n  Projection: person.id, person.age\
+        \n    Filter: person.age = $1\
+        \n      TableScan: person";
+
+        let expected_dt = "[Float64]";
+
+        let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt);
+
+        ///////////////////
+        // replace params with values still succeed and use Float64
+        let param_values = vec![ScalarValue::Float64(Some(10.0))];
+        let expected_plan = "Projection: person.id, person.age\
+        \n  Filter: person.age = Float64(10)\
+        \n    TableScan: person";
+
+        prepare_stmt_replace_params_quick_test(plan, param_values, 
expected_plan);
     }
 
     #[test]
@@ -6342,7 +6454,24 @@ mod tests {
 
         let expected_dt = "[Int32, Utf8, Float64, Int32, Float64, Utf8]";
 
-        prepare_stmt_quick_test(sql, expected_plan, expected_dt);
+        let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt);
+
+        ///////////////////
+        // replace params with values
+        let param_values = vec![
+            ScalarValue::Int32(Some(10)),
+            ScalarValue::Utf8(Some("abc".to_string())),
+            ScalarValue::Float64(Some(100.0)),
+            ScalarValue::Int32(Some(20)),
+            ScalarValue::Float64(Some(200.0)),
+            ScalarValue::Utf8(Some("xyz".to_string())),
+        ];
+        let expected_plan =
+        "Projection: person.id, person.age, Utf8(\"xyz\")\
+        \n  Filter: person.age IN ([Int32(10), Int32(20)]) AND person.salary > 
Float64(100) AND person.salary < Float64(200) OR person.first_name < 
Utf8(\"abc\")\
+        \n    TableScan: person";
+
+        prepare_stmt_replace_params_quick_test(plan, param_values, 
expected_plan);
     }
 
     #[test]
@@ -6364,7 +6493,24 @@ mod tests {
 
         let expected_dt = "[Int32, Float64, Float64, Float64]";
 
-        prepare_stmt_quick_test(sql, expected_plan, expected_dt);
+        let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt);
+
+        ///////////////////
+        // replace params with values
+        let param_values = vec![
+            ScalarValue::Int32(Some(10)),
+            ScalarValue::Float64(Some(100.0)),
+            ScalarValue::Float64(Some(200.0)),
+            ScalarValue::Float64(Some(300.0)),
+        ];
+        let expected_plan =
+        "Projection: person.id, SUM(person.age)\
+        \n  Filter: SUM(person.age) < Int32(10) AND SUM(person.age) > 
Int64(10) OR SUM(person.age) IN ([Float64(200), Float64(300)])\
+        \n    Aggregate: groupBy=[[person.id]], aggr=[[SUM(person.age)]]\
+        \n      Filter: person.salary > Float64(100)\
+        \n        TableScan: person";
+
+        prepare_stmt_replace_params_quick_test(plan, param_values, 
expected_plan);
     }
 
     #[test]
@@ -6379,7 +6525,20 @@ mod tests {
 
         let expected_dt = "[Utf8, Utf8]";
 
-        prepare_stmt_quick_test(sql, expected_plan, expected_dt);
+        let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt);
+
+        ///////////////////
+        // replace params with values
+        let param_values = vec![
+            ScalarValue::Utf8(Some("a".to_string())),
+            ScalarValue::Utf8(Some("b".to_string())),
+        ];
+        let expected_plan = "Projection: num, letter\
+        \n  Projection: t.column1 AS num, t.column2 AS letter\
+        \n    SubqueryAlias: t\
+        \n      Values: (Int64(1), Utf8(\"a\")), (Int64(2), Utf8(\"b\"))";
+
+        prepare_stmt_replace_params_quick_test(plan, param_values, 
expected_plan);
     }
 
     #[test]

Reply via email to