cloud-fan commented on a change in pull request #35248:
URL: https://github.com/apache/spark/pull/35248#discussion_r802694708



##########
File path: sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
##########
@@ -806,17 +806,124 @@ class JDBCV2Suite extends QueryTest with 
SharedSparkSession with ExplainSuiteHel
     checkAnswer(query, Seq(Row(29000.0)))
   }
 
-  test("scan with aggregate push-down: SUM(CASE WHEN) with group by") {
-    val df =
-      sql("SELECT SUM(CASE WHEN SALARY > 0 THEN 1 ELSE 0 END) FROM 
h2.test.employee GROUP BY DEPT")
-    checkAggregateRemoved(df, false)
+  test("scan with aggregate push-down: aggregate function with CASE WHEN") {
+    val df = sql(
+      """
+        |SELECT
+        |  SUM(CASE WHEN SALARY is null THEN 0 ELSE SALARY END),
+        |  COUNT(CASE WHEN SALARY is null THEN 0 ELSE SALARY END),
+        |  AVG(CASE WHEN SALARY is null THEN 0 ELSE SALARY END),
+        |  MAX(CASE WHEN SALARY is null THEN 0 ELSE SALARY END),
+        |  MIN(CASE WHEN SALARY is null THEN 0 ELSE SALARY END)
+        |FROM h2.test.employee GROUP BY DEPT
+      """.stripMargin)
+    checkAggregateRemoved(df)
     df.queryExecution.optimizedPlan.collect {
       case _: DataSourceV2ScanRelation =>
         val expected_plan_fragment =
-          "PushedFilters: [], "
+          "PushedAggregates: [SUM(CASE WHEN SALARY IS NULL THEN 0.00 ELSE 
SALARY END), " +
+            "COUNT(CASE WHEN SALARY IS NULL THEN 0.0..., " +
+            "PushedFilters: [], " +
+            "PushedGroupByColumns: [DEPT]"
         checkKeywordsExistsInExplain(df, expected_plan_fragment)
     }
-    checkAnswer(df, Seq(Row(1), Row(2), Row(2)))
+    checkAnswer(df, Seq(Row(12000d, 1, 12000d, 12000d, 12000d),
+      Row(19000d, 2, 9500d, 10000d, 9000d), Row(22000d, 2, 11000d, 12000d, 
10000d)))
+
+    val df2 = sql(
+      """
+        |SELECT
+        |  SUM(CASE WHEN SALARY is not null THEN SALARY ELSE 0 END),
+        |  COUNT(CASE WHEN SALARY is not null THEN SALARY ELSE 0 END),
+        |  AVG(CASE WHEN SALARY is not null THEN SALARY ELSE 0 END),
+        |  MAX(CASE WHEN SALARY is not null THEN SALARY ELSE 0 END),
+        |  MIN(CASE WHEN SALARY is not null THEN SALARY ELSE 0 END)
+        |FROM h2.test.employee GROUP BY DEPT
+      """.stripMargin)
+    checkAggregateRemoved(df2)
+    df2.queryExecution.optimizedPlan.collect {
+      case _: DataSourceV2ScanRelation =>
+        val expected_plan_fragment =
+          "PushedAggregates: [SUM(CASE WHEN SALARY IS NOT NULL THEN SALARY 
ELSE 0.00 END), " +
+            "COUNT(CASE WHEN SALARY IS NOT NULL ..., " +
+            "PushedFilters: [], " +
+            "PushedGroupByColumns: [DEPT]"
+        checkKeywordsExistsInExplain(df2, expected_plan_fragment)
+    }
+    checkAnswer(df2, Seq(Row(12000d, 1, 12000d, 12000d, 12000d),
+      Row(19000d, 2, 9500d, 10000d, 9000d), Row(22000d, 2, 11000d, 12000d, 
10000d)))
+
+    val df3 = sql(
+      """
+        |SELECT
+        |  SUM(CASE WHEN SALARY > 0 THEN 0 ELSE SALARY END),
+        |  SUM(CASE WHEN SALARY >= 0 THEN 0 ELSE SALARY END),
+        |  SUM(CASE WHEN SALARY < 0 THEN 0 ELSE SALARY END),
+        |  SUM(CASE WHEN SALARY <= 0 THEN 0 ELSE SALARY END),
+        |  SUM(CASE WHEN SALARY = 0 THEN 0 ELSE SALARY END),
+        |  SUM(CASE WHEN NOT(SALARY > 0) THEN 0 ELSE SALARY END),
+        |  SUM(CASE WHEN NOT(SALARY >= 0) THEN 0 ELSE SALARY END),
+        |  SUM(CASE WHEN NOT(SALARY < 0) THEN 0 ELSE SALARY END),
+        |  SUM(CASE WHEN NOT(SALARY <= 0) THEN 0 ELSE SALARY END),
+        |  SUM(CASE WHEN NOT(SALARY = 0) THEN 0 ELSE SALARY END),
+        |  SUM(CASE WHEN SALARY != 0 THEN 0 ELSE SALARY END)
+        |FROM h2.test.employee GROUP BY DEPT
+      """.stripMargin)
+    checkAggregateRemoved(df3)
+    df3.queryExecution.optimizedPlan.collect {
+      case _: DataSourceV2ScanRelation =>
+        val expected_plan_fragment =
+          "PushedAggregates: [SUM(CASE WHEN (SALARY) > (0.00) THEN 0.00 ELSE 
SALARY END), " +
+            "SUM(CASE WHEN (SALARY) >= (0.00) THE..., " +
+            "PushedFilters: [], " +
+            "PushedGroupByColumns: [DEPT]"
+        checkKeywordsExistsInExplain(df3, expected_plan_fragment)
+    }
+    checkAnswer(df3, Seq(Row(0d, 0d, 12000d, 12000d, 12000d, 12000d, 12000d, 
0d, 0d, 0d, 0d),
+      Row(0d, 0d, 19000d, 19000d, 19000d, 19000d, 19000d, 0d, 0d, 0d, 0d),
+      Row(0d, 0d, 22000d, 22000d, 22000d, 22000d, 22000d, 0d, 0d, 0d, 0d)))
+
+    val df4 = sql(
+      """
+        |SELECT
+        |  COUNT(CASE WHEN SALARY > 8000 AND SALARY < 10000 THEN SALARY ELSE 0 
END),

Review comment:
       I think this test covers everything. We don't need the above 3 test 
queries.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to