This is an automated email from the ASF dual-hosted git repository.

avantgardner pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new b4cf60acb Replace placeholders in ScalarSubqueries (#5216)
b4cf60acb is described below

commit b4cf60acb00294baeed6574b6b08a4963204c23f
Author: Brent Gardner <[email protected]>
AuthorDate: Wed Feb 8 08:42:38 2023 -0700

    Replace placeholders in ScalarSubqueries (#5216)
    
    * Failing subquery test
    
    * Fix test
    
    * fmt
---
 datafusion/expr/src/logical_plan/plan.rs | 57 ++++++++++++++++++--------------
 datafusion/sql/tests/integration_test.rs | 41 +++++++++++++++++++++++
 2 files changed, 74 insertions(+), 24 deletions(-)

diff --git a/datafusion/expr/src/logical_plan/plan.rs 
b/datafusion/expr/src/logical_plan/plan.rs
index 47d658535..8f8c4fd65 100644
--- a/datafusion/expr/src/logical_plan/plan.rs
+++ b/datafusion/expr/src/logical_plan/plan.rs
@@ -22,6 +22,7 @@ use crate::expr_visitor::{ExprVisitable, ExpressionVisitor, 
Recursion};
 use crate::logical_plan::builder::validate_unique_names;
 use crate::logical_plan::display::{GraphvizVisitor, IndentVisitor};
 use crate::logical_plan::extension::UserDefinedLogicalNode;
+use crate::logical_plan::plan;
 use crate::utils::{
     self, exprlist_to_fields, from_plan, grouping_set_expr_count,
     grouping_set_to_exprlist,
@@ -710,31 +711,39 @@ impl LogicalPlan {
         param_values: &[ScalarValue],
     ) -> Result<Expr, DataFusionError> {
         rewrite_expr(expr, |expr| {
-            if let Expr::Placeholder { id, data_type } = &expr {
-                // convert id (in format $1, $2, ..) to idx (0, 1, ..)
-                let idx = id[1..].parse::<usize>().map_err(|e| {
-                    DataFusionError::Internal(format!(
-                        "Failed to parse placeholder id: {e}"
-                    ))
-                })? - 1;
-                // value at the idx-th position in param_values should be the 
value for the placeholder
-                let value = param_values.get(idx).ok_or_else(|| {
-                    DataFusionError::Internal(format!(
-                        "No value found for placeholder with id {id}"
-                    ))
-                })?;
-                // check if the data type of the value matches the data type 
of the placeholder
-                if Some(value.get_datatype()) != *data_type {
-                    return Err(DataFusionError::Internal(format!(
-                        "Placeholder value type mismatch: expected {:?}, got 
{:?}",
-                        data_type,
-                        value.get_datatype()
-                    )));
+            match &expr {
+                Expr::Placeholder { id, data_type } => {
+                    // convert id (in format $1, $2, ..) to idx (0, 1, ..)
+                    let idx = id[1..].parse::<usize>().map_err(|e| {
+                        DataFusionError::Internal(format!(
+                            "Failed to parse placeholder id: {e}"
+                        ))
+                    })? - 1;
+                    // value at the idx-th position in param_values should be 
the value for the placeholder
+                    let value = param_values.get(idx).ok_or_else(|| {
+                        DataFusionError::Internal(format!(
+                            "No value found for placeholder with id {id}"
+                        ))
+                    })?;
+                    // check if the data type of the value matches the data 
type of the placeholder
+                    if Some(value.get_datatype()) != *data_type {
+                        return Err(DataFusionError::Internal(format!(
+                            "Placeholder value type mismatch: expected {:?}, 
got {:?}",
+                            data_type,
+                            value.get_datatype()
+                        )));
+                    }
+                    // Replace the placeholder with the value
+                    Ok(Expr::Literal(value.clone()))
                 }
-                // Replace the placeholder with the value
-                Ok(Expr::Literal(value.clone()))
-            } else {
-                Ok(expr)
+                Expr::ScalarSubquery(qry) => {
+                    let subquery = Arc::new(
+                        qry.subquery
+                            
.replace_params_with_values(&param_values.to_vec())?,
+                    );
+                    Ok(Expr::ScalarSubquery(plan::Subquery { subquery }))
+                }
+                _ => Ok(expr),
             }
         })
     }
diff --git a/datafusion/sql/tests/integration_test.rs 
b/datafusion/sql/tests/integration_test.rs
index 761bbf345..38ede4e6b 100644
--- a/datafusion/sql/tests/integration_test.rs
+++ b/datafusion/sql/tests/integration_test.rs
@@ -3527,6 +3527,47 @@ Projection: person.id, person.age
     prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan);
 }
 
+#[test]
+fn test_prepare_statement_infer_types_subquery() {
+    let sql = "SELECT id, age FROM person WHERE age = (select max(age) from 
person where id = $1)";
+
+    let expected_plan = r#"
+Projection: person.id, person.age
+  Filter: person.age = (<subquery>)
+    Subquery:
+      Projection: MAX(person.age)
+        Aggregate: groupBy=[[]], aggr=[[MAX(person.age)]]
+          Filter: person.id = $1
+            TableScan: person
+    TableScan: person
+        "#
+    .trim();
+
+    let expected_dt = "[Int32]";
+    let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt);
+
+    let actual_types = plan.get_parameter_types().unwrap();
+    let expected_types = HashMap::from([("$1".to_string(), 
Some(DataType::UInt32))]);
+    assert_eq!(actual_types, expected_types);
+
+    // replace params with values
+    let param_values = vec![ScalarValue::UInt32(Some(10))];
+    let expected_plan = r#"
+Projection: person.id, person.age
+  Filter: person.age = (<subquery>)
+    Subquery:
+      Projection: MAX(person.age)
+        Aggregate: groupBy=[[]], aggr=[[MAX(person.age)]]
+          Filter: person.id = UInt32(10)
+            TableScan: person
+    TableScan: person
+        "#
+    .trim();
+    let plan = plan.replace_params_with_values(&param_values).unwrap();
+
+    prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan);
+}
+
 #[test]
 fn test_prepare_statement_update_infer() {
     let sql = "update person set age=$1 where id=$2";

Reply via email to