funrollloops commented on code in PR #56192:
URL: https://github.com/apache/spark/pull/56192#discussion_r3320875056
##########
python/benchmarks/bench_eval_type.py:
##########
@@ -1929,3 +1935,173 @@ class
WindowAggPandasUDFTimeBench(_WindowAggPandasBenchMixin, _TimeBenchBase):
class WindowAggPandasUDFPeakmemBench(_WindowAggPandasBenchMixin,
_PeakmemBenchBase):
pass
+
+
+# -- SQL_TRANSFORM_WITH_STATE_PANDAS_UDF
---------------------------------------
+# Stateful streaming with Pandas. UDF signature is
+# ``(api_client, mode, key, pdfs)`` and returns ``Iterator[pandas.DataFrame]``.
+# The input wire stream is a single plain Arrow stream pre-sorted by the
+# grouping key column at offset 0; ``TransformWithStateInPandasSerializer``
+# chunks rows into one ``(mode, key, pdfs)`` tuple per group, then emits a
+# phantom ``PROCESS_TIMER`` and ``COMPLETE`` call with an empty pdf iterator.
+# ``StatefulProcessorApiClient.__init__`` opens a real TCP socket to the JVM
+# state server; the stub listener below satisfies that connect. The benchmark
+# UDFs never invoke any state API method, so no protocol exchange is needed.
+
+
+class _StubStateServer:
+ """Stub TCP listener so ``StatefulProcessorApiClient`` init succeeds.
+
+ One instance per benchmark process; the port is reused across all scenarios
+ and ASV iterations. The accept loop stashes connections to keep them open
+ until the worker process tears them down (the worker never closes its end
+ explicitly, but Python GCs the socket on ``main`` return).
+ """
+
+ _instance: "_StubStateServer | None" = None
+
+ @classmethod
+ def get_port(cls) -> int:
+ if cls._instance is None:
+ cls._instance = cls()
+ return cls._instance.port
+
+ def __init__(self) -> None:
+ self._sock = socket.socket()
+ self._sock.bind(("127.0.0.1", 0))
+ self._sock.listen(128)
+ self.port = self._sock.getsockname()[1]
+ self._connections: list[socket.socket] = []
+ self._thread = threading.Thread(target=self._accept_loop, daemon=True)
+ self._thread.start()
+
+ def _accept_loop(self) -> None:
+ while True:
+ try:
+ conn, _ = self._sock.accept()
+ except OSError:
+ break
+ self._connections.append(conn)
+
+
+class _TransformWithStatePandasBenchMixin:
+ """Provides ``_write_scenario`` for SQL_TRANSFORM_WITH_STATE_PANDAS_UDF.
+
+ Each scenario emits one plain Arrow stream pre-sorted by the leading int
+ key column. UDFs receive an iterator of value-only Pandas DataFrames per
+ group plus phantom ``PROCESS_TIMER``/``COMPLETE`` calls (empty iterator).
+ """
+
+ # Each scenario: (num_groups, rows_per_group, num_value_cols).
+ # Row counts are scaled so identity_udf (full pdf passthrough -> ~equal
+ # input and output volume) stays under ASV's 60s per-sample timeout.
+ _scenario_configs = {
+ "few_groups_sm": (50, 5_000, 5),
+ "few_groups_lg": (50, 50_000, 5),
+ "many_groups_sm": (2_000, 500, 5),
+ "many_groups_lg": (500, 2_000, 5),
+ "wide_cols": (200, 5_000, 20),
+ }
+
+ @staticmethod
+ def _build_scenario(name):
+ """Build a single TWS Pandas scenario.
+
+ Returns ``(batches, schema)`` where ``batches`` is a plain list of
Arrow
+ RecordBatches with rows pre-sorted by the leading int32 key column.
+ """
+ np.random.seed(42)
+ num_groups, rows_per_group, num_value_cols = (
+ _TransformWithStatePandasBenchMixin._scenario_configs[name]
+ )
+ total_rows = num_groups * rows_per_group
+ key_array = pa.array(
+ np.repeat(np.arange(num_groups, dtype=np.int32), rows_per_group),
+ type=pa.int32(),
+ )
+ value_pool = MockDataFactory.NUMERIC_TYPES
Review Comment:
Sure you don't want to add non-numeric types? Maybe some cases with nested
types (arrays, structs, etc)?
##########
python/benchmarks/bench_eval_type.py:
##########
@@ -1929,3 +1935,173 @@ class
WindowAggPandasUDFTimeBench(_WindowAggPandasBenchMixin, _TimeBenchBase):
class WindowAggPandasUDFPeakmemBench(_WindowAggPandasBenchMixin,
_PeakmemBenchBase):
pass
+
+
+# -- SQL_TRANSFORM_WITH_STATE_PANDAS_UDF
---------------------------------------
+# Stateful streaming with Pandas. UDF signature is
+# ``(api_client, mode, key, pdfs)`` and returns ``Iterator[pandas.DataFrame]``.
+# The input wire stream is a single plain Arrow stream pre-sorted by the
+# grouping key column at offset 0; ``TransformWithStateInPandasSerializer``
+# chunks rows into one ``(mode, key, pdfs)`` tuple per group, then emits a
+# phantom ``PROCESS_TIMER`` and ``COMPLETE`` call with an empty pdf iterator.
+# ``StatefulProcessorApiClient.__init__`` opens a real TCP socket to the JVM
+# state server; the stub listener below satisfies that connect. The benchmark
+# UDFs never invoke any state API method, so no protocol exchange is needed.
+
+
+class _StubStateServer:
+ """Stub TCP listener so ``StatefulProcessorApiClient`` init succeeds.
+
+ One instance per benchmark process; the port is reused across all scenarios
+ and ASV iterations. The accept loop stashes connections to keep them open
+ until the worker process tears them down (the worker never closes its end
+ explicitly, but Python GCs the socket on ``main`` return).
+ """
+
+ _instance: "_StubStateServer | None" = None
+
+ @classmethod
+ def get_port(cls) -> int:
+ if cls._instance is None:
+ cls._instance = cls()
+ return cls._instance.port
+
+ def __init__(self) -> None:
+ self._sock = socket.socket()
+ self._sock.bind(("127.0.0.1", 0))
+ self._sock.listen(128)
+ self.port = self._sock.getsockname()[1]
+ self._connections: list[socket.socket] = []
+ self._thread = threading.Thread(target=self._accept_loop, daemon=True)
+ self._thread.start()
+
+ def _accept_loop(self) -> None:
+ while True:
+ try:
+ conn, _ = self._sock.accept()
+ except OSError:
+ break
+ self._connections.append(conn)
+
+
+class _TransformWithStatePandasBenchMixin:
+ """Provides ``_write_scenario`` for SQL_TRANSFORM_WITH_STATE_PANDAS_UDF.
+
+ Each scenario emits one plain Arrow stream pre-sorted by the leading int
+ key column. UDFs receive an iterator of value-only Pandas DataFrames per
+ group plus phantom ``PROCESS_TIMER``/``COMPLETE`` calls (empty iterator).
+ """
+
+ # Each scenario: (num_groups, rows_per_group, num_value_cols).
+ # Row counts are scaled so identity_udf (full pdf passthrough -> ~equal
+ # input and output volume) stays under ASV's 60s per-sample timeout.
+ _scenario_configs = {
+ "few_groups_sm": (50, 5_000, 5),
+ "few_groups_lg": (50, 50_000, 5),
+ "many_groups_sm": (2_000, 500, 5),
+ "many_groups_lg": (500, 2_000, 5),
+ "wide_cols": (200, 5_000, 20),
+ }
+
+ @staticmethod
+ def _build_scenario(name):
+ """Build a single TWS Pandas scenario.
+
+ Returns ``(batches, schema)`` where ``batches`` is a plain list of
Arrow
+ RecordBatches with rows pre-sorted by the leading int32 key column.
+ """
+ np.random.seed(42)
+ num_groups, rows_per_group, num_value_cols = (
+ _TransformWithStatePandasBenchMixin._scenario_configs[name]
+ )
+ total_rows = num_groups * rows_per_group
+ key_array = pa.array(
+ np.repeat(np.arange(num_groups, dtype=np.int32), rows_per_group),
+ type=pa.int32(),
+ )
+ value_pool = MockDataFactory.NUMERIC_TYPES
+ value_arrays = [
+ value_pool[i % len(value_pool)][0](total_rows) for i in
range(num_value_cols)
+ ]
+ names = ["col_0"] + [f"col_{i + 1}" for i in range(num_value_cols)]
+ full_batch = pa.RecordBatch.from_arrays([key_array] + value_arrays,
names=names)
+ batch_size = MockDataFactory.MAX_RECORDS_PER_BATCH
+ batches = [
+ full_batch.slice(offset, min(batch_size, total_rows - offset))
+ for offset in range(0, total_rows, batch_size)
+ ]
+ schema = StructType(
+ [StructField("col_0", IntegerType())]
+ + [
+ StructField(f"col_{i + 1}", value_pool[i % len(value_pool)][1])
+ for i in range(num_value_cols)
+ ]
+ )
+ return batches, schema
+
+ def _tws_pandas_identity(api_client, mode, key, pdfs):
+ from pyspark.sql.streaming.stateful_processor_util import (
+ TransformWithStateInPandasFuncMode,
+ )
+
+ if mode == TransformWithStateInPandasFuncMode.PROCESS_DATA:
+ yield from pdfs
+
+ def _tws_pandas_sort(api_client, mode, key, pdfs):
+ from pyspark.sql.streaming.stateful_processor_util import (
+ TransformWithStateInPandasFuncMode,
+ )
+
+ if mode == TransformWithStateInPandasFuncMode.PROCESS_DATA:
+ for pdf in pdfs:
+ yield pdf.sort_values(pdf.columns[0])
+
+ def _tws_pandas_count(api_client, mode, key, pdfs):
+ import pandas as pd
+ from pyspark.sql.streaming.stateful_processor_util import (
+ TransformWithStateInPandasFuncMode,
+ )
+
+ if mode == TransformWithStateInPandasFuncMode.PROCESS_DATA:
+ total = sum(len(pdf) for pdf in pdfs)
+ yield pd.DataFrame({"col_1": [total]})
+
+ # ret_type=None means "use all value columns of the input schema".
+ _udfs = {
+ "identity_udf": (_tws_pandas_identity, None),
+ "sort_udf": (_tws_pandas_sort, None),
+ "count_udf": (_tws_pandas_count, StructType([StructField("col_1",
IntegerType())])),
+ }
+ params = [list(_scenario_configs), list(_udfs)]
+ param_names = ["scenario", "udf"]
+
+ _NUM_KEY_COLS = 1
+
+ def _write_scenario(self, scenario, udf_name, buf):
+ batches, schema = self._build_scenario(scenario)
+ udf_func, ret_type = self._udfs[udf_name]
+ if ret_type is None:
+ ret_type = StructType(schema.fields[self._NUM_KEY_COLS :])
Review Comment:
nit: we typically see the keys included in the output schema for transform
with state.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]