crepererum commented on code in PR #6129:
URL: https://github.com/apache/arrow-datafusion/pull/6129#discussion_r1180181390
##########
datafusion/optimizer/src/common_subexpr_eliminate.rs:
##########
@@ -708,22 +889,149 @@ mod test {
fn aggregate() -> Result<()> {
let table_scan = test_table_scan()?;
- let plan = LogicalPlanBuilder::from(table_scan)
+ let return_type: ReturnTypeFunction = Arc::new(|inputs| {
+ assert_eq!(inputs, &[DataType::UInt32]);
+ Ok(Arc::new(DataType::UInt32))
+ });
+ let accumulator: AccumulatorFunctionImplementation =
+ Arc::new(|_| unimplemented!());
+ let state_type: StateTypeFunction = Arc::new(|_| unimplemented!());
+ let udf_agg = |inner: Expr| Expr::AggregateUDF {
+ fun: Arc::new(AggregateUDF::new(
+ "my_agg",
+ &Signature::exact(vec![DataType::UInt32], Volatility::Stable),
+ &return_type,
+ &accumulator,
+ &state_type,
+ )),
+ args: vec![inner],
+ filter: None,
+ };
+
+ // test: common aggregates
+ let plan = LogicalPlanBuilder::from(table_scan.clone())
+ .aggregate(
+ iter::empty::<Expr>(),
+ vec![
+ // common: avg(col("a"))
+ avg(col("a")).alias("col1"),
+ avg(col("a")).alias("col2"),
+ // no common
+ avg(col("b")).alias("col3"),
+ avg(col("c")),
+ // common: udf_agg(col("a"))
+ udf_agg(col("a")).alias("col4"),
+ udf_agg(col("a")).alias("col5"),
+ // no common
+ udf_agg(col("b")).alias("col6"),
+ udf_agg(col("c")),
+ ],
+ )?
+ .build()?;
+
+ let expected = "Projection: AVG(test.a)test.a AS AVG(test.a) AS col1,
AVG(test.a)test.a AS AVG(test.a) AS col2, col3, AVG(test.c) AS AVG(test.c),
my_agg(test.a)test.a AS my_agg(test.a) AS col4, my_agg(test.a)test.a AS
my_agg(test.a) AS col5, col6, my_agg(test.c) AS my_agg(test.c)\
+ \n Aggregate: groupBy=[[]], aggr=[[AVG(test.a) AS AVG(test.a)test.a,
my_agg(test.a) AS my_agg(test.a)test.a, AVG(test.b) AS col3, AVG(test.c) AS
AVG(test.c), my_agg(test.b) AS col6, my_agg(test.c) AS my_agg(test.c)]]\
+ \n TableScan: test";
+
+ assert_optimized_plan_eq(expected, &plan);
+
+ // test: trafo after aggregate
+ let plan = LogicalPlanBuilder::from(table_scan.clone())
.aggregate(
iter::empty::<Expr>(),
vec![
binary_expr(lit(1), Operator::Plus, avg(col("a"))),
binary_expr(lit(1), Operator::Minus, avg(col("a"))),
+ binary_expr(lit(1), Operator::Plus, udf_agg(col("a"))),
+ binary_expr(lit(1), Operator::Minus, udf_agg(col("a"))),
Review Comment:
done (also for the other tests within this module
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]