This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 609552e19cf [SPARK-45279][PYTHON][CONNECT] Attach plan_id for all logical plans 609552e19cf is described below commit 609552e19cfe75109b1b4641baadd79360e75443 Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Mon Sep 25 08:17:08 2023 +0800 [SPARK-45279][PYTHON][CONNECT] Attach plan_id for all logical plans ### What changes were proposed in this pull request? Attach plan_id for all logical plans, except `CachedRelation` ### Why are the changes needed? 1, all logical plans should contain its plan id in protos 2, catalog plans also contain the plan id in scala client, e.g. https://github.com/apache/spark/blob/05f5dccbd34218c7d399228529853bdb1595f3a2/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala#L63-L67 `newDataset` method will set the plan id ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? CI ### Was this patch authored or co-authored using generative AI tooling? no Closes #43055 from zhengruifeng/connect_plan_id. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- python/pyspark/sql/connect/plan.py | 79 +++++++++++++++++++------------------- 1 file changed, 40 insertions(+), 39 deletions(-) diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index 219545cf646..6758b3673f3 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -1190,9 +1190,7 @@ class CollectMetrics(LogicalPlan): def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None - - plan = proto.Relation() - plan.common.plan_id = self._child._plan_id + plan = self._create_proto_relation() plan.collect_metrics.input.CopyFrom(self._child.plan(session)) plan.collect_metrics.name = self._name plan.collect_metrics.metrics.extend([self.col_to_expr(x, session) for x in self._exprs]) @@ -1689,7 +1687,9 @@ class CurrentDatabase(LogicalPlan): super().__init__(None) def plan(self, session: "SparkConnectClient") -> proto.Relation: - return proto.Relation(catalog=proto.Catalog(current_database=proto.CurrentDatabase())) + plan = self._create_proto_relation() + plan.catalog.current_database.SetInParent() + return plan class SetCurrentDatabase(LogicalPlan): @@ -1698,7 +1698,7 @@ class SetCurrentDatabase(LogicalPlan): self._db_name = db_name def plan(self, session: "SparkConnectClient") -> proto.Relation: - plan = proto.Relation() + plan = self._create_proto_relation() plan.catalog.set_current_database.db_name = self._db_name return plan @@ -1709,7 +1709,8 @@ class ListDatabases(LogicalPlan): self._pattern = pattern def plan(self, session: "SparkConnectClient") -> proto.Relation: - plan = proto.Relation(catalog=proto.Catalog(list_databases=proto.ListDatabases())) + plan = self._create_proto_relation() + plan.catalog.list_databases.SetInParent() if self._pattern is not None: plan.catalog.list_databases.pattern = self._pattern return plan @@ -1722,7 +1723,8 @@ class ListTables(LogicalPlan): self._pattern = pattern def plan(self, session: "SparkConnectClient") -> proto.Relation: - plan = proto.Relation(catalog=proto.Catalog(list_tables=proto.ListTables())) + plan = self._create_proto_relation() + plan.catalog.list_tables.SetInParent() if self._db_name is not None: plan.catalog.list_tables.db_name = self._db_name if self._pattern is not None: @@ -1737,7 +1739,8 @@ class ListFunctions(LogicalPlan): self._pattern = pattern def plan(self, session: "SparkConnectClient") -> proto.Relation: - plan = proto.Relation(catalog=proto.Catalog(list_functions=proto.ListFunctions())) + plan = self._create_proto_relation() + plan.catalog.list_functions.SetInParent() if self._db_name is not None: plan.catalog.list_functions.db_name = self._db_name if self._pattern is not None: @@ -1752,7 +1755,7 @@ class ListColumns(LogicalPlan): self._db_name = db_name def plan(self, session: "SparkConnectClient") -> proto.Relation: - plan = proto.Relation(catalog=proto.Catalog(list_columns=proto.ListColumns())) + plan = self._create_proto_relation() plan.catalog.list_columns.table_name = self._table_name if self._db_name is not None: plan.catalog.list_columns.db_name = self._db_name @@ -1765,7 +1768,7 @@ class GetDatabase(LogicalPlan): self._db_name = db_name def plan(self, session: "SparkConnectClient") -> proto.Relation: - plan = proto.Relation(catalog=proto.Catalog(get_database=proto.GetDatabase())) + plan = self._create_proto_relation() plan.catalog.get_database.db_name = self._db_name return plan @@ -1777,7 +1780,7 @@ class GetTable(LogicalPlan): self._db_name = db_name def plan(self, session: "SparkConnectClient") -> proto.Relation: - plan = proto.Relation(catalog=proto.Catalog(get_table=proto.GetTable())) + plan = self._create_proto_relation() plan.catalog.get_table.table_name = self._table_name if self._db_name is not None: plan.catalog.get_table.db_name = self._db_name @@ -1791,7 +1794,7 @@ class GetFunction(LogicalPlan): self._db_name = db_name def plan(self, session: "SparkConnectClient") -> proto.Relation: - plan = proto.Relation(catalog=proto.Catalog(get_function=proto.GetFunction())) + plan = self._create_proto_relation() plan.catalog.get_function.function_name = self._function_name if self._db_name is not None: plan.catalog.get_function.db_name = self._db_name @@ -1804,7 +1807,7 @@ class DatabaseExists(LogicalPlan): self._db_name = db_name def plan(self, session: "SparkConnectClient") -> proto.Relation: - plan = proto.Relation(catalog=proto.Catalog(database_exists=proto.DatabaseExists())) + plan = self._create_proto_relation() plan.catalog.database_exists.db_name = self._db_name return plan @@ -1816,7 +1819,7 @@ class TableExists(LogicalPlan): self._db_name = db_name def plan(self, session: "SparkConnectClient") -> proto.Relation: - plan = proto.Relation(catalog=proto.Catalog(table_exists=proto.TableExists())) + plan = self._create_proto_relation() plan.catalog.table_exists.table_name = self._table_name if self._db_name is not None: plan.catalog.table_exists.db_name = self._db_name @@ -1830,7 +1833,7 @@ class FunctionExists(LogicalPlan): self._db_name = db_name def plan(self, session: "SparkConnectClient") -> proto.Relation: - plan = proto.Relation(catalog=proto.Catalog(function_exists=proto.FunctionExists())) + plan = self._create_proto_relation() plan.catalog.function_exists.function_name = self._function_name if self._db_name is not None: plan.catalog.function_exists.db_name = self._db_name @@ -1854,9 +1857,7 @@ class CreateExternalTable(LogicalPlan): self._options = options def plan(self, session: "SparkConnectClient") -> proto.Relation: - plan = proto.Relation( - catalog=proto.Catalog(create_external_table=proto.CreateExternalTable()) - ) + plan = self._create_proto_relation() plan.catalog.create_external_table.table_name = self._table_name if self._path is not None: plan.catalog.create_external_table.path = self._path @@ -1892,7 +1893,7 @@ class CreateTable(LogicalPlan): self._options = options def plan(self, session: "SparkConnectClient") -> proto.Relation: - plan = proto.Relation(catalog=proto.Catalog(create_table=proto.CreateTable())) + plan = self._create_proto_relation() plan.catalog.create_table.table_name = self._table_name if self._path is not None: plan.catalog.create_table.path = self._path @@ -1915,7 +1916,7 @@ class DropTempView(LogicalPlan): self._view_name = view_name def plan(self, session: "SparkConnectClient") -> proto.Relation: - plan = proto.Relation(catalog=proto.Catalog(drop_temp_view=proto.DropTempView())) + plan = self._create_proto_relation() plan.catalog.drop_temp_view.view_name = self._view_name return plan @@ -1926,9 +1927,7 @@ class DropGlobalTempView(LogicalPlan): self._view_name = view_name def plan(self, session: "SparkConnectClient") -> proto.Relation: - plan = proto.Relation( - catalog=proto.Catalog(drop_global_temp_view=proto.DropGlobalTempView()) - ) + plan = self._create_proto_relation() plan.catalog.drop_global_temp_view.view_name = self._view_name return plan @@ -1939,11 +1938,8 @@ class RecoverPartitions(LogicalPlan): self._table_name = table_name def plan(self, session: "SparkConnectClient") -> proto.Relation: - plan = proto.Relation( - catalog=proto.Catalog( - recover_partitions=proto.RecoverPartitions(table_name=self._table_name) - ) - ) + plan = self._create_proto_relation() + plan.catalog.recover_partitions.table_name = self._table_name return plan @@ -1953,9 +1949,8 @@ class IsCached(LogicalPlan): self._table_name = table_name def plan(self, session: "SparkConnectClient") -> proto.Relation: - plan = proto.Relation( - catalog=proto.Catalog(is_cached=proto.IsCached(table_name=self._table_name)) - ) + plan = self._create_proto_relation() + plan.catalog.is_cached.table_name = self._table_name return plan @@ -1966,10 +1961,11 @@ class CacheTable(LogicalPlan): self._storage_level = storage_level def plan(self, session: "SparkConnectClient") -> proto.Relation: + plan = self._create_proto_relation() _cache_table = proto.CacheTable(table_name=self._table_name) if self._storage_level: _cache_table.storage_level.CopyFrom(storage_level_to_proto(self._storage_level)) - plan = proto.Relation(catalog=proto.Catalog(cache_table=_cache_table)) + plan.catalog.cache_table.CopyFrom(_cache_table) return plan @@ -1979,7 +1975,7 @@ class UncacheTable(LogicalPlan): self._table_name = table_name def plan(self, session: "SparkConnectClient") -> proto.Relation: - plan = proto.Relation(catalog=proto.Catalog(uncache_table=proto.UncacheTable())) + plan = self._create_proto_relation() plan.catalog.uncache_table.table_name = self._table_name return plan @@ -1989,7 +1985,9 @@ class ClearCache(LogicalPlan): super().__init__(None) def plan(self, session: "SparkConnectClient") -> proto.Relation: - return proto.Relation(catalog=proto.Catalog(clear_cache=proto.ClearCache())) + plan = self._create_proto_relation() + plan.catalog.clear_cache.SetInParent() + return plan class RefreshTable(LogicalPlan): @@ -1998,7 +1996,7 @@ class RefreshTable(LogicalPlan): self._table_name = table_name def plan(self, session: "SparkConnectClient") -> proto.Relation: - plan = proto.Relation(catalog=proto.Catalog(refresh_table=proto.RefreshTable())) + plan = self._create_proto_relation() plan.catalog.refresh_table.table_name = self._table_name return plan @@ -2009,7 +2007,7 @@ class RefreshByPath(LogicalPlan): self._path = path def plan(self, session: "SparkConnectClient") -> proto.Relation: - plan = proto.Relation(catalog=proto.Catalog(refresh_by_path=proto.RefreshByPath())) + plan = self._create_proto_relation() plan.catalog.refresh_by_path.path = self._path return plan @@ -2019,7 +2017,9 @@ class CurrentCatalog(LogicalPlan): super().__init__(None) def plan(self, session: "SparkConnectClient") -> proto.Relation: - return proto.Relation(catalog=proto.Catalog(current_catalog=proto.CurrentCatalog())) + plan = self._create_proto_relation() + plan.catalog.current_catalog.SetInParent() + return plan class SetCurrentCatalog(LogicalPlan): @@ -2028,7 +2028,7 @@ class SetCurrentCatalog(LogicalPlan): self._catalog_name = catalog_name def plan(self, session: "SparkConnectClient") -> proto.Relation: - plan = proto.Relation(catalog=proto.Catalog(set_current_catalog=proto.SetCurrentCatalog())) + plan = self._create_proto_relation() plan.catalog.set_current_catalog.catalog_name = self._catalog_name return plan @@ -2039,7 +2039,8 @@ class ListCatalogs(LogicalPlan): self._pattern = pattern def plan(self, session: "SparkConnectClient") -> proto.Relation: - plan = proto.Relation(catalog=proto.Catalog(list_catalogs=proto.ListCatalogs())) + plan = self._create_proto_relation() + plan.catalog.list_catalogs.SetInParent() if self._pattern is not None: plan.catalog.list_catalogs.pattern = self._pattern return plan --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org