vladimirg-db commented on code in PR #49351:
URL: https://github.com/apache/spark/pull/49351#discussion_r1911026873


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveWithCTE.scala:
##########
@@ -37,21 +38,150 @@ object ResolveWithCTE extends Rule[LogicalPlan] {
     }
   }
 
+  // Substitute CTERelationRef with UnionLoopRef.
+  private def transformRefs(plan: LogicalPlan) = {
+    plan.transformWithPruning(_.containsPattern(CTE)) {
+      case r: CTERelationRef if r.recursive =>
+        UnionLoopRef(r.cteId, r.output, false)
+    }
+  }
+
+  // Update the definition's recursiveAnchor if the anchor is resolved.
+  private def recursiveAnchorResolved(cteDef: CTERelationDef): 
Option[LogicalPlan] = {
+    cteDef.child match {
+      case SubqueryAlias(_, ul: UnionLoop) =>
+        if (ul.anchor.resolved) {
+          Some(ul.anchor)
+        } else {
+          None
+        }
+      case SubqueryAlias(_, Distinct(ul: UnionLoop)) =>
+        if (ul.anchor.resolved) {
+          Some(ul.anchor)
+        } else {
+          None
+        }
+      case SubqueryAlias(_, UnresolvedSubqueryColumnAliases(_, ul: UnionLoop)) 
=>
+        if (ul.anchor.resolved) {
+          Some(ul.anchor)
+        } else {
+          None
+        }
+      case SubqueryAlias(_, UnresolvedSubqueryColumnAliases(_, Distinct(ul: 
UnionLoop))) =>
+        if (ul.anchor.resolved) {
+          Some(ul.anchor)
+        } else {
+          None
+        }
+      case _ =>
+        cteDef.failAnalysis(
+          errorClass = "INVALID_RECURSIVE_CTE",
+          messageParameters = Map.empty)
+    }
+  }
+
   private def resolveWithCTE(
       plan: LogicalPlan,
       cteDefMap: mutable.HashMap[Long, CTERelationDef]): LogicalPlan = {
     plan.resolveOperatorsDownWithPruning(_.containsAllPatterns(CTE)) {
       case w @ WithCTE(_, cteDefs) =>
-        cteDefs.foreach { cteDef =>
-          if (cteDef.resolved) {
-            cteDefMap.put(cteDef.id, cteDef)
+        val newCTEDefs = cteDefs.map { cteDef =>
+          val newCTEDef = if (cteDef.recursive) {
+            cteDef.child match {
+              // Substitutions to UnionLoop and UnionLoopRef.
+              case a @ SubqueryAlias(_, Union(Seq(anchor, recursion), false, 
false)) =>
+                cteDef.copy(child =
+                  a.copy(child =
+                    UnionLoop(cteDef.id, anchor, transformRefs(recursion))))
+              case a @ SubqueryAlias(_,
+              ca @ UnresolvedSubqueryColumnAliases(_,
+              Union(Seq(anchor, recursion), false, false))) =>
+                cteDef.copy(child =
+                  a.copy(child =
+                    ca.copy(child =
+                      UnionLoop(cteDef.id, anchor, transformRefs(recursion)))))
+              // If the recursion is described with an UNION (deduplicating) 
clause then the
+              // recursive term should not return those rows that have been 
calculated previously,
+              // and we exclude those rows from the current iteration result.
+              case a @ SubqueryAlias(_, Distinct(Union(Seq(anchor, recursion), 
false, false))) =>
+                cteDef.copy(child =
+                  a.copy(child =
+                    UnionLoop(cteDef.id,
+                      Distinct(anchor),
+                      Except(
+                        transformRefs(recursion),
+                        UnionLoopRef(cteDef.id, cteDef.output, true),
+                        false))))
+              case a @ SubqueryAlias(_,
+              ca @ UnresolvedSubqueryColumnAliases(_, 
Distinct(Union(Seq(anchor, recursion),
+              false, false)))) =>
+                cteDef.copy(child =
+                  a.copy(child =
+                    ca.copy(child =
+                      UnionLoop(cteDef.id,
+                        Distinct(anchor),
+                        Except(
+                          transformRefs(recursion),
+                          UnionLoopRef(cteDef.id, cteDef.output, true),
+                          false)))))
+              case _ =>
+                // We do not support cases of sole Union (needs a 
SubqueryAlias above it), nor
+                // Project (as UnresolvedSubqueryColumnAliases have not been 
substituted with the
+                // Project yet), leaving us with cases of SubqueryAlias->Union 
and SubqueryAlias->
+                // UnresolvedSubqueryColumnAliases->Union. The same applies to 
Distinct Union.
+                cteDef.failAnalysis(
+                    errorClass = "INVALID_RECURSIVE_CTE",
+                    messageParameters = Map.empty)
+            }
+          } else {
+            cteDef
           }
+
+          if (newCTEDef.recursive) {
+            // cteDefMap holds "partially" resolved (only via anchor) CTE 
definitions in the
+            // recursive case.
+            if (newCTEDef.resolved) {
+              newCTEDef.failAnalysis(
+                errorClass = "INVALID_RECURSIVE_CTE",
+                messageParameters = Map.empty)
+            }
+            if (recursiveAnchorResolved(newCTEDef).isDefined) {
+              cteDefMap.put(newCTEDef.id, newCTEDef)
+            }
+          } else {
+            if (newCTEDef.resolved) {
+              cteDefMap.put(newCTEDef.id, newCTEDef)
+            }
+          }
+
+          newCTEDef
         }
-        w
+        w.copy(cteDefs = newCTEDefs)
 
       case ref: CTERelationRef if !ref.resolved =>
         cteDefMap.get(ref.cteId).map { cteDef =>

Review Comment:
   How about we rewrite it a bit:
   
   ```
   cteDefMap.get(ref.cteId) match {
     case Some(cteDef) if !ref.recursive =>
       ref.copy(_resolved = true, output = cteDef.output, isStreaming = 
cteDef.isStreaming)
     case Some(cteDef) =>
       ...
     case None =>
       ref
   }
   ```



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveWithCTE.scala:
##########
@@ -37,21 +38,150 @@ object ResolveWithCTE extends Rule[LogicalPlan] {
     }
   }
 
+  // Substitute CTERelationRef with UnionLoopRef.
+  private def transformRefs(plan: LogicalPlan) = {
+    plan.transformWithPruning(_.containsPattern(CTE)) {
+      case r: CTERelationRef if r.recursive =>
+        UnionLoopRef(r.cteId, r.output, false)
+    }
+  }
+
+  // Update the definition's recursiveAnchor if the anchor is resolved.
+  private def recursiveAnchorResolved(cteDef: CTERelationDef): 
Option[LogicalPlan] = {
+    cteDef.child match {
+      case SubqueryAlias(_, ul: UnionLoop) =>
+        if (ul.anchor.resolved) {
+          Some(ul.anchor)
+        } else {
+          None
+        }
+      case SubqueryAlias(_, Distinct(ul: UnionLoop)) =>
+        if (ul.anchor.resolved) {
+          Some(ul.anchor)
+        } else {
+          None
+        }
+      case SubqueryAlias(_, UnresolvedSubqueryColumnAliases(_, ul: UnionLoop)) 
=>
+        if (ul.anchor.resolved) {
+          Some(ul.anchor)
+        } else {
+          None
+        }
+      case SubqueryAlias(_, UnresolvedSubqueryColumnAliases(_, Distinct(ul: 
UnionLoop))) =>
+        if (ul.anchor.resolved) {
+          Some(ul.anchor)
+        } else {
+          None
+        }
+      case _ =>
+        cteDef.failAnalysis(
+          errorClass = "INVALID_RECURSIVE_CTE",
+          messageParameters = Map.empty)
+    }
+  }
+
   private def resolveWithCTE(
       plan: LogicalPlan,
       cteDefMap: mutable.HashMap[Long, CTERelationDef]): LogicalPlan = {
     plan.resolveOperatorsDownWithPruning(_.containsAllPatterns(CTE)) {
       case w @ WithCTE(_, cteDefs) =>
-        cteDefs.foreach { cteDef =>
-          if (cteDef.resolved) {
-            cteDefMap.put(cteDef.id, cteDef)
+        val newCTEDefs = cteDefs.map { cteDef =>
+          val newCTEDef = if (cteDef.recursive) {
+            cteDef.child match {
+              // Substitutions to UnionLoop and UnionLoopRef.
+              case a @ SubqueryAlias(_, Union(Seq(anchor, recursion), false, 
false)) =>

Review Comment:
   Maybe we can introduce an extractor object to reduce complexity here:
   
   ```
   object ReplaceUnionWithUnionLoop {
     def unapply(plan: LogicalPlan): Option[UnionLoop] = plan match {
       case union: Union(Seq(anchor, recursion), false, false) =>
         Some(UnionLoop(cteDef.id, anchor, transformRefs(recursion)))
       case distinctUnion: Distinct(Union(Seq(anchor, recursion), false, 
false)) =>
         Some(UnionLoop(cteDef.id, Distinct(anchor), 
Except(transformRefs(recursion), UnionLoopRef(cteDef.id, cteDef.output, true), 
false)))
       case _ =>
         None
     }
   }
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to