This is an automated email from the ASF dual-hosted git repository.

JingsongLi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git


The following commit(s) were added to refs/heads/master by this push:
     new 01a4732991 [python] Add parallel split reading to to_pandas / to_arrow 
(#7870)
01a4732991 is described below

commit 01a4732991f0bd838975f6985c5ae30f14135f44
Author: chaoyang <[email protected]>
AuthorDate: Wed May 20 19:20:34 2026 +0800

    [python] Add parallel split reading to to_pandas / to_arrow (#7870)
    
    Today `TableRead.to_pandas` / `to_arrow` iterate splits serially in
    `_arrow_batch_generator`, so wall time scales linearly with the number
    of splits even though PyArrow's parquet/orc readers release the GIL
    during decode. Unlike Java, where Flink/Spark fan splits out across
    TaskManagers/Executors, PyPaimon has no external framework above the
    SDK; split-level parallelism therefore has to live inside the SDK.
    
    This PR adds an opt-in `max_workers` parameter to `to_pandas` /
    `to_arrow`. Default behavior is unchanged.
---
 .../pypaimon/common/options/core_options.py        |  15 +
 paimon-python/pypaimon/read/table_read.py          | 215 ++++++++++-
 .../pypaimon/tests/reader_parallel_test.py         | 405 +++++++++++++++++++++
 3 files changed, 631 insertions(+), 4 deletions(-)

diff --git a/paimon-python/pypaimon/common/options/core_options.py 
b/paimon-python/pypaimon/common/options/core_options.py
index 7d9a227e4a..2d140b9539 100644
--- a/paimon-python/pypaimon/common/options/core_options.py
+++ b/paimon-python/pypaimon/common/options/core_options.py
@@ -449,6 +449,18 @@ class CoreOptions:
         .with_description("Read batch size for any file format if it 
supports.")
     )
 
+    READ_PARALLELISM: ConfigOption[int] = (
+        ConfigOptions.key("read.parallelism")
+        .int_type()
+        .default_value(1)
+        .with_description(
+            "Parallelism for reading splits within a single TableRead call. "
+            "The value 1 (default) keeps reads serial. Values >= 2 enable a "
+            "thread pool that reads splits concurrently and assembles the "
+            "result in input order. Has no effect when fewer than 2 splits "
+            "are passed.")
+    )
+
     ADD_COLUMN_BEFORE_PARTITION: ConfigOption[bool] = (
         ConfigOptions.key("add-column-before-partition")
         .boolean_type()
@@ -702,6 +714,9 @@ class CoreOptions:
     def read_batch_size(self, default=None) -> int:
         return self.options.get(CoreOptions.READ_BATCH_SIZE, default or 1024)
 
+    def read_parallelism(self, default=None) -> int:
+        return self.options.get(CoreOptions.READ_PARALLELISM, default)
+
     def add_column_before_partition(self) -> bool:
         return self.options.get(CoreOptions.ADD_COLUMN_BEFORE_PARTITION, False)
 
diff --git a/paimon-python/pypaimon/read/table_read.py 
b/paimon-python/pypaimon/read/table_read.py
index c45f2a8532..b3a8edaf63 100644
--- a/paimon-python/pypaimon/read/table_read.py
+++ b/paimon-python/pypaimon/read/table_read.py
@@ -15,6 +15,8 @@
 # specific language governing permissions and limitations
 # under the License.
 
+import threading
+from concurrent.futures import ThreadPoolExecutor, as_completed
 from typing import Any, Dict, Iterator, List, Optional
 
 import pandas
@@ -32,6 +34,41 @@ from pypaimon.table.row.offset_row import OffsetRow
 ROW_KIND_COLUMN = "_row_kind"
 
 
+class _RemainingRows:
+    """Thread-safe remaining-rows counter for parallel reads.
+
+    Row quota is pre-debited under a single lock so that any rows that
+    threads commit to emit are guaranteed not to overshoot the limit, even
+    if individual readers keep decoding one extra batch after the quota is
+    exhausted.
+
+    When ``limit`` is None the counter is unbounded and ``try_consume``
+    always returns the requested row count.
+    """
+
+    def __init__(self, limit: Optional[int]):
+        self._lock = threading.Lock()
+        self._remaining = limit  # None == unlimited
+
+    def try_consume(self, requested: int) -> int:
+        if self._remaining is None:
+            return requested
+        if requested <= 0:
+            return 0
+        with self._lock:
+            if self._remaining <= 0:
+                return 0
+            allowed = min(requested, self._remaining)
+            self._remaining -= allowed
+            return allowed
+
+    def exhausted(self) -> bool:
+        if self._remaining is None:
+            return False
+        with self._lock:
+            return self._remaining <= 0
+
+
 class TableRead:
     """Implementation of TableRead for native Python reading."""
 
@@ -52,6 +89,7 @@ class TableRead:
         self.include_row_kind = include_row_kind
         self.nested_name_paths = nested_name_paths
         self.limit = limit
+        self._read_parallelism = self.table.options.read_parallelism()
 
     def to_iterator(self, splits: List[Split]) -> Iterator:
         limit = self.limit
@@ -104,13 +142,34 @@ class TableRead:
 
         return pyarrow.RecordBatch.from_arrays(columns, schema=target_schema)
 
-    def to_arrow(self, splits: List[Split]) -> Optional[pyarrow.Table]:
-        batch_reader = self.to_arrow_batch_reader(splits)
+    def to_arrow(
+        self,
+        splits: List[Split],
+        parallelism: Optional[int] = None,
+    ) -> Optional[pyarrow.Table]:
+        """Read ``splits`` into a single arrow ``Table``.
 
+        Args:
+            splits: scan-plan splits returned from a ``TableScan``.
+            parallelism: optional runtime override of the
+                ``read.parallelism`` table option. ``None`` (default) falls
+                back to the table option; a non-None value temporarily
+                overrides it for this call. ``1`` keeps reads serial;
+                ``>= 2`` enables a thread pool that reads splits
+                concurrently and assembles the final table in input order.
+                Must be ``>= 1``.
+        """
+        # TODO: default read.parallelism to min(splits, cpu_count()) once 
stable
+        effective = self._resolve_parallelism(parallelism)
         schema = PyarrowFieldParser.from_paimon_schema(self.read_type)
         if self.include_row_kind:
             schema = self._add_row_kind_to_schema(schema)
 
+        if self._should_run_parallel(splits, effective):
+            return self._to_arrow_parallel(splits, schema, effective)
+
+        batch_reader = self.to_arrow_batch_reader(splits)
+
         table_list = []
         for batch in iter(batch_reader.read_next_batch, None):
             if batch.num_rows == 0:
@@ -183,6 +242,146 @@ class TableRead:
             finally:
                 reader.close()
 
+    def _resolve_parallelism(self, runtime: Optional[int]) -> int:
+        """Pick the effective parallelism and reject illegal values.
+
+        Priority: explicit ``parallelism`` argument > ``read.parallelism``
+        table option > built-in default of 1. The validation message names
+        whichever source produced the offending value, so users know where
+        to fix it.
+        """
+        if runtime is not None:
+            value = runtime
+            source = "parallelism"
+        else:
+            value = self._read_parallelism
+            source = "read.parallelism"
+        if value < 1:
+            raise ValueError(f"{source} must be >= 1, got {value}")
+        return value
+
+    def _should_run_parallel(
+        self,
+        splits: List[Split],
+        effective: int,
+    ) -> bool:
+        """Decide whether to take the parallel read path.
+
+        ``effective == 1`` falls back to the serial path (no thread pool
+        overhead, no behavior change). A single split is never
+        parallelized since there is nothing to fan out across.
+        """
+        return effective >= 2 and len(splits) >= 2
+
+    def _to_arrow_parallel(
+        self,
+        splits: List[Split],
+        schema: pyarrow.Schema,
+        effective: int,
+    ) -> pyarrow.Table:
+        """Read ``splits`` concurrently and assemble the result in input order.
+
+        Each split is read in its own worker thread; row quota for ``limit``
+        is shared through :class:`_RemainingRows` so the combined output
+        never exceeds ``self.limit`` rows. Per-split batches are collected
+        by submission index, so the merged table preserves the order of the
+        input ``splits`` list.
+        """
+        remaining_state = _RemainingRows(self.limit)
+        results: List[Optional[List[pyarrow.RecordBatch]]] = [None] * 
len(splits)
+        workers = min(effective, len(splits))
+        with ThreadPoolExecutor(
+            max_workers=workers,
+            thread_name_prefix="pypaimon-read",
+        ) as executor:
+            futures = {
+                executor.submit(
+                    self._read_one_split_to_batches,
+                    split,
+                    schema,
+                    remaining_state,
+                ): idx
+                for idx, split in enumerate(splits)
+            }
+            for fut in as_completed(futures):
+                results[futures[fut]] = fut.result()
+
+        table_list: List[pyarrow.RecordBatch] = []
+        for split_batches in results:
+            if split_batches is None:
+                continue
+            for batch in split_batches:
+                if batch.num_rows == 0:
+                    continue
+                table_list.append(self._try_to_pad_batch_by_schema(batch, 
schema))
+
+        if not table_list:
+            return pyarrow.Table.from_arrays(
+                [pyarrow.array([], type=field.type) for field in schema],
+                schema=schema,
+            )
+        return pyarrow.Table.from_batches(table_list)
+
+    def _read_one_split_to_batches(
+        self,
+        split: Split,
+        schema: pyarrow.Schema,
+        remaining_state: _RemainingRows,
+    ) -> List[pyarrow.RecordBatch]:
+        """Read a single split into arrow batches under soft-stop control.
+
+        Row quota is debited against the shared ``remaining_state``; once a
+        request returns 0, the worker stops emitting further batches. The
+        reader is always closed via ``finally``.
+        """
+        chunk_size = 65536
+        out: List[pyarrow.RecordBatch] = []
+        reader = self._create_split_read(split).create_reader()
+        try:
+            if isinstance(reader, RecordBatchReader):
+                for batch in iter(reader.read_arrow_batch, None):
+                    allowed = remaining_state.try_consume(batch.num_rows)
+                    if allowed == 0:
+                        break
+                    if allowed < batch.num_rows:
+                        batch = batch.slice(0, allowed)
+                    if self.include_row_kind:
+                        batch = self._add_row_kind_column_to_batch(batch, "+I")
+                    out.append(batch)
+                    if remaining_state.exhausted():
+                        break
+            else:
+                row_tuple_chunk: List[tuple] = []
+                row_kind_chunk: List[str] = []
+                stop = False
+                while not stop:
+                    row_iterator = reader.read_batch()
+                    if row_iterator is None:
+                        break
+                    for row in iter(row_iterator.next, None):
+                        if not isinstance(row, OffsetRow):
+                            raise TypeError(
+                                f"Expected OffsetRow, but got 
{type(row).__name__}")
+                        if remaining_state.try_consume(1) == 0:
+                            stop = True
+                            break
+                        row_tuple_chunk.append(
+                            row.row_tuple[row.offset: row.offset + row.arity])
+                        if self.include_row_kind:
+                            
row_kind_chunk.append(row.get_row_kind().to_string())
+
+                        if len(row_tuple_chunk) >= chunk_size:
+                            
out.append(self._convert_rows_to_arrow_batch_with_row_kind(
+                                row_tuple_chunk, row_kind_chunk, schema))
+                            row_tuple_chunk = []
+                            row_kind_chunk = []
+                if row_tuple_chunk:
+                    out.append(self._convert_rows_to_arrow_batch_with_row_kind(
+                        row_tuple_chunk, row_kind_chunk, schema))
+        finally:
+            reader.close()
+        return out
+
     def _convert_rows_to_arrow_batch_with_row_kind(
         self,
         row_tuples: List[tuple],
@@ -216,8 +415,16 @@ class TableRead:
         columns = [row_kind_array] + [batch.column(i) for i in 
range(batch.num_columns)]
         return pyarrow.RecordBatch.from_arrays(columns, schema=new_schema)
 
-    def to_pandas(self, splits: List[Split]) -> pandas.DataFrame:
-        arrow_table = self.to_arrow(splits)
+    def to_pandas(
+        self,
+        splits: List[Split],
+        parallelism: Optional[int] = None,
+    ) -> pandas.DataFrame:
+        """Read ``splits`` into a pandas ``DataFrame``.
+
+        See :meth:`to_arrow` for the semantics of ``parallelism``.
+        """
+        arrow_table = self.to_arrow(splits, parallelism=parallelism)
         return arrow_table.to_pandas()
 
     def to_duckdb(self, splits: List[Split], table_name: str,
diff --git a/paimon-python/pypaimon/tests/reader_parallel_test.py 
b/paimon-python/pypaimon/tests/reader_parallel_test.py
new file mode 100644
index 0000000000..692d1c5b48
--- /dev/null
+++ b/paimon-python/pypaimon/tests/reader_parallel_test.py
@@ -0,0 +1,405 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import os
+import shutil
+import tempfile
+import threading
+import unittest
+from unittest import mock
+
+import pyarrow as pa
+
+from pypaimon import CatalogFactory, Schema
+from pypaimon.read.table_read import TableRead, _RemainingRows
+
+
+class RemainingRowsTest(unittest.TestCase):
+    """Pure unit tests for the row-quota counter — no Paimon table needed."""
+
+    def test_unlimited(self):
+        rr = _RemainingRows(None)
+        self.assertEqual(rr.try_consume(1_000_000), 1_000_000)
+        self.assertEqual(rr.try_consume(1), 1)
+        self.assertFalse(rr.exhausted())
+
+    def test_basic_pre_debit(self):
+        rr = _RemainingRows(100)
+        self.assertEqual(rr.try_consume(40), 40)
+        self.assertEqual(rr.try_consume(40), 40)
+        # Only 20 left, asking for 30 returns 20.
+        self.assertEqual(rr.try_consume(30), 20)
+        self.assertTrue(rr.exhausted())
+        self.assertEqual(rr.try_consume(1), 0)
+
+    def test_zero_request(self):
+        rr = _RemainingRows(100)
+        self.assertEqual(rr.try_consume(0), 0)
+        # Quota unchanged.
+        self.assertEqual(rr.try_consume(100), 100)
+
+    def test_concurrent_consume_never_overcommits(self):
+        rr = _RemainingRows(10_000)
+        granted = []
+        granted_lock = threading.Lock()
+        barrier = threading.Barrier(8)
+
+        def worker():
+            barrier.wait()
+            for _ in range(2000):
+                got = rr.try_consume(7)
+                if got:
+                    with granted_lock:
+                        granted.append(got)
+
+        threads = [threading.Thread(target=worker) for _ in range(8)]
+        for t in threads:
+            t.start()
+        for t in threads:
+            t.join()
+        self.assertEqual(sum(granted), 10_000)
+        self.assertTrue(rr.exhausted())
+
+
+class ParallelReaderAppendOnlyTest(unittest.TestCase):
+    """Append-only multi-partition table — parallel must match serial 
exactly."""
+
+    @classmethod
+    def setUpClass(cls):
+        cls.tempdir = tempfile.mkdtemp()
+        cls.warehouse = os.path.join(cls.tempdir, 'warehouse')
+        cls.catalog = CatalogFactory.create({'warehouse': cls.warehouse})
+        cls.catalog.create_database('default', False)
+
+        cls.pa_schema = pa.schema([
+            ('user_id', pa.int64()),
+            ('item_id', pa.int64()),
+            ('behavior', pa.string()),
+            ('dt', pa.string()),
+        ])
+
+        # 8 partitions => 8 splits.
+        rows_per_partition = 250
+        user_ids, item_ids, behaviors, dts = [], [], [], []
+        for p in range(8):
+            for i in range(rows_per_partition):
+                user_ids.append(p * rows_per_partition + i)
+                item_ids.append(1000 + i)
+                behaviors.append(f"act-{i % 5}")
+                dts.append(f"p{p}")
+        cls.expected_rows = len(user_ids)
+        data = pa.Table.from_pydict({
+            'user_id': user_ids,
+            'item_id': item_ids,
+            'behavior': behaviors,
+            'dt': dts,
+        }, schema=cls.pa_schema)
+
+        # Default table — read.parallelism unset (defaults to 1, i.e. serial).
+        cls.table = cls._build_table('append_parallel_default', None, data)
+        # Option-set table — read.parallelism=4 baked into the table schema.
+        cls.table_opt_4 = cls._build_table(
+            'append_parallel_opt4', {'read.parallelism': '4'}, data)
+
+    @classmethod
+    def _build_table(cls, name, options, data):
+        schema = Schema.from_pyarrow_schema(
+            cls.pa_schema,
+            partition_keys=['dt'],
+            options=options,
+        )
+        cls.catalog.create_table(f'default.{name}', schema, False)
+        table = cls.catalog.get_table(f'default.{name}')
+        wb = table.new_batch_write_builder()
+        w, c = wb.new_write(), wb.new_commit()
+        w.write_arrow(data)
+        c.commit(w.prepare_commit())
+        w.close()
+        c.close()
+        return table
+
+    @classmethod
+    def tearDownClass(cls):
+        shutil.rmtree(cls.tempdir, ignore_errors=True)
+
+    def _scan_splits(self, read_builder):
+        return read_builder.new_scan().plan().splits()
+
+    def test_multi_partition_yields_multiple_splits(self):
+        splits = self._scan_splits(self.table.new_read_builder())
+        self.assertGreaterEqual(len(splits), 4,
+                                f"expected multi-split fixture, got 
{len(splits)}")
+
+    # ------------------------------------------------------------------
+    # Result-parity tests for the two opt-in paths.
+    # ------------------------------------------------------------------
+
+    def test_parallel_via_method_arg_matches_serial(self):
+        rb = self.table.new_read_builder()
+        splits = self._scan_splits(rb)
+        read = rb.new_read()
+        serial = read.to_arrow(splits)
+        parallel = read.to_arrow(splits, parallelism=4)
+        # Same split order preserved => byte-identical tables.
+        self.assertEqual(serial, parallel)
+
+        df_serial = 
serial.to_pandas().sort_values('user_id').reset_index(drop=True)
+        df_parallel = read.to_pandas(splits, parallelism=4) \
+            .sort_values('user_id').reset_index(drop=True)
+        self.assertTrue(df_serial.equals(df_parallel))
+        self.assertEqual(len(df_parallel), self.expected_rows)
+
+    def test_parallel_via_table_option_matches_serial(self):
+        rb_serial = self.table.new_read_builder()
+        rb_parallel = self.table_opt_4.new_read_builder()
+        splits_serial = self._scan_splits(rb_serial)
+        splits_parallel = self._scan_splits(rb_parallel)
+
+        serial_df = rb_serial.new_read().to_pandas(splits_serial) \
+            .sort_values('user_id').reset_index(drop=True)
+        # No explicit parallelism — must pick up read.parallelism=4 from the 
table option.
+        parallel_df = rb_parallel.new_read().to_pandas(splits_parallel) \
+            .sort_values('user_id').reset_index(drop=True)
+        self.assertTrue(serial_df.equals(parallel_df))
+
+    # ------------------------------------------------------------------
+    # Priority: method arg > table option > built-in default.
+    # ------------------------------------------------------------------
+
+    def test_method_arg_overrides_option_to_serial(self):
+        # option=4 but caller passes 1: should disable parallelism.
+        read = self.table_opt_4.new_read_builder().new_read()
+        with mock.patch.object(read, '_to_arrow_parallel') as patched:
+            patched.side_effect = AssertionError(
+                "_to_arrow_parallel should not be called when arg=1 overrides 
option")
+            splits = self._scan_splits(self.table_opt_4.new_read_builder())
+            read.to_arrow(splits, parallelism=1)
+
+    def test_method_arg_overrides_option_to_parallel(self):
+        # option=1 (default) but caller passes 4: should enable parallelism.
+        read = self.table.new_read_builder().new_read()
+        splits = self._scan_splits(self.table.new_read_builder())
+        with mock.patch.object(
+            read, '_to_arrow_parallel', wraps=read._to_arrow_parallel
+        ) as spy:
+            result = read.to_arrow(splits, parallelism=4)
+        spy.assert_called_once()
+        self.assertEqual(result.num_rows, self.expected_rows)
+
+    # ------------------------------------------------------------------
+    # Boundary / invalid value handling.
+    # ------------------------------------------------------------------
+
+    def test_parallelism_one_equals_serial(self):
+        rb = self.table.new_read_builder()
+        splits = self._scan_splits(rb)
+        read = rb.new_read()
+        self.assertEqual(read.to_arrow(splits),
+                         read.to_arrow(splits, parallelism=1))
+
+    def test_parallelism_exceeds_split_count(self):
+        rb = self.table.new_read_builder()
+        splits = self._scan_splits(rb)
+        read = rb.new_read()
+        # 64 workers but only ~8 splits — should clamp internally, no error.
+        result = read.to_arrow(splits, parallelism=64)
+        self.assertEqual(result.num_rows, self.expected_rows)
+
+    def test_invalid_method_arg_raises(self):
+        rb = self.table.new_read_builder()
+        splits = self._scan_splits(rb)
+        read = rb.new_read()
+        with self.assertRaises(ValueError) as ctx:
+            read.to_arrow(splits, parallelism=0)
+        self.assertIn("parallelism", str(ctx.exception))
+        self.assertNotIn("read.parallelism", str(ctx.exception))
+        with self.assertRaises(ValueError):
+            read.to_pandas(splits, parallelism=-1)
+
+    def test_invalid_option_value_raises(self):
+        # Build a fresh table with an invalid option value.
+        schema = Schema.from_pyarrow_schema(
+            self.pa_schema,
+            partition_keys=['dt'],
+            options={'read.parallelism': '0'},
+        )
+        self.catalog.create_table('default.bad_option', schema, False)
+        bad = self.catalog.get_table('default.bad_option')
+        # Reproduce the data so split planning yields a non-trivial plan.
+        wb = bad.new_batch_write_builder()
+        w, c = wb.new_write(), wb.new_commit()
+        w.write_arrow(pa.Table.from_pydict({
+            'user_id': [1, 2], 'item_id': [10, 20],
+            'behavior': ['a', 'b'], 'dt': ['p1', 'p2'],
+        }, schema=self.pa_schema))
+        c.commit(w.prepare_commit())
+        w.close()
+        c.close()
+
+        read = bad.new_read_builder().new_read()
+        splits = bad.new_read_builder().new_scan().plan().splits()
+        with self.assertRaises(ValueError) as ctx:
+            read.to_arrow(splits)
+        self.assertIn("read.parallelism", str(ctx.exception))
+
+    def test_empty_splits_with_parallel_arg(self):
+        rb = self.table.new_read_builder()
+        read = rb.new_read()
+        result = read.to_arrow([], parallelism=4)
+        self.assertEqual(result.num_rows, 0)
+        self.assertEqual([f.name for f in result.schema],
+                         [f.name for f in self.pa_schema])
+
+    def test_parallel_with_limit_soft_stop(self):
+        # 10 calls with limit=600 should all return exactly 600 rows.
+        limit = 600
+        for _ in range(10):
+            rb = self.table.new_read_builder().with_limit(limit)
+            splits = self._scan_splits(rb)
+            df = rb.new_read().to_pandas(splits, parallelism=4)
+            self.assertEqual(len(df), limit)
+
+    def test_parallel_reader_error_propagates(self):
+        rb = self.table.new_read_builder()
+        splits = self._scan_splits(rb)
+        self.assertGreaterEqual(len(splits), 2)
+
+        original_create = TableRead._create_split_read
+        call_counter = {'n': 0}
+        lock = threading.Lock()
+
+        def flaky(self_, split):
+            with lock:
+                call_counter['n'] += 1
+                idx = call_counter['n']
+            if idx == 2:
+                raise RuntimeError("simulated reader failure")
+            return original_create(self_, split)
+
+        with mock.patch.object(TableRead, '_create_split_read', flaky):
+            read = rb.new_read()
+            with self.assertRaises(RuntimeError) as ctx:
+                read.to_pandas(splits, parallelism=4)
+            self.assertIn("simulated reader failure", str(ctx.exception))
+
+
+class ParallelReaderPrimaryKeyTest(unittest.TestCase):
+    """PK + multi-bucket merge-on-read parity between serial and parallel."""
+
+    @classmethod
+    def setUpClass(cls):
+        cls.tempdir = tempfile.mkdtemp()
+        cls.warehouse = os.path.join(cls.tempdir, 'warehouse')
+        cls.catalog = CatalogFactory.create({'warehouse': cls.warehouse})
+        cls.catalog.create_database('default', False)
+
+        cls.pa_schema = pa.schema([
+            pa.field('user_id', pa.int32(), nullable=False),
+            ('item_id', pa.int64()),
+            ('behavior', pa.string()),
+            pa.field('dt', pa.string(), nullable=False),
+        ])
+
+        schema = Schema.from_pyarrow_schema(
+            cls.pa_schema,
+            partition_keys=['dt'],
+            primary_keys=['user_id', 'dt'],
+            options={'bucket': '4'},
+        )
+        cls.catalog.create_table('default.pk_parallel', schema, False)
+        cls.table = cls.catalog.get_table('default.pk_parallel')
+
+        # First snapshot.
+        v1 = pa.Table.from_pydict({
+            'user_id': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] * 4,
+            'item_id': list(range(1001, 1041)),
+            'behavior': [f"v1-{i}" for i in range(40)],
+            'dt': (['p1'] * 10 + ['p2'] * 10 + ['p3'] * 10 + ['p4'] * 10),
+        }, schema=cls.pa_schema)
+        wb = cls.table.new_batch_write_builder()
+        w, c = wb.new_write(), wb.new_commit()
+        w.write_arrow(v1)
+        c.commit(w.prepare_commit())
+        w.close()
+        c.close()
+
+        # Second snapshot — updates some rows for the same PK to exercise 
merge.
+        v2 = pa.Table.from_pydict({
+            'user_id': [1, 2, 3, 4, 5],
+            'item_id': [9001, 9002, 9003, 9004, 9005],
+            'behavior': ['v2-updated'] * 5,
+            'dt': ['p1', 'p1', 'p2', 'p2', 'p3'],
+        }, schema=cls.pa_schema)
+        wb = cls.table.new_batch_write_builder()
+        w, c = wb.new_write(), wb.new_commit()
+        w.write_arrow(v2)
+        c.commit(w.prepare_commit())
+        w.close()
+        c.close()
+
+    @classmethod
+    def tearDownClass(cls):
+        shutil.rmtree(cls.tempdir, ignore_errors=True)
+
+    def test_multi_bucket_yields_multiple_splits(self):
+        splits = self.table.new_read_builder().new_scan().plan().splits()
+        self.assertGreaterEqual(
+            len(splits), 2,
+            f"expected multi-bucket fixture to yield >= 2 splits, got 
{len(splits)}")
+
+    def test_parallel_merge_matches_serial(self):
+        rb = self.table.new_read_builder()
+        splits = rb.new_scan().plan().splits()
+        read = rb.new_read()
+        serial = read.to_pandas(splits).sort_values(
+            ['dt', 'user_id']).reset_index(drop=True)
+        parallel = read.to_pandas(splits, parallelism=4).sort_values(
+            ['dt', 'user_id']).reset_index(drop=True)
+        self.assertTrue(serial.equals(parallel))
+        # Ensure the merge actually picked the latest version.
+        # user_id=1 / dt=p1 must have behavior='v2-updated', item_id=9001.
+        updated = parallel[(parallel.user_id == 1) & (parallel.dt == 'p1')]
+        self.assertEqual(len(updated), 1)
+        self.assertEqual(updated.iloc[0].behavior, 'v2-updated')
+        self.assertEqual(updated.iloc[0].item_id, 9001)
+
+    def test_parallel_with_limit_pk(self):
+        limit = 12
+        rb = self.table.new_read_builder().with_limit(limit)
+        splits = rb.new_scan().plan().splits()
+        df = rb.new_read().to_pandas(splits, parallelism=4)
+        self.assertLessEqual(len(df), limit)
+        # PK table merge applies limit per-split internally; in addition the
+        # parallel soft-stop caps the global total. The combined output must
+        # never exceed the user-visible limit.
+
+    def test_include_row_kind_parallel(self):
+        rb = self.table.new_read_builder()
+        splits = rb.new_scan().plan().splits()
+        read = rb.new_read()
+        read.include_row_kind = True
+        serial = read.to_arrow(splits)
+        parallel = read.to_arrow(splits, parallelism=4)
+        self.assertEqual(serial.schema, parallel.schema)
+        self.assertIn('_row_kind', serial.schema.names)
+        df_s = serial.to_pandas().sort_values(['dt', 
'user_id']).reset_index(drop=True)
+        df_p = parallel.to_pandas().sort_values(['dt', 
'user_id']).reset_index(drop=True)
+        self.assertTrue(df_s.equals(df_p))
+
+
+if __name__ == '__main__':
+    unittest.main()

Reply via email to