bogao007 commented on code in PR #47133:
URL: https://github.com/apache/spark/pull/47133#discussion_r1663048758
##########
python/pyspark/sql/pandas/serializers.py:
##########
@@ -1116,3 +1122,86 @@ def init_stream_yield_batches(batches):
batches_to_write = init_stream_yield_batches(serialize_batches())
return ArrowStreamSerializer.dump_stream(self, batches_to_write,
stream)
+
+
+class TransformWithStateInPandasSerializer(ArrowStreamPandasUDFSerializer):
+
+ def __init__(
+ self,
+ timezone,
+ safecheck,
+ assign_cols_by_name,
+ arrow_max_records_per_batch):
+ super(
+ TransformWithStateInPandasSerializer,
+ self
+ ).__init__(timezone, safecheck, assign_cols_by_name)
+
+ # self.state_server_port = state_server_port
+
+ # # open client connection to state server socket
+ # self._client_socket = socket.socket()
+ # self._client_socket.connect(("localhost", state_server_port))
+ # sockfile = self._client_socket.makefile("rwb",
int(os.environ.get("SPARK_BUFFER_SIZE", 65536)))
+ # self.state_serializer =
TransformWithStateInPandasStateSerializer(sockfile)
+ self.arrow_max_records_per_batch = arrow_max_records_per_batch
+ self.key_offsets = None
+
+ # Nothing special here, we need to create the handle and read
+ # data in groups.
+ def load_stream(self, stream):
+ import pyarrow as pa
+ from itertools import tee
+
+ def generate_data_batches(batches):
+ for batch in batches:
+ data_pandas = [self.arrow_to_pandas(c) for c in
pa.Table.from_batches([batch]).itercolumns()]
+ key_series = [data_pandas[o] for o in self.key_offsets]
+ batch_key = tuple(s[0] for s in key_series)
+ yield (batch_key, data_pandas)
+
+ print("Generating data batches...")
+ _batches = super(ArrowStreamPandasSerializer, self).load_stream(stream)
+ data_batches = generate_data_batches(_batches)
+
+ print("Returning data batches...")
+ for k, g in groupby(data_batches, key=lambda x: x[0]):
+ yield (k, g)
+
+
+ def dump_stream(self, iterator, stream):
+ result = [(b, t) for x in iterator for y, t in x for b in y]
+ super().dump_stream(result, stream)
+
+class ImplicitGroupingKeyTracker:
+ def __init__(self) -> None:
+ self._key = None
+
+ def setKey(self, key: Any) -> None:
+ self._key = key
+
+ def getKey(self) -> Any:
+ return self._key
+
+
+class TransformWithStateInPandasStateSerializer:
Review Comment:
I will remove this for now, this can be added when we implement ListState.
Btw, I'm going to use socket file to pass the result of `ValueState.get()`
rather than an Arrow Record batch, let me know if you have any concerns, thanks!
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]