peter-toth commented on code in PR #52835:
URL: https://github.com/apache/spark/pull/52835#discussion_r2485976169
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueries.scala:
##########
@@ -117,297 +116,126 @@ object MergeScalarSubqueries extends Rule[LogicalPlan] {
}
}
- /**
- * An item in the cache of merged scalar subqueries.
- *
- * @param plan The plan of a merged scalar subquery.
- * @param merged A flag to identify if this item is the result of merging
subqueries.
- * Please note that `attributes.size == 1` doesn't always mean
that the plan is not
- * merged as there can be subqueries that are different
([[checkIdenticalPlans]] is
- * false) due to an extra [[Project]] node in one of them. In
that case
- * `attributes.size` remains 1 after merging, but the merged
flag becomes true.
- * @param references A set of subquery indexes in the cache to track all
(including transitive)
- * nested subqueries.
- */
- case class Header(
- plan: LogicalPlan,
- merged: Boolean,
- references: Set[Int])
-
private def extractCommonScalarSubqueries(plan: LogicalPlan) = {
- val cache = ArrayBuffer.empty[Header]
- val planWithReferences = insertReferences(plan, cache)
- cache.zipWithIndex.foreach { case (header, i) =>
- cache(i) = cache(i).copy(plan =
- if (header.merged) {
+ // Collect `ScalarSubquery` plans by level into `PlanMerger`s and insert
references in place of
+ // `ScalarSubquery`s.
+ val planMergers = ArrayBuffer.empty[PlanMerger]
+ val planWithReferences = insertReferences(plan, planMergers)._1
+
+ // Traverse level by level and convert merged plans to `CTERelationDef`s
and keep non-merged
+ // ones. While traversing replace references in plans back to
`CTERelationRef`s or to original
+ // `ScalarSubquery`s. This is safe as a subquery plan at a level can
reference only lower level
+ // other subqueries.
+ val subqueryPlansByLevel = ArrayBuffer.empty[IndexedSeq[LogicalPlan]]
+ planMergers.foreach { planMerger =>
+ val mergedPlans = planMerger.mergedPlans()
+ subqueryPlansByLevel += mergedPlans.map { mergedPlan =>
+ val planWithoutReferences = if (subqueryPlansByLevel.isEmpty) {
+ // Level 0 plans can't contain references
+ mergedPlan.plan
+ } else {
+ removeReferences(mergedPlan.plan, subqueryPlansByLevel)
+ }
+ if (mergedPlan.merged) {
CTERelationDef(
- createProject(header.plan.output, removeReferences(header.plan,
cache)),
+ Project(
+ Seq(Alias(
+ CreateNamedStruct(
+ planWithoutReferences.output.flatMap(a =>
Seq(Literal(a.name), a))),
+ "mergedValue")()),
+ planWithoutReferences),
underSubquery = true)
} else {
- removeReferences(header.plan, cache)
- })
+ planWithoutReferences
+ }
+ }
}
- val newPlan = removeReferences(planWithReferences, cache)
- val subqueryCTEs =
cache.filter(_.merged).map(_.plan.asInstanceOf[CTERelationDef])
+
+ // Replace references back to `CTERelationRef`s or to original
`ScalarSubquery`s in the main
+ // plan.
+ val newPlan = removeReferences(planWithReferences, subqueryPlansByLevel)
+
+ // Add `CTERelationDef`s to the plan.
+ val subqueryCTEs = subqueryPlansByLevel.flatMap(_.collect { case cte:
CTERelationDef => cte })
if (subqueryCTEs.nonEmpty) {
WithCTE(newPlan, subqueryCTEs.toSeq)
} else {
newPlan
}
}
- // First traversal builds up the cache and inserts
`ScalarSubqueryReference`s to the plan.
- private def insertReferences(plan: LogicalPlan, cache: ArrayBuffer[Header]):
LogicalPlan = {
- plan.transformUpWithSubqueries {
- case n =>
n.transformExpressionsUpWithPruning(_.containsAnyPattern(SCALAR_SUBQUERY)) {
- // The subquery could contain a hint that is not propagated once we
cache it, but as a
- // non-correlated scalar subquery won't be turned into a Join the loss
of hints is fine.
+ // First traversal inserts `ScalarSubqueryReference`s to the plan and tries
to merge subquery
+ // plans by each level.
+ private def insertReferences(
+ plan: LogicalPlan,
+ planMergers: ArrayBuffer[PlanMerger]): (LogicalPlan, Int) = {
+ // The level of a subquery plan is maximum level of its inner subqueries +
1 or 0 if it has no
+ // inner subqueries.
+ var maxLevel = 0
+ val planWithReferences =
+
plan.transformAllExpressionsWithPruning(_.containsAnyPattern(SCALAR_SUBQUERY)) {
case s: ScalarSubquery if !s.isCorrelated && s.deterministic =>
- val (subqueryIndex, headerIndex) = cacheSubquery(s.plan, cache)
- ScalarSubqueryReference(subqueryIndex, headerIndex, s.dataType,
s.exprId)
- }
- }
- }
-
- // Caching returns the index of the subquery in the cache and the index of
scalar member in the
- // "Header".
- private def cacheSubquery(plan: LogicalPlan, cache: ArrayBuffer[Header]):
(Int, Int) = {
- val output = plan.output.head
- val references = mutable.HashSet.empty[Int]
-
plan.transformAllExpressionsWithPruning(_.containsAnyPattern(SCALAR_SUBQUERY_REFERENCE))
{
- case ssr: ScalarSubqueryReference =>
- references += ssr.subqueryIndex
- references ++= cache(ssr.subqueryIndex).references
- ssr
- }
-
- cache.zipWithIndex.collectFirst(Function.unlift {
- case (header, subqueryIndex) if !references.contains(subqueryIndex) =>
- checkIdenticalPlans(plan, header.plan).map { outputMap =>
- val mappedOutput = mapAttributes(output, outputMap)
- val headerIndex = header.plan.output.indexWhere(_.exprId ==
mappedOutput.exprId)
- subqueryIndex -> headerIndex
- }.orElse{
- tryMergePlans(plan, header.plan).map {
- case (mergedPlan, outputMap) =>
- val mappedOutput = mapAttributes(output, outputMap)
- val headerIndex = mergedPlan.output.indexWhere(_.exprId ==
mappedOutput.exprId)
- cache(subqueryIndex) = Header(mergedPlan, true,
header.references ++ references)
- subqueryIndex -> headerIndex
- }
- }
- case _ => None
- }).getOrElse {
- cache += Header(plan, false, references.toSet)
- cache.length - 1 -> 0
- }
- }
-
- // If 2 plans are identical return the attribute mapping from the new to the
cached version.
- private def checkIdenticalPlans(
- newPlan: LogicalPlan,
- cachedPlan: LogicalPlan): Option[AttributeMap[Attribute]] = {
- if (newPlan.canonicalized == cachedPlan.canonicalized) {
- Some(AttributeMap(newPlan.output.zip(cachedPlan.output)))
- } else {
- None
- }
- }
-
- // Recursively traverse down and try merging 2 plans. If merge is possible
then return the merged
- // plan with the attribute mapping from the new to the merged version.
- // Please note that merging arbitrary plans can be complicated, the current
version supports only
- // some of the most important nodes.
- private def tryMergePlans(
- newPlan: LogicalPlan,
- cachedPlan: LogicalPlan): Option[(LogicalPlan, AttributeMap[Attribute])]
= {
- checkIdenticalPlans(newPlan, cachedPlan).map(cachedPlan -> _).orElse(
- (newPlan, cachedPlan) match {
- case (np: Project, cp: Project) =>
- tryMergePlans(np.child, cp.child).map { case (mergedChild,
outputMap) =>
- val (mergedProjectList, newOutputMap) =
- mergeNamedExpressions(np.projectList, outputMap, cp.projectList)
- val mergedPlan = Project(mergedProjectList, mergedChild)
- mergedPlan -> newOutputMap
- }
- case (np, cp: Project) =>
- tryMergePlans(np, cp.child).map { case (mergedChild, outputMap) =>
- val (mergedProjectList, newOutputMap) =
- mergeNamedExpressions(np.output, outputMap, cp.projectList)
- val mergedPlan = Project(mergedProjectList, mergedChild)
- mergedPlan -> newOutputMap
- }
- case (np: Project, cp) =>
- tryMergePlans(np.child, cp).map { case (mergedChild, outputMap) =>
- val (mergedProjectList, newOutputMap) =
- mergeNamedExpressions(np.projectList, outputMap, cp.output)
- val mergedPlan = Project(mergedProjectList, mergedChild)
- mergedPlan -> newOutputMap
- }
- case (np: Aggregate, cp: Aggregate) if supportedAggregateMerge(np, cp)
=>
- tryMergePlans(np.child, cp.child).flatMap { case (mergedChild,
outputMap) =>
- val mappedNewGroupingExpression =
- np.groupingExpressions.map(mapAttributes(_, outputMap))
- // Order of grouping expression does matter as merging different
grouping orders can
- // introduce "extra" shuffles/sorts that might not present in all
of the original
- // subqueries.
- if (mappedNewGroupingExpression.map(_.canonicalized) ==
- cp.groupingExpressions.map(_.canonicalized)) {
- val (mergedAggregateExpressions, newOutputMap) =
- mergeNamedExpressions(np.aggregateExpressions, outputMap,
cp.aggregateExpressions)
- val mergedPlan =
- Aggregate(cp.groupingExpressions, mergedAggregateExpressions,
mergedChild)
- Some(mergedPlan -> newOutputMap)
- } else {
- None
- }
- }
-
- case (np: Filter, cp: Filter) =>
- tryMergePlans(np.child, cp.child).flatMap { case (mergedChild,
outputMap) =>
- val mappedNewCondition = mapAttributes(np.condition, outputMap)
- // Comparing the canonicalized form is required to ignore
different forms of the same
- // expression.
- if (mappedNewCondition.canonicalized ==
cp.condition.canonicalized) {
- val mergedPlan = cp.withNewChildren(Seq(mergedChild))
- Some(mergedPlan -> outputMap)
- } else {
- None
- }
- }
-
- case (np: Join, cp: Join) if np.joinType == cp.joinType && np.hint ==
cp.hint =>
- tryMergePlans(np.left, cp.left).flatMap { case (mergedLeft,
leftOutputMap) =>
- tryMergePlans(np.right, cp.right).flatMap { case (mergedRight,
rightOutputMap) =>
- val outputMap = leftOutputMap ++ rightOutputMap
- val mappedNewCondition = np.condition.map(mapAttributes(_,
outputMap))
- // Comparing the canonicalized form is required to ignore
different forms of the same
- // expression and `AttributeReference.quailifier`s in
`cp.condition`.
- if (mappedNewCondition.map(_.canonicalized) ==
cp.condition.map(_.canonicalized)) {
- val mergedPlan = cp.withNewChildren(Seq(mergedLeft,
mergedRight))
- Some(mergedPlan -> outputMap)
- } else {
- None
- }
- }
- }
-
- // Otherwise merging is not possible.
- case _ => None
- })
- }
-
- private def createProject(attributes: Seq[Attribute], plan: LogicalPlan):
Project = {
- Project(
- Seq(Alias(
- CreateNamedStruct(attributes.flatMap(a => Seq(Literal(a.name), a))),
- "mergedValue")()),
- plan)
- }
-
- private def mapAttributes[T <: Expression](expr: T, outputMap:
AttributeMap[Attribute]) = {
- expr.transform {
- case a: Attribute => outputMap.getOrElse(a, a)
- }.asInstanceOf[T]
- }
-
- // Applies `outputMap` attribute mapping on attributes of `newExpressions`
and merges them into
- // `cachedExpressions`. Returns the merged expressions and the attribute
mapping from the new to
- // the merged version that can be propagated up during merging nodes.
- private def mergeNamedExpressions(
- newExpressions: Seq[NamedExpression],
- outputMap: AttributeMap[Attribute],
- cachedExpressions: Seq[NamedExpression]) = {
- val mergedExpressions = ArrayBuffer[NamedExpression](cachedExpressions: _*)
- val newOutputMap = AttributeMap(newExpressions.map { ne =>
- val mapped = mapAttributes(ne, outputMap)
- val withoutAlias = mapped match {
- case Alias(child, _) => child
- case e => e
- }
- ne.toAttribute -> mergedExpressions.find {
- case Alias(child, _) => child semanticEquals withoutAlias
- case e => e semanticEquals withoutAlias
- }.getOrElse {
- mergedExpressions += mapped
- mapped
- }.toAttribute
- })
- (mergedExpressions.toSeq, newOutputMap)
- }
-
- // Only allow aggregates of the same implementation because merging
different implementations
- // could cause performance regression.
- private def supportedAggregateMerge(newPlan: Aggregate, cachedPlan:
Aggregate) = {
- val aggregateExpressionsSeq = Seq(newPlan, cachedPlan).map { plan =>
- plan.aggregateExpressions.flatMap(_.collect {
- case a: AggregateExpression => a
- })
- }
- val groupByExpressionSeq = Seq(newPlan,
cachedPlan).map(_.groupingExpressions)
-
- val Seq(newPlanSupportsHashAggregate, cachedPlanSupportsHashAggregate) =
- aggregateExpressionsSeq.zip(groupByExpressionSeq).map {
- case (aggregateExpressions, groupByExpressions) =>
- Aggregate.supportsHashAggregate(
- aggregateExpressions.flatMap(
- _.aggregateFunction.aggBufferAttributes), groupByExpressions)
- }
-
- newPlanSupportsHashAggregate && cachedPlanSupportsHashAggregate ||
- newPlanSupportsHashAggregate == cachedPlanSupportsHashAggregate && {
- val Seq(newPlanSupportsObjectHashAggregate,
cachedPlanSupportsObjectHashAggregate) =
- aggregateExpressionsSeq.zip(groupByExpressionSeq).map {
- case (aggregateExpressions, groupByExpressions) =>
- Aggregate.supportsObjectHashAggregate(aggregateExpressions,
groupByExpressions)
- }
- newPlanSupportsObjectHashAggregate &&
cachedPlanSupportsObjectHashAggregate ||
- newPlanSupportsObjectHashAggregate ==
cachedPlanSupportsObjectHashAggregate
+ val (planWithReferences, level) = insertReferences(s.plan,
planMergers)
+
+ while (level >= planMergers.size) planMergers += PlanMerger()
+ // The subquery could contain a hint that is not propagated once we
merge it, but as a
+ // non-correlated scalar subquery won't be turned into a Join the
loss of hints is fine.
+ val planMergeResult = planMergers(level).merge(planWithReferences)
+
+ maxLevel = maxLevel.max(level + 1)
+
+ val mergedOutput =
planMergeResult.outputMap(planWithReferences.output.head)
+ val headerIndex =
+ planMergeResult.mergedPlan.output.indexWhere(_.exprId ==
mergedOutput.exprId)
+ ScalarSubqueryReference(
+ level,
+ planMergeResult.mergedPlanIndex,
+ headerIndex,
+ s.dataType,
+ s.exprId)
+ case o => o
}
+ (planWithReferences, maxLevel)
}
// Second traversal replaces `ScalarSubqueryReference`s to either
// `GetStructField(ScalarSubquery(CTERelationRef to the merged plan)` if the
plan is merged from
// multiple subqueries or `ScalarSubquery(original plan)` if it isn't.
private def removeReferences(
plan: LogicalPlan,
- cache: ArrayBuffer[Header]) = {
- plan.transformUpWithSubqueries {
- case n =>
-
n.transformExpressionsWithPruning(_.containsAnyPattern(SCALAR_SUBQUERY_REFERENCE))
{
- case ssr: ScalarSubqueryReference =>
- val header = cache(ssr.subqueryIndex)
- if (header.merged) {
- val subqueryCTE = header.plan.asInstanceOf[CTERelationDef]
- GetStructField(
- ScalarSubquery(
- CTERelationRef(subqueryCTE.id, _resolved = true,
subqueryCTE.output,
- subqueryCTE.isStreaming),
- exprId = ssr.exprId),
- ssr.headerIndex)
- } else {
- ScalarSubquery(header.plan, exprId = ssr.exprId)
- }
+ subqueryPlansByLevel: ArrayBuffer[IndexedSeq[LogicalPlan]]) = {
+
plan.transformAllExpressionsWithPruning(_.containsAnyPattern(SCALAR_SUBQUERY_REFERENCE))
{
+ case ssr: ScalarSubqueryReference =>
+ subqueryPlansByLevel(ssr.level)(ssr.mergedPlanIndex) match {
+ case cte: CTERelationDef =>
+ GetStructField(
+ ScalarSubquery(
+ CTERelationRef(cte.id, _resolved = true, cte.output,
cte.isStreaming),
+ exprId = ssr.exprId),
+ ssr.outputIndex)
+ case o => ScalarSubquery(o, exprId = ssr.exprId)
}
}
}
}
/**
- * Temporal reference to a cached subquery.
+ * Temporal reference to a subquery which is added to a `PlanMerger`.
*
- * @param subqueryIndex A subquery index in the cache.
- * @param headerIndex An index in the output of merged subquery.
- * @param dataType The dataType of origin scalar subquery.
+ * @param level The level of the replaced subquery. It defines the
`PlanMerger` instance into which
+ * the subquery is merged.
+ * @param mergedPlanIndex The index of the merged plan in the `PlanMerger`.
+ * @param outputIndex The index of the output attribute of the merged plan.
+ * @param dataType The dataType of original scalar subquery.
+ * @param exprId The expression id of the original scalar subquery.
*/
case class ScalarSubqueryReference(
- subqueryIndex: Int,
- headerIndex: Int,
+ level: Int,
+ mergedPlanIndex: Int,
+ outputIndex: Int,
dataType: DataType,
exprId: ExprId) extends LeafExpression with Unevaluable {
override def nullable: Boolean = true
final override val nodePatterns: Seq[TreePattern] =
Seq(SCALAR_SUBQUERY_REFERENCE)
-
- override def stringArgs: Iterator[Any] = Iterator(subqueryIndex,
headerIndex, dataType, exprId.id)
Review Comment:
`ScalarSubqueryReference` never shows up in any stringified plan. It's just
a temporary object that is added and then removed when `MergeScalarSubqueries`
runs so I think we don't need this.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]