This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 45a316ca7e Extract CSE logic to `datafusion_common` (#13002)
45a316ca7e is described below

commit 45a316ca7e3b5faa21cc80bab6d879c2be3d90c8
Author: Peter Toth <[email protected]>
AuthorDate: Mon Oct 21 21:47:05 2024 +0200

    Extract CSE logic to `datafusion_common` (#13002)
    
    * Extract CSE logic
    
    * address review comments, move `HashNode` to `datafusion_common::cse`, 
shorter names for eliminator and controller, change 
`CSE::extract_common_nodes()` to return `Result<FoundCommonNodes<N>>` (instead 
of `Result<Transformed<FoundCommonNodes<N>>>`)
---
 datafusion-cli/Cargo.lock                          |   21 +-
 datafusion/common/Cargo.toml                       |    1 +
 datafusion/common/src/cse.rs                       |  800 ++++++++++++
 datafusion/common/src/lib.rs                       |    1 +
 datafusion/common/src/tree_node.rs                 |  135 +-
 datafusion/expr/src/expr.rs                        |   79 +-
 .../optimizer/src/common_subexpr_eliminate.rs      | 1361 ++++++--------------
 7 files changed, 1314 insertions(+), 1084 deletions(-)

diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock
index 612209fdd9..401f203dd9 100644
--- a/datafusion-cli/Cargo.lock
+++ b/datafusion-cli/Cargo.lock
@@ -406,9 +406,9 @@ dependencies = [
 
 [[package]]
 name = "async-compression"
-version = "0.4.16"
+version = "0.4.17"
 source = "registry+https://github.com/rust-lang/crates.io-index";
-checksum = "103db485efc3e41214fe4fda9f3dbeae2eb9082f48fd236e6095627a9422066e"
+checksum = "0cb8f1d480b0ea3783ab015936d2a55c87e219676f0c0b7dec61494043f21857"
 dependencies = [
  "bzip2",
  "flate2",
@@ -917,9 +917,9 @@ dependencies = [
 
 [[package]]
 name = "cc"
-version = "1.1.30"
+version = "1.1.31"
 source = "registry+https://github.com/rust-lang/crates.io-index";
-checksum = "b16803a61b81d9eabb7eae2588776c4c1e584b738ede45fdbb4c972cec1e9945"
+checksum = "c2e7962b54006dcfcc61cb72735f4d89bb97061dd6a7ed882ec6b8ee53714c6f"
 dependencies = [
  "jobserver",
  "libc",
@@ -1293,6 +1293,7 @@ dependencies = [
  "chrono",
  "half",
  "hashbrown 0.14.5",
+ "indexmap",
  "instant",
  "libc",
  "num_cpus",
@@ -2615,9 +2616,9 @@ dependencies = [
 
 [[package]]
 name = "object_store"
-version = "0.11.0"
+version = "0.11.1"
 source = "registry+https://github.com/rust-lang/crates.io-index";
-checksum = "25a0c4b3a0e31f8b66f71ad8064521efa773910196e2cde791436f13409f3b45"
+checksum = "6eb4c22c6154a1e759d7099f9ffad7cc5ef8245f9efbab4a41b92623079c82f3"
 dependencies = [
  "async-trait",
  "base64 0.22.1",
@@ -3411,9 +3412,9 @@ dependencies = [
 
 [[package]]
 name = "serde_json"
-version = "1.0.130"
+version = "1.0.132"
 source = "registry+https://github.com/rust-lang/crates.io-index";
-checksum = "610f75ff4a8e3cb29b85da56eabdd1bff5b06739059a4b8e2967fef32e5d9944"
+checksum = "d726bfaff4b320266d395898905d0eba0345aae23b54aee3a737e260fd46db03"
 dependencies = [
  "itoa",
  "memchr",
@@ -3605,9 +3606,9 @@ checksum = 
"13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292"
 
 [[package]]
 name = "syn"
-version = "2.0.79"
+version = "2.0.82"
 source = "registry+https://github.com/rust-lang/crates.io-index";
-checksum = "89132cd0bf050864e1d38dc3bbc07a0eb8e7530af26344d3d2bbbef83499f590"
+checksum = "83540f837a8afc019423a8edb95b52a8effe46957ee402287f4292fae35be021"
 dependencies = [
  "proc-macro2",
  "quote",
diff --git a/datafusion/common/Cargo.toml b/datafusion/common/Cargo.toml
index 1ac27b40c2..0747672a18 100644
--- a/datafusion/common/Cargo.toml
+++ b/datafusion/common/Cargo.toml
@@ -56,6 +56,7 @@ arrow-schema = { workspace = true }
 chrono = { workspace = true }
 half = { workspace = true }
 hashbrown = { workspace = true }
+indexmap = { workspace = true }
 libc = "0.2.140"
 num_cpus = { workspace = true }
 object_store = { workspace = true, optional = true }
diff --git a/datafusion/common/src/cse.rs b/datafusion/common/src/cse.rs
new file mode 100644
index 0000000000..453ae26e73
--- /dev/null
+++ b/datafusion/common/src/cse.rs
@@ -0,0 +1,800 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Common Subexpression Elimination logic implemented in [`CSE`] can be 
controlled with
+//! a [`CSEController`], that defines how to eliminate common subtrees from a 
particular
+//! [`TreeNode`] tree.
+
+use crate::hash_utils::combine_hashes;
+use crate::tree_node::{
+    Transformed, TransformedResult, TreeNode, TreeNodeRecursion, 
TreeNodeRewriter,
+    TreeNodeVisitor,
+};
+use crate::Result;
+use indexmap::IndexMap;
+use std::collections::HashMap;
+use std::hash::{BuildHasher, Hash, Hasher, RandomState};
+use std::marker::PhantomData;
+use std::sync::Arc;
+
+/// Hashes the direct content of an [`TreeNode`] without recursing into its 
children.
+///
+/// This method is useful to incrementally compute hashes, such as in [`CSE`] 
which builds
+/// a deep hash of a node and its descendants during the bottom-up phase of 
the first
+/// traversal and so avoid computing the hash of the node and then the hash of 
its
+/// descendants separately.
+///
+/// If a node doesn't have any children then the value returned by 
`hash_node()` is
+/// similar to '.hash()`, but not necessarily returns the same value.
+pub trait HashNode {
+    fn hash_node<H: Hasher>(&self, state: &mut H);
+}
+
+impl<T: HashNode + ?Sized> HashNode for Arc<T> {
+    fn hash_node<H: Hasher>(&self, state: &mut H) {
+        (**self).hash_node(state);
+    }
+}
+
+/// Identifier that represents a [`TreeNode`] tree.
+///
+/// This identifier is designed to be efficient and  "hash", "accumulate", 
"equal" and
+/// "have no collision (as low as possible)"
+#[derive(Debug, Eq, PartialEq)]
+struct Identifier<'n, N> {
+    // Hash of `node` built up incrementally during the first, visiting 
traversal.
+    // Its value is not necessarily equal to default hash of the node. E.g. it 
is not
+    // equal to `expr.hash()` if the node is `Expr`.
+    hash: u64,
+    node: &'n N,
+}
+
+impl<N> Clone for Identifier<'_, N> {
+    fn clone(&self) -> Self {
+        *self
+    }
+}
+impl<N> Copy for Identifier<'_, N> {}
+
+impl<N> Hash for Identifier<'_, N> {
+    fn hash<H: Hasher>(&self, state: &mut H) {
+        state.write_u64(self.hash);
+    }
+}
+
+impl<'n, N: HashNode> Identifier<'n, N> {
+    fn new(node: &'n N, random_state: &RandomState) -> Self {
+        let mut hasher = random_state.build_hasher();
+        node.hash_node(&mut hasher);
+        let hash = hasher.finish();
+        Self { hash, node }
+    }
+
+    fn combine(mut self, other: Option<Self>) -> Self {
+        other.map_or(self, |other_id| {
+            self.hash = combine_hashes(self.hash, other_id.hash);
+            self
+        })
+    }
+}
+
+/// A cache that contains the postorder index and the identifier of 
[`TreeNode`]s by the
+/// preorder index of the nodes.
+///
+/// This cache is filled by [`CSEVisitor`] during the first traversal and is
+/// used by [`CSERewriter`] during the second traversal.
+///
+/// The purpose of this cache is to quickly find the identifier of a node 
during the
+/// second traversal.
+///
+/// Elements in this array are added during `f_down` so the indexes represent 
the preorder
+/// index of nodes and thus element 0 belongs to the root of the tree.
+///
+/// The elements of the array are tuples that contain:
+/// - Postorder index that belongs to the preorder index. Assigned during 
`f_up`, start
+///   from 0.
+/// - The optional [`Identifier`] of the node. If none the node should not be 
considered
+///   for CSE.
+///
+/// # Example
+/// An expression tree like `(a + b)` would have the following `IdArray`:
+/// ```text
+/// [
+///   (2, Some(Identifier(hash_of("a + b"), &"a + b"))),
+///   (1, Some(Identifier(hash_of("a"), &"a"))),
+///   (0, Some(Identifier(hash_of("b"), &"b")))
+/// ]
+/// ```
+type IdArray<'n, N> = Vec<(usize, Option<Identifier<'n, N>>)>;
+
+/// A map that contains the number of normal and conditional occurrences of 
[`TreeNode`]s
+/// by their identifiers.
+type NodeStats<'n, N> = HashMap<Identifier<'n, N>, (usize, usize)>;
+
+/// A map that contains the common [`TreeNode`]s and their alias by their 
identifiers,
+/// extracted during the second, rewriting traversal.
+type CommonNodes<'n, N> = IndexMap<Identifier<'n, N>, (N, String)>;
+
+type ChildrenList<N> = (Vec<N>, Vec<N>);
+
+/// The [`TreeNode`] specific definition of elimination.
+pub trait CSEController {
+    /// The type of the tree nodes.
+    type Node;
+
+    /// Splits the children to normal and conditionally evaluated ones or 
returns `None`
+    /// if all are always evaluated.
+    fn conditional_children(node: &Self::Node) -> 
Option<ChildrenList<&Self::Node>>;
+
+    // Returns true if a node is valid. If a node is invalid then it can't be 
eliminated.
+    // Validity is propagated up which means no subtree can be eliminated that 
contains
+    // an invalid node.
+    // (E.g. volatile expressions are not valid and subtrees containing such a 
node can't
+    // be extracted.)
+    fn is_valid(node: &Self::Node) -> bool;
+
+    // Returns true if a node should be ignored during CSE. Contrary to 
validity of a node,
+    // it is not propagated up.
+    fn is_ignored(&self, node: &Self::Node) -> bool;
+
+    // Generates a new name for the extracted subtree.
+    fn generate_alias(&self) -> String;
+
+    // Replaces a node to the generated alias.
+    fn rewrite(&mut self, node: &Self::Node, alias: &str) -> Self::Node;
+
+    // A helper method called on each node during top-down traversal during 
the second,
+    // rewriting traversal of CSE.
+    fn rewrite_f_down(&mut self, _node: &Self::Node) {}
+
+    // A helper method called on each node during bottom-up traversal during 
the second,
+    // rewriting traversal of CSE.
+    fn rewrite_f_up(&mut self, _node: &Self::Node) {}
+}
+
+/// The result of potentially rewriting a list of [`TreeNode`]s to eliminate 
common
+/// subtrees.
+#[derive(Debug)]
+pub enum FoundCommonNodes<N> {
+    /// No common [`TreeNode`]s were found
+    No { original_nodes_list: Vec<Vec<N>> },
+
+    /// Common [`TreeNode`]s were found
+    Yes {
+        /// extracted common [`TreeNode`]
+        common_nodes: Vec<(N, String)>,
+
+        /// new [`TreeNode`]s with common subtrees replaced
+        new_nodes_list: Vec<Vec<N>>,
+
+        /// original [`TreeNode`]s
+        original_nodes_list: Vec<Vec<N>>,
+    },
+}
+
+/// Go through a [`TreeNode`] tree and generate identifiers for each subtrees.
+///
+/// An identifier contains information of the [`TreeNode`] itself and its 
subtrees.
+/// This visitor implementation use a stack `visit_stack` to track traversal, 
which
+/// lets us know when a subtree's visiting is finished. When `pre_visit` is 
called
+/// (traversing to a new node), an `EnterMark` and an `NodeItem` will be 
pushed into stack.
+/// And try to pop out a `EnterMark` on leaving a node (`f_up()`). All 
`NodeItem`
+/// before the first `EnterMark` is considered to be sub-tree of the leaving 
node.
+///
+/// This visitor also records identifier in `id_array`. Makes the following 
traverse
+/// pass can get the identifier of a node without recalculate it. We assign 
each node
+/// in the tree a series number, start from 1, maintained by `series_number`.
+/// Series number represents the order we left (`f_up()`) a node. Has the 
property
+/// that child node's series number always smaller than parent's. While 
`id_array` is
+/// organized in the order we enter (`f_down()`) a node. `node_count` helps us 
to
+/// get the index of `id_array` for each node.
+///
+/// A [`TreeNode`] without any children (column, literal etc.) will not have 
identifier
+/// because they should not be recognized as common subtree.
+struct CSEVisitor<'a, 'n, N, C: CSEController<Node = N>> {
+    /// statistics of [`TreeNode`]s
+    node_stats: &'a mut NodeStats<'n, N>,
+
+    /// cache to speed up second traversal
+    id_array: &'a mut IdArray<'n, N>,
+
+    /// inner states
+    visit_stack: Vec<VisitRecord<'n, N>>,
+
+    /// preorder index, start from 0.
+    down_index: usize,
+
+    /// postorder index, start from 0.
+    up_index: usize,
+
+    /// a [`RandomState`] to generate hashes during the first traversal
+    random_state: &'a RandomState,
+
+    /// a flag to indicate that common [`TreeNode`]s found
+    found_common: bool,
+
+    /// if we are in a conditional branch. A conditional branch means that the 
[`TreeNode`]
+    /// might not be executed depending on the runtime values of other 
[`TreeNode`]s, and
+    /// thus can not be extracted as a common [`TreeNode`].
+    conditional: bool,
+
+    controller: &'a C,
+}
+
+/// Record item that used when traversing a [`TreeNode`] tree.
+enum VisitRecord<'n, N> {
+    /// Marks the beginning of [`TreeNode`]. It contains:
+    /// - The post-order index assigned during the first, visiting traversal.
+    EnterMark(usize),
+
+    /// Marks an accumulated subtree. It contains:
+    /// - The accumulated identifier of a subtree.
+    /// - A accumulated boolean flag if the subtree is valid for CSE.
+    ///   The flag is propagated up from children to parent. (E.g. volatile 
expressions
+    ///   are not valid and can't be extracted, but non-volatile children of 
volatile
+    ///   expressions can be extracted.)
+    NodeItem(Identifier<'n, N>, bool),
+}
+
+impl<'n, N: TreeNode + HashNode, C: CSEController<Node = N>> CSEVisitor<'_, 
'n, N, C> {
+    /// Find the first `EnterMark` in the stack, and accumulates every 
`NodeItem` before
+    /// it. Returns a tuple that contains:
+    /// - The pre-order index of the [`TreeNode`] we marked.
+    /// - The accumulated identifier of the children of the marked 
[`TreeNode`].
+    /// - An accumulated boolean flag from the children of the marked 
[`TreeNode`] if all
+    ///   children are valid for CSE (i.e. it is safe to extract the 
[`TreeNode`] as a
+    ///   common [`TreeNode`] from its children POV).
+    ///   (E.g. if any of the children of the marked expression is not valid 
(e.g. is
+    ///   volatile) then the expression is also not valid, so we can propagate 
this
+    ///   information up from children to parents via `visit_stack` during the 
first,
+    ///   visiting traversal and no need to test the expression's validity 
beforehand with
+    ///   an extra traversal).
+    fn pop_enter_mark(&mut self) -> (usize, Option<Identifier<'n, N>>, bool) {
+        let mut node_id = None;
+        let mut is_valid = true;
+
+        while let Some(item) = self.visit_stack.pop() {
+            match item {
+                VisitRecord::EnterMark(down_index) => {
+                    return (down_index, node_id, is_valid);
+                }
+                VisitRecord::NodeItem(sub_node_id, sub_node_is_valid) => {
+                    node_id = Some(sub_node_id.combine(node_id));
+                    is_valid &= sub_node_is_valid;
+                }
+            }
+        }
+        unreachable!("EnterMark should paired with NodeItem");
+    }
+}
+
+impl<'n, N: TreeNode + HashNode + Eq, C: CSEController<Node = N>> 
TreeNodeVisitor<'n>
+    for CSEVisitor<'_, 'n, N, C>
+{
+    type Node = N;
+
+    fn f_down(&mut self, node: &'n Self::Node) -> Result<TreeNodeRecursion> {
+        self.id_array.push((0, None));
+        self.visit_stack
+            .push(VisitRecord::EnterMark(self.down_index));
+        self.down_index += 1;
+
+        // If a node can short-circuit then some of its children might not be 
executed so
+        // count the occurrence either normal or conditional.
+        Ok(if self.conditional {
+            // If we are already in a conditionally evaluated subtree then 
continue
+            // traversal.
+            TreeNodeRecursion::Continue
+        } else {
+            // If we are already in a node that can short-circuit then start 
new
+            // traversals on its normal conditional children.
+            match C::conditional_children(node) {
+                Some((normal, conditional)) => {
+                    normal
+                        .into_iter()
+                        .try_for_each(|n| n.visit(self).map(|_| ()))?;
+                    self.conditional = true;
+                    conditional
+                        .into_iter()
+                        .try_for_each(|n| n.visit(self).map(|_| ()))?;
+                    self.conditional = false;
+
+                    TreeNodeRecursion::Jump
+                }
+
+                // In case of non-short-circuit node continue the traversal.
+                _ => TreeNodeRecursion::Continue,
+            }
+        })
+    }
+
+    fn f_up(&mut self, node: &'n Self::Node) -> Result<TreeNodeRecursion> {
+        let (down_index, sub_node_id, sub_node_is_valid) = 
self.pop_enter_mark();
+
+        let node_id = Identifier::new(node, 
self.random_state).combine(sub_node_id);
+        let is_valid = C::is_valid(node) && sub_node_is_valid;
+
+        self.id_array[down_index].0 = self.up_index;
+        if is_valid && !self.controller.is_ignored(node) {
+            self.id_array[down_index].1 = Some(node_id);
+            let (count, conditional_count) =
+                self.node_stats.entry(node_id).or_insert((0, 0));
+            if self.conditional {
+                *conditional_count += 1;
+            } else {
+                *count += 1;
+            }
+            if *count > 1 || (*count == 1 && *conditional_count > 0) {
+                self.found_common = true;
+            }
+        }
+        self.visit_stack
+            .push(VisitRecord::NodeItem(node_id, is_valid));
+        self.up_index += 1;
+
+        Ok(TreeNodeRecursion::Continue)
+    }
+}
+
+/// Rewrite a [`TreeNode`] tree by replacing detected common subtrees with the
+/// corresponding temporary [`TreeNode`], that column contains the evaluate 
result of
+/// replaced [`TreeNode`] tree.
+struct CSERewriter<'a, 'n, N, C: CSEController<Node = N>> {
+    /// statistics of [`TreeNode`]s
+    node_stats: &'a NodeStats<'n, N>,
+
+    /// cache to speed up second traversal
+    id_array: &'a IdArray<'n, N>,
+
+    /// common [`TreeNode`]s, that are replaced during the second traversal, 
are collected
+    /// to this map
+    common_nodes: &'a mut CommonNodes<'n, N>,
+
+    // preorder index, starts from 0.
+    down_index: usize,
+
+    controller: &'a mut C,
+}
+
+impl<N: TreeNode + Eq, C: CSEController<Node = N>> TreeNodeRewriter
+    for CSERewriter<'_, '_, N, C>
+{
+    type Node = N;
+
+    fn f_down(&mut self, node: Self::Node) -> Result<Transformed<Self::Node>> {
+        self.controller.rewrite_f_down(&node);
+
+        let (up_index, node_id) = self.id_array[self.down_index];
+        self.down_index += 1;
+
+        // Handle nodes with identifiers only
+        if let Some(node_id) = node_id {
+            let (count, conditional_count) = 
self.node_stats.get(&node_id).unwrap();
+            if *count > 1 || *count == 1 && *conditional_count > 0 {
+                // step index to skip all sub-node (which has smaller series 
number).
+                while self.down_index < self.id_array.len()
+                    && self.id_array[self.down_index].0 < up_index
+                {
+                    self.down_index += 1;
+                }
+
+                let (node, alias) =
+                    self.common_nodes.entry(node_id).or_insert_with(|| {
+                        let node_alias = self.controller.generate_alias();
+                        (node, node_alias)
+                    });
+
+                let rewritten = self.controller.rewrite(node, alias);
+
+                return Ok(Transformed::new(rewritten, true, 
TreeNodeRecursion::Jump));
+            }
+        }
+
+        Ok(Transformed::no(node))
+    }
+
+    fn f_up(&mut self, node: Self::Node) -> Result<Transformed<Self::Node>> {
+        self.controller.rewrite_f_up(&node);
+
+        Ok(Transformed::no(node))
+    }
+}
+
+/// The main entry point of Common Subexpression Elimination.
+///
+/// [`CSE`] requires a [`CSEController`], that defines how common subtrees of 
a particular
+/// [`TreeNode`] tree can be eliminated. The elimination process can be 
started with the
+/// [`CSE::extract_common_nodes()`] method.
+pub struct CSE<N, C: CSEController<Node = N>> {
+    random_state: RandomState,
+    phantom_data: PhantomData<N>,
+    controller: C,
+}
+
+impl<N: TreeNode + HashNode + Clone + Eq, C: CSEController<Node = N>> CSE<N, 
C> {
+    pub fn new(controller: C) -> Self {
+        Self {
+            random_state: RandomState::new(),
+            phantom_data: PhantomData,
+            controller,
+        }
+    }
+
+    /// Add an identifier to `id_array` for every [`TreeNode`] in this tree.
+    fn node_to_id_array<'n>(
+        &self,
+        node: &'n N,
+        node_stats: &mut NodeStats<'n, N>,
+        id_array: &mut IdArray<'n, N>,
+    ) -> Result<bool> {
+        let mut visitor = CSEVisitor {
+            node_stats,
+            id_array,
+            visit_stack: vec![],
+            down_index: 0,
+            up_index: 0,
+            random_state: &self.random_state,
+            found_common: false,
+            conditional: false,
+            controller: &self.controller,
+        };
+        node.visit(&mut visitor)?;
+
+        Ok(visitor.found_common)
+    }
+
+    /// Returns the identifier list for each element in `nodes` and a flag to 
indicate if
+    /// rewrite phase of CSE make sense.
+    ///
+    /// Returns and array with 1 element for each input node in `nodes`
+    ///
+    /// Each element is itself the result of [`CSE::node_to_id_array`] for 
that node
+    /// (e.g. the identifiers for each node in the tree)
+    fn to_arrays<'n>(
+        &self,
+        nodes: &'n [N],
+        node_stats: &mut NodeStats<'n, N>,
+    ) -> Result<(bool, Vec<IdArray<'n, N>>)> {
+        let mut found_common = false;
+        nodes
+            .iter()
+            .map(|n| {
+                let mut id_array = vec![];
+                self.node_to_id_array(n, node_stats, &mut id_array)
+                    .map(|fc| {
+                        found_common |= fc;
+
+                        id_array
+                    })
+            })
+            .collect::<Result<Vec<_>>>()
+            .map(|id_arrays| (found_common, id_arrays))
+    }
+
+    /// Replace common subtrees in `node` with the corresponding temporary
+    /// [`TreeNode`], updating `common_nodes` with any replaced [`TreeNode`]
+    fn replace_common_node<'n>(
+        &mut self,
+        node: N,
+        id_array: &IdArray<'n, N>,
+        node_stats: &NodeStats<'n, N>,
+        common_nodes: &mut CommonNodes<'n, N>,
+    ) -> Result<N> {
+        if id_array.is_empty() {
+            Ok(Transformed::no(node))
+        } else {
+            node.rewrite(&mut CSERewriter {
+                node_stats,
+                id_array,
+                common_nodes,
+                down_index: 0,
+                controller: &mut self.controller,
+            })
+        }
+        .data()
+    }
+
+    /// Replace common subtrees in `nodes_list` with the corresponding 
temporary
+    /// [`TreeNode`], updating `common_nodes` with any replaced [`TreeNode`].
+    fn rewrite_nodes_list<'n>(
+        &mut self,
+        nodes_list: Vec<Vec<N>>,
+        arrays_list: &[Vec<IdArray<'n, N>>],
+        node_stats: &NodeStats<'n, N>,
+        common_nodes: &mut CommonNodes<'n, N>,
+    ) -> Result<Vec<Vec<N>>> {
+        nodes_list
+            .into_iter()
+            .zip(arrays_list.iter())
+            .map(|(nodes, arrays)| {
+                nodes
+                    .into_iter()
+                    .zip(arrays.iter())
+                    .map(|(node, id_array)| {
+                        self.replace_common_node(node, id_array, node_stats, 
common_nodes)
+                    })
+                    .collect::<Result<Vec<_>>>()
+            })
+            .collect::<Result<Vec<_>>>()
+    }
+
+    /// Extracts common [`TreeNode`]s and rewrites `nodes_list`.
+    ///
+    /// Returns [`FoundCommonNodes`] recording the result of the extraction.
+    pub fn extract_common_nodes(
+        &mut self,
+        nodes_list: Vec<Vec<N>>,
+    ) -> Result<FoundCommonNodes<N>> {
+        let mut found_common = false;
+        let mut node_stats = NodeStats::new();
+        let id_arrays_list = nodes_list
+            .iter()
+            .map(|nodes| {
+                self.to_arrays(nodes, &mut node_stats)
+                    .map(|(fc, id_arrays)| {
+                        found_common |= fc;
+
+                        id_arrays
+                    })
+            })
+            .collect::<Result<Vec<_>>>()?;
+        if found_common {
+            let mut common_nodes = CommonNodes::new();
+            let new_nodes_list = self.rewrite_nodes_list(
+                // Must clone the list of nodes as Identifiers use references 
to original
+                // nodes so we have to keep them intact.
+                nodes_list.clone(),
+                &id_arrays_list,
+                &node_stats,
+                &mut common_nodes,
+            )?;
+            assert!(!common_nodes.is_empty());
+
+            Ok(FoundCommonNodes::Yes {
+                common_nodes: common_nodes.into_values().collect(),
+                new_nodes_list,
+                original_nodes_list: nodes_list,
+            })
+        } else {
+            Ok(FoundCommonNodes::No {
+                original_nodes_list: nodes_list,
+            })
+        }
+    }
+}
+
+#[cfg(test)]
+mod test {
+    use crate::alias::AliasGenerator;
+    use crate::cse::{CSEController, HashNode, IdArray, Identifier, NodeStats, 
CSE};
+    use crate::tree_node::tests::TestTreeNode;
+    use crate::Result;
+    use std::collections::HashSet;
+    use std::hash::{Hash, Hasher};
+
+    const CSE_PREFIX: &str = "__common_node";
+
+    #[derive(Clone, Copy)]
+    pub enum TestTreeNodeMask {
+        Normal,
+        NormalAndAggregates,
+    }
+
+    pub struct TestTreeNodeCSEController<'a> {
+        alias_generator: &'a AliasGenerator,
+        mask: TestTreeNodeMask,
+    }
+
+    impl<'a> TestTreeNodeCSEController<'a> {
+        fn new(alias_generator: &'a AliasGenerator, mask: TestTreeNodeMask) -> 
Self {
+            Self {
+                alias_generator,
+                mask,
+            }
+        }
+    }
+
+    impl CSEController for TestTreeNodeCSEController<'_> {
+        type Node = TestTreeNode<String>;
+
+        fn conditional_children(
+            _: &Self::Node,
+        ) -> Option<(Vec<&Self::Node>, Vec<&Self::Node>)> {
+            None
+        }
+
+        fn is_valid(_node: &Self::Node) -> bool {
+            true
+        }
+
+        fn is_ignored(&self, node: &Self::Node) -> bool {
+            let is_leaf = node.is_leaf();
+            let is_aggr = node.data == "avg" || node.data == "sum";
+
+            match self.mask {
+                TestTreeNodeMask::Normal => is_leaf || is_aggr,
+                TestTreeNodeMask::NormalAndAggregates => is_leaf,
+            }
+        }
+
+        fn generate_alias(&self) -> String {
+            self.alias_generator.next(CSE_PREFIX)
+        }
+
+        fn rewrite(&mut self, node: &Self::Node, alias: &str) -> Self::Node {
+            TestTreeNode::new_leaf(format!("alias({}, {})", node.data, alias))
+        }
+    }
+
+    impl HashNode for TestTreeNode<String> {
+        fn hash_node<H: Hasher>(&self, state: &mut H) {
+            self.data.hash(state);
+        }
+    }
+
+    #[test]
+    fn id_array_visitor() -> Result<()> {
+        let alias_generator = AliasGenerator::new();
+        let eliminator = CSE::new(TestTreeNodeCSEController::new(
+            &alias_generator,
+            TestTreeNodeMask::Normal,
+        ));
+
+        let a_plus_1 = TestTreeNode::new(
+            vec![
+                TestTreeNode::new_leaf("a".to_string()),
+                TestTreeNode::new_leaf("1".to_string()),
+            ],
+            "+".to_string(),
+        );
+        let avg_c = TestTreeNode::new(
+            vec![TestTreeNode::new_leaf("c".to_string())],
+            "avg".to_string(),
+        );
+        let sum_a_plus_1 = TestTreeNode::new(vec![a_plus_1], 
"sum".to_string());
+        let sum_a_plus_1_minus_avg_c =
+            TestTreeNode::new(vec![sum_a_plus_1, avg_c], "-".to_string());
+        let root = TestTreeNode::new(
+            vec![
+                sum_a_plus_1_minus_avg_c,
+                TestTreeNode::new_leaf("2".to_string()),
+            ],
+            "*".to_string(),
+        );
+
+        let [sum_a_plus_1_minus_avg_c, _] = root.children.as_slice() else {
+            panic!("Cannot extract subtree references")
+        };
+        let [sum_a_plus_1, avg_c] = 
sum_a_plus_1_minus_avg_c.children.as_slice() else {
+            panic!("Cannot extract subtree references")
+        };
+        let [a_plus_1] = sum_a_plus_1.children.as_slice() else {
+            panic!("Cannot extract subtree references")
+        };
+
+        // skip aggregates
+        let mut id_array = vec![];
+        eliminator.node_to_id_array(&root, &mut NodeStats::new(), &mut 
id_array)?;
+
+        // Collect distinct hashes and set them to 0 in `id_array`
+        fn collect_hashes(
+            id_array: &mut IdArray<'_, TestTreeNode<String>>,
+        ) -> HashSet<u64> {
+            id_array
+                .iter_mut()
+                .flat_map(|(_, id_option)| {
+                    id_option.as_mut().map(|node_id| {
+                        let hash = node_id.hash;
+                        node_id.hash = 0;
+                        hash
+                    })
+                })
+                .collect::<HashSet<_>>()
+        }
+
+        let hashes = collect_hashes(&mut id_array);
+        assert_eq!(hashes.len(), 3);
+
+        let expected = vec![
+            (
+                8,
+                Some(Identifier {
+                    hash: 0,
+                    node: &root,
+                }),
+            ),
+            (
+                6,
+                Some(Identifier {
+                    hash: 0,
+                    node: sum_a_plus_1_minus_avg_c,
+                }),
+            ),
+            (3, None),
+            (
+                2,
+                Some(Identifier {
+                    hash: 0,
+                    node: a_plus_1,
+                }),
+            ),
+            (0, None),
+            (1, None),
+            (5, None),
+            (4, None),
+            (7, None),
+        ];
+        assert_eq!(expected, id_array);
+
+        // include aggregates
+        let eliminator = CSE::new(TestTreeNodeCSEController::new(
+            &alias_generator,
+            TestTreeNodeMask::NormalAndAggregates,
+        ));
+
+        let mut id_array = vec![];
+        eliminator.node_to_id_array(&root, &mut NodeStats::new(), &mut 
id_array)?;
+
+        let hashes = collect_hashes(&mut id_array);
+        assert_eq!(hashes.len(), 5);
+
+        let expected = vec![
+            (
+                8,
+                Some(Identifier {
+                    hash: 0,
+                    node: &root,
+                }),
+            ),
+            (
+                6,
+                Some(Identifier {
+                    hash: 0,
+                    node: sum_a_plus_1_minus_avg_c,
+                }),
+            ),
+            (
+                3,
+                Some(Identifier {
+                    hash: 0,
+                    node: sum_a_plus_1,
+                }),
+            ),
+            (
+                2,
+                Some(Identifier {
+                    hash: 0,
+                    node: a_plus_1,
+                }),
+            ),
+            (0, None),
+            (1, None),
+            (
+                5,
+                Some(Identifier {
+                    hash: 0,
+                    node: avg_c,
+                }),
+            ),
+            (4, None),
+            (7, None),
+        ];
+        assert_eq!(expected, id_array);
+
+        Ok(())
+    }
+}
diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs
index 8323f5efc8..e4575038ab 100644
--- a/datafusion/common/src/lib.rs
+++ b/datafusion/common/src/lib.rs
@@ -31,6 +31,7 @@ mod unnest;
 pub mod alias;
 pub mod cast;
 pub mod config;
+pub mod cse;
 pub mod display;
 pub mod error;
 pub mod file_options;
diff --git a/datafusion/common/src/tree_node.rs 
b/datafusion/common/src/tree_node.rs
index b4d3251fd2..563f1fa856 100644
--- a/datafusion/common/src/tree_node.rs
+++ b/datafusion/common/src/tree_node.rs
@@ -1027,7 +1027,7 @@ impl<T: ConcreteTreeNode> TreeNode for T {
 }
 
 #[cfg(test)]
-mod tests {
+pub(crate) mod tests {
     use std::collections::HashMap;
     use std::fmt::Display;
 
@@ -1037,16 +1037,27 @@ mod tests {
     };
     use crate::Result;
 
-    #[derive(Debug, Eq, Hash, PartialEq)]
-    struct TestTreeNode<T> {
-        children: Vec<TestTreeNode<T>>,
-        data: T,
+    #[derive(Debug, Eq, Hash, PartialEq, Clone)]
+    pub struct TestTreeNode<T> {
+        pub(crate) children: Vec<TestTreeNode<T>>,
+        pub(crate) data: T,
     }
 
     impl<T> TestTreeNode<T> {
-        fn new(children: Vec<TestTreeNode<T>>, data: T) -> Self {
+        pub(crate) fn new(children: Vec<TestTreeNode<T>>, data: T) -> Self {
             Self { children, data }
         }
+
+        pub(crate) fn new_leaf(data: T) -> Self {
+            Self {
+                children: vec![],
+                data,
+            }
+        }
+
+        pub(crate) fn is_leaf(&self) -> bool {
+            self.children.is_empty()
+        }
     }
 
     impl<T> TreeNode for TestTreeNode<T> {
@@ -1086,12 +1097,12 @@ mod tests {
     //       |
     //       A
     fn test_tree() -> TestTreeNode<String> {
-        let node_a = TestTreeNode::new(vec![], "a".to_string());
-        let node_b = TestTreeNode::new(vec![], "b".to_string());
+        let node_a = TestTreeNode::new_leaf("a".to_string());
+        let node_b = TestTreeNode::new_leaf("b".to_string());
         let node_d = TestTreeNode::new(vec![node_a], "d".to_string());
         let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string());
         let node_e = TestTreeNode::new(vec![node_c], "e".to_string());
-        let node_h = TestTreeNode::new(vec![], "h".to_string());
+        let node_h = TestTreeNode::new_leaf("h".to_string());
         let node_g = TestTreeNode::new(vec![node_h], "g".to_string());
         let node_f = TestTreeNode::new(vec![node_e, node_g], "f".to_string());
         let node_i = TestTreeNode::new(vec![node_f], "i".to_string());
@@ -1130,13 +1141,13 @@ mod tests {
 
     // Expected transformed tree after a combined traversal
     fn transformed_tree() -> TestTreeNode<String> {
-        let node_a = TestTreeNode::new(vec![], "f_up(f_down(a))".to_string());
-        let node_b = TestTreeNode::new(vec![], "f_up(f_down(b))".to_string());
+        let node_a = TestTreeNode::new_leaf("f_up(f_down(a))".to_string());
+        let node_b = TestTreeNode::new_leaf("f_up(f_down(b))".to_string());
         let node_d = TestTreeNode::new(vec![node_a], 
"f_up(f_down(d))".to_string());
         let node_c =
             TestTreeNode::new(vec![node_b, node_d], 
"f_up(f_down(c))".to_string());
         let node_e = TestTreeNode::new(vec![node_c], 
"f_up(f_down(e))".to_string());
-        let node_h = TestTreeNode::new(vec![], "f_up(f_down(h))".to_string());
+        let node_h = TestTreeNode::new_leaf("f_up(f_down(h))".to_string());
         let node_g = TestTreeNode::new(vec![node_h], 
"f_up(f_down(g))".to_string());
         let node_f =
             TestTreeNode::new(vec![node_e, node_g], 
"f_up(f_down(f))".to_string());
@@ -1146,12 +1157,12 @@ mod tests {
 
     // Expected transformed tree after a top-down traversal
     fn transformed_down_tree() -> TestTreeNode<String> {
-        let node_a = TestTreeNode::new(vec![], "f_down(a)".to_string());
-        let node_b = TestTreeNode::new(vec![], "f_down(b)".to_string());
+        let node_a = TestTreeNode::new_leaf("f_down(a)".to_string());
+        let node_b = TestTreeNode::new_leaf("f_down(b)".to_string());
         let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string());
         let node_c = TestTreeNode::new(vec![node_b, node_d], 
"f_down(c)".to_string());
         let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string());
-        let node_h = TestTreeNode::new(vec![], "f_down(h)".to_string());
+        let node_h = TestTreeNode::new_leaf("f_down(h)".to_string());
         let node_g = TestTreeNode::new(vec![node_h], "f_down(g)".to_string());
         let node_f = TestTreeNode::new(vec![node_e, node_g], 
"f_down(f)".to_string());
         let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string());
@@ -1160,12 +1171,12 @@ mod tests {
 
     // Expected transformed tree after a bottom-up traversal
     fn transformed_up_tree() -> TestTreeNode<String> {
-        let node_a = TestTreeNode::new(vec![], "f_up(a)".to_string());
-        let node_b = TestTreeNode::new(vec![], "f_up(b)".to_string());
+        let node_a = TestTreeNode::new_leaf("f_up(a)".to_string());
+        let node_b = TestTreeNode::new_leaf("f_up(b)".to_string());
         let node_d = TestTreeNode::new(vec![node_a], "f_up(d)".to_string());
         let node_c = TestTreeNode::new(vec![node_b, node_d], 
"f_up(c)".to_string());
         let node_e = TestTreeNode::new(vec![node_c], "f_up(e)".to_string());
-        let node_h = TestTreeNode::new(vec![], "f_up(h)".to_string());
+        let node_h = TestTreeNode::new_leaf("f_up(h)".to_string());
         let node_g = TestTreeNode::new(vec![node_h], "f_up(g)".to_string());
         let node_f = TestTreeNode::new(vec![node_e, node_g], 
"f_up(f)".to_string());
         let node_i = TestTreeNode::new(vec![node_f], "f_up(i)".to_string());
@@ -1202,12 +1213,12 @@ mod tests {
     }
 
     fn f_down_jump_on_a_transformed_down_tree() -> TestTreeNode<String> {
-        let node_a = TestTreeNode::new(vec![], "f_down(a)".to_string());
-        let node_b = TestTreeNode::new(vec![], "f_down(b)".to_string());
+        let node_a = TestTreeNode::new_leaf("f_down(a)".to_string());
+        let node_b = TestTreeNode::new_leaf("f_down(b)".to_string());
         let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string());
         let node_c = TestTreeNode::new(vec![node_b, node_d], 
"f_down(c)".to_string());
         let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string());
-        let node_h = TestTreeNode::new(vec![], "f_down(h)".to_string());
+        let node_h = TestTreeNode::new_leaf("f_down(h)".to_string());
         let node_g = TestTreeNode::new(vec![node_h], "f_down(g)".to_string());
         let node_f = TestTreeNode::new(vec![node_e, node_g], 
"f_down(f)".to_string());
         let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string());
@@ -1236,12 +1247,12 @@ mod tests {
     }
 
     fn f_down_jump_on_e_transformed_tree() -> TestTreeNode<String> {
-        let node_a = TestTreeNode::new(vec![], "a".to_string());
-        let node_b = TestTreeNode::new(vec![], "b".to_string());
+        let node_a = TestTreeNode::new_leaf("a".to_string());
+        let node_b = TestTreeNode::new_leaf("b".to_string());
         let node_d = TestTreeNode::new(vec![node_a], "d".to_string());
         let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string());
         let node_e = TestTreeNode::new(vec![node_c], 
"f_up(f_down(e))".to_string());
-        let node_h = TestTreeNode::new(vec![], "f_up(f_down(h))".to_string());
+        let node_h = TestTreeNode::new_leaf("f_up(f_down(h))".to_string());
         let node_g = TestTreeNode::new(vec![node_h], 
"f_up(f_down(g))".to_string());
         let node_f =
             TestTreeNode::new(vec![node_e, node_g], 
"f_up(f_down(f))".to_string());
@@ -1250,12 +1261,12 @@ mod tests {
     }
 
     fn f_down_jump_on_e_transformed_down_tree() -> TestTreeNode<String> {
-        let node_a = TestTreeNode::new(vec![], "a".to_string());
-        let node_b = TestTreeNode::new(vec![], "b".to_string());
+        let node_a = TestTreeNode::new_leaf("a".to_string());
+        let node_b = TestTreeNode::new_leaf("b".to_string());
         let node_d = TestTreeNode::new(vec![node_a], "d".to_string());
         let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string());
         let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string());
-        let node_h = TestTreeNode::new(vec![], "f_down(h)".to_string());
+        let node_h = TestTreeNode::new_leaf("f_down(h)".to_string());
         let node_g = TestTreeNode::new(vec![node_h], "f_down(g)".to_string());
         let node_f = TestTreeNode::new(vec![node_e, node_g], 
"f_down(f)".to_string());
         let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string());
@@ -1289,12 +1300,12 @@ mod tests {
     }
 
     fn f_up_jump_on_a_transformed_tree() -> TestTreeNode<String> {
-        let node_a = TestTreeNode::new(vec![], "f_up(f_down(a))".to_string());
-        let node_b = TestTreeNode::new(vec![], "f_up(f_down(b))".to_string());
+        let node_a = TestTreeNode::new_leaf("f_up(f_down(a))".to_string());
+        let node_b = TestTreeNode::new_leaf("f_up(f_down(b))".to_string());
         let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string());
         let node_c = TestTreeNode::new(vec![node_b, node_d], 
"f_down(c)".to_string());
         let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string());
-        let node_h = TestTreeNode::new(vec![], "f_up(f_down(h))".to_string());
+        let node_h = TestTreeNode::new_leaf("f_up(f_down(h))".to_string());
         let node_g = TestTreeNode::new(vec![node_h], 
"f_up(f_down(g))".to_string());
         let node_f =
             TestTreeNode::new(vec![node_e, node_g], 
"f_up(f_down(f))".to_string());
@@ -1303,12 +1314,12 @@ mod tests {
     }
 
     fn f_up_jump_on_a_transformed_up_tree() -> TestTreeNode<String> {
-        let node_a = TestTreeNode::new(vec![], "f_up(a)".to_string());
-        let node_b = TestTreeNode::new(vec![], "f_up(b)".to_string());
+        let node_a = TestTreeNode::new_leaf("f_up(a)".to_string());
+        let node_b = TestTreeNode::new_leaf("f_up(b)".to_string());
         let node_d = TestTreeNode::new(vec![node_a], "d".to_string());
         let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string());
         let node_e = TestTreeNode::new(vec![node_c], "e".to_string());
-        let node_h = TestTreeNode::new(vec![], "f_up(h)".to_string());
+        let node_h = TestTreeNode::new_leaf("f_up(h)".to_string());
         let node_g = TestTreeNode::new(vec![node_h], "f_up(g)".to_string());
         let node_f = TestTreeNode::new(vec![node_e, node_g], 
"f_up(f)".to_string());
         let node_i = TestTreeNode::new(vec![node_f], "f_up(i)".to_string());
@@ -1372,12 +1383,12 @@ mod tests {
     }
 
     fn f_down_stop_on_a_transformed_tree() -> TestTreeNode<String> {
-        let node_a = TestTreeNode::new(vec![], "f_down(a)".to_string());
-        let node_b = TestTreeNode::new(vec![], "f_up(f_down(b))".to_string());
+        let node_a = TestTreeNode::new_leaf("f_down(a)".to_string());
+        let node_b = TestTreeNode::new_leaf("f_up(f_down(b))".to_string());
         let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string());
         let node_c = TestTreeNode::new(vec![node_b, node_d], 
"f_down(c)".to_string());
         let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string());
-        let node_h = TestTreeNode::new(vec![], "h".to_string());
+        let node_h = TestTreeNode::new_leaf("h".to_string());
         let node_g = TestTreeNode::new(vec![node_h], "g".to_string());
         let node_f = TestTreeNode::new(vec![node_e, node_g], 
"f_down(f)".to_string());
         let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string());
@@ -1385,12 +1396,12 @@ mod tests {
     }
 
     fn f_down_stop_on_a_transformed_down_tree() -> TestTreeNode<String> {
-        let node_a = TestTreeNode::new(vec![], "f_down(a)".to_string());
-        let node_b = TestTreeNode::new(vec![], "f_down(b)".to_string());
+        let node_a = TestTreeNode::new_leaf("f_down(a)".to_string());
+        let node_b = TestTreeNode::new_leaf("f_down(b)".to_string());
         let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string());
         let node_c = TestTreeNode::new(vec![node_b, node_d], 
"f_down(c)".to_string());
         let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string());
-        let node_h = TestTreeNode::new(vec![], "h".to_string());
+        let node_h = TestTreeNode::new_leaf("h".to_string());
         let node_g = TestTreeNode::new(vec![node_h], "g".to_string());
         let node_f = TestTreeNode::new(vec![node_e, node_g], 
"f_down(f)".to_string());
         let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string());
@@ -1406,12 +1417,12 @@ mod tests {
     }
 
     fn f_down_stop_on_e_transformed_tree() -> TestTreeNode<String> {
-        let node_a = TestTreeNode::new(vec![], "a".to_string());
-        let node_b = TestTreeNode::new(vec![], "b".to_string());
+        let node_a = TestTreeNode::new_leaf("a".to_string());
+        let node_b = TestTreeNode::new_leaf("b".to_string());
         let node_d = TestTreeNode::new(vec![node_a], "d".to_string());
         let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string());
         let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string());
-        let node_h = TestTreeNode::new(vec![], "h".to_string());
+        let node_h = TestTreeNode::new_leaf("h".to_string());
         let node_g = TestTreeNode::new(vec![node_h], "g".to_string());
         let node_f = TestTreeNode::new(vec![node_e, node_g], 
"f_down(f)".to_string());
         let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string());
@@ -1419,12 +1430,12 @@ mod tests {
     }
 
     fn f_down_stop_on_e_transformed_down_tree() -> TestTreeNode<String> {
-        let node_a = TestTreeNode::new(vec![], "a".to_string());
-        let node_b = TestTreeNode::new(vec![], "b".to_string());
+        let node_a = TestTreeNode::new_leaf("a".to_string());
+        let node_b = TestTreeNode::new_leaf("b".to_string());
         let node_d = TestTreeNode::new(vec![node_a], "d".to_string());
         let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string());
         let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string());
-        let node_h = TestTreeNode::new(vec![], "h".to_string());
+        let node_h = TestTreeNode::new_leaf("h".to_string());
         let node_g = TestTreeNode::new(vec![node_h], "g".to_string());
         let node_f = TestTreeNode::new(vec![node_e, node_g], 
"f_down(f)".to_string());
         let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string());
@@ -1451,12 +1462,12 @@ mod tests {
     }
 
     fn f_up_stop_on_a_transformed_tree() -> TestTreeNode<String> {
-        let node_a = TestTreeNode::new(vec![], "f_up(f_down(a))".to_string());
-        let node_b = TestTreeNode::new(vec![], "f_up(f_down(b))".to_string());
+        let node_a = TestTreeNode::new_leaf("f_up(f_down(a))".to_string());
+        let node_b = TestTreeNode::new_leaf("f_up(f_down(b))".to_string());
         let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string());
         let node_c = TestTreeNode::new(vec![node_b, node_d], 
"f_down(c)".to_string());
         let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string());
-        let node_h = TestTreeNode::new(vec![], "h".to_string());
+        let node_h = TestTreeNode::new_leaf("h".to_string());
         let node_g = TestTreeNode::new(vec![node_h], "g".to_string());
         let node_f = TestTreeNode::new(vec![node_e, node_g], 
"f_down(f)".to_string());
         let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string());
@@ -1464,12 +1475,12 @@ mod tests {
     }
 
     fn f_up_stop_on_a_transformed_up_tree() -> TestTreeNode<String> {
-        let node_a = TestTreeNode::new(vec![], "f_up(a)".to_string());
-        let node_b = TestTreeNode::new(vec![], "f_up(b)".to_string());
+        let node_a = TestTreeNode::new_leaf("f_up(a)".to_string());
+        let node_b = TestTreeNode::new_leaf("f_up(b)".to_string());
         let node_d = TestTreeNode::new(vec![node_a], "d".to_string());
         let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string());
         let node_e = TestTreeNode::new(vec![node_c], "e".to_string());
-        let node_h = TestTreeNode::new(vec![], "h".to_string());
+        let node_h = TestTreeNode::new_leaf("h".to_string());
         let node_g = TestTreeNode::new(vec![node_h], "g".to_string());
         let node_f = TestTreeNode::new(vec![node_e, node_g], "f".to_string());
         let node_i = TestTreeNode::new(vec![node_f], "i".to_string());
@@ -1499,13 +1510,13 @@ mod tests {
     }
 
     fn f_up_stop_on_e_transformed_tree() -> TestTreeNode<String> {
-        let node_a = TestTreeNode::new(vec![], "f_up(f_down(a))".to_string());
-        let node_b = TestTreeNode::new(vec![], "f_up(f_down(b))".to_string());
+        let node_a = TestTreeNode::new_leaf("f_up(f_down(a))".to_string());
+        let node_b = TestTreeNode::new_leaf("f_up(f_down(b))".to_string());
         let node_d = TestTreeNode::new(vec![node_a], 
"f_up(f_down(d))".to_string());
         let node_c =
             TestTreeNode::new(vec![node_b, node_d], 
"f_up(f_down(c))".to_string());
         let node_e = TestTreeNode::new(vec![node_c], 
"f_up(f_down(e))".to_string());
-        let node_h = TestTreeNode::new(vec![], "h".to_string());
+        let node_h = TestTreeNode::new_leaf("h".to_string());
         let node_g = TestTreeNode::new(vec![node_h], "g".to_string());
         let node_f = TestTreeNode::new(vec![node_e, node_g], 
"f_down(f)".to_string());
         let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string());
@@ -1513,12 +1524,12 @@ mod tests {
     }
 
     fn f_up_stop_on_e_transformed_up_tree() -> TestTreeNode<String> {
-        let node_a = TestTreeNode::new(vec![], "f_up(a)".to_string());
-        let node_b = TestTreeNode::new(vec![], "f_up(b)".to_string());
+        let node_a = TestTreeNode::new_leaf("f_up(a)".to_string());
+        let node_b = TestTreeNode::new_leaf("f_up(b)".to_string());
         let node_d = TestTreeNode::new(vec![node_a], "f_up(d)".to_string());
         let node_c = TestTreeNode::new(vec![node_b, node_d], 
"f_up(c)".to_string());
         let node_e = TestTreeNode::new(vec![node_c], "f_up(e)".to_string());
-        let node_h = TestTreeNode::new(vec![], "h".to_string());
+        let node_h = TestTreeNode::new_leaf("h".to_string());
         let node_g = TestTreeNode::new(vec![node_h], "g".to_string());
         let node_f = TestTreeNode::new(vec![node_e, node_g], "f".to_string());
         let node_i = TestTreeNode::new(vec![node_f], "i".to_string());
@@ -2016,16 +2027,16 @@ mod tests {
     //       A
     #[test]
     fn test_apply_and_visit_references() -> Result<()> {
-        let node_a = TestTreeNode::new(vec![], "a".to_string());
-        let node_b = TestTreeNode::new(vec![], "b".to_string());
+        let node_a = TestTreeNode::new_leaf("a".to_string());
+        let node_b = TestTreeNode::new_leaf("b".to_string());
         let node_d = TestTreeNode::new(vec![node_a], "d".to_string());
         let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string());
         let node_e = TestTreeNode::new(vec![node_c], "e".to_string());
-        let node_a_2 = TestTreeNode::new(vec![], "a".to_string());
-        let node_b_2 = TestTreeNode::new(vec![], "b".to_string());
+        let node_a_2 = TestTreeNode::new_leaf("a".to_string());
+        let node_b_2 = TestTreeNode::new_leaf("b".to_string());
         let node_d_2 = TestTreeNode::new(vec![node_a_2], "d".to_string());
         let node_c_2 = TestTreeNode::new(vec![node_b_2, node_d_2], 
"c".to_string());
-        let node_a_3 = TestTreeNode::new(vec![], "a".to_string());
+        let node_a_3 = TestTreeNode::new_leaf("a".to_string());
         let tree = TestTreeNode::new(vec![node_e, node_c_2, node_a_3], 
"f".to_string());
 
         let node_f_ref = &tree;
diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs
index f3f71a8727..691b65d344 100644
--- a/datafusion/expr/src/expr.rs
+++ b/datafusion/expr/src/expr.rs
@@ -34,6 +34,7 @@ use crate::{
 };
 
 use arrow::datatypes::{DataType, FieldRef};
+use datafusion_common::cse::HashNode;
 use datafusion_common::tree_node::{
     Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
 };
@@ -1652,47 +1653,39 @@ impl Expr {
             | Expr::Placeholder(..) => false,
         }
     }
+}
 
-    /// Hashes the direct content of an `Expr` without recursing into its 
children.
-    ///
-    /// This method is useful to incrementally compute hashes, such as  in
-    /// `CommonSubexprEliminate` which builds a deep hash of a node and its 
descendants
-    /// during the bottom-up phase of the first traversal and so avoid 
computing the hash
-    /// of the node and then the hash of its descendants separately.
-    ///
-    /// If a node doesn't have any children then this method is similar to 
`.hash()`, but
-    /// not necessarily returns the same value.
-    ///
+impl HashNode for Expr {
     /// As it is pretty easy to forget changing this method when `Expr` 
changes the
     /// implementation doesn't use wildcard patterns (`..`, `_`) to catch 
changes
     /// compile time.
-    pub fn hash_node<H: Hasher>(&self, hasher: &mut H) {
-        mem::discriminant(self).hash(hasher);
+    fn hash_node<H: Hasher>(&self, state: &mut H) {
+        mem::discriminant(self).hash(state);
         match self {
             Expr::Alias(Alias {
                 expr: _expr,
                 relation,
                 name,
             }) => {
-                relation.hash(hasher);
-                name.hash(hasher);
+                relation.hash(state);
+                name.hash(state);
             }
             Expr::Column(column) => {
-                column.hash(hasher);
+                column.hash(state);
             }
             Expr::ScalarVariable(data_type, name) => {
-                data_type.hash(hasher);
-                name.hash(hasher);
+                data_type.hash(state);
+                name.hash(state);
             }
             Expr::Literal(scalar_value) => {
-                scalar_value.hash(hasher);
+                scalar_value.hash(state);
             }
             Expr::BinaryExpr(BinaryExpr {
                 left: _left,
                 op,
                 right: _right,
             }) => {
-                op.hash(hasher);
+                op.hash(state);
             }
             Expr::Like(Like {
                 negated,
@@ -1708,9 +1701,9 @@ impl Expr {
                 escape_char,
                 case_insensitive,
             }) => {
-                negated.hash(hasher);
-                escape_char.hash(hasher);
-                case_insensitive.hash(hasher);
+                negated.hash(state);
+                escape_char.hash(state);
+                case_insensitive.hash(state);
             }
             Expr::Not(_expr)
             | Expr::IsNotNull(_expr)
@@ -1728,7 +1721,7 @@ impl Expr {
                 low: _low,
                 high: _high,
             }) => {
-                negated.hash(hasher);
+                negated.hash(state);
             }
             Expr::Case(Case {
                 expr: _expr,
@@ -1743,10 +1736,10 @@ impl Expr {
                 expr: _expr,
                 data_type,
             }) => {
-                data_type.hash(hasher);
+                data_type.hash(state);
             }
             Expr::ScalarFunction(ScalarFunction { func, args: _args }) => {
-                func.hash(hasher);
+                func.hash(state);
             }
             Expr::AggregateFunction(AggregateFunction {
                 func,
@@ -1756,9 +1749,9 @@ impl Expr {
                 order_by: _order_by,
                 null_treatment,
             }) => {
-                func.hash(hasher);
-                distinct.hash(hasher);
-                null_treatment.hash(hasher);
+                func.hash(state);
+                distinct.hash(state);
+                null_treatment.hash(state);
             }
             Expr::WindowFunction(WindowFunction {
                 fun,
@@ -1768,49 +1761,49 @@ impl Expr {
                 window_frame,
                 null_treatment,
             }) => {
-                fun.hash(hasher);
-                window_frame.hash(hasher);
-                null_treatment.hash(hasher);
+                fun.hash(state);
+                window_frame.hash(state);
+                null_treatment.hash(state);
             }
             Expr::InList(InList {
                 expr: _expr,
                 list: _list,
                 negated,
             }) => {
-                negated.hash(hasher);
+                negated.hash(state);
             }
             Expr::Exists(Exists { subquery, negated }) => {
-                subquery.hash(hasher);
-                negated.hash(hasher);
+                subquery.hash(state);
+                negated.hash(state);
             }
             Expr::InSubquery(InSubquery {
                 expr: _expr,
                 subquery,
                 negated,
             }) => {
-                subquery.hash(hasher);
-                negated.hash(hasher);
+                subquery.hash(state);
+                negated.hash(state);
             }
             Expr::ScalarSubquery(subquery) => {
-                subquery.hash(hasher);
+                subquery.hash(state);
             }
             Expr::Wildcard { qualifier, options } => {
-                qualifier.hash(hasher);
-                options.hash(hasher);
+                qualifier.hash(state);
+                options.hash(state);
             }
             Expr::GroupingSet(grouping_set) => {
-                mem::discriminant(grouping_set).hash(hasher);
+                mem::discriminant(grouping_set).hash(state);
                 match grouping_set {
                     GroupingSet::Rollup(_exprs) | GroupingSet::Cube(_exprs) => 
{}
                     GroupingSet::GroupingSets(_exprs) => {}
                 }
             }
             Expr::Placeholder(place_holder) => {
-                place_holder.hash(hasher);
+                place_holder.hash(state);
             }
             Expr::OuterReferenceColumn(data_type, column) => {
-                data_type.hash(hasher);
-                column.hash(hasher);
+                data_type.hash(state);
+                column.hash(state);
             }
             Expr::Unnest(Unnest { expr: _expr }) => {}
         };
diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs 
b/datafusion/optimizer/src/common_subexpr_eliminate.rs
index c13cb3a8e9..921011d33f 100644
--- a/datafusion/optimizer/src/common_subexpr_eliminate.rs
+++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs
@@ -17,8 +17,8 @@
 
 //! [`CommonSubexprEliminate`] to avoid redundant computation of common 
sub-expressions
 
-use std::collections::{BTreeSet, HashMap};
-use std::hash::{BuildHasher, Hash, Hasher, RandomState};
+use std::collections::BTreeSet;
+use std::fmt::Debug;
 use std::sync::Arc;
 
 use crate::{OptimizerConfig, OptimizerRule};
@@ -26,11 +26,9 @@ use crate::{OptimizerConfig, OptimizerRule};
 use crate::optimizer::ApplyOrder;
 use crate::utils::NamePreserver;
 use datafusion_common::alias::AliasGenerator;
-use datafusion_common::hash_utils::combine_hashes;
-use datafusion_common::tree_node::{
-    Transformed, TransformedResult, TreeNode, TreeNodeRecursion, 
TreeNodeRewriter,
-    TreeNodeVisitor,
-};
+
+use datafusion_common::cse::{CSEController, FoundCommonNodes, CSE};
+use datafusion_common::tree_node::{Transformed, TreeNode};
 use datafusion_common::{qualified_name, Column, DFSchema, DFSchemaRef, Result};
 use datafusion_expr::expr::{Alias, ScalarFunction};
 use datafusion_expr::logical_plan::{
@@ -38,81 +36,9 @@ use datafusion_expr::logical_plan::{
 };
 use datafusion_expr::tree_node::replace_sort_expressions;
 use datafusion_expr::{col, BinaryExpr, Case, Expr, Operator};
-use indexmap::IndexMap;
 
 const CSE_PREFIX: &str = "__common_expr";
 
-/// Identifier that represents a subexpression tree.
-///
-/// This identifier is designed to be efficient and  "hash", "accumulate", 
"equal" and
-/// "have no collision (as low as possible)"
-#[derive(Clone, Copy, Debug, Eq, PartialEq)]
-struct Identifier<'n> {
-    // Hash of `expr` built up incrementally during the first, visiting 
traversal, but its
-    // value is not necessarily equal to `expr.hash()`.
-    hash: u64,
-    expr: &'n Expr,
-}
-
-impl<'n> Identifier<'n> {
-    fn new(expr: &'n Expr, random_state: &RandomState) -> Self {
-        let mut hasher = random_state.build_hasher();
-        expr.hash_node(&mut hasher);
-        let hash = hasher.finish();
-        Self { hash, expr }
-    }
-
-    fn combine(mut self, other: Option<Self>) -> Self {
-        other.map_or(self, |other_id| {
-            self.hash = combine_hashes(self.hash, other_id.hash);
-            self
-        })
-    }
-}
-
-impl Hash for Identifier<'_> {
-    fn hash<H: Hasher>(&self, state: &mut H) {
-        state.write_u64(self.hash);
-    }
-}
-
-/// A cache that contains the postorder index and the identifier of expression 
tree nodes
-/// by the preorder index of the nodes.
-///
-/// This cache is filled by `ExprIdentifierVisitor` during the first traversal 
and is used
-/// by `CommonSubexprRewriter` during the second traversal.
-///
-/// The purpose of this cache is to quickly find the identifier of a node 
during the
-/// second traversal.
-///
-/// Elements in this array are added during `f_down` so the indexes represent 
the preorder
-/// index of expression nodes and thus element 0 belongs to the root of the 
expression
-/// tree.
-/// The elements of the array are tuples that contain:
-/// - Postorder index that belongs to the preorder index. Assigned during 
`f_up`, start
-///   from 0.
-/// - Identifier of the expression. If empty (`""`), expr should not be 
considered for
-///   CSE.
-///
-/// # Example
-/// An expression like `(a + b)` would have the following `IdArray`:
-/// ```text
-/// [
-///   (2, "a + b"),
-///   (1, "a"),
-///   (0, "b")
-/// ]
-/// ```
-type IdArray<'n> = Vec<(usize, Option<Identifier<'n>>)>;
-
-/// A map that contains the number of normal and conditional occurrences of 
expressions by
-/// their identifiers.
-type ExprStats<'n> = HashMap<Identifier<'n>, (usize, usize)>;
-
-/// A map that contains the common expressions and their alias extracted 
during the
-/// second, rewriting traversal.
-type CommonExprs<'n> = IndexMap<Identifier<'n>, (Expr, String)>;
-
 /// Performs Common Sub-expression Elimination optimization.
 ///
 /// This optimization improves query performance by computing expressions that
@@ -140,168 +66,11 @@ type CommonExprs<'n> = IndexMap<Identifier<'n>, (Expr, 
String)>;
 ///   ProjectionExec(exprs=[to_date(c1) as new_col]) <-- compute to_date once
 /// ```
 #[derive(Debug)]
-pub struct CommonSubexprEliminate {
-    random_state: RandomState,
-}
-
-/// The result of potentially rewriting a list of expressions to eliminate 
common
-/// subexpressions.
-#[derive(Debug)]
-enum FoundCommonExprs {
-    /// No common expressions were found
-    No { original_exprs_list: Vec<Vec<Expr>> },
-    /// Common expressions were found
-    Yes {
-        /// extracted common expressions
-        common_exprs: Vec<(Expr, String)>,
-        /// new expressions with common subexpressions replaced
-        new_exprs_list: Vec<Vec<Expr>>,
-        /// original expressions
-        original_exprs_list: Vec<Vec<Expr>>,
-    },
-}
+pub struct CommonSubexprEliminate {}
 
 impl CommonSubexprEliminate {
     pub fn new() -> Self {
-        Self {
-            random_state: RandomState::new(),
-        }
-    }
-
-    /// Returns the identifier list for each element in `exprs` and a flag to 
indicate if
-    /// rewrite phase of CSE make sense.
-    ///
-    /// Returns and array with 1 element for each input expr in `exprs`
-    ///
-    /// Each element is itself the result of 
[`CommonSubexprEliminate::expr_to_identifier`] for that expr
-    /// (e.g. the identifiers for each node in the tree)
-    fn to_arrays<'n>(
-        &self,
-        exprs: &'n [Expr],
-        expr_stats: &mut ExprStats<'n>,
-        expr_mask: ExprMask,
-    ) -> Result<(bool, Vec<IdArray<'n>>)> {
-        let mut found_common = false;
-        exprs
-            .iter()
-            .map(|e| {
-                let mut id_array = vec![];
-                self.expr_to_identifier(e, expr_stats, &mut id_array, 
expr_mask)
-                    .map(|fc| {
-                        found_common |= fc;
-
-                        id_array
-                    })
-            })
-            .collect::<Result<Vec<_>>>()
-            .map(|id_arrays| (found_common, id_arrays))
-    }
-
-    /// Add an identifier to `id_array` for every subexpression in this tree.
-    fn expr_to_identifier<'n>(
-        &self,
-        expr: &'n Expr,
-        expr_stats: &mut ExprStats<'n>,
-        id_array: &mut IdArray<'n>,
-        expr_mask: ExprMask,
-    ) -> Result<bool> {
-        let mut visitor = ExprIdentifierVisitor {
-            expr_stats,
-            id_array,
-            visit_stack: vec![],
-            down_index: 0,
-            up_index: 0,
-            expr_mask,
-            random_state: &self.random_state,
-            found_common: false,
-            conditional: false,
-        };
-        expr.visit(&mut visitor)?;
-
-        Ok(visitor.found_common)
-    }
-
-    /// Rewrites `exprs_list` with common sub-expressions replaced with a new
-    /// column.
-    ///
-    /// `common_exprs` is updated with any sub expressions that were replaced.
-    ///
-    /// Returns the rewritten expressions
-    fn rewrite_exprs_list<'n>(
-        &self,
-        exprs_list: Vec<Vec<Expr>>,
-        arrays_list: &[Vec<IdArray<'n>>],
-        expr_stats: &ExprStats<'n>,
-        common_exprs: &mut CommonExprs<'n>,
-        alias_generator: &AliasGenerator,
-    ) -> Result<Vec<Vec<Expr>>> {
-        exprs_list
-            .into_iter()
-            .zip(arrays_list.iter())
-            .map(|(exprs, arrays)| {
-                exprs
-                    .into_iter()
-                    .zip(arrays.iter())
-                    .map(|(expr, id_array)| {
-                        replace_common_expr(
-                            expr,
-                            id_array,
-                            expr_stats,
-                            common_exprs,
-                            alias_generator,
-                        )
-                    })
-                    .collect::<Result<Vec<_>>>()
-            })
-            .collect::<Result<Vec<_>>>()
-    }
-
-    /// Extracts common sub-expressions and rewrites `exprs_list`.
-    ///
-    /// Returns `FoundCommonExprs` recording the result of the extraction
-    fn find_common_exprs(
-        &self,
-        exprs_list: Vec<Vec<Expr>>,
-        config: &dyn OptimizerConfig,
-        expr_mask: ExprMask,
-    ) -> Result<Transformed<FoundCommonExprs>> {
-        let mut found_common = false;
-        let mut expr_stats = ExprStats::new();
-        let id_arrays_list = exprs_list
-            .iter()
-            .map(|exprs| {
-                self.to_arrays(exprs, &mut expr_stats, expr_mask).map(
-                    |(fc, id_arrays)| {
-                        found_common |= fc;
-
-                        id_arrays
-                    },
-                )
-            })
-            .collect::<Result<Vec<_>>>()?;
-        if found_common {
-            let mut common_exprs = CommonExprs::new();
-            let new_exprs_list = self.rewrite_exprs_list(
-                // Must clone as Identifiers use references to original 
expressions so we have
-                // to keep the original expressions intact.
-                exprs_list.clone(),
-                &id_arrays_list,
-                &expr_stats,
-                &mut common_exprs,
-                config.alias_generator().as_ref(),
-            )?;
-            assert!(!common_exprs.is_empty());
-
-            Ok(Transformed::yes(FoundCommonExprs::Yes {
-                common_exprs: common_exprs.into_values().collect(),
-                new_exprs_list,
-                original_exprs_list: exprs_list,
-            }))
-        } else {
-            Ok(Transformed::no(FoundCommonExprs::No {
-                original_exprs_list: exprs_list,
-            }))
-        }
+        Self {}
     }
 
     fn try_optimize_proj(
@@ -372,80 +141,83 @@ impl CommonSubexprEliminate {
             get_consecutive_window_exprs(window);
 
         // Extract common sub-expressions from the list.
-        self.find_common_exprs(window_expr_list, config, ExprMask::Normal)?
-            .map_data(|common| match common {
-                // If there are common sub-expressions, then the insert a 
projection node
-                // with the common expressions between the new window nodes 
and the
-                // original input.
-                FoundCommonExprs::Yes {
-                    common_exprs,
-                    new_exprs_list,
-                    original_exprs_list,
-                } => {
-                    build_common_expr_project_plan(input, 
common_exprs).map(|new_input| {
-                        (new_exprs_list, new_input, Some(original_exprs_list))
+
+        match CSE::new(ExprCSEController::new(
+            config.alias_generator().as_ref(),
+            ExprMask::Normal,
+        ))
+        .extract_common_nodes(window_expr_list)?
+        {
+            // If there are common sub-expressions, then the insert a 
projection node
+            // with the common expressions between the new window nodes and the
+            // original input.
+            FoundCommonNodes::Yes {
+                common_nodes: common_exprs,
+                new_nodes_list: new_exprs_list,
+                original_nodes_list: original_exprs_list,
+            } => build_common_expr_project_plan(input, 
common_exprs).map(|new_input| {
+                Transformed::yes((new_exprs_list, new_input, 
Some(original_exprs_list)))
+            }),
+            FoundCommonNodes::No {
+                original_nodes_list: original_exprs_list,
+            } => Ok(Transformed::no((original_exprs_list, input, None))),
+        }?
+        // Recurse into the new input.
+        // (This is similar to what a `ApplyOrder::TopDown` optimizer rule 
would do.)
+        .transform_data(|(new_window_expr_list, new_input, window_expr_list)| {
+            self.rewrite(new_input, config)?.map_data(|new_input| {
+                Ok((new_window_expr_list, new_input, window_expr_list))
+            })
+        })?
+        // Rebuild the consecutive window nodes.
+        .map_data(|(new_window_expr_list, new_input, window_expr_list)| {
+            // If there were common expressions extracted, then we need to 
make sure
+            // we restore the original column names.
+            // TODO: Although `find_common_exprs()` inserts aliases around 
extracted
+            //  common expressions this doesn't mean that the original column 
names
+            //  (schema) are preserved due to the inserted aliases are not 
always at
+            //  the top of the expression.
+            //  Let's consider improving `find_common_exprs()` to always keep 
column
+            //  names and get rid of additional name preserving logic here.
+            if let Some(window_expr_list) = window_expr_list {
+                let name_preserver = NamePreserver::new_for_projection();
+                let saved_names = window_expr_list
+                    .iter()
+                    .map(|exprs| {
+                        exprs
+                            .iter()
+                            .map(|expr| name_preserver.save(expr))
+                            .collect::<Vec<_>>()
                     })
-                }
-                FoundCommonExprs::No {
-                    original_exprs_list,
-                } => Ok((original_exprs_list, input, None)),
-            })?
-            // Recurse into the new input.
-            // (This is similar to what a `ApplyOrder::TopDown` optimizer rule 
would do.)
-            .transform_data(|(new_window_expr_list, new_input, 
window_expr_list)| {
-                self.rewrite(new_input, config)?.map_data(|new_input| {
-                    Ok((new_window_expr_list, new_input, window_expr_list))
-                })
-            })?
-            // Rebuild the consecutive window nodes.
-            .map_data(|(new_window_expr_list, new_input, window_expr_list)| {
-                // If there were common expressions extracted, then we need to 
make sure
-                // we restore the original column names.
-                // TODO: Although `find_common_exprs()` inserts aliases around 
extracted
-                //  common expressions this doesn't mean that the original 
column names
-                //  (schema) are preserved due to the inserted aliases are not 
always at
-                //  the top of the expression.
-                //  Let's consider improving `find_common_exprs()` to always 
keep column
-                //  names and get rid of additional name preserving logic here.
-                if let Some(window_expr_list) = window_expr_list {
-                    let name_preserver = NamePreserver::new_for_projection();
-                    let saved_names = window_expr_list
-                        .iter()
-                        .map(|exprs| {
-                            exprs
-                                .iter()
-                                .map(|expr| name_preserver.save(expr))
-                                .collect::<Vec<_>>()
-                        })
-                        .collect::<Vec<_>>();
-                    
new_window_expr_list.into_iter().zip(saved_names).try_rfold(
-                        new_input,
-                        |plan, (new_window_expr, saved_names)| {
-                            let new_window_expr = new_window_expr
-                                .into_iter()
-                                .zip(saved_names)
-                                .map(|(new_window_expr, saved_name)| {
-                                    saved_name.restore(new_window_expr)
-                                })
-                                .collect::<Vec<_>>();
-                            Window::try_new(new_window_expr, Arc::new(plan))
-                                .map(LogicalPlan::Window)
-                        },
-                    )
-                } else {
-                    new_window_expr_list
-                        .into_iter()
-                        .zip(window_schemas)
-                        .try_rfold(new_input, |plan, (new_window_expr, 
schema)| {
-                            Window::try_new_with_schema(
-                                new_window_expr,
-                                Arc::new(plan),
-                                schema,
-                            )
+                    .collect::<Vec<_>>();
+                new_window_expr_list.into_iter().zip(saved_names).try_rfold(
+                    new_input,
+                    |plan, (new_window_expr, saved_names)| {
+                        let new_window_expr = new_window_expr
+                            .into_iter()
+                            .zip(saved_names)
+                            .map(|(new_window_expr, saved_name)| {
+                                saved_name.restore(new_window_expr)
+                            })
+                            .collect::<Vec<_>>();
+                        Window::try_new(new_window_expr, Arc::new(plan))
                             .map(LogicalPlan::Window)
-                        })
-                }
-            })
+                    },
+                )
+            } else {
+                new_window_expr_list
+                    .into_iter()
+                    .zip(window_schemas)
+                    .try_rfold(new_input, |plan, (new_window_expr, schema)| {
+                        Window::try_new_with_schema(
+                            new_window_expr,
+                            Arc::new(plan),
+                            schema,
+                        )
+                        .map(LogicalPlan::Window)
+                    })
+            }
+        })
     }
 
     fn try_optimize_aggregate(
@@ -462,174 +234,175 @@ impl CommonSubexprEliminate {
         } = aggregate;
         let input = Arc::unwrap_or_clone(input);
         // Extract common sub-expressions from the aggregate and grouping 
expressions.
-        self.find_common_exprs(vec![group_expr, aggr_expr], config, 
ExprMask::Normal)?
-            .map_data(|common| {
-                match common {
-                    // If there are common sub-expressions, then insert a 
projection node
-                    // with the common expressions between the new aggregate 
node and the
-                    // original input.
-                    FoundCommonExprs::Yes {
-                        common_exprs,
-                        mut new_exprs_list,
-                        mut original_exprs_list,
-                    } => {
-                        let new_aggr_expr = new_exprs_list.pop().unwrap();
-                        let new_group_expr = new_exprs_list.pop().unwrap();
-
-                        build_common_expr_project_plan(input, 
common_exprs).map(
-                            |new_input| {
-                                let aggr_expr = 
original_exprs_list.pop().unwrap();
-                                (
-                                    new_aggr_expr,
-                                    new_group_expr,
-                                    new_input,
-                                    Some(aggr_expr),
-                                )
-                            },
-                        )
-                    }
-
-                    FoundCommonExprs::No {
-                        mut original_exprs_list,
-                    } => {
-                        let new_aggr_expr = original_exprs_list.pop().unwrap();
-                        let new_group_expr = 
original_exprs_list.pop().unwrap();
-
-                        Ok((new_aggr_expr, new_group_expr, input, None))
-                    }
-                }
-            })?
-            // Recurse into the new input.
-            // (This is similar to what a `ApplyOrder::TopDown` optimizer rule 
would do.)
-            .transform_data(|(new_aggr_expr, new_group_expr, new_input, 
aggr_expr)| {
-                self.rewrite(new_input, config)?.map_data(|new_input| {
-                    Ok((
+        match CSE::new(ExprCSEController::new(
+            config.alias_generator().as_ref(),
+            ExprMask::Normal,
+        ))
+        .extract_common_nodes(vec![group_expr, aggr_expr])?
+        {
+            // If there are common sub-expressions, then insert a projection 
node
+            // with the common expressions between the new aggregate node and 
the
+            // original input.
+            FoundCommonNodes::Yes {
+                common_nodes: common_exprs,
+                new_nodes_list: mut new_exprs_list,
+                original_nodes_list: mut original_exprs_list,
+            } => {
+                let new_aggr_expr = new_exprs_list.pop().unwrap();
+                let new_group_expr = new_exprs_list.pop().unwrap();
+
+                build_common_expr_project_plan(input, 
common_exprs).map(|new_input| {
+                    let aggr_expr = original_exprs_list.pop().unwrap();
+                    Transformed::yes((
                         new_aggr_expr,
                         new_group_expr,
-                        aggr_expr,
-                        Arc::new(new_input),
+                        new_input,
+                        Some(aggr_expr),
                     ))
                 })
-            })?
-            // Try extracting common aggregate expressions and rebuild the 
aggregate node.
-            .transform_data(|(new_aggr_expr, new_group_expr, aggr_expr, 
new_input)| {
+            }
+
+            FoundCommonNodes::No {
+                original_nodes_list: mut original_exprs_list,
+            } => {
+                let new_aggr_expr = original_exprs_list.pop().unwrap();
+                let new_group_expr = original_exprs_list.pop().unwrap();
+
+                Ok(Transformed::no((
+                    new_aggr_expr,
+                    new_group_expr,
+                    input,
+                    None,
+                )))
+            }
+        }?
+        // Recurse into the new input.
+        // (This is similar to what a `ApplyOrder::TopDown` optimizer rule 
would do.)
+        .transform_data(|(new_aggr_expr, new_group_expr, new_input, 
aggr_expr)| {
+            self.rewrite(new_input, config)?.map_data(|new_input| {
+                Ok((
+                    new_aggr_expr,
+                    new_group_expr,
+                    aggr_expr,
+                    Arc::new(new_input),
+                ))
+            })
+        })?
+        // Try extracting common aggregate expressions and rebuild the 
aggregate node.
+        .transform_data(
+            |(new_aggr_expr, new_group_expr, aggr_expr, new_input)| {
                 // Extract common aggregate sub-expressions from the aggregate 
expressions.
-                self.find_common_exprs(
-                    vec![new_aggr_expr],
-                    config,
+                match CSE::new(ExprCSEController::new(
+                    config.alias_generator().as_ref(),
                     ExprMask::NormalAndAggregates,
-                )?
-                .map_data(|common| {
-                    match common {
-                        FoundCommonExprs::Yes {
-                            common_exprs,
-                            mut new_exprs_list,
-                            mut original_exprs_list,
-                        } => {
-                            let rewritten_aggr_expr = 
new_exprs_list.pop().unwrap();
-                            let new_aggr_expr = 
original_exprs_list.pop().unwrap();
-
-                            let mut agg_exprs = common_exprs
-                                .into_iter()
-                                .map(|(expr, expr_alias)| 
expr.alias(expr_alias))
-                                .collect::<Vec<_>>();
+                ))
+                .extract_common_nodes(vec![new_aggr_expr])?
+                {
+                    FoundCommonNodes::Yes {
+                        common_nodes: common_exprs,
+                        new_nodes_list: mut new_exprs_list,
+                        original_nodes_list: mut original_exprs_list,
+                    } => {
+                        let rewritten_aggr_expr = 
new_exprs_list.pop().unwrap();
+                        let new_aggr_expr = original_exprs_list.pop().unwrap();
 
-                            let mut proj_exprs = vec![];
-                            for expr in &new_group_expr {
-                                extract_expressions(expr, &mut proj_exprs)
-                            }
-                            for (expr_rewritten, expr_orig) in
-                                
rewritten_aggr_expr.into_iter().zip(new_aggr_expr)
-                            {
-                                if expr_rewritten == expr_orig {
-                                    if let Expr::Alias(Alias { expr, name, .. 
}) =
-                                        expr_rewritten
-                                    {
-                                        agg_exprs.push(expr.alias(&name));
-                                        proj_exprs
-                                            
.push(Expr::Column(Column::from_name(name)));
-                                    } else {
-                                        let expr_alias =
-                                            
config.alias_generator().next(CSE_PREFIX);
-                                        let (qualifier, field_name) =
-                                            expr_rewritten.qualified_name();
-                                        let out_name = qualified_name(
-                                            qualifier.as_ref(),
-                                            &field_name,
-                                        );
-
-                                        
agg_exprs.push(expr_rewritten.alias(&expr_alias));
-                                        proj_exprs.push(
-                                            
Expr::Column(Column::from_name(expr_alias))
-                                                .alias(out_name),
-                                        );
-                                    }
+                        let mut agg_exprs = common_exprs
+                            .into_iter()
+                            .map(|(expr, expr_alias)| expr.alias(expr_alias))
+                            .collect::<Vec<_>>();
+
+                        let mut proj_exprs = vec![];
+                        for expr in &new_group_expr {
+                            extract_expressions(expr, &mut proj_exprs)
+                        }
+                        for (expr_rewritten, expr_orig) in
+                            rewritten_aggr_expr.into_iter().zip(new_aggr_expr)
+                        {
+                            if expr_rewritten == expr_orig {
+                                if let Expr::Alias(Alias { expr, name, .. }) =
+                                    expr_rewritten
+                                {
+                                    agg_exprs.push(expr.alias(&name));
+                                    proj_exprs
+                                        
.push(Expr::Column(Column::from_name(name)));
                                 } else {
-                                    proj_exprs.push(expr_rewritten);
+                                    let expr_alias =
+                                        
config.alias_generator().next(CSE_PREFIX);
+                                    let (qualifier, field_name) =
+                                        expr_rewritten.qualified_name();
+                                    let out_name =
+                                        qualified_name(qualifier.as_ref(), 
&field_name);
+
+                                    
agg_exprs.push(expr_rewritten.alias(&expr_alias));
+                                    proj_exprs.push(
+                                        
Expr::Column(Column::from_name(expr_alias))
+                                            .alias(out_name),
+                                    );
                                 }
+                            } else {
+                                proj_exprs.push(expr_rewritten);
                             }
-
-                            let agg = 
LogicalPlan::Aggregate(Aggregate::try_new(
-                                new_input,
-                                new_group_expr,
-                                agg_exprs,
-                            )?);
-                            Projection::try_new(proj_exprs, Arc::new(agg))
-                                .map(LogicalPlan::Projection)
                         }
 
-                        // If there aren't any common aggregate 
sub-expressions, then just
-                        // rebuild the aggregate node.
-                        FoundCommonExprs::No {
-                            mut original_exprs_list,
-                        } => {
-                            let rewritten_aggr_expr = 
original_exprs_list.pop().unwrap();
-
-                            // If there were common expressions extracted, 
then we need to
-                            // make sure we restore the original column names.
-                            // TODO: Although `find_common_exprs()` inserts 
aliases around
-                            //  extracted common expressions this doesn't mean 
that the
-                            //  original column names (schema) are preserved 
due to the
-                            //  inserted aliases are not always at the top of 
the
-                            //  expression.
-                            //  Let's consider improving `find_common_exprs()` 
to always
-                            //  keep column names and get rid of additional 
name
-                            //  preserving logic here.
-                            if let Some(aggr_expr) = aggr_expr {
-                                let name_perserver = 
NamePreserver::new_for_projection();
-                                let saved_names = aggr_expr
-                                    .iter()
-                                    .map(|expr| name_perserver.save(expr))
-                                    .collect::<Vec<_>>();
-                                let new_aggr_expr = rewritten_aggr_expr
-                                    .into_iter()
-                                    .zip(saved_names)
-                                    .map(|(new_expr, saved_name)| {
-                                        saved_name.restore(new_expr)
-                                    })
-                                    .collect::<Vec<Expr>>();
-
-                                // Since `group_expr` may have changed, schema 
may also.
-                                // Use `try_new()` method.
-                                Aggregate::try_new(
-                                    new_input,
-                                    new_group_expr,
-                                    new_aggr_expr,
-                                )
-                                .map(LogicalPlan::Aggregate)
-                            } else {
-                                Aggregate::try_new_with_schema(
-                                    new_input,
-                                    new_group_expr,
-                                    rewritten_aggr_expr,
-                                    schema,
-                                )
+                        let agg = LogicalPlan::Aggregate(Aggregate::try_new(
+                            new_input,
+                            new_group_expr,
+                            agg_exprs,
+                        )?);
+                        Projection::try_new(proj_exprs, Arc::new(agg))
+                            .map(|p| 
Transformed::yes(LogicalPlan::Projection(p)))
+                    }
+
+                    // If there aren't any common aggregate sub-expressions, 
then just
+                    // rebuild the aggregate node.
+                    FoundCommonNodes::No {
+                        original_nodes_list: mut original_exprs_list,
+                    } => {
+                        let rewritten_aggr_expr = 
original_exprs_list.pop().unwrap();
+
+                        // If there were common expressions extracted, then we 
need to
+                        // make sure we restore the original column names.
+                        // TODO: Although `find_common_exprs()` inserts 
aliases around
+                        //  extracted common expressions this doesn't mean 
that the
+                        //  original column names (schema) are preserved due 
to the
+                        //  inserted aliases are not always at the top of the
+                        //  expression.
+                        //  Let's consider improving `find_common_exprs()` to 
always
+                        //  keep column names and get rid of additional name
+                        //  preserving logic here.
+                        if let Some(aggr_expr) = aggr_expr {
+                            let name_perserver = 
NamePreserver::new_for_projection();
+                            let saved_names = aggr_expr
+                                .iter()
+                                .map(|expr| name_perserver.save(expr))
+                                .collect::<Vec<_>>();
+                            let new_aggr_expr = rewritten_aggr_expr
+                                .into_iter()
+                                .zip(saved_names)
+                                .map(|(new_expr, saved_name)| {
+                                    saved_name.restore(new_expr)
+                                })
+                                .collect::<Vec<Expr>>();
+
+                            // Since `group_expr` may have changed, schema may 
also.
+                            // Use `try_new()` method.
+                            Aggregate::try_new(new_input, new_group_expr, 
new_aggr_expr)
                                 .map(LogicalPlan::Aggregate)
-                            }
+                                .map(Transformed::no)
+                        } else {
+                            Aggregate::try_new_with_schema(
+                                new_input,
+                                new_group_expr,
+                                rewritten_aggr_expr,
+                                schema,
+                            )
+                            .map(LogicalPlan::Aggregate)
+                            .map(Transformed::no)
                         }
                     }
-                })
-            })
+                }
+            },
+        )
     }
 
     /// Rewrites the expr list and input to remove common subexpressions
@@ -653,30 +426,34 @@ impl CommonSubexprEliminate {
         config: &dyn OptimizerConfig,
     ) -> Result<Transformed<(Vec<Expr>, LogicalPlan)>> {
         // Extract common sub-expressions from the expressions.
-        self.find_common_exprs(vec![exprs], config, ExprMask::Normal)?
-            .map_data(|common| match common {
-                FoundCommonExprs::Yes {
-                    common_exprs,
-                    mut new_exprs_list,
-                    original_exprs_list: _,
-                } => {
-                    let new_exprs = new_exprs_list.pop().unwrap();
-                    build_common_expr_project_plan(input, common_exprs)
-                        .map(|new_input| (new_exprs, new_input))
-                }
-                FoundCommonExprs::No {
-                    mut original_exprs_list,
-                } => {
-                    let new_exprs = original_exprs_list.pop().unwrap();
-                    Ok((new_exprs, input))
-                }
-            })?
-            // Recurse into the new input.
-            // (This is similar to what a `ApplyOrder::TopDown` optimizer rule 
would do.)
-            .transform_data(|(new_exprs, new_input)| {
-                self.rewrite(new_input, config)?
-                    .map_data(|new_input| Ok((new_exprs, new_input)))
-            })
+        match CSE::new(ExprCSEController::new(
+            config.alias_generator().as_ref(),
+            ExprMask::Normal,
+        ))
+        .extract_common_nodes(vec![exprs])?
+        {
+            FoundCommonNodes::Yes {
+                common_nodes: common_exprs,
+                new_nodes_list: mut new_exprs_list,
+                original_nodes_list: _,
+            } => {
+                let new_exprs = new_exprs_list.pop().unwrap();
+                build_common_expr_project_plan(input, common_exprs)
+                    .map(|new_input| Transformed::yes((new_exprs, new_input)))
+            }
+            FoundCommonNodes::No {
+                original_nodes_list: mut original_exprs_list,
+            } => {
+                let new_exprs = original_exprs_list.pop().unwrap();
+                Ok(Transformed::no((new_exprs, input)))
+            }
+        }?
+        // Recurse into the new input.
+        // (This is similar to what a `ApplyOrder::TopDown` optimizer rule 
would do.)
+        .transform_data(|(new_exprs, new_input)| {
+            self.rewrite(new_input, config)?
+                .map_data(|new_input| Ok((new_exprs, new_input)))
+        })
     }
 }
 
@@ -800,71 +577,6 @@ impl OptimizerRule for CommonSubexprEliminate {
     }
 }
 
-impl Default for CommonSubexprEliminate {
-    fn default() -> Self {
-        Self::new()
-    }
-}
-
-/// Build the "intermediate" projection plan that evaluates the extracted 
common
-/// expressions.
-///
-/// # Arguments
-/// input: the input plan
-///
-/// common_exprs: which common subexpressions were used (and thus are added to
-/// intermediate projection)
-///
-/// expr_stats: the set of common subexpressions
-fn build_common_expr_project_plan(
-    input: LogicalPlan,
-    common_exprs: Vec<(Expr, String)>,
-) -> Result<LogicalPlan> {
-    let mut fields_set = BTreeSet::new();
-    let mut project_exprs = common_exprs
-        .into_iter()
-        .map(|(expr, expr_alias)| {
-            fields_set.insert(expr_alias.clone());
-            Ok(expr.alias(expr_alias))
-        })
-        .collect::<Result<Vec<_>>>()?;
-
-    for (qualifier, field) in input.schema().iter() {
-        if fields_set.insert(qualified_name(qualifier, field.name())) {
-            project_exprs.push(Expr::from((qualifier, field)));
-        }
-    }
-
-    Projection::try_new(project_exprs, 
Arc::new(input)).map(LogicalPlan::Projection)
-}
-
-/// Build the projection plan to eliminate unnecessary columns produced by
-/// the "intermediate" projection plan built in 
[build_common_expr_project_plan].
-///
-/// This is required to keep the schema the same for plans that pass the input
-/// on to the output, such as `Filter` or `Sort`.
-fn build_recover_project_plan(
-    schema: &DFSchema,
-    input: LogicalPlan,
-) -> Result<LogicalPlan> {
-    let col_exprs = schema.iter().map(Expr::from).collect();
-    Projection::try_new(col_exprs, 
Arc::new(input)).map(LogicalPlan::Projection)
-}
-
-fn extract_expressions(expr: &Expr, result: &mut Vec<Expr>) {
-    if let Expr::GroupingSet(groupings) = expr {
-        for e in groupings.distinct_expr() {
-            let (qualifier, field_name) = e.qualified_name();
-            let col = Column::new(qualifier, field_name);
-            result.push(Expr::Column(col))
-        }
-    } else {
-        let (qualifier, field_name) = expr.qualified_name();
-        let col = Column::new(qualifier, field_name);
-        result.push(Expr::Column(col));
-    }
-}
-
 /// Which type of [expressions](Expr) should be considered for rewriting?
 #[derive(Debug, Clone, Copy)]
 enum ExprMask {
@@ -882,156 +594,36 @@ enum ExprMask {
     NormalAndAggregates,
 }
 
-impl ExprMask {
-    fn ignores(&self, expr: &Expr) -> bool {
-        let is_normal_minus_aggregates = matches!(
-            expr,
-            Expr::Literal(..)
-                | Expr::Column(..)
-                | Expr::ScalarVariable(..)
-                | Expr::Alias(..)
-                | Expr::Wildcard { .. }
-        );
-
-        let is_aggr = matches!(expr, Expr::AggregateFunction(..));
-
-        match self {
-            Self::Normal => is_normal_minus_aggregates || is_aggr,
-            Self::NormalAndAggregates => is_normal_minus_aggregates,
-        }
-    }
-}
-
-/// Go through an expression tree and generate identifiers for each 
subexpression.
-///
-/// An identifier contains information of the expression itself and its 
sub-expression.
-/// This visitor implementation use a stack `visit_stack` to track traversal, 
which
-/// lets us know when a sub-tree's visiting is finished. When `pre_visit` is 
called
-/// (traversing to a new node), an `EnterMark` and an `ExprItem` will be 
pushed into stack.
-/// And try to pop out a `EnterMark` on leaving a node (`f_up()`). All 
`ExprItem`
-/// before the first `EnterMark` is considered to be sub-tree of the leaving 
node.
-///
-/// This visitor also records identifier in `id_array`. Makes the following 
traverse
-/// pass can get the identifier of a node without recalculate it. We assign 
each node
-/// in the expr tree a series number, start from 1, maintained by 
`series_number`.
-/// Series number represents the order we left (`f_up()`) a node. Has the 
property
-/// that child node's series number always smaller than parent's. While 
`id_array` is
-/// organized in the order we enter (`f_down()`) a node. `node_count` helps us 
to
-/// get the index of `id_array` for each node.
-///
-/// `Expr` without sub-expr (column, literal etc.) will not have identifier
-/// because they should not be recognized as common sub-expr.
-struct ExprIdentifierVisitor<'a, 'n> {
-    // statistics of expressions
-    expr_stats: &'a mut ExprStats<'n>,
-    // cache to speed up second traversal
-    id_array: &'a mut IdArray<'n>,
-    // inner states
-    visit_stack: Vec<VisitRecord<'n>>,
-    // preorder index, start from 0.
-    down_index: usize,
-    // postorder index, start from 0.
-    up_index: usize,
-    // which expression should be skipped?
-    expr_mask: ExprMask,
-    // a `RandomState` to generate hashes during the first traversal
-    random_state: &'a RandomState,
-    // a flag to indicate that common expression found
-    found_common: bool,
-    // if we are in a conditional branch. A conditional branch means that the 
expression
-    // might not be executed depending on the runtime values of other 
expressions, and
-    // thus can not be extracted as a common expression.
-    conditional: bool,
-}
+struct ExprCSEController<'a> {
+    alias_generator: &'a AliasGenerator,
+    mask: ExprMask,
 
-/// Record item that used when traversing an expression tree.
-enum VisitRecord<'n> {
-    /// Marks the beginning of expression. It contains:
-    /// - The post-order index assigned during the first, visiting traversal.
-    EnterMark(usize),
-
-    /// Marks an accumulated subexpression tree. It contains:
-    /// - The accumulated identifier of a subexpression.
-    /// - A boolean flag if the expression is valid for subexpression 
elimination.
-    ///   The flag is propagated up from children to parent. (E.g. volatile 
expressions
-    ///   are not valid and can't be extracted, but non-volatile children of 
volatile
-    ///   expressions can be extracted.)
-    ExprItem(Identifier<'n>, bool),
+    // how many aliases have we seen so far
+    alias_counter: usize,
 }
 
-impl<'n> ExprIdentifierVisitor<'_, 'n> {
-    /// Find the first `EnterMark` in the stack, and accumulates every 
`ExprItem` before
-    /// it. Returns a tuple that contains:
-    /// - The pre-order index of the expression we marked.
-    /// - The accumulated identifier of the children of the marked expression.
-    /// - An accumulated boolean flag from the children of the marked 
expression if all
-    ///   children are valid for subexpression elimination (i.e. it is safe to 
extract the
-    ///   expression as a common expression from its children POV).
-    ///   (E.g. if any of the children of the marked expression is not valid 
(e.g. is
-    ///   volatile) then the expression is also not valid, so we can propagate 
this
-    ///   information up from children to parents via `visit_stack` during the 
first,
-    ///   visiting traversal and no need to test the expression's validity 
beforehand with
-    ///   an extra traversal).
-    fn pop_enter_mark(&mut self) -> (usize, Option<Identifier<'n>>, bool) {
-        let mut expr_id = None;
-        let mut is_valid = true;
-
-        while let Some(item) = self.visit_stack.pop() {
-            match item {
-                VisitRecord::EnterMark(down_index) => {
-                    return (down_index, expr_id, is_valid);
-                }
-                VisitRecord::ExprItem(sub_expr_id, sub_expr_is_valid) => {
-                    expr_id = Some(sub_expr_id.combine(expr_id));
-                    is_valid &= sub_expr_is_valid;
-                }
-            }
+impl<'a> ExprCSEController<'a> {
+    fn new(alias_generator: &'a AliasGenerator, mask: ExprMask) -> Self {
+        Self {
+            alias_generator,
+            mask,
+            alias_counter: 0,
         }
-        unreachable!("Enter mark should paired with node number");
-    }
-
-    /// Save the current `conditional` status and run `f` with `conditional` 
set to true.
-    fn conditionally<F: FnMut(&mut Self) -> Result<()>>(
-        &mut self,
-        mut f: F,
-    ) -> Result<()> {
-        let conditional = self.conditional;
-        self.conditional = true;
-        f(self)?;
-        self.conditional = conditional;
-
-        Ok(())
     }
 }
 
-impl<'n> TreeNodeVisitor<'n> for ExprIdentifierVisitor<'_, 'n> {
+impl CSEController for ExprCSEController<'_> {
     type Node = Expr;
 
-    fn f_down(&mut self, expr: &'n Expr) -> Result<TreeNodeRecursion> {
-        self.id_array.push((0, None));
-        self.visit_stack
-            .push(VisitRecord::EnterMark(self.down_index));
-        self.down_index += 1;
-
-        // If an expression can short-circuit then some of its children might 
not be
-        // executed so count the occurrence of subexpressions as conditional 
in all
-        // children.
-        Ok(match expr {
-            // If we are already in a conditionally evaluated subtree then 
continue
-            // traversal.
-            _ if self.conditional => TreeNodeRecursion::Continue,
-
+    fn conditional_children(node: &Expr) -> Option<(Vec<&Expr>, Vec<&Expr>)> {
+        match node {
             // In case of `ScalarFunction`s we don't know which children are 
surely
             // executed so start visiting all children conditionally and stop 
the
             // recursion with `TreeNodeRecursion::Jump`.
             Expr::ScalarFunction(ScalarFunction { func, args })
                 if func.short_circuits() =>
             {
-                self.conditionally(|visitor| {
-                    args.iter().try_for_each(|e| e.visit(visitor).map(|_| ()))
-                })?;
-
-                TreeNodeRecursion::Jump
+                Some((vec![], args.iter().collect()))
             }
 
             // In case of `And` and `Or` the first child is surely executed, 
but we
@@ -1040,12 +632,7 @@ impl<'n> TreeNodeVisitor<'n> for 
ExprIdentifierVisitor<'_, 'n> {
                 left,
                 op: Operator::And | Operator::Or,
                 right,
-            }) => {
-                left.visit(self)?;
-                self.conditionally(|visitor| right.visit(visitor).map(|_| 
()))?;
-
-                TreeNodeRecursion::Jump
-            }
+            }) => Some((vec![left.as_ref()], vec![right.as_ref()])),
 
             // In case of `Case` the optional base expression and the first 
when
             // expressions are surely executed, but we account subexpressions 
as
@@ -1054,167 +641,151 @@ impl<'n> TreeNodeVisitor<'n> for 
ExprIdentifierVisitor<'_, 'n> {
                 expr,
                 when_then_expr,
                 else_expr,
-            }) => {
-                expr.iter().try_for_each(|e| e.visit(self).map(|_| ()))?;
-                when_then_expr.iter().take(1).try_for_each(|(when, then)| {
-                    when.visit(self)?;
-                    self.conditionally(|visitor| then.visit(visitor).map(|_| 
()))
-                })?;
-                self.conditionally(|visitor| {
-                    when_then_expr.iter().skip(1).try_for_each(|(when, then)| {
-                        when.visit(visitor)?;
-                        then.visit(visitor).map(|_| ())
-                    })?;
-                    else_expr
-                        .iter()
-                        .try_for_each(|e| e.visit(visitor).map(|_| ()))
-                })?;
-
-                TreeNodeRecursion::Jump
-            }
+            }) => Some((
+                expr.iter()
+                    .map(|e| e.as_ref())
+                    .chain(when_then_expr.iter().take(1).map(|(when, _)| 
when.as_ref()))
+                    .collect(),
+                when_then_expr
+                    .iter()
+                    .take(1)
+                    .map(|(_, then)| then.as_ref())
+                    .chain(
+                        when_then_expr
+                            .iter()
+                            .skip(1)
+                            .flat_map(|(when, then)| [when.as_ref(), 
then.as_ref()]),
+                    )
+                    .chain(else_expr.iter().map(|e| e.as_ref()))
+                    .collect(),
+            )),
+            _ => None,
+        }
+    }
 
-            // In case of non-short-circuit expressions continue the traversal.
-            _ => TreeNodeRecursion::Continue,
-        })
+    fn is_valid(node: &Expr) -> bool {
+        !node.is_volatile_node()
     }
 
-    fn f_up(&mut self, expr: &'n Expr) -> Result<TreeNodeRecursion> {
-        let (down_index, sub_expr_id, sub_expr_is_valid) = 
self.pop_enter_mark();
+    fn is_ignored(&self, node: &Expr) -> bool {
+        let is_normal_minus_aggregates = matches!(
+            node,
+            Expr::Literal(..)
+                | Expr::Column(..)
+                | Expr::ScalarVariable(..)
+                | Expr::Alias(..)
+                | Expr::Wildcard { .. }
+        );
 
-        let expr_id = Identifier::new(expr, 
self.random_state).combine(sub_expr_id);
-        let is_valid = !expr.is_volatile_node() && sub_expr_is_valid;
+        let is_aggr = matches!(node, Expr::AggregateFunction(..));
 
-        self.id_array[down_index].0 = self.up_index;
-        if is_valid && !self.expr_mask.ignores(expr) {
-            self.id_array[down_index].1 = Some(expr_id);
-            let (count, conditional_count) =
-                self.expr_stats.entry(expr_id).or_insert((0, 0));
-            if self.conditional {
-                *conditional_count += 1;
-            } else {
-                *count += 1;
-            }
-            if *count > 1 || (*count == 1 && *conditional_count > 0) {
-                self.found_common = true;
-            }
+        match self.mask {
+            ExprMask::Normal => is_normal_minus_aggregates || is_aggr,
+            ExprMask::NormalAndAggregates => is_normal_minus_aggregates,
         }
-        self.visit_stack
-            .push(VisitRecord::ExprItem(expr_id, is_valid));
-        self.up_index += 1;
-
-        Ok(TreeNodeRecursion::Continue)
     }
-}
 
-/// Rewrite expression by replacing detected common sub-expression with
-/// the corresponding temporary column name. That column contains the
-/// evaluate result of replaced expression.
-struct CommonSubexprRewriter<'a, 'n> {
-    // statistics of expressions
-    expr_stats: &'a ExprStats<'n>,
-    // cache to speed up second traversal
-    id_array: &'a IdArray<'n>,
-    // common expression, that are replaced during the second traversal, are 
collected to
-    // this map
-    common_exprs: &'a mut CommonExprs<'n>,
-    // preorder index, starts from 0.
-    down_index: usize,
-    // how many aliases have we seen so far
-    alias_counter: usize,
-    // alias generator for extracted common expressions
-    alias_generator: &'a AliasGenerator,
-}
+    fn generate_alias(&self) -> String {
+        self.alias_generator.next(CSE_PREFIX)
+    }
 
-impl TreeNodeRewriter for CommonSubexprRewriter<'_, '_> {
-    type Node = Expr;
+    fn rewrite(&mut self, node: &Self::Node, alias: &str) -> Self::Node {
+        // alias the expressions without an `Alias` ancestor node
+        if self.alias_counter > 0 {
+            col(alias)
+        } else {
+            self.alias_counter += 1;
+            col(alias).alias(node.schema_name().to_string())
+        }
+    }
 
-    fn f_down(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
-        if matches!(expr, Expr::Alias(_)) {
+    fn rewrite_f_down(&mut self, node: &Expr) {
+        if matches!(node, Expr::Alias(_)) {
             self.alias_counter += 1;
         }
+    }
+    fn rewrite_f_up(&mut self, node: &Expr) {
+        if matches!(node, Expr::Alias(_)) {
+            self.alias_counter -= 1
+        }
+    }
+}
 
-        let (up_index, expr_id) = self.id_array[self.down_index];
-        self.down_index += 1;
+impl Default for CommonSubexprEliminate {
+    fn default() -> Self {
+        Self::new()
+    }
+}
 
-        // Handle `Expr`s with identifiers only
-        if let Some(expr_id) = expr_id {
-            let (count, conditional_count) = 
self.expr_stats.get(&expr_id).unwrap();
-            if *count > 1 || *count == 1 && *conditional_count > 0 {
-                // step index to skip all sub-node (which has smaller series 
number).
-                while self.down_index < self.id_array.len()
-                    && self.id_array[self.down_index].0 < up_index
-                {
-                    self.down_index += 1;
-                }
+/// Build the "intermediate" projection plan that evaluates the extracted 
common
+/// expressions.
+///
+/// # Arguments
+/// input: the input plan
+///
+/// common_exprs: which common subexpressions were used (and thus are added to
+/// intermediate projection)
+///
+/// expr_stats: the set of common subexpressions
+fn build_common_expr_project_plan(
+    input: LogicalPlan,
+    common_exprs: Vec<(Expr, String)>,
+) -> Result<LogicalPlan> {
+    let mut fields_set = BTreeSet::new();
+    let mut project_exprs = common_exprs
+        .into_iter()
+        .map(|(expr, expr_alias)| {
+            fields_set.insert(expr_alias.clone());
+            Ok(expr.alias(expr_alias))
+        })
+        .collect::<Result<Vec<_>>>()?;
 
-                let expr_name = expr.schema_name().to_string();
-                let (_, expr_alias) =
-                    self.common_exprs.entry(expr_id).or_insert_with(|| {
-                        let expr_alias = self.alias_generator.next(CSE_PREFIX);
-                        (expr, expr_alias)
-                    });
-
-                // alias the expressions without an `Alias` ancestor node
-                let rewritten = if self.alias_counter > 0 {
-                    col(expr_alias.clone())
-                } else {
-                    self.alias_counter += 1;
-                    col(expr_alias.clone()).alias(expr_name)
-                };
-
-                return Ok(Transformed::new(rewritten, true, 
TreeNodeRecursion::Jump));
-            }
+    for (qualifier, field) in input.schema().iter() {
+        if fields_set.insert(qualified_name(qualifier, field.name())) {
+            project_exprs.push(Expr::from((qualifier, field)));
         }
-
-        Ok(Transformed::no(expr))
     }
 
-    fn f_up(&mut self, expr: Expr) -> Result<Transformed<Self::Node>> {
-        if matches!(expr, Expr::Alias(_)) {
-            self.alias_counter -= 1
-        }
+    Projection::try_new(project_exprs, 
Arc::new(input)).map(LogicalPlan::Projection)
+}
 
-        Ok(Transformed::no(expr))
-    }
+/// Build the projection plan to eliminate unnecessary columns produced by
+/// the "intermediate" projection plan built in 
[build_common_expr_project_plan].
+///
+/// This is required to keep the schema the same for plans that pass the input
+/// on to the output, such as `Filter` or `Sort`.
+fn build_recover_project_plan(
+    schema: &DFSchema,
+    input: LogicalPlan,
+) -> Result<LogicalPlan> {
+    let col_exprs = schema.iter().map(Expr::from).collect();
+    Projection::try_new(col_exprs, 
Arc::new(input)).map(LogicalPlan::Projection)
 }
 
-/// Replace common sub-expression in `expr` with the corresponding temporary
-/// column name, updating `common_exprs` with any replaced expressions
-fn replace_common_expr<'n>(
-    expr: Expr,
-    id_array: &IdArray<'n>,
-    expr_stats: &ExprStats<'n>,
-    common_exprs: &mut CommonExprs<'n>,
-    alias_generator: &AliasGenerator,
-) -> Result<Expr> {
-    if id_array.is_empty() {
-        Ok(Transformed::no(expr))
+fn extract_expressions(expr: &Expr, result: &mut Vec<Expr>) {
+    if let Expr::GroupingSet(groupings) = expr {
+        for e in groupings.distinct_expr() {
+            let (qualifier, field_name) = e.qualified_name();
+            let col = Column::new(qualifier, field_name);
+            result.push(Expr::Column(col))
+        }
     } else {
-        expr.rewrite(&mut CommonSubexprRewriter {
-            expr_stats,
-            id_array,
-            common_exprs,
-            down_index: 0,
-            alias_counter: 0,
-            alias_generator,
-        })
+        let (qualifier, field_name) = expr.qualified_name();
+        let col = Column::new(qualifier, field_name);
+        result.push(Expr::Column(col));
     }
-    .data()
 }
 
 #[cfg(test)]
 mod test {
     use std::any::Any;
-    use std::collections::HashSet;
     use std::iter;
 
     use arrow::datatypes::{DataType, Field, Schema};
-    use datafusion_expr::expr::AggregateFunction;
     use datafusion_expr::logical_plan::{table_scan, JoinType};
     use datafusion_expr::{
-        grouping_set, AccumulatorFactoryFunction, AggregateUDF, BinaryExpr,
-        ColumnarValue, ScalarUDF, ScalarUDFImpl, Signature, SimpleAggregateUDF,
-        Volatility,
+        grouping_set, AccumulatorFactoryFunction, AggregateUDF, ColumnarValue, 
ScalarUDF,
+        ScalarUDFImpl, Signature, SimpleAggregateUDF, Volatility,
     };
     use datafusion_expr::{lit, logical_plan::builder::LogicalPlanBuilder};
 
@@ -1238,154 +809,6 @@ mod test {
         assert_eq!(expected, formatted_plan);
     }
 
-    #[test]
-    fn id_array_visitor() -> Result<()> {
-        let optimizer = CommonSubexprEliminate::new();
-
-        let a_plus_1 = col("a") + lit(1);
-        let avg_c = avg(col("c"));
-        let sum_a_plus_1 = sum(a_plus_1);
-        let sum_a_plus_1_minus_avg_c = sum_a_plus_1 - avg_c;
-        let expr = sum_a_plus_1_minus_avg_c * lit(2);
-
-        let Expr::BinaryExpr(BinaryExpr {
-            left: sum_a_plus_1_minus_avg_c,
-            ..
-        }) = &expr
-        else {
-            panic!("Cannot extract subexpression reference")
-        };
-        let Expr::BinaryExpr(BinaryExpr {
-            left: sum_a_plus_1,
-            right: avg_c,
-            ..
-        }) = sum_a_plus_1_minus_avg_c.as_ref()
-        else {
-            panic!("Cannot extract subexpression reference")
-        };
-        let Expr::AggregateFunction(AggregateFunction {
-            args: a_plus_1_vec, ..
-        }) = sum_a_plus_1.as_ref()
-        else {
-            panic!("Cannot extract subexpression reference")
-        };
-        let a_plus_1 = &a_plus_1_vec.as_slice()[0];
-
-        // skip aggregates
-        let mut id_array = vec![];
-        optimizer.expr_to_identifier(
-            &expr,
-            &mut ExprStats::new(),
-            &mut id_array,
-            ExprMask::Normal,
-        )?;
-
-        // Collect distinct hashes and set them to 0 in `id_array`
-        fn collect_hashes(id_array: &mut IdArray) -> HashSet<u64> {
-            id_array
-                .iter_mut()
-                .flat_map(|(_, expr_id_option)| {
-                    expr_id_option.as_mut().map(|expr_id| {
-                        let hash = expr_id.hash;
-                        expr_id.hash = 0;
-                        hash
-                    })
-                })
-                .collect::<HashSet<_>>()
-        }
-
-        let hashes = collect_hashes(&mut id_array);
-        assert_eq!(hashes.len(), 3);
-
-        let expected = vec![
-            (
-                8,
-                Some(Identifier {
-                    hash: 0,
-                    expr: &expr,
-                }),
-            ),
-            (
-                6,
-                Some(Identifier {
-                    hash: 0,
-                    expr: sum_a_plus_1_minus_avg_c,
-                }),
-            ),
-            (3, None),
-            (
-                2,
-                Some(Identifier {
-                    hash: 0,
-                    expr: a_plus_1,
-                }),
-            ),
-            (0, None),
-            (1, None),
-            (5, None),
-            (4, None),
-            (7, None),
-        ];
-        assert_eq!(expected, id_array);
-
-        // include aggregates
-        let mut id_array = vec![];
-        optimizer.expr_to_identifier(
-            &expr,
-            &mut ExprStats::new(),
-            &mut id_array,
-            ExprMask::NormalAndAggregates,
-        )?;
-
-        let hashes = collect_hashes(&mut id_array);
-        assert_eq!(hashes.len(), 5);
-
-        let expected = vec![
-            (
-                8,
-                Some(Identifier {
-                    hash: 0,
-                    expr: &expr,
-                }),
-            ),
-            (
-                6,
-                Some(Identifier {
-                    hash: 0,
-                    expr: sum_a_plus_1_minus_avg_c,
-                }),
-            ),
-            (
-                3,
-                Some(Identifier {
-                    hash: 0,
-                    expr: sum_a_plus_1,
-                }),
-            ),
-            (
-                2,
-                Some(Identifier {
-                    hash: 0,
-                    expr: a_plus_1,
-                }),
-            ),
-            (0, None),
-            (1, None),
-            (
-                5,
-                Some(Identifier {
-                    hash: 0,
-                    expr: avg_c,
-                }),
-            ),
-            (4, None),
-            (7, None),
-        ];
-        assert_eq!(expected, id_array);
-
-        Ok(())
-    }
-
     #[test]
     fn tpch_q1_simplified() -> Result<()> {
         // SQL:


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to