This is an automated email from the ASF dual-hosted git repository.
avantgardner pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new b4cf60acb Replace placeholders in ScalarSubqueries (#5216)
b4cf60acb is described below
commit b4cf60acb00294baeed6574b6b08a4963204c23f
Author: Brent Gardner <[email protected]>
AuthorDate: Wed Feb 8 08:42:38 2023 -0700
Replace placeholders in ScalarSubqueries (#5216)
* Failing subquery test
* Fix test
* fmt
---
datafusion/expr/src/logical_plan/plan.rs | 57 ++++++++++++++++++--------------
datafusion/sql/tests/integration_test.rs | 41 +++++++++++++++++++++++
2 files changed, 74 insertions(+), 24 deletions(-)
diff --git a/datafusion/expr/src/logical_plan/plan.rs
b/datafusion/expr/src/logical_plan/plan.rs
index 47d658535..8f8c4fd65 100644
--- a/datafusion/expr/src/logical_plan/plan.rs
+++ b/datafusion/expr/src/logical_plan/plan.rs
@@ -22,6 +22,7 @@ use crate::expr_visitor::{ExprVisitable, ExpressionVisitor,
Recursion};
use crate::logical_plan::builder::validate_unique_names;
use crate::logical_plan::display::{GraphvizVisitor, IndentVisitor};
use crate::logical_plan::extension::UserDefinedLogicalNode;
+use crate::logical_plan::plan;
use crate::utils::{
self, exprlist_to_fields, from_plan, grouping_set_expr_count,
grouping_set_to_exprlist,
@@ -710,31 +711,39 @@ impl LogicalPlan {
param_values: &[ScalarValue],
) -> Result<Expr, DataFusionError> {
rewrite_expr(expr, |expr| {
- if let Expr::Placeholder { id, data_type } = &expr {
- // convert id (in format $1, $2, ..) to idx (0, 1, ..)
- let idx = id[1..].parse::<usize>().map_err(|e| {
- DataFusionError::Internal(format!(
- "Failed to parse placeholder id: {e}"
- ))
- })? - 1;
- // value at the idx-th position in param_values should be the
value for the placeholder
- let value = param_values.get(idx).ok_or_else(|| {
- DataFusionError::Internal(format!(
- "No value found for placeholder with id {id}"
- ))
- })?;
- // check if the data type of the value matches the data type
of the placeholder
- if Some(value.get_datatype()) != *data_type {
- return Err(DataFusionError::Internal(format!(
- "Placeholder value type mismatch: expected {:?}, got
{:?}",
- data_type,
- value.get_datatype()
- )));
+ match &expr {
+ Expr::Placeholder { id, data_type } => {
+ // convert id (in format $1, $2, ..) to idx (0, 1, ..)
+ let idx = id[1..].parse::<usize>().map_err(|e| {
+ DataFusionError::Internal(format!(
+ "Failed to parse placeholder id: {e}"
+ ))
+ })? - 1;
+ // value at the idx-th position in param_values should be
the value for the placeholder
+ let value = param_values.get(idx).ok_or_else(|| {
+ DataFusionError::Internal(format!(
+ "No value found for placeholder with id {id}"
+ ))
+ })?;
+ // check if the data type of the value matches the data
type of the placeholder
+ if Some(value.get_datatype()) != *data_type {
+ return Err(DataFusionError::Internal(format!(
+ "Placeholder value type mismatch: expected {:?},
got {:?}",
+ data_type,
+ value.get_datatype()
+ )));
+ }
+ // Replace the placeholder with the value
+ Ok(Expr::Literal(value.clone()))
}
- // Replace the placeholder with the value
- Ok(Expr::Literal(value.clone()))
- } else {
- Ok(expr)
+ Expr::ScalarSubquery(qry) => {
+ let subquery = Arc::new(
+ qry.subquery
+
.replace_params_with_values(¶m_values.to_vec())?,
+ );
+ Ok(Expr::ScalarSubquery(plan::Subquery { subquery }))
+ }
+ _ => Ok(expr),
}
})
}
diff --git a/datafusion/sql/tests/integration_test.rs
b/datafusion/sql/tests/integration_test.rs
index 761bbf345..38ede4e6b 100644
--- a/datafusion/sql/tests/integration_test.rs
+++ b/datafusion/sql/tests/integration_test.rs
@@ -3527,6 +3527,47 @@ Projection: person.id, person.age
prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan);
}
+#[test]
+fn test_prepare_statement_infer_types_subquery() {
+ let sql = "SELECT id, age FROM person WHERE age = (select max(age) from
person where id = $1)";
+
+ let expected_plan = r#"
+Projection: person.id, person.age
+ Filter: person.age = (<subquery>)
+ Subquery:
+ Projection: MAX(person.age)
+ Aggregate: groupBy=[[]], aggr=[[MAX(person.age)]]
+ Filter: person.id = $1
+ TableScan: person
+ TableScan: person
+ "#
+ .trim();
+
+ let expected_dt = "[Int32]";
+ let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt);
+
+ let actual_types = plan.get_parameter_types().unwrap();
+ let expected_types = HashMap::from([("$1".to_string(),
Some(DataType::UInt32))]);
+ assert_eq!(actual_types, expected_types);
+
+ // replace params with values
+ let param_values = vec![ScalarValue::UInt32(Some(10))];
+ let expected_plan = r#"
+Projection: person.id, person.age
+ Filter: person.age = (<subquery>)
+ Subquery:
+ Projection: MAX(person.age)
+ Aggregate: groupBy=[[]], aggr=[[MAX(person.age)]]
+ Filter: person.id = UInt32(10)
+ TableScan: person
+ TableScan: person
+ "#
+ .trim();
+ let plan = plan.replace_params_with_values(¶m_values).unwrap();
+
+ prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan);
+}
+
#[test]
fn test_prepare_statement_update_infer() {
let sql = "update person set age=$1 where id=$2";