cloud-fan commented on a change in pull request #28496:
URL: https://github.com/apache/spark/pull/28496#discussion_r425132885
##########
File path:
sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
##########
@@ -973,4 +973,39 @@ class DataFrameAggregateSuite extends QueryTest
assert(error.message.contains("function count_if requires boolean type"))
}
}
+
+ Seq(true, false).foreach { value =>
+ test(s"SPARK-31620: agg with subquery (whole-stage-codegen = $value)") {
+ withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> value.toString) {
+ withTempView("t1", "t2") {
+ sql("create temporary view t1 as select * from values (1, 2) as
t1(a, b)")
+ sql("create temporary view t2 as select * from values (3, 4) as
t2(c, d)")
+
+ // test without grouping keys
+ checkAnswer(sql("select sum(if(c > (select a from t1), d, 0)) as
csum from t2"),
+ Row(4) :: Nil)
+
+ // test with grouping keys
+ checkAnswer(sql("select c, sum(if(c > (select a from t1), d, 0)) as
csum from " +
+ "t2 group by c"), Row(3, 4) :: Nil)
+
+ // test with distinct
+ checkAnswer(sql("select avg(distinct(d)), sum(distinct(if(c >
(select a from t1)," +
+ " d, 0))) as csum from t2 group by c"), Row(4, 4) :: Nil)
+
+ // test subquery with agg
+ checkAnswer(sql("select sum(distinct(if(c > (select sum(distinct(a))
from t1)," +
+ " d, 0))) as csum from t2 group by c"), Row(4) :: Nil)
+
+ // test SortAggregateExec
+ checkAnswer(sql("select max(if(c > (select a from t1), 'str1',
'str2')) as csum from t2"),
Review comment:
it's better to check the physical plan and make sure it's sort agg
##########
File path:
sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
##########
@@ -973,4 +973,39 @@ class DataFrameAggregateSuite extends QueryTest
assert(error.message.contains("function count_if requires boolean type"))
}
}
+
+ Seq(true, false).foreach { value =>
+ test(s"SPARK-31620: agg with subquery (whole-stage-codegen = $value)") {
+ withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> value.toString) {
+ withTempView("t1", "t2") {
+ sql("create temporary view t1 as select * from values (1, 2) as
t1(a, b)")
+ sql("create temporary view t2 as select * from values (3, 4) as
t2(c, d)")
+
+ // test without grouping keys
+ checkAnswer(sql("select sum(if(c > (select a from t1), d, 0)) as
csum from t2"),
+ Row(4) :: Nil)
+
+ // test with grouping keys
+ checkAnswer(sql("select c, sum(if(c > (select a from t1), d, 0)) as
csum from " +
+ "t2 group by c"), Row(3, 4) :: Nil)
+
+ // test with distinct
+ checkAnswer(sql("select avg(distinct(d)), sum(distinct(if(c >
(select a from t1)," +
+ " d, 0))) as csum from t2 group by c"), Row(4, 4) :: Nil)
+
+ // test subquery with agg
+ checkAnswer(sql("select sum(distinct(if(c > (select sum(distinct(a))
from t1)," +
+ " d, 0))) as csum from t2 group by c"), Row(4) :: Nil)
+
+ // test SortAggregateExec
+ checkAnswer(sql("select max(if(c > (select a from t1), 'str1',
'str2')) as csum from t2"),
+ Row("str1") :: Nil)
+
+ // test ObjectHashAggregateExec
Review comment:
ditto
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]