peter-toth commented on code in PR #55628:
URL: https://github.com/apache/spark/pull/55628#discussion_r3173342588
##########
sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubplansSuite.scala:
##########
@@ -1542,4 +1543,179 @@ class MergeSubplansSuite extends PlanTest {
comparePlans(Optimize.execute(originalQuery.analyze),
correctAnswer.analyze)
}
+
+ test("SPARK-56677: Merge non-grouping subqueries with filter on left join
child") {
+ // cp (subquery1): Aggregate([], [sum(a)], Join(testRelation,
testRelation2, a=d))
+ // np (subquery2): Aggregate([], [max(a)], Join(Filter(a>1, testRelation),
testRelation2, a=d))
+ // The filter on the left join child propagates as a boolean attribute
through the Join node
+ // and is consumed as a FILTER (WHERE ...) clause on the np-side aggregate
expression.
+ val subquery1 = ScalarSubquery(
+ testRelation.join(testRelation2, Inner, Some($"a" === $"d"))
+ .groupBy()(sum($"a").as("sum_a")))
+ val subquery2 = ScalarSubquery(
+ testRelation.where($"a" > 1).join(testRelation2, Inner, Some($"a" ===
$"d"))
+ .groupBy()(max($"a").as("max_a")))
+ val originalQuery = testRelation.select(subquery1, subquery2)
+
+ val f0Alias = Alias($"a" > 1, "propagatedFilter_0")()
+ val f0 = f0Alias.toAttribute
+ val mergedSubquery = testRelation
+ .select(testRelation.output ++ Seq(f0Alias): _*)
+ .join(testRelation2, Inner, Some($"a" === $"d"))
+ .groupBy()(
+ sum($"a").as("sum_a"),
+ max($"a", Some(f0)).as("max_a"))
+ .select(CreateNamedStruct(Seq(
+ Literal("sum_a"), $"sum_a",
+ Literal("max_a"), $"max_a"
+ )).as("mergedValue"))
+ val analyzedMergedSubquery = mergedSubquery.analyze
+ val correctAnswer = WithCTE(
+ testRelation.select(
+ extractorExpression(0, analyzedMergedSubquery.output, 0),
+ extractorExpression(0, analyzedMergedSubquery.output, 1)),
+ Seq(definitionNode(analyzedMergedSubquery, 0)))
+
+
withSQLConf(SQLConf.MERGE_SUBPLANS_FILTER_PROPAGATION_THROUGH_JOIN_ENABLED.key
-> "true") {
+ comparePlans(Optimize.execute(originalQuery.analyze),
correctAnswer.analyze)
+ }
+ }
+
+ test("SPARK-56677: Merge non-grouping subqueries with filter on right join
child") {
+ // cp (subquery1): Aggregate([], [sum(a)], Join(testRelation,
testRelation2, a=d))
+ // np (subquery2): Aggregate([], [max(d)], Join(testRelation, Filter(d>1,
testRelation2), a=d))
+ // The filter on the right join child propagates analogously to the
left-child case.
+ val subquery1 = ScalarSubquery(
+ testRelation.join(testRelation2, Inner, Some($"a" === $"d"))
+ .groupBy()(sum($"a").as("sum_a")))
+ val subquery2 = ScalarSubquery(
+ testRelation.join(testRelation2.where($"d" > 1), Inner, Some($"a" ===
$"d"))
+ .groupBy()(max($"d").as("max_d")))
+ val originalQuery = testRelation.select(subquery1, subquery2)
+
+ val f0Alias = Alias($"d" > 1, "propagatedFilter_0")()
+ val f0 = f0Alias.toAttribute
+ val mergedSubquery = testRelation
+ .join(
+ testRelation2.select(testRelation2.output ++ Seq(f0Alias): _*),
+ Inner, Some($"a" === $"d"))
+ .groupBy()(
+ sum($"a").as("sum_a"),
+ max($"d", Some(f0)).as("max_d"))
+ .select(CreateNamedStruct(Seq(
+ Literal("sum_a"), $"sum_a",
+ Literal("max_d"), $"max_d"
+ )).as("mergedValue"))
+ val analyzedMergedSubquery = mergedSubquery.analyze
+ val correctAnswer = WithCTE(
+ testRelation.select(
+ extractorExpression(0, analyzedMergedSubquery.output, 0),
+ extractorExpression(0, analyzedMergedSubquery.output, 1)),
+ Seq(definitionNode(analyzedMergedSubquery, 0)))
+
+
withSQLConf(SQLConf.MERGE_SUBPLANS_FILTER_PROPAGATION_THROUGH_JOIN_ENABLED.key
-> "true") {
+ comparePlans(Optimize.execute(originalQuery.analyze),
correctAnswer.analyze)
+ }
+ }
+
+ test("SPARK-56677: Do not merge subqueries when both join children have
independent filters") {
+ // np has filters on BOTH left and right join children simultaneously. The
guard in the
+ // Join case prevents this merge because combining two independent filter
attributes would
+ // require ANDing them into a new alias, which is not yet supported.
+ val subquery1 = ScalarSubquery(
+ testRelation.join(testRelation2, Inner, Some($"a" === $"d"))
+ .groupBy()(sum($"a").as("sum_a")))
+ val subquery2 = ScalarSubquery(
+ testRelation.where($"a" > 1).join(testRelation2.where($"d" > 1), Inner,
Some($"a" === $"d"))
+ .groupBy()(max($"a").as("max_a")))
+ val originalQuery = testRelation.select(subquery1, subquery2)
+
+
withSQLConf(SQLConf.MERGE_SUBPLANS_FILTER_PROPAGATION_THROUGH_JOIN_ENABLED.key
-> "true") {
+ comparePlans(Optimize.execute(originalQuery.analyze),
originalQuery.analyze)
+ }
+ }
+
+ test("SPARK-56677: Merge non-grouping subqueries with filter on left side of
LeftSemi join") {
+ // Left-side filter attributes ARE in the LeftSemi join output, so
propagation is safe.
+ val subquery1 = ScalarSubquery(
+ testRelation.join(testRelation2, LeftSemi, Some($"a" === $"d"))
+ .groupBy()(sum($"a").as("sum_a")))
+ val subquery2 = ScalarSubquery(
+ testRelation.where($"a" > 1).join(testRelation2, LeftSemi, Some($"a" ===
$"d"))
+ .groupBy()(max($"a").as("max_a")))
+ val originalQuery = testRelation.select(subquery1, subquery2)
+
+ val f0Alias = Alias($"a" > 1, "propagatedFilter_0")()
+ val f0 = f0Alias.toAttribute
+ val mergedSubquery = testRelation
+ .select(testRelation.output ++ Seq(f0Alias): _*)
+ .join(testRelation2, LeftSemi, Some($"a" === $"d"))
+ .groupBy()(
+ sum($"a").as("sum_a"),
+ max($"a", Some(f0)).as("max_a"))
+ .select(CreateNamedStruct(Seq(
+ Literal("sum_a"), $"sum_a",
+ Literal("max_a"), $"max_a"
+ )).as("mergedValue"))
+ val analyzedMergedSubquery = mergedSubquery.analyze
+ val correctAnswer = WithCTE(
+ testRelation.select(
+ extractorExpression(0, analyzedMergedSubquery.output, 0),
+ extractorExpression(0, analyzedMergedSubquery.output, 1)),
+ Seq(definitionNode(analyzedMergedSubquery, 0)))
+
+
withSQLConf(SQLConf.MERGE_SUBPLANS_FILTER_PROPAGATION_THROUGH_JOIN_ENABLED.key
-> "true") {
+ comparePlans(Optimize.execute(originalQuery.analyze),
correctAnswer.analyze)
+ }
+ }
+
+ test("SPARK-56677: Do not merge subqueries when filter is on the right side
of a LeftSemi join") {
+ // Right-side filter attributes are NOT in the LeftSemi join output (only
left-side columns
+ // are produced). Propagating such a filter would create an unresolvable
attribute reference
+ // in the parent Aggregate's FILTER clause.
+ val subquery1 = ScalarSubquery(
+ testRelation.join(testRelation2, LeftSemi, Some($"a" === $"d"))
+ .groupBy()(sum($"a").as("sum_a")))
+ val subquery2 = ScalarSubquery(
+ testRelation.join(testRelation2.where($"d" > 1), LeftSemi, Some($"a" ===
$"d"))
+ .groupBy()(max($"a").as("max_a")))
+ val originalQuery = testRelation.select(subquery1, subquery2)
+
+
withSQLConf(SQLConf.MERGE_SUBPLANS_FILTER_PROPAGATION_THROUGH_JOIN_ENABLED.key
-> "true") {
+ comparePlans(Optimize.execute(originalQuery.analyze),
originalQuery.analyze)
+ }
+ }
+
+ test("SPARK-56677: Do not merge subqueries when filter is on the nullable
side of an outer " +
+ "join") {
+ // For a RightOuter join the left side is nullable: unmatched right rows
produce NULL for all
+ // left-side columns including the filter attribute f, so FILTER (WHERE
f=NULL) would
+ // incorrectly exclude those rows from the aggregate even though they
appear in the join result.
+ // The same problem applies to the right side of a LeftOuter join and both
sides of FullOuter.
+ val subquery1 = ScalarSubquery(
+ testRelation.join(testRelation2, RightOuter, Some($"a" === $"d"))
+ .groupBy()(sum($"a").as("sum_a")))
+ val subquery2 = ScalarSubquery(
+ testRelation.where($"a" > 1).join(testRelation2, RightOuter, Some($"a"
=== $"d"))
+ .groupBy()(max($"a").as("max_a")))
+ val originalQuery = testRelation.select(subquery1, subquery2)
+
+
withSQLConf(SQLConf.MERGE_SUBPLANS_FILTER_PROPAGATION_THROUGH_JOIN_ENABLED.key
-> "true") {
+ comparePlans(Optimize.execute(originalQuery.analyze),
originalQuery.analyze)
+ }
+ }
+
+ test("SPARK-56677: Do not merge subqueries with filter propagation through
join when disabled") {
+
withSQLConf(SQLConf.MERGE_SUBPLANS_FILTER_PROPAGATION_THROUGH_JOIN_ENABLED.key
-> "false") {
+ val subquery1 = ScalarSubquery(
+ testRelation.join(testRelation2, Inner, Some($"a" === $"d"))
+ .groupBy()(sum($"a").as("sum_a")))
+ val subquery2 = ScalarSubquery(
+ testRelation.where($"a" > 1).join(testRelation2, Inner, Some($"a" ===
$"d"))
+ .groupBy()(max($"a").as("max_a")))
+ val originalQuery = testRelation.select(subquery1, subquery2)
+
+ comparePlans(Optimize.execute(originalQuery.analyze),
originalQuery.analyze)
+ }
+ }
Review Comment:
Added in
https://github.com/apache/spark/pull/55628/commits/4d87515c6de9583e4f6999094cb91a6d4dd70312.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]