peter-toth commented on code in PR #52835:
URL: https://github.com/apache/spark/pull/52835#discussion_r2485847706
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueries.scala:
##########
@@ -117,297 +116,126 @@ object MergeScalarSubqueries extends Rule[LogicalPlan] {
}
}
- /**
- * An item in the cache of merged scalar subqueries.
- *
- * @param plan The plan of a merged scalar subquery.
- * @param merged A flag to identify if this item is the result of merging
subqueries.
- * Please note that `attributes.size == 1` doesn't always mean
that the plan is not
- * merged as there can be subqueries that are different
([[checkIdenticalPlans]] is
- * false) due to an extra [[Project]] node in one of them. In
that case
- * `attributes.size` remains 1 after merging, but the merged
flag becomes true.
- * @param references A set of subquery indexes in the cache to track all
(including transitive)
- * nested subqueries.
- */
- case class Header(
- plan: LogicalPlan,
- merged: Boolean,
- references: Set[Int])
-
private def extractCommonScalarSubqueries(plan: LogicalPlan) = {
- val cache = ArrayBuffer.empty[Header]
- val planWithReferences = insertReferences(plan, cache)
- cache.zipWithIndex.foreach { case (header, i) =>
- cache(i) = cache(i).copy(plan =
- if (header.merged) {
+ // Collect `ScalarSubquery` plans by level into `PlanMerger`s and insert
references in place of
+ // `ScalarSubquery`s.
+ val planMergers = ArrayBuffer.empty[PlanMerger]
+ val planWithReferences = insertReferences(plan, planMergers)._1
+
+ // Traverse level by level and convert merged plans to `CTERelationDef`s
and keep non-merged
+ // ones. While traversing replace references in plans back to
`CTERelationRef`s or to original
+ // `ScalarSubquery`s. This is safe as a subquery plan at a level can
reference only lower level
+ // other subqueries.
+ val subqueryPlansByLevel = ArrayBuffer.empty[IndexedSeq[LogicalPlan]]
+ planMergers.foreach { planMerger =>
+ val mergedPlans = planMerger.mergedPlans()
+ subqueryPlansByLevel += mergedPlans.map { mergedPlan =>
+ val planWithoutReferences = if (subqueryPlansByLevel.isEmpty) {
+ // Level 0 plans can't contain references
+ mergedPlan.plan
+ } else {
+ removeReferences(mergedPlan.plan, subqueryPlansByLevel)
+ }
+ if (mergedPlan.merged) {
CTERelationDef(
- createProject(header.plan.output, removeReferences(header.plan,
cache)),
+ Project(
+ Seq(Alias(
+ CreateNamedStruct(
+ planWithoutReferences.output.flatMap(a =>
Seq(Literal(a.name), a))),
+ "mergedValue")()),
+ planWithoutReferences),
underSubquery = true)
} else {
- removeReferences(header.plan, cache)
- })
+ planWithoutReferences
+ }
+ }
}
- val newPlan = removeReferences(planWithReferences, cache)
- val subqueryCTEs =
cache.filter(_.merged).map(_.plan.asInstanceOf[CTERelationDef])
+
+ // Replace references back to `CTERelationRef`s or to original
`ScalarSubquery`s in the main
+ // plan.
+ val newPlan = removeReferences(planWithReferences, subqueryPlansByLevel)
+
+ // Add `CTERelationDef`s to the plan.
+ val subqueryCTEs = subqueryPlansByLevel.flatMap(_.collect { case cte:
CTERelationDef => cte })
if (subqueryCTEs.nonEmpty) {
WithCTE(newPlan, subqueryCTEs.toSeq)
} else {
newPlan
}
}
- // First traversal builds up the cache and inserts
`ScalarSubqueryReference`s to the plan.
- private def insertReferences(plan: LogicalPlan, cache: ArrayBuffer[Header]):
LogicalPlan = {
- plan.transformUpWithSubqueries {
- case n =>
n.transformExpressionsUpWithPruning(_.containsAnyPattern(SCALAR_SUBQUERY)) {
- // The subquery could contain a hint that is not propagated once we
cache it, but as a
- // non-correlated scalar subquery won't be turned into a Join the loss
of hints is fine.
+ // First traversal inserts `ScalarSubqueryReference`s to the plan and tries
to merge subquery
+ // plans by each level.
+ private def insertReferences(
+ plan: LogicalPlan,
+ planMergers: ArrayBuffer[PlanMerger]): (LogicalPlan, Int) = {
+ // The level of a subquery plan is maximum level of its inner subqueries +
1 or 0 if it has no
+ // inner subqueries.
+ var maxLevel = 0
+ val planWithReferences =
+
plan.transformAllExpressionsWithPruning(_.containsAnyPattern(SCALAR_SUBQUERY)) {
case s: ScalarSubquery if !s.isCorrelated && s.deterministic =>
- val (subqueryIndex, headerIndex) = cacheSubquery(s.plan, cache)
- ScalarSubqueryReference(subqueryIndex, headerIndex, s.dataType,
s.exprId)
- }
- }
- }
-
- // Caching returns the index of the subquery in the cache and the index of
scalar member in the
- // "Header".
- private def cacheSubquery(plan: LogicalPlan, cache: ArrayBuffer[Header]):
(Int, Int) = {
Review Comment:
Briefly, `cacheSubquery()` became `PlanMerger.merge()` with some
modifications.
As `PlanMerger.merge()` is kind of an internal API now, which can be used by
other rules. It returns `MergeResult` instead of the tuple after this refactor.
The `Header` objects, which contined the information about the merged plans
became `MergedPlan`s and can be accessed as `PlanMerger.mergedPlans()` by the
rules.
The main difference between `cacheSubquery()` and the new
`PlanMerger.merge()` is that I removed `Header.references` and the related
safeguards like the `if !references.contains(subqueryIndex)` below.
Previously, we filled up `Header.references` and with nested subquery
indices so as to avoid trying to merge a subquery with other subqueries that
are nested into it.
E.g.
```
SELECT (
SELECT avg(a) FROM t WHERE c = ( -- subquery 1
SELECT sum(b) FROM t2) -- subquery 2 nested into subquery 1
)
)
```
Here it doesn't make sense to try merging the plan of `subquery 1` and
`subquery 2` as the former contains the latter. The problem of trying to merge
the 2 was not just it doesn't make sense, but in certain weird cases (see the
example in [SPARK-40618](https://issues.apache.org/jira/browse/SPARK-40618))
the merge was successful and resulted in invalid plans.
But because `PlanMerger` should be a general plan merging tool and in some
future rules it will be used to merge other than subqueries, tracking the
nested subqueries in it is not the best way to deal with the problem.
This is why the refactored `MergeScalarSubqueries` uses a different
approach. We have one `PlanMerger` object per each subquery level. E.g.
`subquery 2` is a level 0 (leaf) subquery, but `subquery 1` is a level 1
subquery because it has an inner, level 0 subquery. So the 2 will be added to
different `PlanMerger`s. This way a level n subquery is tried to merge with
other level n subqueries only.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]