This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new dbd765dbfac2 [SPARK-53544][PYTHON] Support complex types on 
observations
dbd765dbfac2 is described below

commit dbd765dbfac244124d12ed3cd60012564a14f4b9
Author: Takuya Ueshin <ues...@databricks.com>
AuthorDate: Mon Sep 15 07:58:53 2025 +0900

    [SPARK-53544][PYTHON] Support complex types on observations
    
    ### What changes were proposed in this pull request?
    
    Supports complex types on observations.
    
    ### Why are the changes needed?
    
    The observations didn't support complex types.
    
    For example:
    
    ```py
    >>> observation = Observation("struct")
    >>> df = spark.range(10).observe(
    ...     observation,
    ...     F.struct(F.count(F.lit(1)).alias("rows"), 
F.max("id").alias("maxid")).alias("struct"),
    ... )
    ```
    
    - classic
    
    ```py
    >>> df.collect()
    [Row(id=0), Row(id=1), Row(id=2), Row(id=3), Row(id=4), Row(id=5), 
Row(id=6), Row(id=7), Row(id=8), Row(id=9)]
    >>> observation.get
    {'struct': JavaObject id=o61}
    ```
    
    - connect
    
    ```py
    >>> df.collect()
    Traceback (most recent call last):
    ...
    pyspark.errors.exceptions.base.PySparkTypeError: [UNSUPPORTED_LITERAL] 
Unsupported Literal 'struct {
    ...
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, complex types are available on observations.
    
    ```py
    >>> df.collect()
    [Row(id=0), Row(id=1), Row(id=2), Row(id=3), Row(id=4), Row(id=5), 
Row(id=6), Row(id=7), Row(id=8), Row(id=9)]
    >>>
    >>> observation.get
    {'struct': Row(rows=10, maxid=9)}
    ```
    
    ### How was this patch tested?
    
    Added the related tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #52321 from ueshin/issues/SPARK-53544/complex_observation.
    
    Authored-by: Takuya Ueshin <ues...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 python/pyspark/sql/connect/expressions.py          | 26 +++++++++
 python/pyspark/sql/observation.py                  | 10 +++-
 python/pyspark/sql/tests/test_observation.py       | 68 ++++++++++++++++------
 .../scala/org/apache/spark/sql/Observation.scala   | 12 +++-
 4 files changed, 94 insertions(+), 22 deletions(-)

diff --git a/python/pyspark/sql/connect/expressions.py 
b/python/pyspark/sql/connect/expressions.py
index 4ddf13757db4..624599aac9e3 100644
--- a/python/pyspark/sql/connect/expressions.py
+++ b/python/pyspark/sql/connect/expressions.py
@@ -39,6 +39,7 @@ import numpy as np
 
 from pyspark.serializers import CloudPickleSerializer
 from pyspark.sql.types import (
+    _create_row,
     _from_numpy_type,
     DateType,
     ArrayType,
@@ -58,6 +59,8 @@ from pyspark.sql.types import (
     TimestampType,
     TimestampNTZType,
     DayTimeIntervalType,
+    MapType,
+    StructType,
 )
 
 import pyspark.sql.connect.proto as proto
@@ -441,6 +444,29 @@ class LiteralExpression(Expression):
                 assert isinstance(dataType, ArrayType)
                 assert elementType == dataType.elementType
             return [LiteralExpression._to_value(v, elementType) for v in 
literal.array.elements]
+        elif literal.HasField("map"):
+            keyType = proto_schema_to_pyspark_data_type(literal.map.key_type)
+            valueType = 
proto_schema_to_pyspark_data_type(literal.map.value_type)
+            if dataType is not None:
+                assert isinstance(dataType, MapType)
+                assert keyType == dataType.keyType
+                assert valueType == dataType.valueType
+            return {
+                LiteralExpression._to_value(k, keyType): 
LiteralExpression._to_value(v, valueType)
+                for k, v in zip(literal.map.keys, literal.map.values)
+            }
+        elif literal.HasField("struct"):
+            struct_type = cast(
+                StructType, 
proto_schema_to_pyspark_data_type(literal.struct.struct_type)
+            )
+            if dataType is not None:
+                assert isinstance(dataType, StructType)
+                assert struct_type == dataType
+            values = [
+                LiteralExpression._to_value(v, f.dataType)
+                for v, f in zip(literal.struct.elements, struct_type.fields)
+            ]
+            return _create_row(struct_type.names, values)
 
         raise PySparkTypeError(
             errorClass="UNSUPPORTED_LITERAL",
diff --git a/python/pyspark/sql/observation.py 
b/python/pyspark/sql/observation.py
index 09ae7a339a4c..36accdb25b70 100644
--- a/python/pyspark/sql/observation.py
+++ b/python/pyspark/sql/observation.py
@@ -18,6 +18,8 @@ import os
 from typing import Any, Dict, Optional, TYPE_CHECKING
 
 from pyspark.errors import PySparkTypeError, PySparkValueError, 
PySparkAssertionError
+from pyspark.serializers import CPickleSerializer
+from pyspark.sql import Row
 from pyspark.sql.column import Column
 from pyspark.sql.dataframe import DataFrame
 from pyspark.sql.utils import is_remote
@@ -144,9 +146,11 @@ class Observation:
         if self._jo is None:
             raise PySparkAssertionError(errorClass="NO_OBSERVE_BEFORE_GET", 
messageParameters={})
 
-        jmap = self._jo.getAsJava()
-        # return a pure Python dict, not jmap which is a py4j JavaMap
-        return {k: v for k, v in jmap.items()}
+        assert self._jvm is not None
+        utils = getattr(self._jvm, 
"org.apache.spark.sql.api.python.PythonSQLUtils")
+        jrow = self._jo.getRow()
+        row: Row = CPickleSerializer().loads(utils.toPyRow(jrow))
+        return row.asDict(recursive=False)
 
 
 def _test() -> None:
diff --git a/python/pyspark/sql/tests/test_observation.py 
b/python/pyspark/sql/tests/test_observation.py
index d780da300c35..5af709d191c1 100644
--- a/python/pyspark/sql/tests/test_observation.py
+++ b/python/pyspark/sql/tests/test_observation.py
@@ -18,21 +18,19 @@
 import time
 import unittest
 
-from pyspark.sql import Row
-from pyspark.sql.functions import col, lit, count, sum, mean
+from pyspark.sql import Row, Observation, functions as F
 from pyspark.errors import (
     PySparkAssertionError,
     PySparkTypeError,
     PySparkValueError,
 )
 from pyspark.testing.sqlutils import ReusedSQLTestCase
+from pyspark.testing.utils import assertDataFrameEqual
 
 
 class DataFrameObservationTestsMixin:
     def test_observe(self):
         # SPARK-36263: tests the DataFrame.observe(Observation, *Column) method
-        from pyspark.sql import Observation
-
         df = self.spark.createDataFrame(
             [
                 (1, 1.0, "one"),
@@ -58,11 +56,11 @@ class DataFrameObservationTestsMixin:
             df.orderBy("id")
             .observe(
                 named_observation,
-                count(lit(1)).alias("cnt"),
-                sum(col("id")).alias("sum"),
-                mean(col("val")).alias("mean"),
+                F.count(F.lit(1)).alias("cnt"),
+                F.sum(F.col("id")).alias("sum"),
+                F.mean(F.col("val")).alias("mean"),
             )
-            .observe(unnamed_observation, count(lit(1)).alias("rows"))
+            .observe(unnamed_observation, F.count(F.lit(1)).alias("rows"))
         )
 
         # test that observe works transparently
@@ -81,7 +79,7 @@ class DataFrameObservationTestsMixin:
         self.assertEqual(unnamed_observation.get, dict(rows=3))
 
         with self.assertRaises(PySparkAssertionError) as pe:
-            df.observe(named_observation, count(lit(1)).alias("count"))
+            df.observe(named_observation, F.count(F.lit(1)).alias("count"))
 
         self.check_error(
             exception=pe.exception,
@@ -106,7 +104,7 @@ class DataFrameObservationTestsMixin:
         )
 
         # dataframe.observe requires non-None Columns
-        for args in [(None,), ("id",), (lit(1), None), (lit(1), "id")]:
+        for args in [(None,), ("id",), (F.lit(1), None), (F.lit(1), "id")]:
             with self.subTest(args=args):
                 with self.assertRaises(PySparkTypeError) as pe:
                     df.observe(Observation(), *args)
@@ -140,7 +138,9 @@ class DataFrameObservationTestsMixin:
         self.spark.streams.addListener(TestListener())
 
         df = self.spark.readStream.format("rate").option("rowsPerSecond", 
10).load()
-        df = df.observe("metric", count(lit(1)).alias("cnt"), 
sum(col("value")).alias("sum"))
+        df = df.observe(
+            "metric", F.count(F.lit(1)).alias("cnt"), 
F.sum(F.col("value")).alias("sum")
+        )
         q = df.writeStream.format("noop").queryName("test").start()
         self.assertTrue(q.isActive)
         time.sleep(10)
@@ -157,15 +157,13 @@ class DataFrameObservationTestsMixin:
 
     def test_observe_with_same_name_on_different_dataframe(self):
         # SPARK-45656: named observations with the same name on different 
datasets
-        from pyspark.sql import Observation
-
         observation1 = Observation("named")
         df1 = self.spark.range(50)
-        observed_df1 = df1.observe(observation1, count(lit(1)).alias("cnt"))
+        observed_df1 = df1.observe(observation1, 
F.count(F.lit(1)).alias("cnt"))
 
         observation2 = Observation("named")
         df2 = self.spark.range(100)
-        observed_df2 = df2.observe(observation2, count(lit(1)).alias("cnt"))
+        observed_df2 = df2.observe(observation2, 
F.count(F.lit(1)).alias("cnt"))
 
         observed_df1.collect()
         observed_df2.collect()
@@ -174,8 +172,6 @@ class DataFrameObservationTestsMixin:
         self.assertEqual(observation2.get, dict(cnt=100))
 
     def test_observe_on_commands(self):
-        from pyspark.sql import Observation
-
         df = self.spark.range(50)
 
         test_table = "test_table"
@@ -190,10 +186,46 @@ class DataFrameObservationTestsMixin:
             ]:
                 with self.subTest(command=command):
                     observation = Observation()
-                    observed_df = df.observe(observation, 
count(lit(1)).alias("cnt"))
+                    observed_df = df.observe(observation, 
F.count(F.lit(1)).alias("cnt"))
                     action(observed_df)
                     self.assertEqual(observation.get, dict(cnt=50))
 
+    def test_observe_with_struct_type(self):
+        observation = Observation("struct")
+
+        df = self.spark.range(10).observe(
+            observation,
+            F.struct(F.count(F.lit(1)).alias("rows"), 
F.max("id").alias("maxid")).alias("struct"),
+        )
+
+        assertDataFrameEqual(df, [Row(id=id) for id in range(10)])
+
+        self.assertEqual(observation.get, {"struct": Row(rows=10, maxid=9)})
+
+    def test_observe_with_array_type(self):
+        observation = Observation("array")
+
+        df = self.spark.range(10).observe(
+            observation,
+            F.array(F.count(F.lit(1))).alias("array"),
+        )
+
+        assertDataFrameEqual(df, [Row(id=id) for id in range(10)])
+
+        self.assertEqual(observation.get, {"array": [10]})
+
+    def test_observe_with_map_type(self):
+        observation = Observation("map")
+
+        df = self.spark.range(10).observe(
+            observation,
+            F.create_map(F.lit("count"), F.count(F.lit(1))).alias("map"),
+        )
+
+        assertDataFrameEqual(df, [Row(id=id) for id in range(10)])
+
+        self.assertEqual(observation.get, {"map": {"count": 10}})
+
 
 class DataFrameObservationTests(
     DataFrameObservationTestsMixin,
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/Observation.scala 
b/sql/api/src/main/scala/org/apache/spark/sql/Observation.scala
index 5e271983aa6b..59c27d1e5630 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/Observation.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/Observation.scala
@@ -77,7 +77,7 @@ class Observation(val name: String) {
    */
   @throws[InterruptedException]
   def get: Map[String, Any] = {
-    val row = SparkThreadUtils.awaitResult(future, Duration.Inf)
+    val row = getRow
     row.getValuesMap(row.schema.map(_.name))
   }
 
@@ -134,6 +134,16 @@ class Observation(val name: String) {
   private[sql] def getRowOrEmpty: Option[Row] = {
     Try(SparkThreadUtils.awaitResult(future, 100.millis)).toOption
   }
+
+  /**
+   * Get the observed metrics as a Row.
+   *
+   * @return
+   *   the observed metrics as a `Row`.
+   */
+  private[sql] def getRow: Row = {
+    SparkThreadUtils.awaitResult(future, Duration.Inf)
+  }
 }
 
 /**


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to