dongjoon-hyun commented on code in PR #52835:
URL: https://github.com/apache/spark/pull/52835#discussion_r2484975684
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueries.scala:
##########
@@ -117,297 +116,126 @@ object MergeScalarSubqueries extends Rule[LogicalPlan] {
}
}
- /**
- * An item in the cache of merged scalar subqueries.
- *
- * @param plan The plan of a merged scalar subquery.
- * @param merged A flag to identify if this item is the result of merging
subqueries.
- * Please note that `attributes.size == 1` doesn't always mean
that the plan is not
- * merged as there can be subqueries that are different
([[checkIdenticalPlans]] is
- * false) due to an extra [[Project]] node in one of them. In
that case
- * `attributes.size` remains 1 after merging, but the merged
flag becomes true.
- * @param references A set of subquery indexes in the cache to track all
(including transitive)
- * nested subqueries.
- */
- case class Header(
- plan: LogicalPlan,
- merged: Boolean,
- references: Set[Int])
-
private def extractCommonScalarSubqueries(plan: LogicalPlan) = {
- val cache = ArrayBuffer.empty[Header]
- val planWithReferences = insertReferences(plan, cache)
- cache.zipWithIndex.foreach { case (header, i) =>
- cache(i) = cache(i).copy(plan =
- if (header.merged) {
+ // Collect `ScalarSubquery` plans by level into `PlanMerger`s and insert
references in place of
+ // `ScalarSubquery`s.
+ val planMergers = ArrayBuffer.empty[PlanMerger]
+ val planWithReferences = insertReferences(plan, planMergers)._1
+
+ // Traverse level by level and convert merged plans to `CTERelationDef`s
and keep non-merged
+ // ones. While traversing replace references in plans back to
`CTERelationRef`s or to original
+ // `ScalarSubquery`s. This is safe as a subquery plan at a level can
reference only lower level
+ // other subqueries.
+ val subqueryPlansByLevel = ArrayBuffer.empty[IndexedSeq[LogicalPlan]]
+ planMergers.foreach { planMerger =>
+ val mergedPlans = planMerger.mergedPlans()
+ subqueryPlansByLevel += mergedPlans.map { mergedPlan =>
+ val planWithoutReferences = if (subqueryPlansByLevel.isEmpty) {
+ // Level 0 plans can't contain references
+ mergedPlan.plan
+ } else {
+ removeReferences(mergedPlan.plan, subqueryPlansByLevel)
+ }
+ if (mergedPlan.merged) {
CTERelationDef(
- createProject(header.plan.output, removeReferences(header.plan,
cache)),
+ Project(
+ Seq(Alias(
+ CreateNamedStruct(
+ planWithoutReferences.output.flatMap(a =>
Seq(Literal(a.name), a))),
+ "mergedValue")()),
+ planWithoutReferences),
underSubquery = true)
} else {
- removeReferences(header.plan, cache)
- })
+ planWithoutReferences
+ }
+ }
}
- val newPlan = removeReferences(planWithReferences, cache)
- val subqueryCTEs =
cache.filter(_.merged).map(_.plan.asInstanceOf[CTERelationDef])
+
+ // Replace references back to `CTERelationRef`s or to original
`ScalarSubquery`s in the main
+ // plan.
+ val newPlan = removeReferences(planWithReferences, subqueryPlansByLevel)
+
+ // Add `CTERelationDef`s to the plan.
+ val subqueryCTEs = subqueryPlansByLevel.flatMap(_.collect { case cte:
CTERelationDef => cte })
if (subqueryCTEs.nonEmpty) {
WithCTE(newPlan, subqueryCTEs.toSeq)
} else {
newPlan
}
}
- // First traversal builds up the cache and inserts
`ScalarSubqueryReference`s to the plan.
- private def insertReferences(plan: LogicalPlan, cache: ArrayBuffer[Header]):
LogicalPlan = {
- plan.transformUpWithSubqueries {
- case n =>
n.transformExpressionsUpWithPruning(_.containsAnyPattern(SCALAR_SUBQUERY)) {
- // The subquery could contain a hint that is not propagated once we
cache it, but as a
- // non-correlated scalar subquery won't be turned into a Join the loss
of hints is fine.
+ // First traversal inserts `ScalarSubqueryReference`s to the plan and tries
to merge subquery
+ // plans by each level.
+ private def insertReferences(
+ plan: LogicalPlan,
+ planMergers: ArrayBuffer[PlanMerger]): (LogicalPlan, Int) = {
+ // The level of a subquery plan is maximum level of its inner subqueries +
1 or 0 if it has no
+ // inner subqueries.
+ var maxLevel = 0
+ val planWithReferences =
+
plan.transformAllExpressionsWithPruning(_.containsAnyPattern(SCALAR_SUBQUERY)) {
case s: ScalarSubquery if !s.isCorrelated && s.deterministic =>
- val (subqueryIndex, headerIndex) = cacheSubquery(s.plan, cache)
- ScalarSubqueryReference(subqueryIndex, headerIndex, s.dataType,
s.exprId)
- }
- }
- }
-
- // Caching returns the index of the subquery in the cache and the index of
scalar member in the
- // "Header".
- private def cacheSubquery(plan: LogicalPlan, cache: ArrayBuffer[Header]):
(Int, Int) = {
Review Comment:
Is this method removed simply?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]