alamb commented on code in PR #6129:
URL: https://github.com/apache/arrow-datafusion/pull/6129#discussion_r1179407218


##########
datafusion/optimizer/src/common_subexpr_eliminate.rs:
##########
@@ -708,22 +889,149 @@ mod test {
     fn aggregate() -> Result<()> {
         let table_scan = test_table_scan()?;
 
-        let plan = LogicalPlanBuilder::from(table_scan)
+        let return_type: ReturnTypeFunction = Arc::new(|inputs| {
+            assert_eq!(inputs, &[DataType::UInt32]);
+            Ok(Arc::new(DataType::UInt32))
+        });
+        let accumulator: AccumulatorFunctionImplementation =
+            Arc::new(|_| unimplemented!());
+        let state_type: StateTypeFunction = Arc::new(|_| unimplemented!());
+        let udf_agg = |inner: Expr| Expr::AggregateUDF {
+            fun: Arc::new(AggregateUDF::new(
+                "my_agg",
+                &Signature::exact(vec![DataType::UInt32], Volatility::Stable),
+                &return_type,
+                &accumulator,
+                &state_type,
+            )),
+            args: vec![inner],
+            filter: None,
+        };
+
+        // test: common aggregates
+        let plan = LogicalPlanBuilder::from(table_scan.clone())
+            .aggregate(
+                iter::empty::<Expr>(),
+                vec![
+                    // common: avg(col("a"))
+                    avg(col("a")).alias("col1"),
+                    avg(col("a")).alias("col2"),
+                    // no common
+                    avg(col("b")).alias("col3"),
+                    avg(col("c")),
+                    // common: udf_agg(col("a"))
+                    udf_agg(col("a")).alias("col4"),
+                    udf_agg(col("a")).alias("col5"),
+                    // no common
+                    udf_agg(col("b")).alias("col6"),
+                    udf_agg(col("c")),
+                ],
+            )?
+            .build()?;
+
+        let expected = "Projection: AVG(test.a)test.a AS AVG(test.a) AS col1, 
AVG(test.a)test.a AS AVG(test.a) AS col2, col3, AVG(test.c) AS AVG(test.c), 
my_agg(test.a)test.a AS my_agg(test.a) AS col4, my_agg(test.a)test.a AS 
my_agg(test.a) AS col5, col6, my_agg(test.c) AS my_agg(test.c)\

Review Comment:
   👍 



##########
datafusion/optimizer/src/common_subexpr_eliminate.rs:
##########
@@ -708,22 +889,149 @@ mod test {
     fn aggregate() -> Result<()> {
         let table_scan = test_table_scan()?;
 
-        let plan = LogicalPlanBuilder::from(table_scan)
+        let return_type: ReturnTypeFunction = Arc::new(|inputs| {
+            assert_eq!(inputs, &[DataType::UInt32]);
+            Ok(Arc::new(DataType::UInt32))
+        });
+        let accumulator: AccumulatorFunctionImplementation =
+            Arc::new(|_| unimplemented!());
+        let state_type: StateTypeFunction = Arc::new(|_| unimplemented!());
+        let udf_agg = |inner: Expr| Expr::AggregateUDF {
+            fun: Arc::new(AggregateUDF::new(
+                "my_agg",
+                &Signature::exact(vec![DataType::UInt32], Volatility::Stable),
+                &return_type,
+                &accumulator,
+                &state_type,
+            )),
+            args: vec![inner],
+            filter: None,
+        };
+
+        // test: common aggregates
+        let plan = LogicalPlanBuilder::from(table_scan.clone())
+            .aggregate(
+                iter::empty::<Expr>(),
+                vec![
+                    // common: avg(col("a"))
+                    avg(col("a")).alias("col1"),
+                    avg(col("a")).alias("col2"),
+                    // no common
+                    avg(col("b")).alias("col3"),
+                    avg(col("c")),
+                    // common: udf_agg(col("a"))
+                    udf_agg(col("a")).alias("col4"),
+                    udf_agg(col("a")).alias("col5"),
+                    // no common
+                    udf_agg(col("b")).alias("col6"),
+                    udf_agg(col("c")),
+                ],
+            )?
+            .build()?;
+
+        let expected = "Projection: AVG(test.a)test.a AS AVG(test.a) AS col1, 
AVG(test.a)test.a AS AVG(test.a) AS col2, col3, AVG(test.c) AS AVG(test.c), 
my_agg(test.a)test.a AS my_agg(test.a) AS col4, my_agg(test.a)test.a AS 
my_agg(test.a) AS col5, col6, my_agg(test.c) AS my_agg(test.c)\
+        \n  Aggregate: groupBy=[[]], aggr=[[AVG(test.a) AS AVG(test.a)test.a, 
my_agg(test.a) AS my_agg(test.a)test.a, AVG(test.b) AS col3, AVG(test.c) AS 
AVG(test.c), my_agg(test.b) AS col6, my_agg(test.c) AS my_agg(test.c)]]\
+        \n    TableScan: test";
+
+        assert_optimized_plan_eq(expected, &plan);
+
+        // test: trafo after aggregate
+        let plan = LogicalPlanBuilder::from(table_scan.clone())
             .aggregate(
                 iter::empty::<Expr>(),
                 vec![
                     binary_expr(lit(1), Operator::Plus, avg(col("a"))),
                     binary_expr(lit(1), Operator::Minus, avg(col("a"))),
+                    binary_expr(lit(1), Operator::Plus, udf_agg(col("a"))),
+                    binary_expr(lit(1), Operator::Minus, udf_agg(col("a"))),
                 ],
             )?
             .build()?;
 
-        let expected = "Aggregate: groupBy=[[]], aggr=[[Int32(1) + 
AVG(test.a)test.a AS AVG(test.a), Int32(1) - AVG(test.a)test.a AS AVG(test.a)]]\
-        \n  Projection: AVG(test.a) AS AVG(test.a)test.a, test.a, test.b, 
test.c\
+        let expected = "Projection: Int32(1) + AVG(test.a)test.a AS 
AVG(test.a), Int32(1) - AVG(test.a)test.a AS AVG(test.a), Int32(1) + 
my_agg(test.a)test.a AS my_agg(test.a), Int32(1) - my_agg(test.a)test.a AS 
my_agg(test.a)\
+        \n  Aggregate: groupBy=[[]], aggr=[[AVG(test.a) AS AVG(test.a)test.a, 
my_agg(test.a) AS my_agg(test.a)test.a]]\
         \n    TableScan: test";
 
         assert_optimized_plan_eq(expected, &plan);
 
+        // test: trafo before aggregate

Review Comment:
   What is a `trafo`?



##########
datafusion/optimizer/src/common_subexpr_eliminate.rs:
##########
@@ -352,6 +462,49 @@ fn build_recover_project_plan(schema: &DFSchema, input: 
LogicalPlan) -> LogicalP
     )
 }
 
+/// Which type of [expressions](Expr) should be considered for rewriting?

Review Comment:
   Another enum also adds an excellent place for more documentation 👍 



##########
datafusion/optimizer/src/common_subexpr_eliminate.rs:
##########
@@ -708,22 +889,149 @@ mod test {
     fn aggregate() -> Result<()> {
         let table_scan = test_table_scan()?;
 
-        let plan = LogicalPlanBuilder::from(table_scan)
+        let return_type: ReturnTypeFunction = Arc::new(|inputs| {
+            assert_eq!(inputs, &[DataType::UInt32]);
+            Ok(Arc::new(DataType::UInt32))
+        });
+        let accumulator: AccumulatorFunctionImplementation =
+            Arc::new(|_| unimplemented!());
+        let state_type: StateTypeFunction = Arc::new(|_| unimplemented!());
+        let udf_agg = |inner: Expr| Expr::AggregateUDF {
+            fun: Arc::new(AggregateUDF::new(
+                "my_agg",
+                &Signature::exact(vec![DataType::UInt32], Volatility::Stable),
+                &return_type,
+                &accumulator,
+                &state_type,
+            )),
+            args: vec![inner],
+            filter: None,
+        };
+
+        // test: common aggregates
+        let plan = LogicalPlanBuilder::from(table_scan.clone())
+            .aggregate(
+                iter::empty::<Expr>(),
+                vec![
+                    // common: avg(col("a"))
+                    avg(col("a")).alias("col1"),
+                    avg(col("a")).alias("col2"),
+                    // no common
+                    avg(col("b")).alias("col3"),
+                    avg(col("c")),
+                    // common: udf_agg(col("a"))
+                    udf_agg(col("a")).alias("col4"),
+                    udf_agg(col("a")).alias("col5"),
+                    // no common
+                    udf_agg(col("b")).alias("col6"),
+                    udf_agg(col("c")),
+                ],
+            )?
+            .build()?;
+
+        let expected = "Projection: AVG(test.a)test.a AS AVG(test.a) AS col1, 
AVG(test.a)test.a AS AVG(test.a) AS col2, col3, AVG(test.c) AS AVG(test.c), 
my_agg(test.a)test.a AS my_agg(test.a) AS col4, my_agg(test.a)test.a AS 
my_agg(test.a) AS col5, col6, my_agg(test.c) AS my_agg(test.c)\
+        \n  Aggregate: groupBy=[[]], aggr=[[AVG(test.a) AS AVG(test.a)test.a, 
my_agg(test.a) AS my_agg(test.a)test.a, AVG(test.b) AS col3, AVG(test.c) AS 
AVG(test.c), my_agg(test.b) AS col6, my_agg(test.c) AS my_agg(test.c)]]\
+        \n    TableScan: test";
+
+        assert_optimized_plan_eq(expected, &plan);
+
+        // test: trafo after aggregate
+        let plan = LogicalPlanBuilder::from(table_scan.clone())
             .aggregate(
                 iter::empty::<Expr>(),
                 vec![
                     binary_expr(lit(1), Operator::Plus, avg(col("a"))),
                     binary_expr(lit(1), Operator::Minus, avg(col("a"))),
+                    binary_expr(lit(1), Operator::Plus, udf_agg(col("a"))),
+                    binary_expr(lit(1), Operator::Minus, udf_agg(col("a"))),

Review Comment:
   FWIW you can also write these expressions like this (which I find easier to 
read / follow the intent) as it matches the output more closely
   
   ```suggestion
                      lit(1)+ avg(col("a")),
                       lit(1)- avg(col("a"))),
                       lit(1) + udf_agg(col("a")),
                       lit(1) - udf_agg(col("a")),
   ```
   
   The same comment applies to the new tests below as well



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to