This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new f12d2441a1b9 [SPARK-45619][CONNECT][PYTHON] Apply the observed metrics 
to Observation object
f12d2441a1b9 is described below

commit f12d2441a1b94118f759f4ed95ca7fbd4b5f0f74
Author: Takuya UESHIN <[email protected]>
AuthorDate: Tue Oct 24 08:30:34 2023 +0900

    [SPARK-45619][CONNECT][PYTHON] Apply the observed metrics to Observation 
object
    
    ### What changes were proposed in this pull request?
    
    Apply the observed metrics to `Observation` object.
    
    ### Why are the changes needed?
    
    When using `Observation`, the observed metrics should be applied to the 
object.
    
    For example, in vanilla PySpark:
    
    ```py
    >>> df = spark.createDataFrame([["Alice", 2], ["Bob", 5]], ["name", "age"])
    >>> observation = Observation("my metrics")
    >>> observed_df = df.observe(observation, count(lit(1)).alias("count"), 
max(col("age")))
    >>> observed_df.count()
    2
    >>>
    >>> observation.get
    {'count': 2, 'max(age)': 5}
    ```
    
    whereas in Spark Connect, currently it fails with 
`PySparkNotImplementedError`:
    
    ```py
    >>> observation.get
    Traceback (most recent call last):
    ...
        raise PySparkNotImplementedError(
    pyspark.errors.exceptions.base.PySparkNotImplementedError: 
[NOT_IMPLEMENTED] Observation support for Spark Connect is not implemented.
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, the `Observation` will see the changes.
    
    ### How was this patch tested?
    
    Added/modified the related tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #43469 from ueshin/issues/SPARK-45619/observation.
    
    Authored-by: Takuya UESHIN <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 .../src/main/protobuf/spark/connect/base.proto     |   1 +
 .../execution/SparkConnectPlanExecution.scala      |   7 +-
 dev/sparktestsupport/modules.py                    |   1 +
 python/pyspark/sql/connect/client/core.py          |  48 ++++--
 python/pyspark/sql/connect/dataframe.py            |  25 ++--
 python/pyspark/sql/connect/observation.py          |  95 ++++++++++++
 python/pyspark/sql/connect/plan.py                 |  43 +++++-
 python/pyspark/sql/connect/proto/base_pb2.py       | 166 ++++++++++-----------
 python/pyspark/sql/connect/proto/base_pb2.pyi      |  11 +-
 python/pyspark/sql/connect/readwriter.py           |  36 +++--
 python/pyspark/sql/observation.py                  |  13 +-
 .../sql/tests/connect/test_connect_basic.py        |   6 +-
 .../pyspark/sql/tests/connect/test_connect_plan.py |  12 +-
 .../sql/tests/connect/test_parity_dataframe.py     |   5 -
 python/pyspark/sql/utils.py                        |  16 --
 15 files changed, 333 insertions(+), 152 deletions(-)

diff --git 
a/connector/connect/common/src/main/protobuf/spark/connect/base.proto 
b/connector/connect/common/src/main/protobuf/spark/connect/base.proto
index e2532cfc66d1..5b94c6d663cc 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/base.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/base.proto
@@ -404,6 +404,7 @@ message ExecutePlanResponse {
   message ObservedMetrics {
     string name = 1;
     repeated Expression.Literal values = 2;
+    repeated string keys = 3;
   }
 
   message ResultComplete {
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala
index a3ce813271a3..9dd54e5b2b5d 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala
@@ -244,11 +244,14 @@ private[execution] class 
SparkConnectPlanExecution(executeHolder: ExecuteHolder)
       dataframe: DataFrame): ExecutePlanResponse = {
     val observedMetrics = dataframe.queryExecution.observedMetrics.map { case 
(name, row) =>
       val cols = (0 until row.length).map(i => toLiteralProto(row(i)))
-      ExecutePlanResponse.ObservedMetrics
+      val metrics = ExecutePlanResponse.ObservedMetrics
         .newBuilder()
         .setName(name)
         .addAllValues(cols.asJava)
-        .build()
+      if (row.schema != null) {
+        metrics.addAllKeys(row.schema.fieldNames.toList.asJava)
+      }
+      metrics.build()
     }
     // Prepare a response with the observed metrics.
     ExecutePlanResponse
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 7251fd7c1588..913d9df6bcf2 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -839,6 +839,7 @@ pyspark_connect = Module(
         "pyspark.sql.connect.readwriter",
         "pyspark.sql.connect.dataframe",
         "pyspark.sql.connect.functions",
+        "pyspark.sql.connect.observation",
         "pyspark.sql.connect.avro.functions",
         "pyspark.sql.connect.protobuf.functions",
         "pyspark.sql.connect.streaming.readwriter",
diff --git a/python/pyspark/sql/connect/client/core.py 
b/python/pyspark/sql/connect/client/core.py
index 958bc460c953..318f7d7ade4a 100644
--- a/python/pyspark/sql/connect/client/core.py
+++ b/python/pyspark/sql/connect/client/core.py
@@ -79,6 +79,7 @@ from pyspark.errors.exceptions.connect import (
     SparkConnectGrpcException,
 )
 from pyspark.sql.connect.expressions import (
+    LiteralExpression,
     PythonUDF,
     CommonInlineUserDefinedFunction,
     JavaUDF,
@@ -87,6 +88,7 @@ from pyspark.sql.connect.plan import (
     CommonInlineUserDefinedTableFunction,
     PythonUDTF,
 )
+from pyspark.sql.connect.observation import Observation
 from pyspark.sql.connect.utils import get_python_ver
 from pyspark.sql.pandas.types import _create_converter_to_pandas, 
from_arrow_schema
 from pyspark.sql.types import DataType, StructType, TimestampType, _has_type
@@ -429,9 +431,10 @@ class PlanMetrics:
 
 
 class PlanObservedMetrics:
-    def __init__(self, name: str, metrics: List[pb2.Expression.Literal]):
+    def __init__(self, name: str, metrics: List[pb2.Expression.Literal], keys: 
List[str]):
         self._name = name
         self._metrics = metrics
+        self._keys = keys if keys else [f"observed_metric_{i}" for i in 
range(len(self.metrics))]
 
     def __repr__(self) -> str:
         return f"Plan observed({self._name}={self._metrics})"
@@ -444,6 +447,10 @@ class PlanObservedMetrics:
     def metrics(self) -> List[pb2.Expression.Literal]:
         return self._metrics
 
+    @property
+    def keys(self) -> List[str]:
+        return self._keys
+
 
 class AnalyzeResult:
     def __init__(
@@ -798,33 +805,37 @@ class SparkConnectClient(object):
     def _build_observed_metrics(
         self, metrics: Sequence["pb2.ExecutePlanResponse.ObservedMetrics"]
     ) -> Iterator[PlanObservedMetrics]:
-        return (PlanObservedMetrics(x.name, [v for v in x.values]) for x in 
metrics)
+        return (PlanObservedMetrics(x.name, [v for v in x.values], 
list(x.keys)) for x in metrics)
 
-    def to_table_as_iterator(self, plan: pb2.Plan) -> 
Iterator[Union[StructType, "pa.Table"]]:
+    def to_table_as_iterator(
+        self, plan: pb2.Plan, observations: Dict[str, Observation]
+    ) -> Iterator[Union[StructType, "pa.Table"]]:
         """
         Return given plan as a PyArrow Table iterator.
         """
         logger.info(f"Executing plan {self._proto_to_string(plan)}")
         req = self._execute_plan_request_with_metadata()
         req.plan.CopyFrom(plan)
-        for response in self._execute_and_fetch_as_iterator(req):
+        for response in self._execute_and_fetch_as_iterator(req, observations):
             if isinstance(response, StructType):
                 yield response
             elif isinstance(response, pa.RecordBatch):
                 yield pa.Table.from_batches([response])
 
-    def to_table(self, plan: pb2.Plan) -> Tuple["pa.Table", 
Optional[StructType]]:
+    def to_table(
+        self, plan: pb2.Plan, observations: Dict[str, Observation]
+    ) -> Tuple["pa.Table", Optional[StructType]]:
         """
         Return given plan as a PyArrow Table.
         """
         logger.info(f"Executing plan {self._proto_to_string(plan)}")
         req = self._execute_plan_request_with_metadata()
         req.plan.CopyFrom(plan)
-        table, schema, _, _, _ = self._execute_and_fetch(req)
+        table, schema, _, _, _ = self._execute_and_fetch(req, observations)
         assert table is not None
         return table, schema
 
-    def to_pandas(self, plan: pb2.Plan) -> "pd.DataFrame":
+    def to_pandas(self, plan: pb2.Plan, observations: Dict[str, Observation]) 
-> "pd.DataFrame":
         """
         Return given plan as a pandas DataFrame.
         """
@@ -836,7 +847,7 @@ class SparkConnectClient(object):
         )
         self_destruct = cast(str, self_destruct_conf).lower() == "true"
         table, schema, metrics, observed_metrics, _ = self._execute_and_fetch(
-            req, self_destruct=self_destruct
+            req, observations, self_destruct=self_destruct
         )
         assert table is not None
 
@@ -945,7 +956,7 @@ class SparkConnectClient(object):
         return result
 
     def execute_command(
-        self, command: pb2.Command
+        self, command: pb2.Command, observations: Optional[Dict[str, 
Observation]] = None
     ) -> Tuple[Optional[pd.DataFrame], Dict[str, Any]]:
         """
         Execute given command.
@@ -955,7 +966,7 @@ class SparkConnectClient(object):
         if self._user_id:
             req.user_context.user_id = self._user_id
         req.plan.command.CopyFrom(command)
-        data, _, _, _, properties = self._execute_and_fetch(req)
+        data, _, _, _, properties = self._execute_and_fetch(req, observations 
or {})
         if data is not None:
             return (data.to_pandas(), properties)
         else:
@@ -1155,7 +1166,7 @@ class SparkConnectClient(object):
             self._handle_error(error)
 
     def _execute_and_fetch_as_iterator(
-        self, req: pb2.ExecutePlanRequest
+        self, req: pb2.ExecutePlanRequest, observations: Dict[str, Observation]
     ) -> Iterator[
         Union[
             "pa.RecordBatch",
@@ -1191,7 +1202,13 @@ class SparkConnectClient(object):
                 yield from self._build_metrics(b.metrics)
             if b.observed_metrics:
                 logger.debug("Received observed metric batch.")
-                yield from self._build_observed_metrics(b.observed_metrics)
+                for observed_metrics in 
self._build_observed_metrics(b.observed_metrics):
+                    if observed_metrics.name in observations:
+                        observations[observed_metrics.name]._result = {
+                            key: LiteralExpression._to_value(metric)
+                            for key, metric in zip(observed_metrics.keys, 
observed_metrics.metrics)
+                        }
+                    yield observed_metrics
             if b.HasField("schema"):
                 logger.debug("Received the schema.")
                 dt = types.proto_schema_to_pyspark_data_type(b.schema)
@@ -1262,7 +1279,10 @@ class SparkConnectClient(object):
             self._handle_error(error)
 
     def _execute_and_fetch(
-        self, req: pb2.ExecutePlanRequest, self_destruct: bool = False
+        self,
+        req: pb2.ExecutePlanRequest,
+        observations: Dict[str, Observation],
+        self_destruct: bool = False,
     ) -> Tuple[
         Optional["pa.Table"],
         Optional[StructType],
@@ -1278,7 +1298,7 @@ class SparkConnectClient(object):
         schema: Optional[StructType] = None
         properties: Dict[str, Any] = {}
 
-        for response in self._execute_and_fetch_as_iterator(req):
+        for response in self._execute_and_fetch_as_iterator(req, observations):
             if isinstance(response, StructType):
                 schema = response
             elif isinstance(response, pa.RecordBatch):
diff --git a/python/pyspark/sql/connect/dataframe.py 
b/python/pyspark/sql/connect/dataframe.py
index 4044fab3bb35..b322ded84a46 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -45,7 +45,6 @@ from collections.abc import Iterable
 
 from pyspark import _NoValue
 from pyspark._globals import _NoValueType
-from pyspark.sql.observation import Observation
 from pyspark.sql.types import Row, StructType, _create_row
 from pyspark.sql.dataframe import (
     DataFrame as PySparkDataFrame,
@@ -92,6 +91,7 @@ if TYPE_CHECKING:
         PandasMapIterFunction,
         ArrowMapIterFunction,
     )
+    from pyspark.sql.connect.observation import Observation
     from pyspark.sql.connect.session import SparkSession
     from pyspark.pandas.frame import DataFrame as PandasOnSparkDataFrame
 
@@ -1051,6 +1051,8 @@ class DataFrame:
         observation: Union["Observation", str],
         *exprs: Column,
     ) -> "DataFrame":
+        from pyspark.sql.connect.observation import Observation
+
         if len(exprs) == 0:
             raise PySparkValueError(
                 error_class="CANNOT_BE_EMPTY",
@@ -1063,10 +1065,7 @@ class DataFrame:
             )
 
         if isinstance(observation, Observation):
-            return DataFrame.withPlan(
-                plan.CollectMetrics(self._plan, str(observation._name), 
list(exprs)),
-                self._session,
-            )
+            return observation._on(self, *exprs)
         elif isinstance(observation, str):
             return DataFrame.withPlan(
                 plan.CollectMetrics(self._plan, observation, list(exprs)),
@@ -1722,13 +1721,13 @@ class DataFrame:
 
     def _to_table(self) -> Tuple["pa.Table", Optional[StructType]]:
         query = self._plan.to_proto(self._session.client)
-        table, schema = self._session.client.to_table(query)
+        table, schema = self._session.client.to_table(query, 
self._plan.observations)
         assert table is not None
         return (table, schema)
 
     def toPandas(self) -> "pandas.DataFrame":
         query = self._plan.to_proto(self._session.client)
-        return self._session.client.to_pandas(query)
+        return self._session.client.to_pandas(query, self._plan.observations)
 
     toPandas.__doc__ = PySparkDataFrame.toPandas.__doc__
 
@@ -1865,7 +1864,7 @@ class DataFrame:
         command = plan.CreateView(
             child=self._plan, name=name, is_global=False, replace=False
         ).command(session=self._session.client)
-        self._session.client.execute_command(command)
+        self._session.client.execute_command(command, self._plan.observations)
 
     createTempView.__doc__ = PySparkDataFrame.createTempView.__doc__
 
@@ -1873,7 +1872,7 @@ class DataFrame:
         command = plan.CreateView(
             child=self._plan, name=name, is_global=False, replace=True
         ).command(session=self._session.client)
-        self._session.client.execute_command(command)
+        self._session.client.execute_command(command, self._plan.observations)
 
     createOrReplaceTempView.__doc__ = 
PySparkDataFrame.createOrReplaceTempView.__doc__
 
@@ -1881,7 +1880,7 @@ class DataFrame:
         command = plan.CreateView(
             child=self._plan, name=name, is_global=True, replace=False
         ).command(session=self._session.client)
-        self._session.client.execute_command(command)
+        self._session.client.execute_command(command, self._plan.observations)
 
     createGlobalTempView.__doc__ = 
PySparkDataFrame.createGlobalTempView.__doc__
 
@@ -1889,7 +1888,7 @@ class DataFrame:
         command = plan.CreateView(
             child=self._plan, name=name, is_global=True, replace=True
         ).command(session=self._session.client)
-        self._session.client.execute_command(command)
+        self._session.client.execute_command(command, self._plan.observations)
 
     createOrReplaceGlobalTempView.__doc__ = 
PySparkDataFrame.createOrReplaceGlobalTempView.__doc__
 
@@ -1936,7 +1935,9 @@ class DataFrame:
         query = self._plan.to_proto(self._session.client)
 
         schema: Optional[StructType] = None
-        for schema_or_table in 
self._session.client.to_table_as_iterator(query):
+        for schema_or_table in self._session.client.to_table_as_iterator(
+            query, self._plan.observations
+        ):
             if isinstance(schema_or_table, StructType):
                 assert schema is None
                 schema = schema_or_table
diff --git a/python/pyspark/sql/connect/observation.py 
b/python/pyspark/sql/connect/observation.py
new file mode 100644
index 000000000000..ff1044357496
--- /dev/null
+++ b/python/pyspark/sql/connect/observation.py
@@ -0,0 +1,95 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from typing import Any, Dict, Optional
+import uuid
+
+from pyspark.errors import IllegalArgumentException
+from pyspark.sql.connect.column import Column
+from pyspark.sql.connect.dataframe import DataFrame
+from pyspark.sql.observation import Observation as PySparkObservation
+import pyspark.sql.connect.plan as plan
+
+
+__all__ = ["Observation"]
+
+
+class Observation:
+    def __init__(self, name: Optional[str] = None) -> None:
+        if name is not None:
+            if not isinstance(name, str):
+                raise TypeError("name should be a string")
+            if name == "":
+                raise ValueError("name should not be empty")
+        self._name = name
+        self._result: Optional[Dict[str, Any]] = None
+
+    __init__.__doc__ = PySparkObservation.__init__.__doc__
+
+    def _on(self, df: DataFrame, *exprs: Column) -> DataFrame:
+        assert self._result is None, "an Observation can be used with a 
DataFrame only once"
+
+        if self._name is None:
+            self._name = str(uuid.uuid4())
+
+        if df.isStreaming:
+            raise IllegalArgumentException("Observation does not support 
streaming Datasets")
+
+        self._result = {}
+        return DataFrame.withPlan(plan.CollectMetrics(df._plan, self, 
list(exprs)), df._session)
+
+    _on.__doc__ = PySparkObservation._on.__doc__
+
+    @property
+    def get(self) -> Dict[str, Any]:
+        assert self._result is not None
+        return self._result
+
+    get.__doc__ = PySparkObservation.get.__doc__
+
+
+Observation.__doc__ = PySparkObservation.__doc__
+
+
+def _test() -> None:
+    import sys
+    import doctest
+    from pyspark.sql import SparkSession as PySparkSession
+    import pyspark.sql.connect.observation
+
+    globs = pyspark.sql.connect.observation.__dict__.copy()
+    globs["spark"] = (
+        PySparkSession.builder.appName("sql.connect.observation tests")
+        .remote("local[4]")
+        .getOrCreate()
+    )
+
+    (failure_count, test_count) = doctest.testmod(
+        pyspark.sql.connect.observation,
+        globs=globs,
+        optionflags=doctest.ELLIPSIS
+        | doctest.NORMALIZE_WHITESPACE
+        | doctest.IGNORE_EXCEPTION_DETAIL,
+    )
+
+    globs["spark"].stop()
+
+    if failure_count:
+        sys.exit(-1)
+
+
+if __name__ == "__main__":
+    _test()
diff --git a/python/pyspark/sql/connect/plan.py 
b/python/pyspark/sql/connect/plan.py
index 0121d4c3d572..d888422d29f7 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -52,6 +52,7 @@ if TYPE_CHECKING:
     from pyspark.sql.connect._typing import ColumnOrName
     from pyspark.sql.connect.client import SparkConnectClient
     from pyspark.sql.connect.udf import UserDefinedFunction
+    from pyspark.sql.connect.observation import Observation
 
 
 class LogicalPlan:
@@ -129,6 +130,13 @@ class LogicalPlan:
 
         return plan
 
+    @property
+    def observations(self) -> Dict[str, "Observation"]:
+        if self._child is None:
+            return {}
+        else:
+            return self._child.observations
+
     def _parameters_to_print(self, parameters: Mapping[str, Any]) -> 
Mapping[str, Any]:
         """
         Extracts the parameters that are able to be printed. It looks up the 
signature
@@ -879,6 +887,10 @@ class Join(LogicalPlan):
         plan.join.join_type = self.how
         return plan
 
+    @property
+    def observations(self) -> Dict[str, "Observation"]:
+        return dict(**super().observations, **self.right.observations)
+
     def print(self, indent: int = 0) -> str:
         i = " " * indent
         o = " " * (indent + LogicalPlan.INDENT)
@@ -966,6 +978,10 @@ class AsOfJoin(LogicalPlan):
 
         return plan
 
+    @property
+    def observations(self) -> Dict[str, "Observation"]:
+        return dict(**super().observations, **self.right.observations)
+
     def print(self, indent: int = 0) -> str:
         assert self.left is not None
         assert self.right is not None
@@ -1035,6 +1051,13 @@ class SetOperation(LogicalPlan):
         plan.set_op.allow_missing_columns = self.allow_missing_columns
         return plan
 
+    @property
+    def observations(self) -> Dict[str, "Observation"]:
+        return dict(
+            **super().observations,
+            **(self.other.observations if self.other is not None else {}),
+        )
+
     def print(self, indent: int = 0) -> str:
         assert self._child is not None
         assert self.other is not None
@@ -1269,11 +1292,11 @@ class CollectMetrics(LogicalPlan):
     def __init__(
         self,
         child: Optional["LogicalPlan"],
-        name: str,
+        observation: Union[str, "Observation"],
         exprs: List["ColumnOrName"],
     ) -> None:
         super().__init__(child)
-        self._name = name
+        self._observation = observation
         self._exprs = exprs
 
     def col_to_expr(self, col: "ColumnOrName", session: "SparkConnectClient") 
-> proto.Expression:
@@ -1286,10 +1309,24 @@ class CollectMetrics(LogicalPlan):
         assert self._child is not None
         plan = self._create_proto_relation()
         plan.collect_metrics.input.CopyFrom(self._child.plan(session))
-        plan.collect_metrics.name = self._name
+        plan.collect_metrics.name = (
+            self._observation
+            if isinstance(self._observation, str)
+            else str(self._observation._name)
+        )
         plan.collect_metrics.metrics.extend([self.col_to_expr(x, session) for 
x in self._exprs])
         return plan
 
+    @property
+    def observations(self) -> Dict[str, "Observation"]:
+        from pyspark.sql.connect.observation import Observation
+
+        if isinstance(self._observation, Observation):
+            observations = {str(self._observation._name): self._observation}
+        else:
+            observations = {}
+        return dict(**super().observations, **observations)
+
 
 class NAFill(LogicalPlan):
     def __init__(
diff --git a/python/pyspark/sql/connect/proto/base_pb2.py 
b/python/pyspark/sql/connect/proto/base_pb2.py
index bc9272772a87..05040d813501 100644
--- a/python/pyspark/sql/connect/proto/base_pb2.py
+++ b/python/pyspark/sql/connect/proto/base_pb2.py
@@ -37,7 +37,7 @@ from pyspark.sql.connect.proto import types_pb2 as 
spark_dot_connect_dot_types__
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    
b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01
 
\x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02
 
\x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"z\n\x0bUserContext\x12\x17
 [...]
+    
b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01
 
\x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02
 
\x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"z\n\x0bUserContext\x12\x17
 [...]
 )
 
 _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
@@ -120,7 +120,7 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     _EXECUTEPLANREQUEST_REQUESTOPTION._serialized_start = 4924
     _EXECUTEPLANREQUEST_REQUESTOPTION._serialized_end = 5089
     _EXECUTEPLANRESPONSE._serialized_start = 5125
-    _EXECUTEPLANRESPONSE._serialized_end = 7127
+    _EXECUTEPLANRESPONSE._serialized_end = 7147
     _EXECUTEPLANRESPONSE_SQLCOMMANDRESULT._serialized_start = 6283
     _EXECUTEPLANRESPONSE_SQLCOMMANDRESULT._serialized_end = 6354
     _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_start = 6356
@@ -134,85 +134,85 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_start = 6906
     _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_end = 6994
     _EXECUTEPLANRESPONSE_OBSERVEDMETRICS._serialized_start = 6996
-    _EXECUTEPLANRESPONSE_OBSERVEDMETRICS._serialized_end = 7092
-    _EXECUTEPLANRESPONSE_RESULTCOMPLETE._serialized_start = 7094
-    _EXECUTEPLANRESPONSE_RESULTCOMPLETE._serialized_end = 7110
-    _KEYVALUE._serialized_start = 7129
-    _KEYVALUE._serialized_end = 7194
-    _CONFIGREQUEST._serialized_start = 7197
-    _CONFIGREQUEST._serialized_end = 8225
-    _CONFIGREQUEST_OPERATION._serialized_start = 7417
-    _CONFIGREQUEST_OPERATION._serialized_end = 7915
-    _CONFIGREQUEST_SET._serialized_start = 7917
-    _CONFIGREQUEST_SET._serialized_end = 7969
-    _CONFIGREQUEST_GET._serialized_start = 7971
-    _CONFIGREQUEST_GET._serialized_end = 7996
-    _CONFIGREQUEST_GETWITHDEFAULT._serialized_start = 7998
-    _CONFIGREQUEST_GETWITHDEFAULT._serialized_end = 8061
-    _CONFIGREQUEST_GETOPTION._serialized_start = 8063
-    _CONFIGREQUEST_GETOPTION._serialized_end = 8094
-    _CONFIGREQUEST_GETALL._serialized_start = 8096
-    _CONFIGREQUEST_GETALL._serialized_end = 8144
-    _CONFIGREQUEST_UNSET._serialized_start = 8146
-    _CONFIGREQUEST_UNSET._serialized_end = 8173
-    _CONFIGREQUEST_ISMODIFIABLE._serialized_start = 8175
-    _CONFIGREQUEST_ISMODIFIABLE._serialized_end = 8209
-    _CONFIGRESPONSE._serialized_start = 8227
-    _CONFIGRESPONSE._serialized_end = 8349
-    _ADDARTIFACTSREQUEST._serialized_start = 8352
-    _ADDARTIFACTSREQUEST._serialized_end = 9223
-    _ADDARTIFACTSREQUEST_ARTIFACTCHUNK._serialized_start = 8739
-    _ADDARTIFACTSREQUEST_ARTIFACTCHUNK._serialized_end = 8792
-    _ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT._serialized_start = 8794
-    _ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT._serialized_end = 8905
-    _ADDARTIFACTSREQUEST_BATCH._serialized_start = 8907
-    _ADDARTIFACTSREQUEST_BATCH._serialized_end = 9000
-    _ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT._serialized_start = 9003
-    _ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT._serialized_end = 9196
-    _ADDARTIFACTSRESPONSE._serialized_start = 9226
-    _ADDARTIFACTSRESPONSE._serialized_end = 9414
-    _ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY._serialized_start = 9333
-    _ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY._serialized_end = 9414
-    _ARTIFACTSTATUSESREQUEST._serialized_start = 9417
-    _ARTIFACTSTATUSESREQUEST._serialized_end = 9612
-    _ARTIFACTSTATUSESRESPONSE._serialized_start = 9615
-    _ARTIFACTSTATUSESRESPONSE._serialized_end = 9883
-    _ARTIFACTSTATUSESRESPONSE_ARTIFACTSTATUS._serialized_start = 9726
-    _ARTIFACTSTATUSESRESPONSE_ARTIFACTSTATUS._serialized_end = 9766
-    _ARTIFACTSTATUSESRESPONSE_STATUSESENTRY._serialized_start = 9768
-    _ARTIFACTSTATUSESRESPONSE_STATUSESENTRY._serialized_end = 9883
-    _INTERRUPTREQUEST._serialized_start = 9886
-    _INTERRUPTREQUEST._serialized_end = 10358
-    _INTERRUPTREQUEST_INTERRUPTTYPE._serialized_start = 10201
-    _INTERRUPTREQUEST_INTERRUPTTYPE._serialized_end = 10329
-    _INTERRUPTRESPONSE._serialized_start = 10360
-    _INTERRUPTRESPONSE._serialized_end = 10451
-    _REATTACHOPTIONS._serialized_start = 10453
-    _REATTACHOPTIONS._serialized_end = 10506
-    _REATTACHEXECUTEREQUEST._serialized_start = 10509
-    _REATTACHEXECUTEREQUEST._serialized_end = 10784
-    _RELEASEEXECUTEREQUEST._serialized_start = 10787
-    _RELEASEEXECUTEREQUEST._serialized_end = 11241
-    _RELEASEEXECUTEREQUEST_RELEASEALL._serialized_start = 11153
-    _RELEASEEXECUTEREQUEST_RELEASEALL._serialized_end = 11165
-    _RELEASEEXECUTEREQUEST_RELEASEUNTIL._serialized_start = 11167
-    _RELEASEEXECUTEREQUEST_RELEASEUNTIL._serialized_end = 11214
-    _RELEASEEXECUTERESPONSE._serialized_start = 11243
-    _RELEASEEXECUTERESPONSE._serialized_end = 11355
-    _FETCHERRORDETAILSREQUEST._serialized_start = 11358
-    _FETCHERRORDETAILSREQUEST._serialized_end = 11559
-    _FETCHERRORDETAILSRESPONSE._serialized_start = 11562
-    _FETCHERRORDETAILSRESPONSE._serialized_end = 12837
-    _FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT._serialized_start = 11707
-    _FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT._serialized_end = 11881
-    _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT._serialized_start = 11884
-    _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT._serialized_end = 12056
-    _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE._serialized_start = 12059
-    _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE._serialized_end = 12468
-    
_FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE_MESSAGEPARAMETERSENTRY._serialized_start
 = 12370
-    
_FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE_MESSAGEPARAMETERSENTRY._serialized_end
 = 12438
-    _FETCHERRORDETAILSRESPONSE_ERROR._serialized_start = 12471
-    _FETCHERRORDETAILSRESPONSE_ERROR._serialized_end = 12818
-    _SPARKCONNECTSERVICE._serialized_start = 12840
-    _SPARKCONNECTSERVICE._serialized_end = 13689
+    _EXECUTEPLANRESPONSE_OBSERVEDMETRICS._serialized_end = 7112
+    _EXECUTEPLANRESPONSE_RESULTCOMPLETE._serialized_start = 7114
+    _EXECUTEPLANRESPONSE_RESULTCOMPLETE._serialized_end = 7130
+    _KEYVALUE._serialized_start = 7149
+    _KEYVALUE._serialized_end = 7214
+    _CONFIGREQUEST._serialized_start = 7217
+    _CONFIGREQUEST._serialized_end = 8245
+    _CONFIGREQUEST_OPERATION._serialized_start = 7437
+    _CONFIGREQUEST_OPERATION._serialized_end = 7935
+    _CONFIGREQUEST_SET._serialized_start = 7937
+    _CONFIGREQUEST_SET._serialized_end = 7989
+    _CONFIGREQUEST_GET._serialized_start = 7991
+    _CONFIGREQUEST_GET._serialized_end = 8016
+    _CONFIGREQUEST_GETWITHDEFAULT._serialized_start = 8018
+    _CONFIGREQUEST_GETWITHDEFAULT._serialized_end = 8081
+    _CONFIGREQUEST_GETOPTION._serialized_start = 8083
+    _CONFIGREQUEST_GETOPTION._serialized_end = 8114
+    _CONFIGREQUEST_GETALL._serialized_start = 8116
+    _CONFIGREQUEST_GETALL._serialized_end = 8164
+    _CONFIGREQUEST_UNSET._serialized_start = 8166
+    _CONFIGREQUEST_UNSET._serialized_end = 8193
+    _CONFIGREQUEST_ISMODIFIABLE._serialized_start = 8195
+    _CONFIGREQUEST_ISMODIFIABLE._serialized_end = 8229
+    _CONFIGRESPONSE._serialized_start = 8247
+    _CONFIGRESPONSE._serialized_end = 8369
+    _ADDARTIFACTSREQUEST._serialized_start = 8372
+    _ADDARTIFACTSREQUEST._serialized_end = 9243
+    _ADDARTIFACTSREQUEST_ARTIFACTCHUNK._serialized_start = 8759
+    _ADDARTIFACTSREQUEST_ARTIFACTCHUNK._serialized_end = 8812
+    _ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT._serialized_start = 8814
+    _ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT._serialized_end = 8925
+    _ADDARTIFACTSREQUEST_BATCH._serialized_start = 8927
+    _ADDARTIFACTSREQUEST_BATCH._serialized_end = 9020
+    _ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT._serialized_start = 9023
+    _ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT._serialized_end = 9216
+    _ADDARTIFACTSRESPONSE._serialized_start = 9246
+    _ADDARTIFACTSRESPONSE._serialized_end = 9434
+    _ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY._serialized_start = 9353
+    _ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY._serialized_end = 9434
+    _ARTIFACTSTATUSESREQUEST._serialized_start = 9437
+    _ARTIFACTSTATUSESREQUEST._serialized_end = 9632
+    _ARTIFACTSTATUSESRESPONSE._serialized_start = 9635
+    _ARTIFACTSTATUSESRESPONSE._serialized_end = 9903
+    _ARTIFACTSTATUSESRESPONSE_ARTIFACTSTATUS._serialized_start = 9746
+    _ARTIFACTSTATUSESRESPONSE_ARTIFACTSTATUS._serialized_end = 9786
+    _ARTIFACTSTATUSESRESPONSE_STATUSESENTRY._serialized_start = 9788
+    _ARTIFACTSTATUSESRESPONSE_STATUSESENTRY._serialized_end = 9903
+    _INTERRUPTREQUEST._serialized_start = 9906
+    _INTERRUPTREQUEST._serialized_end = 10378
+    _INTERRUPTREQUEST_INTERRUPTTYPE._serialized_start = 10221
+    _INTERRUPTREQUEST_INTERRUPTTYPE._serialized_end = 10349
+    _INTERRUPTRESPONSE._serialized_start = 10380
+    _INTERRUPTRESPONSE._serialized_end = 10471
+    _REATTACHOPTIONS._serialized_start = 10473
+    _REATTACHOPTIONS._serialized_end = 10526
+    _REATTACHEXECUTEREQUEST._serialized_start = 10529
+    _REATTACHEXECUTEREQUEST._serialized_end = 10804
+    _RELEASEEXECUTEREQUEST._serialized_start = 10807
+    _RELEASEEXECUTEREQUEST._serialized_end = 11261
+    _RELEASEEXECUTEREQUEST_RELEASEALL._serialized_start = 11173
+    _RELEASEEXECUTEREQUEST_RELEASEALL._serialized_end = 11185
+    _RELEASEEXECUTEREQUEST_RELEASEUNTIL._serialized_start = 11187
+    _RELEASEEXECUTEREQUEST_RELEASEUNTIL._serialized_end = 11234
+    _RELEASEEXECUTERESPONSE._serialized_start = 11263
+    _RELEASEEXECUTERESPONSE._serialized_end = 11375
+    _FETCHERRORDETAILSREQUEST._serialized_start = 11378
+    _FETCHERRORDETAILSREQUEST._serialized_end = 11579
+    _FETCHERRORDETAILSRESPONSE._serialized_start = 11582
+    _FETCHERRORDETAILSRESPONSE._serialized_end = 12857
+    _FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT._serialized_start = 11727
+    _FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT._serialized_end = 11901
+    _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT._serialized_start = 11904
+    _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT._serialized_end = 12076
+    _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE._serialized_start = 12079
+    _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE._serialized_end = 12488
+    
_FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE_MESSAGEPARAMETERSENTRY._serialized_start
 = 12390
+    
_FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE_MESSAGEPARAMETERSENTRY._serialized_end
 = 12458
+    _FETCHERRORDETAILSRESPONSE_ERROR._serialized_start = 12491
+    _FETCHERRORDETAILSRESPONSE_ERROR._serialized_end = 12838
+    _SPARKCONNECTSERVICE._serialized_start = 12860
+    _SPARKCONNECTSERVICE._serialized_end = 13709
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/base_pb2.pyi 
b/python/pyspark/sql/connect/proto/base_pb2.pyi
index 0ad295dbe080..5d2ebeb57399 100644
--- a/python/pyspark/sql/connect/proto/base_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/base_pb2.pyi
@@ -1349,6 +1349,7 @@ class 
ExecutePlanResponse(google.protobuf.message.Message):
 
         NAME_FIELD_NUMBER: builtins.int
         VALUES_FIELD_NUMBER: builtins.int
+        KEYS_FIELD_NUMBER: builtins.int
         name: builtins.str
         @property
         def values(
@@ -1356,6 +1357,10 @@ class 
ExecutePlanResponse(google.protobuf.message.Message):
         ) -> 
google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
             pyspark.sql.connect.proto.expressions_pb2.Expression.Literal
         ]: ...
+        @property
+        def keys(
+            self,
+        ) -> 
google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: 
...
         def __init__(
             self,
             *,
@@ -1364,9 +1369,13 @@ class 
ExecutePlanResponse(google.protobuf.message.Message):
                 pyspark.sql.connect.proto.expressions_pb2.Expression.Literal
             ]
             | None = ...,
+            keys: collections.abc.Iterable[builtins.str] | None = ...,
         ) -> None: ...
         def ClearField(
-            self, field_name: typing_extensions.Literal["name", b"name", 
"values", b"values"]
+            self,
+            field_name: typing_extensions.Literal[
+                "keys", b"keys", "name", b"name", "values", b"values"
+            ],
         ) -> None: ...
 
     class ResultComplete(google.protobuf.message.Message):
diff --git a/python/pyspark/sql/connect/readwriter.py 
b/python/pyspark/sql/connect/readwriter.py
index ff8ee829749e..fee98cc34964 100644
--- a/python/pyspark/sql/connect/readwriter.py
+++ b/python/pyspark/sql/connect/readwriter.py
@@ -648,7 +648,9 @@ class DataFrameWriter(OptionUtils):
         if format is not None:
             self.format(format)
         self._write.path = path
-        
self._spark.client.execute_command(self._write.command(self._spark.client))
+        self._spark.client.execute_command(
+            self._write.command(self._spark.client), self._write.observations
+        )
 
     save.__doc__ = PySparkDataFrameWriter.save.__doc__
 
@@ -657,7 +659,9 @@ class DataFrameWriter(OptionUtils):
             self.mode("overwrite" if overwrite else "append")
         self._write.table_name = tableName
         self._write.table_save_method = "insert_into"
-        
self._spark.client.execute_command(self._write.command(self._spark.client))
+        self._spark.client.execute_command(
+            self._write.command(self._spark.client), self._write.observations
+        )
 
     insertInto.__doc__ = PySparkDataFrameWriter.insertInto.__doc__
 
@@ -676,7 +680,9 @@ class DataFrameWriter(OptionUtils):
             self.format(format)
         self._write.table_name = name
         self._write.table_save_method = "save_as_table"
-        
self._spark.client.execute_command(self._write.command(self._spark.client))
+        self._spark.client.execute_command(
+            self._write.command(self._spark.client), self._write.observations
+        )
 
     saveAsTable.__doc__ = PySparkDataFrameWriter.saveAsTable.__doc__
 
@@ -876,38 +882,50 @@ class DataFrameWriterV2(OptionUtils):
 
     def create(self) -> None:
         self._write.mode = "create"
-        
self._spark.client.execute_command(self._write.command(self._spark.client))
+        self._spark.client.execute_command(
+            self._write.command(self._spark.client), self._write.observations
+        )
 
     create.__doc__ = PySparkDataFrameWriterV2.create.__doc__
 
     def replace(self) -> None:
         self._write.mode = "replace"
-        
self._spark.client.execute_command(self._write.command(self._spark.client))
+        self._spark.client.execute_command(
+            self._write.command(self._spark.client), self._write.observations
+        )
 
     replace.__doc__ = PySparkDataFrameWriterV2.replace.__doc__
 
     def createOrReplace(self) -> None:
         self._write.mode = "create_or_replace"
-        
self._spark.client.execute_command(self._write.command(self._spark.client))
+        self._spark.client.execute_command(
+            self._write.command(self._spark.client), self._write.observations
+        )
 
     createOrReplace.__doc__ = PySparkDataFrameWriterV2.createOrReplace.__doc__
 
     def append(self) -> None:
         self._write.mode = "append"
-        
self._spark.client.execute_command(self._write.command(self._spark.client))
+        self._spark.client.execute_command(
+            self._write.command(self._spark.client), self._write.observations
+        )
 
     append.__doc__ = PySparkDataFrameWriterV2.append.__doc__
 
     def overwrite(self, condition: "ColumnOrName") -> None:
         self._write.mode = "overwrite"
         self._write.overwrite_condition = condition
-        
self._spark.client.execute_command(self._write.command(self._spark.client))
+        self._spark.client.execute_command(
+            self._write.command(self._spark.client), self._write.observations
+        )
 
     overwrite.__doc__ = PySparkDataFrameWriterV2.overwrite.__doc__
 
     def overwritePartitions(self) -> None:
         self._write.mode = "overwrite_partitions"
-        
self._spark.client.execute_command(self._write.command(self._spark.client))
+        self._spark.client.execute_command(
+            self._write.command(self._spark.client), self._write.observations
+        )
 
     overwritePartitions.__doc__ = 
PySparkDataFrameWriterV2.overwritePartitions.__doc__
 
diff --git a/python/pyspark/sql/observation.py 
b/python/pyspark/sql/observation.py
index 686b036bb9ec..19201cdf0f3c 100644
--- a/python/pyspark/sql/observation.py
+++ b/python/pyspark/sql/observation.py
@@ -14,6 +14,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
+import os
 from typing import Any, Dict, Optional
 
 from py4j.java_gateway import JavaObject, JVMView
@@ -21,7 +22,7 @@ from py4j.java_gateway import JavaObject, JVMView
 from pyspark.sql import column
 from pyspark.sql.column import Column
 from pyspark.sql.dataframe import DataFrame
-from pyspark.sql.utils import try_remote_observation
+from pyspark.sql.utils import is_remote
 
 __all__ = ["Observation"]
 
@@ -67,6 +68,13 @@ class Observation:
     {'count': 2, 'max(age)': 5}
     """
 
+    def __new__(cls, *args: Any, **kwargs: Any) -> Any:
+        if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ:
+            from pyspark.sql.connect.observation import Observation as 
ConnectObservation
+
+            return ConnectObservation(*args, **kwargs)
+        return super().__new__(cls)
+
     def __init__(self, name: Optional[str] = None) -> None:
         """Constructs a named or unnamed Observation instance.
 
@@ -84,7 +92,6 @@ class Observation:
         self._jvm: Optional[JVMView] = None
         self._jo: Optional[JavaObject] = None
 
-    @try_remote_observation
     def _on(self, df: DataFrame, *exprs: Column) -> DataFrame:
         """Attaches this observation to the given :class:`DataFrame` to 
observe aggregations.
 
@@ -111,9 +118,7 @@ class Observation:
         )
         return DataFrame(observed_df, df.sparkSession)
 
-    # Note that decorated property only works with Python 3.9+ which Spark 
Connect requires.
     @property
-    @try_remote_observation
     def get(self) -> Dict[str, Any]:
         """Get the observed metrics.
 
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py 
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 306b2d3b2be7..c96d08b5bbe2 100755
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -1789,14 +1789,16 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
             .toPandas(),
         )
 
+        from pyspark.sql.connect.observation import Observation as 
ConnectObservation
         from pyspark.sql.observation import Observation
 
+        cobservation = ConnectObservation(observation_name)
         observation = Observation(observation_name)
 
         cdf = (
             self.connect.read.table(self.tbl_name)
             .filter("id > 3")
-            .observe(observation, CF.min("id"), CF.max("id"), CF.sum("id"))
+            .observe(cobservation, CF.min("id"), CF.max("id"), CF.sum("id"))
             .toPandas()
         )
         df = (
@@ -1808,6 +1810,8 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
 
         self.assert_eq(cdf, df)
 
+        self.assert_eq(cobservation.get, observation.get)
+
         observed_metrics = cdf.attrs["observed_metrics"]
         self.assert_eq(len(observed_metrics), 1)
         self.assert_eq(observed_metrics[0].name, observation_name)
diff --git a/python/pyspark/sql/tests/connect/test_connect_plan.py 
b/python/pyspark/sql/tests/connect/test_connect_plan.py
index bd1c6e037154..f3be1683fb2d 100644
--- a/python/pyspark/sql/tests/connect/test_connect_plan.py
+++ b/python/pyspark/sql/tests/connect/test_connect_plan.py
@@ -330,10 +330,18 @@ class SparkConnectPlanTests(PlanOnlyTestFixture):
             "id",
         )
 
-        from pyspark.sql.observation import Observation
+        from pyspark.sql.connect.observation import Observation
+
+        class MockDF(DataFrame):
+            def __init__(self, df: DataFrame):
+                super().__init__(df._plan, df._session)
+
+            @property
+            def isStreaming(self) -> bool:
+                return False
 
         plan = (
-            df.filter(df.col_name > 3)
+            MockDF(df.filter(df.col_name > 3))
             .observe(Observation("my_metric"), min("id"), max("id"), sum("id"))
             ._plan.to_proto(self.connect)
         )
diff --git a/python/pyspark/sql/tests/connect/test_parity_dataframe.py 
b/python/pyspark/sql/tests/connect/test_parity_dataframe.py
index cc9f71f8b46f..b7b4fdcd287b 100644
--- a/python/pyspark/sql/tests/connect/test_parity_dataframe.py
+++ b/python/pyspark/sql/tests/connect/test_parity_dataframe.py
@@ -26,11 +26,6 @@ class DataFrameParityTests(DataFrameTestsMixin, 
ReusedConnectTestCase):
     def test_help_command(self):
         super().test_help_command()
 
-    # TODO(SPARK-41527): Implement DataFrame.observe
-    @unittest.skip("Fails in Spark Connect, should enable.")
-    def test_observe(self):
-        super().test_observe()
-
     # TODO(SPARK-41625): Support Structured Streaming
     @unittest.skip("Fails in Spark Connect, should enable.")
     def test_observe_str(self):
diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py
index f584440e1212..b366c6c8bd8d 100644
--- a/python/pyspark/sql/utils.py
+++ b/python/pyspark/sql/utils.py
@@ -263,22 +263,6 @@ def get_active_spark_context() -> SparkContext:
     return sc
 
 
-def try_remote_observation(f: FuncT) -> FuncT:
-    """Mark API supported from Spark Connect."""
-
-    @functools.wraps(f)
-    def wrapped(*args: Any, **kwargs: Any) -> Any:
-        # TODO(SPARK-41527): Add the support of Observation.
-        if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ:
-            raise PySparkNotImplementedError(
-                error_class="NOT_IMPLEMENTED",
-                message_parameters={"feature": "Observation support for Spark 
Connect"},
-            )
-        return f(*args, **kwargs)
-
-    return cast(FuncT, wrapped)
-
-
 def try_remote_session_classmethod(f: FuncT) -> FuncT:
     """Mark API supported from Spark Connect."""
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to