This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 45a316ca7e Extract CSE logic to `datafusion_common` (#13002)
45a316ca7e is described below
commit 45a316ca7e3b5faa21cc80bab6d879c2be3d90c8
Author: Peter Toth <[email protected]>
AuthorDate: Mon Oct 21 21:47:05 2024 +0200
Extract CSE logic to `datafusion_common` (#13002)
* Extract CSE logic
* address review comments, move `HashNode` to `datafusion_common::cse`,
shorter names for eliminator and controller, change
`CSE::extract_common_nodes()` to return `Result<FoundCommonNodes<N>>` (instead
of `Result<Transformed<FoundCommonNodes<N>>>`)
---
datafusion-cli/Cargo.lock | 21 +-
datafusion/common/Cargo.toml | 1 +
datafusion/common/src/cse.rs | 800 ++++++++++++
datafusion/common/src/lib.rs | 1 +
datafusion/common/src/tree_node.rs | 135 +-
datafusion/expr/src/expr.rs | 79 +-
.../optimizer/src/common_subexpr_eliminate.rs | 1361 ++++++--------------
7 files changed, 1314 insertions(+), 1084 deletions(-)
diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock
index 612209fdd9..401f203dd9 100644
--- a/datafusion-cli/Cargo.lock
+++ b/datafusion-cli/Cargo.lock
@@ -406,9 +406,9 @@ dependencies = [
[[package]]
name = "async-compression"
-version = "0.4.16"
+version = "0.4.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "103db485efc3e41214fe4fda9f3dbeae2eb9082f48fd236e6095627a9422066e"
+checksum = "0cb8f1d480b0ea3783ab015936d2a55c87e219676f0c0b7dec61494043f21857"
dependencies = [
"bzip2",
"flate2",
@@ -917,9 +917,9 @@ dependencies = [
[[package]]
name = "cc"
-version = "1.1.30"
+version = "1.1.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "b16803a61b81d9eabb7eae2588776c4c1e584b738ede45fdbb4c972cec1e9945"
+checksum = "c2e7962b54006dcfcc61cb72735f4d89bb97061dd6a7ed882ec6b8ee53714c6f"
dependencies = [
"jobserver",
"libc",
@@ -1293,6 +1293,7 @@ dependencies = [
"chrono",
"half",
"hashbrown 0.14.5",
+ "indexmap",
"instant",
"libc",
"num_cpus",
@@ -2615,9 +2616,9 @@ dependencies = [
[[package]]
name = "object_store"
-version = "0.11.0"
+version = "0.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "25a0c4b3a0e31f8b66f71ad8064521efa773910196e2cde791436f13409f3b45"
+checksum = "6eb4c22c6154a1e759d7099f9ffad7cc5ef8245f9efbab4a41b92623079c82f3"
dependencies = [
"async-trait",
"base64 0.22.1",
@@ -3411,9 +3412,9 @@ dependencies = [
[[package]]
name = "serde_json"
-version = "1.0.130"
+version = "1.0.132"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "610f75ff4a8e3cb29b85da56eabdd1bff5b06739059a4b8e2967fef32e5d9944"
+checksum = "d726bfaff4b320266d395898905d0eba0345aae23b54aee3a737e260fd46db03"
dependencies = [
"itoa",
"memchr",
@@ -3605,9 +3606,9 @@ checksum =
"13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292"
[[package]]
name = "syn"
-version = "2.0.79"
+version = "2.0.82"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "89132cd0bf050864e1d38dc3bbc07a0eb8e7530af26344d3d2bbbef83499f590"
+checksum = "83540f837a8afc019423a8edb95b52a8effe46957ee402287f4292fae35be021"
dependencies = [
"proc-macro2",
"quote",
diff --git a/datafusion/common/Cargo.toml b/datafusion/common/Cargo.toml
index 1ac27b40c2..0747672a18 100644
--- a/datafusion/common/Cargo.toml
+++ b/datafusion/common/Cargo.toml
@@ -56,6 +56,7 @@ arrow-schema = { workspace = true }
chrono = { workspace = true }
half = { workspace = true }
hashbrown = { workspace = true }
+indexmap = { workspace = true }
libc = "0.2.140"
num_cpus = { workspace = true }
object_store = { workspace = true, optional = true }
diff --git a/datafusion/common/src/cse.rs b/datafusion/common/src/cse.rs
new file mode 100644
index 0000000000..453ae26e73
--- /dev/null
+++ b/datafusion/common/src/cse.rs
@@ -0,0 +1,800 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Common Subexpression Elimination logic implemented in [`CSE`] can be
controlled with
+//! a [`CSEController`], that defines how to eliminate common subtrees from a
particular
+//! [`TreeNode`] tree.
+
+use crate::hash_utils::combine_hashes;
+use crate::tree_node::{
+ Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
TreeNodeRewriter,
+ TreeNodeVisitor,
+};
+use crate::Result;
+use indexmap::IndexMap;
+use std::collections::HashMap;
+use std::hash::{BuildHasher, Hash, Hasher, RandomState};
+use std::marker::PhantomData;
+use std::sync::Arc;
+
+/// Hashes the direct content of an [`TreeNode`] without recursing into its
children.
+///
+/// This method is useful to incrementally compute hashes, such as in [`CSE`]
which builds
+/// a deep hash of a node and its descendants during the bottom-up phase of
the first
+/// traversal and so avoid computing the hash of the node and then the hash of
its
+/// descendants separately.
+///
+/// If a node doesn't have any children then the value returned by
`hash_node()` is
+/// similar to '.hash()`, but not necessarily returns the same value.
+pub trait HashNode {
+ fn hash_node<H: Hasher>(&self, state: &mut H);
+}
+
+impl<T: HashNode + ?Sized> HashNode for Arc<T> {
+ fn hash_node<H: Hasher>(&self, state: &mut H) {
+ (**self).hash_node(state);
+ }
+}
+
+/// Identifier that represents a [`TreeNode`] tree.
+///
+/// This identifier is designed to be efficient and "hash", "accumulate",
"equal" and
+/// "have no collision (as low as possible)"
+#[derive(Debug, Eq, PartialEq)]
+struct Identifier<'n, N> {
+ // Hash of `node` built up incrementally during the first, visiting
traversal.
+ // Its value is not necessarily equal to default hash of the node. E.g. it
is not
+ // equal to `expr.hash()` if the node is `Expr`.
+ hash: u64,
+ node: &'n N,
+}
+
+impl<N> Clone for Identifier<'_, N> {
+ fn clone(&self) -> Self {
+ *self
+ }
+}
+impl<N> Copy for Identifier<'_, N> {}
+
+impl<N> Hash for Identifier<'_, N> {
+ fn hash<H: Hasher>(&self, state: &mut H) {
+ state.write_u64(self.hash);
+ }
+}
+
+impl<'n, N: HashNode> Identifier<'n, N> {
+ fn new(node: &'n N, random_state: &RandomState) -> Self {
+ let mut hasher = random_state.build_hasher();
+ node.hash_node(&mut hasher);
+ let hash = hasher.finish();
+ Self { hash, node }
+ }
+
+ fn combine(mut self, other: Option<Self>) -> Self {
+ other.map_or(self, |other_id| {
+ self.hash = combine_hashes(self.hash, other_id.hash);
+ self
+ })
+ }
+}
+
+/// A cache that contains the postorder index and the identifier of
[`TreeNode`]s by the
+/// preorder index of the nodes.
+///
+/// This cache is filled by [`CSEVisitor`] during the first traversal and is
+/// used by [`CSERewriter`] during the second traversal.
+///
+/// The purpose of this cache is to quickly find the identifier of a node
during the
+/// second traversal.
+///
+/// Elements in this array are added during `f_down` so the indexes represent
the preorder
+/// index of nodes and thus element 0 belongs to the root of the tree.
+///
+/// The elements of the array are tuples that contain:
+/// - Postorder index that belongs to the preorder index. Assigned during
`f_up`, start
+/// from 0.
+/// - The optional [`Identifier`] of the node. If none the node should not be
considered
+/// for CSE.
+///
+/// # Example
+/// An expression tree like `(a + b)` would have the following `IdArray`:
+/// ```text
+/// [
+/// (2, Some(Identifier(hash_of("a + b"), &"a + b"))),
+/// (1, Some(Identifier(hash_of("a"), &"a"))),
+/// (0, Some(Identifier(hash_of("b"), &"b")))
+/// ]
+/// ```
+type IdArray<'n, N> = Vec<(usize, Option<Identifier<'n, N>>)>;
+
+/// A map that contains the number of normal and conditional occurrences of
[`TreeNode`]s
+/// by their identifiers.
+type NodeStats<'n, N> = HashMap<Identifier<'n, N>, (usize, usize)>;
+
+/// A map that contains the common [`TreeNode`]s and their alias by their
identifiers,
+/// extracted during the second, rewriting traversal.
+type CommonNodes<'n, N> = IndexMap<Identifier<'n, N>, (N, String)>;
+
+type ChildrenList<N> = (Vec<N>, Vec<N>);
+
+/// The [`TreeNode`] specific definition of elimination.
+pub trait CSEController {
+ /// The type of the tree nodes.
+ type Node;
+
+ /// Splits the children to normal and conditionally evaluated ones or
returns `None`
+ /// if all are always evaluated.
+ fn conditional_children(node: &Self::Node) ->
Option<ChildrenList<&Self::Node>>;
+
+ // Returns true if a node is valid. If a node is invalid then it can't be
eliminated.
+ // Validity is propagated up which means no subtree can be eliminated that
contains
+ // an invalid node.
+ // (E.g. volatile expressions are not valid and subtrees containing such a
node can't
+ // be extracted.)
+ fn is_valid(node: &Self::Node) -> bool;
+
+ // Returns true if a node should be ignored during CSE. Contrary to
validity of a node,
+ // it is not propagated up.
+ fn is_ignored(&self, node: &Self::Node) -> bool;
+
+ // Generates a new name for the extracted subtree.
+ fn generate_alias(&self) -> String;
+
+ // Replaces a node to the generated alias.
+ fn rewrite(&mut self, node: &Self::Node, alias: &str) -> Self::Node;
+
+ // A helper method called on each node during top-down traversal during
the second,
+ // rewriting traversal of CSE.
+ fn rewrite_f_down(&mut self, _node: &Self::Node) {}
+
+ // A helper method called on each node during bottom-up traversal during
the second,
+ // rewriting traversal of CSE.
+ fn rewrite_f_up(&mut self, _node: &Self::Node) {}
+}
+
+/// The result of potentially rewriting a list of [`TreeNode`]s to eliminate
common
+/// subtrees.
+#[derive(Debug)]
+pub enum FoundCommonNodes<N> {
+ /// No common [`TreeNode`]s were found
+ No { original_nodes_list: Vec<Vec<N>> },
+
+ /// Common [`TreeNode`]s were found
+ Yes {
+ /// extracted common [`TreeNode`]
+ common_nodes: Vec<(N, String)>,
+
+ /// new [`TreeNode`]s with common subtrees replaced
+ new_nodes_list: Vec<Vec<N>>,
+
+ /// original [`TreeNode`]s
+ original_nodes_list: Vec<Vec<N>>,
+ },
+}
+
+/// Go through a [`TreeNode`] tree and generate identifiers for each subtrees.
+///
+/// An identifier contains information of the [`TreeNode`] itself and its
subtrees.
+/// This visitor implementation use a stack `visit_stack` to track traversal,
which
+/// lets us know when a subtree's visiting is finished. When `pre_visit` is
called
+/// (traversing to a new node), an `EnterMark` and an `NodeItem` will be
pushed into stack.
+/// And try to pop out a `EnterMark` on leaving a node (`f_up()`). All
`NodeItem`
+/// before the first `EnterMark` is considered to be sub-tree of the leaving
node.
+///
+/// This visitor also records identifier in `id_array`. Makes the following
traverse
+/// pass can get the identifier of a node without recalculate it. We assign
each node
+/// in the tree a series number, start from 1, maintained by `series_number`.
+/// Series number represents the order we left (`f_up()`) a node. Has the
property
+/// that child node's series number always smaller than parent's. While
`id_array` is
+/// organized in the order we enter (`f_down()`) a node. `node_count` helps us
to
+/// get the index of `id_array` for each node.
+///
+/// A [`TreeNode`] without any children (column, literal etc.) will not have
identifier
+/// because they should not be recognized as common subtree.
+struct CSEVisitor<'a, 'n, N, C: CSEController<Node = N>> {
+ /// statistics of [`TreeNode`]s
+ node_stats: &'a mut NodeStats<'n, N>,
+
+ /// cache to speed up second traversal
+ id_array: &'a mut IdArray<'n, N>,
+
+ /// inner states
+ visit_stack: Vec<VisitRecord<'n, N>>,
+
+ /// preorder index, start from 0.
+ down_index: usize,
+
+ /// postorder index, start from 0.
+ up_index: usize,
+
+ /// a [`RandomState`] to generate hashes during the first traversal
+ random_state: &'a RandomState,
+
+ /// a flag to indicate that common [`TreeNode`]s found
+ found_common: bool,
+
+ /// if we are in a conditional branch. A conditional branch means that the
[`TreeNode`]
+ /// might not be executed depending on the runtime values of other
[`TreeNode`]s, and
+ /// thus can not be extracted as a common [`TreeNode`].
+ conditional: bool,
+
+ controller: &'a C,
+}
+
+/// Record item that used when traversing a [`TreeNode`] tree.
+enum VisitRecord<'n, N> {
+ /// Marks the beginning of [`TreeNode`]. It contains:
+ /// - The post-order index assigned during the first, visiting traversal.
+ EnterMark(usize),
+
+ /// Marks an accumulated subtree. It contains:
+ /// - The accumulated identifier of a subtree.
+ /// - A accumulated boolean flag if the subtree is valid for CSE.
+ /// The flag is propagated up from children to parent. (E.g. volatile
expressions
+ /// are not valid and can't be extracted, but non-volatile children of
volatile
+ /// expressions can be extracted.)
+ NodeItem(Identifier<'n, N>, bool),
+}
+
+impl<'n, N: TreeNode + HashNode, C: CSEController<Node = N>> CSEVisitor<'_,
'n, N, C> {
+ /// Find the first `EnterMark` in the stack, and accumulates every
`NodeItem` before
+ /// it. Returns a tuple that contains:
+ /// - The pre-order index of the [`TreeNode`] we marked.
+ /// - The accumulated identifier of the children of the marked
[`TreeNode`].
+ /// - An accumulated boolean flag from the children of the marked
[`TreeNode`] if all
+ /// children are valid for CSE (i.e. it is safe to extract the
[`TreeNode`] as a
+ /// common [`TreeNode`] from its children POV).
+ /// (E.g. if any of the children of the marked expression is not valid
(e.g. is
+ /// volatile) then the expression is also not valid, so we can propagate
this
+ /// information up from children to parents via `visit_stack` during the
first,
+ /// visiting traversal and no need to test the expression's validity
beforehand with
+ /// an extra traversal).
+ fn pop_enter_mark(&mut self) -> (usize, Option<Identifier<'n, N>>, bool) {
+ let mut node_id = None;
+ let mut is_valid = true;
+
+ while let Some(item) = self.visit_stack.pop() {
+ match item {
+ VisitRecord::EnterMark(down_index) => {
+ return (down_index, node_id, is_valid);
+ }
+ VisitRecord::NodeItem(sub_node_id, sub_node_is_valid) => {
+ node_id = Some(sub_node_id.combine(node_id));
+ is_valid &= sub_node_is_valid;
+ }
+ }
+ }
+ unreachable!("EnterMark should paired with NodeItem");
+ }
+}
+
+impl<'n, N: TreeNode + HashNode + Eq, C: CSEController<Node = N>>
TreeNodeVisitor<'n>
+ for CSEVisitor<'_, 'n, N, C>
+{
+ type Node = N;
+
+ fn f_down(&mut self, node: &'n Self::Node) -> Result<TreeNodeRecursion> {
+ self.id_array.push((0, None));
+ self.visit_stack
+ .push(VisitRecord::EnterMark(self.down_index));
+ self.down_index += 1;
+
+ // If a node can short-circuit then some of its children might not be
executed so
+ // count the occurrence either normal or conditional.
+ Ok(if self.conditional {
+ // If we are already in a conditionally evaluated subtree then
continue
+ // traversal.
+ TreeNodeRecursion::Continue
+ } else {
+ // If we are already in a node that can short-circuit then start
new
+ // traversals on its normal conditional children.
+ match C::conditional_children(node) {
+ Some((normal, conditional)) => {
+ normal
+ .into_iter()
+ .try_for_each(|n| n.visit(self).map(|_| ()))?;
+ self.conditional = true;
+ conditional
+ .into_iter()
+ .try_for_each(|n| n.visit(self).map(|_| ()))?;
+ self.conditional = false;
+
+ TreeNodeRecursion::Jump
+ }
+
+ // In case of non-short-circuit node continue the traversal.
+ _ => TreeNodeRecursion::Continue,
+ }
+ })
+ }
+
+ fn f_up(&mut self, node: &'n Self::Node) -> Result<TreeNodeRecursion> {
+ let (down_index, sub_node_id, sub_node_is_valid) =
self.pop_enter_mark();
+
+ let node_id = Identifier::new(node,
self.random_state).combine(sub_node_id);
+ let is_valid = C::is_valid(node) && sub_node_is_valid;
+
+ self.id_array[down_index].0 = self.up_index;
+ if is_valid && !self.controller.is_ignored(node) {
+ self.id_array[down_index].1 = Some(node_id);
+ let (count, conditional_count) =
+ self.node_stats.entry(node_id).or_insert((0, 0));
+ if self.conditional {
+ *conditional_count += 1;
+ } else {
+ *count += 1;
+ }
+ if *count > 1 || (*count == 1 && *conditional_count > 0) {
+ self.found_common = true;
+ }
+ }
+ self.visit_stack
+ .push(VisitRecord::NodeItem(node_id, is_valid));
+ self.up_index += 1;
+
+ Ok(TreeNodeRecursion::Continue)
+ }
+}
+
+/// Rewrite a [`TreeNode`] tree by replacing detected common subtrees with the
+/// corresponding temporary [`TreeNode`], that column contains the evaluate
result of
+/// replaced [`TreeNode`] tree.
+struct CSERewriter<'a, 'n, N, C: CSEController<Node = N>> {
+ /// statistics of [`TreeNode`]s
+ node_stats: &'a NodeStats<'n, N>,
+
+ /// cache to speed up second traversal
+ id_array: &'a IdArray<'n, N>,
+
+ /// common [`TreeNode`]s, that are replaced during the second traversal,
are collected
+ /// to this map
+ common_nodes: &'a mut CommonNodes<'n, N>,
+
+ // preorder index, starts from 0.
+ down_index: usize,
+
+ controller: &'a mut C,
+}
+
+impl<N: TreeNode + Eq, C: CSEController<Node = N>> TreeNodeRewriter
+ for CSERewriter<'_, '_, N, C>
+{
+ type Node = N;
+
+ fn f_down(&mut self, node: Self::Node) -> Result<Transformed<Self::Node>> {
+ self.controller.rewrite_f_down(&node);
+
+ let (up_index, node_id) = self.id_array[self.down_index];
+ self.down_index += 1;
+
+ // Handle nodes with identifiers only
+ if let Some(node_id) = node_id {
+ let (count, conditional_count) =
self.node_stats.get(&node_id).unwrap();
+ if *count > 1 || *count == 1 && *conditional_count > 0 {
+ // step index to skip all sub-node (which has smaller series
number).
+ while self.down_index < self.id_array.len()
+ && self.id_array[self.down_index].0 < up_index
+ {
+ self.down_index += 1;
+ }
+
+ let (node, alias) =
+ self.common_nodes.entry(node_id).or_insert_with(|| {
+ let node_alias = self.controller.generate_alias();
+ (node, node_alias)
+ });
+
+ let rewritten = self.controller.rewrite(node, alias);
+
+ return Ok(Transformed::new(rewritten, true,
TreeNodeRecursion::Jump));
+ }
+ }
+
+ Ok(Transformed::no(node))
+ }
+
+ fn f_up(&mut self, node: Self::Node) -> Result<Transformed<Self::Node>> {
+ self.controller.rewrite_f_up(&node);
+
+ Ok(Transformed::no(node))
+ }
+}
+
+/// The main entry point of Common Subexpression Elimination.
+///
+/// [`CSE`] requires a [`CSEController`], that defines how common subtrees of
a particular
+/// [`TreeNode`] tree can be eliminated. The elimination process can be
started with the
+/// [`CSE::extract_common_nodes()`] method.
+pub struct CSE<N, C: CSEController<Node = N>> {
+ random_state: RandomState,
+ phantom_data: PhantomData<N>,
+ controller: C,
+}
+
+impl<N: TreeNode + HashNode + Clone + Eq, C: CSEController<Node = N>> CSE<N,
C> {
+ pub fn new(controller: C) -> Self {
+ Self {
+ random_state: RandomState::new(),
+ phantom_data: PhantomData,
+ controller,
+ }
+ }
+
+ /// Add an identifier to `id_array` for every [`TreeNode`] in this tree.
+ fn node_to_id_array<'n>(
+ &self,
+ node: &'n N,
+ node_stats: &mut NodeStats<'n, N>,
+ id_array: &mut IdArray<'n, N>,
+ ) -> Result<bool> {
+ let mut visitor = CSEVisitor {
+ node_stats,
+ id_array,
+ visit_stack: vec![],
+ down_index: 0,
+ up_index: 0,
+ random_state: &self.random_state,
+ found_common: false,
+ conditional: false,
+ controller: &self.controller,
+ };
+ node.visit(&mut visitor)?;
+
+ Ok(visitor.found_common)
+ }
+
+ /// Returns the identifier list for each element in `nodes` and a flag to
indicate if
+ /// rewrite phase of CSE make sense.
+ ///
+ /// Returns and array with 1 element for each input node in `nodes`
+ ///
+ /// Each element is itself the result of [`CSE::node_to_id_array`] for
that node
+ /// (e.g. the identifiers for each node in the tree)
+ fn to_arrays<'n>(
+ &self,
+ nodes: &'n [N],
+ node_stats: &mut NodeStats<'n, N>,
+ ) -> Result<(bool, Vec<IdArray<'n, N>>)> {
+ let mut found_common = false;
+ nodes
+ .iter()
+ .map(|n| {
+ let mut id_array = vec![];
+ self.node_to_id_array(n, node_stats, &mut id_array)
+ .map(|fc| {
+ found_common |= fc;
+
+ id_array
+ })
+ })
+ .collect::<Result<Vec<_>>>()
+ .map(|id_arrays| (found_common, id_arrays))
+ }
+
+ /// Replace common subtrees in `node` with the corresponding temporary
+ /// [`TreeNode`], updating `common_nodes` with any replaced [`TreeNode`]
+ fn replace_common_node<'n>(
+ &mut self,
+ node: N,
+ id_array: &IdArray<'n, N>,
+ node_stats: &NodeStats<'n, N>,
+ common_nodes: &mut CommonNodes<'n, N>,
+ ) -> Result<N> {
+ if id_array.is_empty() {
+ Ok(Transformed::no(node))
+ } else {
+ node.rewrite(&mut CSERewriter {
+ node_stats,
+ id_array,
+ common_nodes,
+ down_index: 0,
+ controller: &mut self.controller,
+ })
+ }
+ .data()
+ }
+
+ /// Replace common subtrees in `nodes_list` with the corresponding
temporary
+ /// [`TreeNode`], updating `common_nodes` with any replaced [`TreeNode`].
+ fn rewrite_nodes_list<'n>(
+ &mut self,
+ nodes_list: Vec<Vec<N>>,
+ arrays_list: &[Vec<IdArray<'n, N>>],
+ node_stats: &NodeStats<'n, N>,
+ common_nodes: &mut CommonNodes<'n, N>,
+ ) -> Result<Vec<Vec<N>>> {
+ nodes_list
+ .into_iter()
+ .zip(arrays_list.iter())
+ .map(|(nodes, arrays)| {
+ nodes
+ .into_iter()
+ .zip(arrays.iter())
+ .map(|(node, id_array)| {
+ self.replace_common_node(node, id_array, node_stats,
common_nodes)
+ })
+ .collect::<Result<Vec<_>>>()
+ })
+ .collect::<Result<Vec<_>>>()
+ }
+
+ /// Extracts common [`TreeNode`]s and rewrites `nodes_list`.
+ ///
+ /// Returns [`FoundCommonNodes`] recording the result of the extraction.
+ pub fn extract_common_nodes(
+ &mut self,
+ nodes_list: Vec<Vec<N>>,
+ ) -> Result<FoundCommonNodes<N>> {
+ let mut found_common = false;
+ let mut node_stats = NodeStats::new();
+ let id_arrays_list = nodes_list
+ .iter()
+ .map(|nodes| {
+ self.to_arrays(nodes, &mut node_stats)
+ .map(|(fc, id_arrays)| {
+ found_common |= fc;
+
+ id_arrays
+ })
+ })
+ .collect::<Result<Vec<_>>>()?;
+ if found_common {
+ let mut common_nodes = CommonNodes::new();
+ let new_nodes_list = self.rewrite_nodes_list(
+ // Must clone the list of nodes as Identifiers use references
to original
+ // nodes so we have to keep them intact.
+ nodes_list.clone(),
+ &id_arrays_list,
+ &node_stats,
+ &mut common_nodes,
+ )?;
+ assert!(!common_nodes.is_empty());
+
+ Ok(FoundCommonNodes::Yes {
+ common_nodes: common_nodes.into_values().collect(),
+ new_nodes_list,
+ original_nodes_list: nodes_list,
+ })
+ } else {
+ Ok(FoundCommonNodes::No {
+ original_nodes_list: nodes_list,
+ })
+ }
+ }
+}
+
+#[cfg(test)]
+mod test {
+ use crate::alias::AliasGenerator;
+ use crate::cse::{CSEController, HashNode, IdArray, Identifier, NodeStats,
CSE};
+ use crate::tree_node::tests::TestTreeNode;
+ use crate::Result;
+ use std::collections::HashSet;
+ use std::hash::{Hash, Hasher};
+
+ const CSE_PREFIX: &str = "__common_node";
+
+ #[derive(Clone, Copy)]
+ pub enum TestTreeNodeMask {
+ Normal,
+ NormalAndAggregates,
+ }
+
+ pub struct TestTreeNodeCSEController<'a> {
+ alias_generator: &'a AliasGenerator,
+ mask: TestTreeNodeMask,
+ }
+
+ impl<'a> TestTreeNodeCSEController<'a> {
+ fn new(alias_generator: &'a AliasGenerator, mask: TestTreeNodeMask) ->
Self {
+ Self {
+ alias_generator,
+ mask,
+ }
+ }
+ }
+
+ impl CSEController for TestTreeNodeCSEController<'_> {
+ type Node = TestTreeNode<String>;
+
+ fn conditional_children(
+ _: &Self::Node,
+ ) -> Option<(Vec<&Self::Node>, Vec<&Self::Node>)> {
+ None
+ }
+
+ fn is_valid(_node: &Self::Node) -> bool {
+ true
+ }
+
+ fn is_ignored(&self, node: &Self::Node) -> bool {
+ let is_leaf = node.is_leaf();
+ let is_aggr = node.data == "avg" || node.data == "sum";
+
+ match self.mask {
+ TestTreeNodeMask::Normal => is_leaf || is_aggr,
+ TestTreeNodeMask::NormalAndAggregates => is_leaf,
+ }
+ }
+
+ fn generate_alias(&self) -> String {
+ self.alias_generator.next(CSE_PREFIX)
+ }
+
+ fn rewrite(&mut self, node: &Self::Node, alias: &str) -> Self::Node {
+ TestTreeNode::new_leaf(format!("alias({}, {})", node.data, alias))
+ }
+ }
+
+ impl HashNode for TestTreeNode<String> {
+ fn hash_node<H: Hasher>(&self, state: &mut H) {
+ self.data.hash(state);
+ }
+ }
+
+ #[test]
+ fn id_array_visitor() -> Result<()> {
+ let alias_generator = AliasGenerator::new();
+ let eliminator = CSE::new(TestTreeNodeCSEController::new(
+ &alias_generator,
+ TestTreeNodeMask::Normal,
+ ));
+
+ let a_plus_1 = TestTreeNode::new(
+ vec![
+ TestTreeNode::new_leaf("a".to_string()),
+ TestTreeNode::new_leaf("1".to_string()),
+ ],
+ "+".to_string(),
+ );
+ let avg_c = TestTreeNode::new(
+ vec![TestTreeNode::new_leaf("c".to_string())],
+ "avg".to_string(),
+ );
+ let sum_a_plus_1 = TestTreeNode::new(vec![a_plus_1],
"sum".to_string());
+ let sum_a_plus_1_minus_avg_c =
+ TestTreeNode::new(vec![sum_a_plus_1, avg_c], "-".to_string());
+ let root = TestTreeNode::new(
+ vec![
+ sum_a_plus_1_minus_avg_c,
+ TestTreeNode::new_leaf("2".to_string()),
+ ],
+ "*".to_string(),
+ );
+
+ let [sum_a_plus_1_minus_avg_c, _] = root.children.as_slice() else {
+ panic!("Cannot extract subtree references")
+ };
+ let [sum_a_plus_1, avg_c] =
sum_a_plus_1_minus_avg_c.children.as_slice() else {
+ panic!("Cannot extract subtree references")
+ };
+ let [a_plus_1] = sum_a_plus_1.children.as_slice() else {
+ panic!("Cannot extract subtree references")
+ };
+
+ // skip aggregates
+ let mut id_array = vec![];
+ eliminator.node_to_id_array(&root, &mut NodeStats::new(), &mut
id_array)?;
+
+ // Collect distinct hashes and set them to 0 in `id_array`
+ fn collect_hashes(
+ id_array: &mut IdArray<'_, TestTreeNode<String>>,
+ ) -> HashSet<u64> {
+ id_array
+ .iter_mut()
+ .flat_map(|(_, id_option)| {
+ id_option.as_mut().map(|node_id| {
+ let hash = node_id.hash;
+ node_id.hash = 0;
+ hash
+ })
+ })
+ .collect::<HashSet<_>>()
+ }
+
+ let hashes = collect_hashes(&mut id_array);
+ assert_eq!(hashes.len(), 3);
+
+ let expected = vec![
+ (
+ 8,
+ Some(Identifier {
+ hash: 0,
+ node: &root,
+ }),
+ ),
+ (
+ 6,
+ Some(Identifier {
+ hash: 0,
+ node: sum_a_plus_1_minus_avg_c,
+ }),
+ ),
+ (3, None),
+ (
+ 2,
+ Some(Identifier {
+ hash: 0,
+ node: a_plus_1,
+ }),
+ ),
+ (0, None),
+ (1, None),
+ (5, None),
+ (4, None),
+ (7, None),
+ ];
+ assert_eq!(expected, id_array);
+
+ // include aggregates
+ let eliminator = CSE::new(TestTreeNodeCSEController::new(
+ &alias_generator,
+ TestTreeNodeMask::NormalAndAggregates,
+ ));
+
+ let mut id_array = vec![];
+ eliminator.node_to_id_array(&root, &mut NodeStats::new(), &mut
id_array)?;
+
+ let hashes = collect_hashes(&mut id_array);
+ assert_eq!(hashes.len(), 5);
+
+ let expected = vec![
+ (
+ 8,
+ Some(Identifier {
+ hash: 0,
+ node: &root,
+ }),
+ ),
+ (
+ 6,
+ Some(Identifier {
+ hash: 0,
+ node: sum_a_plus_1_minus_avg_c,
+ }),
+ ),
+ (
+ 3,
+ Some(Identifier {
+ hash: 0,
+ node: sum_a_plus_1,
+ }),
+ ),
+ (
+ 2,
+ Some(Identifier {
+ hash: 0,
+ node: a_plus_1,
+ }),
+ ),
+ (0, None),
+ (1, None),
+ (
+ 5,
+ Some(Identifier {
+ hash: 0,
+ node: avg_c,
+ }),
+ ),
+ (4, None),
+ (7, None),
+ ];
+ assert_eq!(expected, id_array);
+
+ Ok(())
+ }
+}
diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs
index 8323f5efc8..e4575038ab 100644
--- a/datafusion/common/src/lib.rs
+++ b/datafusion/common/src/lib.rs
@@ -31,6 +31,7 @@ mod unnest;
pub mod alias;
pub mod cast;
pub mod config;
+pub mod cse;
pub mod display;
pub mod error;
pub mod file_options;
diff --git a/datafusion/common/src/tree_node.rs
b/datafusion/common/src/tree_node.rs
index b4d3251fd2..563f1fa856 100644
--- a/datafusion/common/src/tree_node.rs
+++ b/datafusion/common/src/tree_node.rs
@@ -1027,7 +1027,7 @@ impl<T: ConcreteTreeNode> TreeNode for T {
}
#[cfg(test)]
-mod tests {
+pub(crate) mod tests {
use std::collections::HashMap;
use std::fmt::Display;
@@ -1037,16 +1037,27 @@ mod tests {
};
use crate::Result;
- #[derive(Debug, Eq, Hash, PartialEq)]
- struct TestTreeNode<T> {
- children: Vec<TestTreeNode<T>>,
- data: T,
+ #[derive(Debug, Eq, Hash, PartialEq, Clone)]
+ pub struct TestTreeNode<T> {
+ pub(crate) children: Vec<TestTreeNode<T>>,
+ pub(crate) data: T,
}
impl<T> TestTreeNode<T> {
- fn new(children: Vec<TestTreeNode<T>>, data: T) -> Self {
+ pub(crate) fn new(children: Vec<TestTreeNode<T>>, data: T) -> Self {
Self { children, data }
}
+
+ pub(crate) fn new_leaf(data: T) -> Self {
+ Self {
+ children: vec![],
+ data,
+ }
+ }
+
+ pub(crate) fn is_leaf(&self) -> bool {
+ self.children.is_empty()
+ }
}
impl<T> TreeNode for TestTreeNode<T> {
@@ -1086,12 +1097,12 @@ mod tests {
// |
// A
fn test_tree() -> TestTreeNode<String> {
- let node_a = TestTreeNode::new(vec![], "a".to_string());
- let node_b = TestTreeNode::new(vec![], "b".to_string());
+ let node_a = TestTreeNode::new_leaf("a".to_string());
+ let node_b = TestTreeNode::new_leaf("b".to_string());
let node_d = TestTreeNode::new(vec![node_a], "d".to_string());
let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string());
let node_e = TestTreeNode::new(vec![node_c], "e".to_string());
- let node_h = TestTreeNode::new(vec![], "h".to_string());
+ let node_h = TestTreeNode::new_leaf("h".to_string());
let node_g = TestTreeNode::new(vec![node_h], "g".to_string());
let node_f = TestTreeNode::new(vec![node_e, node_g], "f".to_string());
let node_i = TestTreeNode::new(vec![node_f], "i".to_string());
@@ -1130,13 +1141,13 @@ mod tests {
// Expected transformed tree after a combined traversal
fn transformed_tree() -> TestTreeNode<String> {
- let node_a = TestTreeNode::new(vec![], "f_up(f_down(a))".to_string());
- let node_b = TestTreeNode::new(vec![], "f_up(f_down(b))".to_string());
+ let node_a = TestTreeNode::new_leaf("f_up(f_down(a))".to_string());
+ let node_b = TestTreeNode::new_leaf("f_up(f_down(b))".to_string());
let node_d = TestTreeNode::new(vec![node_a],
"f_up(f_down(d))".to_string());
let node_c =
TestTreeNode::new(vec![node_b, node_d],
"f_up(f_down(c))".to_string());
let node_e = TestTreeNode::new(vec![node_c],
"f_up(f_down(e))".to_string());
- let node_h = TestTreeNode::new(vec![], "f_up(f_down(h))".to_string());
+ let node_h = TestTreeNode::new_leaf("f_up(f_down(h))".to_string());
let node_g = TestTreeNode::new(vec![node_h],
"f_up(f_down(g))".to_string());
let node_f =
TestTreeNode::new(vec![node_e, node_g],
"f_up(f_down(f))".to_string());
@@ -1146,12 +1157,12 @@ mod tests {
// Expected transformed tree after a top-down traversal
fn transformed_down_tree() -> TestTreeNode<String> {
- let node_a = TestTreeNode::new(vec![], "f_down(a)".to_string());
- let node_b = TestTreeNode::new(vec![], "f_down(b)".to_string());
+ let node_a = TestTreeNode::new_leaf("f_down(a)".to_string());
+ let node_b = TestTreeNode::new_leaf("f_down(b)".to_string());
let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string());
let node_c = TestTreeNode::new(vec![node_b, node_d],
"f_down(c)".to_string());
let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string());
- let node_h = TestTreeNode::new(vec![], "f_down(h)".to_string());
+ let node_h = TestTreeNode::new_leaf("f_down(h)".to_string());
let node_g = TestTreeNode::new(vec![node_h], "f_down(g)".to_string());
let node_f = TestTreeNode::new(vec![node_e, node_g],
"f_down(f)".to_string());
let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string());
@@ -1160,12 +1171,12 @@ mod tests {
// Expected transformed tree after a bottom-up traversal
fn transformed_up_tree() -> TestTreeNode<String> {
- let node_a = TestTreeNode::new(vec![], "f_up(a)".to_string());
- let node_b = TestTreeNode::new(vec![], "f_up(b)".to_string());
+ let node_a = TestTreeNode::new_leaf("f_up(a)".to_string());
+ let node_b = TestTreeNode::new_leaf("f_up(b)".to_string());
let node_d = TestTreeNode::new(vec![node_a], "f_up(d)".to_string());
let node_c = TestTreeNode::new(vec![node_b, node_d],
"f_up(c)".to_string());
let node_e = TestTreeNode::new(vec![node_c], "f_up(e)".to_string());
- let node_h = TestTreeNode::new(vec![], "f_up(h)".to_string());
+ let node_h = TestTreeNode::new_leaf("f_up(h)".to_string());
let node_g = TestTreeNode::new(vec![node_h], "f_up(g)".to_string());
let node_f = TestTreeNode::new(vec![node_e, node_g],
"f_up(f)".to_string());
let node_i = TestTreeNode::new(vec![node_f], "f_up(i)".to_string());
@@ -1202,12 +1213,12 @@ mod tests {
}
fn f_down_jump_on_a_transformed_down_tree() -> TestTreeNode<String> {
- let node_a = TestTreeNode::new(vec![], "f_down(a)".to_string());
- let node_b = TestTreeNode::new(vec![], "f_down(b)".to_string());
+ let node_a = TestTreeNode::new_leaf("f_down(a)".to_string());
+ let node_b = TestTreeNode::new_leaf("f_down(b)".to_string());
let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string());
let node_c = TestTreeNode::new(vec![node_b, node_d],
"f_down(c)".to_string());
let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string());
- let node_h = TestTreeNode::new(vec![], "f_down(h)".to_string());
+ let node_h = TestTreeNode::new_leaf("f_down(h)".to_string());
let node_g = TestTreeNode::new(vec![node_h], "f_down(g)".to_string());
let node_f = TestTreeNode::new(vec![node_e, node_g],
"f_down(f)".to_string());
let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string());
@@ -1236,12 +1247,12 @@ mod tests {
}
fn f_down_jump_on_e_transformed_tree() -> TestTreeNode<String> {
- let node_a = TestTreeNode::new(vec![], "a".to_string());
- let node_b = TestTreeNode::new(vec![], "b".to_string());
+ let node_a = TestTreeNode::new_leaf("a".to_string());
+ let node_b = TestTreeNode::new_leaf("b".to_string());
let node_d = TestTreeNode::new(vec![node_a], "d".to_string());
let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string());
let node_e = TestTreeNode::new(vec![node_c],
"f_up(f_down(e))".to_string());
- let node_h = TestTreeNode::new(vec![], "f_up(f_down(h))".to_string());
+ let node_h = TestTreeNode::new_leaf("f_up(f_down(h))".to_string());
let node_g = TestTreeNode::new(vec![node_h],
"f_up(f_down(g))".to_string());
let node_f =
TestTreeNode::new(vec![node_e, node_g],
"f_up(f_down(f))".to_string());
@@ -1250,12 +1261,12 @@ mod tests {
}
fn f_down_jump_on_e_transformed_down_tree() -> TestTreeNode<String> {
- let node_a = TestTreeNode::new(vec![], "a".to_string());
- let node_b = TestTreeNode::new(vec![], "b".to_string());
+ let node_a = TestTreeNode::new_leaf("a".to_string());
+ let node_b = TestTreeNode::new_leaf("b".to_string());
let node_d = TestTreeNode::new(vec![node_a], "d".to_string());
let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string());
let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string());
- let node_h = TestTreeNode::new(vec![], "f_down(h)".to_string());
+ let node_h = TestTreeNode::new_leaf("f_down(h)".to_string());
let node_g = TestTreeNode::new(vec![node_h], "f_down(g)".to_string());
let node_f = TestTreeNode::new(vec![node_e, node_g],
"f_down(f)".to_string());
let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string());
@@ -1289,12 +1300,12 @@ mod tests {
}
fn f_up_jump_on_a_transformed_tree() -> TestTreeNode<String> {
- let node_a = TestTreeNode::new(vec![], "f_up(f_down(a))".to_string());
- let node_b = TestTreeNode::new(vec![], "f_up(f_down(b))".to_string());
+ let node_a = TestTreeNode::new_leaf("f_up(f_down(a))".to_string());
+ let node_b = TestTreeNode::new_leaf("f_up(f_down(b))".to_string());
let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string());
let node_c = TestTreeNode::new(vec![node_b, node_d],
"f_down(c)".to_string());
let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string());
- let node_h = TestTreeNode::new(vec![], "f_up(f_down(h))".to_string());
+ let node_h = TestTreeNode::new_leaf("f_up(f_down(h))".to_string());
let node_g = TestTreeNode::new(vec![node_h],
"f_up(f_down(g))".to_string());
let node_f =
TestTreeNode::new(vec![node_e, node_g],
"f_up(f_down(f))".to_string());
@@ -1303,12 +1314,12 @@ mod tests {
}
fn f_up_jump_on_a_transformed_up_tree() -> TestTreeNode<String> {
- let node_a = TestTreeNode::new(vec![], "f_up(a)".to_string());
- let node_b = TestTreeNode::new(vec![], "f_up(b)".to_string());
+ let node_a = TestTreeNode::new_leaf("f_up(a)".to_string());
+ let node_b = TestTreeNode::new_leaf("f_up(b)".to_string());
let node_d = TestTreeNode::new(vec![node_a], "d".to_string());
let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string());
let node_e = TestTreeNode::new(vec![node_c], "e".to_string());
- let node_h = TestTreeNode::new(vec![], "f_up(h)".to_string());
+ let node_h = TestTreeNode::new_leaf("f_up(h)".to_string());
let node_g = TestTreeNode::new(vec![node_h], "f_up(g)".to_string());
let node_f = TestTreeNode::new(vec![node_e, node_g],
"f_up(f)".to_string());
let node_i = TestTreeNode::new(vec![node_f], "f_up(i)".to_string());
@@ -1372,12 +1383,12 @@ mod tests {
}
fn f_down_stop_on_a_transformed_tree() -> TestTreeNode<String> {
- let node_a = TestTreeNode::new(vec![], "f_down(a)".to_string());
- let node_b = TestTreeNode::new(vec![], "f_up(f_down(b))".to_string());
+ let node_a = TestTreeNode::new_leaf("f_down(a)".to_string());
+ let node_b = TestTreeNode::new_leaf("f_up(f_down(b))".to_string());
let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string());
let node_c = TestTreeNode::new(vec![node_b, node_d],
"f_down(c)".to_string());
let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string());
- let node_h = TestTreeNode::new(vec![], "h".to_string());
+ let node_h = TestTreeNode::new_leaf("h".to_string());
let node_g = TestTreeNode::new(vec![node_h], "g".to_string());
let node_f = TestTreeNode::new(vec![node_e, node_g],
"f_down(f)".to_string());
let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string());
@@ -1385,12 +1396,12 @@ mod tests {
}
fn f_down_stop_on_a_transformed_down_tree() -> TestTreeNode<String> {
- let node_a = TestTreeNode::new(vec![], "f_down(a)".to_string());
- let node_b = TestTreeNode::new(vec![], "f_down(b)".to_string());
+ let node_a = TestTreeNode::new_leaf("f_down(a)".to_string());
+ let node_b = TestTreeNode::new_leaf("f_down(b)".to_string());
let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string());
let node_c = TestTreeNode::new(vec![node_b, node_d],
"f_down(c)".to_string());
let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string());
- let node_h = TestTreeNode::new(vec![], "h".to_string());
+ let node_h = TestTreeNode::new_leaf("h".to_string());
let node_g = TestTreeNode::new(vec![node_h], "g".to_string());
let node_f = TestTreeNode::new(vec![node_e, node_g],
"f_down(f)".to_string());
let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string());
@@ -1406,12 +1417,12 @@ mod tests {
}
fn f_down_stop_on_e_transformed_tree() -> TestTreeNode<String> {
- let node_a = TestTreeNode::new(vec![], "a".to_string());
- let node_b = TestTreeNode::new(vec![], "b".to_string());
+ let node_a = TestTreeNode::new_leaf("a".to_string());
+ let node_b = TestTreeNode::new_leaf("b".to_string());
let node_d = TestTreeNode::new(vec![node_a], "d".to_string());
let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string());
let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string());
- let node_h = TestTreeNode::new(vec![], "h".to_string());
+ let node_h = TestTreeNode::new_leaf("h".to_string());
let node_g = TestTreeNode::new(vec![node_h], "g".to_string());
let node_f = TestTreeNode::new(vec![node_e, node_g],
"f_down(f)".to_string());
let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string());
@@ -1419,12 +1430,12 @@ mod tests {
}
fn f_down_stop_on_e_transformed_down_tree() -> TestTreeNode<String> {
- let node_a = TestTreeNode::new(vec![], "a".to_string());
- let node_b = TestTreeNode::new(vec![], "b".to_string());
+ let node_a = TestTreeNode::new_leaf("a".to_string());
+ let node_b = TestTreeNode::new_leaf("b".to_string());
let node_d = TestTreeNode::new(vec![node_a], "d".to_string());
let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string());
let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string());
- let node_h = TestTreeNode::new(vec![], "h".to_string());
+ let node_h = TestTreeNode::new_leaf("h".to_string());
let node_g = TestTreeNode::new(vec![node_h], "g".to_string());
let node_f = TestTreeNode::new(vec![node_e, node_g],
"f_down(f)".to_string());
let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string());
@@ -1451,12 +1462,12 @@ mod tests {
}
fn f_up_stop_on_a_transformed_tree() -> TestTreeNode<String> {
- let node_a = TestTreeNode::new(vec![], "f_up(f_down(a))".to_string());
- let node_b = TestTreeNode::new(vec![], "f_up(f_down(b))".to_string());
+ let node_a = TestTreeNode::new_leaf("f_up(f_down(a))".to_string());
+ let node_b = TestTreeNode::new_leaf("f_up(f_down(b))".to_string());
let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string());
let node_c = TestTreeNode::new(vec![node_b, node_d],
"f_down(c)".to_string());
let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string());
- let node_h = TestTreeNode::new(vec![], "h".to_string());
+ let node_h = TestTreeNode::new_leaf("h".to_string());
let node_g = TestTreeNode::new(vec![node_h], "g".to_string());
let node_f = TestTreeNode::new(vec![node_e, node_g],
"f_down(f)".to_string());
let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string());
@@ -1464,12 +1475,12 @@ mod tests {
}
fn f_up_stop_on_a_transformed_up_tree() -> TestTreeNode<String> {
- let node_a = TestTreeNode::new(vec![], "f_up(a)".to_string());
- let node_b = TestTreeNode::new(vec![], "f_up(b)".to_string());
+ let node_a = TestTreeNode::new_leaf("f_up(a)".to_string());
+ let node_b = TestTreeNode::new_leaf("f_up(b)".to_string());
let node_d = TestTreeNode::new(vec![node_a], "d".to_string());
let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string());
let node_e = TestTreeNode::new(vec![node_c], "e".to_string());
- let node_h = TestTreeNode::new(vec![], "h".to_string());
+ let node_h = TestTreeNode::new_leaf("h".to_string());
let node_g = TestTreeNode::new(vec![node_h], "g".to_string());
let node_f = TestTreeNode::new(vec![node_e, node_g], "f".to_string());
let node_i = TestTreeNode::new(vec![node_f], "i".to_string());
@@ -1499,13 +1510,13 @@ mod tests {
}
fn f_up_stop_on_e_transformed_tree() -> TestTreeNode<String> {
- let node_a = TestTreeNode::new(vec![], "f_up(f_down(a))".to_string());
- let node_b = TestTreeNode::new(vec![], "f_up(f_down(b))".to_string());
+ let node_a = TestTreeNode::new_leaf("f_up(f_down(a))".to_string());
+ let node_b = TestTreeNode::new_leaf("f_up(f_down(b))".to_string());
let node_d = TestTreeNode::new(vec![node_a],
"f_up(f_down(d))".to_string());
let node_c =
TestTreeNode::new(vec![node_b, node_d],
"f_up(f_down(c))".to_string());
let node_e = TestTreeNode::new(vec![node_c],
"f_up(f_down(e))".to_string());
- let node_h = TestTreeNode::new(vec![], "h".to_string());
+ let node_h = TestTreeNode::new_leaf("h".to_string());
let node_g = TestTreeNode::new(vec![node_h], "g".to_string());
let node_f = TestTreeNode::new(vec![node_e, node_g],
"f_down(f)".to_string());
let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string());
@@ -1513,12 +1524,12 @@ mod tests {
}
fn f_up_stop_on_e_transformed_up_tree() -> TestTreeNode<String> {
- let node_a = TestTreeNode::new(vec![], "f_up(a)".to_string());
- let node_b = TestTreeNode::new(vec![], "f_up(b)".to_string());
+ let node_a = TestTreeNode::new_leaf("f_up(a)".to_string());
+ let node_b = TestTreeNode::new_leaf("f_up(b)".to_string());
let node_d = TestTreeNode::new(vec![node_a], "f_up(d)".to_string());
let node_c = TestTreeNode::new(vec![node_b, node_d],
"f_up(c)".to_string());
let node_e = TestTreeNode::new(vec![node_c], "f_up(e)".to_string());
- let node_h = TestTreeNode::new(vec![], "h".to_string());
+ let node_h = TestTreeNode::new_leaf("h".to_string());
let node_g = TestTreeNode::new(vec![node_h], "g".to_string());
let node_f = TestTreeNode::new(vec![node_e, node_g], "f".to_string());
let node_i = TestTreeNode::new(vec![node_f], "i".to_string());
@@ -2016,16 +2027,16 @@ mod tests {
// A
#[test]
fn test_apply_and_visit_references() -> Result<()> {
- let node_a = TestTreeNode::new(vec![], "a".to_string());
- let node_b = TestTreeNode::new(vec![], "b".to_string());
+ let node_a = TestTreeNode::new_leaf("a".to_string());
+ let node_b = TestTreeNode::new_leaf("b".to_string());
let node_d = TestTreeNode::new(vec![node_a], "d".to_string());
let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string());
let node_e = TestTreeNode::new(vec![node_c], "e".to_string());
- let node_a_2 = TestTreeNode::new(vec![], "a".to_string());
- let node_b_2 = TestTreeNode::new(vec![], "b".to_string());
+ let node_a_2 = TestTreeNode::new_leaf("a".to_string());
+ let node_b_2 = TestTreeNode::new_leaf("b".to_string());
let node_d_2 = TestTreeNode::new(vec![node_a_2], "d".to_string());
let node_c_2 = TestTreeNode::new(vec![node_b_2, node_d_2],
"c".to_string());
- let node_a_3 = TestTreeNode::new(vec![], "a".to_string());
+ let node_a_3 = TestTreeNode::new_leaf("a".to_string());
let tree = TestTreeNode::new(vec![node_e, node_c_2, node_a_3],
"f".to_string());
let node_f_ref = &tree;
diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs
index f3f71a8727..691b65d344 100644
--- a/datafusion/expr/src/expr.rs
+++ b/datafusion/expr/src/expr.rs
@@ -34,6 +34,7 @@ use crate::{
};
use arrow::datatypes::{DataType, FieldRef};
+use datafusion_common::cse::HashNode;
use datafusion_common::tree_node::{
Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
};
@@ -1652,47 +1653,39 @@ impl Expr {
| Expr::Placeholder(..) => false,
}
}
+}
- /// Hashes the direct content of an `Expr` without recursing into its
children.
- ///
- /// This method is useful to incrementally compute hashes, such as in
- /// `CommonSubexprEliminate` which builds a deep hash of a node and its
descendants
- /// during the bottom-up phase of the first traversal and so avoid
computing the hash
- /// of the node and then the hash of its descendants separately.
- ///
- /// If a node doesn't have any children then this method is similar to
`.hash()`, but
- /// not necessarily returns the same value.
- ///
+impl HashNode for Expr {
/// As it is pretty easy to forget changing this method when `Expr`
changes the
/// implementation doesn't use wildcard patterns (`..`, `_`) to catch
changes
/// compile time.
- pub fn hash_node<H: Hasher>(&self, hasher: &mut H) {
- mem::discriminant(self).hash(hasher);
+ fn hash_node<H: Hasher>(&self, state: &mut H) {
+ mem::discriminant(self).hash(state);
match self {
Expr::Alias(Alias {
expr: _expr,
relation,
name,
}) => {
- relation.hash(hasher);
- name.hash(hasher);
+ relation.hash(state);
+ name.hash(state);
}
Expr::Column(column) => {
- column.hash(hasher);
+ column.hash(state);
}
Expr::ScalarVariable(data_type, name) => {
- data_type.hash(hasher);
- name.hash(hasher);
+ data_type.hash(state);
+ name.hash(state);
}
Expr::Literal(scalar_value) => {
- scalar_value.hash(hasher);
+ scalar_value.hash(state);
}
Expr::BinaryExpr(BinaryExpr {
left: _left,
op,
right: _right,
}) => {
- op.hash(hasher);
+ op.hash(state);
}
Expr::Like(Like {
negated,
@@ -1708,9 +1701,9 @@ impl Expr {
escape_char,
case_insensitive,
}) => {
- negated.hash(hasher);
- escape_char.hash(hasher);
- case_insensitive.hash(hasher);
+ negated.hash(state);
+ escape_char.hash(state);
+ case_insensitive.hash(state);
}
Expr::Not(_expr)
| Expr::IsNotNull(_expr)
@@ -1728,7 +1721,7 @@ impl Expr {
low: _low,
high: _high,
}) => {
- negated.hash(hasher);
+ negated.hash(state);
}
Expr::Case(Case {
expr: _expr,
@@ -1743,10 +1736,10 @@ impl Expr {
expr: _expr,
data_type,
}) => {
- data_type.hash(hasher);
+ data_type.hash(state);
}
Expr::ScalarFunction(ScalarFunction { func, args: _args }) => {
- func.hash(hasher);
+ func.hash(state);
}
Expr::AggregateFunction(AggregateFunction {
func,
@@ -1756,9 +1749,9 @@ impl Expr {
order_by: _order_by,
null_treatment,
}) => {
- func.hash(hasher);
- distinct.hash(hasher);
- null_treatment.hash(hasher);
+ func.hash(state);
+ distinct.hash(state);
+ null_treatment.hash(state);
}
Expr::WindowFunction(WindowFunction {
fun,
@@ -1768,49 +1761,49 @@ impl Expr {
window_frame,
null_treatment,
}) => {
- fun.hash(hasher);
- window_frame.hash(hasher);
- null_treatment.hash(hasher);
+ fun.hash(state);
+ window_frame.hash(state);
+ null_treatment.hash(state);
}
Expr::InList(InList {
expr: _expr,
list: _list,
negated,
}) => {
- negated.hash(hasher);
+ negated.hash(state);
}
Expr::Exists(Exists { subquery, negated }) => {
- subquery.hash(hasher);
- negated.hash(hasher);
+ subquery.hash(state);
+ negated.hash(state);
}
Expr::InSubquery(InSubquery {
expr: _expr,
subquery,
negated,
}) => {
- subquery.hash(hasher);
- negated.hash(hasher);
+ subquery.hash(state);
+ negated.hash(state);
}
Expr::ScalarSubquery(subquery) => {
- subquery.hash(hasher);
+ subquery.hash(state);
}
Expr::Wildcard { qualifier, options } => {
- qualifier.hash(hasher);
- options.hash(hasher);
+ qualifier.hash(state);
+ options.hash(state);
}
Expr::GroupingSet(grouping_set) => {
- mem::discriminant(grouping_set).hash(hasher);
+ mem::discriminant(grouping_set).hash(state);
match grouping_set {
GroupingSet::Rollup(_exprs) | GroupingSet::Cube(_exprs) =>
{}
GroupingSet::GroupingSets(_exprs) => {}
}
}
Expr::Placeholder(place_holder) => {
- place_holder.hash(hasher);
+ place_holder.hash(state);
}
Expr::OuterReferenceColumn(data_type, column) => {
- data_type.hash(hasher);
- column.hash(hasher);
+ data_type.hash(state);
+ column.hash(state);
}
Expr::Unnest(Unnest { expr: _expr }) => {}
};
diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs
b/datafusion/optimizer/src/common_subexpr_eliminate.rs
index c13cb3a8e9..921011d33f 100644
--- a/datafusion/optimizer/src/common_subexpr_eliminate.rs
+++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs
@@ -17,8 +17,8 @@
//! [`CommonSubexprEliminate`] to avoid redundant computation of common
sub-expressions
-use std::collections::{BTreeSet, HashMap};
-use std::hash::{BuildHasher, Hash, Hasher, RandomState};
+use std::collections::BTreeSet;
+use std::fmt::Debug;
use std::sync::Arc;
use crate::{OptimizerConfig, OptimizerRule};
@@ -26,11 +26,9 @@ use crate::{OptimizerConfig, OptimizerRule};
use crate::optimizer::ApplyOrder;
use crate::utils::NamePreserver;
use datafusion_common::alias::AliasGenerator;
-use datafusion_common::hash_utils::combine_hashes;
-use datafusion_common::tree_node::{
- Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
TreeNodeRewriter,
- TreeNodeVisitor,
-};
+
+use datafusion_common::cse::{CSEController, FoundCommonNodes, CSE};
+use datafusion_common::tree_node::{Transformed, TreeNode};
use datafusion_common::{qualified_name, Column, DFSchema, DFSchemaRef, Result};
use datafusion_expr::expr::{Alias, ScalarFunction};
use datafusion_expr::logical_plan::{
@@ -38,81 +36,9 @@ use datafusion_expr::logical_plan::{
};
use datafusion_expr::tree_node::replace_sort_expressions;
use datafusion_expr::{col, BinaryExpr, Case, Expr, Operator};
-use indexmap::IndexMap;
const CSE_PREFIX: &str = "__common_expr";
-/// Identifier that represents a subexpression tree.
-///
-/// This identifier is designed to be efficient and "hash", "accumulate",
"equal" and
-/// "have no collision (as low as possible)"
-#[derive(Clone, Copy, Debug, Eq, PartialEq)]
-struct Identifier<'n> {
- // Hash of `expr` built up incrementally during the first, visiting
traversal, but its
- // value is not necessarily equal to `expr.hash()`.
- hash: u64,
- expr: &'n Expr,
-}
-
-impl<'n> Identifier<'n> {
- fn new(expr: &'n Expr, random_state: &RandomState) -> Self {
- let mut hasher = random_state.build_hasher();
- expr.hash_node(&mut hasher);
- let hash = hasher.finish();
- Self { hash, expr }
- }
-
- fn combine(mut self, other: Option<Self>) -> Self {
- other.map_or(self, |other_id| {
- self.hash = combine_hashes(self.hash, other_id.hash);
- self
- })
- }
-}
-
-impl Hash for Identifier<'_> {
- fn hash<H: Hasher>(&self, state: &mut H) {
- state.write_u64(self.hash);
- }
-}
-
-/// A cache that contains the postorder index and the identifier of expression
tree nodes
-/// by the preorder index of the nodes.
-///
-/// This cache is filled by `ExprIdentifierVisitor` during the first traversal
and is used
-/// by `CommonSubexprRewriter` during the second traversal.
-///
-/// The purpose of this cache is to quickly find the identifier of a node
during the
-/// second traversal.
-///
-/// Elements in this array are added during `f_down` so the indexes represent
the preorder
-/// index of expression nodes and thus element 0 belongs to the root of the
expression
-/// tree.
-/// The elements of the array are tuples that contain:
-/// - Postorder index that belongs to the preorder index. Assigned during
`f_up`, start
-/// from 0.
-/// - Identifier of the expression. If empty (`""`), expr should not be
considered for
-/// CSE.
-///
-/// # Example
-/// An expression like `(a + b)` would have the following `IdArray`:
-/// ```text
-/// [
-/// (2, "a + b"),
-/// (1, "a"),
-/// (0, "b")
-/// ]
-/// ```
-type IdArray<'n> = Vec<(usize, Option<Identifier<'n>>)>;
-
-/// A map that contains the number of normal and conditional occurrences of
expressions by
-/// their identifiers.
-type ExprStats<'n> = HashMap<Identifier<'n>, (usize, usize)>;
-
-/// A map that contains the common expressions and their alias extracted
during the
-/// second, rewriting traversal.
-type CommonExprs<'n> = IndexMap<Identifier<'n>, (Expr, String)>;
-
/// Performs Common Sub-expression Elimination optimization.
///
/// This optimization improves query performance by computing expressions that
@@ -140,168 +66,11 @@ type CommonExprs<'n> = IndexMap<Identifier<'n>, (Expr,
String)>;
/// ProjectionExec(exprs=[to_date(c1) as new_col]) <-- compute to_date once
/// ```
#[derive(Debug)]
-pub struct CommonSubexprEliminate {
- random_state: RandomState,
-}
-
-/// The result of potentially rewriting a list of expressions to eliminate
common
-/// subexpressions.
-#[derive(Debug)]
-enum FoundCommonExprs {
- /// No common expressions were found
- No { original_exprs_list: Vec<Vec<Expr>> },
- /// Common expressions were found
- Yes {
- /// extracted common expressions
- common_exprs: Vec<(Expr, String)>,
- /// new expressions with common subexpressions replaced
- new_exprs_list: Vec<Vec<Expr>>,
- /// original expressions
- original_exprs_list: Vec<Vec<Expr>>,
- },
-}
+pub struct CommonSubexprEliminate {}
impl CommonSubexprEliminate {
pub fn new() -> Self {
- Self {
- random_state: RandomState::new(),
- }
- }
-
- /// Returns the identifier list for each element in `exprs` and a flag to
indicate if
- /// rewrite phase of CSE make sense.
- ///
- /// Returns and array with 1 element for each input expr in `exprs`
- ///
- /// Each element is itself the result of
[`CommonSubexprEliminate::expr_to_identifier`] for that expr
- /// (e.g. the identifiers for each node in the tree)
- fn to_arrays<'n>(
- &self,
- exprs: &'n [Expr],
- expr_stats: &mut ExprStats<'n>,
- expr_mask: ExprMask,
- ) -> Result<(bool, Vec<IdArray<'n>>)> {
- let mut found_common = false;
- exprs
- .iter()
- .map(|e| {
- let mut id_array = vec![];
- self.expr_to_identifier(e, expr_stats, &mut id_array,
expr_mask)
- .map(|fc| {
- found_common |= fc;
-
- id_array
- })
- })
- .collect::<Result<Vec<_>>>()
- .map(|id_arrays| (found_common, id_arrays))
- }
-
- /// Add an identifier to `id_array` for every subexpression in this tree.
- fn expr_to_identifier<'n>(
- &self,
- expr: &'n Expr,
- expr_stats: &mut ExprStats<'n>,
- id_array: &mut IdArray<'n>,
- expr_mask: ExprMask,
- ) -> Result<bool> {
- let mut visitor = ExprIdentifierVisitor {
- expr_stats,
- id_array,
- visit_stack: vec![],
- down_index: 0,
- up_index: 0,
- expr_mask,
- random_state: &self.random_state,
- found_common: false,
- conditional: false,
- };
- expr.visit(&mut visitor)?;
-
- Ok(visitor.found_common)
- }
-
- /// Rewrites `exprs_list` with common sub-expressions replaced with a new
- /// column.
- ///
- /// `common_exprs` is updated with any sub expressions that were replaced.
- ///
- /// Returns the rewritten expressions
- fn rewrite_exprs_list<'n>(
- &self,
- exprs_list: Vec<Vec<Expr>>,
- arrays_list: &[Vec<IdArray<'n>>],
- expr_stats: &ExprStats<'n>,
- common_exprs: &mut CommonExprs<'n>,
- alias_generator: &AliasGenerator,
- ) -> Result<Vec<Vec<Expr>>> {
- exprs_list
- .into_iter()
- .zip(arrays_list.iter())
- .map(|(exprs, arrays)| {
- exprs
- .into_iter()
- .zip(arrays.iter())
- .map(|(expr, id_array)| {
- replace_common_expr(
- expr,
- id_array,
- expr_stats,
- common_exprs,
- alias_generator,
- )
- })
- .collect::<Result<Vec<_>>>()
- })
- .collect::<Result<Vec<_>>>()
- }
-
- /// Extracts common sub-expressions and rewrites `exprs_list`.
- ///
- /// Returns `FoundCommonExprs` recording the result of the extraction
- fn find_common_exprs(
- &self,
- exprs_list: Vec<Vec<Expr>>,
- config: &dyn OptimizerConfig,
- expr_mask: ExprMask,
- ) -> Result<Transformed<FoundCommonExprs>> {
- let mut found_common = false;
- let mut expr_stats = ExprStats::new();
- let id_arrays_list = exprs_list
- .iter()
- .map(|exprs| {
- self.to_arrays(exprs, &mut expr_stats, expr_mask).map(
- |(fc, id_arrays)| {
- found_common |= fc;
-
- id_arrays
- },
- )
- })
- .collect::<Result<Vec<_>>>()?;
- if found_common {
- let mut common_exprs = CommonExprs::new();
- let new_exprs_list = self.rewrite_exprs_list(
- // Must clone as Identifiers use references to original
expressions so we have
- // to keep the original expressions intact.
- exprs_list.clone(),
- &id_arrays_list,
- &expr_stats,
- &mut common_exprs,
- config.alias_generator().as_ref(),
- )?;
- assert!(!common_exprs.is_empty());
-
- Ok(Transformed::yes(FoundCommonExprs::Yes {
- common_exprs: common_exprs.into_values().collect(),
- new_exprs_list,
- original_exprs_list: exprs_list,
- }))
- } else {
- Ok(Transformed::no(FoundCommonExprs::No {
- original_exprs_list: exprs_list,
- }))
- }
+ Self {}
}
fn try_optimize_proj(
@@ -372,80 +141,83 @@ impl CommonSubexprEliminate {
get_consecutive_window_exprs(window);
// Extract common sub-expressions from the list.
- self.find_common_exprs(window_expr_list, config, ExprMask::Normal)?
- .map_data(|common| match common {
- // If there are common sub-expressions, then the insert a
projection node
- // with the common expressions between the new window nodes
and the
- // original input.
- FoundCommonExprs::Yes {
- common_exprs,
- new_exprs_list,
- original_exprs_list,
- } => {
- build_common_expr_project_plan(input,
common_exprs).map(|new_input| {
- (new_exprs_list, new_input, Some(original_exprs_list))
+
+ match CSE::new(ExprCSEController::new(
+ config.alias_generator().as_ref(),
+ ExprMask::Normal,
+ ))
+ .extract_common_nodes(window_expr_list)?
+ {
+ // If there are common sub-expressions, then the insert a
projection node
+ // with the common expressions between the new window nodes and the
+ // original input.
+ FoundCommonNodes::Yes {
+ common_nodes: common_exprs,
+ new_nodes_list: new_exprs_list,
+ original_nodes_list: original_exprs_list,
+ } => build_common_expr_project_plan(input,
common_exprs).map(|new_input| {
+ Transformed::yes((new_exprs_list, new_input,
Some(original_exprs_list)))
+ }),
+ FoundCommonNodes::No {
+ original_nodes_list: original_exprs_list,
+ } => Ok(Transformed::no((original_exprs_list, input, None))),
+ }?
+ // Recurse into the new input.
+ // (This is similar to what a `ApplyOrder::TopDown` optimizer rule
would do.)
+ .transform_data(|(new_window_expr_list, new_input, window_expr_list)| {
+ self.rewrite(new_input, config)?.map_data(|new_input| {
+ Ok((new_window_expr_list, new_input, window_expr_list))
+ })
+ })?
+ // Rebuild the consecutive window nodes.
+ .map_data(|(new_window_expr_list, new_input, window_expr_list)| {
+ // If there were common expressions extracted, then we need to
make sure
+ // we restore the original column names.
+ // TODO: Although `find_common_exprs()` inserts aliases around
extracted
+ // common expressions this doesn't mean that the original column
names
+ // (schema) are preserved due to the inserted aliases are not
always at
+ // the top of the expression.
+ // Let's consider improving `find_common_exprs()` to always keep
column
+ // names and get rid of additional name preserving logic here.
+ if let Some(window_expr_list) = window_expr_list {
+ let name_preserver = NamePreserver::new_for_projection();
+ let saved_names = window_expr_list
+ .iter()
+ .map(|exprs| {
+ exprs
+ .iter()
+ .map(|expr| name_preserver.save(expr))
+ .collect::<Vec<_>>()
})
- }
- FoundCommonExprs::No {
- original_exprs_list,
- } => Ok((original_exprs_list, input, None)),
- })?
- // Recurse into the new input.
- // (This is similar to what a `ApplyOrder::TopDown` optimizer rule
would do.)
- .transform_data(|(new_window_expr_list, new_input,
window_expr_list)| {
- self.rewrite(new_input, config)?.map_data(|new_input| {
- Ok((new_window_expr_list, new_input, window_expr_list))
- })
- })?
- // Rebuild the consecutive window nodes.
- .map_data(|(new_window_expr_list, new_input, window_expr_list)| {
- // If there were common expressions extracted, then we need to
make sure
- // we restore the original column names.
- // TODO: Although `find_common_exprs()` inserts aliases around
extracted
- // common expressions this doesn't mean that the original
column names
- // (schema) are preserved due to the inserted aliases are not
always at
- // the top of the expression.
- // Let's consider improving `find_common_exprs()` to always
keep column
- // names and get rid of additional name preserving logic here.
- if let Some(window_expr_list) = window_expr_list {
- let name_preserver = NamePreserver::new_for_projection();
- let saved_names = window_expr_list
- .iter()
- .map(|exprs| {
- exprs
- .iter()
- .map(|expr| name_preserver.save(expr))
- .collect::<Vec<_>>()
- })
- .collect::<Vec<_>>();
-
new_window_expr_list.into_iter().zip(saved_names).try_rfold(
- new_input,
- |plan, (new_window_expr, saved_names)| {
- let new_window_expr = new_window_expr
- .into_iter()
- .zip(saved_names)
- .map(|(new_window_expr, saved_name)| {
- saved_name.restore(new_window_expr)
- })
- .collect::<Vec<_>>();
- Window::try_new(new_window_expr, Arc::new(plan))
- .map(LogicalPlan::Window)
- },
- )
- } else {
- new_window_expr_list
- .into_iter()
- .zip(window_schemas)
- .try_rfold(new_input, |plan, (new_window_expr,
schema)| {
- Window::try_new_with_schema(
- new_window_expr,
- Arc::new(plan),
- schema,
- )
+ .collect::<Vec<_>>();
+ new_window_expr_list.into_iter().zip(saved_names).try_rfold(
+ new_input,
+ |plan, (new_window_expr, saved_names)| {
+ let new_window_expr = new_window_expr
+ .into_iter()
+ .zip(saved_names)
+ .map(|(new_window_expr, saved_name)| {
+ saved_name.restore(new_window_expr)
+ })
+ .collect::<Vec<_>>();
+ Window::try_new(new_window_expr, Arc::new(plan))
.map(LogicalPlan::Window)
- })
- }
- })
+ },
+ )
+ } else {
+ new_window_expr_list
+ .into_iter()
+ .zip(window_schemas)
+ .try_rfold(new_input, |plan, (new_window_expr, schema)| {
+ Window::try_new_with_schema(
+ new_window_expr,
+ Arc::new(plan),
+ schema,
+ )
+ .map(LogicalPlan::Window)
+ })
+ }
+ })
}
fn try_optimize_aggregate(
@@ -462,174 +234,175 @@ impl CommonSubexprEliminate {
} = aggregate;
let input = Arc::unwrap_or_clone(input);
// Extract common sub-expressions from the aggregate and grouping
expressions.
- self.find_common_exprs(vec![group_expr, aggr_expr], config,
ExprMask::Normal)?
- .map_data(|common| {
- match common {
- // If there are common sub-expressions, then insert a
projection node
- // with the common expressions between the new aggregate
node and the
- // original input.
- FoundCommonExprs::Yes {
- common_exprs,
- mut new_exprs_list,
- mut original_exprs_list,
- } => {
- let new_aggr_expr = new_exprs_list.pop().unwrap();
- let new_group_expr = new_exprs_list.pop().unwrap();
-
- build_common_expr_project_plan(input,
common_exprs).map(
- |new_input| {
- let aggr_expr =
original_exprs_list.pop().unwrap();
- (
- new_aggr_expr,
- new_group_expr,
- new_input,
- Some(aggr_expr),
- )
- },
- )
- }
-
- FoundCommonExprs::No {
- mut original_exprs_list,
- } => {
- let new_aggr_expr = original_exprs_list.pop().unwrap();
- let new_group_expr =
original_exprs_list.pop().unwrap();
-
- Ok((new_aggr_expr, new_group_expr, input, None))
- }
- }
- })?
- // Recurse into the new input.
- // (This is similar to what a `ApplyOrder::TopDown` optimizer rule
would do.)
- .transform_data(|(new_aggr_expr, new_group_expr, new_input,
aggr_expr)| {
- self.rewrite(new_input, config)?.map_data(|new_input| {
- Ok((
+ match CSE::new(ExprCSEController::new(
+ config.alias_generator().as_ref(),
+ ExprMask::Normal,
+ ))
+ .extract_common_nodes(vec![group_expr, aggr_expr])?
+ {
+ // If there are common sub-expressions, then insert a projection
node
+ // with the common expressions between the new aggregate node and
the
+ // original input.
+ FoundCommonNodes::Yes {
+ common_nodes: common_exprs,
+ new_nodes_list: mut new_exprs_list,
+ original_nodes_list: mut original_exprs_list,
+ } => {
+ let new_aggr_expr = new_exprs_list.pop().unwrap();
+ let new_group_expr = new_exprs_list.pop().unwrap();
+
+ build_common_expr_project_plan(input,
common_exprs).map(|new_input| {
+ let aggr_expr = original_exprs_list.pop().unwrap();
+ Transformed::yes((
new_aggr_expr,
new_group_expr,
- aggr_expr,
- Arc::new(new_input),
+ new_input,
+ Some(aggr_expr),
))
})
- })?
- // Try extracting common aggregate expressions and rebuild the
aggregate node.
- .transform_data(|(new_aggr_expr, new_group_expr, aggr_expr,
new_input)| {
+ }
+
+ FoundCommonNodes::No {
+ original_nodes_list: mut original_exprs_list,
+ } => {
+ let new_aggr_expr = original_exprs_list.pop().unwrap();
+ let new_group_expr = original_exprs_list.pop().unwrap();
+
+ Ok(Transformed::no((
+ new_aggr_expr,
+ new_group_expr,
+ input,
+ None,
+ )))
+ }
+ }?
+ // Recurse into the new input.
+ // (This is similar to what a `ApplyOrder::TopDown` optimizer rule
would do.)
+ .transform_data(|(new_aggr_expr, new_group_expr, new_input,
aggr_expr)| {
+ self.rewrite(new_input, config)?.map_data(|new_input| {
+ Ok((
+ new_aggr_expr,
+ new_group_expr,
+ aggr_expr,
+ Arc::new(new_input),
+ ))
+ })
+ })?
+ // Try extracting common aggregate expressions and rebuild the
aggregate node.
+ .transform_data(
+ |(new_aggr_expr, new_group_expr, aggr_expr, new_input)| {
// Extract common aggregate sub-expressions from the aggregate
expressions.
- self.find_common_exprs(
- vec![new_aggr_expr],
- config,
+ match CSE::new(ExprCSEController::new(
+ config.alias_generator().as_ref(),
ExprMask::NormalAndAggregates,
- )?
- .map_data(|common| {
- match common {
- FoundCommonExprs::Yes {
- common_exprs,
- mut new_exprs_list,
- mut original_exprs_list,
- } => {
- let rewritten_aggr_expr =
new_exprs_list.pop().unwrap();
- let new_aggr_expr =
original_exprs_list.pop().unwrap();
-
- let mut agg_exprs = common_exprs
- .into_iter()
- .map(|(expr, expr_alias)|
expr.alias(expr_alias))
- .collect::<Vec<_>>();
+ ))
+ .extract_common_nodes(vec![new_aggr_expr])?
+ {
+ FoundCommonNodes::Yes {
+ common_nodes: common_exprs,
+ new_nodes_list: mut new_exprs_list,
+ original_nodes_list: mut original_exprs_list,
+ } => {
+ let rewritten_aggr_expr =
new_exprs_list.pop().unwrap();
+ let new_aggr_expr = original_exprs_list.pop().unwrap();
- let mut proj_exprs = vec![];
- for expr in &new_group_expr {
- extract_expressions(expr, &mut proj_exprs)
- }
- for (expr_rewritten, expr_orig) in
-
rewritten_aggr_expr.into_iter().zip(new_aggr_expr)
- {
- if expr_rewritten == expr_orig {
- if let Expr::Alias(Alias { expr, name, ..
}) =
- expr_rewritten
- {
- agg_exprs.push(expr.alias(&name));
- proj_exprs
-
.push(Expr::Column(Column::from_name(name)));
- } else {
- let expr_alias =
-
config.alias_generator().next(CSE_PREFIX);
- let (qualifier, field_name) =
- expr_rewritten.qualified_name();
- let out_name = qualified_name(
- qualifier.as_ref(),
- &field_name,
- );
-
-
agg_exprs.push(expr_rewritten.alias(&expr_alias));
- proj_exprs.push(
-
Expr::Column(Column::from_name(expr_alias))
- .alias(out_name),
- );
- }
+ let mut agg_exprs = common_exprs
+ .into_iter()
+ .map(|(expr, expr_alias)| expr.alias(expr_alias))
+ .collect::<Vec<_>>();
+
+ let mut proj_exprs = vec![];
+ for expr in &new_group_expr {
+ extract_expressions(expr, &mut proj_exprs)
+ }
+ for (expr_rewritten, expr_orig) in
+ rewritten_aggr_expr.into_iter().zip(new_aggr_expr)
+ {
+ if expr_rewritten == expr_orig {
+ if let Expr::Alias(Alias { expr, name, .. }) =
+ expr_rewritten
+ {
+ agg_exprs.push(expr.alias(&name));
+ proj_exprs
+
.push(Expr::Column(Column::from_name(name)));
} else {
- proj_exprs.push(expr_rewritten);
+ let expr_alias =
+
config.alias_generator().next(CSE_PREFIX);
+ let (qualifier, field_name) =
+ expr_rewritten.qualified_name();
+ let out_name =
+ qualified_name(qualifier.as_ref(),
&field_name);
+
+
agg_exprs.push(expr_rewritten.alias(&expr_alias));
+ proj_exprs.push(
+
Expr::Column(Column::from_name(expr_alias))
+ .alias(out_name),
+ );
}
+ } else {
+ proj_exprs.push(expr_rewritten);
}
-
- let agg =
LogicalPlan::Aggregate(Aggregate::try_new(
- new_input,
- new_group_expr,
- agg_exprs,
- )?);
- Projection::try_new(proj_exprs, Arc::new(agg))
- .map(LogicalPlan::Projection)
}
- // If there aren't any common aggregate
sub-expressions, then just
- // rebuild the aggregate node.
- FoundCommonExprs::No {
- mut original_exprs_list,
- } => {
- let rewritten_aggr_expr =
original_exprs_list.pop().unwrap();
-
- // If there were common expressions extracted,
then we need to
- // make sure we restore the original column names.
- // TODO: Although `find_common_exprs()` inserts
aliases around
- // extracted common expressions this doesn't mean
that the
- // original column names (schema) are preserved
due to the
- // inserted aliases are not always at the top of
the
- // expression.
- // Let's consider improving `find_common_exprs()`
to always
- // keep column names and get rid of additional
name
- // preserving logic here.
- if let Some(aggr_expr) = aggr_expr {
- let name_perserver =
NamePreserver::new_for_projection();
- let saved_names = aggr_expr
- .iter()
- .map(|expr| name_perserver.save(expr))
- .collect::<Vec<_>>();
- let new_aggr_expr = rewritten_aggr_expr
- .into_iter()
- .zip(saved_names)
- .map(|(new_expr, saved_name)| {
- saved_name.restore(new_expr)
- })
- .collect::<Vec<Expr>>();
-
- // Since `group_expr` may have changed, schema
may also.
- // Use `try_new()` method.
- Aggregate::try_new(
- new_input,
- new_group_expr,
- new_aggr_expr,
- )
- .map(LogicalPlan::Aggregate)
- } else {
- Aggregate::try_new_with_schema(
- new_input,
- new_group_expr,
- rewritten_aggr_expr,
- schema,
- )
+ let agg = LogicalPlan::Aggregate(Aggregate::try_new(
+ new_input,
+ new_group_expr,
+ agg_exprs,
+ )?);
+ Projection::try_new(proj_exprs, Arc::new(agg))
+ .map(|p|
Transformed::yes(LogicalPlan::Projection(p)))
+ }
+
+ // If there aren't any common aggregate sub-expressions,
then just
+ // rebuild the aggregate node.
+ FoundCommonNodes::No {
+ original_nodes_list: mut original_exprs_list,
+ } => {
+ let rewritten_aggr_expr =
original_exprs_list.pop().unwrap();
+
+ // If there were common expressions extracted, then we
need to
+ // make sure we restore the original column names.
+ // TODO: Although `find_common_exprs()` inserts
aliases around
+ // extracted common expressions this doesn't mean
that the
+ // original column names (schema) are preserved due
to the
+ // inserted aliases are not always at the top of the
+ // expression.
+ // Let's consider improving `find_common_exprs()` to
always
+ // keep column names and get rid of additional name
+ // preserving logic here.
+ if let Some(aggr_expr) = aggr_expr {
+ let name_perserver =
NamePreserver::new_for_projection();
+ let saved_names = aggr_expr
+ .iter()
+ .map(|expr| name_perserver.save(expr))
+ .collect::<Vec<_>>();
+ let new_aggr_expr = rewritten_aggr_expr
+ .into_iter()
+ .zip(saved_names)
+ .map(|(new_expr, saved_name)| {
+ saved_name.restore(new_expr)
+ })
+ .collect::<Vec<Expr>>();
+
+ // Since `group_expr` may have changed, schema may
also.
+ // Use `try_new()` method.
+ Aggregate::try_new(new_input, new_group_expr,
new_aggr_expr)
.map(LogicalPlan::Aggregate)
- }
+ .map(Transformed::no)
+ } else {
+ Aggregate::try_new_with_schema(
+ new_input,
+ new_group_expr,
+ rewritten_aggr_expr,
+ schema,
+ )
+ .map(LogicalPlan::Aggregate)
+ .map(Transformed::no)
}
}
- })
- })
+ }
+ },
+ )
}
/// Rewrites the expr list and input to remove common subexpressions
@@ -653,30 +426,34 @@ impl CommonSubexprEliminate {
config: &dyn OptimizerConfig,
) -> Result<Transformed<(Vec<Expr>, LogicalPlan)>> {
// Extract common sub-expressions from the expressions.
- self.find_common_exprs(vec![exprs], config, ExprMask::Normal)?
- .map_data(|common| match common {
- FoundCommonExprs::Yes {
- common_exprs,
- mut new_exprs_list,
- original_exprs_list: _,
- } => {
- let new_exprs = new_exprs_list.pop().unwrap();
- build_common_expr_project_plan(input, common_exprs)
- .map(|new_input| (new_exprs, new_input))
- }
- FoundCommonExprs::No {
- mut original_exprs_list,
- } => {
- let new_exprs = original_exprs_list.pop().unwrap();
- Ok((new_exprs, input))
- }
- })?
- // Recurse into the new input.
- // (This is similar to what a `ApplyOrder::TopDown` optimizer rule
would do.)
- .transform_data(|(new_exprs, new_input)| {
- self.rewrite(new_input, config)?
- .map_data(|new_input| Ok((new_exprs, new_input)))
- })
+ match CSE::new(ExprCSEController::new(
+ config.alias_generator().as_ref(),
+ ExprMask::Normal,
+ ))
+ .extract_common_nodes(vec![exprs])?
+ {
+ FoundCommonNodes::Yes {
+ common_nodes: common_exprs,
+ new_nodes_list: mut new_exprs_list,
+ original_nodes_list: _,
+ } => {
+ let new_exprs = new_exprs_list.pop().unwrap();
+ build_common_expr_project_plan(input, common_exprs)
+ .map(|new_input| Transformed::yes((new_exprs, new_input)))
+ }
+ FoundCommonNodes::No {
+ original_nodes_list: mut original_exprs_list,
+ } => {
+ let new_exprs = original_exprs_list.pop().unwrap();
+ Ok(Transformed::no((new_exprs, input)))
+ }
+ }?
+ // Recurse into the new input.
+ // (This is similar to what a `ApplyOrder::TopDown` optimizer rule
would do.)
+ .transform_data(|(new_exprs, new_input)| {
+ self.rewrite(new_input, config)?
+ .map_data(|new_input| Ok((new_exprs, new_input)))
+ })
}
}
@@ -800,71 +577,6 @@ impl OptimizerRule for CommonSubexprEliminate {
}
}
-impl Default for CommonSubexprEliminate {
- fn default() -> Self {
- Self::new()
- }
-}
-
-/// Build the "intermediate" projection plan that evaluates the extracted
common
-/// expressions.
-///
-/// # Arguments
-/// input: the input plan
-///
-/// common_exprs: which common subexpressions were used (and thus are added to
-/// intermediate projection)
-///
-/// expr_stats: the set of common subexpressions
-fn build_common_expr_project_plan(
- input: LogicalPlan,
- common_exprs: Vec<(Expr, String)>,
-) -> Result<LogicalPlan> {
- let mut fields_set = BTreeSet::new();
- let mut project_exprs = common_exprs
- .into_iter()
- .map(|(expr, expr_alias)| {
- fields_set.insert(expr_alias.clone());
- Ok(expr.alias(expr_alias))
- })
- .collect::<Result<Vec<_>>>()?;
-
- for (qualifier, field) in input.schema().iter() {
- if fields_set.insert(qualified_name(qualifier, field.name())) {
- project_exprs.push(Expr::from((qualifier, field)));
- }
- }
-
- Projection::try_new(project_exprs,
Arc::new(input)).map(LogicalPlan::Projection)
-}
-
-/// Build the projection plan to eliminate unnecessary columns produced by
-/// the "intermediate" projection plan built in
[build_common_expr_project_plan].
-///
-/// This is required to keep the schema the same for plans that pass the input
-/// on to the output, such as `Filter` or `Sort`.
-fn build_recover_project_plan(
- schema: &DFSchema,
- input: LogicalPlan,
-) -> Result<LogicalPlan> {
- let col_exprs = schema.iter().map(Expr::from).collect();
- Projection::try_new(col_exprs,
Arc::new(input)).map(LogicalPlan::Projection)
-}
-
-fn extract_expressions(expr: &Expr, result: &mut Vec<Expr>) {
- if let Expr::GroupingSet(groupings) = expr {
- for e in groupings.distinct_expr() {
- let (qualifier, field_name) = e.qualified_name();
- let col = Column::new(qualifier, field_name);
- result.push(Expr::Column(col))
- }
- } else {
- let (qualifier, field_name) = expr.qualified_name();
- let col = Column::new(qualifier, field_name);
- result.push(Expr::Column(col));
- }
-}
-
/// Which type of [expressions](Expr) should be considered for rewriting?
#[derive(Debug, Clone, Copy)]
enum ExprMask {
@@ -882,156 +594,36 @@ enum ExprMask {
NormalAndAggregates,
}
-impl ExprMask {
- fn ignores(&self, expr: &Expr) -> bool {
- let is_normal_minus_aggregates = matches!(
- expr,
- Expr::Literal(..)
- | Expr::Column(..)
- | Expr::ScalarVariable(..)
- | Expr::Alias(..)
- | Expr::Wildcard { .. }
- );
-
- let is_aggr = matches!(expr, Expr::AggregateFunction(..));
-
- match self {
- Self::Normal => is_normal_minus_aggregates || is_aggr,
- Self::NormalAndAggregates => is_normal_minus_aggregates,
- }
- }
-}
-
-/// Go through an expression tree and generate identifiers for each
subexpression.
-///
-/// An identifier contains information of the expression itself and its
sub-expression.
-/// This visitor implementation use a stack `visit_stack` to track traversal,
which
-/// lets us know when a sub-tree's visiting is finished. When `pre_visit` is
called
-/// (traversing to a new node), an `EnterMark` and an `ExprItem` will be
pushed into stack.
-/// And try to pop out a `EnterMark` on leaving a node (`f_up()`). All
`ExprItem`
-/// before the first `EnterMark` is considered to be sub-tree of the leaving
node.
-///
-/// This visitor also records identifier in `id_array`. Makes the following
traverse
-/// pass can get the identifier of a node without recalculate it. We assign
each node
-/// in the expr tree a series number, start from 1, maintained by
`series_number`.
-/// Series number represents the order we left (`f_up()`) a node. Has the
property
-/// that child node's series number always smaller than parent's. While
`id_array` is
-/// organized in the order we enter (`f_down()`) a node. `node_count` helps us
to
-/// get the index of `id_array` for each node.
-///
-/// `Expr` without sub-expr (column, literal etc.) will not have identifier
-/// because they should not be recognized as common sub-expr.
-struct ExprIdentifierVisitor<'a, 'n> {
- // statistics of expressions
- expr_stats: &'a mut ExprStats<'n>,
- // cache to speed up second traversal
- id_array: &'a mut IdArray<'n>,
- // inner states
- visit_stack: Vec<VisitRecord<'n>>,
- // preorder index, start from 0.
- down_index: usize,
- // postorder index, start from 0.
- up_index: usize,
- // which expression should be skipped?
- expr_mask: ExprMask,
- // a `RandomState` to generate hashes during the first traversal
- random_state: &'a RandomState,
- // a flag to indicate that common expression found
- found_common: bool,
- // if we are in a conditional branch. A conditional branch means that the
expression
- // might not be executed depending on the runtime values of other
expressions, and
- // thus can not be extracted as a common expression.
- conditional: bool,
-}
+struct ExprCSEController<'a> {
+ alias_generator: &'a AliasGenerator,
+ mask: ExprMask,
-/// Record item that used when traversing an expression tree.
-enum VisitRecord<'n> {
- /// Marks the beginning of expression. It contains:
- /// - The post-order index assigned during the first, visiting traversal.
- EnterMark(usize),
-
- /// Marks an accumulated subexpression tree. It contains:
- /// - The accumulated identifier of a subexpression.
- /// - A boolean flag if the expression is valid for subexpression
elimination.
- /// The flag is propagated up from children to parent. (E.g. volatile
expressions
- /// are not valid and can't be extracted, but non-volatile children of
volatile
- /// expressions can be extracted.)
- ExprItem(Identifier<'n>, bool),
+ // how many aliases have we seen so far
+ alias_counter: usize,
}
-impl<'n> ExprIdentifierVisitor<'_, 'n> {
- /// Find the first `EnterMark` in the stack, and accumulates every
`ExprItem` before
- /// it. Returns a tuple that contains:
- /// - The pre-order index of the expression we marked.
- /// - The accumulated identifier of the children of the marked expression.
- /// - An accumulated boolean flag from the children of the marked
expression if all
- /// children are valid for subexpression elimination (i.e. it is safe to
extract the
- /// expression as a common expression from its children POV).
- /// (E.g. if any of the children of the marked expression is not valid
(e.g. is
- /// volatile) then the expression is also not valid, so we can propagate
this
- /// information up from children to parents via `visit_stack` during the
first,
- /// visiting traversal and no need to test the expression's validity
beforehand with
- /// an extra traversal).
- fn pop_enter_mark(&mut self) -> (usize, Option<Identifier<'n>>, bool) {
- let mut expr_id = None;
- let mut is_valid = true;
-
- while let Some(item) = self.visit_stack.pop() {
- match item {
- VisitRecord::EnterMark(down_index) => {
- return (down_index, expr_id, is_valid);
- }
- VisitRecord::ExprItem(sub_expr_id, sub_expr_is_valid) => {
- expr_id = Some(sub_expr_id.combine(expr_id));
- is_valid &= sub_expr_is_valid;
- }
- }
+impl<'a> ExprCSEController<'a> {
+ fn new(alias_generator: &'a AliasGenerator, mask: ExprMask) -> Self {
+ Self {
+ alias_generator,
+ mask,
+ alias_counter: 0,
}
- unreachable!("Enter mark should paired with node number");
- }
-
- /// Save the current `conditional` status and run `f` with `conditional`
set to true.
- fn conditionally<F: FnMut(&mut Self) -> Result<()>>(
- &mut self,
- mut f: F,
- ) -> Result<()> {
- let conditional = self.conditional;
- self.conditional = true;
- f(self)?;
- self.conditional = conditional;
-
- Ok(())
}
}
-impl<'n> TreeNodeVisitor<'n> for ExprIdentifierVisitor<'_, 'n> {
+impl CSEController for ExprCSEController<'_> {
type Node = Expr;
- fn f_down(&mut self, expr: &'n Expr) -> Result<TreeNodeRecursion> {
- self.id_array.push((0, None));
- self.visit_stack
- .push(VisitRecord::EnterMark(self.down_index));
- self.down_index += 1;
-
- // If an expression can short-circuit then some of its children might
not be
- // executed so count the occurrence of subexpressions as conditional
in all
- // children.
- Ok(match expr {
- // If we are already in a conditionally evaluated subtree then
continue
- // traversal.
- _ if self.conditional => TreeNodeRecursion::Continue,
-
+ fn conditional_children(node: &Expr) -> Option<(Vec<&Expr>, Vec<&Expr>)> {
+ match node {
// In case of `ScalarFunction`s we don't know which children are
surely
// executed so start visiting all children conditionally and stop
the
// recursion with `TreeNodeRecursion::Jump`.
Expr::ScalarFunction(ScalarFunction { func, args })
if func.short_circuits() =>
{
- self.conditionally(|visitor| {
- args.iter().try_for_each(|e| e.visit(visitor).map(|_| ()))
- })?;
-
- TreeNodeRecursion::Jump
+ Some((vec![], args.iter().collect()))
}
// In case of `And` and `Or` the first child is surely executed,
but we
@@ -1040,12 +632,7 @@ impl<'n> TreeNodeVisitor<'n> for
ExprIdentifierVisitor<'_, 'n> {
left,
op: Operator::And | Operator::Or,
right,
- }) => {
- left.visit(self)?;
- self.conditionally(|visitor| right.visit(visitor).map(|_|
()))?;
-
- TreeNodeRecursion::Jump
- }
+ }) => Some((vec![left.as_ref()], vec![right.as_ref()])),
// In case of `Case` the optional base expression and the first
when
// expressions are surely executed, but we account subexpressions
as
@@ -1054,167 +641,151 @@ impl<'n> TreeNodeVisitor<'n> for
ExprIdentifierVisitor<'_, 'n> {
expr,
when_then_expr,
else_expr,
- }) => {
- expr.iter().try_for_each(|e| e.visit(self).map(|_| ()))?;
- when_then_expr.iter().take(1).try_for_each(|(when, then)| {
- when.visit(self)?;
- self.conditionally(|visitor| then.visit(visitor).map(|_|
()))
- })?;
- self.conditionally(|visitor| {
- when_then_expr.iter().skip(1).try_for_each(|(when, then)| {
- when.visit(visitor)?;
- then.visit(visitor).map(|_| ())
- })?;
- else_expr
- .iter()
- .try_for_each(|e| e.visit(visitor).map(|_| ()))
- })?;
-
- TreeNodeRecursion::Jump
- }
+ }) => Some((
+ expr.iter()
+ .map(|e| e.as_ref())
+ .chain(when_then_expr.iter().take(1).map(|(when, _)|
when.as_ref()))
+ .collect(),
+ when_then_expr
+ .iter()
+ .take(1)
+ .map(|(_, then)| then.as_ref())
+ .chain(
+ when_then_expr
+ .iter()
+ .skip(1)
+ .flat_map(|(when, then)| [when.as_ref(),
then.as_ref()]),
+ )
+ .chain(else_expr.iter().map(|e| e.as_ref()))
+ .collect(),
+ )),
+ _ => None,
+ }
+ }
- // In case of non-short-circuit expressions continue the traversal.
- _ => TreeNodeRecursion::Continue,
- })
+ fn is_valid(node: &Expr) -> bool {
+ !node.is_volatile_node()
}
- fn f_up(&mut self, expr: &'n Expr) -> Result<TreeNodeRecursion> {
- let (down_index, sub_expr_id, sub_expr_is_valid) =
self.pop_enter_mark();
+ fn is_ignored(&self, node: &Expr) -> bool {
+ let is_normal_minus_aggregates = matches!(
+ node,
+ Expr::Literal(..)
+ | Expr::Column(..)
+ | Expr::ScalarVariable(..)
+ | Expr::Alias(..)
+ | Expr::Wildcard { .. }
+ );
- let expr_id = Identifier::new(expr,
self.random_state).combine(sub_expr_id);
- let is_valid = !expr.is_volatile_node() && sub_expr_is_valid;
+ let is_aggr = matches!(node, Expr::AggregateFunction(..));
- self.id_array[down_index].0 = self.up_index;
- if is_valid && !self.expr_mask.ignores(expr) {
- self.id_array[down_index].1 = Some(expr_id);
- let (count, conditional_count) =
- self.expr_stats.entry(expr_id).or_insert((0, 0));
- if self.conditional {
- *conditional_count += 1;
- } else {
- *count += 1;
- }
- if *count > 1 || (*count == 1 && *conditional_count > 0) {
- self.found_common = true;
- }
+ match self.mask {
+ ExprMask::Normal => is_normal_minus_aggregates || is_aggr,
+ ExprMask::NormalAndAggregates => is_normal_minus_aggregates,
}
- self.visit_stack
- .push(VisitRecord::ExprItem(expr_id, is_valid));
- self.up_index += 1;
-
- Ok(TreeNodeRecursion::Continue)
}
-}
-/// Rewrite expression by replacing detected common sub-expression with
-/// the corresponding temporary column name. That column contains the
-/// evaluate result of replaced expression.
-struct CommonSubexprRewriter<'a, 'n> {
- // statistics of expressions
- expr_stats: &'a ExprStats<'n>,
- // cache to speed up second traversal
- id_array: &'a IdArray<'n>,
- // common expression, that are replaced during the second traversal, are
collected to
- // this map
- common_exprs: &'a mut CommonExprs<'n>,
- // preorder index, starts from 0.
- down_index: usize,
- // how many aliases have we seen so far
- alias_counter: usize,
- // alias generator for extracted common expressions
- alias_generator: &'a AliasGenerator,
-}
+ fn generate_alias(&self) -> String {
+ self.alias_generator.next(CSE_PREFIX)
+ }
-impl TreeNodeRewriter for CommonSubexprRewriter<'_, '_> {
- type Node = Expr;
+ fn rewrite(&mut self, node: &Self::Node, alias: &str) -> Self::Node {
+ // alias the expressions without an `Alias` ancestor node
+ if self.alias_counter > 0 {
+ col(alias)
+ } else {
+ self.alias_counter += 1;
+ col(alias).alias(node.schema_name().to_string())
+ }
+ }
- fn f_down(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
- if matches!(expr, Expr::Alias(_)) {
+ fn rewrite_f_down(&mut self, node: &Expr) {
+ if matches!(node, Expr::Alias(_)) {
self.alias_counter += 1;
}
+ }
+ fn rewrite_f_up(&mut self, node: &Expr) {
+ if matches!(node, Expr::Alias(_)) {
+ self.alias_counter -= 1
+ }
+ }
+}
- let (up_index, expr_id) = self.id_array[self.down_index];
- self.down_index += 1;
+impl Default for CommonSubexprEliminate {
+ fn default() -> Self {
+ Self::new()
+ }
+}
- // Handle `Expr`s with identifiers only
- if let Some(expr_id) = expr_id {
- let (count, conditional_count) =
self.expr_stats.get(&expr_id).unwrap();
- if *count > 1 || *count == 1 && *conditional_count > 0 {
- // step index to skip all sub-node (which has smaller series
number).
- while self.down_index < self.id_array.len()
- && self.id_array[self.down_index].0 < up_index
- {
- self.down_index += 1;
- }
+/// Build the "intermediate" projection plan that evaluates the extracted
common
+/// expressions.
+///
+/// # Arguments
+/// input: the input plan
+///
+/// common_exprs: which common subexpressions were used (and thus are added to
+/// intermediate projection)
+///
+/// expr_stats: the set of common subexpressions
+fn build_common_expr_project_plan(
+ input: LogicalPlan,
+ common_exprs: Vec<(Expr, String)>,
+) -> Result<LogicalPlan> {
+ let mut fields_set = BTreeSet::new();
+ let mut project_exprs = common_exprs
+ .into_iter()
+ .map(|(expr, expr_alias)| {
+ fields_set.insert(expr_alias.clone());
+ Ok(expr.alias(expr_alias))
+ })
+ .collect::<Result<Vec<_>>>()?;
- let expr_name = expr.schema_name().to_string();
- let (_, expr_alias) =
- self.common_exprs.entry(expr_id).or_insert_with(|| {
- let expr_alias = self.alias_generator.next(CSE_PREFIX);
- (expr, expr_alias)
- });
-
- // alias the expressions without an `Alias` ancestor node
- let rewritten = if self.alias_counter > 0 {
- col(expr_alias.clone())
- } else {
- self.alias_counter += 1;
- col(expr_alias.clone()).alias(expr_name)
- };
-
- return Ok(Transformed::new(rewritten, true,
TreeNodeRecursion::Jump));
- }
+ for (qualifier, field) in input.schema().iter() {
+ if fields_set.insert(qualified_name(qualifier, field.name())) {
+ project_exprs.push(Expr::from((qualifier, field)));
}
-
- Ok(Transformed::no(expr))
}
- fn f_up(&mut self, expr: Expr) -> Result<Transformed<Self::Node>> {
- if matches!(expr, Expr::Alias(_)) {
- self.alias_counter -= 1
- }
+ Projection::try_new(project_exprs,
Arc::new(input)).map(LogicalPlan::Projection)
+}
- Ok(Transformed::no(expr))
- }
+/// Build the projection plan to eliminate unnecessary columns produced by
+/// the "intermediate" projection plan built in
[build_common_expr_project_plan].
+///
+/// This is required to keep the schema the same for plans that pass the input
+/// on to the output, such as `Filter` or `Sort`.
+fn build_recover_project_plan(
+ schema: &DFSchema,
+ input: LogicalPlan,
+) -> Result<LogicalPlan> {
+ let col_exprs = schema.iter().map(Expr::from).collect();
+ Projection::try_new(col_exprs,
Arc::new(input)).map(LogicalPlan::Projection)
}
-/// Replace common sub-expression in `expr` with the corresponding temporary
-/// column name, updating `common_exprs` with any replaced expressions
-fn replace_common_expr<'n>(
- expr: Expr,
- id_array: &IdArray<'n>,
- expr_stats: &ExprStats<'n>,
- common_exprs: &mut CommonExprs<'n>,
- alias_generator: &AliasGenerator,
-) -> Result<Expr> {
- if id_array.is_empty() {
- Ok(Transformed::no(expr))
+fn extract_expressions(expr: &Expr, result: &mut Vec<Expr>) {
+ if let Expr::GroupingSet(groupings) = expr {
+ for e in groupings.distinct_expr() {
+ let (qualifier, field_name) = e.qualified_name();
+ let col = Column::new(qualifier, field_name);
+ result.push(Expr::Column(col))
+ }
} else {
- expr.rewrite(&mut CommonSubexprRewriter {
- expr_stats,
- id_array,
- common_exprs,
- down_index: 0,
- alias_counter: 0,
- alias_generator,
- })
+ let (qualifier, field_name) = expr.qualified_name();
+ let col = Column::new(qualifier, field_name);
+ result.push(Expr::Column(col));
}
- .data()
}
#[cfg(test)]
mod test {
use std::any::Any;
- use std::collections::HashSet;
use std::iter;
use arrow::datatypes::{DataType, Field, Schema};
- use datafusion_expr::expr::AggregateFunction;
use datafusion_expr::logical_plan::{table_scan, JoinType};
use datafusion_expr::{
- grouping_set, AccumulatorFactoryFunction, AggregateUDF, BinaryExpr,
- ColumnarValue, ScalarUDF, ScalarUDFImpl, Signature, SimpleAggregateUDF,
- Volatility,
+ grouping_set, AccumulatorFactoryFunction, AggregateUDF, ColumnarValue,
ScalarUDF,
+ ScalarUDFImpl, Signature, SimpleAggregateUDF, Volatility,
};
use datafusion_expr::{lit, logical_plan::builder::LogicalPlanBuilder};
@@ -1238,154 +809,6 @@ mod test {
assert_eq!(expected, formatted_plan);
}
- #[test]
- fn id_array_visitor() -> Result<()> {
- let optimizer = CommonSubexprEliminate::new();
-
- let a_plus_1 = col("a") + lit(1);
- let avg_c = avg(col("c"));
- let sum_a_plus_1 = sum(a_plus_1);
- let sum_a_plus_1_minus_avg_c = sum_a_plus_1 - avg_c;
- let expr = sum_a_plus_1_minus_avg_c * lit(2);
-
- let Expr::BinaryExpr(BinaryExpr {
- left: sum_a_plus_1_minus_avg_c,
- ..
- }) = &expr
- else {
- panic!("Cannot extract subexpression reference")
- };
- let Expr::BinaryExpr(BinaryExpr {
- left: sum_a_plus_1,
- right: avg_c,
- ..
- }) = sum_a_plus_1_minus_avg_c.as_ref()
- else {
- panic!("Cannot extract subexpression reference")
- };
- let Expr::AggregateFunction(AggregateFunction {
- args: a_plus_1_vec, ..
- }) = sum_a_plus_1.as_ref()
- else {
- panic!("Cannot extract subexpression reference")
- };
- let a_plus_1 = &a_plus_1_vec.as_slice()[0];
-
- // skip aggregates
- let mut id_array = vec![];
- optimizer.expr_to_identifier(
- &expr,
- &mut ExprStats::new(),
- &mut id_array,
- ExprMask::Normal,
- )?;
-
- // Collect distinct hashes and set them to 0 in `id_array`
- fn collect_hashes(id_array: &mut IdArray) -> HashSet<u64> {
- id_array
- .iter_mut()
- .flat_map(|(_, expr_id_option)| {
- expr_id_option.as_mut().map(|expr_id| {
- let hash = expr_id.hash;
- expr_id.hash = 0;
- hash
- })
- })
- .collect::<HashSet<_>>()
- }
-
- let hashes = collect_hashes(&mut id_array);
- assert_eq!(hashes.len(), 3);
-
- let expected = vec![
- (
- 8,
- Some(Identifier {
- hash: 0,
- expr: &expr,
- }),
- ),
- (
- 6,
- Some(Identifier {
- hash: 0,
- expr: sum_a_plus_1_minus_avg_c,
- }),
- ),
- (3, None),
- (
- 2,
- Some(Identifier {
- hash: 0,
- expr: a_plus_1,
- }),
- ),
- (0, None),
- (1, None),
- (5, None),
- (4, None),
- (7, None),
- ];
- assert_eq!(expected, id_array);
-
- // include aggregates
- let mut id_array = vec![];
- optimizer.expr_to_identifier(
- &expr,
- &mut ExprStats::new(),
- &mut id_array,
- ExprMask::NormalAndAggregates,
- )?;
-
- let hashes = collect_hashes(&mut id_array);
- assert_eq!(hashes.len(), 5);
-
- let expected = vec![
- (
- 8,
- Some(Identifier {
- hash: 0,
- expr: &expr,
- }),
- ),
- (
- 6,
- Some(Identifier {
- hash: 0,
- expr: sum_a_plus_1_minus_avg_c,
- }),
- ),
- (
- 3,
- Some(Identifier {
- hash: 0,
- expr: sum_a_plus_1,
- }),
- ),
- (
- 2,
- Some(Identifier {
- hash: 0,
- expr: a_plus_1,
- }),
- ),
- (0, None),
- (1, None),
- (
- 5,
- Some(Identifier {
- hash: 0,
- expr: avg_c,
- }),
- ),
- (4, None),
- (7, None),
- ];
- assert_eq!(expected, id_array);
-
- Ok(())
- }
-
#[test]
fn tpch_q1_simplified() -> Result<()> {
// SQL:
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]