This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.3
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.3 by this push:
     new 671539de00c [SPARK-37670][SQL] Support predicate pushdown and column 
pruning for de-duped CTEs
671539de00c is described below

commit 671539de00c1da817859de66345e122cac01a2ee
Author: Maryann Xue <maryann....@gmail.com>
AuthorDate: Tue Apr 19 10:50:07 2022 +0800

    [SPARK-37670][SQL] Support predicate pushdown and column pruning for 
de-duped CTEs
    
    This PR adds predicate push-down and column pruning to CTEs that are not 
inlined as well as fixes a few potential correctness issues:
      1) Replace (previously not inlined) CTE refs with Repartition operations 
at the end of logical plan optimization so that WithCTE is not carried over to 
physical plan. As a result, we can simplify the logic of physical planning, as 
well as avoid a correctness issue where the logical link of a physical plan 
node can point to `WithCTE` and lead to unexpected behaviors in AQE, e.g., 
class cast exceptions in DPP.
      2) Pull (not inlined) CTE defs from subqueries up to the main query 
level, in order to avoid creating copies of the same CTE def during predicate 
push-downs and other transformations.
      3) Make CTE IDs more deterministic by starting from 0 for each query.
    
    Improve de-duped CTEs' performance with predicate pushdown and column 
pruning; fixes de-duped CTEs' correctness issues.
    
    No.
    
    Added UTs.
    
    Closes #34929 from maryannxue/cte-followup.
    
    Lead-authored-by: Maryann Xue <maryann....@gmail.com>
    Co-authored-by: Wenchen Fan <wenc...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
    (cherry picked from commit 175e429cca29c2314ee029bf009ed5222c0bffad)
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../sql/catalyst/analysis/CTESubstitution.scala    |  30 ++-
 .../sql/catalyst/analysis/CheckAnalysis.scala      |   8 +-
 .../spark/sql/catalyst/optimizer/InlineCTE.scala   |  56 ++---
 .../spark/sql/catalyst/optimizer/Optimizer.scala   |  64 +++---
 ...ushdownPredicatesAndPruneColumnsForCTEDef.scala | 175 ++++++++++++++++
 .../optimizer/ReplaceCTERefWithRepartition.scala   |  84 ++++++++
 .../spark/sql/catalyst/plans/QueryPlan.scala       |  31 +++
 .../plans/logical/basicLogicalOperators.scala      |   9 +-
 .../spark/sql/catalyst/analysis/AnalysisTest.scala |   3 +-
 .../spark/sql/execution/QueryExecution.scala       |  23 +--
 .../spark/sql/execution/SparkOptimizer.scala       |   3 +-
 .../apache/spark/sql/execution/SparkPlanner.scala  |   1 -
 .../spark/sql/execution/SparkStrategies.scala      |  31 ---
 .../execution/adaptive/AdaptiveSparkPlanExec.scala |   7 +-
 .../scalar-subquery/scalar-subquery-select.sql     |  42 ++++
 .../scalar-subquery/scalar-subquery-select.sql.out | 103 ++++++++-
 .../approved-plans-v1_4/q23a.sf100/explain.txt     | 166 +++++++--------
 .../approved-plans-v1_4/q23b.sf100/explain.txt     | 190 ++++++++---------
 .../org/apache/spark/sql/CTEInlineSuite.scala      | 229 ++++++++++++++++++++-
 .../scala/org/apache/spark/sql/SQLQuerySuite.scala |  15 ++
 20 files changed, 962 insertions(+), 308 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala
index c0ba3598e4b..976a5d385d8 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala
@@ -69,13 +69,13 @@ object CTESubstitution extends Rule[LogicalPlan] {
     if (cteDefs.isEmpty) {
       substituted
     } else if (substituted eq lastSubstituted.get) {
-      WithCTE(substituted, cteDefs.toSeq)
+      WithCTE(substituted, cteDefs.sortBy(_.id).toSeq)
     } else {
       var done = false
       substituted.resolveOperatorsWithPruning(_ => !done) {
         case p if p eq lastSubstituted.get =>
           done = true
-          WithCTE(p, cteDefs.toSeq)
+          WithCTE(p, cteDefs.sortBy(_.id).toSeq)
       }
     }
   }
@@ -203,6 +203,7 @@ object CTESubstitution extends Rule[LogicalPlan] {
       cteDefs: mutable.ArrayBuffer[CTERelationDef]): Seq[(String, 
CTERelationDef)] = {
     val resolvedCTERelations = new mutable.ArrayBuffer[(String, 
CTERelationDef)](relations.size)
     for ((name, relation) <- relations) {
+      val lastCTEDefCount = cteDefs.length
       val innerCTEResolved = if (isLegacy) {
         // In legacy mode, outer CTE relations take precedence. Here we don't 
resolve the inner
         // `With` nodes, later we will substitute `UnresolvedRelation`s with 
outer CTE relations.
@@ -211,8 +212,33 @@ object CTESubstitution extends Rule[LogicalPlan] {
       } else {
         // A CTE definition might contain an inner CTE that has a higher 
priority, so traverse and
         // substitute CTE defined in `relation` first.
+        // NOTE: we must call `traverseAndSubstituteCTE` before 
`substituteCTE`, as the relations
+        // in the inner CTE have higher priority over the relations in the 
outer CTE when resolving
+        // inner CTE relations. For example:
+        // WITH t1 AS (SELECT 1)
+        // t2 AS (
+        //   WITH t1 AS (SELECT 2)
+        //   WITH t3 AS (SELECT * FROM t1)
+        // )
+        // t3 should resolve the t1 to `SELECT 2` instead of `SELECT 1`.
         traverseAndSubstituteCTE(relation, isCommand, cteDefs)._1
       }
+
+      if (cteDefs.length > lastCTEDefCount) {
+        // We have added more CTE relations to the `cteDefs` from the inner 
CTE, and these relations
+        // should also be substituted with `resolvedCTERelations` as inner CTE 
relation can refer to
+        // outer CTE relation. For example:
+        // WITH t1 AS (SELECT 1)
+        // t2 AS (
+        //   WITH t3 AS (SELECT * FROM t1)
+        // )
+        for (i <- lastCTEDefCount until cteDefs.length) {
+          val substituted =
+            substituteCTE(cteDefs(i).child, isLegacy || isCommand, 
resolvedCTERelations.toSeq)
+          cteDefs(i) = cteDefs(i).copy(child = substituted)
+        }
+      }
+
       // CTE definition can reference a previous one
       val substituted =
         substituteCTE(innerCTEResolved, isLegacy || isCommand, 
resolvedCTERelations.toSeq)
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 3b8a73717af..1c2de771a3d 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -22,7 +22,7 @@ import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.SubExprUtils._
 import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
-import org.apache.spark.sql.catalyst.optimizer.{BooleanSimplification, 
DecorrelateInnerQuery}
+import org.apache.spark.sql.catalyst.optimizer.{BooleanSimplification, 
DecorrelateInnerQuery, InlineCTE}
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.trees.TreeNodeTag
@@ -94,8 +94,10 @@ trait CheckAnalysis extends PredicateHelper with 
LookupCatalog {
 
   def checkAnalysis(plan: LogicalPlan): Unit = {
     // We transform up and order the rules so as to catch the first possible 
failure instead
-    // of the result of cascading resolution failures.
-    plan.foreachUp {
+    // of the result of cascading resolution failures. Inline all CTEs in the 
plan to help check
+    // query plan structures in subqueries.
+    val inlineCTE = InlineCTE(alwaysInline = true)
+    inlineCTE(plan).foreachUp {
 
       case p if p.analyzed => // Skip already analyzed sub-plans
 
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTE.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTE.scala
index 61577b1d21e..a740b92933f 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTE.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTE.scala
@@ -28,26 +28,37 @@ import 
org.apache.spark.sql.catalyst.trees.TreePattern.{CTE, PLAN_EXPRESSION}
 
 /**
  * Inlines CTE definitions into corresponding references if either of the 
conditions satisfies:
- * 1. The CTE definition does not contain any non-deterministic expressions. 
If this CTE
- *    definition references another CTE definition that has non-deterministic 
expressions, it
- *    is still OK to inline the current CTE definition.
+ * 1. The CTE definition does not contain any non-deterministic expressions or 
contains attribute
+ *    references to an outer query. If this CTE definition references another 
CTE definition that
+ *    has non-deterministic expressions, it is still OK to inline the current 
CTE definition.
  * 2. The CTE definition is only referenced once throughout the main query and 
all the subqueries.
  *
- * In addition, due to the complexity of correlated subqueries, all CTE 
references in correlated
- * subqueries are inlined regardless of the conditions above.
+ * CTE definitions that appear in subqueries and are not inlined will be 
pulled up to the main
+ * query level.
+ *
+ * @param alwaysInline if true, inline all CTEs in the query plan.
  */
-object InlineCTE extends Rule[LogicalPlan] {
+case class InlineCTE(alwaysInline: Boolean = false) extends Rule[LogicalPlan] {
+
   override def apply(plan: LogicalPlan): LogicalPlan = {
     if (!plan.isInstanceOf[Subquery] && plan.containsPattern(CTE)) {
       val cteMap = mutable.HashMap.empty[Long, (CTERelationDef, Int)]
       buildCTEMap(plan, cteMap)
-      inlineCTE(plan, cteMap, forceInline = false)
+      val notInlined = mutable.ArrayBuffer.empty[CTERelationDef]
+      val inlined = inlineCTE(plan, cteMap, notInlined)
+      // CTEs in SQL Commands have been inlined by `CTESubstitution` already, 
so it is safe to add
+      // WithCTE as top node here.
+      if (notInlined.isEmpty) {
+        inlined
+      } else {
+        WithCTE(inlined, notInlined.toSeq)
+      }
     } else {
       plan
     }
   }
 
-  private def shouldInline(cteDef: CTERelationDef, refCount: Int): Boolean = {
+  private def shouldInline(cteDef: CTERelationDef, refCount: Int): Boolean = 
alwaysInline || {
     // We do not need to check enclosed `CTERelationRef`s for `deterministic` 
or `OuterReference`,
     // because:
     // 1) It is fine to inline a CTE if it references another CTE that is 
non-deterministic;
@@ -93,25 +104,24 @@ object InlineCTE extends Rule[LogicalPlan] {
   private def inlineCTE(
       plan: LogicalPlan,
       cteMap: mutable.HashMap[Long, (CTERelationDef, Int)],
-      forceInline: Boolean): LogicalPlan = {
-    val (stripped, notInlined) = plan match {
+      notInlined: mutable.ArrayBuffer[CTERelationDef]): LogicalPlan = {
+    plan match {
       case WithCTE(child, cteDefs) =>
-        val notInlined = mutable.ArrayBuffer.empty[CTERelationDef]
         cteDefs.foreach { cteDef =>
           val (cte, refCount) = cteMap(cteDef.id)
           if (refCount > 0) {
-            val inlined = cte.copy(child = inlineCTE(cte.child, cteMap, 
forceInline))
+            val inlined = cte.copy(child = inlineCTE(cte.child, cteMap, 
notInlined))
             cteMap.update(cteDef.id, (inlined, refCount))
-            if (!forceInline && !shouldInline(inlined, refCount)) {
+            if (!shouldInline(inlined, refCount)) {
               notInlined.append(inlined)
             }
           }
         }
-        (inlineCTE(child, cteMap, forceInline), notInlined.toSeq)
+        inlineCTE(child, cteMap, notInlined)
 
       case ref: CTERelationRef =>
         val (cteDef, refCount) = cteMap(ref.cteId)
-        val newRef = if (forceInline || shouldInline(cteDef, refCount)) {
+        if (shouldInline(cteDef, refCount)) {
           if (ref.outputSet == cteDef.outputSet) {
             cteDef.child
           } else {
@@ -125,24 +135,16 @@ object InlineCTE extends Rule[LogicalPlan] {
         } else {
           ref
         }
-        (newRef, Seq.empty)
 
       case _ if plan.containsPattern(CTE) =>
-        val newPlan = plan
-          .withNewChildren(plan.children.map(child => inlineCTE(child, cteMap, 
forceInline)))
+        plan
+          .withNewChildren(plan.children.map(child => inlineCTE(child, cteMap, 
notInlined)))
           
.transformExpressionsWithPruning(_.containsAllPatterns(PLAN_EXPRESSION, CTE)) {
             case e: SubqueryExpression =>
-              e.withNewPlan(inlineCTE(e.plan, cteMap, forceInline = 
e.isCorrelated))
+              e.withNewPlan(inlineCTE(e.plan, cteMap, notInlined))
           }
-        (newPlan, Seq.empty)
 
-      case _ => (plan, Seq.empty)
-    }
-
-    if (notInlined.isEmpty) {
-      stripped
-    } else {
-      WithCTE(stripped, notInlined)
+      case _ => plan
     }
   }
 }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 66c2ad84cce..dc3e4c3da34 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -128,7 +128,8 @@ abstract class Optimizer(catalogManager: CatalogManager)
         OptimizeUpdateFields,
         SimplifyExtractValueOps,
         OptimizeCsvJsonExprs,
-        CombineConcats) ++
+        CombineConcats,
+        PushdownPredicatesAndPruneColumnsForCTEDef) ++
         extendedOperatorOptimizationRules
 
     val operatorOptimizationBatch: Seq[Batch] = {
@@ -147,22 +148,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
     }
 
     val batches = (Batch("Eliminate Distinct", Once, EliminateDistinct) ::
-    // Technically some of the rules in Finish Analysis are not optimizer 
rules and belong more
-    // in the analyzer, because they are needed for correctness (e.g. 
ComputeCurrentTime).
-    // However, because we also use the analyzer to canonicalized queries (for 
view definition),
-    // we do not eliminate subqueries or compute current time in the analyzer.
-    Batch("Finish Analysis", Once,
-      EliminateResolvedHint,
-      EliminateSubqueryAliases,
-      EliminateView,
-      InlineCTE,
-      ReplaceExpressions,
-      RewriteNonCorrelatedExists,
-      PullOutGroupingExpressions,
-      ComputeCurrentTime,
-      ReplaceCurrentLike(catalogManager),
-      SpecialDatetimeValues,
-      RewriteAsOfJoin) ::
+    Batch("Finish Analysis", Once, FinishAnalysis) ::
     
//////////////////////////////////////////////////////////////////////////////////////////
     // Optimizer rules start here
     
//////////////////////////////////////////////////////////////////////////////////////////
@@ -171,6 +157,8 @@ abstract class Optimizer(catalogManager: CatalogManager)
     //   extra operators between two adjacent Union operators.
     // - Call CombineUnions again in Batch("Operator Optimizations"),
     //   since the other rules might make two separate Unions operators 
adjacent.
+    Batch("Inline CTE", Once,
+      InlineCTE()) ::
     Batch("Union", Once,
       RemoveNoopOperators,
       CombineUnions,
@@ -207,6 +195,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
       RemoveLiteralFromGroupExpressions,
       RemoveRepetitionFromGroupExpressions) :: Nil ++
     operatorOptimizationBatch) :+
+    Batch("Clean Up Temporary CTE Info", Once, CleanUpTempCTEInfo) :+
     // This batch rewrites plans after the operator optimization and
     // before any batches that depend on stats.
     Batch("Pre CBO Rules", Once, preCBORules: _*) :+
@@ -265,14 +254,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
    * (defaultBatches - (excludedRules - nonExcludableRules)).
    */
   def nonExcludableRules: Seq[String] =
-    EliminateDistinct.ruleName ::
-      EliminateResolvedHint.ruleName ::
-      EliminateSubqueryAliases.ruleName ::
-      EliminateView.ruleName ::
-      ReplaceExpressions.ruleName ::
-      ComputeCurrentTime.ruleName ::
-      SpecialDatetimeValues.ruleName ::
-      ReplaceCurrentLike(catalogManager).ruleName ::
+    FinishAnalysis.ruleName ::
       RewriteDistinctAggregates.ruleName ::
       ReplaceDeduplicateWithAggregate.ruleName ::
       ReplaceIntersectWithSemiJoin.ruleName ::
@@ -286,10 +268,38 @@ abstract class Optimizer(catalogManager: CatalogManager)
       RewritePredicateSubquery.ruleName ::
       NormalizeFloatingNumbers.ruleName ::
       ReplaceUpdateFieldsExpression.ruleName ::
-      PullOutGroupingExpressions.ruleName ::
-      RewriteAsOfJoin.ruleName ::
       RewriteLateralSubquery.ruleName :: Nil
 
+  /**
+   * Apply finish-analysis rules for the entire plan including all subqueries.
+   */
+  object FinishAnalysis extends Rule[LogicalPlan] {
+    // Technically some of the rules in Finish Analysis are not optimizer 
rules and belong more
+    // in the analyzer, because they are needed for correctness (e.g. 
ComputeCurrentTime).
+    // However, because we also use the analyzer to canonicalized queries (for 
view definition),
+    // we do not eliminate subqueries or compute current time in the analyzer.
+    private val rules = Seq(
+      EliminateResolvedHint,
+      EliminateSubqueryAliases,
+      EliminateView,
+      ReplaceExpressions,
+      RewriteNonCorrelatedExists,
+      PullOutGroupingExpressions,
+      ComputeCurrentTime,
+      ReplaceCurrentLike(catalogManager),
+      SpecialDatetimeValues,
+      RewriteAsOfJoin)
+
+    override def apply(plan: LogicalPlan): LogicalPlan = {
+      rules.foldLeft(plan) { case (sp, rule) => rule.apply(sp) }
+        
.transformAllExpressionsWithPruning(_.containsPattern(PLAN_EXPRESSION)) {
+          case s: SubqueryExpression =>
+            val Subquery(newPlan, _) = apply(Subquery.fromExpression(s))
+            s.withNewPlan(newPlan)
+        }
+    }
+  }
+
   /**
    * Optimize all the subqueries inside expression.
    */
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushdownPredicatesAndPruneColumnsForCTEDef.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushdownPredicatesAndPruneColumnsForCTEDef.scala
new file mode 100644
index 00000000000..ab9f20edb0b
--- /dev/null
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushdownPredicatesAndPruneColumnsForCTEDef.scala
@@ -0,0 +1,175 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import scala.collection.mutable
+
+import org.apache.spark.sql.catalyst.expressions.{And, Attribute, 
AttributeSet, Expression, Literal, Or, SubqueryExpression}
+import org.apache.spark.sql.catalyst.planning.ScanOperation
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.trees.TreePattern.CTE
+
+/**
+ * Infer predicates and column pruning for [[CTERelationDef]] from its 
reference points, and push
+ * the disjunctive predicates as well as the union of attributes down the CTE 
plan.
+ */
+object PushdownPredicatesAndPruneColumnsForCTEDef extends Rule[LogicalPlan] {
+
+  // CTE_id - (CTE_definition, precedence, predicates_to_push_down, 
attributes_to_prune)
+  private type CTEMap = mutable.HashMap[Long, (CTERelationDef, Int, 
Seq[Expression], AttributeSet)]
+
+  override def apply(plan: LogicalPlan): LogicalPlan = {
+    if (!plan.isInstanceOf[Subquery] && plan.containsPattern(CTE)) {
+      val cteMap = new CTEMap
+      gatherPredicatesAndAttributes(plan, cteMap)
+      pushdownPredicatesAndAttributes(plan, cteMap)
+    } else {
+      plan
+    }
+  }
+
+  private def restoreCTEDefAttrs(
+      input: Seq[Expression],
+      mapping: Map[Attribute, Expression]): Seq[Expression] = {
+    input.map(e => e.transform {
+      case a: Attribute =>
+        mapping.keys.find(_.semanticEquals(a)).map(mapping).getOrElse(a)
+    })
+  }
+
+  /**
+   * Gather all the predicates and referenced attributes on different points 
of CTE references
+   * using pattern `ScanOperation` (which takes care of determinism) and 
combine those predicates
+   * and attributes that belong to the same CTE definition.
+   * For the same CTE definition, if any of its references does not have 
predicates, the combined
+   * predicate will be a TRUE literal, which means there will be no predicate 
push-down.
+   */
+  private def gatherPredicatesAndAttributes(plan: LogicalPlan, cteMap: 
CTEMap): Unit = {
+    plan match {
+      case WithCTE(child, cteDefs) =>
+        cteDefs.zipWithIndex.foreach { case (cteDef, precedence) =>
+          gatherPredicatesAndAttributes(cteDef.child, cteMap)
+          cteMap.put(cteDef.id, (cteDef, precedence, Seq.empty, 
AttributeSet.empty))
+        }
+        gatherPredicatesAndAttributes(child, cteMap)
+
+      case ScanOperation(projects, predicates, ref: CTERelationRef) =>
+        val (cteDef, precedence, preds, attrs) = cteMap(ref.cteId)
+        val attrMapping = ref.output.zip(cteDef.output).map{ case (r, d) => r 
-> d }.toMap
+        val newPredicates = if (isTruePredicate(preds)) {
+          preds
+        } else {
+          // Make sure we only push down predicates that do not contain 
forward CTE references.
+          val filteredPredicates = restoreCTEDefAttrs(predicates.filter(_.find 
{
+            case s: SubqueryExpression => s.plan.find {
+              case r: CTERelationRef =>
+                // If the ref's ID does not exist in the map or if ref's 
corresponding precedence
+                // is bigger than that of the current CTE we are pushing 
predicates for, it
+                // indicates a forward reference and we should exclude this 
predicate.
+                !cteMap.contains(r.cteId) || cteMap(r.cteId)._2 >= precedence
+              case _ => false
+            }.nonEmpty
+            case _ => false
+          }.isEmpty), 
attrMapping).filter(_.references.forall(cteDef.outputSet.contains))
+          if (filteredPredicates.isEmpty) {
+            Seq(Literal.TrueLiteral)
+          } else {
+            preds :+ filteredPredicates.reduce(And)
+          }
+        }
+        val newAttributes = attrs ++
+          AttributeSet(restoreCTEDefAttrs(projects.flatMap(_.references), 
attrMapping)) ++
+          AttributeSet(restoreCTEDefAttrs(predicates.flatMap(_.references), 
attrMapping))
+
+        cteMap.update(ref.cteId, (cteDef, precedence, newPredicates, 
newAttributes))
+        plan.subqueriesAll.foreach(s => gatherPredicatesAndAttributes(s, 
cteMap))
+
+      case _ =>
+        plan.children.foreach(c => gatherPredicatesAndAttributes(c, cteMap))
+        plan.subqueries.foreach(s => gatherPredicatesAndAttributes(s, cteMap))
+    }
+  }
+
+  /**
+   * Push down the combined predicate and attribute references to each CTE 
definition plan.
+   *
+   * In order to guarantee idempotency, we keep the predicates (if any) being 
pushed down by the
+   * last iteration of this rule in a temporary field of `CTERelationDef`, so 
that on the current
+   * iteration, we only push down predicates for a CTE def if there exists any 
new predicate that
+   * has not been pushed before. Also, since part of a new predicate might 
overlap with some
+   * existing predicate and it can be hard to extract only the non-overlapping 
part, we also keep
+   * the original CTE definition plan without any predicate push-down in that 
temporary field so
+   * that when we do a new predicate push-down, we can construct a new plan 
with all latest
+   * predicates over the original plan without having to figure out the exact 
predicate difference.
+   */
+  private def pushdownPredicatesAndAttributes(
+      plan: LogicalPlan,
+      cteMap: CTEMap): LogicalPlan = plan.transformWithSubqueries {
+    case cteDef @ CTERelationDef(child, id, originalPlanWithPredicates) =>
+      val (_, _, newPreds, newAttrSet) = cteMap(id)
+      val originalPlan = originalPlanWithPredicates.map(_._1).getOrElse(child)
+      val preds = originalPlanWithPredicates.map(_._2).getOrElse(Seq.empty)
+      if (!isTruePredicate(newPreds) &&
+          newPreds.exists(newPred => 
!preds.exists(_.semanticEquals(newPred)))) {
+        val newCombinedPred = newPreds.reduce(Or)
+        val newChild = if (needsPruning(originalPlan, newAttrSet)) {
+          Project(newAttrSet.toSeq, originalPlan)
+        } else {
+          originalPlan
+        }
+        CTERelationDef(Filter(newCombinedPred, newChild), id, 
Some((originalPlan, newPreds)))
+      } else if (needsPruning(cteDef.child, newAttrSet)) {
+        CTERelationDef(Project(newAttrSet.toSeq, cteDef.child), id, 
Some((originalPlan, preds)))
+      } else {
+        cteDef
+      }
+
+    case cteRef @ CTERelationRef(cteId, _, output, _) =>
+      val (cteDef, _, _, newAttrSet) = cteMap(cteId)
+      if (newAttrSet.size < output.size) {
+        val indices = newAttrSet.toSeq.map(cteDef.output.indexOf)
+        val newOutput = indices.map(output)
+        cteRef.copy(output = newOutput)
+      } else {
+        // Do not change the order of output columns if no column is pruned, 
in which case there
+        // might be no Project and the order is important.
+        cteRef
+      }
+  }
+
+  private def isTruePredicate(predicates: Seq[Expression]): Boolean = {
+    predicates.length == 1 && predicates.head == Literal.TrueLiteral
+  }
+
+  private def needsPruning(sourcePlan: LogicalPlan, attributeSet: 
AttributeSet): Boolean = {
+    attributeSet.size < sourcePlan.outputSet.size && 
attributeSet.subsetOf(sourcePlan.outputSet)
+  }
+}
+
+/**
+ * Clean up temporary info from [[CTERelationDef]] nodes. This rule should be 
called after all
+ * iterations of [[PushdownPredicatesAndPruneColumnsForCTEDef]] are done.
+ */
+object CleanUpTempCTEInfo extends Rule[LogicalPlan] {
+  override def apply(plan: LogicalPlan): LogicalPlan =
+    plan.transformWithPruning(_.containsPattern(CTE)) {
+      case cteDef @ CTERelationDef(_, _, Some(_)) =>
+        cteDef.copy(originalPlanWithPredicates = None)
+    }
+}
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceCTERefWithRepartition.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceCTERefWithRepartition.scala
new file mode 100644
index 00000000000..e0d0417ce51
--- /dev/null
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceCTERefWithRepartition.scala
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import scala.collection.mutable
+
+import org.apache.spark.sql.catalyst.analysis.DeduplicateRelations
+import org.apache.spark.sql.catalyst.expressions.{Alias, SubqueryExpression}
+import org.apache.spark.sql.catalyst.plans.Inner
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.trees.TreePattern.{CTE, PLAN_EXPRESSION}
+
+/**
+ * Replaces CTE references that have not been previously inlined with 
[[Repartition]] operations
+ * which will then be planned as shuffles and reused across different 
reference points.
+ *
+ * Note that this rule should be called at the very end of the optimization 
phase to best guarantee
+ * that CTE repartition shuffles are reused.
+ */
+object ReplaceCTERefWithRepartition extends Rule[LogicalPlan] {
+
+  override def apply(plan: LogicalPlan): LogicalPlan = plan match {
+    case _: Subquery => plan
+    case _ =>
+      replaceWithRepartition(plan, mutable.HashMap.empty[Long, LogicalPlan])
+  }
+
+  private def replaceWithRepartition(
+      plan: LogicalPlan,
+      cteMap: mutable.HashMap[Long, LogicalPlan]): LogicalPlan = plan match {
+    case WithCTE(child, cteDefs) =>
+      cteDefs.foreach { cteDef =>
+        val inlined = replaceWithRepartition(cteDef.child, cteMap)
+        val withRepartition = if (inlined.isInstanceOf[RepartitionOperation]) {
+          // If the CTE definition plan itself is a repartition operation, we 
do not need to add an
+          // extra repartition shuffle.
+          inlined
+        } else {
+          Repartition(conf.numShufflePartitions, shuffle = true, inlined)
+        }
+        cteMap.put(cteDef.id, withRepartition)
+      }
+      replaceWithRepartition(child, cteMap)
+
+    case ref: CTERelationRef =>
+      val cteDefPlan = cteMap(ref.cteId)
+      if (ref.outputSet == cteDefPlan.outputSet) {
+        cteDefPlan
+      } else {
+        val ctePlan = DeduplicateRelations(
+          Join(cteDefPlan, cteDefPlan, Inner, None, JoinHint(None, 
None))).children(1)
+        val projectList = ref.output.zip(ctePlan.output).map { case (tgtAttr, 
srcAttr) =>
+          Alias(srcAttr, tgtAttr.name)(exprId = tgtAttr.exprId)
+        }
+        Project(projectList, ctePlan)
+      }
+
+    case _ if plan.containsPattern(CTE) =>
+      plan
+        .withNewChildren(plan.children.map(c => replaceWithRepartition(c, 
cteMap)))
+        
.transformExpressionsWithPruning(_.containsAllPatterns(PLAN_EXPRESSION, CTE)) {
+          case e: SubqueryExpression =>
+            e.withNewPlan(replaceWithRepartition(e.plan, cteMap))
+        }
+
+    case _ => plan
+  }
+}
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
index 5d749b8fc4b..0f8df5df376 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
@@ -448,6 +448,14 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]]
     subqueries ++ subqueries.flatMap(_.subqueriesAll)
   }
 
+  /**
+   * This method is similar to the transform method, but also applies the 
given partial function
+   * also to all the plans in the subqueries of a node. This method is useful 
when we want
+   * to rewrite the whole plan, include its subqueries, in one go.
+   */
+  def transformWithSubqueries(f: PartialFunction[PlanType, PlanType]): 
PlanType =
+    transformDownWithSubqueries(f)
+
   /**
    * Returns a copy of this node where the given partial function has been 
recursively applied
    * first to the subqueries in this node's children, then this node's 
children, and finally
@@ -465,6 +473,29 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]]
     }
   }
 
+  /**
+   * This method is the top-down (pre-order) counterpart of 
transformUpWithSubqueries.
+   * Returns a copy of this node where the given partial function has been 
recursively applied
+   * first to this node, then this node's subqueries and finally this node's 
children.
+   * When the partial function does not apply to a given node, it is left 
unchanged.
+   */
+  def transformDownWithSubqueries(f: PartialFunction[PlanType, PlanType]): 
PlanType = {
+    val g: PartialFunction[PlanType, PlanType] = new PartialFunction[PlanType, 
PlanType] {
+      override def isDefinedAt(x: PlanType): Boolean = true
+
+      override def apply(plan: PlanType): PlanType = {
+        val transformed = f.applyOrElse[PlanType, PlanType](plan, identity)
+        transformed transformExpressionsDown {
+          case planExpression: PlanExpression[PlanType] =>
+            val newPlan = planExpression.plan.transformDownWithSubqueries(f)
+            planExpression.withNewPlan(newPlan)
+        }
+      }
+    }
+
+    transformDown(g)
+  }
+
   /**
    * A variant of `collect`. This method not only apply the given function to 
all elements in this
    * plan, also considering all the plans in its (nested) subqueries
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index 895eeb77207..e5eab691d14 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -659,8 +659,15 @@ case class UnresolvedWith(
  * A wrapper for CTE definition plan with a unique ID.
  * @param child The CTE definition query plan.
  * @param id    The unique ID for this CTE definition.
+ * @param originalPlanWithPredicates The original query plan before predicate 
pushdown and the
+ *                                   predicates that have been pushed down 
into `child`. This is
+ *                                   a temporary field used by optimization 
rules for CTE predicate
+ *                                   pushdown to help ensure rule idempotency.
  */
-case class CTERelationDef(child: LogicalPlan, id: Long = CTERelationDef.newId) 
extends UnaryNode {
+case class CTERelationDef(
+    child: LogicalPlan,
+    id: Long = CTERelationDef.newId,
+    originalPlanWithPredicates: Option[(LogicalPlan, Seq[Expression])] = None) 
extends UnaryNode {
 
   final override val nodePatterns: Seq[TreePattern] = Seq(CTE)
 
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala
index 804f1edbe06..7dde85014e7 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala
@@ -108,7 +108,8 @@ trait AnalysisTest extends PlanTest {
         case v: View if v.isTempViewStoringAnalyzedPlan => v.child
       }
       val actualPlan = if (inlineCTE) {
-        InlineCTE(transformed)
+        val inlineCTE = InlineCTE()
+        inlineCTE(transformed)
       } else {
         transformed
       }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
index 9bf8de5ea6c..5dcdebfbe0e 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
@@ -21,8 +21,6 @@ import java.io.{BufferedWriter, OutputStreamWriter}
 import java.util.UUID
 import java.util.concurrent.atomic.AtomicLong
 
-import scala.collection.mutable
-
 import org.apache.hadoop.fs.Path
 
 import org.apache.spark.internal.Logging
@@ -32,7 +30,7 @@ import org.apache.spark.sql.catalyst.{InternalRow, 
QueryPlanningTracker}
 import org.apache.spark.sql.catalyst.analysis.UnsupportedOperationChecker
 import org.apache.spark.sql.catalyst.expressions.codegen.ByteCodeStats
 import org.apache.spark.sql.catalyst.plans.QueryPlan
-import org.apache.spark.sql.catalyst.plans.logical.{AppendData, Command, 
CommandResult, CreateTableAsSelect, CTERelationDef, LogicalPlan, 
OverwriteByExpression, OverwritePartitionsDynamic, ReplaceTableAsSelect, 
ReturnAnswer}
+import org.apache.spark.sql.catalyst.plans.logical.{AppendData, Command, 
CommandResult, CreateTableAsSelect, LogicalPlan, OverwriteByExpression, 
OverwritePartitionsDynamic, ReplaceTableAsSelect, ReturnAnswer}
 import org.apache.spark.sql.catalyst.rules.{PlanChangeLogger, Rule}
 import org.apache.spark.sql.catalyst.util.StringUtils.PlanStringConcat
 import org.apache.spark.sql.catalyst.util.truncatedString
@@ -64,17 +62,6 @@ class QueryExecution(
   // TODO: Move the planner an optimizer into here from SessionState.
   protected def planner = sparkSession.sessionState.planner
 
-  // The CTE map for the planner shared by the main query and all subqueries.
-  private val cteMap = mutable.HashMap.empty[Long, CTERelationDef]
-
-  def withCteMap[T](f: => T): T = {
-    val old = QueryExecution.currentCteMap.get()
-    QueryExecution.currentCteMap.set(cteMap)
-    try f finally {
-      QueryExecution.currentCteMap.set(old)
-    }
-  }
-
   def assertAnalyzed(): Unit = analyzed
 
   def assertSupported(): Unit = {
@@ -147,7 +134,7 @@ class QueryExecution(
 
   private def assertOptimized(): Unit = optimizedPlan
 
-  lazy val sparkPlan: SparkPlan = withCteMap {
+  lazy val sparkPlan: SparkPlan = {
     // We need to materialize the optimizedPlan here because sparkPlan is also 
tracked under
     // the planning phase
     assertOptimized()
@@ -160,7 +147,7 @@ class QueryExecution(
 
   // executedPlan should not be used to initialize any SparkPlan. It should be
   // only used for execution.
-  lazy val executedPlan: SparkPlan = withCteMap {
+  lazy val executedPlan: SparkPlan = {
     // We need to materialize the optimizedPlan here, before tracking the 
planning phase, to ensure
     // that the optimization time is not counted as part of the planning phase.
     assertOptimized()
@@ -497,8 +484,4 @@ object QueryExecution {
     val preparationRules = preparations(session, 
Option(InsertAdaptiveSparkPlan(context)), true)
     prepareForExecution(preparationRules, sparkPlan.clone())
   }
-
-  private val currentCteMap = new ThreadLocal[mutable.HashMap[Long, 
CTERelationDef]]()
-
-  def cteMap: mutable.HashMap[Long, CTERelationDef] = currentCteMap.get()
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
index 8c134363af1..d9457a20d91 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
@@ -76,7 +76,8 @@ class SparkOptimizer(
       ColumnPruning,
       PushPredicateThroughNonJoin,
       RemoveNoopOperators) :+
-    Batch("User Provided Optimizers", fixedPoint, 
experimentalMethods.extraOptimizations: _*)
+    Batch("User Provided Optimizers", fixedPoint, 
experimentalMethods.extraOptimizations: _*) :+
+    Batch("Replace CTE with Repartition", Once, ReplaceCTERefWithRepartition)
 
   override def nonExcludableRules: Seq[String] = super.nonExcludableRules :+
     ExtractPythonUDFFromJoinCondition.ruleName :+
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala
index 32ac58f8353..6994aaf47df 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala
@@ -44,7 +44,6 @@ class SparkPlanner(val session: SparkSession, val 
experimentalMethods: Experimen
       JoinSelection ::
       InMemoryScans ::
       SparkScripts ::
-      WithCTEStrategy ::
       BasicOperators :: Nil)
 
   /**
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 675b1581003..3b8a70ffe94 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -29,7 +29,6 @@ import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, 
BuildRight, BuildSide
 import org.apache.spark.sql.catalyst.planning._
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.catalyst.plans.physical.RoundRobinPartitioning
 import org.apache.spark.sql.catalyst.streaming.{InternalOutputModes, 
StreamingRelationV2}
 import org.apache.spark.sql.errors.{QueryCompilationErrors, 
QueryExecutionErrors}
 import org.apache.spark.sql.execution.aggregate.AggUtils
@@ -675,36 +674,6 @@ abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
     }
   }
 
-  /**
-   * Strategy to plan CTE relations left not inlined.
-   */
-  object WithCTEStrategy extends Strategy {
-    override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
-      case WithCTE(plan, cteDefs) =>
-        val cteMap = QueryExecution.cteMap
-        cteDefs.foreach { cteDef =>
-          cteMap.put(cteDef.id, cteDef)
-        }
-        planLater(plan) :: Nil
-
-      case r: CTERelationRef =>
-        val ctePlan = QueryExecution.cteMap(r.cteId).child
-        val projectList = r.output.zip(ctePlan.output).map { case (tgtAttr, 
srcAttr) =>
-          Alias(srcAttr, tgtAttr.name)(exprId = tgtAttr.exprId)
-        }
-        val newPlan = Project(projectList, ctePlan)
-        // Plan CTE ref as a repartition shuffle so that all refs of the same 
CTE def will share
-        // an Exchange reuse at runtime.
-        // TODO create a new identity partitioning instead of using 
RoundRobinPartitioning.
-        exchange.ShuffleExchangeExec(
-          RoundRobinPartitioning(conf.numShufflePartitions),
-          planLater(newPlan),
-          REPARTITION_BY_COL) :: Nil
-
-      case _ => Nil
-    }
-  }
-
   object BasicOperators extends Strategy {
     def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
       case d: DataWritingCommand => DataWritingCommandExec(d, 
planLater(d.query)) :: Nil
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
index c6505a0ea5f..df302e5dc75 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
@@ -148,9 +148,7 @@ case class AdaptiveSparkPlanExec(
     collapseCodegenStagesRule
   )
 
-  private def optimizeQueryStage(
-      plan: SparkPlan,
-      isFinalStage: Boolean): SparkPlan = context.qe.withCteMap {
+  private def optimizeQueryStage(plan: SparkPlan, isFinalStage: Boolean): 
SparkPlan = {
     val optimized = queryStageOptimizerRules.foldLeft(plan) { case 
(latestPlan, rule) =>
       val applied = rule.apply(latestPlan)
       val result = rule match {
@@ -640,8 +638,7 @@ case class AdaptiveSparkPlanExec(
   /**
    * Re-optimize and run physical planning on the current logical plan based 
on the latest stats.
    */
-  private def reOptimize(
-      logicalPlan: LogicalPlan): (SparkPlan, LogicalPlan) = 
context.qe.withCteMap {
+  private def reOptimize(logicalPlan: LogicalPlan): (SparkPlan, LogicalPlan) = 
{
     logicalPlan.invalidateStatsCache()
     val optimized = optimizer.execute(logicalPlan)
     val sparkPlan = 
context.session.sessionState.planner.plan(ReturnAnswer(optimized)).next()
diff --git 
a/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-select.sql
 
b/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-select.sql
index a76a0107220..4c80b268c20 100644
--- 
a/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-select.sql
+++ 
b/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-select.sql
@@ -145,3 +145,45 @@ SELECT t1c, (SELECT t1c WHERE t1c = 8) FROM t1;
 SELECT t1c, t1d, (SELECT c + d FROM (SELECT t1c AS c, t1d AS d)) FROM t1;
 SELECT t1c, (SELECT SUM(c) FROM (SELECT t1c AS c)) FROM t1;
 SELECT t1a, (SELECT SUM(t2b) FROM t2 JOIN (SELECT t1a AS a) ON t2a = a) FROM 
t1;
+
+-- CTE in correlated scalar subqueries
+CREATE OR REPLACE TEMPORARY VIEW t1 AS VALUES (0, 1), (1, 2) t1(c1, c2);
+CREATE OR REPLACE TEMPORARY VIEW t2 AS VALUES (0, 2), (0, 3) t2(c1, c2);
+
+-- Single row subquery
+SELECT c1, (WITH t AS (SELECT 1 AS a) SELECT a + c1 FROM t) FROM t1;
+-- Correlation in CTE.
+SELECT c1, (WITH t AS (SELECT * FROM t2 WHERE c1 = t1.c1) SELECT SUM(c2) FROM 
t) FROM t1;
+-- Multiple CTE definitions.
+SELECT c1, (
+    WITH t3 AS (SELECT c1 + 1 AS c1, c2 + 1 AS c2 FROM t2),
+    t4 AS (SELECT * FROM t3 WHERE t1.c1 = c1)
+    SELECT SUM(c2) FROM t4
+) FROM t1;
+-- Multiple CTE references.
+SELECT c1, (
+    WITH t AS (SELECT * FROM t2)
+    SELECT SUM(c2) FROM (SELECT c1, c2 FROM t UNION SELECT c2, c1 FROM t) 
r(c1, c2)
+    WHERE c1 = t1.c1
+) FROM t1;
+-- Reference CTE in both the main query and the subquery.
+WITH v AS (SELECT * FROM t2)
+SELECT * FROM t1 WHERE c1 > (
+    WITH t AS (SELECT * FROM t2)
+    SELECT COUNT(*) FROM v WHERE c1 = t1.c1 AND c1 > (SELECT SUM(c2) FROM t 
WHERE c1 = v.c1)
+);
+-- Single row subquery that references CTE in the main query.
+WITH t AS (SELECT 1 AS a)
+SELECT c1, (SELECT a FROM t WHERE a = c1) FROM t1;
+-- Multiple CTE references with non-deterministic CTEs.
+WITH
+v1 AS (SELECT c1, c2, rand(0) c3 FROM t1),
+v2 AS (SELECT c1, c2, rand(0) c4 FROM v1 WHERE c3 IN (SELECT c3 FROM v1))
+SELECT c1, (
+    WITH v3 AS (SELECT c1, c2, rand(0) c5 FROM t2)
+    SELECT COUNT(*) FROM (
+        SELECT * FROM v2 WHERE c1 > 0
+        UNION SELECT * FROM v2 WHERE c2 > 0
+        UNION SELECT * FROM v3 WHERE c2 > 0
+    ) WHERE c1 = v1.c1
+) FROM v1;
diff --git 
a/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-select.sql.out
 
b/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-select.sql.out
index 8fac940f8ef..3eb1c6ffba1 100644
--- 
a/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-select.sql.out
+++ 
b/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-select.sql.out
@@ -1,5 +1,5 @@
 -- Automatically generated by SQLQueryTestSuite
--- Number of queries: 17
+-- Number of queries: 26
 
 
 -- !query
@@ -317,3 +317,104 @@ val1d     NULL
 val1e  8
 val1e  8
 val1e  8
+
+
+-- !query
+CREATE OR REPLACE TEMPORARY VIEW t1 AS VALUES (0, 1), (1, 2) t1(c1, c2)
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+CREATE OR REPLACE TEMPORARY VIEW t2 AS VALUES (0, 2), (0, 3) t2(c1, c2)
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+SELECT c1, (WITH t AS (SELECT 1 AS a) SELECT a + c1 FROM t) FROM t1
+-- !query schema
+struct<c1:int,scalarsubquery(c1):int>
+-- !query output
+0      1
+1      2
+
+
+-- !query
+SELECT c1, (WITH t AS (SELECT * FROM t2 WHERE c1 = t1.c1) SELECT SUM(c2) FROM 
t) FROM t1
+-- !query schema
+struct<c1:int,scalarsubquery(c1):bigint>
+-- !query output
+0      5
+1      NULL
+
+
+-- !query
+SELECT c1, (
+    WITH t3 AS (SELECT c1 + 1 AS c1, c2 + 1 AS c2 FROM t2),
+    t4 AS (SELECT * FROM t3 WHERE t1.c1 = c1)
+    SELECT SUM(c2) FROM t4
+) FROM t1
+-- !query schema
+struct<c1:int,scalarsubquery(c1):bigint>
+-- !query output
+0      NULL
+1      7
+
+
+-- !query
+SELECT c1, (
+    WITH t AS (SELECT * FROM t2)
+    SELECT SUM(c2) FROM (SELECT c1, c2 FROM t UNION SELECT c2, c1 FROM t) 
r(c1, c2)
+    WHERE c1 = t1.c1
+) FROM t1
+-- !query schema
+struct<c1:int,scalarsubquery(c1):bigint>
+-- !query output
+0      5
+1      NULL
+
+
+-- !query
+WITH v AS (SELECT * FROM t2)
+SELECT * FROM t1 WHERE c1 > (
+    WITH t AS (SELECT * FROM t2)
+    SELECT COUNT(*) FROM v WHERE c1 = t1.c1 AND c1 > (SELECT SUM(c2) FROM t 
WHERE c1 = v.c1)
+)
+-- !query schema
+struct<c1:int,c2:int>
+-- !query output
+1      2
+
+
+-- !query
+WITH t AS (SELECT 1 AS a)
+SELECT c1, (SELECT a FROM t WHERE a = c1) FROM t1
+-- !query schema
+struct<c1:int,scalarsubquery(c1):int>
+-- !query output
+0      NULL
+1      1
+
+
+-- !query
+WITH
+v1 AS (SELECT c1, c2, rand(0) c3 FROM t1),
+v2 AS (SELECT c1, c2, rand(0) c4 FROM v1 WHERE c3 IN (SELECT c3 FROM v1))
+SELECT c1, (
+    WITH v3 AS (SELECT c1, c2, rand(0) c5 FROM t2)
+    SELECT COUNT(*) FROM (
+        SELECT * FROM v2 WHERE c1 > 0
+        UNION SELECT * FROM v2 WHERE c2 > 0
+        UNION SELECT * FROM v3 WHERE c2 > 0
+    ) WHERE c1 = v1.c1
+) FROM v1
+-- !query schema
+struct<c1:int,scalarsubquery(c1):bigint>
+-- !query output
+0      3
+1      1
diff --git 
a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q23a.sf100/explain.txt
 
b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q23a.sf100/explain.txt
index 5bf5193487b..7f419ce3eaf 100644
--- 
a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q23a.sf100/explain.txt
+++ 
b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q23a.sf100/explain.txt
@@ -360,19 +360,19 @@ Right keys [1]: [i_item_sk#14]
 Join condition: None
 
 (61) Project [codegen id : 25]
-Output [3]: [d_date#12, i_item_sk#14, substr(i_item_desc#15, 1, 30) AS 
_groupingexpression#47]
+Output [3]: [d_date#12, i_item_sk#14, substr(i_item_desc#15, 1, 30) AS 
_groupingexpression#17]
 Input [4]: [ss_item_sk#8, d_date#12, i_item_sk#14, i_item_desc#15]
 
 (62) HashAggregate [codegen id : 25]
-Input [3]: [d_date#12, i_item_sk#14, _groupingexpression#47]
-Keys [3]: [_groupingexpression#47, i_item_sk#14, d_date#12]
+Input [3]: [d_date#12, i_item_sk#14, _groupingexpression#17]
+Keys [3]: [_groupingexpression#17, i_item_sk#14, d_date#12]
 Functions [1]: [partial_count(1)]
 Aggregate Attributes [1]: [count#18]
-Results [4]: [_groupingexpression#47, i_item_sk#14, d_date#12, count#19]
+Results [4]: [_groupingexpression#17, i_item_sk#14, d_date#12, count#19]
 
 (63) HashAggregate [codegen id : 25]
-Input [4]: [_groupingexpression#47, i_item_sk#14, d_date#12, count#19]
-Keys [3]: [_groupingexpression#47, i_item_sk#14, d_date#12]
+Input [4]: [_groupingexpression#17, i_item_sk#14, d_date#12, count#19]
+Keys [3]: [_groupingexpression#17, i_item_sk#14, d_date#12]
 Functions [1]: [count(1)]
 Aggregate Attributes [1]: [count(1)#20]
 Results [2]: [i_item_sk#14 AS item_sk#21, count(1)#20 AS cnt#22]
@@ -400,7 +400,7 @@ Input [5]: [ws_item_sk#41, ws_bill_customer_sk#42, 
ws_quantity#43, ws_list_price
 
 (69) Exchange
 Input [4]: [ws_bill_customer_sk#42, ws_quantity#43, ws_list_price#44, 
ws_sold_date_sk#45]
-Arguments: hashpartitioning(ws_bill_customer_sk#42, 5), ENSURE_REQUIREMENTS, 
[id=#48]
+Arguments: hashpartitioning(ws_bill_customer_sk#42, 5), ENSURE_REQUIREMENTS, 
[id=#47]
 
 (70) Sort [codegen id : 27]
 Input [4]: [ws_bill_customer_sk#42, ws_quantity#43, ws_list_price#44, 
ws_sold_date_sk#45]
@@ -433,11 +433,11 @@ Input [4]: [ss_customer_sk#24, ss_quantity#25, 
ss_sales_price#26, c_customer_sk#
 Input [3]: [ss_quantity#25, ss_sales_price#26, c_customer_sk#29]
 Keys [1]: [c_customer_sk#29]
 Functions [1]: 
[partial_sum(CheckOverflow((promote_precision(cast(ss_quantity#25 as 
decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), 
DecimalType(18,2)))]
-Aggregate Attributes [2]: [sum#49, isEmpty#50]
-Results [3]: [c_customer_sk#29, sum#51, isEmpty#52]
+Aggregate Attributes [2]: [sum#48, isEmpty#49]
+Results [3]: [c_customer_sk#29, sum#50, isEmpty#51]
 
 (78) HashAggregate [codegen id : 32]
-Input [3]: [c_customer_sk#29, sum#51, isEmpty#52]
+Input [3]: [c_customer_sk#29, sum#50, isEmpty#51]
 Keys [1]: [c_customer_sk#29]
 Functions [1]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#25 as 
decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), 
DecimalType(18,2)))]
 Aggregate Attributes [1]: 
[sum(CheckOverflow((promote_precision(cast(ss_quantity#25 as decimal(12,2))) * 
promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), 
DecimalType(18,2)))#35]
@@ -465,16 +465,16 @@ Output [3]: [ws_quantity#43, ws_list_price#44, 
ws_sold_date_sk#45]
 Input [4]: [ws_bill_customer_sk#42, ws_quantity#43, ws_list_price#44, 
ws_sold_date_sk#45]
 
 (84) ReusedExchange [Reuses operator id: 95]
-Output [1]: [d_date_sk#53]
+Output [1]: [d_date_sk#52]
 
 (85) BroadcastHashJoin [codegen id : 34]
 Left keys [1]: [ws_sold_date_sk#45]
-Right keys [1]: [d_date_sk#53]
+Right keys [1]: [d_date_sk#52]
 Join condition: None
 
 (86) Project [codegen id : 34]
-Output [1]: [CheckOverflow((promote_precision(cast(ws_quantity#43 as 
decimal(12,2))) * promote_precision(cast(ws_list_price#44 as decimal(12,2)))), 
DecimalType(18,2)) AS sales#54]
-Input [4]: [ws_quantity#43, ws_list_price#44, ws_sold_date_sk#45, d_date_sk#53]
+Output [1]: [CheckOverflow((promote_precision(cast(ws_quantity#43 as 
decimal(12,2))) * promote_precision(cast(ws_list_price#44 as decimal(12,2)))), 
DecimalType(18,2)) AS sales#53]
+Input [4]: [ws_quantity#43, ws_list_price#44, ws_sold_date_sk#45, d_date_sk#52]
 
 (87) Union
 
@@ -482,19 +482,19 @@ Input [4]: [ws_quantity#43, ws_list_price#44, 
ws_sold_date_sk#45, d_date_sk#53]
 Input [1]: [sales#40]
 Keys: []
 Functions [1]: [partial_sum(sales#40)]
-Aggregate Attributes [2]: [sum#55, isEmpty#56]
-Results [2]: [sum#57, isEmpty#58]
+Aggregate Attributes [2]: [sum#54, isEmpty#55]
+Results [2]: [sum#56, isEmpty#57]
 
 (89) Exchange
-Input [2]: [sum#57, isEmpty#58]
-Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#59]
+Input [2]: [sum#56, isEmpty#57]
+Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#58]
 
 (90) HashAggregate [codegen id : 36]
-Input [2]: [sum#57, isEmpty#58]
+Input [2]: [sum#56, isEmpty#57]
 Keys: []
 Functions [1]: [sum(sales#40)]
-Aggregate Attributes [1]: [sum(sales#40)#60]
-Results [1]: [sum(sales#40)#60 AS sum(sales)#61]
+Aggregate Attributes [1]: [sum(sales#40)#59]
+Results [1]: [sum(sales#40)#59 AS sum(sales)#60]
 
 ===== Subqueries =====
 
@@ -507,26 +507,26 @@ BroadcastExchange (95)
 
 
 (91) Scan parquet default.date_dim
-Output [3]: [d_date_sk#39, d_year#62, d_moy#63]
+Output [3]: [d_date_sk#39, d_year#61, d_moy#62]
 Batched: true
 Location [not included in comparison]/{warehouse_dir}/date_dim]
 PushedFilters: [IsNotNull(d_year), IsNotNull(d_moy), EqualTo(d_year,2000), 
EqualTo(d_moy,2), IsNotNull(d_date_sk)]
 ReadSchema: struct<d_date_sk:int,d_year:int,d_moy:int>
 
 (92) ColumnarToRow [codegen id : 1]
-Input [3]: [d_date_sk#39, d_year#62, d_moy#63]
+Input [3]: [d_date_sk#39, d_year#61, d_moy#62]
 
 (93) Filter [codegen id : 1]
-Input [3]: [d_date_sk#39, d_year#62, d_moy#63]
-Condition : ((((isnotnull(d_year#62) AND isnotnull(d_moy#63)) AND (d_year#62 = 
2000)) AND (d_moy#63 = 2)) AND isnotnull(d_date_sk#39))
+Input [3]: [d_date_sk#39, d_year#61, d_moy#62]
+Condition : ((((isnotnull(d_year#61) AND isnotnull(d_moy#62)) AND (d_year#61 = 
2000)) AND (d_moy#62 = 2)) AND isnotnull(d_date_sk#39))
 
 (94) Project [codegen id : 1]
 Output [1]: [d_date_sk#39]
-Input [3]: [d_date_sk#39, d_year#62, d_moy#63]
+Input [3]: [d_date_sk#39, d_year#61, d_moy#62]
 
 (95) BroadcastExchange
 Input [1]: [d_date_sk#39]
-Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as 
bigint)),false), [id=#64]
+Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as 
bigint)),false), [id=#63]
 
 Subquery:2 Hosting operator id = 5 Hosting Expression = ss_sold_date_sk#9 IN 
dynamicpruning#10
 BroadcastExchange (100)
@@ -537,26 +537,26 @@ BroadcastExchange (100)
 
 
 (96) Scan parquet default.date_dim
-Output [3]: [d_date_sk#11, d_date#12, d_year#65]
+Output [3]: [d_date_sk#11, d_date#12, d_year#64]
 Batched: true
 Location [not included in comparison]/{warehouse_dir}/date_dim]
 PushedFilters: [In(d_year, [2000,2001,2002,2003]), IsNotNull(d_date_sk)]
 ReadSchema: struct<d_date_sk:int,d_date:date,d_year:int>
 
 (97) ColumnarToRow [codegen id : 1]
-Input [3]: [d_date_sk#11, d_date#12, d_year#65]
+Input [3]: [d_date_sk#11, d_date#12, d_year#64]
 
 (98) Filter [codegen id : 1]
-Input [3]: [d_date_sk#11, d_date#12, d_year#65]
-Condition : (d_year#65 IN (2000,2001,2002,2003) AND isnotnull(d_date_sk#11))
+Input [3]: [d_date_sk#11, d_date#12, d_year#64]
+Condition : (d_year#64 IN (2000,2001,2002,2003) AND isnotnull(d_date_sk#11))
 
 (99) Project [codegen id : 1]
 Output [2]: [d_date_sk#11, d_date#12]
-Input [3]: [d_date_sk#11, d_date#12, d_year#65]
+Input [3]: [d_date_sk#11, d_date#12, d_year#64]
 
 (100) BroadcastExchange
 Input [2]: [d_date_sk#11, d_date#12]
-Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as 
bigint)),false), [id=#66]
+Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as 
bigint)),false), [id=#65]
 
 Subquery:3 Hosting operator id = 44 Hosting Expression = Subquery 
scalar-subquery#37, [id=#38]
 * HashAggregate (117)
@@ -579,89 +579,89 @@ Subquery:3 Hosting operator id = 44 Hosting Expression = 
Subquery scalar-subquer
 
 
 (101) Scan parquet default.store_sales
-Output [4]: [ss_customer_sk#67, ss_quantity#68, ss_sales_price#69, 
ss_sold_date_sk#70]
+Output [4]: [ss_customer_sk#66, ss_quantity#67, ss_sales_price#68, 
ss_sold_date_sk#69]
 Batched: true
 Location: InMemoryFileIndex []
-PartitionFilters: [isnotnull(ss_sold_date_sk#70), 
dynamicpruningexpression(ss_sold_date_sk#70 IN dynamicpruning#71)]
+PartitionFilters: [isnotnull(ss_sold_date_sk#69), 
dynamicpruningexpression(ss_sold_date_sk#69 IN dynamicpruning#70)]
 PushedFilters: [IsNotNull(ss_customer_sk)]
 ReadSchema: 
struct<ss_customer_sk:int,ss_quantity:int,ss_sales_price:decimal(7,2)>
 
 (102) ColumnarToRow [codegen id : 2]
-Input [4]: [ss_customer_sk#67, ss_quantity#68, ss_sales_price#69, 
ss_sold_date_sk#70]
+Input [4]: [ss_customer_sk#66, ss_quantity#67, ss_sales_price#68, 
ss_sold_date_sk#69]
 
 (103) Filter [codegen id : 2]
-Input [4]: [ss_customer_sk#67, ss_quantity#68, ss_sales_price#69, 
ss_sold_date_sk#70]
-Condition : isnotnull(ss_customer_sk#67)
+Input [4]: [ss_customer_sk#66, ss_quantity#67, ss_sales_price#68, 
ss_sold_date_sk#69]
+Condition : isnotnull(ss_customer_sk#66)
 
 (104) ReusedExchange [Reuses operator id: 122]
-Output [1]: [d_date_sk#72]
+Output [1]: [d_date_sk#71]
 
 (105) BroadcastHashJoin [codegen id : 2]
-Left keys [1]: [ss_sold_date_sk#70]
-Right keys [1]: [d_date_sk#72]
+Left keys [1]: [ss_sold_date_sk#69]
+Right keys [1]: [d_date_sk#71]
 Join condition: None
 
 (106) Project [codegen id : 2]
-Output [3]: [ss_customer_sk#67, ss_quantity#68, ss_sales_price#69]
-Input [5]: [ss_customer_sk#67, ss_quantity#68, ss_sales_price#69, 
ss_sold_date_sk#70, d_date_sk#72]
+Output [3]: [ss_customer_sk#66, ss_quantity#67, ss_sales_price#68]
+Input [5]: [ss_customer_sk#66, ss_quantity#67, ss_sales_price#68, 
ss_sold_date_sk#69, d_date_sk#71]
 
 (107) Exchange
-Input [3]: [ss_customer_sk#67, ss_quantity#68, ss_sales_price#69]
-Arguments: hashpartitioning(ss_customer_sk#67, 5), ENSURE_REQUIREMENTS, 
[id=#73]
+Input [3]: [ss_customer_sk#66, ss_quantity#67, ss_sales_price#68]
+Arguments: hashpartitioning(ss_customer_sk#66, 5), ENSURE_REQUIREMENTS, 
[id=#72]
 
 (108) Sort [codegen id : 3]
-Input [3]: [ss_customer_sk#67, ss_quantity#68, ss_sales_price#69]
-Arguments: [ss_customer_sk#67 ASC NULLS FIRST], false, 0
+Input [3]: [ss_customer_sk#66, ss_quantity#67, ss_sales_price#68]
+Arguments: [ss_customer_sk#66 ASC NULLS FIRST], false, 0
 
 (109) ReusedExchange [Reuses operator id: 38]
-Output [1]: [c_customer_sk#74]
+Output [1]: [c_customer_sk#73]
 
 (110) Sort [codegen id : 5]
-Input [1]: [c_customer_sk#74]
-Arguments: [c_customer_sk#74 ASC NULLS FIRST], false, 0
+Input [1]: [c_customer_sk#73]
+Arguments: [c_customer_sk#73 ASC NULLS FIRST], false, 0
 
 (111) SortMergeJoin [codegen id : 6]
-Left keys [1]: [ss_customer_sk#67]
-Right keys [1]: [c_customer_sk#74]
+Left keys [1]: [ss_customer_sk#66]
+Right keys [1]: [c_customer_sk#73]
 Join condition: None
 
 (112) Project [codegen id : 6]
-Output [3]: [ss_quantity#68, ss_sales_price#69, c_customer_sk#74]
-Input [4]: [ss_customer_sk#67, ss_quantity#68, ss_sales_price#69, 
c_customer_sk#74]
+Output [3]: [ss_quantity#67, ss_sales_price#68, c_customer_sk#73]
+Input [4]: [ss_customer_sk#66, ss_quantity#67, ss_sales_price#68, 
c_customer_sk#73]
 
 (113) HashAggregate [codegen id : 6]
-Input [3]: [ss_quantity#68, ss_sales_price#69, c_customer_sk#74]
-Keys [1]: [c_customer_sk#74]
-Functions [1]: 
[partial_sum(CheckOverflow((promote_precision(cast(ss_quantity#68 as 
decimal(12,2))) * promote_precision(cast(ss_sales_price#69 as decimal(12,2)))), 
DecimalType(18,2)))]
-Aggregate Attributes [2]: [sum#75, isEmpty#76]
-Results [3]: [c_customer_sk#74, sum#77, isEmpty#78]
+Input [3]: [ss_quantity#67, ss_sales_price#68, c_customer_sk#73]
+Keys [1]: [c_customer_sk#73]
+Functions [1]: 
[partial_sum(CheckOverflow((promote_precision(cast(ss_quantity#67 as 
decimal(12,2))) * promote_precision(cast(ss_sales_price#68 as decimal(12,2)))), 
DecimalType(18,2)))]
+Aggregate Attributes [2]: [sum#74, isEmpty#75]
+Results [3]: [c_customer_sk#73, sum#76, isEmpty#77]
 
 (114) HashAggregate [codegen id : 6]
-Input [3]: [c_customer_sk#74, sum#77, isEmpty#78]
-Keys [1]: [c_customer_sk#74]
-Functions [1]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#68 as 
decimal(12,2))) * promote_precision(cast(ss_sales_price#69 as decimal(12,2)))), 
DecimalType(18,2)))]
-Aggregate Attributes [1]: 
[sum(CheckOverflow((promote_precision(cast(ss_quantity#68 as decimal(12,2))) * 
promote_precision(cast(ss_sales_price#69 as decimal(12,2)))), 
DecimalType(18,2)))#79]
-Results [1]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#68 as 
decimal(12,2))) * promote_precision(cast(ss_sales_price#69 as decimal(12,2)))), 
DecimalType(18,2)))#79 AS csales#80]
+Input [3]: [c_customer_sk#73, sum#76, isEmpty#77]
+Keys [1]: [c_customer_sk#73]
+Functions [1]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#67 as 
decimal(12,2))) * promote_precision(cast(ss_sales_price#68 as decimal(12,2)))), 
DecimalType(18,2)))]
+Aggregate Attributes [1]: 
[sum(CheckOverflow((promote_precision(cast(ss_quantity#67 as decimal(12,2))) * 
promote_precision(cast(ss_sales_price#68 as decimal(12,2)))), 
DecimalType(18,2)))#78]
+Results [1]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#67 as 
decimal(12,2))) * promote_precision(cast(ss_sales_price#68 as decimal(12,2)))), 
DecimalType(18,2)))#78 AS csales#79]
 
 (115) HashAggregate [codegen id : 6]
-Input [1]: [csales#80]
+Input [1]: [csales#79]
 Keys: []
-Functions [1]: [partial_max(csales#80)]
-Aggregate Attributes [1]: [max#81]
-Results [1]: [max#82]
+Functions [1]: [partial_max(csales#79)]
+Aggregate Attributes [1]: [max#80]
+Results [1]: [max#81]
 
 (116) Exchange
-Input [1]: [max#82]
-Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#83]
+Input [1]: [max#81]
+Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#82]
 
 (117) HashAggregate [codegen id : 7]
-Input [1]: [max#82]
+Input [1]: [max#81]
 Keys: []
-Functions [1]: [max(csales#80)]
-Aggregate Attributes [1]: [max(csales#80)#84]
-Results [1]: [max(csales#80)#84 AS tpcds_cmax#85]
+Functions [1]: [max(csales#79)]
+Aggregate Attributes [1]: [max(csales#79)#83]
+Results [1]: [max(csales#79)#83 AS tpcds_cmax#84]
 
-Subquery:4 Hosting operator id = 101 Hosting Expression = ss_sold_date_sk#70 
IN dynamicpruning#71
+Subquery:4 Hosting operator id = 101 Hosting Expression = ss_sold_date_sk#69 
IN dynamicpruning#70
 BroadcastExchange (122)
 +- * Project (121)
    +- * Filter (120)
@@ -670,26 +670,26 @@ BroadcastExchange (122)
 
 
 (118) Scan parquet default.date_dim
-Output [2]: [d_date_sk#72, d_year#86]
+Output [2]: [d_date_sk#71, d_year#85]
 Batched: true
 Location [not included in comparison]/{warehouse_dir}/date_dim]
 PushedFilters: [In(d_year, [2000,2001,2002,2003]), IsNotNull(d_date_sk)]
 ReadSchema: struct<d_date_sk:int,d_year:int>
 
 (119) ColumnarToRow [codegen id : 1]
-Input [2]: [d_date_sk#72, d_year#86]
+Input [2]: [d_date_sk#71, d_year#85]
 
 (120) Filter [codegen id : 1]
-Input [2]: [d_date_sk#72, d_year#86]
-Condition : (d_year#86 IN (2000,2001,2002,2003) AND isnotnull(d_date_sk#72))
+Input [2]: [d_date_sk#71, d_year#85]
+Condition : (d_year#85 IN (2000,2001,2002,2003) AND isnotnull(d_date_sk#71))
 
 (121) Project [codegen id : 1]
-Output [1]: [d_date_sk#72]
-Input [2]: [d_date_sk#72, d_year#86]
+Output [1]: [d_date_sk#71]
+Input [2]: [d_date_sk#71, d_year#85]
 
 (122) BroadcastExchange
-Input [1]: [d_date_sk#72]
-Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as 
bigint)),false), [id=#87]
+Input [1]: [d_date_sk#71]
+Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as 
bigint)),false), [id=#86]
 
 Subquery:5 Hosting operator id = 52 Hosting Expression = ws_sold_date_sk#45 IN 
dynamicpruning#6
 
diff --git 
a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q23b.sf100/explain.txt
 
b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q23b.sf100/explain.txt
index 3de1f246134..4d1109078e3 100644
--- 
a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q23b.sf100/explain.txt
+++ 
b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q23b.sf100/explain.txt
@@ -508,19 +508,19 @@ Right keys [1]: [i_item_sk#14]
 Join condition: None
 
 (84) Project [codegen id : 35]
-Output [3]: [d_date#12, i_item_sk#14, substr(i_item_desc#15, 1, 30) AS 
_groupingexpression#57]
+Output [3]: [d_date#12, i_item_sk#14, substr(i_item_desc#15, 1, 30) AS 
_groupingexpression#17]
 Input [4]: [ss_item_sk#8, d_date#12, i_item_sk#14, i_item_desc#15]
 
 (85) HashAggregate [codegen id : 35]
-Input [3]: [d_date#12, i_item_sk#14, _groupingexpression#57]
-Keys [3]: [_groupingexpression#57, i_item_sk#14, d_date#12]
+Input [3]: [d_date#12, i_item_sk#14, _groupingexpression#17]
+Keys [3]: [_groupingexpression#17, i_item_sk#14, d_date#12]
 Functions [1]: [partial_count(1)]
 Aggregate Attributes [1]: [count#18]
-Results [4]: [_groupingexpression#57, i_item_sk#14, d_date#12, count#19]
+Results [4]: [_groupingexpression#17, i_item_sk#14, d_date#12, count#19]
 
 (86) HashAggregate [codegen id : 35]
-Input [4]: [_groupingexpression#57, i_item_sk#14, d_date#12, count#19]
-Keys [3]: [_groupingexpression#57, i_item_sk#14, d_date#12]
+Input [4]: [_groupingexpression#17, i_item_sk#14, d_date#12, count#19]
+Keys [3]: [_groupingexpression#17, i_item_sk#14, d_date#12]
 Functions [1]: [count(1)]
 Aggregate Attributes [1]: [count(1)#20]
 Results [2]: [i_item_sk#14 AS item_sk#21, count(1)#20 AS cnt#22]
@@ -548,7 +548,7 @@ Input [5]: [ws_item_sk#51, ws_bill_customer_sk#52, 
ws_quantity#53, ws_list_price
 
 (92) Exchange
 Input [4]: [ws_bill_customer_sk#52, ws_quantity#53, ws_list_price#54, 
ws_sold_date_sk#55]
-Arguments: hashpartitioning(ws_bill_customer_sk#52, 5), ENSURE_REQUIREMENTS, 
[id=#58]
+Arguments: hashpartitioning(ws_bill_customer_sk#52, 5), ENSURE_REQUIREMENTS, 
[id=#57]
 
 (93) Sort [codegen id : 37]
 Input [4]: [ws_bill_customer_sk#52, ws_quantity#53, ws_list_price#54, 
ws_sold_date_sk#55]
@@ -581,11 +581,11 @@ Input [4]: [ss_customer_sk#24, ss_quantity#25, 
ss_sales_price#26, c_customer_sk#
 Input [3]: [ss_quantity#25, ss_sales_price#26, c_customer_sk#29]
 Keys [1]: [c_customer_sk#29]
 Functions [1]: 
[partial_sum(CheckOverflow((promote_precision(cast(ss_quantity#25 as 
decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), 
DecimalType(18,2)))]
-Aggregate Attributes [2]: [sum#59, isEmpty#60]
-Results [3]: [c_customer_sk#29, sum#61, isEmpty#62]
+Aggregate Attributes [2]: [sum#58, isEmpty#59]
+Results [3]: [c_customer_sk#29, sum#60, isEmpty#61]
 
 (101) HashAggregate [codegen id : 42]
-Input [3]: [c_customer_sk#29, sum#61, isEmpty#62]
+Input [3]: [c_customer_sk#29, sum#60, isEmpty#61]
 Keys [1]: [c_customer_sk#29]
 Functions [1]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#25 as 
decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), 
DecimalType(18,2)))]
 Aggregate Attributes [1]: 
[sum(CheckOverflow((promote_precision(cast(ss_quantity#25 as decimal(12,2))) * 
promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), 
DecimalType(18,2)))#35]
@@ -609,23 +609,23 @@ Right keys [1]: [c_customer_sk#29]
 Join condition: None
 
 (106) ReusedExchange [Reuses operator id: 134]
-Output [1]: [d_date_sk#63]
+Output [1]: [d_date_sk#62]
 
 (107) BroadcastHashJoin [codegen id : 44]
 Left keys [1]: [ws_sold_date_sk#55]
-Right keys [1]: [d_date_sk#63]
+Right keys [1]: [d_date_sk#62]
 Join condition: None
 
 (108) Project [codegen id : 44]
 Output [3]: [ws_bill_customer_sk#52, ws_quantity#53, ws_list_price#54]
-Input [5]: [ws_bill_customer_sk#52, ws_quantity#53, ws_list_price#54, 
ws_sold_date_sk#55, d_date_sk#63]
+Input [5]: [ws_bill_customer_sk#52, ws_quantity#53, ws_list_price#54, 
ws_sold_date_sk#55, d_date_sk#62]
 
 (109) ReusedExchange [Reuses operator id: 55]
-Output [3]: [c_customer_sk#64, c_first_name#65, c_last_name#66]
+Output [3]: [c_customer_sk#63, c_first_name#64, c_last_name#65]
 
 (110) Sort [codegen id : 46]
-Input [3]: [c_customer_sk#64, c_first_name#65, c_last_name#66]
-Arguments: [c_customer_sk#64 ASC NULLS FIRST], false, 0
+Input [3]: [c_customer_sk#63, c_first_name#64, c_last_name#65]
+Arguments: [c_customer_sk#63 ASC NULLS FIRST], false, 0
 
 (111) ReusedExchange [Reuses operator id: 34]
 Output [3]: [ss_customer_sk#24, ss_quantity#25, ss_sales_price#26]
@@ -654,11 +654,11 @@ Input [4]: [ss_customer_sk#24, ss_quantity#25, 
ss_sales_price#26, c_customer_sk#
 Input [3]: [ss_quantity#25, ss_sales_price#26, c_customer_sk#29]
 Keys [1]: [c_customer_sk#29]
 Functions [1]: 
[partial_sum(CheckOverflow((promote_precision(cast(ss_quantity#25 as 
decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), 
DecimalType(18,2)))]
-Aggregate Attributes [2]: [sum#59, isEmpty#60]
-Results [3]: [c_customer_sk#29, sum#61, isEmpty#62]
+Aggregate Attributes [2]: [sum#58, isEmpty#59]
+Results [3]: [c_customer_sk#29, sum#60, isEmpty#61]
 
 (118) HashAggregate [codegen id : 51]
-Input [3]: [c_customer_sk#29, sum#61, isEmpty#62]
+Input [3]: [c_customer_sk#29, sum#60, isEmpty#61]
 Keys [1]: [c_customer_sk#29]
 Functions [1]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#25 as 
decimal(12,2))) * promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), 
DecimalType(18,2)))]
 Aggregate Attributes [1]: 
[sum(CheckOverflow((promote_precision(cast(ss_quantity#25 as decimal(12,2))) * 
promote_precision(cast(ss_sales_price#26 as decimal(12,2)))), 
DecimalType(18,2)))#35]
@@ -677,36 +677,36 @@ Input [1]: [c_customer_sk#29]
 Arguments: [c_customer_sk#29 ASC NULLS FIRST], false, 0
 
 (122) SortMergeJoin [codegen id : 52]
-Left keys [1]: [c_customer_sk#64]
+Left keys [1]: [c_customer_sk#63]
 Right keys [1]: [c_customer_sk#29]
 Join condition: None
 
 (123) SortMergeJoin [codegen id : 53]
 Left keys [1]: [ws_bill_customer_sk#52]
-Right keys [1]: [c_customer_sk#64]
+Right keys [1]: [c_customer_sk#63]
 Join condition: None
 
 (124) Project [codegen id : 53]
-Output [4]: [ws_quantity#53, ws_list_price#54, c_first_name#65, c_last_name#66]
-Input [6]: [ws_bill_customer_sk#52, ws_quantity#53, ws_list_price#54, 
c_customer_sk#64, c_first_name#65, c_last_name#66]
+Output [4]: [ws_quantity#53, ws_list_price#54, c_first_name#64, c_last_name#65]
+Input [6]: [ws_bill_customer_sk#52, ws_quantity#53, ws_list_price#54, 
c_customer_sk#63, c_first_name#64, c_last_name#65]
 
 (125) HashAggregate [codegen id : 53]
-Input [4]: [ws_quantity#53, ws_list_price#54, c_first_name#65, c_last_name#66]
-Keys [2]: [c_last_name#66, c_first_name#65]
+Input [4]: [ws_quantity#53, ws_list_price#54, c_first_name#64, c_last_name#65]
+Keys [2]: [c_last_name#65, c_first_name#64]
 Functions [1]: 
[partial_sum(CheckOverflow((promote_precision(cast(ws_quantity#53 as 
decimal(12,2))) * promote_precision(cast(ws_list_price#54 as decimal(12,2)))), 
DecimalType(18,2)))]
-Aggregate Attributes [2]: [sum#67, isEmpty#68]
-Results [4]: [c_last_name#66, c_first_name#65, sum#69, isEmpty#70]
+Aggregate Attributes [2]: [sum#66, isEmpty#67]
+Results [4]: [c_last_name#65, c_first_name#64, sum#68, isEmpty#69]
 
 (126) Exchange
-Input [4]: [c_last_name#66, c_first_name#65, sum#69, isEmpty#70]
-Arguments: hashpartitioning(c_last_name#66, c_first_name#65, 5), 
ENSURE_REQUIREMENTS, [id=#71]
+Input [4]: [c_last_name#65, c_first_name#64, sum#68, isEmpty#69]
+Arguments: hashpartitioning(c_last_name#65, c_first_name#64, 5), 
ENSURE_REQUIREMENTS, [id=#70]
 
 (127) HashAggregate [codegen id : 54]
-Input [4]: [c_last_name#66, c_first_name#65, sum#69, isEmpty#70]
-Keys [2]: [c_last_name#66, c_first_name#65]
+Input [4]: [c_last_name#65, c_first_name#64, sum#68, isEmpty#69]
+Keys [2]: [c_last_name#65, c_first_name#64]
 Functions [1]: [sum(CheckOverflow((promote_precision(cast(ws_quantity#53 as 
decimal(12,2))) * promote_precision(cast(ws_list_price#54 as decimal(12,2)))), 
DecimalType(18,2)))]
-Aggregate Attributes [1]: 
[sum(CheckOverflow((promote_precision(cast(ws_quantity#53 as decimal(12,2))) * 
promote_precision(cast(ws_list_price#54 as decimal(12,2)))), 
DecimalType(18,2)))#72]
-Results [3]: [c_last_name#66, c_first_name#65, 
sum(CheckOverflow((promote_precision(cast(ws_quantity#53 as decimal(12,2))) * 
promote_precision(cast(ws_list_price#54 as decimal(12,2)))), 
DecimalType(18,2)))#72 AS sales#73]
+Aggregate Attributes [1]: 
[sum(CheckOverflow((promote_precision(cast(ws_quantity#53 as decimal(12,2))) * 
promote_precision(cast(ws_list_price#54 as decimal(12,2)))), 
DecimalType(18,2)))#71]
+Results [3]: [c_last_name#65, c_first_name#64, 
sum(CheckOverflow((promote_precision(cast(ws_quantity#53 as decimal(12,2))) * 
promote_precision(cast(ws_list_price#54 as decimal(12,2)))), 
DecimalType(18,2)))#71 AS sales#72]
 
 (128) Union
 
@@ -725,26 +725,26 @@ BroadcastExchange (134)
 
 
 (130) Scan parquet default.date_dim
-Output [3]: [d_date_sk#39, d_year#74, d_moy#75]
+Output [3]: [d_date_sk#39, d_year#73, d_moy#74]
 Batched: true
 Location [not included in comparison]/{warehouse_dir}/date_dim]
 PushedFilters: [IsNotNull(d_year), IsNotNull(d_moy), EqualTo(d_year,2000), 
EqualTo(d_moy,2), IsNotNull(d_date_sk)]
 ReadSchema: struct<d_date_sk:int,d_year:int,d_moy:int>
 
 (131) ColumnarToRow [codegen id : 1]
-Input [3]: [d_date_sk#39, d_year#74, d_moy#75]
+Input [3]: [d_date_sk#39, d_year#73, d_moy#74]
 
 (132) Filter [codegen id : 1]
-Input [3]: [d_date_sk#39, d_year#74, d_moy#75]
-Condition : ((((isnotnull(d_year#74) AND isnotnull(d_moy#75)) AND (d_year#74 = 
2000)) AND (d_moy#75 = 2)) AND isnotnull(d_date_sk#39))
+Input [3]: [d_date_sk#39, d_year#73, d_moy#74]
+Condition : ((((isnotnull(d_year#73) AND isnotnull(d_moy#74)) AND (d_year#73 = 
2000)) AND (d_moy#74 = 2)) AND isnotnull(d_date_sk#39))
 
 (133) Project [codegen id : 1]
 Output [1]: [d_date_sk#39]
-Input [3]: [d_date_sk#39, d_year#74, d_moy#75]
+Input [3]: [d_date_sk#39, d_year#73, d_moy#74]
 
 (134) BroadcastExchange
 Input [1]: [d_date_sk#39]
-Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as 
bigint)),false), [id=#76]
+Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as 
bigint)),false), [id=#75]
 
 Subquery:2 Hosting operator id = 6 Hosting Expression = ss_sold_date_sk#9 IN 
dynamicpruning#10
 BroadcastExchange (139)
@@ -755,26 +755,26 @@ BroadcastExchange (139)
 
 
 (135) Scan parquet default.date_dim
-Output [3]: [d_date_sk#11, d_date#12, d_year#77]
+Output [3]: [d_date_sk#11, d_date#12, d_year#76]
 Batched: true
 Location [not included in comparison]/{warehouse_dir}/date_dim]
 PushedFilters: [In(d_year, [2000,2001,2002,2003]), IsNotNull(d_date_sk)]
 ReadSchema: struct<d_date_sk:int,d_date:date,d_year:int>
 
 (136) ColumnarToRow [codegen id : 1]
-Input [3]: [d_date_sk#11, d_date#12, d_year#77]
+Input [3]: [d_date_sk#11, d_date#12, d_year#76]
 
 (137) Filter [codegen id : 1]
-Input [3]: [d_date_sk#11, d_date#12, d_year#77]
-Condition : (d_year#77 IN (2000,2001,2002,2003) AND isnotnull(d_date_sk#11))
+Input [3]: [d_date_sk#11, d_date#12, d_year#76]
+Condition : (d_year#76 IN (2000,2001,2002,2003) AND isnotnull(d_date_sk#11))
 
 (138) Project [codegen id : 1]
 Output [2]: [d_date_sk#11, d_date#12]
-Input [3]: [d_date_sk#11, d_date#12, d_year#77]
+Input [3]: [d_date_sk#11, d_date#12, d_year#76]
 
 (139) BroadcastExchange
 Input [2]: [d_date_sk#11, d_date#12]
-Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as 
bigint)),false), [id=#78]
+Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as 
bigint)),false), [id=#77]
 
 Subquery:3 Hosting operator id = 45 Hosting Expression = Subquery 
scalar-subquery#37, [id=#38]
 * HashAggregate (156)
@@ -797,89 +797,89 @@ Subquery:3 Hosting operator id = 45 Hosting Expression = 
Subquery scalar-subquer
 
 
 (140) Scan parquet default.store_sales
-Output [4]: [ss_customer_sk#79, ss_quantity#80, ss_sales_price#81, 
ss_sold_date_sk#82]
+Output [4]: [ss_customer_sk#78, ss_quantity#79, ss_sales_price#80, 
ss_sold_date_sk#81]
 Batched: true
 Location: InMemoryFileIndex []
-PartitionFilters: [isnotnull(ss_sold_date_sk#82), 
dynamicpruningexpression(ss_sold_date_sk#82 IN dynamicpruning#83)]
+PartitionFilters: [isnotnull(ss_sold_date_sk#81), 
dynamicpruningexpression(ss_sold_date_sk#81 IN dynamicpruning#82)]
 PushedFilters: [IsNotNull(ss_customer_sk)]
 ReadSchema: 
struct<ss_customer_sk:int,ss_quantity:int,ss_sales_price:decimal(7,2)>
 
 (141) ColumnarToRow [codegen id : 2]
-Input [4]: [ss_customer_sk#79, ss_quantity#80, ss_sales_price#81, 
ss_sold_date_sk#82]
+Input [4]: [ss_customer_sk#78, ss_quantity#79, ss_sales_price#80, 
ss_sold_date_sk#81]
 
 (142) Filter [codegen id : 2]
-Input [4]: [ss_customer_sk#79, ss_quantity#80, ss_sales_price#81, 
ss_sold_date_sk#82]
-Condition : isnotnull(ss_customer_sk#79)
+Input [4]: [ss_customer_sk#78, ss_quantity#79, ss_sales_price#80, 
ss_sold_date_sk#81]
+Condition : isnotnull(ss_customer_sk#78)
 
 (143) ReusedExchange [Reuses operator id: 161]
-Output [1]: [d_date_sk#84]
+Output [1]: [d_date_sk#83]
 
 (144) BroadcastHashJoin [codegen id : 2]
-Left keys [1]: [ss_sold_date_sk#82]
-Right keys [1]: [d_date_sk#84]
+Left keys [1]: [ss_sold_date_sk#81]
+Right keys [1]: [d_date_sk#83]
 Join condition: None
 
 (145) Project [codegen id : 2]
-Output [3]: [ss_customer_sk#79, ss_quantity#80, ss_sales_price#81]
-Input [5]: [ss_customer_sk#79, ss_quantity#80, ss_sales_price#81, 
ss_sold_date_sk#82, d_date_sk#84]
+Output [3]: [ss_customer_sk#78, ss_quantity#79, ss_sales_price#80]
+Input [5]: [ss_customer_sk#78, ss_quantity#79, ss_sales_price#80, 
ss_sold_date_sk#81, d_date_sk#83]
 
 (146) Exchange
-Input [3]: [ss_customer_sk#79, ss_quantity#80, ss_sales_price#81]
-Arguments: hashpartitioning(ss_customer_sk#79, 5), ENSURE_REQUIREMENTS, 
[id=#85]
+Input [3]: [ss_customer_sk#78, ss_quantity#79, ss_sales_price#80]
+Arguments: hashpartitioning(ss_customer_sk#78, 5), ENSURE_REQUIREMENTS, 
[id=#84]
 
 (147) Sort [codegen id : 3]
-Input [3]: [ss_customer_sk#79, ss_quantity#80, ss_sales_price#81]
-Arguments: [ss_customer_sk#79 ASC NULLS FIRST], false, 0
+Input [3]: [ss_customer_sk#78, ss_quantity#79, ss_sales_price#80]
+Arguments: [ss_customer_sk#78 ASC NULLS FIRST], false, 0
 
 (148) ReusedExchange [Reuses operator id: 39]
-Output [1]: [c_customer_sk#86]
+Output [1]: [c_customer_sk#85]
 
 (149) Sort [codegen id : 5]
-Input [1]: [c_customer_sk#86]
-Arguments: [c_customer_sk#86 ASC NULLS FIRST], false, 0
+Input [1]: [c_customer_sk#85]
+Arguments: [c_customer_sk#85 ASC NULLS FIRST], false, 0
 
 (150) SortMergeJoin [codegen id : 6]
-Left keys [1]: [ss_customer_sk#79]
-Right keys [1]: [c_customer_sk#86]
+Left keys [1]: [ss_customer_sk#78]
+Right keys [1]: [c_customer_sk#85]
 Join condition: None
 
 (151) Project [codegen id : 6]
-Output [3]: [ss_quantity#80, ss_sales_price#81, c_customer_sk#86]
-Input [4]: [ss_customer_sk#79, ss_quantity#80, ss_sales_price#81, 
c_customer_sk#86]
+Output [3]: [ss_quantity#79, ss_sales_price#80, c_customer_sk#85]
+Input [4]: [ss_customer_sk#78, ss_quantity#79, ss_sales_price#80, 
c_customer_sk#85]
 
 (152) HashAggregate [codegen id : 6]
-Input [3]: [ss_quantity#80, ss_sales_price#81, c_customer_sk#86]
-Keys [1]: [c_customer_sk#86]
-Functions [1]: 
[partial_sum(CheckOverflow((promote_precision(cast(ss_quantity#80 as 
decimal(12,2))) * promote_precision(cast(ss_sales_price#81 as decimal(12,2)))), 
DecimalType(18,2)))]
-Aggregate Attributes [2]: [sum#87, isEmpty#88]
-Results [3]: [c_customer_sk#86, sum#89, isEmpty#90]
+Input [3]: [ss_quantity#79, ss_sales_price#80, c_customer_sk#85]
+Keys [1]: [c_customer_sk#85]
+Functions [1]: 
[partial_sum(CheckOverflow((promote_precision(cast(ss_quantity#79 as 
decimal(12,2))) * promote_precision(cast(ss_sales_price#80 as decimal(12,2)))), 
DecimalType(18,2)))]
+Aggregate Attributes [2]: [sum#86, isEmpty#87]
+Results [3]: [c_customer_sk#85, sum#88, isEmpty#89]
 
 (153) HashAggregate [codegen id : 6]
-Input [3]: [c_customer_sk#86, sum#89, isEmpty#90]
-Keys [1]: [c_customer_sk#86]
-Functions [1]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#80 as 
decimal(12,2))) * promote_precision(cast(ss_sales_price#81 as decimal(12,2)))), 
DecimalType(18,2)))]
-Aggregate Attributes [1]: 
[sum(CheckOverflow((promote_precision(cast(ss_quantity#80 as decimal(12,2))) * 
promote_precision(cast(ss_sales_price#81 as decimal(12,2)))), 
DecimalType(18,2)))#91]
-Results [1]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#80 as 
decimal(12,2))) * promote_precision(cast(ss_sales_price#81 as decimal(12,2)))), 
DecimalType(18,2)))#91 AS csales#92]
+Input [3]: [c_customer_sk#85, sum#88, isEmpty#89]
+Keys [1]: [c_customer_sk#85]
+Functions [1]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#79 as 
decimal(12,2))) * promote_precision(cast(ss_sales_price#80 as decimal(12,2)))), 
DecimalType(18,2)))]
+Aggregate Attributes [1]: 
[sum(CheckOverflow((promote_precision(cast(ss_quantity#79 as decimal(12,2))) * 
promote_precision(cast(ss_sales_price#80 as decimal(12,2)))), 
DecimalType(18,2)))#90]
+Results [1]: [sum(CheckOverflow((promote_precision(cast(ss_quantity#79 as 
decimal(12,2))) * promote_precision(cast(ss_sales_price#80 as decimal(12,2)))), 
DecimalType(18,2)))#90 AS csales#91]
 
 (154) HashAggregate [codegen id : 6]
-Input [1]: [csales#92]
+Input [1]: [csales#91]
 Keys: []
-Functions [1]: [partial_max(csales#92)]
-Aggregate Attributes [1]: [max#93]
-Results [1]: [max#94]
+Functions [1]: [partial_max(csales#91)]
+Aggregate Attributes [1]: [max#92]
+Results [1]: [max#93]
 
 (155) Exchange
-Input [1]: [max#94]
-Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#95]
+Input [1]: [max#93]
+Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#94]
 
 (156) HashAggregate [codegen id : 7]
-Input [1]: [max#94]
+Input [1]: [max#93]
 Keys: []
-Functions [1]: [max(csales#92)]
-Aggregate Attributes [1]: [max(csales#92)#96]
-Results [1]: [max(csales#92)#96 AS tpcds_cmax#97]
+Functions [1]: [max(csales#91)]
+Aggregate Attributes [1]: [max(csales#91)#95]
+Results [1]: [max(csales#91)#95 AS tpcds_cmax#96]
 
-Subquery:4 Hosting operator id = 140 Hosting Expression = ss_sold_date_sk#82 
IN dynamicpruning#83
+Subquery:4 Hosting operator id = 140 Hosting Expression = ss_sold_date_sk#81 
IN dynamicpruning#82
 BroadcastExchange (161)
 +- * Project (160)
    +- * Filter (159)
@@ -888,26 +888,26 @@ BroadcastExchange (161)
 
 
 (157) Scan parquet default.date_dim
-Output [2]: [d_date_sk#84, d_year#98]
+Output [2]: [d_date_sk#83, d_year#97]
 Batched: true
 Location [not included in comparison]/{warehouse_dir}/date_dim]
 PushedFilters: [In(d_year, [2000,2001,2002,2003]), IsNotNull(d_date_sk)]
 ReadSchema: struct<d_date_sk:int,d_year:int>
 
 (158) ColumnarToRow [codegen id : 1]
-Input [2]: [d_date_sk#84, d_year#98]
+Input [2]: [d_date_sk#83, d_year#97]
 
 (159) Filter [codegen id : 1]
-Input [2]: [d_date_sk#84, d_year#98]
-Condition : (d_year#98 IN (2000,2001,2002,2003) AND isnotnull(d_date_sk#84))
+Input [2]: [d_date_sk#83, d_year#97]
+Condition : (d_year#97 IN (2000,2001,2002,2003) AND isnotnull(d_date_sk#83))
 
 (160) Project [codegen id : 1]
-Output [1]: [d_date_sk#84]
-Input [2]: [d_date_sk#84, d_year#98]
+Output [1]: [d_date_sk#83]
+Input [2]: [d_date_sk#83, d_year#97]
 
 (161) BroadcastExchange
-Input [1]: [d_date_sk#84]
-Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as 
bigint)),false), [id=#99]
+Input [1]: [d_date_sk#83]
+Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as 
bigint)),false), [id=#98]
 
 Subquery:5 Hosting operator id = 65 Hosting Expression = ReusedSubquery 
Subquery scalar-subquery#37, [id=#38]
 
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CTEInlineSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/CTEInlineSuite.scala
index dd30ff68da4..7d45102ac83 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CTEInlineSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CTEInlineSuite.scala
@@ -17,7 +17,8 @@
 
 package org.apache.spark.sql
 
-import org.apache.spark.sql.catalyst.plans.logical.WithCTE
+import org.apache.spark.sql.catalyst.expressions.{And, GreaterThan, LessThan, 
Literal, Or}
+import org.apache.spark.sql.catalyst.plans.logical.{Filter, Project, 
RepartitionOperation, WithCTE}
 import org.apache.spark.sql.execution.adaptive._
 import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
 import org.apache.spark.sql.internal.SQLConf
@@ -42,7 +43,7 @@ abstract class CTEInlineSuiteBase
          """.stripMargin)
       checkAnswer(df, Nil)
       assert(
-        df.queryExecution.optimizedPlan.exists(_.isInstanceOf[WithCTE]),
+        
df.queryExecution.optimizedPlan.exists(_.isInstanceOf[RepartitionOperation]),
         "Non-deterministic With-CTE with multiple references should be not 
inlined.")
     }
   }
@@ -59,7 +60,7 @@ abstract class CTEInlineSuiteBase
          """.stripMargin)
       checkAnswer(df, Nil)
       assert(
-        df.queryExecution.optimizedPlan.exists(_.isInstanceOf[WithCTE]),
+        
df.queryExecution.optimizedPlan.exists(_.isInstanceOf[RepartitionOperation]),
         "Non-deterministic With-CTE with multiple references should be not 
inlined.")
     }
   }
@@ -79,7 +80,7 @@ abstract class CTEInlineSuiteBase
         df.queryExecution.analyzed.exists(_.isInstanceOf[WithCTE]),
         "With-CTE should not be inlined in analyzed plan.")
       assert(
-        !df.queryExecution.optimizedPlan.exists(_.isInstanceOf[WithCTE]),
+        
!df.queryExecution.optimizedPlan.exists(_.isInstanceOf[RepartitionOperation]),
         "With-CTE with one reference should be inlined in optimized plan.")
     }
   }
@@ -107,8 +108,8 @@ abstract class CTEInlineSuiteBase
         "With-CTE should contain 2 CTE defs after analysis.")
       assert(
         df.queryExecution.optimizedPlan.collect {
-          case WithCTE(_, cteDefs) => cteDefs
-        }.head.length == 2,
+          case r: RepartitionOperation => r
+        }.length == 6,
         "With-CTE should contain 2 CTE def after optimization.")
     }
   }
@@ -136,8 +137,8 @@ abstract class CTEInlineSuiteBase
         "With-CTE should contain 2 CTE defs after analysis.")
       assert(
         df.queryExecution.optimizedPlan.collect {
-          case WithCTE(_, cteDefs) => cteDefs
-        }.head.length == 1,
+          case r: RepartitionOperation => r
+        }.length == 4,
         "One CTE def should be inlined after optimization.")
     }
   }
@@ -163,7 +164,7 @@ abstract class CTEInlineSuiteBase
         "With-CTE should contain 2 CTE defs after analysis.")
       assert(
         df.queryExecution.optimizedPlan.collect {
-          case WithCTE(_, cteDefs) => cteDefs
+          case r: RepartitionOperation => r
         }.isEmpty,
         "CTEs with one reference should all be inlined after optimization.")
     }
@@ -248,7 +249,7 @@ abstract class CTEInlineSuiteBase
         "With-CTE should contain 2 CTE defs after analysis.")
       assert(
         df.queryExecution.optimizedPlan.collect {
-          case WithCTE(_, cteDefs) => cteDefs
+          case r: RepartitionOperation => r
         }.isEmpty,
         "Deterministic CTEs should all be inlined after optimization.")
     }
@@ -272,6 +273,214 @@ abstract class CTEInlineSuiteBase
       assert(ex.message.contains("Table or view not found: v1"))
     }
   }
+
+  test("CTE Predicate push-down and column pruning") {
+    withView("t") {
+      Seq((0, 1), (1, 2)).toDF("c1", "c2").createOrReplaceTempView("t")
+      val df = sql(
+        s"""with
+           |v as (
+           |  select c1, c2, 's' c3, rand() c4 from t
+           |),
+           |vv as (
+           |  select v1.c1, v1.c2, rand() c5 from v v1, v v2
+           |  where v1.c1 > 0 and v1.c3 = 's' and v1.c2 = v2.c2
+           |)
+           |select vv1.c1, vv1.c2, vv2.c1, vv2.c2 from vv vv1, vv vv2
+           |where vv1.c2 > 0 and vv2.c2 > 0 and vv1.c1 = vv2.c1
+         """.stripMargin)
+      checkAnswer(df, Row(1, 2, 1, 2) :: Nil)
+      assert(
+        df.queryExecution.analyzed.collect {
+          case WithCTE(_, cteDefs) => cteDefs
+        }.head.length == 2,
+        "With-CTE should contain 2 CTE defs after analysis.")
+      val cteRepartitions = df.queryExecution.optimizedPlan.collect {
+        case r: RepartitionOperation => r
+      }
+      assert(cteRepartitions.length == 6,
+        "CTE should not be inlined after optimization.")
+      val distinctCteRepartitions = 
cteRepartitions.map(_.canonicalized).distinct
+      // Check column pruning and predicate push-down.
+      assert(distinctCteRepartitions.length == 2)
+      assert(distinctCteRepartitions(1).collectFirst {
+        case p: Project if p.projectList.length == 3 => p
+      }.isDefined, "CTE columns should be pruned.")
+      assert(distinctCteRepartitions(1).collectFirst {
+        case f: Filter if f.condition.semanticEquals(GreaterThan(f.output(1), 
Literal(0))) => f
+      }.isDefined, "Predicate 'c2 > 0' should be pushed down to the CTE def 
'v'.")
+      assert(distinctCteRepartitions(0).collectFirst {
+        case f: Filter if 
f.condition.find(_.semanticEquals(f.output(0))).isDefined => f
+      }.isDefined, "CTE 'vv' definition contains predicate 'c1 > 0'.")
+      assert(distinctCteRepartitions(1).collectFirst {
+        case f: Filter if 
f.condition.find(_.semanticEquals(f.output(0))).isDefined => f
+      }.isEmpty, "Predicate 'c1 > 0' should be not pushed down to the CTE def 
'v'.")
+      // Check runtime repartition reuse.
+      assert(
+        collectWithSubqueries(df.queryExecution.executedPlan) {
+          case r: ReusedExchangeExec => r
+        }.length == 2,
+        "CTE repartition is reused.")
+    }
+  }
+
+  test("CTE Predicate push-down and column pruning - combined predicate") {
+    withView("t") {
+      Seq((0, 1, 2), (1, 2, 3)).toDF("c1", "c2", 
"c3").createOrReplaceTempView("t")
+      val df = sql(
+        s"""with
+           |v as (
+           |  select c1, c2, c3, rand() c4 from t
+           |),
+           |vv as (
+           |  select v1.c1, v1.c2, rand() c5 from v v1, v v2
+           |  where v1.c1 > 0 and v2.c3 < 5 and v1.c2 = v2.c2
+           |)
+           |select vv1.c1, vv1.c2, vv2.c1, vv2.c2 from vv vv1, vv vv2
+           |where vv1.c2 > 0 and vv2.c2 > 0 and vv1.c1 = vv2.c1
+         """.stripMargin)
+      checkAnswer(df, Row(1, 2, 1, 2) :: Nil)
+      assert(
+        df.queryExecution.analyzed.collect {
+          case WithCTE(_, cteDefs) => cteDefs
+        }.head.length == 2,
+        "With-CTE should contain 2 CTE defs after analysis.")
+      val cteRepartitions = df.queryExecution.optimizedPlan.collect {
+        case r: RepartitionOperation => r
+      }
+      assert(cteRepartitions.length == 6,
+        "CTE should not be inlined after optimization.")
+      val distinctCteRepartitions = 
cteRepartitions.map(_.canonicalized).distinct
+      // Check column pruning and predicate push-down.
+      assert(distinctCteRepartitions.length == 2)
+      assert(distinctCteRepartitions(1).collectFirst {
+        case p: Project if p.projectList.length == 3 => p
+      }.isDefined, "CTE columns should be pruned.")
+      assert(
+        distinctCteRepartitions(1).collectFirst {
+          case f: Filter
+              if f.condition.semanticEquals(
+                And(
+                  GreaterThan(f.output(1), Literal(0)),
+                  Or(
+                    GreaterThan(f.output(0), Literal(0)),
+                    LessThan(f.output(2), Literal(5))))) =>
+            f
+        }.isDefined,
+        "Predicate 'c2 > 0 AND (c1 > 0 OR c3 < 5)' should be pushed down to 
the CTE def 'v'.")
+      // Check runtime repartition reuse.
+      assert(
+        collectWithSubqueries(df.queryExecution.executedPlan) {
+          case r: ReusedExchangeExec => r
+        }.length == 2,
+        "CTE repartition is reused.")
+    }
+  }
+
+  test("Views with CTEs - 1 temp view") {
+    withView("t", "t2") {
+      Seq((0, 1), (1, 2)).toDF("c1", "c2").createOrReplaceTempView("t")
+      sql(
+        s"""with
+           |v as (
+           |  select c1 + c2 c3 from t
+           |)
+           |select sum(c3) s from v
+         """.stripMargin).createOrReplaceTempView("t2")
+      val df = sql(
+        s"""with
+           |v as (
+           |  select c1 * c2 c3 from t
+           |)
+           |select sum(c3) from v except select s from t2
+         """.stripMargin)
+      checkAnswer(df, Row(2) :: Nil)
+    }
+  }
+
+  test("Views with CTEs - 2 temp views") {
+    withView("t", "t2", "t3") {
+      Seq((0, 1), (1, 2)).toDF("c1", "c2").createOrReplaceTempView("t")
+      sql(
+        s"""with
+           |v as (
+           |  select c1 + c2 c3 from t
+           |)
+           |select sum(c3) s from v
+         """.stripMargin).createOrReplaceTempView("t2")
+      sql(
+        s"""with
+           |v as (
+           |  select c1 * c2 c3 from t
+           |)
+           |select sum(c3) s from v
+         """.stripMargin).createOrReplaceTempView("t3")
+      val df = sql("select s from t3 except select s from t2")
+      checkAnswer(df, Row(2) :: Nil)
+    }
+  }
+
+  test("Views with CTEs - temp view + sql view") {
+    withTable("t") {
+      withView ("t2", "t3") {
+        Seq((0, 1), (1, 2)).toDF("c1", "c2").write.saveAsTable("t")
+        sql(
+          s"""with
+             |v as (
+             |  select c1 + c2 c3 from t
+             |)
+             |select sum(c3) s from v
+           """.stripMargin).createOrReplaceTempView("t2")
+        sql(
+          s"""create view t3 as
+             |with
+             |v as (
+             |  select c1 * c2 c3 from t
+             |)
+             |select sum(c3) s from v
+           """.stripMargin)
+        val df = sql("select s from t3 except select s from t2")
+        checkAnswer(df, Row(2) :: Nil)
+      }
+    }
+  }
+
+  test("Union of Dataframes with CTEs") {
+    val a = spark.sql("with t as (select 1 as n) select * from t ")
+    val b = spark.sql("with t as (select 2 as n) select * from t ")
+    val df = a.union(b)
+    checkAnswer(df, Row(1) :: Row(2) :: Nil)
+  }
+
+  test("CTE definitions out of original order when not inlined") {
+    withView("t1", "t2") {
+      Seq((1, 2, 10, 100), (2, 3, 20, 200)).toDF("workspace_id", "issue_id", 
"shard_id", "field_id")
+        .createOrReplaceTempView("issue_current")
+      withSQLConf(SQLConf.OPTIMIZER_EXCLUDED_RULES.key ->
+          "org.apache.spark.sql.catalyst.optimizer.InlineCTE") {
+        val df = sql(
+          """
+            |WITH cte_0 AS (
+            |  SELECT workspace_id, issue_id, shard_id, field_id FROM 
issue_current
+            |),
+            |cte_1 AS (
+            |  WITH filtered_source_table AS (
+            |    SELECT * FROM cte_0 WHERE shard_id in ( 10 )
+            |  )
+            |  SELECT source_table.workspace_id, field_id FROM cte_0 
source_table
+            |  INNER JOIN (
+            |    SELECT workspace_id, issue_id FROM filtered_source_table 
GROUP BY 1, 2
+            |  ) target_table
+            |  ON source_table.issue_id = target_table.issue_id
+            |  AND source_table.workspace_id = target_table.workspace_id
+            |  WHERE source_table.shard_id IN ( 10 )
+            |)
+            |SELECT * FROM cte_1
+        """.stripMargin)
+        checkAnswer(df, Row(1, 100) :: Nil)
+      }
+    }
+  }
 }
 
 class CTEInlineSuiteAEOff extends CTEInlineSuiteBase with 
DisableAdaptiveExecutionSuite
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 309396543d4..42945e7f1c5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -593,6 +593,21 @@ class SQLQuerySuite extends QueryTest with 
SharedSparkSession with AdaptiveSpark
         |select * from q1 union all select * from q2""".stripMargin),
       Row(5, "5") :: Row(4, "4") :: Nil)
 
+    // inner CTE relation refers to outer CTE relation.
+    withSQLConf(SQLConf.LEGACY_CTE_PRECEDENCE_POLICY.key -> "CORRECTED") {
+      checkAnswer(
+        sql(
+          """
+            |with temp1 as (select 1 col),
+            |temp2 as (
+            |  with temp1 as (select col + 1 AS col from temp1),
+            |  temp3 as (select col + 1 from temp1)
+            |  select * from temp3
+            |)
+            |select * from temp2
+            |""".stripMargin),
+        Row(3))
+      }
   }
 
   test("Allow only a single WITH clause per query") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to