This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 4f6ae843e53c [SPARK-55754][PYTHON][TEST] Add ASV microbenchmarks for 
scalar Arrow UDF eval types
4f6ae843e53c is described below

commit 4f6ae843e53c4eee18f33edfe2d59ed72d9238f2
Author: Yicong-Huang <[email protected]>
AuthorDate: Mon Mar 9 11:31:34 2026 +0800

    [SPARK-55754][PYTHON][TEST] Add ASV microbenchmarks for scalar Arrow UDF 
eval types
    
    ### What changes were proposed in this pull request?
    
    Add ASV microbenchmarks for `SQL_SCALAR_ARROW_UDF` and 
`SQL_SCALAR_ARROW_ITER_UDF`.
    
    **Benchmark classes** (4 total, split by eval type and metric):
    - `ScalarArrowUDFTimeBench` / `ScalarArrowUDFPeakmemBench`
    - `ScalarArrowIterUDFTimeBench` / `ScalarArrowIterUDFPeakmemBench`
    
    **Two ASV parameter dimensions** (`scenario` x `udf`):
    - 9 data scenarios: 4 batch-size/col-count combos + 5 pure-type scenarios 
(`pure_ints`, `pure_floats`, `pure_strings`, `pure_ts`, `mixed_types`)
    - 3 type-agnostic UDFs: `identity_udf`, `sort_udf`, `nullcheck_udf` (all 
use `arg_offsets=[0]`, work on any Arrow type)
    
    **Design decisions:**
    - `time_*` materializes full input in `BytesIO` during setup to eliminate 
disk I/O noise
    - `peakmem_*` stores only params in setup; streams to temp file and replays 
from disk during the benchmark to avoid inflating peak memory
    - UDF return type is dynamically resolved from the scenario's `col0_type` 
when the UDF declares `ret_type=None`
    - Mixin pattern (`_NonGroupedBenchMixin`) + base classes (`_TimeBenchBase`, 
`_PeakmemBenchBase`) for easy extension to future eval types (e.g., UDTF)
    
    ### Why are the changes needed?
    
    Part of SPARK-55724.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No. Benchmark files only.
    
    ### How was this patch tested?
    
    `COLUMNS=120 asv run --python=same --bench "ScalarArrow" --attribute 
"repeat=(3,5,5.0)"`:
    
    **ScalarArrowUDFTimeBench** (`SQL_SCALAR_ARROW_UDF`):
    ```
    =================== ============== ============ ===============
    --                                      udf
    ------------------- -------------------------------------------
          scenario       identity_udf    sort_udf    nullcheck_udf
    =================== ============== ============ ===============
      sm_batch_few_col    70.3+-0.4ms    184+-0.2ms      53.4+-0.5ms
     sm_batch_many_col    24.1+-0.2ms    42.0+-0.3ms     20.1+-0.2ms
      lg_batch_few_col     461+-1ms       819+-3ms        265+-1ms
     lg_batch_many_col    320+-0.4ms      359+-3ms        327+-1ms
         pure_ints         112+-1ms      175+-0.7ms      87.0+-0.7ms
        pure_floats       103+-0.8ms      564+-2ms       86.0+-0.8ms
        pure_strings      116+-0.8ms     515+-0.5ms      93.4+-0.5ms
          pure_ts         103+-0.7ms      212+-4ms       86.1+-0.5ms
        mixed_types        103+-2ms      164+-0.9ms      80.4+-0.6ms
    =================== ============== ============ ===============
    ```
    
    **ScalarArrowUDFPeakmemBench** (`SQL_SCALAR_ARROW_UDF`):
    ```
    =================== ============== ========== ===============
    --                                     udf
    ------------------- -----------------------------------------
          scenario       identity_udf   sort_udf   nullcheck_udf
    =================== ============== ========== ===============
      sm_batch_few_col       121M         123M          113M
     sm_batch_many_col       116M         117M          114M
      lg_batch_few_col       256M         258M          120M
     lg_batch_many_col       136M         138M          123M
         pure_ints           136M         137M          114M
        pure_floats          154M         155M          115M
        pure_strings         157M         159M          115M
          pure_ts            154M         155M          115M
        mixed_types          136M         137M          115M
    =================== ============== ========== ===============
    ```
    
    **ScalarArrowIterUDFTimeBench** (`SQL_SCALAR_ARROW_ITER_UDF`):
    ```
    =================== ============== ============= ===============
    --                                      udf
    ------------------- --------------------------------------------
          scenario       identity_udf     sort_udf    nullcheck_udf
    =================== ============== ============= ===============
      sm_batch_few_col    66.0+-0.5ms      179+-1ms       48.9+-0.7ms
     sm_batch_many_col    22.8+-0.1ms    41.0+-0.05ms     19.3+-0.1ms
      lg_batch_few_col     447+-2ms        803+-3ms        249+-1ms
     lg_batch_many_col    320+-0.5ms       357+-3ms        325+-1ms
         pure_ints        105+-0.5ms       169+-1ms       83.4+-0.6ms
        pure_floats        101+-1ms        559+-2ms        83.1+-2ms
        pure_strings       116+-3ms        513+-3ms        90.5+-1ms
          pure_ts          103+-2ms       208+-0.6ms       83.5+-2ms
        mixed_types       100+-0.5ms       163+-1ms       78.0+-0.2ms
    =================== ============== ============= ===============
    ```
    
    **ScalarArrowIterUDFPeakmemBench** (`SQL_SCALAR_ARROW_ITER_UDF`):
    ```
    =================== ============== ========== ===============
    --                                     udf
    ------------------- -----------------------------------------
          scenario       identity_udf   sort_udf   nullcheck_udf
    =================== ============== ========== ===============
      sm_batch_few_col       122M         123M          113M
     sm_batch_many_col       116M         117M          114M
      lg_batch_few_col       256M         258M          120M
     lg_batch_many_col       136M         138M          123M
         pure_ints           136M         137M          115M
        pure_floats          154M         155M          114M
        pure_strings         157M         159M          115M
          pure_ts            154M         155M          114M
        mixed_types          136M         137M          114M
    =================== ============== ========== ===============
    ```
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #54555 from Yicong-Huang/SPARK-55754/benchmark/scalar-arrow-udf.
    
    Lead-authored-by: Yicong-Huang 
<[email protected]>
    Co-authored-by: Yicong Huang 
<[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 python/benchmarks/bench_eval_type.py | 447 ++++++++++++++++++++---------------
 1 file changed, 257 insertions(+), 190 deletions(-)

diff --git a/python/benchmarks/bench_eval_type.py 
b/python/benchmarks/bench_eval_type.py
index eaf5313a4cc8..53a4f208e99d 100644
--- a/python/benchmarks/bench_eval_type.py
+++ b/python/benchmarks/bench_eval_type.py
@@ -24,10 +24,12 @@ by constructing the complete binary protocol that 
``worker.py``'s
 """
 
 import io
+import os
 import json
 import struct
 import sys
-from typing import Any, Callable, Optional
+import tempfile
+from typing import Any, Callable, Iterator
 
 import numpy as np
 import pyarrow as pa
@@ -35,11 +37,14 @@ import pyarrow as pa
 from pyspark.cloudpickle import dumps as cloudpickle_dumps
 from pyspark.serializers import write_int, write_long
 from pyspark.sql.types import (
+    BinaryType,
+    BooleanType,
     DoubleType,
     IntegerType,
     StringType,
     StructField,
     StructType,
+    TimestampNTZType,
 )
 from pyspark.util import PythonEvalType
 from pyspark.worker import main as worker_main
@@ -110,239 +115,301 @@ def _build_udf_payload(
     write_long(0, buf)  # result_id
 
 
-def _build_grouped_arrow_data(
-    arrow_batch: pa.RecordBatch,
-    num_groups: int,
-    buf: io.BytesIO,
-    max_records_per_batch: Optional[int] = None,
-) -> None:
-    """Write grouped-map Arrow data: ``(write_int(1) + IPC) * N + 
write_int(0)``.
+def _write_arrow_ipc_batches(batch_iter: Iterator[pa.RecordBatch], buf: 
io.BufferedIOBase) -> None:
+    """Write a plain Arrow IPC stream from an iterator of Arrow batches."""
+    first_batch = next(batch_iter)
+    writer = pa.RecordBatchStreamWriter(buf, first_batch.schema)
+    writer.write_batch(first_batch)
+    for batch in batch_iter:
+        writer.write_batch(batch)
+    writer.close()
 
-    When *max_records_per_batch* is set and the batch exceeds that limit, each
-    group is split into multiple smaller Arrow batches inside the same IPC 
stream,
-    mirroring what the JVM does under 
``spark.sql.execution.arrow.maxRecordsPerBatch``.
-    """
-    for _ in range(num_groups):
-        write_int(1, buf)
-        writer = pa.RecordBatchStreamWriter(buf, arrow_batch.schema)
-        if max_records_per_batch and arrow_batch.num_rows > 
max_records_per_batch:
-            for offset in range(0, arrow_batch.num_rows, 
max_records_per_batch):
-                writer.write_batch(arrow_batch.slice(offset, 
max_records_per_batch))
-        else:
-            writer.write_batch(arrow_batch)
-        writer.close()
-    write_int(0, buf)
-
-
-def _build_worker_input(
+
+def _write_worker_input(
     eval_type: int,
-    udf_func: Callable[..., Any],
-    return_type: StructType,
-    arg_offsets: list[int],
-    arrow_batch: pa.RecordBatch,
-    num_groups: int,
-    max_records_per_batch: Optional[int] = None,
-) -> bytes:
-    """Assemble the full binary stream consumed by ``worker_main(infile, 
outfile)``.
-
-    Parameters
-    ----------
-    max_records_per_batch : int, optional
-        When set, each group's Arrow data is split into sub-batches of this
-        size, mirroring the JVM-side behaviour of
-        ``spark.sql.execution.arrow.maxRecordsPerBatch``.
-    """
-    buf = io.BytesIO()
+    write_command: Callable[[io.BufferedIOBase], None],
+    write_data: Callable[[io.BufferedIOBase], None],
+    buf: io.BufferedIOBase,
+) -> None:
+    """Write the full worker binary stream: preamble + command + data + end.
 
+    This is the general skeleton shared by all eval types.  Callers provide
+    *write_command* (e.g. ``_build_udf_payload``) and *write_data*
+    (e.g. ``_write_arrow_ipc_batches``) to plug in protocol specifics.
+    """
     _build_preamble(buf)
     write_int(eval_type, buf)
-
     write_int(0, buf)  # RunnerConf  (0 key-value pairs)
     write_int(0, buf)  # EvalConf    (0 key-value pairs)
-
-    _build_udf_payload(udf_func, return_type, arg_offsets, buf)
-    _build_grouped_arrow_data(arrow_batch, num_groups, buf, 
max_records_per_batch)
+    write_command(buf)
+    write_data(buf)
     write_int(-4, buf)  # SpecialLengths.END_OF_STREAM
 
-    return buf.getvalue()
+
+def _run_worker_from_replayed_file(write_input: Callable[[io.BufferedIOBase], 
None]) -> None:
+    """Write input to a temp file, then replay it through ``worker_main``."""
+    fd, path = tempfile.mkstemp(prefix="spark-bench-replay-", suffix=".bin")
+    try:
+        with os.fdopen(fd, "w+b") as infile:
+            write_input(infile)
+            infile.flush()
+            infile.seek(0)
+            worker_main(infile, io.BytesIO())
+    finally:
+        try:
+            os.remove(path)
+        except FileNotFoundError:
+            pass
+
+
+def _make_typed_batch(rows: int, n_cols: int) -> tuple[pa.RecordBatch, 
IntegerType]:
+    """Columns cycling through int64, string, binary, boolean — reflects 
realistic serde costs."""
+    type_cycle = [
+        (lambda r: pa.array(np.random.randint(0, 1000, r, dtype=np.int64)), 
IntegerType()),
+        (lambda r: pa.array([f"s{j}" for j in range(r)]), StringType()),
+        (lambda r: pa.array([f"b{j}".encode() for j in range(r)]), 
BinaryType()),
+        (lambda r: pa.array(np.random.choice([True, False], r)), 
BooleanType()),
+    ]
+    arrays = [type_cycle[i % len(type_cycle)][0](rows) for i in range(n_cols)]
+    fields = [StructField(f"col_{i}", type_cycle[i % len(type_cycle)][1]) for 
i in range(n_cols)]
+    return (
+        pa.RecordBatch.from_arrays(arrays, names=[f.name for f in fields]),
+        IntegerType(),
+    )
 
 
 # ---------------------------------------------------------------------------
-# Data helpers
+# General benchmark base classes
 # ---------------------------------------------------------------------------
 
 
-def _build_grouped_arg_offsets(n_cols: int, n_keys: int = 0) -> list[int]:
-    """``[len, num_keys, key_col_0, …, val_col_0, …]``"""
-    keys = list(range(n_keys))
-    vals = list(range(n_keys, n_cols))
-    offsets = [n_keys] + keys + vals
-    return [len(offsets)] + offsets
+class _TimeBenchBase:
+    """Base for ``time_*`` benchmarks (any eval type).
 
+    Setup materializes full input bytes in memory so that disk I/O does not
+    affect latency measurements.
 
-def _make_grouped_batch(rows_per_group: int, n_cols: int) -> 
tuple[pa.RecordBatch, StructType]:
-    """``group_key (int64)`` + ``(n_cols - 1)`` float32 value columns."""
-    arrays = [pa.array(np.zeros(rows_per_group, dtype=np.int64))] + [
-        pa.array(np.random.rand(rows_per_group).astype(np.float32)) for _ in 
range(n_cols - 1)
-    ]
-    fields = [StructField("group_key", IntegerType())] + [
-        StructField(f"some_field_{i}", DoubleType()) for i in range(n_cols - 1)
-    ]
-    return (
-        pa.RecordBatch.from_arrays(arrays, names=[f.name for f in fields]),
-        StructType(fields),
-    )
+    Subclasses provide ``params``, ``param_names``, and ``_write_scenario``.
+    """
 
+    def setup(self, *args):
+        buf = io.BytesIO()
+        self._write_scenario(*args, buf)
+        self._input = buf.getvalue()
 
-def _make_mixed_batch(rows_per_group: int) -> tuple[pa.RecordBatch, 
StructType]:
-    """``id``, ``str_col``, ``float_col``, ``double_col``, ``long_col``."""
-    arrays = [
-        pa.array(np.zeros(rows_per_group, dtype=np.int64)),
-        pa.array([f"s{j}" for j in range(rows_per_group)]),
-        pa.array(np.random.rand(rows_per_group).astype(np.float32)),
-        pa.array(np.random.rand(rows_per_group)),
-        pa.array(np.zeros(rows_per_group, dtype=np.int64)),
-    ]
-    fields = [
-        StructField("id", IntegerType()),
-        StructField("str_col", StringType()),
-        StructField("float_col", DoubleType()),
-        StructField("double_col", DoubleType()),
-        StructField("long_col", IntegerType()),
-    ]
-    return (
-        pa.RecordBatch.from_arrays(arrays, names=[f.name for f in fields]),
-        StructType(fields),
-    )
+    def time_worker(self, *args):
+        worker_main(io.BytesIO(self._input), io.BytesIO())
+
+
+class _PeakmemBenchBase:
+    """Base for ``peakmem_*`` benchmarks (any eval type).
+
+    Benchmark streams input to a temp file and replays from disk so that
+    setup memory does not inflate peak-memory measurements.
+
+    Subclasses provide ``params``, ``param_names``, and ``_write_scenario``.
+    """
+
+    def setup(self, *args):
+        self._args = args
+
+    def peakmem_worker(self, *args):
+        _run_worker_from_replayed_file(lambda buf: 
self._write_scenario(*self._args, buf))
 
 
 # ---------------------------------------------------------------------------
-# SQL_GROUPED_MAP_PANDAS_UDF
+# Non-grouped Arrow UDF benchmarks
 # ---------------------------------------------------------------------------
 
+# Data-shape scenarios shared by all non-grouped eval types.
+# Each entry maps to a ``(batch, num_batches, col0_type)`` tuple; the UDF is
+# selected independently via the ``udf`` ASV parameter.
+# ``col0_type`` is the Spark type of column 0, used as the default UDF return
+# type when the UDF declaration leaves it as ``None``.
+
+
+def _make_pure_batch(rows, n_cols, make_array, spark_type):
+    """Create a batch where all columns share the same Arrow type."""
+    arrays = [make_array(rows) for _ in range(n_cols)]
+    fields = [StructField(f"col_{i}", spark_type) for i in range(n_cols)]
+    return pa.RecordBatch.from_arrays(arrays, names=[f.name for f in fields])
 
-class GroupedMapPandasUDFBench:
-    """Full worker round-trip for ``SQL_GROUPED_MAP_PANDAS_UDF``.
 
-    Large groups (100k rows) are split into Arrow sub-batches of at most
-    ``_MAX_RECORDS_PER_BATCH`` rows, mirroring the JVM-side splitting
-    behaviour (``spark.sql.execution.arrow.maxRecordsPerBatch`` default 10 
000).
-    Small groups (1k rows) are unaffected.
+def _build_non_grouped_scenarios():
+    """Build data-shape scenarios for non-grouped Arrow eval types.
+
+    Returns a dict mapping scenario name to ``(batch, num_batches, 
col0_type)``.
     """
+    scenarios = {}
+
+    # Varying batch size and column count (mixed types cycling 
int/str/bin/bool)
+    for name, (rows, n_cols, num_batches) in {
+        "sm_batch_few_col": (1_000, 5, 1_500),
+        "sm_batch_many_col": (1_000, 50, 200),
+        "lg_batch_few_col": (10_000, 5, 3_500),
+        "lg_batch_many_col": (10_000, 50, 400),
+    }.items():
+        batch, col0_type = _make_typed_batch(rows, n_cols)
+        scenarios[name] = (batch, num_batches, col0_type)
+
+    # Pure-type scenarios (5000 rows, 10 cols, 1000 batches)
+    _PURE_ROWS, _PURE_COLS, _PURE_BATCHES = 5_000, 10, 1_000
+
+    scenarios["pure_ints"] = (
+        _make_pure_batch(
+            _PURE_ROWS,
+            _PURE_COLS,
+            lambda r: pa.array(np.random.randint(0, 1000, r, dtype=np.int64)),
+            IntegerType(),
+        ),
+        _PURE_BATCHES,
+        IntegerType(),
+    )
+    scenarios["pure_floats"] = (
+        _make_pure_batch(
+            _PURE_ROWS,
+            _PURE_COLS,
+            lambda r: pa.array(np.random.rand(r)),
+            DoubleType(),
+        ),
+        _PURE_BATCHES,
+        DoubleType(),
+    )
+    scenarios["pure_strings"] = (
+        _make_pure_batch(
+            _PURE_ROWS,
+            _PURE_COLS,
+            lambda r: pa.array([f"s{j}" for j in range(r)]),
+            StringType(),
+        ),
+        _PURE_BATCHES,
+        StringType(),
+    )
+    scenarios["pure_ts"] = (
+        _make_pure_batch(
+            _PURE_ROWS,
+            _PURE_COLS,
+            lambda r: pa.array(
+                np.arange(0, r, dtype="datetime64[us]"), 
type=pa.timestamp("us", tz=None)
+            ),
+            TimestampNTZType(),
+        ),
+        _PURE_BATCHES,
+        TimestampNTZType(),
+    )
+    scenarios["mixed_types"] = (
+        _make_typed_batch(_PURE_ROWS, _PURE_COLS)[0],
+        _PURE_BATCHES,
+        IntegerType(),
+    )
 
-    _MAX_RECORDS_PER_BATCH = 10_000  # matches 
spark.sql.execution.arrow.maxRecordsPerBatch default
-
-    def setup(self):
-        eval_type = PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF
-
-        # ---- varying group size (float data, identity UDF) ----
-        for name, (rows_per_group, n_cols, num_groups) in {
-            "small_few": (1_000, 5, 1_500),
-            "small_many": (1_000, 50, 200),
-            "large_few": (100_000, 5, 350),
-            "large_many": (100_000, 50, 40),
-        }.items():
-            batch, schema = _make_grouped_batch(rows_per_group, n_cols)
-            setattr(
-                self,
-                f"_{name}_input",
-                _build_worker_input(
-                    eval_type,
-                    lambda df: df,
-                    schema,
-                    _build_grouped_arg_offsets(n_cols),
-                    batch,
-                    num_groups=num_groups,
-                    max_records_per_batch=self._MAX_RECORDS_PER_BATCH,
-                ),
-            )
-
-        # ---- mixed types, 1-arg UDF ----
-        mixed_batch, mixed_schema = _make_mixed_batch(3)
-        n_mixed = len(mixed_schema.fields)
-
-        def double_col_add_mean(pdf):
-            return pdf.assign(double_col=pdf["double_col"] + 
pdf["double_col"].mean())
-
-        self._mixed_input = _build_worker_input(
-            eval_type,
-            double_col_add_mean,
-            mixed_schema,
-            _build_grouped_arg_offsets(n_mixed),
-            mixed_batch,
-            num_groups=1_300,
-        )
+    return scenarios
+
+
+_NON_GROUPED_SCENARIOS = _build_non_grouped_scenarios()
 
-        # ---- mixed types, 2-arg UDF with key ----
-        two_arg_schema = StructType(
-            [StructField("group_key", IntegerType()), 
StructField("double_col_mean", DoubleType())]
-        )
 
-        def double_col_key_add_mean(key, pdf):
-            import pandas as pd
+class _NonGroupedBenchMixin:
+    """Provides ``_write_scenario`` for non-grouped Arrow eval types.
 
-            return pd.DataFrame(
-                [{"group_key": key[0], "double_col_mean": 
pdf["double_col"].mean()}]
-            )
+    Subclasses set ``_eval_type``, ``_scenarios``, and ``_udfs``.
+    UDF entries with ``ret_type=None`` inherit ``col0_type`` from the scenario.
+    """
 
-        self._two_args_input = _build_worker_input(
-            eval_type,
-            double_col_key_add_mean,
-            two_arg_schema,
-            _build_grouped_arg_offsets(n_mixed, n_keys=1),
-            mixed_batch,
-            num_groups=1_600,
+    def _write_scenario(self, scenario, udf_name, buf):
+        batch, num_batches, col0_type = self._scenarios[scenario]
+        udf_func, ret_type, arg_offsets = self._udfs[udf_name]
+        if ret_type is None:
+            ret_type = col0_type
+        _write_worker_input(
+            self._eval_type,
+            lambda b: _build_udf_payload(udf_func, ret_type, arg_offsets, b),
+            lambda b: _write_arrow_ipc_batches((batch for _ in 
range(num_batches)), b),
+            buf,
         )
 
-    # -- benchmarks ---------------------------------------------------------
 
-    def _run(self, input_bytes):
-        worker_main(io.BytesIO(input_bytes), io.BytesIO())
+# -- SQL_SCALAR_ARROW_UDF ---------------------------------------------------
+# UDF receives individual ``pa.Array`` columns, returns a ``pa.Array``.
+# All UDFs operate on arg_offsets=[0] so they work with any column type.
+
+
+def _sort_arrow(c):
+    import pyarrow.compute as pc
+
+    return pc.take(c, pc.sort_indices(c))
+
+
+def _nullcheck_arrow(c):
+    import pyarrow.compute as pc
+
+    return pc.is_valid(c)
+
+
+# ret_type=None means "use col0_type from the scenario"
+_SCALAR_ARROW_UDFS = {
+    "identity_udf": (lambda c: c, None, [0]),
+    "sort_udf": (_sort_arrow, None, [0]),
+    "nullcheck_udf": (_nullcheck_arrow, BooleanType(), [0]),
+}
+
+
+class ScalarArrowUDFTimeBench(_NonGroupedBenchMixin, _TimeBenchBase):
+    _eval_type = PythonEvalType.SQL_SCALAR_ARROW_UDF
+    _scenarios = _NON_GROUPED_SCENARIOS
+    _udfs = _SCALAR_ARROW_UDFS
+    params = [list(_NON_GROUPED_SCENARIOS), list(_SCALAR_ARROW_UDFS)]
+    param_names = ["scenario", "udf"]
+
+
+class ScalarArrowUDFPeakmemBench(_NonGroupedBenchMixin, _PeakmemBenchBase):
+    _eval_type = PythonEvalType.SQL_SCALAR_ARROW_UDF
+    _scenarios = _NON_GROUPED_SCENARIOS
+    _udfs = _SCALAR_ARROW_UDFS
+    params = [list(_NON_GROUPED_SCENARIOS), list(_SCALAR_ARROW_UDFS)]
+    param_names = ["scenario", "udf"]
+
+
+# -- SQL_SCALAR_ARROW_ITER_UDF ----------------------------------------------
+# UDF receives ``Iterator[pa.Array]``, returns ``Iterator[pa.Array]``.
+
+
+def _identity_iter(it):
+    return (c for c in it)
 
-    def time_small_groups_few_cols(self):
-        """1k rows/group, 5 cols, 1500 groups."""
-        self._run(self._small_few_input)
 
-    def peakmem_small_groups_few_cols(self):
-        """1k rows/group, 5 cols, 1500 groups."""
-        self._run(self._small_few_input)
+def _sort_iter(it):
+    import pyarrow.compute as pc
 
-    def time_small_groups_many_cols(self):
-        """1k rows/group, 50 cols, 200 groups."""
-        self._run(self._small_many_input)
+    for c in it:
+        yield pc.take(c, pc.sort_indices(c))
 
-    def peakmem_small_groups_many_cols(self):
-        """1k rows/group, 50 cols, 200 groups."""
-        self._run(self._small_many_input)
 
-    def time_large_groups_few_cols(self):
-        """100k rows/group, 5 cols, 350 groups, split at 10k rows/batch."""
-        self._run(self._large_few_input)
+def _nullcheck_iter(it):
+    import pyarrow.compute as pc
 
-    def peakmem_large_groups_few_cols(self):
-        """100k rows/group, 5 cols, 350 groups, split at 10k rows/batch."""
-        self._run(self._large_few_input)
+    for c in it:
+        yield pc.is_valid(c)
 
-    def time_large_groups_many_cols(self):
-        """100k rows/group, 50 cols, 40 groups, split at 10k rows/batch."""
-        self._run(self._large_many_input)
 
-    def peakmem_large_groups_many_cols(self):
-        """100k rows/group, 50 cols, 40 groups, split at 10k rows/batch."""
-        self._run(self._large_many_input)
+_SCALAR_ARROW_ITER_UDFS = {
+    "identity_udf": (_identity_iter, None, [0]),
+    "sort_udf": (_sort_iter, None, [0]),
+    "nullcheck_udf": (_nullcheck_iter, BooleanType(), [0]),
+}
 
-    def time_mixed_types(self):
-        """Mixed column types, 1-arg UDF, 3 rows/group, 1300 groups."""
-        self._run(self._mixed_input)
 
-    def peakmem_mixed_types(self):
-        """Mixed column types, 1-arg UDF, 3 rows/group, 1300 groups."""
-        self._run(self._mixed_input)
+class ScalarArrowIterUDFTimeBench(_NonGroupedBenchMixin, _TimeBenchBase):
+    _eval_type = PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF
+    _scenarios = _NON_GROUPED_SCENARIOS
+    _udfs = _SCALAR_ARROW_ITER_UDFS
+    params = [list(_NON_GROUPED_SCENARIOS), list(_SCALAR_ARROW_ITER_UDFS)]
+    param_names = ["scenario", "udf"]
 
-    def time_mixed_types_two_args(self):
-        """Mixed column types, 2-arg UDF with key, 3 rows/group, 1600 
groups."""
-        self._run(self._two_args_input)
 
-    def peakmem_mixed_types_two_args(self):
-        """Mixed column types, 2-arg UDF with key, 3 rows/group, 1600 
groups."""
-        self._run(self._two_args_input)
+class ScalarArrowIterUDFPeakmemBench(_NonGroupedBenchMixin, _PeakmemBenchBase):
+    _eval_type = PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF
+    _scenarios = _NON_GROUPED_SCENARIOS
+    _udfs = _SCALAR_ARROW_ITER_UDFS
+    params = [list(_NON_GROUPED_SCENARIOS), list(_SCALAR_ARROW_ITER_UDFS)]
+    param_names = ["scenario", "udf"]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to