This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 9d4abaf19fa4 [SPARK-48638][CONNECT] Add ExecutionInfo support for 
DataFrame
9d4abaf19fa4 is described below

commit 9d4abaf19fa4d508aca35ac00528b2ce9f4e8805
Author: Martin Grund <[email protected]>
AuthorDate: Wed Jun 26 08:35:32 2024 +0900

    [SPARK-48638][CONNECT] Add ExecutionInfo support for DataFrame
    
    ### What changes were proposed in this pull request?
    
    One of the interesting shortcomings in Spark Connect is that the query 
execution metrics are not easily accessible directly. In Spark Classic, the 
query execution is only accessible via the `_jdf` private variable and this is 
not available in Spark Connect.
    
    However, since the first release of Spark Connect, the response messages 
were already containing the metrics from the executed plan.
    
    This patch makes them accessible directly and provides a way to visualize 
them.
    
    ```python
    df = spark.range(100)
    df.collect()
    metrics = df.executionInfo.metrics
    metrics.toDot()
    ```
    
    The `toDot()` method returns an instance of the `graphviz.Digraph` object 
that can be either directly displayed in a notebook or further manipulated.
    
    <img width="909" alt="image" 
src="https://github.com/apache/spark/assets/3421/e1710350-54d2-4e6d-9b80-0aaf1b8583e3";>
    
    <img width="871" alt="image" 
src="https://github.com/apache/spark/assets/3421/6a972119-76b6-4e36-bc81-8d01110fa31c";>
    
    The purpose of the `executionInfo` property and the associated 
`ExecutionInfo` class is not to provide equivalence to the `QueryExecution` 
class used internally by Spark (and, for example, access to the analyzed, 
optimized, and executed plan) but rather provide a convenient way of accessing 
execution related information.
    
    ### Why are the changes needed?
    User Experience
    
    ### Does this PR introduce _any_ user-facing change?
    Adding a new API for accessing the query execution of a Spark SQL execution.
    
    ### How was this patch tested?
    Added new UT
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #46996 from grundprinzip/SPARK-48638.
    
    Lead-authored-by: Martin Grund <[email protected]>
    Co-authored-by: Martin Grund <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 .../spark/sql/connect/utils/MetricGenerator.scala  |   1 +
 dev/requirements.txt                               |   3 +
 dev/sparktestsupport/modules.py                    |   1 +
 python/docs/source/getting_started/install.rst     |   1 +
 .../source/reference/pyspark.sql/dataframe.rst     |   1 +
 python/pyspark/errors/error-conditions.json        |   5 +
 python/pyspark/sql/classic/dataframe.py            |   8 +
 python/pyspark/sql/connect/client/core.py          |  91 +++----
 python/pyspark/sql/connect/dataframe.py            |  44 +++-
 python/pyspark/sql/connect/readwriter.py           |  48 +++-
 python/pyspark/sql/connect/session.py              |   7 +-
 python/pyspark/sql/connect/streaming/query.py      |   4 +-
 python/pyspark/sql/connect/streaming/readwriter.py |   2 +-
 python/pyspark/sql/dataframe.py                    |  26 ++
 python/pyspark/sql/metrics.py                      | 287 +++++++++++++++++++++
 python/pyspark/sql/tests/connect/test_df_debug.py  |  86 ++++++
 python/pyspark/sql/tests/test_dataframe.py         |  11 +-
 python/pyspark/testing/connectutils.py             |   7 +
 18 files changed, 544 insertions(+), 89 deletions(-)

diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/MetricGenerator.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/MetricGenerator.scala
index e2e412831187..d76bec5454ab 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/MetricGenerator.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/MetricGenerator.scala
@@ -70,6 +70,7 @@ private[connect] object MetricGenerator extends 
AdaptiveSparkPlanHelper {
       .newBuilder()
       .setName(p.nodeName)
       .setPlanId(p.id)
+      .setParent(parentId)
       .putAllExecutionMetrics(mv.asJava)
       .build()
     Seq(mo) ++ transformChildren(p)
diff --git a/dev/requirements.txt b/dev/requirements.txt
index d6530d8ce282..88883a963950 100644
--- a/dev/requirements.txt
+++ b/dev/requirements.txt
@@ -60,6 +60,9 @@ mypy-protobuf==3.3.0
 googleapis-common-protos-stubs==2.2.0
 grpc-stubs==1.24.11
 
+# Debug for Spark and Spark Connect
+graphviz==0.20.3
+
 # TorchDistributor dependencies
 torch
 torchvision
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 8c17af559c25..66927066faa7 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -1063,6 +1063,7 @@ pyspark_connect = Module(
         "pyspark.sql.tests.connect.test_parity_pandas_udf_window",
         "pyspark.sql.tests.connect.test_resources",
         "pyspark.sql.tests.connect.shell.test_progress",
+        "pyspark.sql.tests.connect.test_df_debug",
     ],
     excluded_python_implementations=[
         "PyPy"  # Skip these tests under PyPy since they require numpy, 
pandas, and pyarrow and
diff --git a/python/docs/source/getting_started/install.rst 
b/python/docs/source/getting_started/install.rst
index 21926ae295bf..6cc68cd46b11 100644
--- a/python/docs/source/getting_started/install.rst
+++ b/python/docs/source/getting_started/install.rst
@@ -210,6 +210,7 @@ Package                    Supported version Note
 `grpcio`                   >=1.62.0          Required for Spark Connect
 `grpcio-status`            >=1.62.0          Required for Spark Connect
 `googleapis-common-protos` >=1.56.4          Required for Spark Connect
+`graphviz`                 >=0.20            Optional for Spark Connect
 ========================== ================= ==========================
 
 Spark SQL
diff --git a/python/docs/source/reference/pyspark.sql/dataframe.rst 
b/python/docs/source/reference/pyspark.sql/dataframe.rst
index ec39b645b140..d0196baa7a05 100644
--- a/python/docs/source/reference/pyspark.sql/dataframe.rst
+++ b/python/docs/source/reference/pyspark.sql/dataframe.rst
@@ -55,6 +55,7 @@ DataFrame
     DataFrame.dropna
     DataFrame.dtypes
     DataFrame.exceptAll
+    DataFrame.executionInfo
     DataFrame.explain
     DataFrame.fillna
     DataFrame.filter
diff --git a/python/pyspark/errors/error-conditions.json 
b/python/pyspark/errors/error-conditions.json
index 30db37387249..dd70e814b1ea 100644
--- a/python/pyspark/errors/error-conditions.json
+++ b/python/pyspark/errors/error-conditions.json
@@ -149,6 +149,11 @@
       "Cannot <condition1> without <condition2>."
     ]
   },
+  "CLASSIC_OPERATION_NOT_SUPPORTED_ON_DF": {
+    "message": [
+      "Calling property or member '<member>' is not supported in PySpark 
Classic, please use Spark Connect instead."
+    ]
+  },
   "COLLATION_INVALID_PROVIDER" : {
     "message" : [
       "The value <provider> does not represent a correct collation provider. 
Supported providers are: [<supportedProviders>]."
diff --git a/python/pyspark/sql/classic/dataframe.py 
b/python/pyspark/sql/classic/dataframe.py
index a03467aff194..1bedd624603e 100644
--- a/python/pyspark/sql/classic/dataframe.py
+++ b/python/pyspark/sql/classic/dataframe.py
@@ -94,6 +94,7 @@ if TYPE_CHECKING:
     from pyspark.sql.session import SparkSession
     from pyspark.sql.group import GroupedData
     from pyspark.sql.observation import Observation
+    from pyspark.sql.metrics import ExecutionInfo
 
 
 class DataFrame(ParentDataFrame, PandasMapOpsMixin, PandasConversionMixin):
@@ -1835,6 +1836,13 @@ class DataFrame(ParentDataFrame, PandasMapOpsMixin, 
PandasConversionMixin):
     def toPandas(self) -> "PandasDataFrameLike":
         return PandasConversionMixin.toPandas(self)
 
+    @property
+    def executionInfo(self) -> Optional["ExecutionInfo"]:
+        raise PySparkValueError(
+            error_class="CLASSIC_OPERATION_NOT_SUPPORTED_ON_DF",
+            message_parameters={"member": "queryExecution"},
+        )
+
 
 def _to_scala_map(sc: "SparkContext", jm: Dict) -> "JavaObject":
     """
diff --git a/python/pyspark/sql/connect/client/core.py 
b/python/pyspark/sql/connect/client/core.py
index f3bbab69f271..e91324150cbd 100644
--- a/python/pyspark/sql/connect/client/core.py
+++ b/python/pyspark/sql/connect/client/core.py
@@ -61,6 +61,7 @@ from pyspark.accumulators import SpecialAccumulatorIds
 from pyspark.loose_version import LooseVersion
 from pyspark.version import __version__
 from pyspark.resource.information import ResourceInformation
+from pyspark.sql.metrics import MetricValue, PlanMetrics, ExecutionInfo, 
ObservedMetrics
 from pyspark.sql.connect.client.artifact import ArtifactManager
 from pyspark.sql.connect.client.logging import logger
 from pyspark.sql.connect.profiler import ConnectProfilerCollector
@@ -447,56 +448,7 @@ class DefaultChannelBuilder(ChannelBuilder):
             return self._secure_channel(self.endpoint, creds)
 
 
-class MetricValue:
-    def __init__(self, name: str, value: Union[int, float], type: str):
-        self._name = name
-        self._type = type
-        self._value = value
-
-    def __repr__(self) -> str:
-        return f"<{self._name}={self._value} ({self._type})>"
-
-    @property
-    def name(self) -> str:
-        return self._name
-
-    @property
-    def value(self) -> Union[int, float]:
-        return self._value
-
-    @property
-    def metric_type(self) -> str:
-        return self._type
-
-
-class PlanMetrics:
-    def __init__(self, name: str, id: int, parent: int, metrics: 
List[MetricValue]):
-        self._name = name
-        self._id = id
-        self._parent_id = parent
-        self._metrics = metrics
-
-    def __repr__(self) -> str:
-        return f"Plan({self._name})={self._metrics}"
-
-    @property
-    def name(self) -> str:
-        return self._name
-
-    @property
-    def plan_id(self) -> int:
-        return self._id
-
-    @property
-    def parent_plan_id(self) -> int:
-        return self._parent_id
-
-    @property
-    def metrics(self) -> List[MetricValue]:
-        return self._metrics
-
-
-class PlanObservedMetrics:
+class PlanObservedMetrics(ObservedMetrics):
     def __init__(self, name: str, metrics: List[pb2.Expression.Literal], keys: 
List[str]):
         self._name = name
         self._metrics = metrics
@@ -513,6 +465,13 @@ class PlanObservedMetrics:
     def metrics(self) -> List[pb2.Expression.Literal]:
         return self._metrics
 
+    @property
+    def pairs(self) -> dict[str, Any]:
+        result = {}
+        for x in range(len(self._metrics)):
+            result[self.keys[x]] = LiteralExpression._to_value(self.metrics[x])
+        return result
+
     @property
     def keys(self) -> List[str]:
         return self._keys
@@ -888,7 +847,7 @@ class SparkConnectClient(object):
         logger.info("Fetching the resources")
         cmd = pb2.Command()
         cmd.get_resources_command.SetInParent()
-        (_, properties) = self.execute_command(cmd)
+        (_, properties, _) = self.execute_command(cmd)
         resources = properties["get_resources_command_result"]
         return resources
 
@@ -915,18 +874,23 @@ class SparkConnectClient(object):
 
     def to_table(
         self, plan: pb2.Plan, observations: Dict[str, Observation]
-    ) -> Tuple["pa.Table", Optional[StructType]]:
+    ) -> Tuple["pa.Table", Optional[StructType], ExecutionInfo]:
         """
         Return given plan as a PyArrow Table.
         """
         logger.info(f"Executing plan {self._proto_to_string(plan)}")
         req = self._execute_plan_request_with_metadata()
         req.plan.CopyFrom(plan)
-        table, schema, _, _, _ = self._execute_and_fetch(req, observations)
+        table, schema, metrics, observed_metrics, _ = 
self._execute_and_fetch(req, observations)
+
+        # Create a query execution object.
+        ei = ExecutionInfo(metrics, observed_metrics)
         assert table is not None
-        return table, schema
+        return table, schema, ei
 
-    def to_pandas(self, plan: pb2.Plan, observations: Dict[str, Observation]) 
-> "pd.DataFrame":
+    def to_pandas(
+        self, plan: pb2.Plan, observations: Dict[str, Observation]
+    ) -> Tuple["pd.DataFrame", "ExecutionInfo"]:
         """
         Return given plan as a pandas DataFrame.
         """
@@ -941,6 +905,7 @@ class SparkConnectClient(object):
             req, observations, self_destruct=self_destruct
         )
         assert table is not None
+        ei = ExecutionInfo(metrics, observed_metrics)
 
         schema = schema or from_arrow_schema(table.schema, 
prefer_timestamp_ntz=True)
         assert schema is not None and isinstance(schema, StructType)
@@ -1007,7 +972,7 @@ class SparkConnectClient(object):
             pdf.attrs["metrics"] = metrics
         if len(observed_metrics) > 0:
             pdf.attrs["observed_metrics"] = observed_metrics
-        return pdf
+        return pdf, ei
 
     def _proto_to_string(self, p: google.protobuf.message.Message) -> str:
         """
@@ -1051,7 +1016,7 @@ class SparkConnectClient(object):
 
     def execute_command(
         self, command: pb2.Command, observations: Optional[Dict[str, 
Observation]] = None
-    ) -> Tuple[Optional[pd.DataFrame], Dict[str, Any]]:
+    ) -> Tuple[Optional[pd.DataFrame], Dict[str, Any], ExecutionInfo]:
         """
         Execute given command.
         """
@@ -1060,11 +1025,15 @@ class SparkConnectClient(object):
         if self._user_id:
             req.user_context.user_id = self._user_id
         req.plan.command.CopyFrom(command)
-        data, _, _, _, properties = self._execute_and_fetch(req, observations 
or {})
+        data, _, metrics, observed_metrics, properties = 
self._execute_and_fetch(
+            req, observations or {}
+        )
+        # Create a query execution object.
+        ei = ExecutionInfo(metrics, observed_metrics)
         if data is not None:
-            return (data.to_pandas(), properties)
+            return (data.to_pandas(), properties, ei)
         else:
-            return (None, properties)
+            return (None, properties, ei)
 
     def execute_command_as_iterator(
         self, command: pb2.Command, observations: Optional[Dict[str, 
Observation]] = None
@@ -1849,6 +1818,6 @@ class SparkConnectClient(object):
         logger.info("Creating the ResourceProfile")
         cmd = pb2.Command()
         cmd.create_resource_profile_command.profile.CopyFrom(profile)
-        (_, properties) = self.execute_command(cmd)
+        (_, properties, _) = self.execute_command(cmd)
         profile_id = properties["create_resource_profile_command_result"]
         return profile_id
diff --git a/python/pyspark/sql/connect/dataframe.py 
b/python/pyspark/sql/connect/dataframe.py
index 678e66ee2b7b..1aa8fc00cfcc 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -101,6 +101,7 @@ if TYPE_CHECKING:
     from pyspark.sql.connect.observation import Observation
     from pyspark.sql.connect.session import SparkSession
     from pyspark.pandas.frame import DataFrame as PandasOnSparkDataFrame
+    from pyspark.sql.metrics import ExecutionInfo
 
 
 class DataFrame(ParentDataFrame):
@@ -137,6 +138,7 @@ class DataFrame(ParentDataFrame):
         # by __repr__ and _repr_html_ while eager evaluation opens.
         self._support_repr_html = False
         self._cached_schema: Optional[StructType] = None
+        self._execution_info: Optional["ExecutionInfo"] = None
 
     def __reduce__(self) -> Tuple:
         """
@@ -206,7 +208,10 @@ class DataFrame(ParentDataFrame):
 
     @property
     def write(self) -> "DataFrameWriter":
-        return DataFrameWriter(self._plan, self._session)
+        def cb(qe: "ExecutionInfo") -> None:
+            self._execution_info = qe
+
+        return DataFrameWriter(self._plan, self._session, cb)
 
     @functools.cache
     def isEmpty(self) -> bool:
@@ -1839,7 +1844,9 @@ class DataFrame(ParentDataFrame):
 
     def _to_table(self) -> Tuple["pa.Table", Optional[StructType]]:
         query = self._plan.to_proto(self._session.client)
-        table, schema = self._session.client.to_table(query, 
self._plan.observations)
+        table, schema, self._execution_info = self._session.client.to_table(
+            query, self._plan.observations
+        )
         assert table is not None
         return (table, schema)
 
@@ -1850,7 +1857,9 @@ class DataFrame(ParentDataFrame):
 
     def toPandas(self) -> "PandasDataFrameLike":
         query = self._plan.to_proto(self._session.client)
-        return self._session.client.to_pandas(query, self._plan.observations)
+        pdf, ei = self._session.client.to_pandas(query, 
self._plan.observations)
+        self._execution_info = ei
+        return pdf
 
     @property
     def schema(self) -> StructType:
@@ -1976,25 +1985,29 @@ class DataFrame(ParentDataFrame):
         command = plan.CreateView(
             child=self._plan, name=name, is_global=False, replace=False
         ).command(session=self._session.client)
-        self._session.client.execute_command(command, self._plan.observations)
+        _, _, ei = self._session.client.execute_command(command, 
self._plan.observations)
+        self._execution_info = ei
 
     def createOrReplaceTempView(self, name: str) -> None:
         command = plan.CreateView(
             child=self._plan, name=name, is_global=False, replace=True
         ).command(session=self._session.client)
-        self._session.client.execute_command(command, self._plan.observations)
+        _, _, ei = self._session.client.execute_command(command, 
self._plan.observations)
+        self._execution_info = ei
 
     def createGlobalTempView(self, name: str) -> None:
         command = plan.CreateView(
             child=self._plan, name=name, is_global=True, replace=False
         ).command(session=self._session.client)
-        self._session.client.execute_command(command, self._plan.observations)
+        _, _, ei = self._session.client.execute_command(command, 
self._plan.observations)
+        self._execution_info = ei
 
     def createOrReplaceGlobalTempView(self, name: str) -> None:
         command = plan.CreateView(
             child=self._plan, name=name, is_global=True, replace=True
         ).command(session=self._session.client)
-        self._session.client.execute_command(command, self._plan.observations)
+        _, _, ei = self._session.client.execute_command(command, 
self._plan.observations)
+        self._execution_info = ei
 
     def cache(self) -> ParentDataFrame:
         return self.persist()
@@ -2169,14 +2182,19 @@ class DataFrame(ParentDataFrame):
         )
 
     def writeTo(self, table: str) -> "DataFrameWriterV2":
-        return DataFrameWriterV2(self._plan, self._session, table)
+        def cb(ei: "ExecutionInfo") -> None:
+            self._execution_info = ei
+
+        return DataFrameWriterV2(self._plan, self._session, table, cb)
 
     def offset(self, n: int) -> ParentDataFrame:
         return DataFrame(plan.Offset(child=self._plan, offset=n), 
session=self._session)
 
     def checkpoint(self, eager: bool = True) -> "DataFrame":
         cmd = plan.Checkpoint(child=self._plan, local=False, eager=eager)
-        _, properties = 
self._session.client.execute_command(cmd.command(self._session.client))
+        _, properties, self._execution_info = 
self._session.client.execute_command(
+            cmd.command(self._session.client)
+        )
         assert "checkpoint_command_result" in properties
         checkpointed = properties["checkpoint_command_result"]
         assert isinstance(checkpointed._plan, plan.CachedRemoteRelation)
@@ -2184,7 +2202,9 @@ class DataFrame(ParentDataFrame):
 
     def localCheckpoint(self, eager: bool = True) -> "DataFrame":
         cmd = plan.Checkpoint(child=self._plan, local=True, eager=eager)
-        _, properties = 
self._session.client.execute_command(cmd.command(self._session.client))
+        _, properties, self._execution_info = 
self._session.client.execute_command(
+            cmd.command(self._session.client)
+        )
         assert "checkpoint_command_result" in properties
         checkpointed = properties["checkpoint_command_result"]
         assert isinstance(checkpointed._plan, plan.CachedRemoteRelation)
@@ -2205,6 +2225,10 @@ class DataFrame(ParentDataFrame):
                 message_parameters={"feature": "rdd"},
             )
 
+    @property
+    def executionInfo(self) -> Optional["ExecutionInfo"]:
+        return self._execution_info
+
 
 class DataFrameNaFunctions(ParentDataFrameNaFunctions):
     def __init__(self, df: ParentDataFrame):
diff --git a/python/pyspark/sql/connect/readwriter.py 
b/python/pyspark/sql/connect/readwriter.py
index bf7dc4d36905..de62cf65b01e 100644
--- a/python/pyspark/sql/connect/readwriter.py
+++ b/python/pyspark/sql/connect/readwriter.py
@@ -19,7 +19,7 @@ from pyspark.sql.connect.utils import check_dependencies
 check_dependencies(__name__)
 
 from typing import Dict
-from typing import Optional, Union, List, overload, Tuple, cast
+from typing import Optional, Union, List, overload, Tuple, cast, Callable
 from typing import TYPE_CHECKING
 
 from pyspark.sql.connect.plan import Read, DataSource, LogicalPlan, 
WriteOperation, WriteOperationV2
@@ -37,6 +37,7 @@ if TYPE_CHECKING:
     from pyspark.sql.connect.dataframe import DataFrame
     from pyspark.sql.connect._typing import ColumnOrName, OptionalPrimitiveType
     from pyspark.sql.connect.session import SparkSession
+    from pyspark.sql.metrics import ExecutionInfo
 
 __all__ = ["DataFrameReader", "DataFrameWriter"]
 
@@ -486,11 +487,18 @@ DataFrameReader.__doc__ = PySparkDataFrameReader.__doc__
 
 
 class DataFrameWriter(OptionUtils):
-    def __init__(self, plan: "LogicalPlan", session: "SparkSession"):
+    def __init__(
+        self,
+        plan: "LogicalPlan",
+        session: "SparkSession",
+        callback: Optional[Callable[["ExecutionInfo"], None]] = None,
+    ):
         self._df: "LogicalPlan" = plan
         self._spark: "SparkSession" = session
         self._write: "WriteOperation" = WriteOperation(self._df)
 
+        self._callback = callback if callback is not None else lambda _: None
+
     def mode(self, saveMode: Optional[str]) -> "DataFrameWriter":
         # At the JVM side, the default value of mode is already set to "error".
         # So, if the given saveMode is None, we will not call JVM-side's mode 
method.
@@ -649,9 +657,10 @@ class DataFrameWriter(OptionUtils):
         if format is not None:
             self.format(format)
         self._write.path = path
-        self._spark.client.execute_command(
+        _, _, ei = self._spark.client.execute_command(
             self._write.command(self._spark.client), self._write.observations
         )
+        self._callback(ei)
 
     save.__doc__ = PySparkDataFrameWriter.save.__doc__
 
@@ -660,9 +669,10 @@ class DataFrameWriter(OptionUtils):
             self.mode("overwrite" if overwrite else "append")
         self._write.table_name = tableName
         self._write.table_save_method = "insert_into"
-        self._spark.client.execute_command(
+        _, _, ei = self._spark.client.execute_command(
             self._write.command(self._spark.client), self._write.observations
         )
+        self._callback(ei)
 
     insertInto.__doc__ = PySparkDataFrameWriter.insertInto.__doc__
 
@@ -681,9 +691,10 @@ class DataFrameWriter(OptionUtils):
             self.format(format)
         self._write.table_name = name
         self._write.table_save_method = "save_as_table"
-        self._spark.client.execute_command(
+        _, _, ei = self._spark.client.execute_command(
             self._write.command(self._spark.client), self._write.observations
         )
+        self._callback(ei)
 
     saveAsTable.__doc__ = PySparkDataFrameWriter.saveAsTable.__doc__
 
@@ -845,11 +856,18 @@ class DataFrameWriter(OptionUtils):
 
 
 class DataFrameWriterV2(OptionUtils):
-    def __init__(self, plan: "LogicalPlan", session: "SparkSession", table: 
str):
+    def __init__(
+        self,
+        plan: "LogicalPlan",
+        session: "SparkSession",
+        table: str,
+        callback: Optional[Callable[["ExecutionInfo"], None]] = None,
+    ):
         self._df: "LogicalPlan" = plan
         self._spark: "SparkSession" = session
         self._table_name: str = table
         self._write: "WriteOperationV2" = WriteOperationV2(self._df, 
self._table_name)
+        self._callback = callback if callback is not None else lambda _: None
 
     def using(self, provider: str) -> "DataFrameWriterV2":
         self._write.provider = provider
@@ -884,50 +902,56 @@ class DataFrameWriterV2(OptionUtils):
 
     def create(self) -> None:
         self._write.mode = "create"
-        self._spark.client.execute_command(
+        _, _, ei = self._spark.client.execute_command(
             self._write.command(self._spark.client), self._write.observations
         )
+        self._callback(ei)
 
     create.__doc__ = PySparkDataFrameWriterV2.create.__doc__
 
     def replace(self) -> None:
         self._write.mode = "replace"
-        self._spark.client.execute_command(
+        _, _, ei = self._spark.client.execute_command(
             self._write.command(self._spark.client), self._write.observations
         )
+        self._callback(ei)
 
     replace.__doc__ = PySparkDataFrameWriterV2.replace.__doc__
 
     def createOrReplace(self) -> None:
         self._write.mode = "create_or_replace"
-        self._spark.client.execute_command(
+        _, _, ei = self._spark.client.execute_command(
             self._write.command(self._spark.client), self._write.observations
         )
+        self._callback(ei)
 
     createOrReplace.__doc__ = PySparkDataFrameWriterV2.createOrReplace.__doc__
 
     def append(self) -> None:
         self._write.mode = "append"
-        self._spark.client.execute_command(
+        _, _, ei = self._spark.client.execute_command(
             self._write.command(self._spark.client), self._write.observations
         )
+        self._callback(ei)
 
     append.__doc__ = PySparkDataFrameWriterV2.append.__doc__
 
     def overwrite(self, condition: "ColumnOrName") -> None:
         self._write.mode = "overwrite"
         self._write.overwrite_condition = F._to_col(condition)
-        self._spark.client.execute_command(
+        _, _, ei = self._spark.client.execute_command(
             self._write.command(self._spark.client), self._write.observations
         )
+        self._callback(ei)
 
     overwrite.__doc__ = PySparkDataFrameWriterV2.overwrite.__doc__
 
     def overwritePartitions(self) -> None:
         self._write.mode = "overwrite_partitions"
-        self._spark.client.execute_command(
+        _, _, ei = self._spark.client.execute_command(
             self._write.command(self._spark.client), self._write.observations
         )
+        self._callback(ei)
 
     overwritePartitions.__doc__ = 
PySparkDataFrameWriterV2.overwritePartitions.__doc__
 
diff --git a/python/pyspark/sql/connect/session.py 
b/python/pyspark/sql/connect/session.py
index f359ab829483..8e277b3fc63a 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -720,9 +720,12 @@ class SparkSession:
                 _views.append(SubqueryAlias(df._plan, name))
 
         cmd = SQL(sqlQuery, _args, _named_args, _views)
-        data, properties = 
self.client.execute_command(cmd.command(self._client))
+        data, properties, ei = 
self.client.execute_command(cmd.command(self._client))
         if "sql_command_result" in properties:
-            return DataFrame(CachedRelation(properties["sql_command_result"]), 
self)
+            df = DataFrame(CachedRelation(properties["sql_command_result"]), 
self)
+            # A command result contains the execution.
+            df._execution_info = ei
+            return df
         else:
             return DataFrame(cmd, self)
 
diff --git a/python/pyspark/sql/connect/streaming/query.py 
b/python/pyspark/sql/connect/streaming/query.py
index 98ecdc4966c7..13458d650fa9 100644
--- a/python/pyspark/sql/connect/streaming/query.py
+++ b/python/pyspark/sql/connect/streaming/query.py
@@ -181,7 +181,7 @@ class StreamingQuery:
         cmd.query_id.run_id = self._run_id
         exec_cmd = pb2.Command()
         exec_cmd.streaming_query_command.CopyFrom(cmd)
-        (_, properties) = self._session.client.execute_command(exec_cmd)
+        (_, properties, _) = self._session.client.execute_command(exec_cmd)
         return cast(pb2.StreamingQueryCommandResult, 
properties["streaming_query_command_result"])
 
 
@@ -260,7 +260,7 @@ class StreamingQueryManager:
     ) -> pb2.StreamingQueryManagerCommandResult:
         exec_cmd = pb2.Command()
         exec_cmd.streaming_query_manager_command.CopyFrom(cmd)
-        (_, properties) = self._session.client.execute_command(exec_cmd)
+        (_, properties, _) = self._session.client.execute_command(exec_cmd)
         return cast(
             pb2.StreamingQueryManagerCommandResult,
             properties["streaming_query_manager_command_result"],
diff --git a/python/pyspark/sql/connect/streaming/readwriter.py 
b/python/pyspark/sql/connect/streaming/readwriter.py
index b5bb7f2a0912..9b11bf328b85 100644
--- a/python/pyspark/sql/connect/streaming/readwriter.py
+++ b/python/pyspark/sql/connect/streaming/readwriter.py
@@ -601,7 +601,7 @@ class DataStreamWriter:
             self._write_proto.table_name = tableName
 
         cmd = self._write_stream.command(self._session.client)
-        (_, properties) = self._session.client.execute_command(cmd)
+        (_, properties, _) = self._session.client.execute_command(cmd)
 
         start_result = cast(
             pb2.WriteStreamOperationStartResult, 
properties["write_stream_operation_start_result"]
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 62c46cfec93c..625678588bf9 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -64,6 +64,7 @@ if TYPE_CHECKING:
         ArrowMapIterFunction,
         DataFrameLike as PandasDataFrameLike,
     )
+    from pyspark.sql.metrics import ExecutionInfo
 
 
 __all__ = ["DataFrame", "DataFrameNaFunctions", "DataFrameStatFunctions"]
@@ -6281,6 +6282,31 @@ class DataFrame:
         """
         ...
 
+    @property
+    def executionInfo(self) -> Optional["ExecutionInfo"]:
+        """
+        Returns a QueryExecution object after the query was executed.
+
+        The queryExecution method allows to introspect information about the 
actual
+        query execution after the successful execution. Accessing this member 
before
+        the query execution will return None.
+
+        If the same DataFrame is executed multiple times, the execution info 
will be
+        overwritten by the latest operation.
+
+        .. versionadded:: 4.0.0
+
+        Returns
+        -------
+        An instance of QueryExecution or None when the value is not set yet.
+
+        Notes
+        -----
+        This is an API dedicated to Spark Connect client only. With regular 
Spark Session, it throws
+        an exception.
+        """
+        ...
+
 
 class DataFrameNaFunctions:
     """Functionality for working with missing data in :class:`DataFrame`.
diff --git a/python/pyspark/sql/metrics.py b/python/pyspark/sql/metrics.py
new file mode 100644
index 000000000000..666458295201
--- /dev/null
+++ b/python/pyspark/sql/metrics.py
@@ -0,0 +1,287 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import abc
+import dataclasses
+from typing import Optional, List, Tuple, Dict, Any, Union, TYPE_CHECKING, 
Sequence
+
+from pyspark.errors import PySparkValueError
+
+if TYPE_CHECKING:
+    from pyspark.testing.connectutils import have_graphviz
+
+    if have_graphviz:
+        import graphviz  # type: ignore
+
+
+class ObservedMetrics(abc.ABC):
+    @property
+    @abc.abstractmethod
+    def name(self) -> str:
+        ...
+
+    @property
+    @abc.abstractmethod
+    def pairs(self) -> Dict[str, Any]:
+        ...
+
+    @property
+    @abc.abstractmethod
+    def keys(self) -> List[str]:
+        ...
+
+
+class MetricValue:
+    """The metric values is the Python representation of a plan metric value 
from the JVM.
+    However, it does not have any reference to the original value."""
+
+    def __init__(self, name: str, value: Union[int, float], type: str):
+        self._name = name
+        self._type = type
+        self._value = value
+
+    def __repr__(self) -> str:
+        return f"<{self._name}={self._value} ({self._type})>"
+
+    @property
+    def name(self) -> str:
+        return self._name
+
+    @property
+    def value(self) -> Union[int, float]:
+        return self._value
+
+    @property
+    def metric_type(self) -> str:
+        return self._type
+
+
+class PlanMetrics:
+    """Represents a particular plan node and the associated metrics of this 
node."""
+
+    def __init__(self, name: str, id: int, parent: int, metrics: 
List[MetricValue]):
+        self._name = name
+        self._id = id
+        self._parent_id = parent
+        self._metrics = metrics
+
+    def __repr__(self) -> str:
+        return f"Plan({self._name}: 
{self._id}->{self._parent_id})={self._metrics}"
+
+    @property
+    def name(self) -> str:
+        return self._name
+
+    @property
+    def plan_id(self) -> int:
+        return self._id
+
+    @property
+    def parent_plan_id(self) -> int:
+        return self._parent_id
+
+    @property
+    def metrics(self) -> List[MetricValue]:
+        return self._metrics
+
+
+class CollectedMetrics:
+    @dataclasses.dataclass
+    class Node:
+        id: int
+        name: str = dataclasses.field(default="")
+        metrics: List[MetricValue] = dataclasses.field(default_factory=list)
+        children: List[int] = dataclasses.field(default_factory=list)
+
+    def text(self, current: "Node", graph: Dict[int, "Node"], prefix: str = 
"") -> str:
+        """
+        Converts the current node and its children into a textual 
representation. This is used
+        to provide a usable output for the command line or other text-based 
interfaces. However,
+        it is recommended to use the Graphviz representation for a more visual 
representation.
+
+        Parameters
+        ----------
+        current: Node
+            Current node in the graph.
+        graph: dict
+            A dictionary representing the full graph mapping from node ID 
(int) to the node itself.
+            The node is an instance of :class:`CollectedMetrics:Node`.
+        prefix: str
+            String prefix used for generating the output buffer.
+
+        Returns
+        -------
+        The full string representation of the current node as root.
+        """
+        base_metrics = set(["numPartitions", "peakMemory", "numOutputRows", 
"spillSize"])
+
+        # Format the metrics of this node:
+        metric_buffer = []
+        for m in current.metrics:
+            if m.name in base_metrics:
+                metric_buffer.append(f"{m.name}: {m.value} ({m.metric_type})")
+
+        buffer = f"{prefix}+- {current.name}({','.join(metric_buffer)})\n"
+        for i, child in enumerate(current.children):
+            c = graph[child]
+            new_prefix = prefix + "  " if i == len(c.children) - 1 else prefix
+            if current.id != c.id:
+                buffer += self.text(c, graph, new_prefix)
+        return buffer
+
+    def __init__(self, metrics: List[PlanMetrics]):
+        # Sort the input list
+        self._metrics = sorted(metrics, key=lambda x: x._parent_id, 
reverse=False)
+
+    def extract_graph(self) -> Tuple[int, Dict[int, "CollectedMetrics.Node"]]:
+        """
+        Builds the graph of the query plan. The graph is represented as a 
dictionary where the key
+        is the node ID and the value is the node itself. The root node is the 
node that has no
+        parent.
+
+        Returns
+        -------
+        The root node ID and the graph of all nodes.
+        """
+        all_nodes: Dict[int, CollectedMetrics.Node] = {}
+
+        for m in self._metrics:
+            # Add yourself to the list if you have to.
+            if m.plan_id not in all_nodes:
+                all_nodes[m.plan_id] = CollectedMetrics.Node(m.plan_id, 
m.name, m.metrics)
+            else:
+                all_nodes[m.plan_id].name = m.name
+                all_nodes[m.plan_id].metrics = m.metrics
+
+            # Now check for the parent of this node if it's in
+            if m.parent_plan_id not in all_nodes:
+                all_nodes[m.parent_plan_id] = 
CollectedMetrics.Node(m.parent_plan_id)
+
+            all_nodes[m.parent_plan_id].children.append(m.plan_id)
+
+        # Next step is to find all the root nodes. Root nodes are never used 
in children.
+        # So we start with all node ids as candidates.
+        candidates = set(all_nodes.keys())
+        for k, v in all_nodes.items():
+            for c in v.children:
+                if c in candidates and c != k:
+                    candidates.remove(c)
+
+        assert len(candidates) == 1, f"Expected 1 root node, found 
{len(candidates)}"
+        return candidates.pop(), all_nodes
+
+    def toText(self) -> str:
+        """
+        Converts the execution graph from a graph into a textual representation
+        that can be read at the command line for example.
+
+        Returns
+        -------
+        A string representation of the collected metrics.
+        """
+        root, graph = self.extract_graph()
+        return self.text(graph[root], graph)
+
+    def toDot(self, filename: Optional[str] = None, out_format: str = "png") 
-> "graphviz.Digraph":
+        """
+        Converts the collected metrics into a dot representation. Since the 
graphviz Digraph
+        implementation provides the ability to render the result graph 
directory in a
+        notebook, we return the graph object directly.
+
+        If the graphviz package is not available, a PACKAGE_NOT_INSTALLED 
error is raised.
+
+        Parameters
+        ----------
+        filename : str, optional
+            The filename to save the graph to given an output format. The path 
can be
+            relative or absolute.
+
+        out_format : str
+            The output format of the graph. The default is 'png'.
+
+        Returns
+        -------
+        An instance of the graphviz.Digraph object.
+        """
+        try:
+            import graphviz
+
+            dot = graphviz.Digraph(
+                comment="Query Plan",
+                node_attr={
+                    "shape": "box",
+                    "font-size": "10pt",
+                },
+            )
+
+            root, graph = self.extract_graph()
+            for k, v in graph.items():
+                # Build table rows for the metrics
+                rows = "\n".join(
+                    [
+                        (
+                            f'<TR><TD><FONT 
POINT-SIZE="8">{x.name}</FONT></TD><TD>'
+                            f'<FONT POINT-SIZE="8">{x.value} 
({x.metric_type})</FONT></TD></TR>'
+                        )
+                        for x in v.metrics
+                    ]
+                )
+
+                dot.node(
+                    str(k),
+                    """<<TABLE BORDER="0" CELLBORDER="1" CELLSPACING="0">
+                    <TR>
+                        <TD COLSPAN="2" BGCOLOR="lightgrey">
+                            <FONT POINT-SIZE=\"10\">{}</FONT>
+                        </TD>
+                    </TR>
+                    <TR><TD COLSPAN="2"><FONT 
POINT-SIZE=\"10\">Metrics</FONT></TD></TR>
+                    {}
+                    </TABLE>>""".format(
+                        v.name, rows
+                    ),
+                )
+                for c in v.children:
+                    dot.edge(str(k), str(c))
+
+            if filename:
+                dot.render(filename, format=out_format, cleanup=True)
+            return dot
+
+        except ImportError:
+            raise PySparkValueError(
+                error_class="PACKAGE_NOT_INSTALLED",
+                message_parameters={"package_name": "graphviz", 
"minimum_version": "0.20"},
+            )
+
+
+class ExecutionInfo:
+    """The query execution class allows users to inspect the query execution 
of this particular
+    data frame. This value is only set in the data frame if it was executed."""
+
+    def __init__(
+        self, metrics: Optional[list[PlanMetrics]], obs: 
Optional[Sequence[ObservedMetrics]]
+    ):
+        self._metrics = CollectedMetrics(metrics) if metrics else None
+        self._observations = obs if obs else []
+
+    @property
+    def metrics(self) -> Optional[CollectedMetrics]:
+        return self._metrics
+
+    @property
+    def flows(self) -> List[Tuple[str, Dict[str, Any]]]:
+        return [(f.name, f.pairs) for f in self._observations]
diff --git a/python/pyspark/sql/tests/connect/test_df_debug.py 
b/python/pyspark/sql/tests/connect/test_df_debug.py
new file mode 100644
index 000000000000..8a4ec68fda84
--- /dev/null
+++ b/python/pyspark/sql/tests/connect/test_df_debug.py
@@ -0,0 +1,86 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+from pyspark.testing.connectutils import (
+    should_test_connect,
+    have_graphviz,
+    graphviz_requirement_message,
+)
+from pyspark.sql.tests.connect.test_connect_basic import 
SparkConnectSQLTestCase
+
+if should_test_connect:
+    from pyspark.sql.connect.dataframe import DataFrame
+
+
+class SparkConnectDataFrameDebug(SparkConnectSQLTestCase):
+    def test_df_debug_basics(self):
+        df: DataFrame = 
self.connect.range(100).repartition(10).groupBy("id").count()
+        x = df.collect()  # noqa: F841
+        ei = df.executionInfo
+
+        root, graph = ei.metrics.extract_graph()
+        self.assertIn(root, graph, "The root must be rooted in the graph")
+
+    def test_df_quey_execution_empty_before_execution(self):
+        df: DataFrame = 
self.connect.range(100).repartition(10).groupBy("id").count()
+        ei = df.executionInfo
+        self.assertIsNone(ei, "The query execution must be None before the 
action is executed")
+
+    def test_df_query_execution_with_writes(self):
+        df: DataFrame = 
self.connect.range(100).repartition(10).groupBy("id").count()
+        df.write.save("/tmp/test_df_query_execution_with_writes", 
format="json", mode="overwrite")
+        ei = df.executionInfo
+        self.assertIsNotNone(
+            ei, "The query execution must be None after the write action is 
executed"
+        )
+
+    def test_query_execution_text_format(self):
+        df: DataFrame = 
self.connect.range(100).repartition(10).groupBy("id").count()
+        df.collect()
+        self.assertIn("HashAggregate", df.executionInfo.metrics.toText())
+
+        # Different execution mode.
+        df: DataFrame = 
self.connect.range(100).repartition(10).groupBy("id").count()
+        df.toPandas()
+        self.assertIn("HashAggregate", df.executionInfo.metrics.toText())
+
+    @unittest.skipIf(not have_graphviz, graphviz_requirement_message)
+    def test_df_query_execution_metrics_to_dot(self):
+        df: DataFrame = 
self.connect.range(100).repartition(10).groupBy("id").count()
+        x = df.collect()  # noqa: F841
+        ei = df.executionInfo
+
+        dot = ei.metrics.toDot()
+        source = dot.source
+        self.assertIsNotNone(dot, "The dot representation must not be None")
+        self.assertGreater(len(source), 0, "The dot representation must not be 
empty")
+        self.assertIn("digraph", source, "The dot representation must contain 
the digraph keyword")
+        self.assertIn("Metrics", source, "The dot representation must contain 
the Metrics keyword")
+
+
+if __name__ == "__main__":
+    from pyspark.sql.tests.connect.test_df_debug import *  # noqa: F401
+
+    try:
+        import xmlrunner  # type: ignore
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/tests/test_dataframe.py 
b/python/pyspark/sql/tests/test_dataframe.py
index d7b31bbc5215..c7cf43a33454 100644
--- a/python/pyspark/sql/tests/test_dataframe.py
+++ b/python/pyspark/sql/tests/test_dataframe.py
@@ -37,6 +37,7 @@ from pyspark.errors import (
     AnalysisException,
     IllegalArgumentException,
     PySparkTypeError,
+    PySparkValueError,
 )
 from pyspark.testing.sqlutils import (
     ReusedSQLTestCase,
@@ -849,7 +850,15 @@ class DataFrameTestsMixin:
 
 
 class DataFrameTests(DataFrameTestsMixin, ReusedSQLTestCase):
-    pass
+    def test_query_execution_unsupported_in_classic(self):
+        with self.assertRaises(PySparkValueError) as pe:
+            self.spark.range(1).executionInfo
+
+        self.check_error(
+            exception=pe.exception,
+            error_class="CLASSIC_OPERATION_NOT_SUPPORTED_ON_DF",
+            message_parameters={"member": "queryExecution"},
+        )
 
 
 if __name__ == "__main__":
diff --git a/python/pyspark/testing/connectutils.py 
b/python/pyspark/testing/connectutils.py
index 191505741eb4..b3004693724b 100644
--- a/python/pyspark/testing/connectutils.py
+++ b/python/pyspark/testing/connectutils.py
@@ -45,6 +45,13 @@ except ImportError as e:
     googleapis_common_protos_requirement_message = str(e)
 have_googleapis_common_protos = googleapis_common_protos_requirement_message 
is None
 
+graphviz_requirement_message = None
+try:
+    import graphviz
+except ImportError as e:
+    graphviz_requirement_message = str(e)
+have_graphviz: bool = graphviz_requirement_message is None
+
 from pyspark import Row, SparkConf
 from pyspark.util import is_remote_only
 from pyspark.testing.utils import PySparkErrorTestUtils


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to