This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new aff777b668 Add standalone example for `OptimizerRule` (#11087)
aff777b668 is described below

commit aff777b6689b9862db063b907c026311ffb27109
Author: Andrew Lamb <and...@nerdnetworks.org>
AuthorDate: Wed Jun 26 13:31:56 2024 -0400

    Add standalone example for `OptimizerRule` (#11087)
    
    * Add standalone example for `OptimizerRule`
    
    * Fix typo
    
    * Update datafusion-examples/examples/optimizer_rule.rs
    
    Co-authored-by: Oleks V <comph...@users.noreply.github.com>
    
    * fmt
    
    ---------
    
    Co-authored-by: Oleks V <comph...@users.noreply.github.com>
---
 datafusion-examples/README.md                      |   1 +
 datafusion-examples/examples/optimizer_rule.rs     | 213 +++++++++++++++++++++
 datafusion/core/src/execution/context/mod.rs       |  22 ++-
 datafusion/core/src/execution/session_state.rs     |  10 +
 .../core/tests/user_defined/user_defined_plan.rs   |   3 +-
 5 files changed, 241 insertions(+), 8 deletions(-)

diff --git a/datafusion-examples/README.md b/datafusion-examples/README.md
index 217738a467..52702361e6 100644
--- a/datafusion-examples/README.md
+++ b/datafusion-examples/README.md
@@ -64,6 +64,7 @@ cargo run --example csv_sql
 - [`function_factory.rs`](examples/function_factory.rs): Register `CREATE 
FUNCTION` handler to implement SQL macros
 - [`make_date.rs`](examples/make_date.rs): Examples of using the make_date 
function
 - [`memtable.rs`](examples/memtable.rs): Create an query data in memory using 
SQL and `RecordBatch`es
+- [`optimizer_rule.rs`](examples/optimizer_rule.rs): Use a custom 
OptimizerRule to replace certain predicates
 - [`parquet_index.rs`](examples/parquet_index.rs): Create an secondary index 
over several parquet files and use it to speed up queries
 - [`parquet_sql.rs`](examples/parquet_sql.rs): Build and run a query plan from 
a SQL statement against a local Parquet file
 - [`parquet_sql_multiple_files.rs`](examples/parquet_sql_multiple_files.rs): 
Build and run a query plan from a SQL statement against multiple local Parquet 
files
diff --git a/datafusion-examples/examples/optimizer_rule.rs 
b/datafusion-examples/examples/optimizer_rule.rs
new file mode 100644
index 0000000000..0578529463
--- /dev/null
+++ b/datafusion-examples/examples/optimizer_rule.rs
@@ -0,0 +1,213 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use arrow::array::{ArrayRef, Int32Array, RecordBatch, StringArray};
+use arrow_schema::DataType;
+use datafusion::prelude::SessionContext;
+use datafusion_common::tree_node::{Transformed, TreeNode};
+use datafusion_common::{Result, ScalarValue};
+use datafusion_expr::{
+    BinaryExpr, ColumnarValue, Expr, LogicalPlan, Operator, ScalarUDF, 
ScalarUDFImpl,
+    Signature, Volatility,
+};
+use datafusion_optimizer::optimizer::ApplyOrder;
+use datafusion_optimizer::{OptimizerConfig, OptimizerRule};
+use std::any::Any;
+use std::sync::Arc;
+
+/// This example demonstrates how to add your own [`OptimizerRule`]
+/// to DataFusion.
+///
+/// [`OptimizerRule`]s transform [`LogicalPlan`]s into an equivalent (but
+/// hopefully faster) form.
+///
+/// See [analyzer_rule.rs] for an example of AnalyzerRules, which are for
+/// changing plan semantics.
+#[tokio::main]
+pub async fn main() -> Result<()> {
+    // DataFusion includes many built in OptimizerRules for tasks such as outer
+    // to inner join conversion and constant folding.
+    //
+    // Note you can change the order of optimizer rules using the lower level
+    // `SessionState` API
+    let ctx = SessionContext::new();
+    ctx.add_optimizer_rule(Arc::new(MyOptimizerRule {}));
+
+    // Now, let's plan and run queries with the new rule
+    ctx.register_batch("person", person_batch())?;
+    let sql = "SELECT * FROM person WHERE age = 22";
+    let plan = ctx.sql(sql).await?.into_optimized_plan()?;
+
+    // We can see the effect of our rewrite on the output plan that the filter
+    // has been rewritten to `my_eq`
+    //
+    // Filter: my_eq(person.age, Int32(22))
+    //   TableScan: person projection=[name, age]
+    println!("Logical Plan:\n\n{}\n", plan.display_indent());
+
+    // The query below doesn't respect a filter `where age = 22` because
+    // the plan has been rewritten using UDF which returns always true
+    //
+    // And the output verifies the predicates have been changed (as the my_eq
+    // function always returns true)
+    //
+    // +--------+-----+
+    // | name   | age |
+    // +--------+-----+
+    // | Andy   | 11  |
+    // | Andrew | 22  |
+    // | Oleks  | 33  |
+    // +--------+-----+
+    ctx.sql(sql).await?.show().await?;
+
+    // however we can see the rule doesn't trigger for queries with predicates
+    // other than `=`
+    //
+    // +-------+-----+
+    // | name  | age |
+    // +-------+-----+
+    // | Andy  | 11  |
+    // | Oleks | 33  |
+    // +-------+-----+
+    ctx.sql("SELECT * FROM person WHERE age <> 22")
+        .await?
+        .show()
+        .await?;
+
+    Ok(())
+}
+
+/// An example OptimizerRule that replaces all `col = <const>` predicates with 
a
+/// user defined function
+struct MyOptimizerRule {}
+
+impl OptimizerRule for MyOptimizerRule {
+    fn name(&self) -> &str {
+        "my_optimizer_rule"
+    }
+
+    // New OptimizerRules should use the "rewrite" api as it is more efficient
+    fn supports_rewrite(&self) -> bool {
+        true
+    }
+
+    /// Ask the optimizer to handle the plan recursion. `rewrite` will be 
called
+    /// on each plan node.
+    fn apply_order(&self) -> Option<ApplyOrder> {
+        Some(ApplyOrder::BottomUp)
+    }
+
+    fn rewrite(
+        &self,
+        plan: LogicalPlan,
+        _config: &dyn OptimizerConfig,
+    ) -> Result<Transformed<LogicalPlan>> {
+        plan.map_expressions(|expr| {
+            // This closure is called for all expressions in the current plan
+            //
+            // For example, given a plan like `SELECT a + b, 5 + 10`
+            //
+            // The closure would be called twice:
+            // 1. once for `a + b`
+            // 2. once for `5 + 10`
+            self.rewrite_expr(expr)
+        })
+    }
+}
+
+impl MyOptimizerRule {
+    /// Rewrites an Expr replacing all `<col> = <const>` expressions with
+    /// a call to my_eq udf
+    fn rewrite_expr(&self, expr: Expr) -> Result<Transformed<Expr>> {
+        // do a bottom up rewrite of the expression tree
+        expr.transform_up(|expr| {
+            // Closure called for each sub tree
+            match expr {
+                Expr::BinaryExpr(binary_expr) if is_binary_eq(&binary_expr) => 
{
+                    // destruture the expression
+                    let BinaryExpr { left, op: _, right } = binary_expr;
+                    // rewrite to `my_eq(left, right)`
+                    let udf = ScalarUDF::new_from_impl(MyEq::new());
+                    let call = udf.call(vec![*left, *right]);
+                    Ok(Transformed::yes(call))
+                }
+                _ => Ok(Transformed::no(expr)),
+            }
+        })
+        // Note that the TreeNode API handles propagating the transformed flag
+        // and errors up the call chain
+    }
+}
+
+/// return true of the expression is an equality expression for a literal or
+/// column reference
+fn is_binary_eq(binary_expr: &BinaryExpr) -> bool {
+    binary_expr.op == Operator::Eq
+        && is_lit_or_col(binary_expr.left.as_ref())
+        && is_lit_or_col(binary_expr.right.as_ref())
+}
+
+/// Return true if the expression is a literal or column reference
+fn is_lit_or_col(expr: &Expr) -> bool {
+    matches!(expr, Expr::Column(_) | Expr::Literal(_))
+}
+
+/// A simple user defined filter function
+#[derive(Debug, Clone)]
+struct MyEq {
+    signature: Signature,
+}
+
+impl MyEq {
+    fn new() -> Self {
+        Self {
+            signature: Signature::any(2, Volatility::Stable),
+        }
+    }
+}
+
+impl ScalarUDFImpl for MyEq {
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
+
+    fn name(&self) -> &str {
+        "my_eq"
+    }
+
+    fn signature(&self) -> &Signature {
+        &self.signature
+    }
+
+    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
+        Ok(DataType::Boolean)
+    }
+
+    fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue> {
+        // this example simply returns "true" which is not what a real
+        // implementation would do.
+        Ok(ColumnarValue::Scalar(ScalarValue::from(true)))
+    }
+}
+
+/// Return a RecordBatch with made up data
+fn person_batch() -> RecordBatch {
+    let name: ArrayRef =
+        Arc::new(StringArray::from_iter_values(["Andy", "Andrew", "Oleks"]));
+    let age: ArrayRef = Arc::new(Int32Array::from(vec![11, 22, 33]));
+    RecordBatch::try_from_iter(vec![("name", name), ("age", age)]).unwrap()
+}
diff --git a/datafusion/core/src/execution/context/mod.rs 
b/datafusion/core/src/execution/context/mod.rs
index ac85dd95af..9ec0148d91 100644
--- a/datafusion/core/src/execution/context/mod.rs
+++ b/datafusion/core/src/execution/context/mod.rs
@@ -75,7 +75,7 @@ use url::Url;
 pub use datafusion_execution::config::SessionConfig;
 pub use datafusion_execution::TaskContext;
 pub use datafusion_expr::execution_props::ExecutionProps;
-use datafusion_optimizer::AnalyzerRule;
+use datafusion_optimizer::{AnalyzerRule, OptimizerRule};
 
 mod avro;
 mod csv;
@@ -332,13 +332,21 @@ impl SessionContext {
         self
     }
 
-    /// Adds an analyzer rule to the `SessionState` in the current 
`SessionContext`.
-    pub fn add_analyzer_rule(
-        self,
-        analyzer_rule: Arc<dyn AnalyzerRule + Send + Sync>,
-    ) -> Self {
+    /// Adds an optimizer rule to the end of the existing rules.
+    ///
+    /// See [`SessionState`] for more control of when the rule is applied.
+    pub fn add_optimizer_rule(
+        &self,
+        optimizer_rule: Arc<dyn OptimizerRule + Send + Sync>,
+    ) {
+        self.state.write().append_optimizer_rule(optimizer_rule);
+    }
+
+    /// Adds an analyzer rule to the end of the existing rules.
+    ///
+    /// See [`SessionState`] for more control of when the rule is applied.
+    pub fn add_analyzer_rule(&self, analyzer_rule: Arc<dyn AnalyzerRule + Send 
+ Sync>) {
         self.state.write().add_analyzer_rule(analyzer_rule);
-        self
     }
 
     /// Registers an [`ObjectStore`] to be used with a specific URL prefix.
diff --git a/datafusion/core/src/execution/session_state.rs 
b/datafusion/core/src/execution/session_state.rs
index 16d8508597..d2bac134b5 100644
--- a/datafusion/core/src/execution/session_state.rs
+++ b/datafusion/core/src/execution/session_state.rs
@@ -402,6 +402,16 @@ impl SessionState {
         self
     }
 
+    // the add_optimizer_rule takes an owned reference
+    // it should probably be renamed to `with_optimizer_rule` to follow 
builder style
+    // and `add_optimizer_rule` that takes &mut self added instead of this
+    pub(crate) fn append_optimizer_rule(
+        &mut self,
+        optimizer_rule: Arc<dyn OptimizerRule + Send + Sync>,
+    ) {
+        self.optimizer.rules.push(optimizer_rule);
+    }
+
     /// Add `physical_optimizer_rule` to the end of the list of
     /// [`PhysicalOptimizerRule`]s used to rewrite queries.
     pub fn add_physical_optimizer_rule(
diff --git a/datafusion/core/tests/user_defined/user_defined_plan.rs 
b/datafusion/core/tests/user_defined/user_defined_plan.rs
index 4b5bd3a28d..38ed142cf9 100644
--- a/datafusion/core/tests/user_defined/user_defined_plan.rs
+++ b/datafusion/core/tests/user_defined/user_defined_plan.rs
@@ -245,7 +245,8 @@ async fn normal_query() -> Result<()> {
 #[tokio::test]
 // Run the query using default planners, optimizer and custom analyzer rule
 async fn normal_query_with_analyzer() -> Result<()> {
-    let ctx = SessionContext::new().add_analyzer_rule(Arc::new(MyAnalyzerRule 
{}));
+    let ctx = SessionContext::new();
+    ctx.add_analyzer_rule(Arc::new(MyAnalyzerRule {}));
     run_and_compare_query_with_analyzer_rule(ctx, "MyAnalyzerRule").await
 }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@datafusion.apache.org
For additional commands, e-mail: commits-h...@datafusion.apache.org

Reply via email to