This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 4ecf3e739 feat: prepare logical plan to logical plan without
params/placeholders (#4561)
4ecf3e739 is described below
commit 4ecf3e7399b1f728a5bd84c10dbfe7a4f0220dd9
Author: Nga Tran <[email protected]>
AuthorDate: Mon Dec 12 09:32:43 2022 -0500
feat: prepare logical plan to logical plan without params/placeholders
(#4561)
* feat: prepare logical plan to logicl plan without params/placeholders
* fix: typo
* refactor: address review comments
* refactor: add index of the params/values into the error message
---
datafusion/core/tests/sql/select.rs | 68 +++++++++
datafusion/expr/src/logical_plan/plan.rs | 107 +++++++++++++-
datafusion/sql/src/planner.rs | 239 +++++++++++++++++++++++++------
3 files changed, 373 insertions(+), 41 deletions(-)
diff --git a/datafusion/core/tests/sql/select.rs
b/datafusion/core/tests/sql/select.rs
index 5a56247e4..c82eee7d0 100644
--- a/datafusion/core/tests/sql/select.rs
+++ b/datafusion/core/tests/sql/select.rs
@@ -20,6 +20,7 @@ use datafusion::{
datasource::empty::EmptyTable, from_slice::FromSlice,
physical_plan::collect_partitioned,
};
+use datafusion_common::ScalarValue;
use tempfile::TempDir;
#[tokio::test]
@@ -1257,6 +1258,73 @@ async fn csv_join_unaliased_subqueries() -> Result<()> {
Ok(())
}
+// Test prepare statement from sql to final result
+// This test is equivalent with the test parallel_query_with_filter below but
using prepare statement
+#[tokio::test]
+async fn test_prepare_statement() -> Result<()> {
+ let tmp_dir = TempDir::new()?;
+ let partition_count = 4;
+ let ctx = partitioned_csv::create_ctx(&tmp_dir, partition_count).await?;
+
+ // sql to statement then to prepare logical plan with parameters
+ // c1 defined as UINT32, c2 defined as UInt64 but the params are Int32 and
Float64
+ let logical_plan =
+ ctx.create_logical_plan("PREPARE my_plan(INT, DOUBLE) AS SELECT c1, c2
FROM test WHERE c1 > $2 AND c1 < $1")?;
+
+ // prepare logical plan to logical plan without parameters
+ let param_values = vec![ScalarValue::Int32(Some(3)),
ScalarValue::Float64(Some(0.0))];
+ let logical_plan = logical_plan.with_param_values(param_values)?;
+
+ // logical plan to optimized logical plan
+ let logical_plan = ctx.optimize(&logical_plan)?;
+
+ // optimized logical plan to physical plan
+ let physical_plan = ctx.create_physical_plan(&logical_plan).await?;
+
+ let task_ctx = ctx.task_ctx();
+ let results = collect_partitioned(physical_plan, task_ctx).await?;
+
+ // note that the order of partitions is not deterministic
+ let mut num_rows = 0;
+ for partition in &results {
+ for batch in partition {
+ num_rows += batch.num_rows();
+ }
+ }
+ assert_eq!(20, num_rows);
+
+ let results: Vec<RecordBatch> = results.into_iter().flatten().collect();
+ let expected = vec![
+ "+----+----+",
+ "| c1 | c2 |",
+ "+----+----+",
+ "| 1 | 1 |",
+ "| 1 | 10 |",
+ "| 1 | 2 |",
+ "| 1 | 3 |",
+ "| 1 | 4 |",
+ "| 1 | 5 |",
+ "| 1 | 6 |",
+ "| 1 | 7 |",
+ "| 1 | 8 |",
+ "| 1 | 9 |",
+ "| 2 | 1 |",
+ "| 2 | 10 |",
+ "| 2 | 2 |",
+ "| 2 | 3 |",
+ "| 2 | 4 |",
+ "| 2 | 5 |",
+ "| 2 | 6 |",
+ "| 2 | 7 |",
+ "| 2 | 8 |",
+ "| 2 | 9 |",
+ "+----+----+",
+ ];
+ assert_batches_sorted_eq!(expected, &results);
+
+ Ok(())
+}
+
#[tokio::test]
async fn parallel_query_with_filter() -> Result<()> {
let tmp_dir = TempDir::new()?;
diff --git a/datafusion/expr/src/logical_plan/plan.rs
b/datafusion/expr/src/logical_plan/plan.rs
index 7f38e7dbb..43e615e14 100644
--- a/datafusion/expr/src/logical_plan/plan.rs
+++ b/datafusion/expr/src/logical_plan/plan.rs
@@ -16,17 +16,20 @@
// under the License.
use crate::expr::BinaryExpr;
+use crate::expr_rewriter::{ExprRewritable, ExprRewriter};
///! Logical plan types
use crate::logical_plan::builder::validate_unique_names;
use crate::logical_plan::display::{GraphvizVisitor, IndentVisitor};
use crate::logical_plan::extension::UserDefinedLogicalNode;
use crate::utils::{
- exprlist_to_fields, from_plan, grouping_set_expr_count,
grouping_set_to_exprlist,
+ self, exprlist_to_fields, from_plan, grouping_set_expr_count,
+ grouping_set_to_exprlist,
};
use crate::{Expr, ExprSchemable, TableProviderFilterPushDown, TableSource};
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use datafusion_common::{
plan_err, Column, DFSchema, DFSchemaRef, DataFusionError,
OwnedTableReference,
+ ScalarValue,
};
use std::collections::{HashMap, HashSet};
use std::fmt::{self, Debug, Display, Formatter};
@@ -364,6 +367,42 @@ impl LogicalPlan {
) -> Result<LogicalPlan, DataFusionError> {
from_plan(self, &self.expressions(), inputs)
}
+
+ /// Convert a prepare logical plan into its inner logical plan with all
params replaced with their corresponding values
+ pub fn with_param_values(
+ self,
+ param_values: Vec<ScalarValue>,
+ ) -> Result<LogicalPlan, DataFusionError> {
+ match self {
+ LogicalPlan::Prepare(prepare_lp) => {
+ // Verify if the number of params matches the number of values
+ if prepare_lp.data_types.len() != param_values.len() {
+ return Err(DataFusionError::Internal(format!(
+ "Expected {} parameters, got {}",
+ prepare_lp.data_types.len(),
+ param_values.len()
+ )));
+ }
+
+ // Verify if the types of the params matches the types of the
values
+ let iter =
prepare_lp.data_types.iter().zip(param_values.iter());
+ for (i, (param_type, value)) in iter.enumerate() {
+ if *param_type != value.get_datatype() {
+ return Err(DataFusionError::Internal(format!(
+ "Expected parameter of type {:?}, got {:?} at
index {}",
+ param_type,
+ value.get_datatype(),
+ i
+ )));
+ }
+ }
+
+ let input_plan = prepare_lp.input;
+ input_plan.replace_params_with_values(¶m_values)
+ }
+ _ => Ok(self),
+ }
+ }
}
/// Trait that implements the [Visitor
@@ -534,6 +573,72 @@ impl LogicalPlan {
_ => {}
}
}
+
+ /// Return a logical plan with all placeholders/params (e.g $1 $2, ...)
replaced with corresponding values provided in the prams_values
+ pub fn replace_params_with_values(
+ &self,
+ param_values: &Vec<ScalarValue>,
+ ) -> Result<LogicalPlan, DataFusionError> {
+ let exprs = self.expressions();
+ let mut new_exprs = vec![];
+ for expr in exprs {
+ new_exprs.push(Self::replace_placeholders_with_values(expr,
param_values)?);
+ }
+
+ let new_inputs = self.inputs();
+ let mut new_inputs_with_values = vec![];
+ for input in new_inputs {
+
new_inputs_with_values.push(input.replace_params_with_values(param_values)?);
+ }
+
+ let new_plan = utils::from_plan(self, &new_exprs,
&new_inputs_with_values)?;
+ Ok(new_plan)
+ }
+
+ /// Return an Expr with all placeholders replaced with their corresponding
values provided in the prams_values
+ fn replace_placeholders_with_values(
+ expr: Expr,
+ param_values: &Vec<ScalarValue>,
+ ) -> Result<Expr, DataFusionError> {
+ struct PlaceholderReplacer<'a> {
+ param_values: &'a Vec<ScalarValue>,
+ }
+
+ impl<'a> ExprRewriter for PlaceholderReplacer<'a> {
+ fn mutate(&mut self, expr: Expr) -> Result<Expr, DataFusionError> {
+ if let Expr::Placeholder { id, data_type } = &expr {
+ // convert id (in format $1, $2, ..) to idx (0, 1, ..)
+ let idx = id[1..].parse::<usize>().map_err(|e| {
+ DataFusionError::Internal(format!(
+ "Failed to parse placeholder id: {}",
+ e
+ ))
+ })? - 1;
+ // value at the idx-th position in param_values should be
the value for the placeholder
+ let value = self.param_values.get(idx).ok_or_else(|| {
+ DataFusionError::Internal(format!(
+ "No value found for placeholder with id {}",
+ id
+ ))
+ })?;
+ // check if the data type of the value matches the data
type of the placeholder
+ if value.get_datatype() != *data_type {
+ return Err(DataFusionError::Internal(format!(
+ "Placeholder value type mismatch: expected {:?},
got {:?}",
+ data_type,
+ value.get_datatype()
+ )));
+ }
+ // Replace the placeholder with the value
+ Ok(Expr::Literal(value.clone()))
+ } else {
+ Ok(expr)
+ }
+ }
+ }
+
+ expr.rewrite(&mut PlaceholderReplacer { param_values })
+ }
}
// Various implementations for printing out LogicalPlans
diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs
index ee3839318..6f54e5803 100644
--- a/datafusion/sql/src/planner.rs
+++ b/datafusion/sql/src/planner.rs
@@ -5379,15 +5379,32 @@ mod tests {
sql: &str,
expected_plan: &str,
expected_data_types: &str,
- ) {
+ ) -> LogicalPlan {
let plan = logical_plan(sql).unwrap();
+
+ let assert_plan = plan.clone();
// verify plan
- assert_eq!(format!("{:?}", plan), expected_plan);
+ assert_eq!(format!("{:?}", assert_plan), expected_plan);
+
// verify data types
- if let LogicalPlan::Prepare(Prepare { data_types, .. }) = plan {
+ if let LogicalPlan::Prepare(Prepare { data_types, .. }) = assert_plan {
let dt = format!("{:?}", data_types);
assert_eq!(dt, expected_data_types);
}
+
+ plan
+ }
+
+ fn prepare_stmt_replace_params_quick_test(
+ plan: LogicalPlan,
+ param_values: Vec<ScalarValue>,
+ expected_plan: &str,
+ ) -> LogicalPlan {
+ // replace params
+ let plan = plan.with_param_values(param_values).unwrap();
+ assert_eq!(format!("{:?}", plan), expected_plan);
+
+ plan
}
struct MockContextProvider {}
@@ -6197,11 +6214,7 @@ mod tests {
// param is not number following the $ sign
// panic due to error returned from the parser
let sql = "PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE
age = $foo";
-
- let expected_plan = "whatever";
- let expected_dt = "whatever";
-
- prepare_stmt_quick_test(sql, expected_plan, expected_dt);
+ logical_plan(sql).unwrap();
}
#[test]
@@ -6210,11 +6223,7 @@ mod tests {
// param is not number following the $ sign
// panic due to error returned from the parser
let sql = "PREPARE AS SELECT id, age FROM person WHERE age = $foo";
-
- let expected_plan = "whatever";
- let expected_dt = "whatever";
-
- prepare_stmt_quick_test(sql, expected_plan, expected_dt);
+ logical_plan(sql).unwrap();
}
#[test]
@@ -6223,11 +6232,7 @@ mod tests {
)]
fn test_prepare_statement_to_plan_panic_no_relation_and_constant_param() {
let sql = "PREPARE my_plan(INT) AS SELECT id + $1";
-
- let expected_plan = "whatever";
- let expected_dt = "whatever";
-
- prepare_stmt_quick_test(sql, expected_plan, expected_dt);
+ logical_plan(sql).unwrap();
}
#[test]
@@ -6237,11 +6242,7 @@ mod tests {
fn test_prepare_statement_to_plan_panic_no_data_types() {
// only provide 1 data type while using 2 params
let sql = "PREPARE my_plan(INT) AS SELECT 1 + $1 + $2";
-
- let expected_plan = "whatever";
- let expected_dt = "whatever";
-
- prepare_stmt_quick_test(sql, expected_plan, expected_dt);
+ logical_plan(sql).unwrap();
}
#[test]
@@ -6250,11 +6251,7 @@ mod tests {
)]
fn test_prepare_statement_to_plan_panic_is_param() {
let sql = "PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE
age is $1";
-
- let expected_plan = "whatever";
- let expected_dt = "whatever";
-
- prepare_stmt_quick_test(sql, expected_plan, expected_dt);
+ logical_plan(sql).unwrap();
}
#[test]
@@ -6269,9 +6266,18 @@ mod tests {
let expected_dt = "[Int32]";
- prepare_stmt_quick_test(sql, expected_plan, expected_dt);
+ let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt);
+
+ ///////////////////
+ // replace params with values
+ let param_values = vec![ScalarValue::Int32(Some(10))];
+ let expected_plan = "Projection: person.id, person.age\
+ \n Filter: person.age = Int64(10)\
+ \n TableScan: person";
- /////////////////////////
+ prepare_stmt_replace_params_quick_test(plan, param_values,
expected_plan);
+
+ //////////////////////////////////////////
// no embedded parameter and no declare it
let sql = "PREPARE my_plan AS SELECT id, age FROM person WHERE age =
10";
@@ -6282,7 +6288,54 @@ mod tests {
let expected_dt = "[]";
- prepare_stmt_quick_test(sql, expected_plan, expected_dt);
+ let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt);
+
+ ///////////////////
+ // replace params with values
+ let param_values = vec![];
+ let expected_plan = "Projection: person.id, person.age\
+ \n Filter: person.age = Int64(10)\
+ \n TableScan: person";
+
+ prepare_stmt_replace_params_quick_test(plan, param_values,
expected_plan);
+ }
+
+ #[test]
+ #[should_panic(expected = "value: Internal(\"Expected 1 parameters, got
0\")")]
+ fn test_prepare_statement_to_plan_one_param_no_value_panic() {
+ // no embedded parameter but still declare it
+ let sql = "PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE
age = 10";
+ let plan = logical_plan(sql).unwrap();
+ // declare 1 param but provide 0
+ let param_values = vec![];
+ let expected_plan = "whatever";
+ prepare_stmt_replace_params_quick_test(plan, param_values,
expected_plan);
+ }
+
+ #[test]
+ #[should_panic(
+ expected = "value: Internal(\"Expected parameter of type Int32, got
Float64 at index 0\")"
+ )]
+ fn
test_prepare_statement_to_plan_one_param_one_value_different_type_panic() {
+ // no embedded parameter but still declare it
+ let sql = "PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE
age = 10";
+ let plan = logical_plan(sql).unwrap();
+ // declare 1 param but provide 0
+ let param_values = vec![ScalarValue::Float64(Some(20.0))];
+ let expected_plan = "whatever";
+ prepare_stmt_replace_params_quick_test(plan, param_values,
expected_plan);
+ }
+
+ #[test]
+ #[should_panic(expected = "value: Internal(\"Expected 0 parameters, got
1\")")]
+ fn test_prepare_statement_to_plan_no_param_on_value_panic() {
+ // no embedded parameter but still declare it
+ let sql = "PREPARE my_plan AS SELECT id, age FROM person WHERE age =
10";
+ let plan = logical_plan(sql).unwrap();
+ // declare 1 param but provide 0
+ let param_values = vec![ScalarValue::Int32(Some(10))];
+ let expected_plan = "whatever";
+ prepare_stmt_replace_params_quick_test(plan, param_values,
expected_plan);
}
#[test]
@@ -6293,25 +6346,50 @@ mod tests {
\n Projection: $1\n EmptyRelation";
let expected_dt = "[Int32]";
- prepare_stmt_quick_test(sql, expected_plan, expected_dt);
+ let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt);
- /////////////////////////
+ ///////////////////
+ // replace params with values
+ let param_values = vec![ScalarValue::Int32(Some(10))];
+ let expected_plan = "Projection: Int32(10)\n EmptyRelation";
+
+ prepare_stmt_replace_params_quick_test(plan, param_values,
expected_plan);
+
+ ///////////////////////////////////////
let sql = "PREPARE my_plan(INT) AS SELECT 1 + $1";
let expected_plan = "Prepare: \"my_plan\" [Int32] \
\n Projection: Int64(1) + $1\n EmptyRelation";
let expected_dt = "[Int32]";
- prepare_stmt_quick_test(sql, expected_plan, expected_dt);
+ let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt);
+
+ ///////////////////
+ // replace params with values
+ let param_values = vec![ScalarValue::Int32(Some(10))];
+ let expected_plan = "Projection: Int64(1) + Int32(10)\n
EmptyRelation";
- /////////////////////////
+ prepare_stmt_replace_params_quick_test(plan, param_values,
expected_plan);
+
+ ///////////////////////////////////////
let sql = "PREPARE my_plan(INT, DOUBLE) AS SELECT 1 + $1 + $2";
let expected_plan = "Prepare: \"my_plan\" [Int32, Float64] \
\n Projection: Int64(1) + $1 + $2\n EmptyRelation";
let expected_dt = "[Int32, Float64]";
- prepare_stmt_quick_test(sql, expected_plan, expected_dt);
+ let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt);
+
+ ///////////////////
+ // replace params with values
+ let param_values = vec![
+ ScalarValue::Int32(Some(10)),
+ ScalarValue::Float64(Some(10.0)),
+ ];
+ let expected_plan =
+ "Projection: Int64(1) + Int32(10) + Float64(10)\n EmptyRelation";
+
+ prepare_stmt_replace_params_quick_test(plan, param_values,
expected_plan);
}
#[test]
@@ -6325,7 +6403,41 @@ mod tests {
let expected_dt = "[Int32]";
- prepare_stmt_quick_test(sql, expected_plan, expected_dt);
+ let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt);
+
+ ///////////////////
+ // replace params with values
+ let param_values = vec![ScalarValue::Int32(Some(10))];
+ let expected_plan = "Projection: person.id, person.age\
+ \n Filter: person.age = Int32(10)\
+ \n TableScan: person";
+
+ prepare_stmt_replace_params_quick_test(plan, param_values,
expected_plan);
+ }
+
+ #[test]
+ fn test_prepare_statement_to_plan_data_type() {
+ let sql = "PREPARE my_plan(DOUBLE) AS SELECT id, age FROM person
WHERE age = $1";
+
+ // age is defined as Int32 but prepare statement declares it as
DOUBLE/Float64
+ // Prepare statement and its logical plan should be created
successfully
+ let expected_plan = "Prepare: \"my_plan\" [Float64] \
+ \n Projection: person.id, person.age\
+ \n Filter: person.age = $1\
+ \n TableScan: person";
+
+ let expected_dt = "[Float64]";
+
+ let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt);
+
+ ///////////////////
+ // replace params with values still succeed and use Float64
+ let param_values = vec![ScalarValue::Float64(Some(10.0))];
+ let expected_plan = "Projection: person.id, person.age\
+ \n Filter: person.age = Float64(10)\
+ \n TableScan: person";
+
+ prepare_stmt_replace_params_quick_test(plan, param_values,
expected_plan);
}
#[test]
@@ -6342,7 +6454,24 @@ mod tests {
let expected_dt = "[Int32, Utf8, Float64, Int32, Float64, Utf8]";
- prepare_stmt_quick_test(sql, expected_plan, expected_dt);
+ let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt);
+
+ ///////////////////
+ // replace params with values
+ let param_values = vec![
+ ScalarValue::Int32(Some(10)),
+ ScalarValue::Utf8(Some("abc".to_string())),
+ ScalarValue::Float64(Some(100.0)),
+ ScalarValue::Int32(Some(20)),
+ ScalarValue::Float64(Some(200.0)),
+ ScalarValue::Utf8(Some("xyz".to_string())),
+ ];
+ let expected_plan =
+ "Projection: person.id, person.age, Utf8(\"xyz\")\
+ \n Filter: person.age IN ([Int32(10), Int32(20)]) AND person.salary >
Float64(100) AND person.salary < Float64(200) OR person.first_name <
Utf8(\"abc\")\
+ \n TableScan: person";
+
+ prepare_stmt_replace_params_quick_test(plan, param_values,
expected_plan);
}
#[test]
@@ -6364,7 +6493,24 @@ mod tests {
let expected_dt = "[Int32, Float64, Float64, Float64]";
- prepare_stmt_quick_test(sql, expected_plan, expected_dt);
+ let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt);
+
+ ///////////////////
+ // replace params with values
+ let param_values = vec![
+ ScalarValue::Int32(Some(10)),
+ ScalarValue::Float64(Some(100.0)),
+ ScalarValue::Float64(Some(200.0)),
+ ScalarValue::Float64(Some(300.0)),
+ ];
+ let expected_plan =
+ "Projection: person.id, SUM(person.age)\
+ \n Filter: SUM(person.age) < Int32(10) AND SUM(person.age) >
Int64(10) OR SUM(person.age) IN ([Float64(200), Float64(300)])\
+ \n Aggregate: groupBy=[[person.id]], aggr=[[SUM(person.age)]]\
+ \n Filter: person.salary > Float64(100)\
+ \n TableScan: person";
+
+ prepare_stmt_replace_params_quick_test(plan, param_values,
expected_plan);
}
#[test]
@@ -6379,7 +6525,20 @@ mod tests {
let expected_dt = "[Utf8, Utf8]";
- prepare_stmt_quick_test(sql, expected_plan, expected_dt);
+ let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt);
+
+ ///////////////////
+ // replace params with values
+ let param_values = vec![
+ ScalarValue::Utf8(Some("a".to_string())),
+ ScalarValue::Utf8(Some("b".to_string())),
+ ];
+ let expected_plan = "Projection: num, letter\
+ \n Projection: t.column1 AS num, t.column2 AS letter\
+ \n SubqueryAlias: t\
+ \n Values: (Int64(1), Utf8(\"a\")), (Int64(2), Utf8(\"b\"))";
+
+ prepare_stmt_replace_params_quick_test(plan, param_values,
expected_plan);
}
#[test]