waynexia commented on code in PR #9719:
URL: https://github.com/apache/datafusion/pull/9719#discussion_r1586020132
##########
datafusion/optimizer/src/common_subexpr_eliminate.rs:
##########
@@ -829,6 +849,213 @@ fn replace_common_expr(
.data()
}
+struct ProjectionAdder {
+ // Keeps track of cumulative usage of common expressions with its
corresponding data type.
+ // accross plan where key is unsafe nodes that cumulative tracking is
invalidated.
+ insertion_point_map: HashMap<usize, HashMap<Expr, (DataType, u32)>>,
+ depth: usize,
+ // Keeps track of cumulative usage of the common expressions with its
corresponding data type.
+ // between safe nodes.
+ complex_exprs: HashMap<Expr, (DataType, u32)>,
+}
+pub fn is_not_complex(op: &Operator) -> bool {
+ matches!(
+ op,
Review Comment:
Sugg:
- `Operator::LtEq` and `Operator::GtEq` can also be sonsidered not complex
- `Operator` is `Copy`, so changing the function parameter to `op: Operator`
seems better
##########
datafusion/optimizer/src/common_subexpr_eliminate.rs:
##########
@@ -829,6 +849,213 @@ fn replace_common_expr(
.data()
}
+struct ProjectionAdder {
+ // Keeps track of cumulative usage of common expressions with its
corresponding data type.
+ // accross plan where key is unsafe nodes that cumulative tracking is
invalidated.
+ insertion_point_map: HashMap<usize, HashMap<Expr, (DataType, u32)>>,
+ depth: usize,
+ // Keeps track of cumulative usage of the common expressions with its
corresponding data type.
+ // between safe nodes.
+ complex_exprs: HashMap<Expr, (DataType, u32)>,
+}
+pub fn is_not_complex(op: &Operator) -> bool {
+ matches!(
+ op,
+ &Operator::Eq | &Operator::NotEq | &Operator::Lt | &Operator::Gt |
&Operator::And
+ )
+}
+
+impl ProjectionAdder {
+ // TODO: adding more expressions for sub query, currently only support for
Simple Binary Expressions
+ fn get_complex_expressions(
+ exprs: Vec<Expr>,
+ schema: DFSchemaRef,
+ ) -> HashSet<(Expr, DataType)> {
+ let mut res = HashSet::new();
+ for expr in exprs {
+ match expr {
+ Expr::BinaryExpr(BinaryExpr {
+ left: ref l_box,
+ op,
+ right: ref r_box,
+ }) if !is_not_complex(&op) => {
+ if let (Expr::Column(l), Expr::Column(_r)) = (&**l_box,
&**r_box) {
+ let l_field = schema
+ .field_from_column(l)
+ .expect("Field not found for left column");
+ res.insert((expr.clone(),
l_field.data_type().clone()));
+ }
+ }
+ Expr::Cast(Cast { expr, data_type: _ }) => {
+ let exprs_with_type =
+ Self::get_complex_expressions(vec![*expr],
schema.clone());
+ res.extend(exprs_with_type);
+ }
+ Expr::Alias(Alias {
+ expr,
+ relation: _,
+ name: _,
+ }) => {
+ let exprs_with_type =
+ Self::get_complex_expressions(vec![*expr],
schema.clone());
+ res.extend(exprs_with_type);
+ }
+ Expr::WindowFunction(WindowFunction { fun: _, args, .. }) => {
+ let exprs_with_type =
+ Self::get_complex_expressions(args, schema.clone());
+ res.extend(exprs_with_type);
+ }
+ _ => {}
+ }
+ }
+ res
+ }
+
+ fn update_expr_with_available_columns(
+ expr: &mut Expr,
+ available_columns: &[Column],
+ ) -> Result<()> {
+ match expr {
+ Expr::BinaryExpr(_) => {
+ for available_col in available_columns {
+ if available_col.flat_name() == expr.display_name()? {
+ *expr = Expr::Column(available_col.clone());
+ }
+ }
+ }
+ Expr::WindowFunction(WindowFunction { fun: _, args, .. }) => {
+ args.iter_mut().try_for_each(|arg| {
+ Self::update_expr_with_available_columns(arg,
available_columns)
+ })?
+ }
+ Expr::Cast(Cast { expr, .. }) => {
+ Self::update_expr_with_available_columns(expr,
available_columns)?
+ }
+ Expr::Alias(alias) => {
+ Self::update_expr_with_available_columns(
+ &mut alias.expr,
+ available_columns,
+ )?;
+ }
+ _ => {
+ // cannot rewrite
+ }
+ }
+ Ok(())
+ }
+
+ // Assumes operators doesn't modify name of the fields.
+ // Otherwise this operation is not safe.
+ fn extend_with_exprs(&mut self, node: &LogicalPlan) {
+ // use depth to trace where we are in the LogicalPlan tree
+ // extract all expressions + check whether it contains in depth_sets
+ let exprs = node.expressions();
+ let mut schema = node.schema().deref().clone();
+ for ip in node.inputs() {
+ schema.merge(ip.schema());
+ }
+ let expr_with_type = Self::get_complex_expressions(exprs,
Arc::new(schema));
+ for (expr, dtype) in expr_with_type {
+ let (_, count) = self.complex_exprs.entry(expr).or_insert_with(||
(dtype, 0));
+ *count += 1;
+ }
+ }
+}
+impl TreeNodeRewriter for ProjectionAdder {
+ type Node = LogicalPlan;
+ /// currently we just collect the complex bianryOP
+
+ fn f_down(&mut self, node: Self::Node) -> Result<Transformed<Self::Node>> {
+ // Insert for other end points
+ self.depth += 1;
+ match node {
+ LogicalPlan::TableScan(_) => {
+ // Stop tracking cumulative usage at the source.
+ let complex_exprs = std::mem::take(&mut self.complex_exprs);
+ self.insertion_point_map
+ .insert(self.depth - 1, complex_exprs);
+ Ok(Transformed::no(node))
+ }
+ LogicalPlan::Sort(_) | LogicalPlan::Filter(_) |
LogicalPlan::Window(_) => {
+ // These are safe operators where, expression identity is
preserved during operation.
+ self.extend_with_exprs(&node);
+ Ok(Transformed::no(node))
+ }
+ LogicalPlan::Projection(_) => {
+ // Stop tracking cumulative usage at the projection since it
may invalidate expression identity.
+ let complex_exprs = std::mem::take(&mut self.complex_exprs);
+ self.insertion_point_map
+ .insert(self.depth - 1, complex_exprs);
+ // Start tracking common expressions from now on including
projection.
+ self.extend_with_exprs(&node);
+ Ok(Transformed::no(node))
+ }
+ _ => {
+ // Unsupported operators
+ self.complex_exprs.clear();
+ Ok(Transformed::no(node))
+ }
+ }
+ }
+
+ fn f_up(&mut self, node: Self::Node) -> Result<Transformed<Self::Node>> {
+ let cached_exprs = self
+ .insertion_point_map
+ .get(&self.depth)
+ .cloned()
+ .unwrap_or_default();
+ self.depth -= 1;
+ // do not do extra things
+ let should_add_projection =
+ cached_exprs.iter().any(|(_expr, (_, count))| *count > 1);
+
+ let children = node.inputs();
+ if children.len() != 1 {
+ // Only can rewrite node with single child
+ return Ok(Transformed::no(node));
+ }
+ let child = children[0].clone();
+ let child = if should_add_projection {
+ let mut field_set = HashSet::new();
+ let mut project_exprs = vec![];
+ for (expr, (dtype, count)) in &cached_exprs {
+ if *count > 1 {
+ let f =
+ DFField::new_unqualified(&expr.to_string(),
dtype.clone(), true);
+ field_set.insert(f.name().to_owned());
Review Comment:
Since only field name is used, we can skip constructing the `DFField`?
```suggestion
field_set.insert(expr.to_string());
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]