This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 9d4abaf19fa4 [SPARK-48638][CONNECT] Add ExecutionInfo support for
DataFrame
9d4abaf19fa4 is described below
commit 9d4abaf19fa4d508aca35ac00528b2ce9f4e8805
Author: Martin Grund <[email protected]>
AuthorDate: Wed Jun 26 08:35:32 2024 +0900
[SPARK-48638][CONNECT] Add ExecutionInfo support for DataFrame
### What changes were proposed in this pull request?
One of the interesting shortcomings in Spark Connect is that the query
execution metrics are not easily accessible directly. In Spark Classic, the
query execution is only accessible via the `_jdf` private variable and this is
not available in Spark Connect.
However, since the first release of Spark Connect, the response messages
were already containing the metrics from the executed plan.
This patch makes them accessible directly and provides a way to visualize
them.
```python
df = spark.range(100)
df.collect()
metrics = df.executionInfo.metrics
metrics.toDot()
```
The `toDot()` method returns an instance of the `graphviz.Digraph` object
that can be either directly displayed in a notebook or further manipulated.
<img width="909" alt="image"
src="https://github.com/apache/spark/assets/3421/e1710350-54d2-4e6d-9b80-0aaf1b8583e3">
<img width="871" alt="image"
src="https://github.com/apache/spark/assets/3421/6a972119-76b6-4e36-bc81-8d01110fa31c">
The purpose of the `executionInfo` property and the associated
`ExecutionInfo` class is not to provide equivalence to the `QueryExecution`
class used internally by Spark (and, for example, access to the analyzed,
optimized, and executed plan) but rather provide a convenient way of accessing
execution related information.
### Why are the changes needed?
User Experience
### Does this PR introduce _any_ user-facing change?
Adding a new API for accessing the query execution of a Spark SQL execution.
### How was this patch tested?
Added new UT
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #46996 from grundprinzip/SPARK-48638.
Lead-authored-by: Martin Grund <[email protected]>
Co-authored-by: Martin Grund <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../spark/sql/connect/utils/MetricGenerator.scala | 1 +
dev/requirements.txt | 3 +
dev/sparktestsupport/modules.py | 1 +
python/docs/source/getting_started/install.rst | 1 +
.../source/reference/pyspark.sql/dataframe.rst | 1 +
python/pyspark/errors/error-conditions.json | 5 +
python/pyspark/sql/classic/dataframe.py | 8 +
python/pyspark/sql/connect/client/core.py | 91 +++----
python/pyspark/sql/connect/dataframe.py | 44 +++-
python/pyspark/sql/connect/readwriter.py | 48 +++-
python/pyspark/sql/connect/session.py | 7 +-
python/pyspark/sql/connect/streaming/query.py | 4 +-
python/pyspark/sql/connect/streaming/readwriter.py | 2 +-
python/pyspark/sql/dataframe.py | 26 ++
python/pyspark/sql/metrics.py | 287 +++++++++++++++++++++
python/pyspark/sql/tests/connect/test_df_debug.py | 86 ++++++
python/pyspark/sql/tests/test_dataframe.py | 11 +-
python/pyspark/testing/connectutils.py | 7 +
18 files changed, 544 insertions(+), 89 deletions(-)
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/MetricGenerator.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/MetricGenerator.scala
index e2e412831187..d76bec5454ab 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/MetricGenerator.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/MetricGenerator.scala
@@ -70,6 +70,7 @@ private[connect] object MetricGenerator extends
AdaptiveSparkPlanHelper {
.newBuilder()
.setName(p.nodeName)
.setPlanId(p.id)
+ .setParent(parentId)
.putAllExecutionMetrics(mv.asJava)
.build()
Seq(mo) ++ transformChildren(p)
diff --git a/dev/requirements.txt b/dev/requirements.txt
index d6530d8ce282..88883a963950 100644
--- a/dev/requirements.txt
+++ b/dev/requirements.txt
@@ -60,6 +60,9 @@ mypy-protobuf==3.3.0
googleapis-common-protos-stubs==2.2.0
grpc-stubs==1.24.11
+# Debug for Spark and Spark Connect
+graphviz==0.20.3
+
# TorchDistributor dependencies
torch
torchvision
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 8c17af559c25..66927066faa7 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -1063,6 +1063,7 @@ pyspark_connect = Module(
"pyspark.sql.tests.connect.test_parity_pandas_udf_window",
"pyspark.sql.tests.connect.test_resources",
"pyspark.sql.tests.connect.shell.test_progress",
+ "pyspark.sql.tests.connect.test_df_debug",
],
excluded_python_implementations=[
"PyPy" # Skip these tests under PyPy since they require numpy,
pandas, and pyarrow and
diff --git a/python/docs/source/getting_started/install.rst
b/python/docs/source/getting_started/install.rst
index 21926ae295bf..6cc68cd46b11 100644
--- a/python/docs/source/getting_started/install.rst
+++ b/python/docs/source/getting_started/install.rst
@@ -210,6 +210,7 @@ Package Supported version Note
`grpcio` >=1.62.0 Required for Spark Connect
`grpcio-status` >=1.62.0 Required for Spark Connect
`googleapis-common-protos` >=1.56.4 Required for Spark Connect
+`graphviz` >=0.20 Optional for Spark Connect
========================== ================= ==========================
Spark SQL
diff --git a/python/docs/source/reference/pyspark.sql/dataframe.rst
b/python/docs/source/reference/pyspark.sql/dataframe.rst
index ec39b645b140..d0196baa7a05 100644
--- a/python/docs/source/reference/pyspark.sql/dataframe.rst
+++ b/python/docs/source/reference/pyspark.sql/dataframe.rst
@@ -55,6 +55,7 @@ DataFrame
DataFrame.dropna
DataFrame.dtypes
DataFrame.exceptAll
+ DataFrame.executionInfo
DataFrame.explain
DataFrame.fillna
DataFrame.filter
diff --git a/python/pyspark/errors/error-conditions.json
b/python/pyspark/errors/error-conditions.json
index 30db37387249..dd70e814b1ea 100644
--- a/python/pyspark/errors/error-conditions.json
+++ b/python/pyspark/errors/error-conditions.json
@@ -149,6 +149,11 @@
"Cannot <condition1> without <condition2>."
]
},
+ "CLASSIC_OPERATION_NOT_SUPPORTED_ON_DF": {
+ "message": [
+ "Calling property or member '<member>' is not supported in PySpark
Classic, please use Spark Connect instead."
+ ]
+ },
"COLLATION_INVALID_PROVIDER" : {
"message" : [
"The value <provider> does not represent a correct collation provider.
Supported providers are: [<supportedProviders>]."
diff --git a/python/pyspark/sql/classic/dataframe.py
b/python/pyspark/sql/classic/dataframe.py
index a03467aff194..1bedd624603e 100644
--- a/python/pyspark/sql/classic/dataframe.py
+++ b/python/pyspark/sql/classic/dataframe.py
@@ -94,6 +94,7 @@ if TYPE_CHECKING:
from pyspark.sql.session import SparkSession
from pyspark.sql.group import GroupedData
from pyspark.sql.observation import Observation
+ from pyspark.sql.metrics import ExecutionInfo
class DataFrame(ParentDataFrame, PandasMapOpsMixin, PandasConversionMixin):
@@ -1835,6 +1836,13 @@ class DataFrame(ParentDataFrame, PandasMapOpsMixin,
PandasConversionMixin):
def toPandas(self) -> "PandasDataFrameLike":
return PandasConversionMixin.toPandas(self)
+ @property
+ def executionInfo(self) -> Optional["ExecutionInfo"]:
+ raise PySparkValueError(
+ error_class="CLASSIC_OPERATION_NOT_SUPPORTED_ON_DF",
+ message_parameters={"member": "queryExecution"},
+ )
+
def _to_scala_map(sc: "SparkContext", jm: Dict) -> "JavaObject":
"""
diff --git a/python/pyspark/sql/connect/client/core.py
b/python/pyspark/sql/connect/client/core.py
index f3bbab69f271..e91324150cbd 100644
--- a/python/pyspark/sql/connect/client/core.py
+++ b/python/pyspark/sql/connect/client/core.py
@@ -61,6 +61,7 @@ from pyspark.accumulators import SpecialAccumulatorIds
from pyspark.loose_version import LooseVersion
from pyspark.version import __version__
from pyspark.resource.information import ResourceInformation
+from pyspark.sql.metrics import MetricValue, PlanMetrics, ExecutionInfo,
ObservedMetrics
from pyspark.sql.connect.client.artifact import ArtifactManager
from pyspark.sql.connect.client.logging import logger
from pyspark.sql.connect.profiler import ConnectProfilerCollector
@@ -447,56 +448,7 @@ class DefaultChannelBuilder(ChannelBuilder):
return self._secure_channel(self.endpoint, creds)
-class MetricValue:
- def __init__(self, name: str, value: Union[int, float], type: str):
- self._name = name
- self._type = type
- self._value = value
-
- def __repr__(self) -> str:
- return f"<{self._name}={self._value} ({self._type})>"
-
- @property
- def name(self) -> str:
- return self._name
-
- @property
- def value(self) -> Union[int, float]:
- return self._value
-
- @property
- def metric_type(self) -> str:
- return self._type
-
-
-class PlanMetrics:
- def __init__(self, name: str, id: int, parent: int, metrics:
List[MetricValue]):
- self._name = name
- self._id = id
- self._parent_id = parent
- self._metrics = metrics
-
- def __repr__(self) -> str:
- return f"Plan({self._name})={self._metrics}"
-
- @property
- def name(self) -> str:
- return self._name
-
- @property
- def plan_id(self) -> int:
- return self._id
-
- @property
- def parent_plan_id(self) -> int:
- return self._parent_id
-
- @property
- def metrics(self) -> List[MetricValue]:
- return self._metrics
-
-
-class PlanObservedMetrics:
+class PlanObservedMetrics(ObservedMetrics):
def __init__(self, name: str, metrics: List[pb2.Expression.Literal], keys:
List[str]):
self._name = name
self._metrics = metrics
@@ -513,6 +465,13 @@ class PlanObservedMetrics:
def metrics(self) -> List[pb2.Expression.Literal]:
return self._metrics
+ @property
+ def pairs(self) -> dict[str, Any]:
+ result = {}
+ for x in range(len(self._metrics)):
+ result[self.keys[x]] = LiteralExpression._to_value(self.metrics[x])
+ return result
+
@property
def keys(self) -> List[str]:
return self._keys
@@ -888,7 +847,7 @@ class SparkConnectClient(object):
logger.info("Fetching the resources")
cmd = pb2.Command()
cmd.get_resources_command.SetInParent()
- (_, properties) = self.execute_command(cmd)
+ (_, properties, _) = self.execute_command(cmd)
resources = properties["get_resources_command_result"]
return resources
@@ -915,18 +874,23 @@ class SparkConnectClient(object):
def to_table(
self, plan: pb2.Plan, observations: Dict[str, Observation]
- ) -> Tuple["pa.Table", Optional[StructType]]:
+ ) -> Tuple["pa.Table", Optional[StructType], ExecutionInfo]:
"""
Return given plan as a PyArrow Table.
"""
logger.info(f"Executing plan {self._proto_to_string(plan)}")
req = self._execute_plan_request_with_metadata()
req.plan.CopyFrom(plan)
- table, schema, _, _, _ = self._execute_and_fetch(req, observations)
+ table, schema, metrics, observed_metrics, _ =
self._execute_and_fetch(req, observations)
+
+ # Create a query execution object.
+ ei = ExecutionInfo(metrics, observed_metrics)
assert table is not None
- return table, schema
+ return table, schema, ei
- def to_pandas(self, plan: pb2.Plan, observations: Dict[str, Observation])
-> "pd.DataFrame":
+ def to_pandas(
+ self, plan: pb2.Plan, observations: Dict[str, Observation]
+ ) -> Tuple["pd.DataFrame", "ExecutionInfo"]:
"""
Return given plan as a pandas DataFrame.
"""
@@ -941,6 +905,7 @@ class SparkConnectClient(object):
req, observations, self_destruct=self_destruct
)
assert table is not None
+ ei = ExecutionInfo(metrics, observed_metrics)
schema = schema or from_arrow_schema(table.schema,
prefer_timestamp_ntz=True)
assert schema is not None and isinstance(schema, StructType)
@@ -1007,7 +972,7 @@ class SparkConnectClient(object):
pdf.attrs["metrics"] = metrics
if len(observed_metrics) > 0:
pdf.attrs["observed_metrics"] = observed_metrics
- return pdf
+ return pdf, ei
def _proto_to_string(self, p: google.protobuf.message.Message) -> str:
"""
@@ -1051,7 +1016,7 @@ class SparkConnectClient(object):
def execute_command(
self, command: pb2.Command, observations: Optional[Dict[str,
Observation]] = None
- ) -> Tuple[Optional[pd.DataFrame], Dict[str, Any]]:
+ ) -> Tuple[Optional[pd.DataFrame], Dict[str, Any], ExecutionInfo]:
"""
Execute given command.
"""
@@ -1060,11 +1025,15 @@ class SparkConnectClient(object):
if self._user_id:
req.user_context.user_id = self._user_id
req.plan.command.CopyFrom(command)
- data, _, _, _, properties = self._execute_and_fetch(req, observations
or {})
+ data, _, metrics, observed_metrics, properties =
self._execute_and_fetch(
+ req, observations or {}
+ )
+ # Create a query execution object.
+ ei = ExecutionInfo(metrics, observed_metrics)
if data is not None:
- return (data.to_pandas(), properties)
+ return (data.to_pandas(), properties, ei)
else:
- return (None, properties)
+ return (None, properties, ei)
def execute_command_as_iterator(
self, command: pb2.Command, observations: Optional[Dict[str,
Observation]] = None
@@ -1849,6 +1818,6 @@ class SparkConnectClient(object):
logger.info("Creating the ResourceProfile")
cmd = pb2.Command()
cmd.create_resource_profile_command.profile.CopyFrom(profile)
- (_, properties) = self.execute_command(cmd)
+ (_, properties, _) = self.execute_command(cmd)
profile_id = properties["create_resource_profile_command_result"]
return profile_id
diff --git a/python/pyspark/sql/connect/dataframe.py
b/python/pyspark/sql/connect/dataframe.py
index 678e66ee2b7b..1aa8fc00cfcc 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -101,6 +101,7 @@ if TYPE_CHECKING:
from pyspark.sql.connect.observation import Observation
from pyspark.sql.connect.session import SparkSession
from pyspark.pandas.frame import DataFrame as PandasOnSparkDataFrame
+ from pyspark.sql.metrics import ExecutionInfo
class DataFrame(ParentDataFrame):
@@ -137,6 +138,7 @@ class DataFrame(ParentDataFrame):
# by __repr__ and _repr_html_ while eager evaluation opens.
self._support_repr_html = False
self._cached_schema: Optional[StructType] = None
+ self._execution_info: Optional["ExecutionInfo"] = None
def __reduce__(self) -> Tuple:
"""
@@ -206,7 +208,10 @@ class DataFrame(ParentDataFrame):
@property
def write(self) -> "DataFrameWriter":
- return DataFrameWriter(self._plan, self._session)
+ def cb(qe: "ExecutionInfo") -> None:
+ self._execution_info = qe
+
+ return DataFrameWriter(self._plan, self._session, cb)
@functools.cache
def isEmpty(self) -> bool:
@@ -1839,7 +1844,9 @@ class DataFrame(ParentDataFrame):
def _to_table(self) -> Tuple["pa.Table", Optional[StructType]]:
query = self._plan.to_proto(self._session.client)
- table, schema = self._session.client.to_table(query,
self._plan.observations)
+ table, schema, self._execution_info = self._session.client.to_table(
+ query, self._plan.observations
+ )
assert table is not None
return (table, schema)
@@ -1850,7 +1857,9 @@ class DataFrame(ParentDataFrame):
def toPandas(self) -> "PandasDataFrameLike":
query = self._plan.to_proto(self._session.client)
- return self._session.client.to_pandas(query, self._plan.observations)
+ pdf, ei = self._session.client.to_pandas(query,
self._plan.observations)
+ self._execution_info = ei
+ return pdf
@property
def schema(self) -> StructType:
@@ -1976,25 +1985,29 @@ class DataFrame(ParentDataFrame):
command = plan.CreateView(
child=self._plan, name=name, is_global=False, replace=False
).command(session=self._session.client)
- self._session.client.execute_command(command, self._plan.observations)
+ _, _, ei = self._session.client.execute_command(command,
self._plan.observations)
+ self._execution_info = ei
def createOrReplaceTempView(self, name: str) -> None:
command = plan.CreateView(
child=self._plan, name=name, is_global=False, replace=True
).command(session=self._session.client)
- self._session.client.execute_command(command, self._plan.observations)
+ _, _, ei = self._session.client.execute_command(command,
self._plan.observations)
+ self._execution_info = ei
def createGlobalTempView(self, name: str) -> None:
command = plan.CreateView(
child=self._plan, name=name, is_global=True, replace=False
).command(session=self._session.client)
- self._session.client.execute_command(command, self._plan.observations)
+ _, _, ei = self._session.client.execute_command(command,
self._plan.observations)
+ self._execution_info = ei
def createOrReplaceGlobalTempView(self, name: str) -> None:
command = plan.CreateView(
child=self._plan, name=name, is_global=True, replace=True
).command(session=self._session.client)
- self._session.client.execute_command(command, self._plan.observations)
+ _, _, ei = self._session.client.execute_command(command,
self._plan.observations)
+ self._execution_info = ei
def cache(self) -> ParentDataFrame:
return self.persist()
@@ -2169,14 +2182,19 @@ class DataFrame(ParentDataFrame):
)
def writeTo(self, table: str) -> "DataFrameWriterV2":
- return DataFrameWriterV2(self._plan, self._session, table)
+ def cb(ei: "ExecutionInfo") -> None:
+ self._execution_info = ei
+
+ return DataFrameWriterV2(self._plan, self._session, table, cb)
def offset(self, n: int) -> ParentDataFrame:
return DataFrame(plan.Offset(child=self._plan, offset=n),
session=self._session)
def checkpoint(self, eager: bool = True) -> "DataFrame":
cmd = plan.Checkpoint(child=self._plan, local=False, eager=eager)
- _, properties =
self._session.client.execute_command(cmd.command(self._session.client))
+ _, properties, self._execution_info =
self._session.client.execute_command(
+ cmd.command(self._session.client)
+ )
assert "checkpoint_command_result" in properties
checkpointed = properties["checkpoint_command_result"]
assert isinstance(checkpointed._plan, plan.CachedRemoteRelation)
@@ -2184,7 +2202,9 @@ class DataFrame(ParentDataFrame):
def localCheckpoint(self, eager: bool = True) -> "DataFrame":
cmd = plan.Checkpoint(child=self._plan, local=True, eager=eager)
- _, properties =
self._session.client.execute_command(cmd.command(self._session.client))
+ _, properties, self._execution_info =
self._session.client.execute_command(
+ cmd.command(self._session.client)
+ )
assert "checkpoint_command_result" in properties
checkpointed = properties["checkpoint_command_result"]
assert isinstance(checkpointed._plan, plan.CachedRemoteRelation)
@@ -2205,6 +2225,10 @@ class DataFrame(ParentDataFrame):
message_parameters={"feature": "rdd"},
)
+ @property
+ def executionInfo(self) -> Optional["ExecutionInfo"]:
+ return self._execution_info
+
class DataFrameNaFunctions(ParentDataFrameNaFunctions):
def __init__(self, df: ParentDataFrame):
diff --git a/python/pyspark/sql/connect/readwriter.py
b/python/pyspark/sql/connect/readwriter.py
index bf7dc4d36905..de62cf65b01e 100644
--- a/python/pyspark/sql/connect/readwriter.py
+++ b/python/pyspark/sql/connect/readwriter.py
@@ -19,7 +19,7 @@ from pyspark.sql.connect.utils import check_dependencies
check_dependencies(__name__)
from typing import Dict
-from typing import Optional, Union, List, overload, Tuple, cast
+from typing import Optional, Union, List, overload, Tuple, cast, Callable
from typing import TYPE_CHECKING
from pyspark.sql.connect.plan import Read, DataSource, LogicalPlan,
WriteOperation, WriteOperationV2
@@ -37,6 +37,7 @@ if TYPE_CHECKING:
from pyspark.sql.connect.dataframe import DataFrame
from pyspark.sql.connect._typing import ColumnOrName, OptionalPrimitiveType
from pyspark.sql.connect.session import SparkSession
+ from pyspark.sql.metrics import ExecutionInfo
__all__ = ["DataFrameReader", "DataFrameWriter"]
@@ -486,11 +487,18 @@ DataFrameReader.__doc__ = PySparkDataFrameReader.__doc__
class DataFrameWriter(OptionUtils):
- def __init__(self, plan: "LogicalPlan", session: "SparkSession"):
+ def __init__(
+ self,
+ plan: "LogicalPlan",
+ session: "SparkSession",
+ callback: Optional[Callable[["ExecutionInfo"], None]] = None,
+ ):
self._df: "LogicalPlan" = plan
self._spark: "SparkSession" = session
self._write: "WriteOperation" = WriteOperation(self._df)
+ self._callback = callback if callback is not None else lambda _: None
+
def mode(self, saveMode: Optional[str]) -> "DataFrameWriter":
# At the JVM side, the default value of mode is already set to "error".
# So, if the given saveMode is None, we will not call JVM-side's mode
method.
@@ -649,9 +657,10 @@ class DataFrameWriter(OptionUtils):
if format is not None:
self.format(format)
self._write.path = path
- self._spark.client.execute_command(
+ _, _, ei = self._spark.client.execute_command(
self._write.command(self._spark.client), self._write.observations
)
+ self._callback(ei)
save.__doc__ = PySparkDataFrameWriter.save.__doc__
@@ -660,9 +669,10 @@ class DataFrameWriter(OptionUtils):
self.mode("overwrite" if overwrite else "append")
self._write.table_name = tableName
self._write.table_save_method = "insert_into"
- self._spark.client.execute_command(
+ _, _, ei = self._spark.client.execute_command(
self._write.command(self._spark.client), self._write.observations
)
+ self._callback(ei)
insertInto.__doc__ = PySparkDataFrameWriter.insertInto.__doc__
@@ -681,9 +691,10 @@ class DataFrameWriter(OptionUtils):
self.format(format)
self._write.table_name = name
self._write.table_save_method = "save_as_table"
- self._spark.client.execute_command(
+ _, _, ei = self._spark.client.execute_command(
self._write.command(self._spark.client), self._write.observations
)
+ self._callback(ei)
saveAsTable.__doc__ = PySparkDataFrameWriter.saveAsTable.__doc__
@@ -845,11 +856,18 @@ class DataFrameWriter(OptionUtils):
class DataFrameWriterV2(OptionUtils):
- def __init__(self, plan: "LogicalPlan", session: "SparkSession", table:
str):
+ def __init__(
+ self,
+ plan: "LogicalPlan",
+ session: "SparkSession",
+ table: str,
+ callback: Optional[Callable[["ExecutionInfo"], None]] = None,
+ ):
self._df: "LogicalPlan" = plan
self._spark: "SparkSession" = session
self._table_name: str = table
self._write: "WriteOperationV2" = WriteOperationV2(self._df,
self._table_name)
+ self._callback = callback if callback is not None else lambda _: None
def using(self, provider: str) -> "DataFrameWriterV2":
self._write.provider = provider
@@ -884,50 +902,56 @@ class DataFrameWriterV2(OptionUtils):
def create(self) -> None:
self._write.mode = "create"
- self._spark.client.execute_command(
+ _, _, ei = self._spark.client.execute_command(
self._write.command(self._spark.client), self._write.observations
)
+ self._callback(ei)
create.__doc__ = PySparkDataFrameWriterV2.create.__doc__
def replace(self) -> None:
self._write.mode = "replace"
- self._spark.client.execute_command(
+ _, _, ei = self._spark.client.execute_command(
self._write.command(self._spark.client), self._write.observations
)
+ self._callback(ei)
replace.__doc__ = PySparkDataFrameWriterV2.replace.__doc__
def createOrReplace(self) -> None:
self._write.mode = "create_or_replace"
- self._spark.client.execute_command(
+ _, _, ei = self._spark.client.execute_command(
self._write.command(self._spark.client), self._write.observations
)
+ self._callback(ei)
createOrReplace.__doc__ = PySparkDataFrameWriterV2.createOrReplace.__doc__
def append(self) -> None:
self._write.mode = "append"
- self._spark.client.execute_command(
+ _, _, ei = self._spark.client.execute_command(
self._write.command(self._spark.client), self._write.observations
)
+ self._callback(ei)
append.__doc__ = PySparkDataFrameWriterV2.append.__doc__
def overwrite(self, condition: "ColumnOrName") -> None:
self._write.mode = "overwrite"
self._write.overwrite_condition = F._to_col(condition)
- self._spark.client.execute_command(
+ _, _, ei = self._spark.client.execute_command(
self._write.command(self._spark.client), self._write.observations
)
+ self._callback(ei)
overwrite.__doc__ = PySparkDataFrameWriterV2.overwrite.__doc__
def overwritePartitions(self) -> None:
self._write.mode = "overwrite_partitions"
- self._spark.client.execute_command(
+ _, _, ei = self._spark.client.execute_command(
self._write.command(self._spark.client), self._write.observations
)
+ self._callback(ei)
overwritePartitions.__doc__ =
PySparkDataFrameWriterV2.overwritePartitions.__doc__
diff --git a/python/pyspark/sql/connect/session.py
b/python/pyspark/sql/connect/session.py
index f359ab829483..8e277b3fc63a 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -720,9 +720,12 @@ class SparkSession:
_views.append(SubqueryAlias(df._plan, name))
cmd = SQL(sqlQuery, _args, _named_args, _views)
- data, properties =
self.client.execute_command(cmd.command(self._client))
+ data, properties, ei =
self.client.execute_command(cmd.command(self._client))
if "sql_command_result" in properties:
- return DataFrame(CachedRelation(properties["sql_command_result"]),
self)
+ df = DataFrame(CachedRelation(properties["sql_command_result"]),
self)
+ # A command result contains the execution.
+ df._execution_info = ei
+ return df
else:
return DataFrame(cmd, self)
diff --git a/python/pyspark/sql/connect/streaming/query.py
b/python/pyspark/sql/connect/streaming/query.py
index 98ecdc4966c7..13458d650fa9 100644
--- a/python/pyspark/sql/connect/streaming/query.py
+++ b/python/pyspark/sql/connect/streaming/query.py
@@ -181,7 +181,7 @@ class StreamingQuery:
cmd.query_id.run_id = self._run_id
exec_cmd = pb2.Command()
exec_cmd.streaming_query_command.CopyFrom(cmd)
- (_, properties) = self._session.client.execute_command(exec_cmd)
+ (_, properties, _) = self._session.client.execute_command(exec_cmd)
return cast(pb2.StreamingQueryCommandResult,
properties["streaming_query_command_result"])
@@ -260,7 +260,7 @@ class StreamingQueryManager:
) -> pb2.StreamingQueryManagerCommandResult:
exec_cmd = pb2.Command()
exec_cmd.streaming_query_manager_command.CopyFrom(cmd)
- (_, properties) = self._session.client.execute_command(exec_cmd)
+ (_, properties, _) = self._session.client.execute_command(exec_cmd)
return cast(
pb2.StreamingQueryManagerCommandResult,
properties["streaming_query_manager_command_result"],
diff --git a/python/pyspark/sql/connect/streaming/readwriter.py
b/python/pyspark/sql/connect/streaming/readwriter.py
index b5bb7f2a0912..9b11bf328b85 100644
--- a/python/pyspark/sql/connect/streaming/readwriter.py
+++ b/python/pyspark/sql/connect/streaming/readwriter.py
@@ -601,7 +601,7 @@ class DataStreamWriter:
self._write_proto.table_name = tableName
cmd = self._write_stream.command(self._session.client)
- (_, properties) = self._session.client.execute_command(cmd)
+ (_, properties, _) = self._session.client.execute_command(cmd)
start_result = cast(
pb2.WriteStreamOperationStartResult,
properties["write_stream_operation_start_result"]
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 62c46cfec93c..625678588bf9 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -64,6 +64,7 @@ if TYPE_CHECKING:
ArrowMapIterFunction,
DataFrameLike as PandasDataFrameLike,
)
+ from pyspark.sql.metrics import ExecutionInfo
__all__ = ["DataFrame", "DataFrameNaFunctions", "DataFrameStatFunctions"]
@@ -6281,6 +6282,31 @@ class DataFrame:
"""
...
+ @property
+ def executionInfo(self) -> Optional["ExecutionInfo"]:
+ """
+ Returns a QueryExecution object after the query was executed.
+
+ The queryExecution method allows to introspect information about the
actual
+ query execution after the successful execution. Accessing this member
before
+ the query execution will return None.
+
+ If the same DataFrame is executed multiple times, the execution info
will be
+ overwritten by the latest operation.
+
+ .. versionadded:: 4.0.0
+
+ Returns
+ -------
+ An instance of QueryExecution or None when the value is not set yet.
+
+ Notes
+ -----
+ This is an API dedicated to Spark Connect client only. With regular
Spark Session, it throws
+ an exception.
+ """
+ ...
+
class DataFrameNaFunctions:
"""Functionality for working with missing data in :class:`DataFrame`.
diff --git a/python/pyspark/sql/metrics.py b/python/pyspark/sql/metrics.py
new file mode 100644
index 000000000000..666458295201
--- /dev/null
+++ b/python/pyspark/sql/metrics.py
@@ -0,0 +1,287 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import abc
+import dataclasses
+from typing import Optional, List, Tuple, Dict, Any, Union, TYPE_CHECKING,
Sequence
+
+from pyspark.errors import PySparkValueError
+
+if TYPE_CHECKING:
+ from pyspark.testing.connectutils import have_graphviz
+
+ if have_graphviz:
+ import graphviz # type: ignore
+
+
+class ObservedMetrics(abc.ABC):
+ @property
+ @abc.abstractmethod
+ def name(self) -> str:
+ ...
+
+ @property
+ @abc.abstractmethod
+ def pairs(self) -> Dict[str, Any]:
+ ...
+
+ @property
+ @abc.abstractmethod
+ def keys(self) -> List[str]:
+ ...
+
+
+class MetricValue:
+ """The metric values is the Python representation of a plan metric value
from the JVM.
+ However, it does not have any reference to the original value."""
+
+ def __init__(self, name: str, value: Union[int, float], type: str):
+ self._name = name
+ self._type = type
+ self._value = value
+
+ def __repr__(self) -> str:
+ return f"<{self._name}={self._value} ({self._type})>"
+
+ @property
+ def name(self) -> str:
+ return self._name
+
+ @property
+ def value(self) -> Union[int, float]:
+ return self._value
+
+ @property
+ def metric_type(self) -> str:
+ return self._type
+
+
+class PlanMetrics:
+ """Represents a particular plan node and the associated metrics of this
node."""
+
+ def __init__(self, name: str, id: int, parent: int, metrics:
List[MetricValue]):
+ self._name = name
+ self._id = id
+ self._parent_id = parent
+ self._metrics = metrics
+
+ def __repr__(self) -> str:
+ return f"Plan({self._name}:
{self._id}->{self._parent_id})={self._metrics}"
+
+ @property
+ def name(self) -> str:
+ return self._name
+
+ @property
+ def plan_id(self) -> int:
+ return self._id
+
+ @property
+ def parent_plan_id(self) -> int:
+ return self._parent_id
+
+ @property
+ def metrics(self) -> List[MetricValue]:
+ return self._metrics
+
+
+class CollectedMetrics:
+ @dataclasses.dataclass
+ class Node:
+ id: int
+ name: str = dataclasses.field(default="")
+ metrics: List[MetricValue] = dataclasses.field(default_factory=list)
+ children: List[int] = dataclasses.field(default_factory=list)
+
+ def text(self, current: "Node", graph: Dict[int, "Node"], prefix: str =
"") -> str:
+ """
+ Converts the current node and its children into a textual
representation. This is used
+ to provide a usable output for the command line or other text-based
interfaces. However,
+ it is recommended to use the Graphviz representation for a more visual
representation.
+
+ Parameters
+ ----------
+ current: Node
+ Current node in the graph.
+ graph: dict
+ A dictionary representing the full graph mapping from node ID
(int) to the node itself.
+ The node is an instance of :class:`CollectedMetrics:Node`.
+ prefix: str
+ String prefix used for generating the output buffer.
+
+ Returns
+ -------
+ The full string representation of the current node as root.
+ """
+ base_metrics = set(["numPartitions", "peakMemory", "numOutputRows",
"spillSize"])
+
+ # Format the metrics of this node:
+ metric_buffer = []
+ for m in current.metrics:
+ if m.name in base_metrics:
+ metric_buffer.append(f"{m.name}: {m.value} ({m.metric_type})")
+
+ buffer = f"{prefix}+- {current.name}({','.join(metric_buffer)})\n"
+ for i, child in enumerate(current.children):
+ c = graph[child]
+ new_prefix = prefix + " " if i == len(c.children) - 1 else prefix
+ if current.id != c.id:
+ buffer += self.text(c, graph, new_prefix)
+ return buffer
+
+ def __init__(self, metrics: List[PlanMetrics]):
+ # Sort the input list
+ self._metrics = sorted(metrics, key=lambda x: x._parent_id,
reverse=False)
+
+ def extract_graph(self) -> Tuple[int, Dict[int, "CollectedMetrics.Node"]]:
+ """
+ Builds the graph of the query plan. The graph is represented as a
dictionary where the key
+ is the node ID and the value is the node itself. The root node is the
node that has no
+ parent.
+
+ Returns
+ -------
+ The root node ID and the graph of all nodes.
+ """
+ all_nodes: Dict[int, CollectedMetrics.Node] = {}
+
+ for m in self._metrics:
+ # Add yourself to the list if you have to.
+ if m.plan_id not in all_nodes:
+ all_nodes[m.plan_id] = CollectedMetrics.Node(m.plan_id,
m.name, m.metrics)
+ else:
+ all_nodes[m.plan_id].name = m.name
+ all_nodes[m.plan_id].metrics = m.metrics
+
+ # Now check for the parent of this node if it's in
+ if m.parent_plan_id not in all_nodes:
+ all_nodes[m.parent_plan_id] =
CollectedMetrics.Node(m.parent_plan_id)
+
+ all_nodes[m.parent_plan_id].children.append(m.plan_id)
+
+ # Next step is to find all the root nodes. Root nodes are never used
in children.
+ # So we start with all node ids as candidates.
+ candidates = set(all_nodes.keys())
+ for k, v in all_nodes.items():
+ for c in v.children:
+ if c in candidates and c != k:
+ candidates.remove(c)
+
+ assert len(candidates) == 1, f"Expected 1 root node, found
{len(candidates)}"
+ return candidates.pop(), all_nodes
+
+ def toText(self) -> str:
+ """
+ Converts the execution graph from a graph into a textual representation
+ that can be read at the command line for example.
+
+ Returns
+ -------
+ A string representation of the collected metrics.
+ """
+ root, graph = self.extract_graph()
+ return self.text(graph[root], graph)
+
+ def toDot(self, filename: Optional[str] = None, out_format: str = "png")
-> "graphviz.Digraph":
+ """
+ Converts the collected metrics into a dot representation. Since the
graphviz Digraph
+ implementation provides the ability to render the result graph
directory in a
+ notebook, we return the graph object directly.
+
+ If the graphviz package is not available, a PACKAGE_NOT_INSTALLED
error is raised.
+
+ Parameters
+ ----------
+ filename : str, optional
+ The filename to save the graph to given an output format. The path
can be
+ relative or absolute.
+
+ out_format : str
+ The output format of the graph. The default is 'png'.
+
+ Returns
+ -------
+ An instance of the graphviz.Digraph object.
+ """
+ try:
+ import graphviz
+
+ dot = graphviz.Digraph(
+ comment="Query Plan",
+ node_attr={
+ "shape": "box",
+ "font-size": "10pt",
+ },
+ )
+
+ root, graph = self.extract_graph()
+ for k, v in graph.items():
+ # Build table rows for the metrics
+ rows = "\n".join(
+ [
+ (
+ f'<TR><TD><FONT
POINT-SIZE="8">{x.name}</FONT></TD><TD>'
+ f'<FONT POINT-SIZE="8">{x.value}
({x.metric_type})</FONT></TD></TR>'
+ )
+ for x in v.metrics
+ ]
+ )
+
+ dot.node(
+ str(k),
+ """<<TABLE BORDER="0" CELLBORDER="1" CELLSPACING="0">
+ <TR>
+ <TD COLSPAN="2" BGCOLOR="lightgrey">
+ <FONT POINT-SIZE=\"10\">{}</FONT>
+ </TD>
+ </TR>
+ <TR><TD COLSPAN="2"><FONT
POINT-SIZE=\"10\">Metrics</FONT></TD></TR>
+ {}
+ </TABLE>>""".format(
+ v.name, rows
+ ),
+ )
+ for c in v.children:
+ dot.edge(str(k), str(c))
+
+ if filename:
+ dot.render(filename, format=out_format, cleanup=True)
+ return dot
+
+ except ImportError:
+ raise PySparkValueError(
+ error_class="PACKAGE_NOT_INSTALLED",
+ message_parameters={"package_name": "graphviz",
"minimum_version": "0.20"},
+ )
+
+
+class ExecutionInfo:
+ """The query execution class allows users to inspect the query execution
of this particular
+ data frame. This value is only set in the data frame if it was executed."""
+
+ def __init__(
+ self, metrics: Optional[list[PlanMetrics]], obs:
Optional[Sequence[ObservedMetrics]]
+ ):
+ self._metrics = CollectedMetrics(metrics) if metrics else None
+ self._observations = obs if obs else []
+
+ @property
+ def metrics(self) -> Optional[CollectedMetrics]:
+ return self._metrics
+
+ @property
+ def flows(self) -> List[Tuple[str, Dict[str, Any]]]:
+ return [(f.name, f.pairs) for f in self._observations]
diff --git a/python/pyspark/sql/tests/connect/test_df_debug.py
b/python/pyspark/sql/tests/connect/test_df_debug.py
new file mode 100644
index 000000000000..8a4ec68fda84
--- /dev/null
+++ b/python/pyspark/sql/tests/connect/test_df_debug.py
@@ -0,0 +1,86 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+from pyspark.testing.connectutils import (
+ should_test_connect,
+ have_graphviz,
+ graphviz_requirement_message,
+)
+from pyspark.sql.tests.connect.test_connect_basic import
SparkConnectSQLTestCase
+
+if should_test_connect:
+ from pyspark.sql.connect.dataframe import DataFrame
+
+
+class SparkConnectDataFrameDebug(SparkConnectSQLTestCase):
+ def test_df_debug_basics(self):
+ df: DataFrame =
self.connect.range(100).repartition(10).groupBy("id").count()
+ x = df.collect() # noqa: F841
+ ei = df.executionInfo
+
+ root, graph = ei.metrics.extract_graph()
+ self.assertIn(root, graph, "The root must be rooted in the graph")
+
+ def test_df_quey_execution_empty_before_execution(self):
+ df: DataFrame =
self.connect.range(100).repartition(10).groupBy("id").count()
+ ei = df.executionInfo
+ self.assertIsNone(ei, "The query execution must be None before the
action is executed")
+
+ def test_df_query_execution_with_writes(self):
+ df: DataFrame =
self.connect.range(100).repartition(10).groupBy("id").count()
+ df.write.save("/tmp/test_df_query_execution_with_writes",
format="json", mode="overwrite")
+ ei = df.executionInfo
+ self.assertIsNotNone(
+ ei, "The query execution must be None after the write action is
executed"
+ )
+
+ def test_query_execution_text_format(self):
+ df: DataFrame =
self.connect.range(100).repartition(10).groupBy("id").count()
+ df.collect()
+ self.assertIn("HashAggregate", df.executionInfo.metrics.toText())
+
+ # Different execution mode.
+ df: DataFrame =
self.connect.range(100).repartition(10).groupBy("id").count()
+ df.toPandas()
+ self.assertIn("HashAggregate", df.executionInfo.metrics.toText())
+
+ @unittest.skipIf(not have_graphviz, graphviz_requirement_message)
+ def test_df_query_execution_metrics_to_dot(self):
+ df: DataFrame =
self.connect.range(100).repartition(10).groupBy("id").count()
+ x = df.collect() # noqa: F841
+ ei = df.executionInfo
+
+ dot = ei.metrics.toDot()
+ source = dot.source
+ self.assertIsNotNone(dot, "The dot representation must not be None")
+ self.assertGreater(len(source), 0, "The dot representation must not be
empty")
+ self.assertIn("digraph", source, "The dot representation must contain
the digraph keyword")
+ self.assertIn("Metrics", source, "The dot representation must contain
the Metrics keyword")
+
+
+if __name__ == "__main__":
+ from pyspark.sql.tests.connect.test_df_debug import * # noqa: F401
+
+ try:
+ import xmlrunner # type: ignore
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports",
verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/tests/test_dataframe.py
b/python/pyspark/sql/tests/test_dataframe.py
index d7b31bbc5215..c7cf43a33454 100644
--- a/python/pyspark/sql/tests/test_dataframe.py
+++ b/python/pyspark/sql/tests/test_dataframe.py
@@ -37,6 +37,7 @@ from pyspark.errors import (
AnalysisException,
IllegalArgumentException,
PySparkTypeError,
+ PySparkValueError,
)
from pyspark.testing.sqlutils import (
ReusedSQLTestCase,
@@ -849,7 +850,15 @@ class DataFrameTestsMixin:
class DataFrameTests(DataFrameTestsMixin, ReusedSQLTestCase):
- pass
+ def test_query_execution_unsupported_in_classic(self):
+ with self.assertRaises(PySparkValueError) as pe:
+ self.spark.range(1).executionInfo
+
+ self.check_error(
+ exception=pe.exception,
+ error_class="CLASSIC_OPERATION_NOT_SUPPORTED_ON_DF",
+ message_parameters={"member": "queryExecution"},
+ )
if __name__ == "__main__":
diff --git a/python/pyspark/testing/connectutils.py
b/python/pyspark/testing/connectutils.py
index 191505741eb4..b3004693724b 100644
--- a/python/pyspark/testing/connectutils.py
+++ b/python/pyspark/testing/connectutils.py
@@ -45,6 +45,13 @@ except ImportError as e:
googleapis_common_protos_requirement_message = str(e)
have_googleapis_common_protos = googleapis_common_protos_requirement_message
is None
+graphviz_requirement_message = None
+try:
+ import graphviz
+except ImportError as e:
+ graphviz_requirement_message = str(e)
+have_graphviz: bool = graphviz_requirement_message is None
+
from pyspark import Row, SparkConf
from pyspark.util import is_remote_only
from pyspark.testing.utils import PySparkErrorTestUtils
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]