This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new d33595ab55 Add documentation and usability for prepared parameters 
(#7785)
d33595ab55 is described below

commit d33595ab55de6e4317c52bceedc0231b25af5fef
Author: Andrew Lamb <[email protected]>
AuthorDate: Mon Oct 16 10:19:01 2023 -0400

    Add documentation and usability for prepared parameters (#7785)
    
    * Add documentation for prepared parameters + make it eaiser to use
    
    * Update datafusion/expr/src/expr.rs
    
    Co-authored-by: jakevin <[email protected]>
    
    ---------
    
    Co-authored-by: jakevin <[email protected]>
---
 datafusion/core/src/dataframe.rs         | 37 ++++++++++++++++-
 datafusion/expr/src/expr.rs              | 58 +++++++++++++++++++++++++--
 datafusion/expr/src/expr_fn.rs           | 20 ++++++++-
 datafusion/expr/src/logical_plan/plan.rs | 69 +++++++++++++++++++++++---------
 datafusion/sql/src/expr/mod.rs           | 48 +---------------------
 datafusion/sql/tests/sql_integration.rs  | 13 ++++++
 6 files changed, 175 insertions(+), 70 deletions(-)

diff --git a/datafusion/core/src/dataframe.rs b/datafusion/core/src/dataframe.rs
index d704c7f304..79b8fcd519 100644
--- a/datafusion/core/src/dataframe.rs
+++ b/datafusion/core/src/dataframe.rs
@@ -1210,7 +1210,42 @@ impl DataFrame {
         Ok(DataFrame::new(self.session_state, project_plan))
     }
 
-    /// Convert a prepare logical plan into its inner logical plan with all 
params replaced with their corresponding values
+    /// Replace all parameters in logical plan with the specified
+    /// values, in preparation for execution.
+    ///
+    /// # Example
+    ///
+    /// ```
+    /// use datafusion::prelude::*;
+    /// # use datafusion::{error::Result, assert_batches_eq};
+    /// # #[tokio::main]
+    /// # async fn main() -> Result<()> {
+    /// # use datafusion_common::ScalarValue;
+    /// let mut ctx = SessionContext::new();
+    /// # ctx.register_csv("example", "tests/data/example.csv", 
CsvReadOptions::new()).await?;
+    /// let results = ctx
+    ///   .sql("SELECT a FROM example WHERE b = $1")
+    ///   .await?
+    ///    // replace $1 with value 2
+    ///   .with_param_values(vec![
+    ///      // value at index 0 --> $1
+    ///      ScalarValue::from(2i64)
+    ///    ])?
+    ///   .collect()
+    ///   .await?;
+    /// assert_batches_eq!(
+    ///  &[
+    ///    "+---+",
+    ///    "| a |",
+    ///    "+---+",
+    ///    "| 1 |",
+    ///    "+---+",
+    ///  ],
+    ///  &results
+    /// );
+    /// # Ok(())
+    /// # }
+    /// ```
     pub fn with_param_values(self, param_values: Vec<ScalarValue>) -> 
Result<Self> {
         let plan = self.plan.with_param_values(param_values)?;
         Ok(Self::new(self.session_state, plan))
diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs
index 3949d25b30..0b166107fb 100644
--- a/datafusion/expr/src/expr.rs
+++ b/datafusion/expr/src/expr.rs
@@ -17,7 +17,6 @@
 
 //! Expr module contains core type definition for `Expr`.
 
-use crate::aggregate_function;
 use crate::built_in_function;
 use crate::expr_fn::binary_expr;
 use crate::logical_plan::Subquery;
@@ -26,8 +25,10 @@ use crate::utils::{expr_to_columns, 
find_out_reference_exprs};
 use crate::window_frame;
 use crate::window_function;
 use crate::Operator;
+use crate::{aggregate_function, ExprSchemable};
 use arrow::datatypes::DataType;
-use datafusion_common::internal_err;
+use datafusion_common::tree_node::{Transformed, TreeNode};
+use datafusion_common::{internal_err, DFSchema};
 use datafusion_common::{plan_err, Column, DataFusionError, Result, 
ScalarValue};
 use std::collections::HashSet;
 use std::fmt;
@@ -605,10 +606,13 @@ impl InSubquery {
     }
 }
 
-/// Placeholder
+/// Placeholder, representing bind parameter values such as `$1`.
+///
+/// The type of these parameters is inferred using 
[`Expr::infer_placeholder_types`]
+/// or can be specified directly using `PREPARE` statements.
 #[derive(Clone, PartialEq, Eq, Hash, Debug)]
 pub struct Placeholder {
-    /// The identifier of the parameter (e.g, $1 or $foo)
+    /// The identifier of the parameter, including the leading `$` (e.g, 
`"$1"` or `"$foo"`)
     pub id: String,
     /// The type the parameter will be filled in with
     pub data_type: Option<DataType>,
@@ -1036,6 +1040,52 @@ impl Expr {
     pub fn contains_outer(&self) -> bool {
         !find_out_reference_exprs(self).is_empty()
     }
+
+    /// Recursively find all [`Expr::Placeholder`] expressions, and
+    /// to infer their [`DataType`] from the context of their use.
+    ///
+    /// For example, gicen an expression like `<int32> = $0` will infer `$0` to
+    /// have type `int32`.
+    pub fn infer_placeholder_types(self, schema: &DFSchema) -> Result<Expr> {
+        self.transform(&|mut expr| {
+            // Default to assuming the arguments are the same type
+            if let Expr::BinaryExpr(BinaryExpr { left, op: _, right }) = &mut 
expr {
+                rewrite_placeholder(left.as_mut(), right.as_ref(), schema)?;
+                rewrite_placeholder(right.as_mut(), left.as_ref(), schema)?;
+            };
+            if let Expr::Between(Between {
+                expr,
+                negated: _,
+                low,
+                high,
+            }) = &mut expr
+            {
+                rewrite_placeholder(low.as_mut(), expr.as_ref(), schema)?;
+                rewrite_placeholder(high.as_mut(), expr.as_ref(), schema)?;
+            }
+            Ok(Transformed::Yes(expr))
+        })
+    }
+}
+
+// modifies expr if it is a placeholder with datatype of right
+fn rewrite_placeholder(expr: &mut Expr, other: &Expr, schema: &DFSchema) -> 
Result<()> {
+    if let Expr::Placeholder(Placeholder { id: _, data_type }) = expr {
+        if data_type.is_none() {
+            let other_dt = other.get_type(schema);
+            match other_dt {
+                Err(e) => {
+                    Err(e.context(format!(
+                        "Can not find type of {other} needed to infer type of 
{expr}"
+                    )))?;
+                }
+                Ok(dt) => {
+                    *data_type = Some(dt);
+                }
+            }
+        };
+    }
+    Ok(())
 }
 
 #[macro_export]
diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs
index 711dc123a4..79a43c2353 100644
--- a/datafusion/expr/src/expr_fn.rs
+++ b/datafusion/expr/src/expr_fn.rs
@@ -19,7 +19,7 @@
 
 use crate::expr::{
     AggregateFunction, BinaryExpr, Cast, Exists, GroupingSet, InList, 
InSubquery,
-    ScalarFunction, TryCast,
+    Placeholder, ScalarFunction, TryCast,
 };
 use crate::function::PartitionEvaluatorFactory;
 use crate::WindowUDF;
@@ -80,6 +80,24 @@ pub fn ident(name: impl Into<String>) -> Expr {
     Expr::Column(Column::from_name(name))
 }
 
+/// Create placeholder value that will be filled in (such as `$1`)
+///
+/// Note the parameter type can be inferred using 
[`Expr::infer_placeholder_types`]
+///
+/// # Example
+///
+/// ```rust
+/// # use datafusion_expr::{placeholder};
+/// let p = placeholder("$0"); // $0, refers to parameter 1
+/// assert_eq!(p.to_string(), "$0")
+/// ```
+pub fn placeholder(id: impl Into<String>) -> Expr {
+    Expr::Placeholder(Placeholder {
+        id: id.into(),
+        data_type: None,
+    })
+}
+
 /// Return a new expression `left <op> right`
 pub fn binary_expr(left: Expr, op: Operator, right: Expr) -> Expr {
     Expr::BinaryExpr(BinaryExpr::new(Box::new(left), op, Box::new(right)))
diff --git a/datafusion/expr/src/logical_plan/plan.rs 
b/datafusion/expr/src/logical_plan/plan.rs
index b865b68557..1c526c7b40 100644
--- a/datafusion/expr/src/logical_plan/plan.rs
+++ b/datafusion/expr/src/logical_plan/plan.rs
@@ -928,8 +928,40 @@ impl LogicalPlan {
             }
         }
     }
-    /// Convert a prepared [`LogicalPlan`] into its inner logical plan
-    /// with all params replaced with their corresponding values
+    /// Replaces placeholder param values (like `$1`, `$2`) in [`LogicalPlan`]
+    /// with the specified `param_values`.
+    ///
+    /// [`LogicalPlan::Prepare`] are
+    /// converted to their inner logical plan for execution.
+    ///
+    /// # Example
+    /// ```
+    /// # use arrow::datatypes::{Field, Schema, DataType};
+    /// use datafusion_common::ScalarValue;
+    /// # use datafusion_expr::{lit, col, LogicalPlanBuilder, 
logical_plan::table_scan, placeholder};
+    /// # let schema = Schema::new(vec![
+    /// #     Field::new("id", DataType::Int32, false),
+    /// # ]);
+    /// // Build SELECT * FROM t1 WHRERE id = $1
+    /// let plan = table_scan(Some("t1"), &schema, None).unwrap()
+    ///     .filter(col("id").eq(placeholder("$1"))).unwrap()
+    ///     .build().unwrap();
+    ///
+    /// assert_eq!("Filter: t1.id = $1\
+    ///            \n  TableScan: t1",
+    ///             plan.display_indent().to_string()
+    /// );
+    ///
+    /// // Fill in the parameter $1 with a literal 3
+    /// let plan = plan.with_param_values(vec![
+    ///   ScalarValue::from(3i32) // value at index 0 --> $1
+    /// ]).unwrap();
+    ///
+    /// assert_eq!("Filter: t1.id = Int32(3)\
+    ///             \n  TableScan: t1",
+    ///             plan.display_indent().to_string()
+    ///  );
+    /// ```
     pub fn with_param_values(
         self,
         param_values: Vec<ScalarValue>,
@@ -961,7 +993,7 @@ impl LogicalPlan {
                 let input_plan = prepare_lp.input;
                 input_plan.replace_params_with_values(&param_values)
             }
-            _ => Ok(self),
+            _ => self.replace_params_with_values(&param_values),
         }
     }
 
@@ -1060,7 +1092,7 @@ impl LogicalPlan {
 }
 
 impl LogicalPlan {
-    /// applies collect to any subqueries in the plan
+    /// applies `op` to any subqueries in the plan
     pub(crate) fn apply_subqueries<F>(&self, op: &mut F) -> 
datafusion_common::Result<()>
     where
         F: FnMut(&Self) -> datafusion_common::Result<VisitRecursion>,
@@ -1112,9 +1144,11 @@ impl LogicalPlan {
         Ok(())
     }
 
-    /// Return a logical plan with all placeholders/params (e.g $1 $2,
-    /// ...) replaced with corresponding values provided in the
-    /// params_values
+    /// Return a `LogicalPlan` with all placeholders (e.g $1 $2,
+    /// ...) replaced with corresponding values provided in
+    /// `params_values`
+    ///
+    /// See [`Self::with_param_values`] for examples and usage
     pub fn replace_params_with_values(
         &self,
         param_values: &[ScalarValue],
@@ -1122,7 +1156,10 @@ impl LogicalPlan {
         let new_exprs = self
             .expressions()
             .into_iter()
-            .map(|e| Self::replace_placeholders_with_values(e, param_values))
+            .map(|e| {
+                let e = e.infer_placeholder_types(self.schema())?;
+                Self::replace_placeholders_with_values(e, param_values)
+            })
             .collect::<Result<Vec<_>>>()?;
 
         let new_inputs_with_values = self
@@ -1219,7 +1256,9 @@ impl LogicalPlan {
 // Various implementations for printing out LogicalPlans
 impl LogicalPlan {
     /// Return a `format`able structure that produces a single line
-    /// per node. For example:
+    /// per node.
+    ///
+    /// # Example
     ///
     /// ```text
     /// Projection: employee.id
@@ -2321,7 +2360,7 @@ pub struct Unnest {
 mod tests {
     use super::*;
     use crate::logical_plan::table_scan;
-    use crate::{col, exists, in_subquery, lit};
+    use crate::{col, exists, in_subquery, lit, placeholder};
     use arrow::datatypes::{DataType, Field, Schema};
     use datafusion_common::tree_node::TreeNodeVisitor;
     use datafusion_common::{not_impl_err, DFSchema, TableReference};
@@ -2767,10 +2806,7 @@ digraph {
 
         let plan = table_scan(TableReference::none(), &schema, None)
             .unwrap()
-            .filter(col("id").eq(Expr::Placeholder(Placeholder::new(
-                "".into(),
-                Some(DataType::Int32),
-            ))))
+            .filter(col("id").eq(placeholder("")))
             .unwrap()
             .build()
             .unwrap();
@@ -2783,10 +2819,7 @@ digraph {
 
         let plan = table_scan(TableReference::none(), &schema, None)
             .unwrap()
-            .filter(col("id").eq(Expr::Placeholder(Placeholder::new(
-                "$0".into(),
-                Some(DataType::Int32),
-            ))))
+            .filter(col("id").eq(placeholder("$0")))
             .unwrap()
             .build()
             .unwrap();
diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs
index cb34b6ca36..a90a0f121f 100644
--- a/datafusion/sql/src/expr/mod.rs
+++ b/datafusion/sql/src/expr/mod.rs
@@ -29,13 +29,12 @@ mod value;
 
 use crate::planner::{ContextProvider, PlannerContext, SqlToRel};
 use arrow_schema::DataType;
-use datafusion_common::tree_node::{Transformed, TreeNode};
 use datafusion_common::{
     internal_err, not_impl_err, plan_err, Column, DFSchema, DataFusionError, 
Result,
     ScalarValue,
 };
+use datafusion_expr::expr::InList;
 use datafusion_expr::expr::ScalarFunction;
-use datafusion_expr::expr::{InList, Placeholder};
 use datafusion_expr::{
     col, expr, lit, AggregateFunction, Between, BinaryExpr, 
BuiltinScalarFunction, Cast,
     Expr, ExprSchemable, GetFieldAccess, GetIndexedField, Like, Operator, 
TryCast,
@@ -122,7 +121,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
         let mut expr = self.sql_expr_to_logical_expr(sql, schema, 
planner_context)?;
         expr = self.rewrite_partial_qualifier(expr, schema);
         self.validate_schema_satisfies_exprs(schema, &[expr.clone()])?;
-        let expr = infer_placeholder_types(expr, schema)?;
+        let expr = expr.infer_placeholder_types(schema)?;
         Ok(expr)
     }
 
@@ -712,49 +711,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
     }
 }
 
-// modifies expr if it is a placeholder with datatype of right
-fn rewrite_placeholder(expr: &mut Expr, other: &Expr, schema: &DFSchema) -> 
Result<()> {
-    if let Expr::Placeholder(Placeholder { id: _, data_type }) = expr {
-        if data_type.is_none() {
-            let other_dt = other.get_type(schema);
-            match other_dt {
-                Err(e) => {
-                    Err(e.context(format!(
-                        "Can not find type of {other} needed to infer type of 
{expr}"
-                    )))?;
-                }
-                Ok(dt) => {
-                    *data_type = Some(dt);
-                }
-            }
-        };
-    }
-    Ok(())
-}
-
-/// Find all [`Expr::Placeholder`] tokens in a logical plan, and try
-/// to infer their [`DataType`] from the context of their use.
-fn infer_placeholder_types(expr: Expr, schema: &DFSchema) -> Result<Expr> {
-    expr.transform(&|mut expr| {
-        // Default to assuming the arguments are the same type
-        if let Expr::BinaryExpr(BinaryExpr { left, op: _, right }) = &mut expr 
{
-            rewrite_placeholder(left.as_mut(), right.as_ref(), schema)?;
-            rewrite_placeholder(right.as_mut(), left.as_ref(), schema)?;
-        };
-        if let Expr::Between(Between {
-            expr,
-            negated: _,
-            low,
-            high,
-        }) = &mut expr
-        {
-            rewrite_placeholder(low.as_mut(), expr.as_ref(), schema)?;
-            rewrite_placeholder(high.as_mut(), expr.as_ref(), schema)?;
-        }
-        Ok(Transformed::Yes(expr))
-    })
-}
-
 #[cfg(test)]
 mod tests {
     use super::*;
diff --git a/datafusion/sql/tests/sql_integration.rs 
b/datafusion/sql/tests/sql_integration.rs
index d95598cc3d..702d7dbce6 100644
--- a/datafusion/sql/tests/sql_integration.rs
+++ b/datafusion/sql/tests/sql_integration.rs
@@ -3671,6 +3671,19 @@ fn test_prepare_statement_should_infer_types() {
     assert_eq!(actual_types, expected_types);
 }
 
+#[test]
+fn test_non_prepare_statement_should_infer_types() {
+    // Non prepared statements (like SELECT) should also have their parameter 
types inferred
+    let sql = "SELECT 1 + $1";
+    let plan = logical_plan(sql).unwrap();
+    let actual_types = plan.get_parameter_types().unwrap();
+    let expected_types = HashMap::from([
+        // constant 1 is inferred to be int64
+        ("$1".to_string(), Some(DataType::Int64)),
+    ]);
+    assert_eq!(actual_types, expected_types);
+}
+
 #[test]
 #[should_panic(
     expected = "value: SQL(ParserError(\"Expected [NOT] NULL or TRUE|FALSE or 
[NOT] DISTINCT FROM after IS, found: $1\""

Reply via email to