This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 609552e19cf [SPARK-45279][PYTHON][CONNECT] Attach plan_id for all 
logical plans
609552e19cf is described below

commit 609552e19cfe75109b1b4641baadd79360e75443
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Mon Sep 25 08:17:08 2023 +0800

    [SPARK-45279][PYTHON][CONNECT] Attach plan_id for all logical plans
    
    ### What changes were proposed in this pull request?
    Attach plan_id for all logical plans, except `CachedRelation`
    
    ### Why are the changes needed?
    1, all logical plans should contain its plan id in protos
    2, catalog plans also contain the plan id in scala client, e.g.
    
    
https://github.com/apache/spark/blob/05f5dccbd34218c7d399228529853bdb1595f3a2/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala#L63-L67
    
    `newDataset` method will set the plan id
    
    ### Does this PR introduce _any_ user-facing change?
    no
    
    ### How was this patch tested?
    CI
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #43055 from zhengruifeng/connect_plan_id.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 python/pyspark/sql/connect/plan.py | 79 +++++++++++++++++++-------------------
 1 file changed, 40 insertions(+), 39 deletions(-)

diff --git a/python/pyspark/sql/connect/plan.py 
b/python/pyspark/sql/connect/plan.py
index 219545cf646..6758b3673f3 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -1190,9 +1190,7 @@ class CollectMetrics(LogicalPlan):
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
         assert self._child is not None
-
-        plan = proto.Relation()
-        plan.common.plan_id = self._child._plan_id
+        plan = self._create_proto_relation()
         plan.collect_metrics.input.CopyFrom(self._child.plan(session))
         plan.collect_metrics.name = self._name
         plan.collect_metrics.metrics.extend([self.col_to_expr(x, session) for 
x in self._exprs])
@@ -1689,7 +1687,9 @@ class CurrentDatabase(LogicalPlan):
         super().__init__(None)
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
-        return 
proto.Relation(catalog=proto.Catalog(current_database=proto.CurrentDatabase()))
+        plan = self._create_proto_relation()
+        plan.catalog.current_database.SetInParent()
+        return plan
 
 
 class SetCurrentDatabase(LogicalPlan):
@@ -1698,7 +1698,7 @@ class SetCurrentDatabase(LogicalPlan):
         self._db_name = db_name
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
-        plan = proto.Relation()
+        plan = self._create_proto_relation()
         plan.catalog.set_current_database.db_name = self._db_name
         return plan
 
@@ -1709,7 +1709,8 @@ class ListDatabases(LogicalPlan):
         self._pattern = pattern
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
-        plan = 
proto.Relation(catalog=proto.Catalog(list_databases=proto.ListDatabases()))
+        plan = self._create_proto_relation()
+        plan.catalog.list_databases.SetInParent()
         if self._pattern is not None:
             plan.catalog.list_databases.pattern = self._pattern
         return plan
@@ -1722,7 +1723,8 @@ class ListTables(LogicalPlan):
         self._pattern = pattern
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
-        plan = 
proto.Relation(catalog=proto.Catalog(list_tables=proto.ListTables()))
+        plan = self._create_proto_relation()
+        plan.catalog.list_tables.SetInParent()
         if self._db_name is not None:
             plan.catalog.list_tables.db_name = self._db_name
         if self._pattern is not None:
@@ -1737,7 +1739,8 @@ class ListFunctions(LogicalPlan):
         self._pattern = pattern
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
-        plan = 
proto.Relation(catalog=proto.Catalog(list_functions=proto.ListFunctions()))
+        plan = self._create_proto_relation()
+        plan.catalog.list_functions.SetInParent()
         if self._db_name is not None:
             plan.catalog.list_functions.db_name = self._db_name
         if self._pattern is not None:
@@ -1752,7 +1755,7 @@ class ListColumns(LogicalPlan):
         self._db_name = db_name
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
-        plan = 
proto.Relation(catalog=proto.Catalog(list_columns=proto.ListColumns()))
+        plan = self._create_proto_relation()
         plan.catalog.list_columns.table_name = self._table_name
         if self._db_name is not None:
             plan.catalog.list_columns.db_name = self._db_name
@@ -1765,7 +1768,7 @@ class GetDatabase(LogicalPlan):
         self._db_name = db_name
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
-        plan = 
proto.Relation(catalog=proto.Catalog(get_database=proto.GetDatabase()))
+        plan = self._create_proto_relation()
         plan.catalog.get_database.db_name = self._db_name
         return plan
 
@@ -1777,7 +1780,7 @@ class GetTable(LogicalPlan):
         self._db_name = db_name
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
-        plan = 
proto.Relation(catalog=proto.Catalog(get_table=proto.GetTable()))
+        plan = self._create_proto_relation()
         plan.catalog.get_table.table_name = self._table_name
         if self._db_name is not None:
             plan.catalog.get_table.db_name = self._db_name
@@ -1791,7 +1794,7 @@ class GetFunction(LogicalPlan):
         self._db_name = db_name
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
-        plan = 
proto.Relation(catalog=proto.Catalog(get_function=proto.GetFunction()))
+        plan = self._create_proto_relation()
         plan.catalog.get_function.function_name = self._function_name
         if self._db_name is not None:
             plan.catalog.get_function.db_name = self._db_name
@@ -1804,7 +1807,7 @@ class DatabaseExists(LogicalPlan):
         self._db_name = db_name
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
-        plan = 
proto.Relation(catalog=proto.Catalog(database_exists=proto.DatabaseExists()))
+        plan = self._create_proto_relation()
         plan.catalog.database_exists.db_name = self._db_name
         return plan
 
@@ -1816,7 +1819,7 @@ class TableExists(LogicalPlan):
         self._db_name = db_name
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
-        plan = 
proto.Relation(catalog=proto.Catalog(table_exists=proto.TableExists()))
+        plan = self._create_proto_relation()
         plan.catalog.table_exists.table_name = self._table_name
         if self._db_name is not None:
             plan.catalog.table_exists.db_name = self._db_name
@@ -1830,7 +1833,7 @@ class FunctionExists(LogicalPlan):
         self._db_name = db_name
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
-        plan = 
proto.Relation(catalog=proto.Catalog(function_exists=proto.FunctionExists()))
+        plan = self._create_proto_relation()
         plan.catalog.function_exists.function_name = self._function_name
         if self._db_name is not None:
             plan.catalog.function_exists.db_name = self._db_name
@@ -1854,9 +1857,7 @@ class CreateExternalTable(LogicalPlan):
         self._options = options
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
-        plan = proto.Relation(
-            
catalog=proto.Catalog(create_external_table=proto.CreateExternalTable())
-        )
+        plan = self._create_proto_relation()
         plan.catalog.create_external_table.table_name = self._table_name
         if self._path is not None:
             plan.catalog.create_external_table.path = self._path
@@ -1892,7 +1893,7 @@ class CreateTable(LogicalPlan):
         self._options = options
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
-        plan = 
proto.Relation(catalog=proto.Catalog(create_table=proto.CreateTable()))
+        plan = self._create_proto_relation()
         plan.catalog.create_table.table_name = self._table_name
         if self._path is not None:
             plan.catalog.create_table.path = self._path
@@ -1915,7 +1916,7 @@ class DropTempView(LogicalPlan):
         self._view_name = view_name
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
-        plan = 
proto.Relation(catalog=proto.Catalog(drop_temp_view=proto.DropTempView()))
+        plan = self._create_proto_relation()
         plan.catalog.drop_temp_view.view_name = self._view_name
         return plan
 
@@ -1926,9 +1927,7 @@ class DropGlobalTempView(LogicalPlan):
         self._view_name = view_name
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
-        plan = proto.Relation(
-            
catalog=proto.Catalog(drop_global_temp_view=proto.DropGlobalTempView())
-        )
+        plan = self._create_proto_relation()
         plan.catalog.drop_global_temp_view.view_name = self._view_name
         return plan
 
@@ -1939,11 +1938,8 @@ class RecoverPartitions(LogicalPlan):
         self._table_name = table_name
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
-        plan = proto.Relation(
-            catalog=proto.Catalog(
-                
recover_partitions=proto.RecoverPartitions(table_name=self._table_name)
-            )
-        )
+        plan = self._create_proto_relation()
+        plan.catalog.recover_partitions.table_name = self._table_name
         return plan
 
 
@@ -1953,9 +1949,8 @@ class IsCached(LogicalPlan):
         self._table_name = table_name
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
-        plan = proto.Relation(
-            
catalog=proto.Catalog(is_cached=proto.IsCached(table_name=self._table_name))
-        )
+        plan = self._create_proto_relation()
+        plan.catalog.is_cached.table_name = self._table_name
         return plan
 
 
@@ -1966,10 +1961,11 @@ class CacheTable(LogicalPlan):
         self._storage_level = storage_level
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
+        plan = self._create_proto_relation()
         _cache_table = proto.CacheTable(table_name=self._table_name)
         if self._storage_level:
             
_cache_table.storage_level.CopyFrom(storage_level_to_proto(self._storage_level))
-        plan = proto.Relation(catalog=proto.Catalog(cache_table=_cache_table))
+        plan.catalog.cache_table.CopyFrom(_cache_table)
         return plan
 
 
@@ -1979,7 +1975,7 @@ class UncacheTable(LogicalPlan):
         self._table_name = table_name
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
-        plan = 
proto.Relation(catalog=proto.Catalog(uncache_table=proto.UncacheTable()))
+        plan = self._create_proto_relation()
         plan.catalog.uncache_table.table_name = self._table_name
         return plan
 
@@ -1989,7 +1985,9 @@ class ClearCache(LogicalPlan):
         super().__init__(None)
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
-        return 
proto.Relation(catalog=proto.Catalog(clear_cache=proto.ClearCache()))
+        plan = self._create_proto_relation()
+        plan.catalog.clear_cache.SetInParent()
+        return plan
 
 
 class RefreshTable(LogicalPlan):
@@ -1998,7 +1996,7 @@ class RefreshTable(LogicalPlan):
         self._table_name = table_name
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
-        plan = 
proto.Relation(catalog=proto.Catalog(refresh_table=proto.RefreshTable()))
+        plan = self._create_proto_relation()
         plan.catalog.refresh_table.table_name = self._table_name
         return plan
 
@@ -2009,7 +2007,7 @@ class RefreshByPath(LogicalPlan):
         self._path = path
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
-        plan = 
proto.Relation(catalog=proto.Catalog(refresh_by_path=proto.RefreshByPath()))
+        plan = self._create_proto_relation()
         plan.catalog.refresh_by_path.path = self._path
         return plan
 
@@ -2019,7 +2017,9 @@ class CurrentCatalog(LogicalPlan):
         super().__init__(None)
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
-        return 
proto.Relation(catalog=proto.Catalog(current_catalog=proto.CurrentCatalog()))
+        plan = self._create_proto_relation()
+        plan.catalog.current_catalog.SetInParent()
+        return plan
 
 
 class SetCurrentCatalog(LogicalPlan):
@@ -2028,7 +2028,7 @@ class SetCurrentCatalog(LogicalPlan):
         self._catalog_name = catalog_name
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
-        plan = 
proto.Relation(catalog=proto.Catalog(set_current_catalog=proto.SetCurrentCatalog()))
+        plan = self._create_proto_relation()
         plan.catalog.set_current_catalog.catalog_name = self._catalog_name
         return plan
 
@@ -2039,7 +2039,8 @@ class ListCatalogs(LogicalPlan):
         self._pattern = pattern
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
-        plan = 
proto.Relation(catalog=proto.Catalog(list_catalogs=proto.ListCatalogs()))
+        plan = self._create_proto_relation()
+        plan.catalog.list_catalogs.SetInParent()
         if self._pattern is not None:
             plan.catalog.list_catalogs.pattern = self._pattern
         return plan


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to