maropu commented on a change in pull request #32303:
URL: https://github.com/apache/spark/pull/32303#discussion_r647092456



##########
File path: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
##########
@@ -784,11 +791,10 @@ trait CheckAnalysis extends PredicateHelper with 
LookupCatalog {
             failAnalysis(s"IN/EXISTS predicate sub-queries can only be used 
in" +
                 s" Filter/Join and a few commands: $plan")
         }
+        // Validate to make sure the correlations appearing in the query are 
valid and
+        // allowed by spark.
+        checkCorrelationsInSubquery(expr.plan)
     }
-
-    // Validate to make sure the correlations appearing in the query are valid 
and
-    // allowed by spark.
-    checkCorrelationsInSubquery(expr.plan)

Review comment:
       Instead of moving this check into each `case` statement, we just check 
if `plan` is lateral or not here?
   ```
   checkCorrelationsInSubquery(expr.plan, isLateral = 
plan.isInstanceOf[LateralJoin])
   ```

##########
File path: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
##########
@@ -315,19 +314,28 @@ object PullupCorrelatedPredicates extends 
Rule[LogicalPlan] with PredicateHelper
       case ListQuery(sub, children, exprId, childOutputs, conditions) if 
children.nonEmpty =>
         val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, outerPlans)
         ListQuery(newPlan, children, exprId, childOutputs, 
getJoinCondition(newCond, conditions))
+      case LateralSubquery(sub, children, exprId, conditions) if 
children.nonEmpty =>
+        val (newPlan, newCond) = decorrelate(sub, outerPlans)
+        LateralSubquery(newPlan, children, exprId, getJoinCondition(newCond, 
conditions))
     }
   }
 
   /**
    * Pull up the correlated predicates and rewrite all subqueries in an 
operator tree..
    */
   def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
-    _.containsAnyPattern(SCALAR_SUBQUERY, EXISTS_SUBQUERY, LIST_SUBQUERY)) {
+    _.containsPattern(PLAN_EXPRESSION)) {
     case f @ Filter(_, a: Aggregate) =>
       rewriteSubQueries(f, Seq(a, a.child))
-    // Only a few unary nodes (Project/Filter/Aggregate) can contain 
subqueries.
+    // Only a few unary nodes (Project/Filter/Aggregate/LateralJoin) can 
contain subqueries.
     case q: UnaryNode =>
-      rewriteSubQueries(q, q.children)
+      val newPlan = rewriteSubQueries(q, q.children)
+      // Preserve the original output of the node.
+      if (newPlan.output != q.output) {

Review comment:
       `q.sameOutput(newPlan)`?

##########
File path: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
##########
@@ -285,6 +285,39 @@ object ScalarSubquery {
   }
 }
 
+/**
+ * A subquery that can return multiple rows and columns. This should be 
rewritten as a join
+ * with the outer query during the optimization phase.
+ *
+ * Note: `exprId` is used to have a unique name in explain string output.
+ */
+case class LateralSubquery(
+    plan: LogicalPlan,
+    outerAttrs: Seq[Expression] = Seq.empty,
+    exprId: ExprId = NamedExpression.newExprId,
+    joinCond: Seq[Expression] = Seq.empty)
+  extends SubqueryExpression(plan, outerAttrs, exprId, joinCond) with 
Unevaluable {
+  override def dataType: DataType = plan.output.toStructType
+  override def nullable: Boolean = true
+  override def withNewPlan(plan: LogicalPlan): LateralSubquery = copy(plan = 
plan)
+  override def toString: String = s"lateral-subquery#${exprId.id} 
$conditionString"

Review comment:
       Could you add simple tests for checking if an explain can print this new 
node correctly?

##########
File path: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
##########
@@ -871,7 +871,13 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with 
SQLConfHelper with Logg
   override def visitFromClause(ctx: FromClauseContext): LogicalPlan = 
withOrigin(ctx) {
     val from = ctx.relation.asScala.foldLeft(null: LogicalPlan) { (left, 
relation) =>
       val right = plan(relation.relationPrimary)
-      val join = right.optionalMap(left)(Join(_, _, Inner, None, 
JoinHint.NONE))
+      val join = right.optionalMap(left) { (left, right) =>
+        if (relation.LATERAL != null) {
+          LateralJoin(left, LateralSubquery(right), Inner, None)

Review comment:
       We do not handle join hints for the lateral case?

##########
File path: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/PlanHelper.scala
##########
@@ -45,6 +45,7 @@ object PlanHelper {
         case e: AggregateExpression
           if !(plan.isInstanceOf[Aggregate] ||
                plan.isInstanceOf[Window] ||
+               plan.isInstanceOf[LateralJoin] ||

Review comment:
       How about more strictly checking that the right plan exprs of the 
lateral join are aggr ones?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to