WweiL commented on code in PR #41540:
URL: https://github.com/apache/spark/pull/41540#discussion_r1227207933
##########
python/pyspark/sql/streaming/listener.py:
##########
@@ -316,38 +385,107 @@ def errorClassOnException(self) -> Optional[str]:
class StreamingQueryProgress:
"""
.. versionadded:: 3.4.0
+ .. versionchanged:: 3.5.0
+ Add fromJson constructor to support Spark Connect.
Notes
-----
This API is evolving.
"""
- def __init__(self, jprogress: JavaObject) -> None:
+ def __init__(
+ self,
+ json: str,
+ prettyJson: str,
+ id: uuid.UUID,
+ runId: uuid.UUID,
+ name: Optional[str],
+ timestamp: str,
+ batchId: int,
+ batchDuration: int,
+ durationMs: Dict[str, int],
+ eventTime: Dict[str, str],
+ stateOperators: List["StateOperatorProgress"],
+ sources: List["SourceProgress"],
+ sink: "SinkProgress",
+ numInputRows: Optional[str],
+ inputRowsPerSecond: float,
+ processedRowsPerSecond: float,
+ observedMetrics: Dict[str, Row],
+ ):
+ self._json: str = json
+ self._prettyJson: str = prettyJson
+ self._id: uuid.UUID = id
+ self._runId: uuid.UUID = runId
+ self._name: Optional[str] = name
+ self._timestamp: str = timestamp
+ self._batchId: int = batchId
+ self._batchDuration: int = batchDuration
+ self._durationMs: Dict[str, int] = durationMs
+ self._eventTime: Dict[str, str] = eventTime
+ self._stateOperators: List[StateOperatorProgress] = stateOperators
+ self._sources: List[SourceProgress] = sources
+ self._sink: SinkProgress = sink
+ self._numInputRows: Optional[str] = numInputRows
+ self._inputRowsPerSecond: float = inputRowsPerSecond
+ self._processedRowsPerSecond: float = processedRowsPerSecond
+ self._observedMetrics: Dict[str, Row] = observedMetrics
+
+ @classmethod
+ def fromJObject(cls, jprogress: JavaObject) -> "StreamingQueryProgress":
from pyspark import SparkContext
- self._jprogress: JavaObject = jprogress
- self._id: uuid.UUID = uuid.UUID(jprogress.id().toString())
- self._runId: uuid.UUID = uuid.UUID(jprogress.runId().toString())
- self._name: Optional[str] = jprogress.name()
- self._timestamp: str = jprogress.timestamp()
- self._batchId: int = jprogress.batchId()
- self._inputRowsPerSecond: float = jprogress.inputRowsPerSecond()
- self._processedRowsPerSecond: float =
jprogress.processedRowsPerSecond()
- self._batchDuration: int = jprogress.batchDuration()
- self._durationMs: Dict[str, int] = dict(jprogress.durationMs())
- self._eventTime: Dict[str, str] = dict(jprogress.eventTime())
- self._stateOperators: List[StateOperatorProgress] = [
- StateOperatorProgress(js) for js in jprogress.stateOperators()
- ]
- self._sources: List[SourceProgress] = [SourceProgress(js) for js in
jprogress.sources()]
- self._sink: SinkProgress = SinkProgress(jprogress.sink())
-
- self._observedMetrics: Dict[str, Row] = {
- k: cloudpickle.loads(
- SparkContext._jvm.PythonSQLUtils.toPyRow(jr) # type:
ignore[union-attr]
- )
- for k, jr in dict(jprogress.observedMetrics()).items()
- }
+ return cls(
+ json=jprogress.json(),
+ prettyJson=jprogress.prettyJson(),
+ id=uuid.UUID(jprogress.id().toString()),
+ runId=uuid.UUID(jprogress.runId().toString()),
+ name=jprogress.name(),
+ timestamp=jprogress.timestamp(),
+ batchId=jprogress.batchId(),
+ batchDuration=jprogress.batchDuration(),
+ durationMs=dict(jprogress.durationMs()),
+ eventTime=dict(jprogress.eventTime()),
+ stateOperators=[
+ StateOperatorProgress.fromJObject(js) for js in
jprogress.stateOperators()
+ ],
+ sources=[SourceProgress.fromJObject(js) for js in
jprogress.sources()],
+ sink=SinkProgress.fromJObject(jprogress.sink()),
+ numInputRows=jprogress.numInputRows(),
+ inputRowsPerSecond=jprogress.inputRowsPerSecond(),
+ processedRowsPerSecond=jprogress.processedRowsPerSecond(),
+ observedMetrics={
+ k: cloudpickle.loads(
+ SparkContext._jvm.PythonSQLUtils.toPyRow(jr) # type:
ignore[union-attr]
+ )
+ for k, jr in dict(jprogress.observedMetrics()).items()
+ },
+ )
+
+ @classmethod
+ def fromJson(cls, j: Dict[str, Any]) -> "StreamingQueryProgress":
+ return cls(
+ json=json.dumps(j),
+ prettyJson=json.dumps(j, indent=4),
+ id=uuid.UUID(j["id"]),
+ runId=uuid.UUID(j["runId"]),
+ name=j["name"],
+ timestamp=j["timestamp"],
+ batchId=j["batchId"],
+ batchDuration=j["batchDuration"],
+ durationMs=dict(j["durationMs"]),
+ eventTime=dict(j["eventTime"]),
+ stateOperators=[StateOperatorProgress.fromJson(s) for s in
j["stateOperators"]],
+ sources=[SourceProgress.fromJson(s) for s in j["sources"]],
+ sink=SinkProgress.fromJson(j["sink"]),
+ numInputRows=j["numInputRows"],
+ inputRowsPerSecond=j["inputRowsPerSecond"],
+ processedRowsPerSecond=j["processedRowsPerSecond"],
+ observedMetrics={
+ k: Row(*row_dict.keys())(*row_dict.values()) # Assume no
nested rows
Review Comment:
Checking the original PRs https://github.com/apache/spark/pull/26127. The
intended use case of `observe` method is to construct this Row by aggregating
on some fields. I think we don't need to handle nested rows here but I'm open
to discussion.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]