This is an automated email from the ASF dual-hosted git repository.
dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 06c741a0061b [SPARK-47129][CONNECT][SQL] Make `ResolveRelations` cache
connect plan properly
06c741a0061b is described below
commit 06c741a0061bcf2c6e2c08212cab9f4e774cb70a
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Fri Feb 23 09:26:13 2024 -0800
[SPARK-47129][CONNECT][SQL] Make `ResolveRelations` cache connect plan
properly
### What changes were proposed in this pull request?
Make `ResolveRelations` handle plan id properly
### Why are the changes needed?
bug fix for Spark Connect, it won't affect classic Spark SQL
before this PR:
```
from pyspark.sql import functions as sf
spark.range(10).withColumn("value_1",
sf.lit(1)).write.saveAsTable("test_table_1")
spark.range(10).withColumnRenamed("id", "index").withColumn("value_2",
sf.lit(2)).write.saveAsTable("test_table_2")
df1 = spark.read.table("test_table_1")
df2 = spark.read.table("test_table_2")
df3 = spark.read.table("test_table_1")
join1 = df1.join(df2, on=df1.id==df2.index).select(df2.index, df2.value_2)
join2 = df3.join(join1, how="left", on=join1.index==df3.id)
join2.schema
```
fails with
```
AnalysisException: [CANNOT_RESOLVE_DATAFRAME_COLUMN] Cannot resolve
dataframe column "id". It's probably because of illegal references like
`df1.select(df2.col("a"))`. SQLSTATE: 42704
```
That is due to existing plan caching in `ResolveRelations` doesn't work
with Spark Connect
```
=== Applying Rule
org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveRelations ===
'[#12]Join LeftOuter, '`==`('index, 'id) '[#12]Join
LeftOuter, '`==`('index, 'id)
!:- '[#9]UnresolvedRelation [test_table_1], [], false :-
'[#9]SubqueryAlias spark_catalog.default.test_table_1
!+- '[#11]Project ['index, 'value_2] : +-
'UnresolvedCatalogRelation `spark_catalog`.`default`.`test_table_1`, [], false
! +- '[#10]Join Inner, '`==`('id, 'index) +-
'[#11]Project ['index, 'value_2]
! :- '[#7]UnresolvedRelation [test_table_1], [], false +-
'[#10]Join Inner, '`==`('id, 'index)
! +- '[#8]UnresolvedRelation [test_table_2], [], false :-
'[#9]SubqueryAlias spark_catalog.default.test_table_1
! : +-
'UnresolvedCatalogRelation `spark_catalog`.`default`.`test_table_1`, [], false
! +-
'[#8]SubqueryAlias spark_catalog.default.test_table_2
! +-
'UnresolvedCatalogRelation `spark_catalog`.`default`.`test_table_2`, [], false
Can not resolve 'id with plan 7
```
`[#7]UnresolvedRelation [test_table_1], [], false` was wrongly resolved to
the cached one
```
:- '[#9]SubqueryAlias spark_catalog.default.test_table_1
+- 'UnresolvedCatalogRelation `spark_catalog`.`default`.`test_table_1`,
[], false
```
### Does this PR introduce _any_ user-facing change?
yes, bug fix
### How was this patch tested?
added ut
### Was this patch authored or co-authored using generative AI tooling?
ci
Closes #45214 from zhengruifeng/connect_fix_read_join.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
---
python/pyspark/sql/tests/test_readwriter.py | 23 +++++++++++++++++-
.../spark/sql/catalyst/analysis/Analyzer.scala | 27 ++++++++++++++++------
2 files changed, 42 insertions(+), 8 deletions(-)
diff --git a/python/pyspark/sql/tests/test_readwriter.py
b/python/pyspark/sql/tests/test_readwriter.py
index 70a320fc53b6..85057f37a181 100644
--- a/python/pyspark/sql/tests/test_readwriter.py
+++ b/python/pyspark/sql/tests/test_readwriter.py
@@ -20,7 +20,7 @@ import shutil
import tempfile
from pyspark.errors import AnalysisException
-from pyspark.sql.functions import col
+from pyspark.sql.functions import col, lit
from pyspark.sql.readwriter import DataFrameWriterV2
from pyspark.sql.types import StructType, StructField, StringType
from pyspark.testing.sqlutils import ReusedSQLTestCase
@@ -181,6 +181,27 @@ class ReadwriterTestsMixin:
df.write.mode("overwrite").insertInto("test_table", False)
self.assertEqual(6, self.spark.sql("select * from
test_table").count())
+ def test_cached_table(self):
+ with self.table("test_cached_table_1"):
+ self.spark.range(10).withColumn(
+ "value_1",
+ lit(1),
+ ).write.saveAsTable("test_cached_table_1")
+
+ with self.table("test_cached_table_2"):
+ self.spark.range(10).withColumnRenamed("id",
"index").withColumn(
+ "value_2", lit(2)
+ ).write.saveAsTable("test_cached_table_2")
+
+ df1 = self.spark.read.table("test_cached_table_1")
+ df2 = self.spark.read.table("test_cached_table_2")
+ df3 = self.spark.read.table("test_cached_table_1")
+
+ join1 = df1.join(df2, on=df1.id ==
df2.index).select(df2.index, df2.value_2)
+ join2 = df3.join(join1, how="left", on=join1.index == df3.id)
+
+ self.assertEqual(join2.columns, ["id", "value_1", "index",
"value_2"])
+
class ReadwriterV2TestsMixin:
def test_api(self):
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index d8127fe03da4..1fb5d00bdf39 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -1275,16 +1275,29 @@ class Analyzer(override val catalogManager:
CatalogManager) extends RuleExecutor
val key =
((catalog.name +: ident.namespace :+
ident.name).toImmutableArraySeq,
finalTimeTravelSpec)
- AnalysisContext.get.relationCache.get(key).map(_.transform {
- case multi: MultiInstanceRelation =>
- val newRelation = multi.newInstance()
- newRelation.copyTagsFrom(multi)
- newRelation
- }).orElse {
+ AnalysisContext.get.relationCache.get(key).map { cache =>
+ val cachedRelation = cache.transform {
+ case multi: MultiInstanceRelation =>
+ val newRelation = multi.newInstance()
+ newRelation.copyTagsFrom(multi)
+ newRelation
+ }
+ u.getTagValue(LogicalPlan.PLAN_ID_TAG).map { planId =>
+ val cachedConnectRelation = cachedRelation.clone()
+ cachedConnectRelation.setTagValue(LogicalPlan.PLAN_ID_TAG,
planId)
+ cachedConnectRelation
+ }.getOrElse(cachedRelation)
+ }.orElse {
val table = CatalogV2Util.loadTable(catalog, ident,
finalTimeTravelSpec)
val loaded = createRelation(catalog, ident, table, u.options,
u.isStreaming)
loaded.foreach(AnalysisContext.get.relationCache.update(key, _))
- loaded
+ u.getTagValue(LogicalPlan.PLAN_ID_TAG).map { planId =>
+ loaded.map { loadedRelation =>
+ val loadedConnectRelation = loadedRelation.clone()
+ loadedConnectRelation.setTagValue(LogicalPlan.PLAN_ID_TAG,
planId)
+ loadedConnectRelation
+ }
+ }.getOrElse(loaded)
}
case _ => None
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]