This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 8c416290335 [SPARK-45509][SQL] Fix df column reference behavior for 
Spark Connect
8c416290335 is described below

commit 8c41629033508197b80f6a50c78fcf4d5f9752ff
Author: Wenchen Fan <wenc...@databricks.com>
AuthorDate: Tue Nov 7 16:35:48 2023 +0800

    [SPARK-45509][SQL] Fix df column reference behavior for Spark Connect
    
    ### What changes were proposed in this pull request?
    
    This PR fixes a few problems of column resolution for Spark Connect, to 
make the behavior closer to classic Spark SQL (unfortunately we still have some 
behavior differences in corner cases).
    1. resolve df column references in both `resolveExpressionByPlanChildren` 
and `resolveExpressionByPlanOutput`. Previously it's only in 
`resolveExpressionByPlanChildren`.
    2. when the plan id has multiple matches, fail with 
`AMBIGUOUS_COLUMN_REFERENCE`
    
    ### Why are the changes needed?
    
    fix behavior differences between spark connect and classic spark sql
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, for spark connect scala client
    
    ### How was this patch tested?
    
    new tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    no
    
    Closes #43465 from cloud-fan/column.
    
    Authored-by: Wenchen Fan <wenc...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../src/main/resources/error/error-classes.json    |  9 ++
 .../org/apache/spark/sql/ClientE2ETestSuite.scala  | 62 +++++++++++++-
 docs/sql-error-conditions.md                       |  9 ++
 python/pyspark/pandas/indexes/multi.py             |  2 +-
 python/pyspark/sql/connect/plan.py                 |  4 +-
 .../catalyst/analysis/ColumnResolutionHelper.scala | 98 +++++++++++++---------
 .../sql/catalyst/analysis/ResolveSetVariable.scala |  1 -
 7 files changed, 140 insertions(+), 45 deletions(-)

diff --git a/common/utils/src/main/resources/error/error-classes.json 
b/common/utils/src/main/resources/error/error-classes.json
index 3e0743d366a..db46ee8ca20 100644
--- a/common/utils/src/main/resources/error/error-classes.json
+++ b/common/utils/src/main/resources/error/error-classes.json
@@ -31,6 +31,15 @@
     ],
     "sqlState" : "42702"
   },
+  "AMBIGUOUS_COLUMN_REFERENCE" : {
+    "message" : [
+      "Column <name> is ambiguous. It's because you joined several DataFrame 
together, and some of these DataFrames are the same.",
+      "This column points to one of the DataFrame but Spark is unable to 
figure out which one.",
+      "Please alias the DataFrames with different names via `DataFrame.alias` 
before joining them,",
+      "and specify the column using qualified name, e.g. 
`df.alias(\"a\").join(df.alias(\"b\"), col(\"a.id\") > col(\"b.id\"))`."
+    ],
+    "sqlState" : "42702"
+  },
   "AMBIGUOUS_LATERAL_COLUMN_ALIAS" : {
     "message" : [
       "Lateral column alias <name> is ambiguous and has <n> matches."
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
index d9a77f2830b..b9fa415034c 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
@@ -44,7 +44,7 @@ import org.apache.spark.sql.types._
 
 class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with 
PrivateMethodTester {
 
-  test(s"throw SparkException with null filename in stack trace elements") {
+  test("throw SparkException with null filename in stack trace elements") {
     withSQLConf("spark.sql.connect.enrichError.enabled" -> "true") {
       val session = spark
       import session.implicits._
@@ -94,7 +94,7 @@ class ClientE2ETestSuite extends RemoteSparkSession with 
SQLHelper with PrivateM
     }
   }
 
-  test(s"throw SparkException with large cause exception") {
+  test("throw SparkException with large cause exception") {
     withSQLConf("spark.sql.connect.enrichError.enabled" -> "true") {
       val session = spark
       import session.implicits._
@@ -872,6 +872,64 @@ class ClientE2ETestSuite extends RemoteSparkSession with 
SQLHelper with PrivateM
     assert(joined2.schema.catalogString === "struct<id:bigint,a:double>")
   }
 
+  test("SPARK-45509: ambiguous column reference") {
+    val session = spark
+    import session.implicits._
+    val df1 = Seq(1 -> "a").toDF("i", "j")
+    val df1_filter = df1.filter(df1("i") > 0)
+    val df2 = Seq(2 -> "b").toDF("i", "y")
+
+    checkSameResult(
+      Seq(Row(1)),
+      // df1("i") is not ambiguous, and it's still valid in the filtered df.
+      df1_filter.select(df1("i")))
+
+    val e1 = intercept[AnalysisException] {
+      // df1("i") is not ambiguous, but it's not valid in the projected df.
+      df1.select((df1("i") + 1).as("plus")).select(df1("i")).collect()
+    }
+    
assert(e1.getMessage.contains("MISSING_ATTRIBUTES.RESOLVED_ATTRIBUTE_MISSING_FROM_INPUT"))
+
+    checkSameResult(
+      Seq(Row(1, "a")),
+      // All these column references are not ambiguous and are still valid 
after join.
+      df1.join(df2, df1("i") + 1 === 
df2("i")).sort(df1("i").desc).select(df1("i"), df1("j")))
+
+    val e2 = intercept[AnalysisException] {
+      // df1("i") is ambiguous as df1 appears in both join sides.
+      df1.join(df1, df1("i") === 1).collect()
+    }
+    assert(e2.getMessage.contains("AMBIGUOUS_COLUMN_REFERENCE"))
+
+    val e3 = intercept[AnalysisException] {
+      // df1("i") is ambiguous as df1 appears in both join sides.
+      df1.join(df1).select(df1("i")).collect()
+    }
+    assert(e3.getMessage.contains("AMBIGUOUS_COLUMN_REFERENCE"))
+
+    val e4 = intercept[AnalysisException] {
+      // df1("i") is ambiguous as df1 appears in both join sides (df1_filter 
contains df1).
+      df1.join(df1_filter, df1("i") === 1).collect()
+    }
+    assert(e4.getMessage.contains("AMBIGUOUS_COLUMN_REFERENCE"))
+
+    checkSameResult(
+      Seq(Row("a")),
+      // df1_filter("i") is not ambiguous as df1_filter does not exist in the 
join left side.
+      df1.join(df1_filter, df1_filter("i") === 1).select(df1_filter("j")))
+
+    val e5 = intercept[AnalysisException] {
+      // df1("i") is ambiguous as df1 appears in both sides of the first join.
+      df1.join(df1_filter, df1_filter("i") === 1).join(df2, df1("i") === 
1).collect()
+    }
+    assert(e5.getMessage.contains("AMBIGUOUS_COLUMN_REFERENCE"))
+
+    checkSameResult(
+      Seq(Row("a")),
+      // df1_filter("i") is not ambiguous as df1_filter only appears once.
+      df1.join(df1_filter).join(df2, df1_filter("i") === 
1).select(df1_filter("j")))
+  }
+
   test("broadcast join") {
     withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "-1") {
       val left = spark.range(100).select(col("id"), rand(10).as("a"))
diff --git a/docs/sql-error-conditions.md b/docs/sql-error-conditions.md
index a6f003647dd..7b0bc8ceb2b 100644
--- a/docs/sql-error-conditions.md
+++ b/docs/sql-error-conditions.md
@@ -55,6 +55,15 @@ See '`<docroot>`/sql-migration-guide.html#query-engine'.
 
 Column or field `<name>` is ambiguous and has `<n>` matches.
 
+### AMBIGUOUS_COLUMN_REFERENCE
+
+[SQLSTATE: 
42702](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation)
+
+Column `<name>` is ambiguous. It's because you joined several DataFrame 
together, and some of these DataFrames are the same.
+This column points to one of the DataFrame but Spark is unable to figure out 
which one.
+Please alias the DataFrames with different names via `DataFrame.alias` before 
joining them,
+and specify the column using qualified name, e.g. 
`df.alias("a").join(df.alias("b"), col("a.id") > col("b.id"))`.
+
 ### AMBIGUOUS_LATERAL_COLUMN_ALIAS
 
 [SQLSTATE: 
42702](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation)
diff --git a/python/pyspark/pandas/indexes/multi.py 
b/python/pyspark/pandas/indexes/multi.py
index 41c3b93ed51..9fbc608c12a 100644
--- a/python/pyspark/pandas/indexes/multi.py
+++ b/python/pyspark/pandas/indexes/multi.py
@@ -813,7 +813,7 @@ class MultiIndex(Index):
         sdf_symdiff = 
sdf_self.union(sdf_other).subtract(sdf_self.intersect(sdf_other))
 
         if sort:
-            sdf_symdiff = sdf_symdiff.sort(*self._internal.index_spark_columns)
+            sdf_symdiff = 
sdf_symdiff.sort(*self._internal.index_spark_column_names)
 
         internal = InternalFrame(
             spark_frame=sdf_symdiff,
diff --git a/python/pyspark/sql/connect/plan.py 
b/python/pyspark/sql/connect/plan.py
index d888422d29f..607d1429a9e 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -2251,7 +2251,9 @@ class CoGroupMap(LogicalPlan):
         self._input_grouping_cols = input_grouping_cols
         self._other_grouping_cols = other_grouping_cols
         self._other = cast(LogicalPlan, other)
-        self._func = function._build_common_inline_user_defined_function(*cols)
+        # The function takes entire DataFrame as inputs, no need to do
+        # column binding (no input columns).
+        self._func = function._build_common_inline_user_defined_function()
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
         assert self._child is not None
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala
index 54a9c6ca018..edfc60fc6ea 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala
@@ -30,10 +30,10 @@ import 
org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin
 import org.apache.spark.sql.catalyst.trees.TreePattern._
 import org.apache.spark.sql.catalyst.util.toPrettySQL
 import org.apache.spark.sql.connector.catalog.{CatalogManager, Identifier}
-import org.apache.spark.sql.errors.QueryCompilationErrors
+import org.apache.spark.sql.errors.{DataTypeErrorsBase, QueryCompilationErrors}
 import org.apache.spark.sql.internal.SQLConf
 
-trait ColumnResolutionHelper extends Logging {
+trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase {
 
   def conf: SQLConf
 
@@ -426,7 +426,7 @@ trait ColumnResolutionHelper extends Logging {
       throws: Boolean = false,
       includeLastResort: Boolean = false): Expression = {
     resolveExpression(
-      expr,
+      tryResolveColumnByPlanId(expr, plan),
       resolveColumnByName = nameParts => {
         plan.resolve(nameParts, conf.resolver)
       },
@@ -447,21 +447,8 @@ trait ColumnResolutionHelper extends Logging {
       e: Expression,
       q: LogicalPlan,
       includeLastResort: Boolean = false): Expression = {
-    val newE = if (e.exists(_.getTagValue(LogicalPlan.PLAN_ID_TAG).nonEmpty)) {
-      // If the TreeNodeTag 'LogicalPlan.PLAN_ID_TAG' is attached, it means 
that the plan and
-      // expression are from Spark Connect, and need to be resolved in this 
way:
-      //    1, extract the attached plan id from the expression 
(UnresolvedAttribute only for now);
-      //    2, top-down traverse the query plan to find the plan node that 
matches the plan id;
-      //    3, if can not find the matching node, fail the analysis due to 
illegal references;
-      //    4, resolve the expression with the matching node, if any error 
occurs here, apply the
-      //    old code path;
-      resolveExpressionByPlanId(e, q)
-    } else {
-      e
-    }
-
     resolveExpression(
-      newE,
+      tryResolveColumnByPlanId(e, q),
       resolveColumnByName = nameParts => {
         q.resolveChildren(nameParts, conf.resolver)
       },
@@ -490,39 +477,46 @@ trait ColumnResolutionHelper extends Logging {
     }
   }
 
-  private def resolveExpressionByPlanId(
+  // If the TreeNodeTag 'LogicalPlan.PLAN_ID_TAG' is attached, it means that 
the plan and
+  // expression are from Spark Connect, and need to be resolved in this way:
+  //    1. extract the attached plan id from UnresolvedAttribute;
+  //    2. top-down traverse the query plan to find the plan node that matches 
the plan id;
+  //    3. if can not find the matching node, fail the analysis due to illegal 
references;
+  //    4. if more than one matching nodes are found, fail due to ambiguous 
column reference;
+  //    5. resolve the expression with the matching node, if any error occurs 
here, return the
+  //       original expression as it is.
+  private def tryResolveColumnByPlanId(
       e: Expression,
-      q: LogicalPlan): Expression = {
-    if (!e.exists(_.getTagValue(LogicalPlan.PLAN_ID_TAG).nonEmpty)) {
-      return e
-    }
-
-    e match {
-      case u: UnresolvedAttribute =>
-        resolveUnresolvedAttributeByPlanId(u, q).getOrElse(u)
-      case _ =>
-        e.mapChildren(c => resolveExpressionByPlanId(c, q))
-    }
+      q: LogicalPlan,
+      idToPlan: mutable.HashMap[Long, LogicalPlan] = mutable.HashMap.empty): 
Expression = e match {
+    case u: UnresolvedAttribute =>
+      resolveUnresolvedAttributeByPlanId(
+        u, q, idToPlan: mutable.HashMap[Long, LogicalPlan]
+      ).getOrElse(u)
+    case _ if e.containsPattern(UNRESOLVED_ATTRIBUTE) =>
+      e.mapChildren(c => tryResolveColumnByPlanId(c, q, idToPlan))
+    case _ => e
   }
 
   private def resolveUnresolvedAttributeByPlanId(
       u: UnresolvedAttribute,
-      q: LogicalPlan): Option[NamedExpression] = {
+      q: LogicalPlan,
+      idToPlan: mutable.HashMap[Long, LogicalPlan]): Option[NamedExpression] = 
{
     val planIdOpt = u.getTagValue(LogicalPlan.PLAN_ID_TAG)
     if (planIdOpt.isEmpty) return None
     val planId = planIdOpt.get
     logDebug(s"Extract plan_id $planId from $u")
 
-    val planOpt = 
q.find(_.getTagValue(LogicalPlan.PLAN_ID_TAG).contains(planId))
-    if (planOpt.isEmpty) {
-      // For example:
-      //  df1 = spark.createDataFrame([Row(a = 1, b = 2, c = 3)]])
-      //  df2 = spark.createDataFrame([Row(a = 1, b = 2)]])
-      //  df1.select(df2.a)   <-   illegal reference df2.a
-      throw new AnalysisException(s"When resolving $u, " +
-        s"fail to find subplan with plan_id=$planId in $q")
-    }
-    val plan = planOpt.get
+    val plan = idToPlan.getOrElseUpdate(planId, {
+      findPlanById(u, planId, q).getOrElse {
+        // For example:
+        //  df1 = spark.createDataFrame([Row(a = 1, b = 2, c = 3)]])
+        //  df2 = spark.createDataFrame([Row(a = 1, b = 2)]])
+        //  df1.select(df2.a)   <-   illegal reference df2.a
+        throw new AnalysisException(s"When resolving $u, " +
+          s"fail to find subplan with plan_id=$planId in $q")
+      }
+    })
 
     val isMetadataAccess = u.getTagValue(LogicalPlan.IS_METADATA_COL).isDefined
     try {
@@ -539,4 +533,28 @@ trait ColumnResolutionHelper extends Logging {
         None
     }
   }
+
+  private def findPlanById(
+      u: UnresolvedAttribute,
+      id: Long,
+      plan: LogicalPlan): Option[LogicalPlan] = {
+    if (plan.getTagValue(LogicalPlan.PLAN_ID_TAG).contains(id)) {
+      Some(plan)
+    } else if (plan.children.length == 1) {
+      findPlanById(u, id, plan.children.head)
+    } else if (plan.children.length > 1) {
+      val matched = plan.children.flatMap(findPlanById(u, id, _))
+      if (matched.length > 1) {
+        throw new AnalysisException(
+          errorClass = "AMBIGUOUS_COLUMN_REFERENCE",
+          messageParameters = Map("name" -> toSQLId(u.nameParts)),
+          origin = u.origin
+        )
+      } else {
+        matched.headOption
+      }
+    } else {
+      None
+    }
+  }
 }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSetVariable.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSetVariable.scala
index ebf56ef1cc4..bd0204ba06f 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSetVariable.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSetVariable.scala
@@ -24,7 +24,6 @@ import org.apache.spark.sql.catalyst.plans.logical.{Limit, 
LogicalPlan, SetVaria
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.catalyst.trees.TreePattern.COMMAND
 import org.apache.spark.sql.connector.catalog.CatalogManager
-import org.apache.spark.sql.errors.DataTypeErrors.toSQLId
 import 
org.apache.spark.sql.errors.QueryCompilationErrors.unresolvedVariableError
 import org.apache.spark.sql.types.IntegerType
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to