dtenedor commented on code in PR #41439:
URL: https://github.com/apache/spark/pull/41439#discussion_r1223363720


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala:
##########
@@ -32,9 +32,11 @@ import org.apache.spark.sql.catalyst.rules._
 import org.apache.spark.sql.catalyst.trees.TreePattern.{EXISTS_SUBQUERY, 
IN_SUBQUERY, LATERAL_JOIN, LIST_SUBQUERY, PLAN_EXPRESSION, SCALAR_SUBQUERY}
 import org.apache.spark.sql.errors.{QueryCompilationErrors, 
QueryExecutionErrors}
 import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.internal.SQLConf.REASSIGN_IDS_IN_SCALAR_SUBQUERY
 import org.apache.spark.sql.types._
 import org.apache.spark.util.Utils
 
+

Review Comment:
   revert extra newline?



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/AssignNewExprIds.scala:
##########
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import scala.collection.mutable
+
+import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, 
Expression, ExprId, NamedExpression}
+import 
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, 
Partial, PartialMerge}
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.rules.Rule
+
+/**
+ * Reassigns expression IDs in every expression of the given LogicalPlan 
(including subqueries
+ * contained in the plan).
+ */
+object AssignNewExprIds extends Rule[LogicalPlan] {
+
+  private def assignNewExprIdsinExpr(input: Expression,

Review Comment:
   Scala lets you put the method definitions in any order :) maybe put the 
`apply` first, that's the convention for Catalyst rules, and then the rest of 
the methods following in the order they're referenced from `apply`?



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala:
##########
@@ -598,26 +600,43 @@ object RewriteCorrelatedScalarSubquery extends 
Rule[LogicalPlan] with AliasHelpe
         val joinHint = JoinHint(None, subHint)
 
         val resultWithZeroTups = evalSubqueryOnZeroTups(query)
-        lazy val planWithoutCountBug = Project(
-          currentChild.output :+ origOutput,
-          Join(currentChild, query, LeftOuter, conditions.reduceOption(And), 
joinHint))
+
+        // Reassign expression IDs in the future right side of the join to 
avoid ID conflicts.
+        val newQuery = if 
(SQLConf.get.getConf(REASSIGN_IDS_IN_SCALAR_SUBQUERY)) {
+          AssignNewExprIds(query)
+        } else {
+          query
+        }
+        val newOutput = newQuery.output.head.withNullability(true)
+        val replacementMap = 
AttributeMap(query.output.zip(newQuery.output).toMap)
+
+        val newConditions = conditions.map(_.transform {

Review Comment:
   same here?



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/AssignNewExprIds.scala:
##########
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import scala.collection.mutable
+
+import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, 
Expression, ExprId, NamedExpression}
+import 
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, 
Partial, PartialMerge}
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.rules.Rule
+
+/**
+ * Reassigns expression IDs in every expression of the given LogicalPlan 
(including subqueries
+ * contained in the plan).
+ */
+object AssignNewExprIds extends Rule[LogicalPlan] {
+
+  private def assignNewExprIdsinExpr(input: Expression,
+                                     reassignedExprs: mutable.HashMap[ExprId, 
Attribute]):
+  Expression = input match {
+    case a: Attribute =>
+      val newAttribute = reassignedExprs.getOrElse(a.exprId,
+        a.withExprId(NamedExpression.newExprId))
+      reassignedExprs.put(a.exprId, newAttribute)

Review Comment:
   can we avoid the `put` if the key was actually present in the map? e.g.
   
   ```
   reassignedExprs.get(a.exprId).getOrElse {
     val newAttribute = a.withExprId(NamedExpression.newExprId)
     reassignedExprs.put(a.exprId, newAttribute)
     reassignedExprs.put(newAttribute.exprId, newAttribute)
     newAttribute
   }
   ```
   
   same for the `Alias` case below.



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/AssignNewExprIds.scala:
##########
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import scala.collection.mutable
+
+import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, 
Expression, ExprId, NamedExpression}
+import 
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, 
Partial, PartialMerge}
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.rules.Rule
+
+/**
+ * Reassigns expression IDs in every expression of the given LogicalPlan 
(including subqueries
+ * contained in the plan).

Review Comment:
   Maybe also mention that:
   * this rule applies to CTEs as well as subquery expressions
   * after this rule finishes, we guarantee that all expression IDs for all 
expressions in the plan are different from the original plan, and:
   * all of the original expression IDs maps consistently to the same new 
expression ID and this mapping holds throughout the entire plan
   * the purpose of this rule is to support cases like safely adding multiple 
copies of the same subplans to the main plan, such as for self-joins.



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala:
##########
@@ -598,26 +600,43 @@ object RewriteCorrelatedScalarSubquery extends 
Rule[LogicalPlan] with AliasHelpe
         val joinHint = JoinHint(None, subHint)
 
         val resultWithZeroTups = evalSubqueryOnZeroTups(query)
-        lazy val planWithoutCountBug = Project(
-          currentChild.output :+ origOutput,
-          Join(currentChild, query, LeftOuter, conditions.reduceOption(And), 
joinHint))
+
+        // Reassign expression IDs in the future right side of the join to 
avoid ID conflicts.
+        val newQuery = if 
(SQLConf.get.getConf(REASSIGN_IDS_IN_SCALAR_SUBQUERY)) {
+          AssignNewExprIds(query)
+        } else {
+          query
+        }
+        val newOutput = newQuery.output.head.withNullability(true)

Review Comment:
   add a comment to say why we use `withNullability(true)` here?



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/AssignNewExprIds.scala:
##########
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import scala.collection.mutable
+
+import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, 
Expression, ExprId, NamedExpression}
+import 
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, 
Partial, PartialMerge}
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.rules.Rule
+
+/**
+ * Reassigns expression IDs in every expression of the given LogicalPlan 
(including subqueries
+ * contained in the plan).
+ */
+object AssignNewExprIds extends Rule[LogicalPlan] {
+
+  private def assignNewExprIdsinExpr(input: Expression,
+                                     reassignedExprs: mutable.HashMap[ExprId, 
Attribute]):
+  Expression = input match {
+    case a: Attribute =>
+      val newAttribute = reassignedExprs.getOrElse(a.exprId,
+        a.withExprId(NamedExpression.newExprId))
+      reassignedExprs.put(a.exprId, newAttribute)
+      reassignedExprs.put(newAttribute.exprId, newAttribute)
+      newAttribute
+    case a: Alias =>
+      val newAlias = Alias(a.child, a.name)(NamedExpression.newExprId,
+        a.qualifier, a.explicitMetadata, a.nonInheritableMetadataKeys)
+      reassignedExprs.put(a.exprId, newAlias.toAttribute)
+      reassignedExprs.put(newAlias.exprId, newAlias.toAttribute)
+      newAlias
+    case a: AggregateExpression =>
+      if (a.mode == PartialMerge || a.mode == Partial) {
+        // Partial aggregation's attributes are going to be reused in the 
final aggregations.
+        // In order to avoid renaming attributes of final aggregations, keep 
the intermediate
+        // attributes as is.
+        reassignedExprs.put(a.resultAttribute.exprId, a.resultAttribute)
+        return a

Review Comment:
   `return` is dangerous to use in Catalyst rules because the behavior is very 
confusing within PartialFunctions. Please just use `a` with `} else { ... }` 
instead.



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/AssignNewExprIds.scala:
##########
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import scala.collection.mutable
+
+import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, 
Expression, ExprId, NamedExpression}
+import 
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, 
Partial, PartialMerge}
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.rules.Rule
+
+/**
+ * Reassigns expression IDs in every expression of the given LogicalPlan 
(including subqueries
+ * contained in the plan).
+ */
+object AssignNewExprIds extends Rule[LogicalPlan] {
+
+  private def assignNewExprIdsinExpr(input: Expression,
+                                     reassignedExprs: mutable.HashMap[ExprId, 
Attribute]):
+  Expression = input match {

Review Comment:
   the style guide says we should put the `: ReturnType` on the same line as 
the last paramaeter



##########
sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AssignNewExprIdsSuite.scala:
##########
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.expressions.{AttributeSet, InSubquery, 
ListQuery, ScalarSubquery}
+import org.apache.spark.sql.catalyst.plans.{LeftSemi, PlanTest}
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
+import org.apache.spark.sql.catalyst.rules.RuleExecutor
+
+class AssignNewExprIdsSuite extends PlanTest {
+  object Optimize extends RuleExecutor[LogicalPlan] {
+    val batches = Batch("Assign Expr Ids", Once,
+      AssignNewExprIds) :: Nil
+  }
+
+  val testRelation1 = LocalRelation($"a".int, $"b".int)
+  val testRelation2 = LocalRelation($"c".int, $"d".int)
+
+  private def collectAllReferences(plan: LogicalPlan): AttributeSet =
+    AttributeSet(plan.collectWithSubqueries({
+      case e: LogicalPlan => e.references ++ e.producedAttributes
+    }).flatten)
+
+  private def check(input: LogicalPlan): Unit = {
+    val output = AssignNewExprIds(input)
+    comparePlans(input, output)
+    val oldRefs = collectAllReferences(input)
+    val newRefs = collectAllReferences(output)
+    assert(oldRefs.size == newRefs.size)
+    assert(oldRefs.intersect(newRefs).isEmpty)
+  }
+
+  test("Reassign IDs in a simple plan") {
+    val plan =
+      testRelation1
+        .sortBy($"a".asc, $"b".asc)
+        .distribute($"a")(2)
+        .where($"a" === 10)
+        .groupBy($"a")(sum($"b")).analyze
+    check(plan)
+  }
+
+  test("Plan with scalar subquery") {
+    val testRelation3 = LocalRelation($"e".int, $"f".int)
+    val subqPlan =
+      testRelation3
+        .groupBy($"e")(sum($"f").as("sum"))
+        .where($"e" === $"a")
+    val subqExpr = ScalarSubquery(subqPlan)
+    val originalQuery =
+      testRelation1
+        .select(subqExpr.as("sum"))
+        .join(testRelation2, joinType = LeftSemi, condition = Some($"sum" === 
$"d")).analyze
+    check(originalQuery)
+  }
+
+  test("Plan with IN subquery") {

Review Comment:
   other test ideas:
   * EXISTS subquery
   * plan with CTEs
   * self-join



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala:
##########
@@ -4280,6 +4280,13 @@ object SQLConf {
       .checkValue(_ >= 0, "The threshold of cached local relations must not be 
negative")
       .createWithDefault(64 * 1024 * 1024)
 
+  val REASSIGN_IDS_IN_SCALAR_SUBQUERY =
+    buildConf("spark.sql.optimizer.reassignIdsInScalarSubquery.enabled")
+      .internal()
+      .doc("Reassigns expression IDs in scalar subqueries.")

Review Comment:
   this just copies the conf name :) maybe also mention why we want to invoke 
this reassignment?



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/AssignNewExprIds.scala:
##########
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import scala.collection.mutable
+
+import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, 
Expression, ExprId, NamedExpression}
+import 
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, 
Partial, PartialMerge}
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.rules.Rule
+
+/**
+ * Reassigns expression IDs in every expression of the given LogicalPlan 
(including subqueries
+ * contained in the plan).
+ */
+object AssignNewExprIds extends Rule[LogicalPlan] {
+
+  private def assignNewExprIdsinExpr(input: Expression,
+                                     reassignedExprs: mutable.HashMap[ExprId, 
Attribute]):
+  Expression = input match {
+    case a: Attribute =>
+      val newAttribute = reassignedExprs.getOrElse(a.exprId,
+        a.withExprId(NamedExpression.newExprId))
+      reassignedExprs.put(a.exprId, newAttribute)
+      reassignedExprs.put(newAttribute.exprId, newAttribute)
+      newAttribute
+    case a: Alias =>
+      val newAlias = Alias(a.child, a.name)(NamedExpression.newExprId,
+        a.qualifier, a.explicitMetadata, a.nonInheritableMetadataKeys)
+      reassignedExprs.put(a.exprId, newAlias.toAttribute)
+      reassignedExprs.put(newAlias.exprId, newAlias.toAttribute)
+      newAlias
+    case a: AggregateExpression =>
+      if (a.mode == PartialMerge || a.mode == Partial) {
+        // Partial aggregation's attributes are going to be reused in the 
final aggregations.
+        // In order to avoid renaming attributes of final aggregations, keep 
the intermediate
+        // attributes as is.
+        reassignedExprs.put(a.resultAttribute.exprId, a.resultAttribute)
+        return a
+      }
+      val newResultId = NamedExpression.newExprId
+      val updatedExpression = a.copy(resultId = newResultId)
+      reassignedExprs.put(newResultId, updatedExpression.resultAttribute)
+      reassignedExprs.put(a.resultId, updatedExpression.resultAttribute)
+      updatedExpression
+    case p: Expression => p
+  }
+
+  private def transformUpAllExpressions(plan: LogicalPlan,
+                                        rule: PartialFunction[Expression, 
Expression]):
+  LogicalPlan = {
+    plan.transformUpWithSubqueries {
+      case q => q.transformExpressionsUp(rule)
+    }
+  }
+
+  private def assignNewExprIds(plan: LogicalPlan,
+                               reassignedExprs: mutable.HashMap[ExprId, 
Attribute]):
+  LogicalPlan = {
+    transformUpAllExpressions(plan, {
+      case e: Expression => assignNewExprIdsinExpr(e, reassignedExprs)
+    })
+  }
+
+  def apply(plan: LogicalPlan): LogicalPlan = {
+    val reassignedExprs = mutable.HashMap.empty[ExprId, Attribute]
+    assignNewExprIds(plan, reassignedExprs)

Review Comment:
   this is the only call site. Maybe just inline it here instead, then we can 
get rid of `assignNewExprIds` and `transformUpAllExpressions`:
   
   ```
   plan.transformUpWithSubqueries(plan, {
     case q => q.transformExpressionsUp {
       assignNewExprIdsinExpr(e, reassignedExprs)
     }
   })
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to