jingz-db commented on code in PR #49560:
URL: https://github.com/apache/spark/pull/49560#discussion_r1953454474


##########
python/pyspark/sql/streaming/stateful_processor_util.py:
##########
@@ -16,13 +16,220 @@
 #
 
 from enum import Enum
+import itertools
+from typing import Any, Iterator, Optional, TYPE_CHECKING
+from pyspark.sql.streaming.stateful_processor_api_client import (
+    StatefulProcessorApiClient,
+    StatefulProcessorHandleState,
+)
+from pyspark.sql.streaming.stateful_processor import (
+    ExpiredTimerInfo,
+    StatefulProcessor,
+    StatefulProcessorHandle,
+    TimerValues,
+)
+
+if TYPE_CHECKING:
+    from pyspark.sql.pandas._typing import DataFrameLike as PandasDataFrameLike
 
 # This file places the utilities for transformWithStateInPandas; we have a 
separate file to avoid
 # putting internal classes to the stateful_processor.py file which contains 
public APIs.
 
 
 class TransformWithStateInPandasFuncMode(Enum):
+    """
+    Internal mode for python worker UDF mode for transformWithStateInPandas; 
external mode are in
+    `StatefulProcessorHandleState` for public use purposes.
+    """
+
     PROCESS_DATA = 1
     PROCESS_TIMER = 2
     COMPLETE = 3
     PRE_INIT = 4
+
+
+class TransformWithStateInPandasUdfUtils:
+    """
+    Internal Utility class used for python worker UDF for 
transformWithStateInPandas. This class is
+    shared for both classic and spark connect mode.
+    """
+
+    def __init__(self, stateful_processor: StatefulProcessor, time_mode: str):
+        self._stateful_processor = stateful_processor
+        self._time_mode = time_mode
+
+    def transformWithStateUDF(
+        self,
+        stateful_processor_api_client: StatefulProcessorApiClient,
+        mode: TransformWithStateInPandasFuncMode,
+        key: Any,
+        input_rows: Iterator["PandasDataFrameLike"],
+    ) -> Iterator["PandasDataFrameLike"]:
+        if mode == TransformWithStateInPandasFuncMode.PRE_INIT:
+            return self._handle_pre_init(stateful_processor_api_client)
+
+        handle = StatefulProcessorHandle(stateful_processor_api_client)
+
+        if stateful_processor_api_client.handle_state == 
StatefulProcessorHandleState.CREATED:
+            self._stateful_processor.init(handle)
+            
stateful_processor_api_client.set_handle_state(StatefulProcessorHandleState.INITIALIZED)
+
+        if mode == TransformWithStateInPandasFuncMode.PROCESS_TIMER:
+            stateful_processor_api_client.set_handle_state(
+                StatefulProcessorHandleState.DATA_PROCESSED
+            )
+            result = self._handle_expired_timers(stateful_processor_api_client)
+            return result
+        elif mode == TransformWithStateInPandasFuncMode.COMPLETE:
+            stateful_processor_api_client.set_handle_state(
+                StatefulProcessorHandleState.TIMER_PROCESSED
+            )
+            stateful_processor_api_client.remove_implicit_key()
+            self._stateful_processor.close()
+            
stateful_processor_api_client.set_handle_state(StatefulProcessorHandleState.CLOSED)
+            return iter([])
+        else:
+            # mode == TransformWithStateInPandasFuncMode.PROCESS_DATA
+            result = self._handle_data_rows(stateful_processor_api_client, 
key, input_rows)
+            return result
+
+    def transformWithStateWithInitStateUDF(
+        self,
+        stateful_processor_api_client: StatefulProcessorApiClient,
+        mode: TransformWithStateInPandasFuncMode,
+        key: Any,
+        input_rows: Iterator["PandasDataFrameLike"],
+        initial_states: Optional[Iterator["PandasDataFrameLike"]] = None,
+    ) -> Iterator["PandasDataFrameLike"]:
+        """
+        UDF for TWS operator with non-empty initial states. Possible input 
combinations
+        of inputRows and initialStates iterator:
+        - Both `inputRows` and `initialStates` are non-empty. Both input rows 
and initial
+         states contains the grouping key and data.
+        - `InitialStates` is non-empty, while `inputRows` is empty. Only 
initial states
+         contains the grouping key and data, and it is first batch.
+        - `initialStates` is empty, while `inputRows` is non-empty. Only 
inputRows contains the
+         grouping key and data, and it is first batch.
+        - `initialStates` is None, while `inputRows` is not empty. This is not 
first batch.
+         `initialStates` is initialized to the positional value as None.
+        """
+        if mode == TransformWithStateInPandasFuncMode.PRE_INIT:
+            return self._handle_pre_init(stateful_processor_api_client)
+
+        handle = StatefulProcessorHandle(stateful_processor_api_client)
+
+        if stateful_processor_api_client.handle_state == 
StatefulProcessorHandleState.CREATED:
+            self._stateful_processor.init(handle)
+            
stateful_processor_api_client.set_handle_state(StatefulProcessorHandleState.INITIALIZED)
+
+        if mode == TransformWithStateInPandasFuncMode.PROCESS_TIMER:
+            stateful_processor_api_client.set_handle_state(
+                StatefulProcessorHandleState.DATA_PROCESSED
+            )
+            result = self._handle_expired_timers(stateful_processor_api_client)
+            return result
+        elif mode == TransformWithStateInPandasFuncMode.COMPLETE:
+            stateful_processor_api_client.remove_implicit_key()
+            self._stateful_processor.close()
+            
stateful_processor_api_client.set_handle_state(StatefulProcessorHandleState.CLOSED)
+            return iter([])
+        else:
+            # mode == TransformWithStateInPandasFuncMode.PROCESS_DATA
+            batch_timestamp, watermark_timestamp = 
stateful_processor_api_client.get_timestamps(
+                self._time_mode
+            )
+
+        # only process initial state if first batch and initial state is not 
None
+        if initial_states is not None:
+            for cur_initial_state in initial_states:
+                stateful_processor_api_client.set_implicit_key(key)
+                self._stateful_processor.handleInitialState(
+                    key, cur_initial_state, TimerValues(batch_timestamp, 
watermark_timestamp)
+                )
+
+        # if we don't have input rows for the given key but only have initial 
state
+        # for the grouping key, the inputRows iterator could be empty
+        input_rows_empty = False
+        try:
+            first = next(input_rows)
+        except StopIteration:
+            input_rows_empty = True
+        else:
+            input_rows = itertools.chain([first], input_rows)
+
+        if not input_rows_empty:
+            result = self._handle_data_rows(stateful_processor_api_client, 
key, input_rows)
+        else:
+            result = iter([])
+
+        return result
+
+    def _handle_pre_init(

Review Comment:
   To reviewers: `_handle_pre_init`, `_handle_data_rows` and 
`handle_expired_timers` was moved from `group_ops.py`. No code difference 
except for minor naming change for internal parameters naming change 
(`statefulProcessorApiClient` -> `stateful_processor_api_client`). And these 
functions are private because only `transformWithStateWithInitStateUDF` and 
`transformWithStateUDF` will be used in `group_ops.py` and `group.py` and the 
rest should not be exposed to external classes.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to