crepererum commented on code in PR #6129:
URL: https://github.com/apache/arrow-datafusion/pull/6129#discussion_r1180070069
##########
datafusion/optimizer/src/common_subexpr_eliminate.rs:
##########
@@ -708,22 +889,149 @@ mod test {
fn aggregate() -> Result<()> {
let table_scan = test_table_scan()?;
- let plan = LogicalPlanBuilder::from(table_scan)
+ let return_type: ReturnTypeFunction = Arc::new(|inputs| {
+ assert_eq!(inputs, &[DataType::UInt32]);
+ Ok(Arc::new(DataType::UInt32))
+ });
+ let accumulator: AccumulatorFunctionImplementation =
+ Arc::new(|_| unimplemented!());
+ let state_type: StateTypeFunction = Arc::new(|_| unimplemented!());
+ let udf_agg = |inner: Expr| Expr::AggregateUDF {
+ fun: Arc::new(AggregateUDF::new(
+ "my_agg",
+ &Signature::exact(vec![DataType::UInt32], Volatility::Stable),
+ &return_type,
+ &accumulator,
+ &state_type,
+ )),
+ args: vec![inner],
+ filter: None,
+ };
+
+ // test: common aggregates
+ let plan = LogicalPlanBuilder::from(table_scan.clone())
+ .aggregate(
+ iter::empty::<Expr>(),
+ vec![
+ // common: avg(col("a"))
+ avg(col("a")).alias("col1"),
+ avg(col("a")).alias("col2"),
+ // no common
+ avg(col("b")).alias("col3"),
+ avg(col("c")),
+ // common: udf_agg(col("a"))
+ udf_agg(col("a")).alias("col4"),
+ udf_agg(col("a")).alias("col5"),
+ // no common
+ udf_agg(col("b")).alias("col6"),
+ udf_agg(col("c")),
+ ],
+ )?
+ .build()?;
+
+ let expected = "Projection: AVG(test.a)test.a AS AVG(test.a) AS col1,
AVG(test.a)test.a AS AVG(test.a) AS col2, col3, AVG(test.c) AS AVG(test.c),
my_agg(test.a)test.a AS my_agg(test.a) AS col4, my_agg(test.a)test.a AS
my_agg(test.a) AS col5, col6, my_agg(test.c) AS my_agg(test.c)\
+ \n Aggregate: groupBy=[[]], aggr=[[AVG(test.a) AS AVG(test.a)test.a,
my_agg(test.a) AS my_agg(test.a)test.a, AVG(test.b) AS col3, AVG(test.c) AS
AVG(test.c), my_agg(test.b) AS col6, my_agg(test.c) AS my_agg(test.c)]]\
+ \n TableScan: test";
+
+ assert_optimized_plan_eq(expected, &plan);
+
+ // test: trafo after aggregate
+ let plan = LogicalPlanBuilder::from(table_scan.clone())
.aggregate(
iter::empty::<Expr>(),
vec![
binary_expr(lit(1), Operator::Plus, avg(col("a"))),
binary_expr(lit(1), Operator::Minus, avg(col("a"))),
+ binary_expr(lit(1), Operator::Plus, udf_agg(col("a"))),
+ binary_expr(lit(1), Operator::Minus, udf_agg(col("a"))),
],
)?
.build()?;
- let expected = "Aggregate: groupBy=[[]], aggr=[[Int32(1) +
AVG(test.a)test.a AS AVG(test.a), Int32(1) - AVG(test.a)test.a AS AVG(test.a)]]\
- \n Projection: AVG(test.a) AS AVG(test.a)test.a, test.a, test.b,
test.c\
+ let expected = "Projection: Int32(1) + AVG(test.a)test.a AS
AVG(test.a), Int32(1) - AVG(test.a)test.a AS AVG(test.a), Int32(1) +
my_agg(test.a)test.a AS my_agg(test.a), Int32(1) - my_agg(test.a)test.a AS
my_agg(test.a)\
+ \n Aggregate: groupBy=[[]], aggr=[[AVG(test.a) AS AVG(test.a)test.a,
my_agg(test.a) AS my_agg(test.a)test.a]]\
\n TableScan: test";
assert_optimized_plan_eq(expected, &plan);
+ // test: trafo before aggregate
Review Comment:
transformation, will fix the comment to spell it out
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]