This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new d33595ab55 Add documentation and usability for prepared parameters
(#7785)
d33595ab55 is described below
commit d33595ab55de6e4317c52bceedc0231b25af5fef
Author: Andrew Lamb <[email protected]>
AuthorDate: Mon Oct 16 10:19:01 2023 -0400
Add documentation and usability for prepared parameters (#7785)
* Add documentation for prepared parameters + make it eaiser to use
* Update datafusion/expr/src/expr.rs
Co-authored-by: jakevin <[email protected]>
---------
Co-authored-by: jakevin <[email protected]>
---
datafusion/core/src/dataframe.rs | 37 ++++++++++++++++-
datafusion/expr/src/expr.rs | 58 +++++++++++++++++++++++++--
datafusion/expr/src/expr_fn.rs | 20 ++++++++-
datafusion/expr/src/logical_plan/plan.rs | 69 +++++++++++++++++++++++---------
datafusion/sql/src/expr/mod.rs | 48 +---------------------
datafusion/sql/tests/sql_integration.rs | 13 ++++++
6 files changed, 175 insertions(+), 70 deletions(-)
diff --git a/datafusion/core/src/dataframe.rs b/datafusion/core/src/dataframe.rs
index d704c7f304..79b8fcd519 100644
--- a/datafusion/core/src/dataframe.rs
+++ b/datafusion/core/src/dataframe.rs
@@ -1210,7 +1210,42 @@ impl DataFrame {
Ok(DataFrame::new(self.session_state, project_plan))
}
- /// Convert a prepare logical plan into its inner logical plan with all
params replaced with their corresponding values
+ /// Replace all parameters in logical plan with the specified
+ /// values, in preparation for execution.
+ ///
+ /// # Example
+ ///
+ /// ```
+ /// use datafusion::prelude::*;
+ /// # use datafusion::{error::Result, assert_batches_eq};
+ /// # #[tokio::main]
+ /// # async fn main() -> Result<()> {
+ /// # use datafusion_common::ScalarValue;
+ /// let mut ctx = SessionContext::new();
+ /// # ctx.register_csv("example", "tests/data/example.csv",
CsvReadOptions::new()).await?;
+ /// let results = ctx
+ /// .sql("SELECT a FROM example WHERE b = $1")
+ /// .await?
+ /// // replace $1 with value 2
+ /// .with_param_values(vec![
+ /// // value at index 0 --> $1
+ /// ScalarValue::from(2i64)
+ /// ])?
+ /// .collect()
+ /// .await?;
+ /// assert_batches_eq!(
+ /// &[
+ /// "+---+",
+ /// "| a |",
+ /// "+---+",
+ /// "| 1 |",
+ /// "+---+",
+ /// ],
+ /// &results
+ /// );
+ /// # Ok(())
+ /// # }
+ /// ```
pub fn with_param_values(self, param_values: Vec<ScalarValue>) ->
Result<Self> {
let plan = self.plan.with_param_values(param_values)?;
Ok(Self::new(self.session_state, plan))
diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs
index 3949d25b30..0b166107fb 100644
--- a/datafusion/expr/src/expr.rs
+++ b/datafusion/expr/src/expr.rs
@@ -17,7 +17,6 @@
//! Expr module contains core type definition for `Expr`.
-use crate::aggregate_function;
use crate::built_in_function;
use crate::expr_fn::binary_expr;
use crate::logical_plan::Subquery;
@@ -26,8 +25,10 @@ use crate::utils::{expr_to_columns,
find_out_reference_exprs};
use crate::window_frame;
use crate::window_function;
use crate::Operator;
+use crate::{aggregate_function, ExprSchemable};
use arrow::datatypes::DataType;
-use datafusion_common::internal_err;
+use datafusion_common::tree_node::{Transformed, TreeNode};
+use datafusion_common::{internal_err, DFSchema};
use datafusion_common::{plan_err, Column, DataFusionError, Result,
ScalarValue};
use std::collections::HashSet;
use std::fmt;
@@ -605,10 +606,13 @@ impl InSubquery {
}
}
-/// Placeholder
+/// Placeholder, representing bind parameter values such as `$1`.
+///
+/// The type of these parameters is inferred using
[`Expr::infer_placeholder_types`]
+/// or can be specified directly using `PREPARE` statements.
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub struct Placeholder {
- /// The identifier of the parameter (e.g, $1 or $foo)
+ /// The identifier of the parameter, including the leading `$` (e.g,
`"$1"` or `"$foo"`)
pub id: String,
/// The type the parameter will be filled in with
pub data_type: Option<DataType>,
@@ -1036,6 +1040,52 @@ impl Expr {
pub fn contains_outer(&self) -> bool {
!find_out_reference_exprs(self).is_empty()
}
+
+ /// Recursively find all [`Expr::Placeholder`] expressions, and
+ /// to infer their [`DataType`] from the context of their use.
+ ///
+ /// For example, gicen an expression like `<int32> = $0` will infer `$0` to
+ /// have type `int32`.
+ pub fn infer_placeholder_types(self, schema: &DFSchema) -> Result<Expr> {
+ self.transform(&|mut expr| {
+ // Default to assuming the arguments are the same type
+ if let Expr::BinaryExpr(BinaryExpr { left, op: _, right }) = &mut
expr {
+ rewrite_placeholder(left.as_mut(), right.as_ref(), schema)?;
+ rewrite_placeholder(right.as_mut(), left.as_ref(), schema)?;
+ };
+ if let Expr::Between(Between {
+ expr,
+ negated: _,
+ low,
+ high,
+ }) = &mut expr
+ {
+ rewrite_placeholder(low.as_mut(), expr.as_ref(), schema)?;
+ rewrite_placeholder(high.as_mut(), expr.as_ref(), schema)?;
+ }
+ Ok(Transformed::Yes(expr))
+ })
+ }
+}
+
+// modifies expr if it is a placeholder with datatype of right
+fn rewrite_placeholder(expr: &mut Expr, other: &Expr, schema: &DFSchema) ->
Result<()> {
+ if let Expr::Placeholder(Placeholder { id: _, data_type }) = expr {
+ if data_type.is_none() {
+ let other_dt = other.get_type(schema);
+ match other_dt {
+ Err(e) => {
+ Err(e.context(format!(
+ "Can not find type of {other} needed to infer type of
{expr}"
+ )))?;
+ }
+ Ok(dt) => {
+ *data_type = Some(dt);
+ }
+ }
+ };
+ }
+ Ok(())
}
#[macro_export]
diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs
index 711dc123a4..79a43c2353 100644
--- a/datafusion/expr/src/expr_fn.rs
+++ b/datafusion/expr/src/expr_fn.rs
@@ -19,7 +19,7 @@
use crate::expr::{
AggregateFunction, BinaryExpr, Cast, Exists, GroupingSet, InList,
InSubquery,
- ScalarFunction, TryCast,
+ Placeholder, ScalarFunction, TryCast,
};
use crate::function::PartitionEvaluatorFactory;
use crate::WindowUDF;
@@ -80,6 +80,24 @@ pub fn ident(name: impl Into<String>) -> Expr {
Expr::Column(Column::from_name(name))
}
+/// Create placeholder value that will be filled in (such as `$1`)
+///
+/// Note the parameter type can be inferred using
[`Expr::infer_placeholder_types`]
+///
+/// # Example
+///
+/// ```rust
+/// # use datafusion_expr::{placeholder};
+/// let p = placeholder("$0"); // $0, refers to parameter 1
+/// assert_eq!(p.to_string(), "$0")
+/// ```
+pub fn placeholder(id: impl Into<String>) -> Expr {
+ Expr::Placeholder(Placeholder {
+ id: id.into(),
+ data_type: None,
+ })
+}
+
/// Return a new expression `left <op> right`
pub fn binary_expr(left: Expr, op: Operator, right: Expr) -> Expr {
Expr::BinaryExpr(BinaryExpr::new(Box::new(left), op, Box::new(right)))
diff --git a/datafusion/expr/src/logical_plan/plan.rs
b/datafusion/expr/src/logical_plan/plan.rs
index b865b68557..1c526c7b40 100644
--- a/datafusion/expr/src/logical_plan/plan.rs
+++ b/datafusion/expr/src/logical_plan/plan.rs
@@ -928,8 +928,40 @@ impl LogicalPlan {
}
}
}
- /// Convert a prepared [`LogicalPlan`] into its inner logical plan
- /// with all params replaced with their corresponding values
+ /// Replaces placeholder param values (like `$1`, `$2`) in [`LogicalPlan`]
+ /// with the specified `param_values`.
+ ///
+ /// [`LogicalPlan::Prepare`] are
+ /// converted to their inner logical plan for execution.
+ ///
+ /// # Example
+ /// ```
+ /// # use arrow::datatypes::{Field, Schema, DataType};
+ /// use datafusion_common::ScalarValue;
+ /// # use datafusion_expr::{lit, col, LogicalPlanBuilder,
logical_plan::table_scan, placeholder};
+ /// # let schema = Schema::new(vec![
+ /// # Field::new("id", DataType::Int32, false),
+ /// # ]);
+ /// // Build SELECT * FROM t1 WHRERE id = $1
+ /// let plan = table_scan(Some("t1"), &schema, None).unwrap()
+ /// .filter(col("id").eq(placeholder("$1"))).unwrap()
+ /// .build().unwrap();
+ ///
+ /// assert_eq!("Filter: t1.id = $1\
+ /// \n TableScan: t1",
+ /// plan.display_indent().to_string()
+ /// );
+ ///
+ /// // Fill in the parameter $1 with a literal 3
+ /// let plan = plan.with_param_values(vec![
+ /// ScalarValue::from(3i32) // value at index 0 --> $1
+ /// ]).unwrap();
+ ///
+ /// assert_eq!("Filter: t1.id = Int32(3)\
+ /// \n TableScan: t1",
+ /// plan.display_indent().to_string()
+ /// );
+ /// ```
pub fn with_param_values(
self,
param_values: Vec<ScalarValue>,
@@ -961,7 +993,7 @@ impl LogicalPlan {
let input_plan = prepare_lp.input;
input_plan.replace_params_with_values(¶m_values)
}
- _ => Ok(self),
+ _ => self.replace_params_with_values(¶m_values),
}
}
@@ -1060,7 +1092,7 @@ impl LogicalPlan {
}
impl LogicalPlan {
- /// applies collect to any subqueries in the plan
+ /// applies `op` to any subqueries in the plan
pub(crate) fn apply_subqueries<F>(&self, op: &mut F) ->
datafusion_common::Result<()>
where
F: FnMut(&Self) -> datafusion_common::Result<VisitRecursion>,
@@ -1112,9 +1144,11 @@ impl LogicalPlan {
Ok(())
}
- /// Return a logical plan with all placeholders/params (e.g $1 $2,
- /// ...) replaced with corresponding values provided in the
- /// params_values
+ /// Return a `LogicalPlan` with all placeholders (e.g $1 $2,
+ /// ...) replaced with corresponding values provided in
+ /// `params_values`
+ ///
+ /// See [`Self::with_param_values`] for examples and usage
pub fn replace_params_with_values(
&self,
param_values: &[ScalarValue],
@@ -1122,7 +1156,10 @@ impl LogicalPlan {
let new_exprs = self
.expressions()
.into_iter()
- .map(|e| Self::replace_placeholders_with_values(e, param_values))
+ .map(|e| {
+ let e = e.infer_placeholder_types(self.schema())?;
+ Self::replace_placeholders_with_values(e, param_values)
+ })
.collect::<Result<Vec<_>>>()?;
let new_inputs_with_values = self
@@ -1219,7 +1256,9 @@ impl LogicalPlan {
// Various implementations for printing out LogicalPlans
impl LogicalPlan {
/// Return a `format`able structure that produces a single line
- /// per node. For example:
+ /// per node.
+ ///
+ /// # Example
///
/// ```text
/// Projection: employee.id
@@ -2321,7 +2360,7 @@ pub struct Unnest {
mod tests {
use super::*;
use crate::logical_plan::table_scan;
- use crate::{col, exists, in_subquery, lit};
+ use crate::{col, exists, in_subquery, lit, placeholder};
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::tree_node::TreeNodeVisitor;
use datafusion_common::{not_impl_err, DFSchema, TableReference};
@@ -2767,10 +2806,7 @@ digraph {
let plan = table_scan(TableReference::none(), &schema, None)
.unwrap()
- .filter(col("id").eq(Expr::Placeholder(Placeholder::new(
- "".into(),
- Some(DataType::Int32),
- ))))
+ .filter(col("id").eq(placeholder("")))
.unwrap()
.build()
.unwrap();
@@ -2783,10 +2819,7 @@ digraph {
let plan = table_scan(TableReference::none(), &schema, None)
.unwrap()
- .filter(col("id").eq(Expr::Placeholder(Placeholder::new(
- "$0".into(),
- Some(DataType::Int32),
- ))))
+ .filter(col("id").eq(placeholder("$0")))
.unwrap()
.build()
.unwrap();
diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs
index cb34b6ca36..a90a0f121f 100644
--- a/datafusion/sql/src/expr/mod.rs
+++ b/datafusion/sql/src/expr/mod.rs
@@ -29,13 +29,12 @@ mod value;
use crate::planner::{ContextProvider, PlannerContext, SqlToRel};
use arrow_schema::DataType;
-use datafusion_common::tree_node::{Transformed, TreeNode};
use datafusion_common::{
internal_err, not_impl_err, plan_err, Column, DFSchema, DataFusionError,
Result,
ScalarValue,
};
+use datafusion_expr::expr::InList;
use datafusion_expr::expr::ScalarFunction;
-use datafusion_expr::expr::{InList, Placeholder};
use datafusion_expr::{
col, expr, lit, AggregateFunction, Between, BinaryExpr,
BuiltinScalarFunction, Cast,
Expr, ExprSchemable, GetFieldAccess, GetIndexedField, Like, Operator,
TryCast,
@@ -122,7 +121,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
let mut expr = self.sql_expr_to_logical_expr(sql, schema,
planner_context)?;
expr = self.rewrite_partial_qualifier(expr, schema);
self.validate_schema_satisfies_exprs(schema, &[expr.clone()])?;
- let expr = infer_placeholder_types(expr, schema)?;
+ let expr = expr.infer_placeholder_types(schema)?;
Ok(expr)
}
@@ -712,49 +711,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
}
}
-// modifies expr if it is a placeholder with datatype of right
-fn rewrite_placeholder(expr: &mut Expr, other: &Expr, schema: &DFSchema) ->
Result<()> {
- if let Expr::Placeholder(Placeholder { id: _, data_type }) = expr {
- if data_type.is_none() {
- let other_dt = other.get_type(schema);
- match other_dt {
- Err(e) => {
- Err(e.context(format!(
- "Can not find type of {other} needed to infer type of
{expr}"
- )))?;
- }
- Ok(dt) => {
- *data_type = Some(dt);
- }
- }
- };
- }
- Ok(())
-}
-
-/// Find all [`Expr::Placeholder`] tokens in a logical plan, and try
-/// to infer their [`DataType`] from the context of their use.
-fn infer_placeholder_types(expr: Expr, schema: &DFSchema) -> Result<Expr> {
- expr.transform(&|mut expr| {
- // Default to assuming the arguments are the same type
- if let Expr::BinaryExpr(BinaryExpr { left, op: _, right }) = &mut expr
{
- rewrite_placeholder(left.as_mut(), right.as_ref(), schema)?;
- rewrite_placeholder(right.as_mut(), left.as_ref(), schema)?;
- };
- if let Expr::Between(Between {
- expr,
- negated: _,
- low,
- high,
- }) = &mut expr
- {
- rewrite_placeholder(low.as_mut(), expr.as_ref(), schema)?;
- rewrite_placeholder(high.as_mut(), expr.as_ref(), schema)?;
- }
- Ok(Transformed::Yes(expr))
- })
-}
-
#[cfg(test)]
mod tests {
use super::*;
diff --git a/datafusion/sql/tests/sql_integration.rs
b/datafusion/sql/tests/sql_integration.rs
index d95598cc3d..702d7dbce6 100644
--- a/datafusion/sql/tests/sql_integration.rs
+++ b/datafusion/sql/tests/sql_integration.rs
@@ -3671,6 +3671,19 @@ fn test_prepare_statement_should_infer_types() {
assert_eq!(actual_types, expected_types);
}
+#[test]
+fn test_non_prepare_statement_should_infer_types() {
+ // Non prepared statements (like SELECT) should also have their parameter
types inferred
+ let sql = "SELECT 1 + $1";
+ let plan = logical_plan(sql).unwrap();
+ let actual_types = plan.get_parameter_types().unwrap();
+ let expected_types = HashMap::from([
+ // constant 1 is inferred to be int64
+ ("$1".to_string(), Some(DataType::Int64)),
+ ]);
+ assert_eq!(actual_types, expected_types);
+}
+
#[test]
#[should_panic(
expected = "value: SQL(ParserError(\"Expected [NOT] NULL or TRUE|FALSE or
[NOT] DISTINCT FROM after IS, found: $1\""