cloud-fan commented on a change in pull request #35651:
URL: https://github.com/apache/spark/pull/35651#discussion_r815988303



##########
File path: 
sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregatesSuite.scala
##########
@@ -166,12 +162,137 @@ class RemoveRedundantAggregatesSuite extends PlanTest {
   }
 
   test("Keep non-redundant aggregate - upper references non-deterministic 
non-grouping") {
-    val relation = LocalRelation('a.int, 'b.int)
     val query = relation
       .groupBy('a)('a, ('a + rand(0)) as 'c)
       .groupBy('a, 'c)('a, 'c)
       .analyze
     val optimized = Optimize.execute(query)
     comparePlans(optimized, query)
   }
+
+  test("SPARK-36194: Remove aggregation from left semi/anti join if 
aggregation the same") {
+    Seq(LeftSemi, LeftAnti).foreach { joinType =>
+      val originalQuery = x.groupBy('a, 'b)('a, 'b)
+        .join(y, joinType, Some("x.a".attr === "y.a".attr && "x.b".attr === 
"y.b".attr))
+        .groupBy("x.a".attr, "x.b".attr)("x.a".attr, "x.b".attr)
+      val correctAnswer = x.groupBy('a, 'b)('a, 'b)
+        .join(y, joinType, Some("x.a".attr === "y.a".attr && "x.b".attr === 
"y.b".attr))
+        .select("x.a".attr, "x.b".attr)
+
+      val optimized = Optimize.execute(originalQuery.analyze)
+      comparePlans(optimized, correctAnswer.analyze)
+    }
+  }
+
+  test("SPARK-36194: Remove aggregation from left semi/anti join with alias") {
+    Seq(LeftSemi, LeftAnti).foreach { joinType =>
+      val originalQuery = x.groupBy('a, 'b)('a, 'b.as("d"))
+        .join(y, joinType, Some("x.a".attr === "y.a".attr && "d".attr === 
"y.b".attr))
+        .groupBy("x.a".attr, "d".attr)("x.a".attr, "d".attr)
+      val correctAnswer = x.groupBy('a, 'b)('a, 'b.as("d"))
+        .join(y, joinType, Some("x.a".attr === "y.a".attr && "d".attr === 
"y.b".attr))
+        .select("x.a".attr, "d".attr)
+
+      val optimized = Optimize.execute(originalQuery.analyze)
+      comparePlans(optimized, correctAnswer.analyze)
+    }
+  }
+
+  test("SPARK-36194: Remove aggregation from left semi/anti join if it is the 
sub aggregateExprs") {
+    Seq(LeftSemi, LeftAnti).foreach { joinType =>
+      val originalQuery = x.groupBy('a, 'b)('a, 'b)
+        .join(y, joinType, Some("x.a".attr === "y.a".attr && "x.b".attr === 
"y.b".attr))
+        .groupBy("x.a".attr, "x.b".attr)("x.a".attr)
+      val correctAnswer = x.groupBy('a, 'b)('a, 'b)
+        .join(y, joinType, Some("x.a".attr === "y.a".attr && "x.b".attr === 
"y.b".attr))
+        .select("x.a".attr)
+
+      val optimized = Optimize.execute(originalQuery.analyze)
+      comparePlans(optimized, correctAnswer.analyze)
+    }
+  }
+
+  test("SPARK-36194: Transform down to remove more aggregates") {
+    Seq(LeftSemi, LeftAnti).foreach { joinType =>
+      val originalQuery = x.groupBy('a, 'b)('a, 'b)
+        .join(y, joinType, Some("x.a".attr === "y.a".attr && "x.b".attr === 
"y.b".attr))
+        .groupBy("x.a".attr, "x.b".attr)("x.a".attr, "x.b".attr)
+        .join(y, joinType, Some("x.a".attr === "y.a".attr && "x.b".attr === 
"y.b".attr))
+        .groupBy("x.a".attr, "x.b".attr)("x.a".attr)
+      val correctAnswer = x.groupBy('a, 'b)('a, 'b)
+        .join(y, joinType, Some("x.a".attr === "y.a".attr && "x.b".attr === 
"y.b".attr))
+        .select("x.a".attr, "x.b".attr)
+        .join(y, joinType, Some("x.a".attr === "y.a".attr && "x.b".attr === 
"y.b".attr))
+        .select("x.a".attr)
+
+      val optimized = Optimize.execute(originalQuery.analyze)
+      comparePlans(optimized, correctAnswer.analyze)
+    }
+  }
+
+  test("SPARK-36194: Remove aggregation from aggregation") {
+    val originalQuery = relation
+      .groupBy('a)('a, count('b).as("cnt"))
+      .groupBy('a, 'cnt)('a, 'cnt)
+      .analyze
+    val correctAnswer = relation
+      .groupBy('a)('a, count('b).as("cnt"))
+      .select('a, 'cnt)
+      .analyze
+    val optimized = Optimize.execute(originalQuery)
+    comparePlans(optimized, correctAnswer)
+  }
+
+  test("SPARK-36194: Negative case: The grouping expressions not same") {
+    Seq(LeftSemi, LeftAnti).foreach { joinType =>
+      val originalQuery = x.groupBy('a, 'b)('a, 'b)
+        .join(y, joinType, Some("x.a".attr === "y.a".attr && "x.b".attr === 
"y.b".attr))
+        .groupBy("x.a".attr)("x.a".attr)
+
+      val optimized = Optimize.execute(originalQuery.analyze)
+      comparePlans(optimized, originalQuery.analyze)
+    }
+  }
+
+  test("SPARK-36194: Negative case: The aggregate expressions not the sub 
aggregateExprs") {
+    Seq(LeftSemi, LeftAnti).foreach { joinType =>
+      val originalQuery = x.groupBy('a, 'b)('a, 'b)
+        .join(y, joinType, Some("x.a".attr === "y.a".attr && "x.b".attr === 
"y.b".attr))
+        .groupBy("x.a".attr, "x.b".attr)(TrueLiteral)

Review comment:
       hmmm why can't we optimize this query?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to