This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch branch-3.1
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.1 by this push:
new 36ab645 [SPARK-35454][SQL][3.1] One LogicalPlan can match multiple
dataset ids
36ab645 is described below
commit 36ab645f2a02961f6f248bfc1250594dcc252b03
Author: yi.wu <[email protected]>
AuthorDate: Fri May 28 13:04:55 2021 +0000
[SPARK-35454][SQL][3.1] One LogicalPlan can match multiple dataset ids
### What changes were proposed in this pull request?
Change the type of `DATASET_ID_TAG` from `Long` to `HashSet[Long]` to allow
the logical plan to match multiple datasets.
### Why are the changes needed?
During the transformation from one Dataset to another Dataset, the
DATASET_ID_TAG of logical plan won't change if the plan itself doesn't change:
https://github.com/apache/spark/blob/b5241c97b17a1139a4ff719bfce7f68aef094d95/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala#L234-L237
However, dataset id always changes even if the logical plan doesn't change:
https://github.com/apache/spark/blob/b5241c97b17a1139a4ff719bfce7f68aef094d95/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala#L207-L208
And this can lead to the mismatch between dataset's id and col's
__dataset_id. E.g.,
```scala
test("SPARK-28344: fail ambiguous self join - Dataset.colRegex as column
ref") {
// The test can fail if we change it to:
// val df1 = spark.range(3).toDF()
// val df2 = df1.filter($"id" > 0).toDF()
val df1 = spark.range(3)
val df2 = df1.filter($"id" > 0)
withSQLConf(
SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED.key -> "true",
SQLConf.CROSS_JOINS_ENABLED.key -> "true") {
assertAmbiguousSelfJoin(df1.join(df2, df1.colRegex("id") >
df2.colRegex("id")))
}
}
```
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Added unit tests.
Closes #32692 from Ngone51/spark-35454-3.1.
Authored-by: yi.wu <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../catalyst/plans/logical/AnalysisHelper.scala | 4 +-
.../main/scala/org/apache/spark/sql/Dataset.scala | 11 +--
.../analysis/DetectAmbiguousSelfJoin.scala | 8 +-
.../apache/spark/sql/DataFrameSelfJoinSuite.scala | 87 ++++++++++++++++++++++
4 files changed, 100 insertions(+), 10 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelper.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelper.scala
index 54b0141..b31b3e6 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelper.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelper.scala
@@ -91,7 +91,9 @@ trait AnalysisHelper extends QueryPlan[LogicalPlan] { self:
LogicalPlan =>
}
} else {
CurrentOrigin.withOrigin(origin) {
- rule.applyOrElse(afterRuleOnChildren, identity[LogicalPlan])
+ val afterRule = rule.applyOrElse(afterRuleOnChildren,
identity[LogicalPlan])
+ afterRule.copyTagsFrom(self)
+ afterRule
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 6afbbce..1c76f4c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql
import java.io.{ByteArrayOutputStream, CharArrayWriter, DataOutputStream}
import scala.collection.JavaConverters._
-import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.{ArrayBuffer, HashSet}
import scala.reflect.runtime.universe.TypeTag
import scala.util.control.NonFatal
@@ -69,7 +69,7 @@ private[sql] object Dataset {
val curId = new java.util.concurrent.atomic.AtomicLong()
val DATASET_ID_KEY = "__dataset_id"
val COL_POS_KEY = "__col_position"
- val DATASET_ID_TAG = TreeNodeTag[Long]("dataset_id")
+ val DATASET_ID_TAG = TreeNodeTag[HashSet[Long]]("dataset_id")
def apply[T: Encoder](sparkSession: SparkSession, logicalPlan: LogicalPlan):
Dataset[T] = {
val dataset = new Dataset(sparkSession, logicalPlan,
implicitly[Encoder[T]])
@@ -231,9 +231,10 @@ class Dataset[T] private[sql](
case _ =>
queryExecution.analyzed
}
- if
(sparkSession.sessionState.conf.getConf(SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED)
&&
- plan.getTagValue(Dataset.DATASET_ID_TAG).isEmpty) {
- plan.setTagValue(Dataset.DATASET_ID_TAG, id)
+ if
(sparkSession.sessionState.conf.getConf(SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED))
{
+ val dsIds = plan.getTagValue(Dataset.DATASET_ID_TAG).getOrElse(new
HashSet[Long])
+ dsIds.add(id)
+ plan.setTagValue(Dataset.DATASET_ID_TAG, dsIds)
}
plan
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/analysis/DetectAmbiguousSelfJoin.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/analysis/DetectAmbiguousSelfJoin.scala
index b26a078..781a7ab 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/analysis/DetectAmbiguousSelfJoin.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/analysis/DetectAmbiguousSelfJoin.scala
@@ -57,8 +57,8 @@ object DetectAmbiguousSelfJoin extends Rule[LogicalPlan] {
}
object LogicalPlanWithDatasetId {
- def unapply(p: LogicalPlan): Option[(LogicalPlan, Long)] = {
- p.getTagValue(Dataset.DATASET_ID_TAG).map(id => p -> id)
+ def unapply(p: LogicalPlan): Option[(LogicalPlan, mutable.HashSet[Long])]
= {
+ p.getTagValue(Dataset.DATASET_ID_TAG).map(ids => p -> ids)
}
}
@@ -89,9 +89,9 @@ object DetectAmbiguousSelfJoin extends Rule[LogicalPlan] {
val inputAttrs = AttributeSet(plan.children.flatMap(_.output))
plan.foreach {
- case LogicalPlanWithDatasetId(p, id) if dsIdSet.contains(id) =>
+ case LogicalPlanWithDatasetId(p, ids) if
dsIdSet.intersect(ids).nonEmpty =>
colRefs.foreach { ref =>
- if (id == ref.datasetId) {
+ if (ids.contains(ref.datasetId)) {
if (ref.colPos < 0 || ref.colPos >= p.output.length) {
throw new IllegalStateException("[BUG] Hit an invalid Dataset
column reference: " +
s"$ref. Please open a JIRA ticket to report it.")
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala
index 76f07b5..062404f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql
+import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.{count, sum}
import org.apache.spark.sql.internal.SQLConf
@@ -257,4 +258,90 @@ class DataFrameSelfJoinSuite extends QueryTest with
SharedSparkSession {
checkAnswer(df1.join(df2, df1("b") === 2), Row(1, 2, 1))
}
}
+
+ test("SPARK-35454: __dataset_id and __col_position should be correctly set")
{
+ val ds = Seq[TestData](
+ TestData(1, "sales"),
+ TestData(2, "personnel"),
+ TestData(3, "develop"),
+ TestData(4, "IT")).toDS()
+ var dsIdSetOpt = ds.logicalPlan.getTagValue(Dataset.DATASET_ID_TAG)
+ assert(dsIdSetOpt.get.size === 1)
+ var col1DsId = -1L
+ val col1 = ds.col("key")
+ col1.expr.foreach {
+ case a: AttributeReference =>
+ col1DsId = a.metadata.getLong(Dataset.DATASET_ID_KEY)
+ assert(dsIdSetOpt.get.contains(col1DsId))
+ assert(a.metadata.getLong(Dataset.COL_POS_KEY) === 0)
+ }
+
+ val df = ds.toDF()
+ dsIdSetOpt = df.logicalPlan.getTagValue(Dataset.DATASET_ID_TAG)
+ assert(dsIdSetOpt.get.size === 2)
+ var col2DsId = -1L
+ val col2 = df.col("key")
+ col2.expr.foreach {
+ case a: AttributeReference =>
+ col2DsId = a.metadata.getLong(Dataset.DATASET_ID_KEY)
+
assert(dsIdSetOpt.get.contains(a.metadata.getLong(Dataset.DATASET_ID_KEY)))
+ assert(a.metadata.getLong(Dataset.COL_POS_KEY) === 0)
+ }
+ assert(col1DsId !== col2DsId)
+ }
+
+ test("SPARK-35454: fail ambiguous self join - toDF") {
+ val df1 = spark.range(3).toDF()
+ val df2 = df1.filter($"id" > 0).toDF()
+
+ withSQLConf(
+ SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED.key -> "true",
+ SQLConf.CROSS_JOINS_ENABLED.key -> "true") {
+ assertAmbiguousSelfJoin(df1.join(df2, df1.col("id") > df2.col("id")))
+ }
+ }
+
+ test("SPARK-35454: fail ambiguous self join - join four tables") {
+ val df1 = spark.range(3).select($"id".as("a"), $"id".as("b"))
+ val df2 = df1.filter($"a" > 0).select("b")
+ val df3 = df1.filter($"a" <= 2).select("b")
+ val df4 = df1.filter($"b" <= 2)
+ val df5 = spark.range(1)
+
+ withSQLConf(
+ SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED.key -> "false",
+ SQLConf.CROSS_JOINS_ENABLED.key -> "true") {
+ // `df2("b") < df4("b")` is always false
+ checkAnswer(df1.join(df2).join(df3).join(df4, df2("b") < df4("b")), Nil)
+ // `df2("b")` actually points to the column of `df1`.
+ checkAnswer(
+ df1.join(df2).join(df5).join(df4).select(df2("b")),
+ Seq(0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2).map(Row(_)))
+ // `df5("id")` is not ambiguous.
+ checkAnswer(
+ df1.join(df5).join(df3).select(df5("id")),
+ Seq(0, 0, 0, 0, 0, 0, 0, 0, 0).map(Row(_)))
+
+ // Alias the dataframe and use qualified column names can fix ambiguous
self-join.
+ val aliasedDf1 = df1.alias("w")
+ val aliasedDf2 = df2.as("x")
+ val aliasedDf3 = df3.as("y")
+ val aliasedDf4 = df3.as("z")
+ checkAnswer(
+ aliasedDf1.join(aliasedDf2).join(aliasedDf3).join(aliasedDf4, $"x.b" <
$"y.b"),
+ Seq(Row(0, 0, 1, 2, 0), Row(0, 0, 1, 2, 1), Row(0, 0, 1, 2, 2),
+ Row(1, 1, 1, 2, 0), Row(1, 1, 1, 2, 1), Row(1, 1, 1, 2, 2),
+ Row(2, 2, 1, 2, 0), Row(2, 2, 1, 2, 1), Row(2, 2, 1, 2, 2)))
+ checkAnswer(
+ aliasedDf1.join(df5).join(aliasedDf3).select($"y.b"),
+ Seq(0, 0, 0, 1, 1, 1, 2, 2, 2).map(Row(_)))
+ }
+
+ withSQLConf(
+ SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED.key -> "true",
+ SQLConf.CROSS_JOINS_ENABLED.key -> "true") {
+ assertAmbiguousSelfJoin(df1.join(df2).join(df3).join(df4, df2("b") <
df4("b")))
+
assertAmbiguousSelfJoin(df1.join(df2).join(df5).join(df4).select(df2("b")))
+ }
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]