This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new dbd765dbfac2 [SPARK-53544][PYTHON] Support complex types on observations dbd765dbfac2 is described below commit dbd765dbfac244124d12ed3cd60012564a14f4b9 Author: Takuya Ueshin <ues...@databricks.com> AuthorDate: Mon Sep 15 07:58:53 2025 +0900 [SPARK-53544][PYTHON] Support complex types on observations ### What changes were proposed in this pull request? Supports complex types on observations. ### Why are the changes needed? The observations didn't support complex types. For example: ```py >>> observation = Observation("struct") >>> df = spark.range(10).observe( ... observation, ... F.struct(F.count(F.lit(1)).alias("rows"), F.max("id").alias("maxid")).alias("struct"), ... ) ``` - classic ```py >>> df.collect() [Row(id=0), Row(id=1), Row(id=2), Row(id=3), Row(id=4), Row(id=5), Row(id=6), Row(id=7), Row(id=8), Row(id=9)] >>> observation.get {'struct': JavaObject id=o61} ``` - connect ```py >>> df.collect() Traceback (most recent call last): ... pyspark.errors.exceptions.base.PySparkTypeError: [UNSUPPORTED_LITERAL] Unsupported Literal 'struct { ... ``` ### Does this PR introduce _any_ user-facing change? Yes, complex types are available on observations. ```py >>> df.collect() [Row(id=0), Row(id=1), Row(id=2), Row(id=3), Row(id=4), Row(id=5), Row(id=6), Row(id=7), Row(id=8), Row(id=9)] >>> >>> observation.get {'struct': Row(rows=10, maxid=9)} ``` ### How was this patch tested? Added the related tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #52321 from ueshin/issues/SPARK-53544/complex_observation. Authored-by: Takuya Ueshin <ues...@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- python/pyspark/sql/connect/expressions.py | 26 +++++++++ python/pyspark/sql/observation.py | 10 +++- python/pyspark/sql/tests/test_observation.py | 68 ++++++++++++++++------ .../scala/org/apache/spark/sql/Observation.scala | 12 +++- 4 files changed, 94 insertions(+), 22 deletions(-) diff --git a/python/pyspark/sql/connect/expressions.py b/python/pyspark/sql/connect/expressions.py index 4ddf13757db4..624599aac9e3 100644 --- a/python/pyspark/sql/connect/expressions.py +++ b/python/pyspark/sql/connect/expressions.py @@ -39,6 +39,7 @@ import numpy as np from pyspark.serializers import CloudPickleSerializer from pyspark.sql.types import ( + _create_row, _from_numpy_type, DateType, ArrayType, @@ -58,6 +59,8 @@ from pyspark.sql.types import ( TimestampType, TimestampNTZType, DayTimeIntervalType, + MapType, + StructType, ) import pyspark.sql.connect.proto as proto @@ -441,6 +444,29 @@ class LiteralExpression(Expression): assert isinstance(dataType, ArrayType) assert elementType == dataType.elementType return [LiteralExpression._to_value(v, elementType) for v in literal.array.elements] + elif literal.HasField("map"): + keyType = proto_schema_to_pyspark_data_type(literal.map.key_type) + valueType = proto_schema_to_pyspark_data_type(literal.map.value_type) + if dataType is not None: + assert isinstance(dataType, MapType) + assert keyType == dataType.keyType + assert valueType == dataType.valueType + return { + LiteralExpression._to_value(k, keyType): LiteralExpression._to_value(v, valueType) + for k, v in zip(literal.map.keys, literal.map.values) + } + elif literal.HasField("struct"): + struct_type = cast( + StructType, proto_schema_to_pyspark_data_type(literal.struct.struct_type) + ) + if dataType is not None: + assert isinstance(dataType, StructType) + assert struct_type == dataType + values = [ + LiteralExpression._to_value(v, f.dataType) + for v, f in zip(literal.struct.elements, struct_type.fields) + ] + return _create_row(struct_type.names, values) raise PySparkTypeError( errorClass="UNSUPPORTED_LITERAL", diff --git a/python/pyspark/sql/observation.py b/python/pyspark/sql/observation.py index 09ae7a339a4c..36accdb25b70 100644 --- a/python/pyspark/sql/observation.py +++ b/python/pyspark/sql/observation.py @@ -18,6 +18,8 @@ import os from typing import Any, Dict, Optional, TYPE_CHECKING from pyspark.errors import PySparkTypeError, PySparkValueError, PySparkAssertionError +from pyspark.serializers import CPickleSerializer +from pyspark.sql import Row from pyspark.sql.column import Column from pyspark.sql.dataframe import DataFrame from pyspark.sql.utils import is_remote @@ -144,9 +146,11 @@ class Observation: if self._jo is None: raise PySparkAssertionError(errorClass="NO_OBSERVE_BEFORE_GET", messageParameters={}) - jmap = self._jo.getAsJava() - # return a pure Python dict, not jmap which is a py4j JavaMap - return {k: v for k, v in jmap.items()} + assert self._jvm is not None + utils = getattr(self._jvm, "org.apache.spark.sql.api.python.PythonSQLUtils") + jrow = self._jo.getRow() + row: Row = CPickleSerializer().loads(utils.toPyRow(jrow)) + return row.asDict(recursive=False) def _test() -> None: diff --git a/python/pyspark/sql/tests/test_observation.py b/python/pyspark/sql/tests/test_observation.py index d780da300c35..5af709d191c1 100644 --- a/python/pyspark/sql/tests/test_observation.py +++ b/python/pyspark/sql/tests/test_observation.py @@ -18,21 +18,19 @@ import time import unittest -from pyspark.sql import Row -from pyspark.sql.functions import col, lit, count, sum, mean +from pyspark.sql import Row, Observation, functions as F from pyspark.errors import ( PySparkAssertionError, PySparkTypeError, PySparkValueError, ) from pyspark.testing.sqlutils import ReusedSQLTestCase +from pyspark.testing.utils import assertDataFrameEqual class DataFrameObservationTestsMixin: def test_observe(self): # SPARK-36263: tests the DataFrame.observe(Observation, *Column) method - from pyspark.sql import Observation - df = self.spark.createDataFrame( [ (1, 1.0, "one"), @@ -58,11 +56,11 @@ class DataFrameObservationTestsMixin: df.orderBy("id") .observe( named_observation, - count(lit(1)).alias("cnt"), - sum(col("id")).alias("sum"), - mean(col("val")).alias("mean"), + F.count(F.lit(1)).alias("cnt"), + F.sum(F.col("id")).alias("sum"), + F.mean(F.col("val")).alias("mean"), ) - .observe(unnamed_observation, count(lit(1)).alias("rows")) + .observe(unnamed_observation, F.count(F.lit(1)).alias("rows")) ) # test that observe works transparently @@ -81,7 +79,7 @@ class DataFrameObservationTestsMixin: self.assertEqual(unnamed_observation.get, dict(rows=3)) with self.assertRaises(PySparkAssertionError) as pe: - df.observe(named_observation, count(lit(1)).alias("count")) + df.observe(named_observation, F.count(F.lit(1)).alias("count")) self.check_error( exception=pe.exception, @@ -106,7 +104,7 @@ class DataFrameObservationTestsMixin: ) # dataframe.observe requires non-None Columns - for args in [(None,), ("id",), (lit(1), None), (lit(1), "id")]: + for args in [(None,), ("id",), (F.lit(1), None), (F.lit(1), "id")]: with self.subTest(args=args): with self.assertRaises(PySparkTypeError) as pe: df.observe(Observation(), *args) @@ -140,7 +138,9 @@ class DataFrameObservationTestsMixin: self.spark.streams.addListener(TestListener()) df = self.spark.readStream.format("rate").option("rowsPerSecond", 10).load() - df = df.observe("metric", count(lit(1)).alias("cnt"), sum(col("value")).alias("sum")) + df = df.observe( + "metric", F.count(F.lit(1)).alias("cnt"), F.sum(F.col("value")).alias("sum") + ) q = df.writeStream.format("noop").queryName("test").start() self.assertTrue(q.isActive) time.sleep(10) @@ -157,15 +157,13 @@ class DataFrameObservationTestsMixin: def test_observe_with_same_name_on_different_dataframe(self): # SPARK-45656: named observations with the same name on different datasets - from pyspark.sql import Observation - observation1 = Observation("named") df1 = self.spark.range(50) - observed_df1 = df1.observe(observation1, count(lit(1)).alias("cnt")) + observed_df1 = df1.observe(observation1, F.count(F.lit(1)).alias("cnt")) observation2 = Observation("named") df2 = self.spark.range(100) - observed_df2 = df2.observe(observation2, count(lit(1)).alias("cnt")) + observed_df2 = df2.observe(observation2, F.count(F.lit(1)).alias("cnt")) observed_df1.collect() observed_df2.collect() @@ -174,8 +172,6 @@ class DataFrameObservationTestsMixin: self.assertEqual(observation2.get, dict(cnt=100)) def test_observe_on_commands(self): - from pyspark.sql import Observation - df = self.spark.range(50) test_table = "test_table" @@ -190,10 +186,46 @@ class DataFrameObservationTestsMixin: ]: with self.subTest(command=command): observation = Observation() - observed_df = df.observe(observation, count(lit(1)).alias("cnt")) + observed_df = df.observe(observation, F.count(F.lit(1)).alias("cnt")) action(observed_df) self.assertEqual(observation.get, dict(cnt=50)) + def test_observe_with_struct_type(self): + observation = Observation("struct") + + df = self.spark.range(10).observe( + observation, + F.struct(F.count(F.lit(1)).alias("rows"), F.max("id").alias("maxid")).alias("struct"), + ) + + assertDataFrameEqual(df, [Row(id=id) for id in range(10)]) + + self.assertEqual(observation.get, {"struct": Row(rows=10, maxid=9)}) + + def test_observe_with_array_type(self): + observation = Observation("array") + + df = self.spark.range(10).observe( + observation, + F.array(F.count(F.lit(1))).alias("array"), + ) + + assertDataFrameEqual(df, [Row(id=id) for id in range(10)]) + + self.assertEqual(observation.get, {"array": [10]}) + + def test_observe_with_map_type(self): + observation = Observation("map") + + df = self.spark.range(10).observe( + observation, + F.create_map(F.lit("count"), F.count(F.lit(1))).alias("map"), + ) + + assertDataFrameEqual(df, [Row(id=id) for id in range(10)]) + + self.assertEqual(observation.get, {"map": {"count": 10}}) + class DataFrameObservationTests( DataFrameObservationTestsMixin, diff --git a/sql/api/src/main/scala/org/apache/spark/sql/Observation.scala b/sql/api/src/main/scala/org/apache/spark/sql/Observation.scala index 5e271983aa6b..59c27d1e5630 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/Observation.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/Observation.scala @@ -77,7 +77,7 @@ class Observation(val name: String) { */ @throws[InterruptedException] def get: Map[String, Any] = { - val row = SparkThreadUtils.awaitResult(future, Duration.Inf) + val row = getRow row.getValuesMap(row.schema.map(_.name)) } @@ -134,6 +134,16 @@ class Observation(val name: String) { private[sql] def getRowOrEmpty: Option[Row] = { Try(SparkThreadUtils.awaitResult(future, 100.millis)).toOption } + + /** + * Get the observed metrics as a Row. + * + * @return + * the observed metrics as a `Row`. + */ + private[sql] def getRow: Row = { + SparkThreadUtils.awaitResult(future, Duration.Inf) + } } /** --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org