This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.1
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.1 by this push:
     new 36ab645  [SPARK-35454][SQL][3.1] One LogicalPlan can match multiple 
dataset ids
36ab645 is described below

commit 36ab645f2a02961f6f248bfc1250594dcc252b03
Author: yi.wu <[email protected]>
AuthorDate: Fri May 28 13:04:55 2021 +0000

    [SPARK-35454][SQL][3.1] One LogicalPlan can match multiple dataset ids
    
    ### What changes were proposed in this pull request?
    
    Change the type of `DATASET_ID_TAG` from `Long` to `HashSet[Long]` to allow 
the logical plan to match multiple datasets.
    
    ### Why are the changes needed?
    
    During the transformation from one Dataset to another Dataset, the 
DATASET_ID_TAG of logical plan won't change if the plan itself doesn't change:
    
    
https://github.com/apache/spark/blob/b5241c97b17a1139a4ff719bfce7f68aef094d95/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala#L234-L237
    
    However, dataset id always changes even if the logical plan doesn't change:
    
https://github.com/apache/spark/blob/b5241c97b17a1139a4ff719bfce7f68aef094d95/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala#L207-L208
    
    And this can lead to the mismatch between dataset's id and col's 
__dataset_id. E.g.,
    
    ```scala
      test("SPARK-28344: fail ambiguous self join - Dataset.colRegex as column 
ref") {
        // The test can fail if we change it to:
        // val df1 = spark.range(3).toDF()
        // val df2 = df1.filter($"id" > 0).toDF()
        val df1 = spark.range(3)
        val df2 = df1.filter($"id" > 0)
    
        withSQLConf(
          SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED.key -> "true",
          SQLConf.CROSS_JOINS_ENABLED.key -> "true") {
          assertAmbiguousSelfJoin(df1.join(df2, df1.colRegex("id") > 
df2.colRegex("id")))
        }
      }
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Added unit tests.
    
    Closes #32692 from Ngone51/spark-35454-3.1.
    
    Authored-by: yi.wu <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../catalyst/plans/logical/AnalysisHelper.scala    |  4 +-
 .../main/scala/org/apache/spark/sql/Dataset.scala  | 11 +--
 .../analysis/DetectAmbiguousSelfJoin.scala         |  8 +-
 .../apache/spark/sql/DataFrameSelfJoinSuite.scala  | 87 ++++++++++++++++++++++
 4 files changed, 100 insertions(+), 10 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelper.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelper.scala
index 54b0141..b31b3e6 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelper.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelper.scala
@@ -91,7 +91,9 @@ trait AnalysisHelper extends QueryPlan[LogicalPlan] { self: 
LogicalPlan =>
           }
         } else {
           CurrentOrigin.withOrigin(origin) {
-            rule.applyOrElse(afterRuleOnChildren, identity[LogicalPlan])
+            val afterRule = rule.applyOrElse(afterRuleOnChildren, 
identity[LogicalPlan])
+            afterRule.copyTagsFrom(self)
+            afterRule
           }
         }
       }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 6afbbce..1c76f4c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql
 import java.io.{ByteArrayOutputStream, CharArrayWriter, DataOutputStream}
 
 import scala.collection.JavaConverters._
-import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.{ArrayBuffer, HashSet}
 import scala.reflect.runtime.universe.TypeTag
 import scala.util.control.NonFatal
 
@@ -69,7 +69,7 @@ private[sql] object Dataset {
   val curId = new java.util.concurrent.atomic.AtomicLong()
   val DATASET_ID_KEY = "__dataset_id"
   val COL_POS_KEY = "__col_position"
-  val DATASET_ID_TAG = TreeNodeTag[Long]("dataset_id")
+  val DATASET_ID_TAG = TreeNodeTag[HashSet[Long]]("dataset_id")
 
   def apply[T: Encoder](sparkSession: SparkSession, logicalPlan: LogicalPlan): 
Dataset[T] = {
     val dataset = new Dataset(sparkSession, logicalPlan, 
implicitly[Encoder[T]])
@@ -231,9 +231,10 @@ class Dataset[T] private[sql](
       case _ =>
         queryExecution.analyzed
     }
-    if 
(sparkSession.sessionState.conf.getConf(SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED)
 &&
-        plan.getTagValue(Dataset.DATASET_ID_TAG).isEmpty) {
-      plan.setTagValue(Dataset.DATASET_ID_TAG, id)
+    if 
(sparkSession.sessionState.conf.getConf(SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED))
 {
+      val dsIds = plan.getTagValue(Dataset.DATASET_ID_TAG).getOrElse(new 
HashSet[Long])
+      dsIds.add(id)
+      plan.setTagValue(Dataset.DATASET_ID_TAG, dsIds)
     }
     plan
   }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/analysis/DetectAmbiguousSelfJoin.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/analysis/DetectAmbiguousSelfJoin.scala
index b26a078..781a7ab 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/analysis/DetectAmbiguousSelfJoin.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/analysis/DetectAmbiguousSelfJoin.scala
@@ -57,8 +57,8 @@ object DetectAmbiguousSelfJoin extends Rule[LogicalPlan] {
   }
 
   object LogicalPlanWithDatasetId {
-    def unapply(p: LogicalPlan): Option[(LogicalPlan, Long)] = {
-      p.getTagValue(Dataset.DATASET_ID_TAG).map(id => p -> id)
+    def unapply(p: LogicalPlan): Option[(LogicalPlan, mutable.HashSet[Long])] 
= {
+      p.getTagValue(Dataset.DATASET_ID_TAG).map(ids => p -> ids)
     }
   }
 
@@ -89,9 +89,9 @@ object DetectAmbiguousSelfJoin extends Rule[LogicalPlan] {
       val inputAttrs = AttributeSet(plan.children.flatMap(_.output))
 
       plan.foreach {
-        case LogicalPlanWithDatasetId(p, id) if dsIdSet.contains(id) =>
+        case LogicalPlanWithDatasetId(p, ids) if 
dsIdSet.intersect(ids).nonEmpty =>
           colRefs.foreach { ref =>
-            if (id == ref.datasetId) {
+            if (ids.contains(ref.datasetId)) {
               if (ref.colPos < 0 || ref.colPos >= p.output.length) {
                 throw new IllegalStateException("[BUG] Hit an invalid Dataset 
column reference: " +
                   s"$ref. Please open a JIRA ticket to report it.")
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala
index 76f07b5..062404f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.sql
 
+import org.apache.spark.sql.catalyst.expressions.AttributeReference
 import org.apache.spark.sql.expressions.Window
 import org.apache.spark.sql.functions.{count, sum}
 import org.apache.spark.sql.internal.SQLConf
@@ -257,4 +258,90 @@ class DataFrameSelfJoinSuite extends QueryTest with 
SharedSparkSession {
       checkAnswer(df1.join(df2, df1("b") === 2), Row(1, 2, 1))
     }
   }
+
+  test("SPARK-35454: __dataset_id and __col_position should be correctly set") 
{
+    val ds = Seq[TestData](
+      TestData(1, "sales"),
+      TestData(2, "personnel"),
+      TestData(3, "develop"),
+      TestData(4, "IT")).toDS()
+    var dsIdSetOpt = ds.logicalPlan.getTagValue(Dataset.DATASET_ID_TAG)
+    assert(dsIdSetOpt.get.size === 1)
+    var col1DsId = -1L
+    val col1 = ds.col("key")
+    col1.expr.foreach {
+      case a: AttributeReference =>
+        col1DsId = a.metadata.getLong(Dataset.DATASET_ID_KEY)
+        assert(dsIdSetOpt.get.contains(col1DsId))
+        assert(a.metadata.getLong(Dataset.COL_POS_KEY) === 0)
+    }
+
+    val df = ds.toDF()
+    dsIdSetOpt = df.logicalPlan.getTagValue(Dataset.DATASET_ID_TAG)
+    assert(dsIdSetOpt.get.size === 2)
+    var col2DsId = -1L
+    val col2 = df.col("key")
+    col2.expr.foreach {
+      case a: AttributeReference =>
+        col2DsId = a.metadata.getLong(Dataset.DATASET_ID_KEY)
+        
assert(dsIdSetOpt.get.contains(a.metadata.getLong(Dataset.DATASET_ID_KEY)))
+        assert(a.metadata.getLong(Dataset.COL_POS_KEY) === 0)
+    }
+    assert(col1DsId !== col2DsId)
+  }
+
+  test("SPARK-35454: fail ambiguous self join - toDF") {
+    val df1 = spark.range(3).toDF()
+    val df2 = df1.filter($"id" > 0).toDF()
+
+    withSQLConf(
+      SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED.key -> "true",
+      SQLConf.CROSS_JOINS_ENABLED.key -> "true") {
+      assertAmbiguousSelfJoin(df1.join(df2, df1.col("id") > df2.col("id")))
+    }
+  }
+
+  test("SPARK-35454: fail ambiguous self join - join four tables") {
+    val df1 = spark.range(3).select($"id".as("a"), $"id".as("b"))
+    val df2 = df1.filter($"a" > 0).select("b")
+    val df3 = df1.filter($"a" <= 2).select("b")
+    val df4 = df1.filter($"b" <= 2)
+    val df5 = spark.range(1)
+
+    withSQLConf(
+      SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED.key -> "false",
+      SQLConf.CROSS_JOINS_ENABLED.key -> "true") {
+      // `df2("b") < df4("b")` is always false
+      checkAnswer(df1.join(df2).join(df3).join(df4, df2("b") < df4("b")), Nil)
+      // `df2("b")` actually points to the column of `df1`.
+      checkAnswer(
+        df1.join(df2).join(df5).join(df4).select(df2("b")),
+        Seq(0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2).map(Row(_)))
+      // `df5("id")` is not ambiguous.
+      checkAnswer(
+        df1.join(df5).join(df3).select(df5("id")),
+        Seq(0, 0, 0, 0, 0, 0, 0, 0, 0).map(Row(_)))
+
+      // Alias the dataframe and use qualified column names can fix ambiguous 
self-join.
+      val aliasedDf1 = df1.alias("w")
+      val aliasedDf2 = df2.as("x")
+      val aliasedDf3 = df3.as("y")
+      val aliasedDf4 = df3.as("z")
+      checkAnswer(
+        aliasedDf1.join(aliasedDf2).join(aliasedDf3).join(aliasedDf4, $"x.b" < 
$"y.b"),
+        Seq(Row(0, 0, 1, 2, 0), Row(0, 0, 1, 2, 1), Row(0, 0, 1, 2, 2),
+          Row(1, 1, 1, 2, 0), Row(1, 1, 1, 2, 1), Row(1, 1, 1, 2, 2),
+          Row(2, 2, 1, 2, 0), Row(2, 2, 1, 2, 1), Row(2, 2, 1, 2, 2)))
+      checkAnswer(
+        aliasedDf1.join(df5).join(aliasedDf3).select($"y.b"),
+        Seq(0, 0, 0, 1, 1, 1, 2, 2, 2).map(Row(_)))
+    }
+
+    withSQLConf(
+      SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED.key -> "true",
+      SQLConf.CROSS_JOINS_ENABLED.key -> "true") {
+      assertAmbiguousSelfJoin(df1.join(df2).join(df3).join(df4, df2("b") < 
df4("b")))
+      
assertAmbiguousSelfJoin(df1.join(df2).join(df5).join(df4).select(df2("b")))
+    }
+  }
 }

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to