This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new f12d2441a1b9 [SPARK-45619][CONNECT][PYTHON] Apply the observed metrics
to Observation object
f12d2441a1b9 is described below
commit f12d2441a1b94118f759f4ed95ca7fbd4b5f0f74
Author: Takuya UESHIN <[email protected]>
AuthorDate: Tue Oct 24 08:30:34 2023 +0900
[SPARK-45619][CONNECT][PYTHON] Apply the observed metrics to Observation
object
### What changes were proposed in this pull request?
Apply the observed metrics to `Observation` object.
### Why are the changes needed?
When using `Observation`, the observed metrics should be applied to the
object.
For example, in vanilla PySpark:
```py
>>> df = spark.createDataFrame([["Alice", 2], ["Bob", 5]], ["name", "age"])
>>> observation = Observation("my metrics")
>>> observed_df = df.observe(observation, count(lit(1)).alias("count"),
max(col("age")))
>>> observed_df.count()
2
>>>
>>> observation.get
{'count': 2, 'max(age)': 5}
```
whereas in Spark Connect, currently it fails with
`PySparkNotImplementedError`:
```py
>>> observation.get
Traceback (most recent call last):
...
raise PySparkNotImplementedError(
pyspark.errors.exceptions.base.PySparkNotImplementedError:
[NOT_IMPLEMENTED] Observation support for Spark Connect is not implemented.
```
### Does this PR introduce _any_ user-facing change?
Yes, the `Observation` will see the changes.
### How was this patch tested?
Added/modified the related tests.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #43469 from ueshin/issues/SPARK-45619/observation.
Authored-by: Takuya UESHIN <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../src/main/protobuf/spark/connect/base.proto | 1 +
.../execution/SparkConnectPlanExecution.scala | 7 +-
dev/sparktestsupport/modules.py | 1 +
python/pyspark/sql/connect/client/core.py | 48 ++++--
python/pyspark/sql/connect/dataframe.py | 25 ++--
python/pyspark/sql/connect/observation.py | 95 ++++++++++++
python/pyspark/sql/connect/plan.py | 43 +++++-
python/pyspark/sql/connect/proto/base_pb2.py | 166 ++++++++++-----------
python/pyspark/sql/connect/proto/base_pb2.pyi | 11 +-
python/pyspark/sql/connect/readwriter.py | 36 +++--
python/pyspark/sql/observation.py | 13 +-
.../sql/tests/connect/test_connect_basic.py | 6 +-
.../pyspark/sql/tests/connect/test_connect_plan.py | 12 +-
.../sql/tests/connect/test_parity_dataframe.py | 5 -
python/pyspark/sql/utils.py | 16 --
15 files changed, 333 insertions(+), 152 deletions(-)
diff --git
a/connector/connect/common/src/main/protobuf/spark/connect/base.proto
b/connector/connect/common/src/main/protobuf/spark/connect/base.proto
index e2532cfc66d1..5b94c6d663cc 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/base.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/base.proto
@@ -404,6 +404,7 @@ message ExecutePlanResponse {
message ObservedMetrics {
string name = 1;
repeated Expression.Literal values = 2;
+ repeated string keys = 3;
}
message ResultComplete {
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala
index a3ce813271a3..9dd54e5b2b5d 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala
@@ -244,11 +244,14 @@ private[execution] class
SparkConnectPlanExecution(executeHolder: ExecuteHolder)
dataframe: DataFrame): ExecutePlanResponse = {
val observedMetrics = dataframe.queryExecution.observedMetrics.map { case
(name, row) =>
val cols = (0 until row.length).map(i => toLiteralProto(row(i)))
- ExecutePlanResponse.ObservedMetrics
+ val metrics = ExecutePlanResponse.ObservedMetrics
.newBuilder()
.setName(name)
.addAllValues(cols.asJava)
- .build()
+ if (row.schema != null) {
+ metrics.addAllKeys(row.schema.fieldNames.toList.asJava)
+ }
+ metrics.build()
}
// Prepare a response with the observed metrics.
ExecutePlanResponse
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 7251fd7c1588..913d9df6bcf2 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -839,6 +839,7 @@ pyspark_connect = Module(
"pyspark.sql.connect.readwriter",
"pyspark.sql.connect.dataframe",
"pyspark.sql.connect.functions",
+ "pyspark.sql.connect.observation",
"pyspark.sql.connect.avro.functions",
"pyspark.sql.connect.protobuf.functions",
"pyspark.sql.connect.streaming.readwriter",
diff --git a/python/pyspark/sql/connect/client/core.py
b/python/pyspark/sql/connect/client/core.py
index 958bc460c953..318f7d7ade4a 100644
--- a/python/pyspark/sql/connect/client/core.py
+++ b/python/pyspark/sql/connect/client/core.py
@@ -79,6 +79,7 @@ from pyspark.errors.exceptions.connect import (
SparkConnectGrpcException,
)
from pyspark.sql.connect.expressions import (
+ LiteralExpression,
PythonUDF,
CommonInlineUserDefinedFunction,
JavaUDF,
@@ -87,6 +88,7 @@ from pyspark.sql.connect.plan import (
CommonInlineUserDefinedTableFunction,
PythonUDTF,
)
+from pyspark.sql.connect.observation import Observation
from pyspark.sql.connect.utils import get_python_ver
from pyspark.sql.pandas.types import _create_converter_to_pandas,
from_arrow_schema
from pyspark.sql.types import DataType, StructType, TimestampType, _has_type
@@ -429,9 +431,10 @@ class PlanMetrics:
class PlanObservedMetrics:
- def __init__(self, name: str, metrics: List[pb2.Expression.Literal]):
+ def __init__(self, name: str, metrics: List[pb2.Expression.Literal], keys:
List[str]):
self._name = name
self._metrics = metrics
+ self._keys = keys if keys else [f"observed_metric_{i}" for i in
range(len(self.metrics))]
def __repr__(self) -> str:
return f"Plan observed({self._name}={self._metrics})"
@@ -444,6 +447,10 @@ class PlanObservedMetrics:
def metrics(self) -> List[pb2.Expression.Literal]:
return self._metrics
+ @property
+ def keys(self) -> List[str]:
+ return self._keys
+
class AnalyzeResult:
def __init__(
@@ -798,33 +805,37 @@ class SparkConnectClient(object):
def _build_observed_metrics(
self, metrics: Sequence["pb2.ExecutePlanResponse.ObservedMetrics"]
) -> Iterator[PlanObservedMetrics]:
- return (PlanObservedMetrics(x.name, [v for v in x.values]) for x in
metrics)
+ return (PlanObservedMetrics(x.name, [v for v in x.values],
list(x.keys)) for x in metrics)
- def to_table_as_iterator(self, plan: pb2.Plan) ->
Iterator[Union[StructType, "pa.Table"]]:
+ def to_table_as_iterator(
+ self, plan: pb2.Plan, observations: Dict[str, Observation]
+ ) -> Iterator[Union[StructType, "pa.Table"]]:
"""
Return given plan as a PyArrow Table iterator.
"""
logger.info(f"Executing plan {self._proto_to_string(plan)}")
req = self._execute_plan_request_with_metadata()
req.plan.CopyFrom(plan)
- for response in self._execute_and_fetch_as_iterator(req):
+ for response in self._execute_and_fetch_as_iterator(req, observations):
if isinstance(response, StructType):
yield response
elif isinstance(response, pa.RecordBatch):
yield pa.Table.from_batches([response])
- def to_table(self, plan: pb2.Plan) -> Tuple["pa.Table",
Optional[StructType]]:
+ def to_table(
+ self, plan: pb2.Plan, observations: Dict[str, Observation]
+ ) -> Tuple["pa.Table", Optional[StructType]]:
"""
Return given plan as a PyArrow Table.
"""
logger.info(f"Executing plan {self._proto_to_string(plan)}")
req = self._execute_plan_request_with_metadata()
req.plan.CopyFrom(plan)
- table, schema, _, _, _ = self._execute_and_fetch(req)
+ table, schema, _, _, _ = self._execute_and_fetch(req, observations)
assert table is not None
return table, schema
- def to_pandas(self, plan: pb2.Plan) -> "pd.DataFrame":
+ def to_pandas(self, plan: pb2.Plan, observations: Dict[str, Observation])
-> "pd.DataFrame":
"""
Return given plan as a pandas DataFrame.
"""
@@ -836,7 +847,7 @@ class SparkConnectClient(object):
)
self_destruct = cast(str, self_destruct_conf).lower() == "true"
table, schema, metrics, observed_metrics, _ = self._execute_and_fetch(
- req, self_destruct=self_destruct
+ req, observations, self_destruct=self_destruct
)
assert table is not None
@@ -945,7 +956,7 @@ class SparkConnectClient(object):
return result
def execute_command(
- self, command: pb2.Command
+ self, command: pb2.Command, observations: Optional[Dict[str,
Observation]] = None
) -> Tuple[Optional[pd.DataFrame], Dict[str, Any]]:
"""
Execute given command.
@@ -955,7 +966,7 @@ class SparkConnectClient(object):
if self._user_id:
req.user_context.user_id = self._user_id
req.plan.command.CopyFrom(command)
- data, _, _, _, properties = self._execute_and_fetch(req)
+ data, _, _, _, properties = self._execute_and_fetch(req, observations
or {})
if data is not None:
return (data.to_pandas(), properties)
else:
@@ -1155,7 +1166,7 @@ class SparkConnectClient(object):
self._handle_error(error)
def _execute_and_fetch_as_iterator(
- self, req: pb2.ExecutePlanRequest
+ self, req: pb2.ExecutePlanRequest, observations: Dict[str, Observation]
) -> Iterator[
Union[
"pa.RecordBatch",
@@ -1191,7 +1202,13 @@ class SparkConnectClient(object):
yield from self._build_metrics(b.metrics)
if b.observed_metrics:
logger.debug("Received observed metric batch.")
- yield from self._build_observed_metrics(b.observed_metrics)
+ for observed_metrics in
self._build_observed_metrics(b.observed_metrics):
+ if observed_metrics.name in observations:
+ observations[observed_metrics.name]._result = {
+ key: LiteralExpression._to_value(metric)
+ for key, metric in zip(observed_metrics.keys,
observed_metrics.metrics)
+ }
+ yield observed_metrics
if b.HasField("schema"):
logger.debug("Received the schema.")
dt = types.proto_schema_to_pyspark_data_type(b.schema)
@@ -1262,7 +1279,10 @@ class SparkConnectClient(object):
self._handle_error(error)
def _execute_and_fetch(
- self, req: pb2.ExecutePlanRequest, self_destruct: bool = False
+ self,
+ req: pb2.ExecutePlanRequest,
+ observations: Dict[str, Observation],
+ self_destruct: bool = False,
) -> Tuple[
Optional["pa.Table"],
Optional[StructType],
@@ -1278,7 +1298,7 @@ class SparkConnectClient(object):
schema: Optional[StructType] = None
properties: Dict[str, Any] = {}
- for response in self._execute_and_fetch_as_iterator(req):
+ for response in self._execute_and_fetch_as_iterator(req, observations):
if isinstance(response, StructType):
schema = response
elif isinstance(response, pa.RecordBatch):
diff --git a/python/pyspark/sql/connect/dataframe.py
b/python/pyspark/sql/connect/dataframe.py
index 4044fab3bb35..b322ded84a46 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -45,7 +45,6 @@ from collections.abc import Iterable
from pyspark import _NoValue
from pyspark._globals import _NoValueType
-from pyspark.sql.observation import Observation
from pyspark.sql.types import Row, StructType, _create_row
from pyspark.sql.dataframe import (
DataFrame as PySparkDataFrame,
@@ -92,6 +91,7 @@ if TYPE_CHECKING:
PandasMapIterFunction,
ArrowMapIterFunction,
)
+ from pyspark.sql.connect.observation import Observation
from pyspark.sql.connect.session import SparkSession
from pyspark.pandas.frame import DataFrame as PandasOnSparkDataFrame
@@ -1051,6 +1051,8 @@ class DataFrame:
observation: Union["Observation", str],
*exprs: Column,
) -> "DataFrame":
+ from pyspark.sql.connect.observation import Observation
+
if len(exprs) == 0:
raise PySparkValueError(
error_class="CANNOT_BE_EMPTY",
@@ -1063,10 +1065,7 @@ class DataFrame:
)
if isinstance(observation, Observation):
- return DataFrame.withPlan(
- plan.CollectMetrics(self._plan, str(observation._name),
list(exprs)),
- self._session,
- )
+ return observation._on(self, *exprs)
elif isinstance(observation, str):
return DataFrame.withPlan(
plan.CollectMetrics(self._plan, observation, list(exprs)),
@@ -1722,13 +1721,13 @@ class DataFrame:
def _to_table(self) -> Tuple["pa.Table", Optional[StructType]]:
query = self._plan.to_proto(self._session.client)
- table, schema = self._session.client.to_table(query)
+ table, schema = self._session.client.to_table(query,
self._plan.observations)
assert table is not None
return (table, schema)
def toPandas(self) -> "pandas.DataFrame":
query = self._plan.to_proto(self._session.client)
- return self._session.client.to_pandas(query)
+ return self._session.client.to_pandas(query, self._plan.observations)
toPandas.__doc__ = PySparkDataFrame.toPandas.__doc__
@@ -1865,7 +1864,7 @@ class DataFrame:
command = plan.CreateView(
child=self._plan, name=name, is_global=False, replace=False
).command(session=self._session.client)
- self._session.client.execute_command(command)
+ self._session.client.execute_command(command, self._plan.observations)
createTempView.__doc__ = PySparkDataFrame.createTempView.__doc__
@@ -1873,7 +1872,7 @@ class DataFrame:
command = plan.CreateView(
child=self._plan, name=name, is_global=False, replace=True
).command(session=self._session.client)
- self._session.client.execute_command(command)
+ self._session.client.execute_command(command, self._plan.observations)
createOrReplaceTempView.__doc__ =
PySparkDataFrame.createOrReplaceTempView.__doc__
@@ -1881,7 +1880,7 @@ class DataFrame:
command = plan.CreateView(
child=self._plan, name=name, is_global=True, replace=False
).command(session=self._session.client)
- self._session.client.execute_command(command)
+ self._session.client.execute_command(command, self._plan.observations)
createGlobalTempView.__doc__ =
PySparkDataFrame.createGlobalTempView.__doc__
@@ -1889,7 +1888,7 @@ class DataFrame:
command = plan.CreateView(
child=self._plan, name=name, is_global=True, replace=True
).command(session=self._session.client)
- self._session.client.execute_command(command)
+ self._session.client.execute_command(command, self._plan.observations)
createOrReplaceGlobalTempView.__doc__ =
PySparkDataFrame.createOrReplaceGlobalTempView.__doc__
@@ -1936,7 +1935,9 @@ class DataFrame:
query = self._plan.to_proto(self._session.client)
schema: Optional[StructType] = None
- for schema_or_table in
self._session.client.to_table_as_iterator(query):
+ for schema_or_table in self._session.client.to_table_as_iterator(
+ query, self._plan.observations
+ ):
if isinstance(schema_or_table, StructType):
assert schema is None
schema = schema_or_table
diff --git a/python/pyspark/sql/connect/observation.py
b/python/pyspark/sql/connect/observation.py
new file mode 100644
index 000000000000..ff1044357496
--- /dev/null
+++ b/python/pyspark/sql/connect/observation.py
@@ -0,0 +1,95 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from typing import Any, Dict, Optional
+import uuid
+
+from pyspark.errors import IllegalArgumentException
+from pyspark.sql.connect.column import Column
+from pyspark.sql.connect.dataframe import DataFrame
+from pyspark.sql.observation import Observation as PySparkObservation
+import pyspark.sql.connect.plan as plan
+
+
+__all__ = ["Observation"]
+
+
+class Observation:
+ def __init__(self, name: Optional[str] = None) -> None:
+ if name is not None:
+ if not isinstance(name, str):
+ raise TypeError("name should be a string")
+ if name == "":
+ raise ValueError("name should not be empty")
+ self._name = name
+ self._result: Optional[Dict[str, Any]] = None
+
+ __init__.__doc__ = PySparkObservation.__init__.__doc__
+
+ def _on(self, df: DataFrame, *exprs: Column) -> DataFrame:
+ assert self._result is None, "an Observation can be used with a
DataFrame only once"
+
+ if self._name is None:
+ self._name = str(uuid.uuid4())
+
+ if df.isStreaming:
+ raise IllegalArgumentException("Observation does not support
streaming Datasets")
+
+ self._result = {}
+ return DataFrame.withPlan(plan.CollectMetrics(df._plan, self,
list(exprs)), df._session)
+
+ _on.__doc__ = PySparkObservation._on.__doc__
+
+ @property
+ def get(self) -> Dict[str, Any]:
+ assert self._result is not None
+ return self._result
+
+ get.__doc__ = PySparkObservation.get.__doc__
+
+
+Observation.__doc__ = PySparkObservation.__doc__
+
+
+def _test() -> None:
+ import sys
+ import doctest
+ from pyspark.sql import SparkSession as PySparkSession
+ import pyspark.sql.connect.observation
+
+ globs = pyspark.sql.connect.observation.__dict__.copy()
+ globs["spark"] = (
+ PySparkSession.builder.appName("sql.connect.observation tests")
+ .remote("local[4]")
+ .getOrCreate()
+ )
+
+ (failure_count, test_count) = doctest.testmod(
+ pyspark.sql.connect.observation,
+ globs=globs,
+ optionflags=doctest.ELLIPSIS
+ | doctest.NORMALIZE_WHITESPACE
+ | doctest.IGNORE_EXCEPTION_DETAIL,
+ )
+
+ globs["spark"].stop()
+
+ if failure_count:
+ sys.exit(-1)
+
+
+if __name__ == "__main__":
+ _test()
diff --git a/python/pyspark/sql/connect/plan.py
b/python/pyspark/sql/connect/plan.py
index 0121d4c3d572..d888422d29f7 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -52,6 +52,7 @@ if TYPE_CHECKING:
from pyspark.sql.connect._typing import ColumnOrName
from pyspark.sql.connect.client import SparkConnectClient
from pyspark.sql.connect.udf import UserDefinedFunction
+ from pyspark.sql.connect.observation import Observation
class LogicalPlan:
@@ -129,6 +130,13 @@ class LogicalPlan:
return plan
+ @property
+ def observations(self) -> Dict[str, "Observation"]:
+ if self._child is None:
+ return {}
+ else:
+ return self._child.observations
+
def _parameters_to_print(self, parameters: Mapping[str, Any]) ->
Mapping[str, Any]:
"""
Extracts the parameters that are able to be printed. It looks up the
signature
@@ -879,6 +887,10 @@ class Join(LogicalPlan):
plan.join.join_type = self.how
return plan
+ @property
+ def observations(self) -> Dict[str, "Observation"]:
+ return dict(**super().observations, **self.right.observations)
+
def print(self, indent: int = 0) -> str:
i = " " * indent
o = " " * (indent + LogicalPlan.INDENT)
@@ -966,6 +978,10 @@ class AsOfJoin(LogicalPlan):
return plan
+ @property
+ def observations(self) -> Dict[str, "Observation"]:
+ return dict(**super().observations, **self.right.observations)
+
def print(self, indent: int = 0) -> str:
assert self.left is not None
assert self.right is not None
@@ -1035,6 +1051,13 @@ class SetOperation(LogicalPlan):
plan.set_op.allow_missing_columns = self.allow_missing_columns
return plan
+ @property
+ def observations(self) -> Dict[str, "Observation"]:
+ return dict(
+ **super().observations,
+ **(self.other.observations if self.other is not None else {}),
+ )
+
def print(self, indent: int = 0) -> str:
assert self._child is not None
assert self.other is not None
@@ -1269,11 +1292,11 @@ class CollectMetrics(LogicalPlan):
def __init__(
self,
child: Optional["LogicalPlan"],
- name: str,
+ observation: Union[str, "Observation"],
exprs: List["ColumnOrName"],
) -> None:
super().__init__(child)
- self._name = name
+ self._observation = observation
self._exprs = exprs
def col_to_expr(self, col: "ColumnOrName", session: "SparkConnectClient")
-> proto.Expression:
@@ -1286,10 +1309,24 @@ class CollectMetrics(LogicalPlan):
assert self._child is not None
plan = self._create_proto_relation()
plan.collect_metrics.input.CopyFrom(self._child.plan(session))
- plan.collect_metrics.name = self._name
+ plan.collect_metrics.name = (
+ self._observation
+ if isinstance(self._observation, str)
+ else str(self._observation._name)
+ )
plan.collect_metrics.metrics.extend([self.col_to_expr(x, session) for
x in self._exprs])
return plan
+ @property
+ def observations(self) -> Dict[str, "Observation"]:
+ from pyspark.sql.connect.observation import Observation
+
+ if isinstance(self._observation, Observation):
+ observations = {str(self._observation._name): self._observation}
+ else:
+ observations = {}
+ return dict(**super().observations, **observations)
+
class NAFill(LogicalPlan):
def __init__(
diff --git a/python/pyspark/sql/connect/proto/base_pb2.py
b/python/pyspark/sql/connect/proto/base_pb2.py
index bc9272772a87..05040d813501 100644
--- a/python/pyspark/sql/connect/proto/base_pb2.py
+++ b/python/pyspark/sql/connect/proto/base_pb2.py
@@ -37,7 +37,7 @@ from pyspark.sql.connect.proto import types_pb2 as
spark_dot_connect_dot_types__
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-
b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01
\x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02
\x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"z\n\x0bUserContext\x12\x17
[...]
+
b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01
\x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02
\x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"z\n\x0bUserContext\x12\x17
[...]
)
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
@@ -120,7 +120,7 @@ if _descriptor._USE_C_DESCRIPTORS == False:
_EXECUTEPLANREQUEST_REQUESTOPTION._serialized_start = 4924
_EXECUTEPLANREQUEST_REQUESTOPTION._serialized_end = 5089
_EXECUTEPLANRESPONSE._serialized_start = 5125
- _EXECUTEPLANRESPONSE._serialized_end = 7127
+ _EXECUTEPLANRESPONSE._serialized_end = 7147
_EXECUTEPLANRESPONSE_SQLCOMMANDRESULT._serialized_start = 6283
_EXECUTEPLANRESPONSE_SQLCOMMANDRESULT._serialized_end = 6354
_EXECUTEPLANRESPONSE_ARROWBATCH._serialized_start = 6356
@@ -134,85 +134,85 @@ if _descriptor._USE_C_DESCRIPTORS == False:
_EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_start = 6906
_EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_end = 6994
_EXECUTEPLANRESPONSE_OBSERVEDMETRICS._serialized_start = 6996
- _EXECUTEPLANRESPONSE_OBSERVEDMETRICS._serialized_end = 7092
- _EXECUTEPLANRESPONSE_RESULTCOMPLETE._serialized_start = 7094
- _EXECUTEPLANRESPONSE_RESULTCOMPLETE._serialized_end = 7110
- _KEYVALUE._serialized_start = 7129
- _KEYVALUE._serialized_end = 7194
- _CONFIGREQUEST._serialized_start = 7197
- _CONFIGREQUEST._serialized_end = 8225
- _CONFIGREQUEST_OPERATION._serialized_start = 7417
- _CONFIGREQUEST_OPERATION._serialized_end = 7915
- _CONFIGREQUEST_SET._serialized_start = 7917
- _CONFIGREQUEST_SET._serialized_end = 7969
- _CONFIGREQUEST_GET._serialized_start = 7971
- _CONFIGREQUEST_GET._serialized_end = 7996
- _CONFIGREQUEST_GETWITHDEFAULT._serialized_start = 7998
- _CONFIGREQUEST_GETWITHDEFAULT._serialized_end = 8061
- _CONFIGREQUEST_GETOPTION._serialized_start = 8063
- _CONFIGREQUEST_GETOPTION._serialized_end = 8094
- _CONFIGREQUEST_GETALL._serialized_start = 8096
- _CONFIGREQUEST_GETALL._serialized_end = 8144
- _CONFIGREQUEST_UNSET._serialized_start = 8146
- _CONFIGREQUEST_UNSET._serialized_end = 8173
- _CONFIGREQUEST_ISMODIFIABLE._serialized_start = 8175
- _CONFIGREQUEST_ISMODIFIABLE._serialized_end = 8209
- _CONFIGRESPONSE._serialized_start = 8227
- _CONFIGRESPONSE._serialized_end = 8349
- _ADDARTIFACTSREQUEST._serialized_start = 8352
- _ADDARTIFACTSREQUEST._serialized_end = 9223
- _ADDARTIFACTSREQUEST_ARTIFACTCHUNK._serialized_start = 8739
- _ADDARTIFACTSREQUEST_ARTIFACTCHUNK._serialized_end = 8792
- _ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT._serialized_start = 8794
- _ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT._serialized_end = 8905
- _ADDARTIFACTSREQUEST_BATCH._serialized_start = 8907
- _ADDARTIFACTSREQUEST_BATCH._serialized_end = 9000
- _ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT._serialized_start = 9003
- _ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT._serialized_end = 9196
- _ADDARTIFACTSRESPONSE._serialized_start = 9226
- _ADDARTIFACTSRESPONSE._serialized_end = 9414
- _ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY._serialized_start = 9333
- _ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY._serialized_end = 9414
- _ARTIFACTSTATUSESREQUEST._serialized_start = 9417
- _ARTIFACTSTATUSESREQUEST._serialized_end = 9612
- _ARTIFACTSTATUSESRESPONSE._serialized_start = 9615
- _ARTIFACTSTATUSESRESPONSE._serialized_end = 9883
- _ARTIFACTSTATUSESRESPONSE_ARTIFACTSTATUS._serialized_start = 9726
- _ARTIFACTSTATUSESRESPONSE_ARTIFACTSTATUS._serialized_end = 9766
- _ARTIFACTSTATUSESRESPONSE_STATUSESENTRY._serialized_start = 9768
- _ARTIFACTSTATUSESRESPONSE_STATUSESENTRY._serialized_end = 9883
- _INTERRUPTREQUEST._serialized_start = 9886
- _INTERRUPTREQUEST._serialized_end = 10358
- _INTERRUPTREQUEST_INTERRUPTTYPE._serialized_start = 10201
- _INTERRUPTREQUEST_INTERRUPTTYPE._serialized_end = 10329
- _INTERRUPTRESPONSE._serialized_start = 10360
- _INTERRUPTRESPONSE._serialized_end = 10451
- _REATTACHOPTIONS._serialized_start = 10453
- _REATTACHOPTIONS._serialized_end = 10506
- _REATTACHEXECUTEREQUEST._serialized_start = 10509
- _REATTACHEXECUTEREQUEST._serialized_end = 10784
- _RELEASEEXECUTEREQUEST._serialized_start = 10787
- _RELEASEEXECUTEREQUEST._serialized_end = 11241
- _RELEASEEXECUTEREQUEST_RELEASEALL._serialized_start = 11153
- _RELEASEEXECUTEREQUEST_RELEASEALL._serialized_end = 11165
- _RELEASEEXECUTEREQUEST_RELEASEUNTIL._serialized_start = 11167
- _RELEASEEXECUTEREQUEST_RELEASEUNTIL._serialized_end = 11214
- _RELEASEEXECUTERESPONSE._serialized_start = 11243
- _RELEASEEXECUTERESPONSE._serialized_end = 11355
- _FETCHERRORDETAILSREQUEST._serialized_start = 11358
- _FETCHERRORDETAILSREQUEST._serialized_end = 11559
- _FETCHERRORDETAILSRESPONSE._serialized_start = 11562
- _FETCHERRORDETAILSRESPONSE._serialized_end = 12837
- _FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT._serialized_start = 11707
- _FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT._serialized_end = 11881
- _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT._serialized_start = 11884
- _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT._serialized_end = 12056
- _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE._serialized_start = 12059
- _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE._serialized_end = 12468
-
_FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE_MESSAGEPARAMETERSENTRY._serialized_start
= 12370
-
_FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE_MESSAGEPARAMETERSENTRY._serialized_end
= 12438
- _FETCHERRORDETAILSRESPONSE_ERROR._serialized_start = 12471
- _FETCHERRORDETAILSRESPONSE_ERROR._serialized_end = 12818
- _SPARKCONNECTSERVICE._serialized_start = 12840
- _SPARKCONNECTSERVICE._serialized_end = 13689
+ _EXECUTEPLANRESPONSE_OBSERVEDMETRICS._serialized_end = 7112
+ _EXECUTEPLANRESPONSE_RESULTCOMPLETE._serialized_start = 7114
+ _EXECUTEPLANRESPONSE_RESULTCOMPLETE._serialized_end = 7130
+ _KEYVALUE._serialized_start = 7149
+ _KEYVALUE._serialized_end = 7214
+ _CONFIGREQUEST._serialized_start = 7217
+ _CONFIGREQUEST._serialized_end = 8245
+ _CONFIGREQUEST_OPERATION._serialized_start = 7437
+ _CONFIGREQUEST_OPERATION._serialized_end = 7935
+ _CONFIGREQUEST_SET._serialized_start = 7937
+ _CONFIGREQUEST_SET._serialized_end = 7989
+ _CONFIGREQUEST_GET._serialized_start = 7991
+ _CONFIGREQUEST_GET._serialized_end = 8016
+ _CONFIGREQUEST_GETWITHDEFAULT._serialized_start = 8018
+ _CONFIGREQUEST_GETWITHDEFAULT._serialized_end = 8081
+ _CONFIGREQUEST_GETOPTION._serialized_start = 8083
+ _CONFIGREQUEST_GETOPTION._serialized_end = 8114
+ _CONFIGREQUEST_GETALL._serialized_start = 8116
+ _CONFIGREQUEST_GETALL._serialized_end = 8164
+ _CONFIGREQUEST_UNSET._serialized_start = 8166
+ _CONFIGREQUEST_UNSET._serialized_end = 8193
+ _CONFIGREQUEST_ISMODIFIABLE._serialized_start = 8195
+ _CONFIGREQUEST_ISMODIFIABLE._serialized_end = 8229
+ _CONFIGRESPONSE._serialized_start = 8247
+ _CONFIGRESPONSE._serialized_end = 8369
+ _ADDARTIFACTSREQUEST._serialized_start = 8372
+ _ADDARTIFACTSREQUEST._serialized_end = 9243
+ _ADDARTIFACTSREQUEST_ARTIFACTCHUNK._serialized_start = 8759
+ _ADDARTIFACTSREQUEST_ARTIFACTCHUNK._serialized_end = 8812
+ _ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT._serialized_start = 8814
+ _ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT._serialized_end = 8925
+ _ADDARTIFACTSREQUEST_BATCH._serialized_start = 8927
+ _ADDARTIFACTSREQUEST_BATCH._serialized_end = 9020
+ _ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT._serialized_start = 9023
+ _ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT._serialized_end = 9216
+ _ADDARTIFACTSRESPONSE._serialized_start = 9246
+ _ADDARTIFACTSRESPONSE._serialized_end = 9434
+ _ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY._serialized_start = 9353
+ _ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY._serialized_end = 9434
+ _ARTIFACTSTATUSESREQUEST._serialized_start = 9437
+ _ARTIFACTSTATUSESREQUEST._serialized_end = 9632
+ _ARTIFACTSTATUSESRESPONSE._serialized_start = 9635
+ _ARTIFACTSTATUSESRESPONSE._serialized_end = 9903
+ _ARTIFACTSTATUSESRESPONSE_ARTIFACTSTATUS._serialized_start = 9746
+ _ARTIFACTSTATUSESRESPONSE_ARTIFACTSTATUS._serialized_end = 9786
+ _ARTIFACTSTATUSESRESPONSE_STATUSESENTRY._serialized_start = 9788
+ _ARTIFACTSTATUSESRESPONSE_STATUSESENTRY._serialized_end = 9903
+ _INTERRUPTREQUEST._serialized_start = 9906
+ _INTERRUPTREQUEST._serialized_end = 10378
+ _INTERRUPTREQUEST_INTERRUPTTYPE._serialized_start = 10221
+ _INTERRUPTREQUEST_INTERRUPTTYPE._serialized_end = 10349
+ _INTERRUPTRESPONSE._serialized_start = 10380
+ _INTERRUPTRESPONSE._serialized_end = 10471
+ _REATTACHOPTIONS._serialized_start = 10473
+ _REATTACHOPTIONS._serialized_end = 10526
+ _REATTACHEXECUTEREQUEST._serialized_start = 10529
+ _REATTACHEXECUTEREQUEST._serialized_end = 10804
+ _RELEASEEXECUTEREQUEST._serialized_start = 10807
+ _RELEASEEXECUTEREQUEST._serialized_end = 11261
+ _RELEASEEXECUTEREQUEST_RELEASEALL._serialized_start = 11173
+ _RELEASEEXECUTEREQUEST_RELEASEALL._serialized_end = 11185
+ _RELEASEEXECUTEREQUEST_RELEASEUNTIL._serialized_start = 11187
+ _RELEASEEXECUTEREQUEST_RELEASEUNTIL._serialized_end = 11234
+ _RELEASEEXECUTERESPONSE._serialized_start = 11263
+ _RELEASEEXECUTERESPONSE._serialized_end = 11375
+ _FETCHERRORDETAILSREQUEST._serialized_start = 11378
+ _FETCHERRORDETAILSREQUEST._serialized_end = 11579
+ _FETCHERRORDETAILSRESPONSE._serialized_start = 11582
+ _FETCHERRORDETAILSRESPONSE._serialized_end = 12857
+ _FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT._serialized_start = 11727
+ _FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT._serialized_end = 11901
+ _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT._serialized_start = 11904
+ _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT._serialized_end = 12076
+ _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE._serialized_start = 12079
+ _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE._serialized_end = 12488
+
_FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE_MESSAGEPARAMETERSENTRY._serialized_start
= 12390
+
_FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE_MESSAGEPARAMETERSENTRY._serialized_end
= 12458
+ _FETCHERRORDETAILSRESPONSE_ERROR._serialized_start = 12491
+ _FETCHERRORDETAILSRESPONSE_ERROR._serialized_end = 12838
+ _SPARKCONNECTSERVICE._serialized_start = 12860
+ _SPARKCONNECTSERVICE._serialized_end = 13709
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/base_pb2.pyi
b/python/pyspark/sql/connect/proto/base_pb2.pyi
index 0ad295dbe080..5d2ebeb57399 100644
--- a/python/pyspark/sql/connect/proto/base_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/base_pb2.pyi
@@ -1349,6 +1349,7 @@ class
ExecutePlanResponse(google.protobuf.message.Message):
NAME_FIELD_NUMBER: builtins.int
VALUES_FIELD_NUMBER: builtins.int
+ KEYS_FIELD_NUMBER: builtins.int
name: builtins.str
@property
def values(
@@ -1356,6 +1357,10 @@ class
ExecutePlanResponse(google.protobuf.message.Message):
) ->
google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
pyspark.sql.connect.proto.expressions_pb2.Expression.Literal
]: ...
+ @property
+ def keys(
+ self,
+ ) ->
google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]:
...
def __init__(
self,
*,
@@ -1364,9 +1369,13 @@ class
ExecutePlanResponse(google.protobuf.message.Message):
pyspark.sql.connect.proto.expressions_pb2.Expression.Literal
]
| None = ...,
+ keys: collections.abc.Iterable[builtins.str] | None = ...,
) -> None: ...
def ClearField(
- self, field_name: typing_extensions.Literal["name", b"name",
"values", b"values"]
+ self,
+ field_name: typing_extensions.Literal[
+ "keys", b"keys", "name", b"name", "values", b"values"
+ ],
) -> None: ...
class ResultComplete(google.protobuf.message.Message):
diff --git a/python/pyspark/sql/connect/readwriter.py
b/python/pyspark/sql/connect/readwriter.py
index ff8ee829749e..fee98cc34964 100644
--- a/python/pyspark/sql/connect/readwriter.py
+++ b/python/pyspark/sql/connect/readwriter.py
@@ -648,7 +648,9 @@ class DataFrameWriter(OptionUtils):
if format is not None:
self.format(format)
self._write.path = path
-
self._spark.client.execute_command(self._write.command(self._spark.client))
+ self._spark.client.execute_command(
+ self._write.command(self._spark.client), self._write.observations
+ )
save.__doc__ = PySparkDataFrameWriter.save.__doc__
@@ -657,7 +659,9 @@ class DataFrameWriter(OptionUtils):
self.mode("overwrite" if overwrite else "append")
self._write.table_name = tableName
self._write.table_save_method = "insert_into"
-
self._spark.client.execute_command(self._write.command(self._spark.client))
+ self._spark.client.execute_command(
+ self._write.command(self._spark.client), self._write.observations
+ )
insertInto.__doc__ = PySparkDataFrameWriter.insertInto.__doc__
@@ -676,7 +680,9 @@ class DataFrameWriter(OptionUtils):
self.format(format)
self._write.table_name = name
self._write.table_save_method = "save_as_table"
-
self._spark.client.execute_command(self._write.command(self._spark.client))
+ self._spark.client.execute_command(
+ self._write.command(self._spark.client), self._write.observations
+ )
saveAsTable.__doc__ = PySparkDataFrameWriter.saveAsTable.__doc__
@@ -876,38 +882,50 @@ class DataFrameWriterV2(OptionUtils):
def create(self) -> None:
self._write.mode = "create"
-
self._spark.client.execute_command(self._write.command(self._spark.client))
+ self._spark.client.execute_command(
+ self._write.command(self._spark.client), self._write.observations
+ )
create.__doc__ = PySparkDataFrameWriterV2.create.__doc__
def replace(self) -> None:
self._write.mode = "replace"
-
self._spark.client.execute_command(self._write.command(self._spark.client))
+ self._spark.client.execute_command(
+ self._write.command(self._spark.client), self._write.observations
+ )
replace.__doc__ = PySparkDataFrameWriterV2.replace.__doc__
def createOrReplace(self) -> None:
self._write.mode = "create_or_replace"
-
self._spark.client.execute_command(self._write.command(self._spark.client))
+ self._spark.client.execute_command(
+ self._write.command(self._spark.client), self._write.observations
+ )
createOrReplace.__doc__ = PySparkDataFrameWriterV2.createOrReplace.__doc__
def append(self) -> None:
self._write.mode = "append"
-
self._spark.client.execute_command(self._write.command(self._spark.client))
+ self._spark.client.execute_command(
+ self._write.command(self._spark.client), self._write.observations
+ )
append.__doc__ = PySparkDataFrameWriterV2.append.__doc__
def overwrite(self, condition: "ColumnOrName") -> None:
self._write.mode = "overwrite"
self._write.overwrite_condition = condition
-
self._spark.client.execute_command(self._write.command(self._spark.client))
+ self._spark.client.execute_command(
+ self._write.command(self._spark.client), self._write.observations
+ )
overwrite.__doc__ = PySparkDataFrameWriterV2.overwrite.__doc__
def overwritePartitions(self) -> None:
self._write.mode = "overwrite_partitions"
-
self._spark.client.execute_command(self._write.command(self._spark.client))
+ self._spark.client.execute_command(
+ self._write.command(self._spark.client), self._write.observations
+ )
overwritePartitions.__doc__ =
PySparkDataFrameWriterV2.overwritePartitions.__doc__
diff --git a/python/pyspark/sql/observation.py
b/python/pyspark/sql/observation.py
index 686b036bb9ec..19201cdf0f3c 100644
--- a/python/pyspark/sql/observation.py
+++ b/python/pyspark/sql/observation.py
@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
+import os
from typing import Any, Dict, Optional
from py4j.java_gateway import JavaObject, JVMView
@@ -21,7 +22,7 @@ from py4j.java_gateway import JavaObject, JVMView
from pyspark.sql import column
from pyspark.sql.column import Column
from pyspark.sql.dataframe import DataFrame
-from pyspark.sql.utils import try_remote_observation
+from pyspark.sql.utils import is_remote
__all__ = ["Observation"]
@@ -67,6 +68,13 @@ class Observation:
{'count': 2, 'max(age)': 5}
"""
+ def __new__(cls, *args: Any, **kwargs: Any) -> Any:
+ if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ:
+ from pyspark.sql.connect.observation import Observation as
ConnectObservation
+
+ return ConnectObservation(*args, **kwargs)
+ return super().__new__(cls)
+
def __init__(self, name: Optional[str] = None) -> None:
"""Constructs a named or unnamed Observation instance.
@@ -84,7 +92,6 @@ class Observation:
self._jvm: Optional[JVMView] = None
self._jo: Optional[JavaObject] = None
- @try_remote_observation
def _on(self, df: DataFrame, *exprs: Column) -> DataFrame:
"""Attaches this observation to the given :class:`DataFrame` to
observe aggregations.
@@ -111,9 +118,7 @@ class Observation:
)
return DataFrame(observed_df, df.sparkSession)
- # Note that decorated property only works with Python 3.9+ which Spark
Connect requires.
@property
- @try_remote_observation
def get(self) -> Dict[str, Any]:
"""Get the observed metrics.
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 306b2d3b2be7..c96d08b5bbe2 100755
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -1789,14 +1789,16 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
.toPandas(),
)
+ from pyspark.sql.connect.observation import Observation as
ConnectObservation
from pyspark.sql.observation import Observation
+ cobservation = ConnectObservation(observation_name)
observation = Observation(observation_name)
cdf = (
self.connect.read.table(self.tbl_name)
.filter("id > 3")
- .observe(observation, CF.min("id"), CF.max("id"), CF.sum("id"))
+ .observe(cobservation, CF.min("id"), CF.max("id"), CF.sum("id"))
.toPandas()
)
df = (
@@ -1808,6 +1810,8 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
self.assert_eq(cdf, df)
+ self.assert_eq(cobservation.get, observation.get)
+
observed_metrics = cdf.attrs["observed_metrics"]
self.assert_eq(len(observed_metrics), 1)
self.assert_eq(observed_metrics[0].name, observation_name)
diff --git a/python/pyspark/sql/tests/connect/test_connect_plan.py
b/python/pyspark/sql/tests/connect/test_connect_plan.py
index bd1c6e037154..f3be1683fb2d 100644
--- a/python/pyspark/sql/tests/connect/test_connect_plan.py
+++ b/python/pyspark/sql/tests/connect/test_connect_plan.py
@@ -330,10 +330,18 @@ class SparkConnectPlanTests(PlanOnlyTestFixture):
"id",
)
- from pyspark.sql.observation import Observation
+ from pyspark.sql.connect.observation import Observation
+
+ class MockDF(DataFrame):
+ def __init__(self, df: DataFrame):
+ super().__init__(df._plan, df._session)
+
+ @property
+ def isStreaming(self) -> bool:
+ return False
plan = (
- df.filter(df.col_name > 3)
+ MockDF(df.filter(df.col_name > 3))
.observe(Observation("my_metric"), min("id"), max("id"), sum("id"))
._plan.to_proto(self.connect)
)
diff --git a/python/pyspark/sql/tests/connect/test_parity_dataframe.py
b/python/pyspark/sql/tests/connect/test_parity_dataframe.py
index cc9f71f8b46f..b7b4fdcd287b 100644
--- a/python/pyspark/sql/tests/connect/test_parity_dataframe.py
+++ b/python/pyspark/sql/tests/connect/test_parity_dataframe.py
@@ -26,11 +26,6 @@ class DataFrameParityTests(DataFrameTestsMixin,
ReusedConnectTestCase):
def test_help_command(self):
super().test_help_command()
- # TODO(SPARK-41527): Implement DataFrame.observe
- @unittest.skip("Fails in Spark Connect, should enable.")
- def test_observe(self):
- super().test_observe()
-
# TODO(SPARK-41625): Support Structured Streaming
@unittest.skip("Fails in Spark Connect, should enable.")
def test_observe_str(self):
diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py
index f584440e1212..b366c6c8bd8d 100644
--- a/python/pyspark/sql/utils.py
+++ b/python/pyspark/sql/utils.py
@@ -263,22 +263,6 @@ def get_active_spark_context() -> SparkContext:
return sc
-def try_remote_observation(f: FuncT) -> FuncT:
- """Mark API supported from Spark Connect."""
-
- @functools.wraps(f)
- def wrapped(*args: Any, **kwargs: Any) -> Any:
- # TODO(SPARK-41527): Add the support of Observation.
- if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ:
- raise PySparkNotImplementedError(
- error_class="NOT_IMPLEMENTED",
- message_parameters={"feature": "Observation support for Spark
Connect"},
- )
- return f(*args, **kwargs)
-
- return cast(FuncT, wrapped)
-
-
def try_remote_session_classmethod(f: FuncT) -> FuncT:
"""Mark API supported from Spark Connect."""
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]