This is an automated email from the ASF dual-hosted git repository.

JingsongLi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git


The following commit(s) were added to refs/heads/master by this push:
     new 2281765a7f [python] support chunk shuffle for planning and 3-layer 
shuffle for pytorch Dataset (#8064)
2281765a7f is described below

commit 2281765a7f44c643353266efb1f95053f9a059ea
Author: Faiz <[email protected]>
AuthorDate: Mon Jun 8 19:14:25 2026 +0800

    [python] support chunk shuffle for planning and 3-layer shuffle for pytorch 
Dataset (#8064)
---
 docs/docs/pypaimon/pytorch.md                      |  97 +++
 .../pypaimon/read/datasource/torch_dataset.py      | 224 ++++--
 .../read/scanner/chunk_shuffle_split_generator.py  | 377 ++++++++++
 .../pypaimon/read/scanner/file_scanner.py          |  66 +-
 paimon-python/pypaimon/read/table_read.py          |  20 +
 paimon-python/pypaimon/read/table_scan.py          |   4 +
 .../scanner/chunk_shuffle_split_generator_test.py  | 800 +++++++++++++++++++++
 paimon-python/pypaimon/tests/torch_read_test.py    | 230 ++++++
 8 files changed, 1775 insertions(+), 43 deletions(-)

diff --git a/docs/docs/pypaimon/pytorch.md b/docs/docs/pypaimon/pytorch.md
index 9e98b487ee..0f7f7bbdef 100644
--- a/docs/docs/pypaimon/pytorch.md
+++ b/docs/docs/pypaimon/pytorch.md
@@ -58,3 +58,100 @@ When the `streaming` parameter is true, it will iteratively 
read;
 when it is false, it will read the full amount of data into memory.
 
 **`prefetch_concurrency`** (default: 1): When streaming is true, number of 
threads used for parallel prefetch within each DataLoader worker. Set to a 
value greater than 1 to partition splits across threads and increase read 
throughput. Has no effect when streaming is false.
+
+## Shuffle
+
+PyPaimon supports streaming shuffle for PyTorch `IterableDataset`. The shuffle
+pipeline can be composed of three layers:
+
+1. **Chunk shuffle**: split files into row chunks during scan planning and
+   shuffle the generated chunk splits. This is enabled by
+   `TableScan.with_chunk_shuffle(seed, chunk_size)`.
+2. **Split interleave**: read from multiple splits in round-robin order inside
+   each DataLoader worker.
+3. **Buffer shuffle**: apply a reservoir-style row shuffle buffer before rows
+   are yielded to PyTorch.
+
+Chunk shuffle is a scan planning feature for append tables, including
+Data Evolution append tables. For Data Evolution tables, chunk shuffle keeps
+row-id-aligned data files and sidecar files together while slicing by row-id
+range. Chunk shuffle should be used with file formats that **support random
+access**. Currently, the random-access file formats are Lance, Vortex, Row, and
+Blob. Primary-key tables and deletion-vector scans are not supported by
+`with_chunk_shuffle`.
+
+The second and third layers are Dataset features. They work on the splits you
+pass to `to_torch`, so they can be used with either normal splits or
+chunk-shuffled splits.
+
+### Use Dataset Shuffle Only
+
+Use this when normal scan splits are enough and you only want split interleave
+plus row buffer shuffle:
+
+```python
+from torch.utils.data import DataLoader
+
+table_scan = read_builder.new_scan()
+table_read = read_builder.new_read()
+splits = table_scan.plan().splits()
+
+dataset = table_read.to_torch(
+    splits,
+    streaming=True,
+    shuffle=True,
+    seed=42,
+    buffer_size=1000,
+    max_buffer_input_splits=10,
+)
+
+dataloader = DataLoader(
+    dataset,
+    batch_size=32,
+    num_workers=2,
+    shuffle=False,
+)
+```
+
+`buffer_size` controls the row shuffle buffer. Larger values produce a better
+approximation of global shuffle, at the cost of more memory. If
+`max_buffer_input_splits` is `1`, split interleave is skipped and only buffer
+shuffle is applied. `shuffle=True` requires `streaming=True` and does not
+support `prefetch_concurrency > 1`.
+
+### Use All Three Layers
+
+For append tables, enable chunk shuffle during scan planning, then enable
+Dataset shuffle when converting to PyTorch:
+
+```python
+from torch.utils.data import DataLoader
+
+seed = 42
+
+table_scan = read_builder.new_scan().with_chunk_shuffle(
+    seed=seed,
+    chunk_size=1000,
+)
+table_read = read_builder.new_read()
+splits = table_scan.plan().splits()
+
+dataset = table_read.to_torch(
+    splits,
+    streaming=True,
+    shuffle=True,
+    seed=seed,
+    buffer_size=1000,
+    max_buffer_input_splits=10,
+)
+
+dataloader = DataLoader(
+    dataset,
+    batch_size=32,
+    num_workers=2,
+    shuffle=False,
+)
+```
+
+Call `dataset.set_epoch(epoch)` before creating or iterating a DataLoader for a
+new training epoch if you want a different buffer-shuffle order for each epoch.
diff --git a/paimon-python/pypaimon/read/datasource/torch_dataset.py 
b/paimon-python/pypaimon/read/datasource/torch_dataset.py
index 6012d3dc68..5eb3485ddd 100644
--- a/paimon-python/pypaimon/read/datasource/torch_dataset.py
+++ b/paimon-python/pypaimon/read/datasource/torch_dataset.py
@@ -19,8 +19,9 @@
 Module to read a Paimon table into PyTorch Dataset.
 """
 import queue
+import random
 import threading
-from typing import List
+from typing import Iterator, List
 
 import torch
 from torch.utils.data import Dataset, IterableDataset
@@ -29,6 +30,12 @@ from pypaimon.read.split import Split
 from pypaimon.read.table_read import TableRead
 
 
+def _share_epoch_with_torch_workers(value):
+    if isinstance(value, torch.Tensor):
+        return value.share_memory_()
+    return torch.tensor(value, dtype=torch.long).share_memory_()
+
+
 class TorchDataset(Dataset):
     """
     PyTorch Dataset implementation for reading Paimon table data.
@@ -76,7 +83,44 @@ class TorchDataset(Dataset):
         return self._data[index]
 
 
-class TorchIterDataset(IterableDataset):
+class _BaseTorchIterDataset(IterableDataset):
+    """
+    Shared helpers for streaming PyTorch datasets backed by Paimon splits.
+    """
+
+    def __init__(self, table_read: TableRead, splits: List[Split]):
+        self.table_read = table_read
+        self.splits = splits
+        self.field_names = [field.name for field in table_read.read_type]
+
+    def _row_to_dict(self, offset_row) -> dict:
+        row_dict = {}
+        for i, field_name in enumerate(self.field_names):
+            value = offset_row.get_field(i)
+            row_dict[field_name] = value
+        return row_dict
+
+    def _worker_splits(self, worker_info) -> List[Split]:
+        if worker_info is None:
+            return self.splits
+
+        worker_id = worker_info.id
+        num_workers = worker_info.num_workers
+        total_splits = len(self.splits)
+        splits_per_worker = total_splits // num_workers
+        remainder = total_splits % num_workers
+
+        if worker_id < remainder:
+            start_idx = worker_id * (splits_per_worker + 1)
+            end_idx = start_idx + splits_per_worker + 1
+        else:
+            start_idx = worker_id * splits_per_worker + remainder
+            end_idx = start_idx + splits_per_worker
+
+        return self.splits[start_idx:end_idx]
+
+
+class TorchIterDataset(_BaseTorchIterDataset):
     """
     PyTorch IterableDataset implementation for reading Paimon table data.
 
@@ -104,18 +148,8 @@ class TorchIterDataset(IterableDataset):
                 this worker (default 1). When > 1, splits are partitioned 
across
                 threads to increase read throughput.
         """
-        self.table_read = table_read
-        self.splits = splits
+        super().__init__(table_read, splits)
         self.prefetch_concurrency = max(1, int(prefetch_concurrency))
-        # Get field names from read_type
-        self.field_names = [field.name for field in table_read.read_type]
-
-    def _row_to_dict(self, offset_row) -> dict:
-        row_dict = {}
-        for i, field_name in enumerate(self.field_names):
-            value = offset_row.get_field(i)
-            row_dict[field_name] = value
-        return row_dict
 
     def __iter__(self):
         """
@@ -128,30 +162,7 @@ class TorchIterDataset(IterableDataset):
             row data of dict type, where keys are column names
         """
         worker_info = torch.utils.data.get_worker_info()
-
-        if worker_info is None:
-            # Single-process data loading, iterate over all splits
-            splits_to_process = self.splits
-        else:
-            # Multi-process data loading, partition splits across workers
-            worker_id = worker_info.id
-            num_workers = worker_info.num_workers
-
-            # Calculate start and end indices for this worker
-            # Distribute splits evenly by slicing
-            total_splits = len(self.splits)
-            splits_per_worker = total_splits // num_workers
-            remainder = total_splits % num_workers
-
-            # Workers with id < remainder get one extra split
-            if worker_id < remainder:
-                start_idx = worker_id * (splits_per_worker + 1)
-                end_idx = start_idx + splits_per_worker + 1
-            else:
-                start_idx = worker_id * splits_per_worker + remainder
-                end_idx = start_idx + splits_per_worker
-
-            splits_to_process = self.splits[start_idx:end_idx]
+        splits_to_process = self._worker_splits(worker_info)
 
         if self.prefetch_concurrency > 1:
             for row in self._iter_rows(splits_to_process):
@@ -161,11 +172,7 @@ class TorchIterDataset(IterableDataset):
         worker_iterator = self.table_read.to_iterator(splits_to_process)
 
         for offset_row in worker_iterator:
-            row_dict = {}
-            for i, field_name in enumerate(self.field_names):
-                value = offset_row.get_field(i)
-                row_dict[field_name] = value
-            yield row_dict
+            yield self._row_to_dict(offset_row)
 
     def _iter_rows(self, splits: List[Split]):
         n = min(self.prefetch_concurrency, len(splits))
@@ -221,3 +228,136 @@ class TorchIterDataset(IterableDataset):
             stop.set()
             for t in threads:
                 t.join(timeout=self._PREFETCH_JOIN_TIMEOUT_SEC)
+
+
+class TorchShuffledIterDataset(_BaseTorchIterDataset):
+    """
+    PyTorch IterableDataset with Paimon-controlled streaming shuffle.
+
+    This dataset consumes pre-planned splits, then mixes rows with split
+    interleaving and a shuffle buffer. Chunk-level shuffle, when needed,
+    stays in TableScan.with_chunk_shuffle().
+    """
+
+    def __init__(
+        self,
+        table_read: TableRead,
+        splits: List[Split],
+        seed: int = 0,
+        buffer_size: int = 1000,
+        max_buffer_input_splits: int = 10,
+    ):
+        super().__init__(table_read, splits)
+        self.seed = self._require_int(seed, "seed")
+        self.buffer_size = self._require_positive_int(buffer_size, 
"buffer_size")
+        self.max_buffer_input_splits = self._require_positive_int(
+            max_buffer_input_splits, "max_buffer_input_splits")
+        self._epoch = _share_epoch_with_torch_workers(0)
+
+    def __setstate__(self, state):
+        self.__dict__ = state
+        self._epoch = _share_epoch_with_torch_workers(self._epoch)
+
+    @property
+    def epoch(self) -> int:
+        return int(self._epoch)
+
+    @epoch.setter
+    def epoch(self, epoch: int) -> None:
+        epoch = self._require_int(epoch, "epoch")
+        self._epoch += epoch - self._epoch
+
+    @staticmethod
+    def _require_int(value: int, name: str) -> int:
+        if not isinstance(value, int):
+            raise ValueError("%s must be an int" % name)
+        return value
+
+    @staticmethod
+    def _require_positive_int(value: int, name: str) -> int:
+        if not isinstance(value, int) or value <= 0:
+            raise ValueError("%s must be a positive int" % name)
+        return value
+
+    def set_epoch(self, epoch: int) -> "TorchShuffledIterDataset":
+        self.epoch = epoch
+        return self
+
+    def __iter__(self):
+        worker_info = torch.utils.data.get_worker_info()
+        worker_id = worker_info.id if worker_info is not None else 0
+        splits_to_process = self._worker_splits(worker_info)
+
+        if self.max_buffer_input_splits == 1:
+            rows = self._iter_ordered_rows(splits_to_process)
+        else:
+            rows = self._iter_interleaved_rows(splits_to_process)
+        for row in self._iter_buffer_shuffled_rows(rows, worker_id):
+            yield row
+
+    def _iter_ordered_rows(self, splits: List[Split]) -> Iterator[dict]:
+        for offset_row in self.table_read.to_iterator(splits):
+            yield self._row_to_dict(offset_row)
+
+    def _iter_interleaved_rows(self, splits: List[Split]) -> Iterator[dict]:
+        if not splits:
+            return
+
+        split_iter = iter(splits)
+        active: List[Iterator] = []
+
+        def add_next_split() -> bool:
+            try:
+                split = next(split_iter)
+            except StopIteration:
+                return False
+            active.append(self.table_read.to_iterator([split]))
+            return True
+
+        for _ in range(min(self.max_buffer_input_splits, len(splits))):
+            add_next_split()
+
+        idx = 0
+        try:
+            while active:
+                if idx >= len(active):
+                    idx = 0
+                row_iter = active[idx]
+                try:
+                    offset_row = next(row_iter)
+                except StopIteration:
+                    self._close_iterator(row_iter)
+                    del active[idx]
+                    add_next_split()
+                    continue
+
+                yield self._row_to_dict(offset_row)
+                idx += 1
+        finally:
+            for row_iter in active:
+                self._close_iterator(row_iter)
+
+    @staticmethod
+    def _close_iterator(row_iter) -> None:
+        close = getattr(row_iter, "close", None)
+        if close is not None:
+            close()
+
+    def _iter_buffer_shuffled_rows(
+        self,
+        rows: Iterator[dict],
+        worker_id: int,
+    ) -> Iterator[dict]:
+        rng = random.Random(self.seed + self.epoch * 1000003 + worker_id)
+        buffer = []
+        for row in rows:
+            if len(buffer) < self.buffer_size:
+                buffer.append(row)
+                continue
+            idx = rng.randint(0, self.buffer_size - 1)
+            yield buffer[idx]
+            buffer[idx] = row
+
+        rng.shuffle(buffer)
+        for row in buffer:
+            yield row
diff --git 
a/paimon-python/pypaimon/read/scanner/chunk_shuffle_split_generator.py 
b/paimon-python/pypaimon/read/scanner/chunk_shuffle_split_generator.py
new file mode 100644
index 0000000000..504236e500
--- /dev/null
+++ b/paimon-python/pypaimon/read/scanner/chunk_shuffle_split_generator.py
@@ -0,0 +1,377 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import random
+from abc import abstractmethod
+from collections import defaultdict
+from dataclasses import dataclass
+from typing import Any, List, Optional, Tuple
+
+from pypaimon.globalindex.indexed_split import IndexedSplit
+from pypaimon.manifest.schema.data_file_meta import DataFileMeta
+from pypaimon.manifest.schema.manifest_entry import ManifestEntry
+from pypaimon.read.scanner.split_generator import AbstractSplitGenerator
+from pypaimon.read.sliced_split import SlicedSplit
+from pypaimon.read.split import DataSplit, Split
+from pypaimon.table.row.generic_row import GenericRow
+from pypaimon.utils.range import Range
+
+
+def _null_safe_partition_key(partition_values) -> tuple:
+    """Wrap each partition value with a None-aware tag so tuples that mix
+    null and non-null partition values can be ordered without raising
+    ``TypeError: '<' not supported between instances of 'NoneType' and 'str'``.
+    Paimon supports null partition values; Python 3 refuses to compare
+    None against str/int directly.
+    """
+    return tuple((v is None, v) for v in partition_values)
+
+
+@dataclass
+class _Chunk:
+    """A unit of work for one DataLoader read. ``segments`` carries
+    subclass-specific payload (file segments for append, aligned-group
+    segments for data evolution).
+    """
+    partition: GenericRow
+    bucket: int
+    segments: List[Any]
+
+
+class ChunkShuffleSplitGeneratorBase(AbstractSplitGenerator):
+    """Common scaffolding for deterministic chunk-shuffled split generation.
+
+    Pipeline (template method, in :meth:`create_splits`):
+      1. Stable-sort entries (key from :meth:`_sort_key`) so manifest-read
+         parallelism cannot bleed into the output.
+      2. Group by (partition, bucket); iterate groups in sorted-key order.
+      3. Per group, call :meth:`_slice_group_into_chunks` to produce a list
+         of segment lists — one segment list per chunk.
+      4. Wrap each chunk with its (partition, bucket) into ``_Chunk``,
+         concatenate across groups.
+      5. ``random.Random(seed).shuffle`` all chunks.
+      6. If sharded, take this worker's slice via balanced 
``_compute_shard_range``.
+      7. Map each chunk through :meth:`_chunk_to_split`.
+
+    Subclasses implement the three abstract hooks. Reader paths
+    (``RawFileSplitRead`` for append, ``DataEvolutionSplitRead`` for DE)
+    are unchanged because chunks ride on existing wrappers
+    (``SlicedSplit`` / ``IndexedSplit``).
+    """
+
+    def __init__(
+        self,
+        table,
+        target_split_size: int,
+        open_file_cost: int,
+        deletion_files_map=None,
+        seed: int = 0,
+        chunk_size: int = 0,
+    ):
+        super().__init__(table, target_split_size, open_file_cost, 
deletion_files_map)
+        self.seed = seed
+        self.chunk_size = chunk_size
+
+    def create_splits(self, file_entries: List[ManifestEntry]) -> List[Split]:
+        if not file_entries:
+            return []
+
+        sorted_entries = sorted(file_entries, key=self._sort_key)
+
+        partitioned: "defaultdict[Tuple[tuple, int], List[ManifestEntry]]" = 
defaultdict(list)
+        for entry in sorted_entries:
+            partitioned[(tuple(entry.partition.values), 
entry.bucket)].append(entry)
+
+        all_chunks: List[_Chunk] = []
+        for key in sorted(
+            partitioned.keys(),
+            key=lambda k: (_null_safe_partition_key(k[0]), k[1]),
+        ):
+            entries_in_group = partitioned[key]
+            partition_row = entries_in_group[0].partition
+            bucket = entries_in_group[0].bucket
+            # Materialize file_path once per unique file in this group.
+            seen_paths: set = set()
+            for entry in entries_in_group:
+                f = entry.file
+                if f.file_name in seen_paths:
+                    continue
+                seen_paths.add(f.file_name)
+                f.set_file_path(
+                    self.table.table_path,
+                    partition_row,
+                    bucket,
+                    self.default_part_value,
+                )
+            for segments in self._slice_group_into_chunks(entries_in_group):
+                all_chunks.append(_Chunk(partition_row, bucket, segments))
+
+        rng = random.Random(self.seed)
+        rng.shuffle(all_chunks)
+
+        if self.idx_of_this_subtask is not None:
+            start, end = self._compute_shard_range(len(all_chunks))
+            all_chunks = all_chunks[start:end]
+
+        return [self._chunk_to_split(c) for c in all_chunks]
+
+    @abstractmethod
+    def _sort_key(self, entry: ManifestEntry):
+        """Return a comparable, deterministic key for stable sort."""
+
+    @abstractmethod
+    def _slice_group_into_chunks(self, entries: List[ManifestEntry]) -> 
List[List[Any]]:
+        """Cut one (partition, bucket) group into chunks of segments.
+
+        Each returned inner list represents one chunk; segment shape is
+        subclass-defined.
+        """
+
+    @abstractmethod
+    def _chunk_to_split(self, chunk: _Chunk) -> Split:
+        """Wrap a chunk into a Split that the existing readers consume."""
+
+
+# ---------------------------------------------------------------------------
+# Append (non-DE, non-DV) implementation
+# ---------------------------------------------------------------------------
+
+
+@dataclass
+class _FileSegment:
+    """A contiguous slice of a data file inside one chunk.
+
+    start/end are half-open row offsets within the file when the chunk
+    boundary falls inside the file; both are None when the chunk owns
+    the full file (so SlicedSplit's shard_file_idx_map can skip it and
+    treat the file as full — see sliced_split.py:73-78).
+    """
+    file: DataFileMeta
+    start: Optional[int]
+    end: Optional[int]
+
+
+class AppendChunkShuffleSplitGenerator(ChunkShuffleSplitGeneratorBase):
+    """Chunk-shuffled splits for plain append tables (non-PK, non-DV, 
non-DE)."""
+
+    def _sort_key(self, entry: ManifestEntry):
+        return (
+            _null_safe_partition_key(entry.partition.values),
+            entry.bucket,
+            entry.file.file_name,
+        )
+
+    def _slice_group_into_chunks(
+        self, entries: List[ManifestEntry]
+    ) -> List[List[_FileSegment]]:
+        """Cut a (partition, bucket) group into chunks of at most
+        ``self.chunk_size`` rows. ``chunk_size`` is a hard upper bound:
+        the last chunk may be smaller, but no chunk exceeds it.
+        """
+        chunks: List[List[_FileSegment]] = []
+        current: List[_FileSegment] = []
+        current_rows = 0
+
+        for entry in entries:
+            file = entry.file
+            offset = 0
+            remaining = file.row_count
+            while remaining > 0:
+                avail = self.chunk_size - current_rows
+                if avail <= 0:
+                    chunks.append(current)
+                    current = []
+                    current_rows = 0
+                    avail = self.chunk_size
+
+                take = min(remaining, avail)
+
+                if take == file.row_count and offset == 0:
+                    current.append(_FileSegment(file, None, None))
+                else:
+                    current.append(_FileSegment(file, offset, offset + take))
+
+                current_rows += take
+                offset += take
+                remaining -= take
+
+        if current:
+            chunks.append(current)
+
+        return chunks
+
+    def _chunk_to_split(self, chunk: _Chunk) -> Split:
+        files: List[DataFileMeta] = []
+        shard_file_idx_map = {}
+        for seg in chunk.segments:
+            files.append(seg.file)
+            if seg.start is not None and seg.end is not None:
+                shard_file_idx_map[seg.file.file_name] = (seg.start, seg.end)
+
+        # set_file_path is already done once per unique file in
+        # ChunkShuffleSplitGeneratorBase.create_splits.
+
+        data_split = DataSplit(
+            files=files,
+            partition=chunk.partition,
+            bucket=chunk.bucket,
+            raw_convertible=True,
+            data_deletion_files=None,
+        )
+
+        if shard_file_idx_map:
+            return SlicedSplit(data_split, shard_file_idx_map)
+        return data_split
+
+
+# ---------------------------------------------------------------------------
+# Data Evolution implementation
+# ---------------------------------------------------------------------------
+
+
+@dataclass
+class _AlignedGroupSegment:
+    """A row_id sub-range over one row-id-aligned file group.
+
+    ``files`` is the entire group (may include blob/vector siblings),
+    so the reader sees every column file even when only a slice of the
+    group's row_id range lands in this chunk. ``row_range`` is the
+    inclusive global row_id range this segment owns.
+    """
+    files: List[DataFileMeta]
+    row_range: Range
+
+
+class DataEvolutionChunkShuffleSplitGenerator(ChunkShuffleSplitGeneratorBase):
+    """Chunk-shuffled splits for data-evolution append tables.
+
+    The minimum cuttable unit is a row_id-aligned file group: cutting
+    inside one group would orphan column files relative to the row_id
+    range, so we keep groups intact and only slice along their row_id
+    axis. Each chunk maps to an :class:`IndexedSplit` whose ``row_ranges``
+    bound the readable slice for that chunk.
+    """
+
+    def _sort_key(self, entry: ManifestEntry):
+        first_row_id = (
+            entry.file.first_row_id
+            if entry.file.first_row_id is not None
+            else float('-inf')
+        )
+        is_special = 1 if (
+            DataFileMeta.is_blob_file(entry.file.file_name)
+            or DataFileMeta.is_vector_file(entry.file.file_name)
+        ) else 0
+        return (
+            _null_safe_partition_key(entry.partition.values),
+            entry.bucket,
+            first_row_id,
+            is_special,
+            entry.file.file_name,
+        )
+
+    def _slice_group_into_chunks(
+        self, entries: List[ManifestEntry]
+    ) -> List[List[_AlignedGroupSegment]]:
+        files = [e.file for e in entries]
+        # (Range, [files]) pairs sorted by row_id — see helper docstring.
+        aligned_groups = self._split_by_row_id_with_range(files)
+
+        chunks: List[List[_AlignedGroupSegment]] = []
+        current: List[_AlignedGroupSegment] = []
+        current_rows = 0
+
+        for group_range, group_files in aligned_groups:
+            offset = 0
+            group_rows = group_range.count()
+            while offset < group_rows:
+                avail = self.chunk_size - current_rows
+                if avail <= 0:
+                    chunks.append(current)
+                    current = []
+                    current_rows = 0
+                    avail = self.chunk_size
+
+                take = min(group_rows - offset, avail)
+                seg_range = Range(
+                    group_range.from_ + offset,
+                    group_range.from_ + offset + take - 1,
+                )
+                current.append(_AlignedGroupSegment(group_files, seg_range))
+                current_rows += take
+                offset += take
+
+        if current:
+            chunks.append(current)
+
+        return chunks
+
+    def _chunk_to_split(self, chunk: _Chunk) -> Split:
+        segments = chunk.segments
+        if len(segments) == 1:
+            all_files = segments[0].files
+            row_ranges = [segments[0].row_range]
+        else:
+            all_files = []
+            row_ranges = []
+            for seg in segments:
+                all_files.extend(seg.files)
+                row_ranges.append(seg.row_range)
+            row_ranges.sort(key=lambda r: r.from_)
+
+        data_split = DataSplit(
+            files=all_files,
+            partition=chunk.partition,
+            bucket=chunk.bucket,
+            raw_convertible=False,
+            data_deletion_files=None,
+        )
+        return IndexedSplit(data_split, row_ranges, scores=None)
+
+    @staticmethod
+    def _split_by_row_id_with_range(
+        files: List[DataFileMeta],
+    ) -> List[Tuple[Range, List[DataFileMeta]]]:
+        """Group files by overlapping row_id range, returning (range, files)
+        pairs sorted by ``range.from_``.
+
+        Mirrors :meth:`DataEvolutionSplitGenerator._split_by_row_id` but
+        also returns the merged row_id range per group, which the chunk
+        slicer needs to drive row-count accumulation.
+        """
+        list_ranges = []
+        for f in files:
+            file_range = f.row_id_range()
+            if file_range is None:
+                raise ValueError(
+                    "chunk_shuffle for data evolution tables requires row 
tracking; "
+                    f"file {f.file_name} is missing first_row_id"
+                )
+            list_ranges.append(file_range)
+        if not list_ranges:
+            return []
+        sorted_ranges = Range.sort_and_merge_overlap(list_ranges, True, False)
+
+        range_to_files: "dict[Range, List[DataFileMeta]]" = {}
+        for f in files:
+            file_range = f.row_id_range()
+            for r in sorted_ranges:
+                if r.overlaps(file_range):
+                    range_to_files.setdefault(r, []).append(f)
+                    break
+
+        return sorted(range_to_files.items(), key=lambda kv: kv[0].from_)
diff --git a/paimon-python/pypaimon/read/scanner/file_scanner.py 
b/paimon-python/pypaimon/read/scanner/file_scanner.py
index 5424d2d519..650bbe8ca4 100755
--- a/paimon-python/pypaimon/read/scanner/file_scanner.py
+++ b/paimon-python/pypaimon/read/scanner/file_scanner.py
@@ -40,6 +40,10 @@ from pypaimon.read.scanner.append_table_split_generator 
import \
     AppendTableSplitGenerator
 from pypaimon.read.scanner.bucket_select_converter import \
     create_bucket_selector
+from pypaimon.read.scanner.chunk_shuffle_split_generator import (
+    AppendChunkShuffleSplitGenerator,
+    DataEvolutionChunkShuffleSplitGenerator,
+)
 from pypaimon.read.scanner.data_evolution_split_generator import \
     DataEvolutionSplitGenerator
 from pypaimon.read.scanner.primary_key_table_split_generator import \
@@ -204,6 +208,7 @@ class FileScanner:
         self.number_of_para_subtasks = None
         self.start_pos_of_this_subtask = None
         self.end_pos_of_this_subtask = None
+        self.chunk_shuffle: Optional[Tuple[int, int]] = None
 
         self.only_read_real_buckets = options.bucket() == 
BucketMode.POSTPONE_BUCKET.value
         self.data_evolution = options.data_evolution_enabled()
@@ -243,7 +248,34 @@ class FileScanner:
     def scan(self) -> Plan:
         start_ms = time.time() * 1000
         # Create appropriate split generator based on table type
-        if self.table.is_primary_key_table:
+        if self.chunk_shuffle is not None:
+            self._validate_chunk_shuffle_compat()
+            seed, chunk_size = self.chunk_shuffle
+            # Both append and DE paths use plan_files() directly: the
+            # predicate is partition-only (enforced by
+            # _validate_chunk_shuffle_compat), so manifest_entry-level
+            # partition pruning in plan_files() is the only filter we
+            # want — no row_id range pushdown, no global index lookup.
+            entries = self.plan_files()
+            if self.data_evolution:
+                split_generator = DataEvolutionChunkShuffleSplitGenerator(
+                    self.table,
+                    self.target_split_size,
+                    self.open_file_cost,
+                    self._deletion_files_map(entries),
+                    seed=seed,
+                    chunk_size=chunk_size,
+                )
+            else:
+                split_generator = AppendChunkShuffleSplitGenerator(
+                    self.table,
+                    self.target_split_size,
+                    self.open_file_cost,
+                    self._deletion_files_map(entries),
+                    seed=seed,
+                    chunk_size=chunk_size,
+                )
+        elif self.table.is_primary_key_table:
             entries = self.plan_files()
             split_generator = PrimaryKeyTableSplitGenerator(
                 self.table,
@@ -441,6 +473,38 @@ class FileScanner:
         plan = self.scan()
         return plan, self.scan_stats
 
+    def with_chunk_shuffle(self, seed: int, chunk_size: int) -> 'FileScanner':
+        if not isinstance(seed, int):
+            raise ValueError("chunk_shuffle seed must be an int")
+        if not isinstance(chunk_size, int) or chunk_size <= 0:
+            raise ValueError("chunk_shuffle chunk_size must be a positive int")
+        self.chunk_shuffle = (seed, chunk_size)
+        return self
+
+    def _validate_chunk_shuffle_compat(self) -> None:
+        if self.table.is_primary_key_table:
+            raise ValueError("chunk_shuffle only supports append tables")
+        if self.deletion_vectors_enabled:
+            raise ValueError("chunk_shuffle not supported with deletion 
vectors")
+        if self.start_pos_of_this_subtask is not None:
+            raise ValueError("chunk_shuffle cannot combine with with_slice")
+        if self.limit is not None:
+            raise ValueError("chunk_shuffle cannot combine with limit")
+        if self._global_index_result is not None:
+            raise ValueError("chunk_shuffle cannot combine with global index")
+        # Only partition predicates are allowed: row-level / column-level
+        # predicates would silently shrink each chunk's effective row count,
+        # breaking the chunk_size contract DataLoader callers expect.
+        if self.predicate is not None:
+            partition_keys = set(self.table.partition_keys or [])
+            non_partition_fields = _get_all_fields(self.predicate) - 
partition_keys
+            if non_partition_fields:
+                raise ValueError(
+                    "chunk_shuffle predicate must reference only partition 
keys; "
+                    "got non-partition fields: "
+                    f"{sorted(non_partition_fields)}"
+                )
+
     def _apply_push_down_limit(self, splits: List[DataSplit]) -> 
List[DataSplit]:
         """Mirror Java ``DataTableBatchScan.applyPushDownLimit``: sum the
         DV-aware ``merged_row_count`` (== Java ``Split.mergedRowCount()``)
diff --git a/paimon-python/pypaimon/read/table_read.py 
b/paimon-python/pypaimon/read/table_read.py
index 826b2b4024..57e704b2b5 100644
--- a/paimon-python/pypaimon/read/table_read.py
+++ b/paimon-python/pypaimon/read/table_read.py
@@ -511,8 +511,28 @@ class TableRead:
         splits: List[Split],
         streaming: bool = False,
         prefetch_concurrency: int = 1,
+        *,
+        shuffle: bool = False,
+        seed: int = 0,
+        buffer_size: int = 1000,
+        max_buffer_input_splits: int = 10,
     ) -> "torch.utils.data.Dataset":
         """Wrap Paimon table data to PyTorch Dataset."""
+        if shuffle:
+            if not streaming:
+                raise ValueError("shuffle=True only supports streaming=True")
+            if prefetch_concurrency > 1:
+                raise ValueError("shuffle=True does not support 
prefetch_concurrency > 1")
+            from pypaimon.read.datasource.torch_dataset import 
TorchShuffledIterDataset
+            dataset = TorchShuffledIterDataset(
+                self,
+                splits,
+                seed=seed,
+                buffer_size=buffer_size,
+                max_buffer_input_splits=max_buffer_input_splits,
+            )
+            return dataset
+
         if streaming:
             from pypaimon.read.datasource.torch_dataset import TorchIterDataset
             dataset = TorchIterDataset(self, splits, prefetch_concurrency)
diff --git a/paimon-python/pypaimon/read/table_scan.py 
b/paimon-python/pypaimon/read/table_scan.py
index 03a1c8b062..36568c618f 100755
--- a/paimon-python/pypaimon/read/table_scan.py
+++ b/paimon-python/pypaimon/read/table_scan.py
@@ -161,6 +161,10 @@ class TableScan:
         self.file_scanner.with_global_index_result(result)
         return self
 
+    def with_chunk_shuffle(self, seed: int, chunk_size: int) -> 'TableScan':
+        self.file_scanner.with_chunk_shuffle(seed, chunk_size)
+        return self
+
     def _validate_scan_mode(self):
         """Validate scan.mode against companion options using a whitelist 
approach.
 
diff --git 
a/paimon-python/pypaimon/tests/scanner/chunk_shuffle_split_generator_test.py 
b/paimon-python/pypaimon/tests/scanner/chunk_shuffle_split_generator_test.py
new file mode 100644
index 0000000000..be50cfa637
--- /dev/null
+++ b/paimon-python/pypaimon/tests/scanner/chunk_shuffle_split_generator_test.py
@@ -0,0 +1,800 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Tests for ChunkShuffleSplitGenerator and TableScan.with_chunk_shuffle.
+
+Algorithmic tests use Mock entries so they don't touch disk; the
+end-to-end test writes a real append table and validates that all
+workers together cover the data exactly once.
+"""
+
+import os
+import shutil
+import tempfile
+import unittest
+from unittest.mock import Mock
+
+import pyarrow as pa
+
+from pypaimon import CatalogFactory, Schema
+from pypaimon.globalindex.indexed_split import IndexedSplit
+from pypaimon.manifest.schema.data_file_meta import DataFileMeta
+from pypaimon.read.scanner.chunk_shuffle_split_generator import (
+    AppendChunkShuffleSplitGenerator,
+    DataEvolutionChunkShuffleSplitGenerator,
+)
+from pypaimon.read.sliced_split import SlicedSplit
+from pypaimon.read.split import DataSplit
+from pypaimon.utils.range import Range
+
+
+def _mock_table(table_path='/tmp/_chunk_shuffle_test_path'):
+    table = Mock()
+    table.table_path = table_path
+    table.options = Mock()
+    return table
+
+
+def _mock_entry(partition_values, bucket, file_name, row_count, 
file_size=1024):
+    entry = Mock()
+    entry.partition = Mock()
+    entry.partition.values = partition_values
+    entry.bucket = bucket
+    entry.file = Mock()
+    entry.file.file_name = file_name
+    entry.file.file_size = file_size
+    entry.file.row_count = row_count
+    # Swallow set_file_path so we don't need to mock partition path encoding.
+    entry.file.set_file_path = Mock()
+    return entry
+
+
+def _make_generator(seed, chunk_size, table=None):
+    if table is None:
+        table = _mock_table()
+    return AppendChunkShuffleSplitGenerator(
+        table,
+        target_split_size=128 * 1024 * 1024,
+        open_file_cost=4 * 1024 * 1024,
+        deletion_files_map=None,
+        seed=seed,
+        chunk_size=chunk_size,
+    )
+
+
+def _make_de_generator(seed, chunk_size, table=None):
+    if table is None:
+        table = _mock_table()
+    return DataEvolutionChunkShuffleSplitGenerator(
+        table,
+        target_split_size=128 * 1024 * 1024,
+        open_file_cost=4 * 1024 * 1024,
+        deletion_files_map=None,
+        seed=seed,
+        chunk_size=chunk_size,
+    )
+
+
+def _mock_de_entry(partition_values, bucket, file_name, first_row_id, 
row_count, file_size=1024):
+    """A DE-flavoured mock entry: file carries first_row_id and a real
+    Range so :meth:`row_id_range` and ``Range.overlaps`` work."""
+    entry = Mock()
+    entry.partition = Mock()
+    entry.partition.values = partition_values
+    entry.bucket = bucket
+    file = Mock(spec=DataFileMeta)
+    file.file_name = file_name
+    file.file_size = file_size
+    file.row_count = row_count
+    file.first_row_id = first_row_id
+    file.row_id_range = lambda f=first_row_id, c=row_count: Range(f, f + c - 1)
+    file.set_file_path = Mock()
+    entry.file = file
+    return entry
+
+
+def _split_signature(split):
+    """A stable, comparable identity for a split — what the worker would 
actually read."""
+    if isinstance(split, SlicedSplit):
+        underlying = split.data_split()
+        files = tuple(f.file_name for f in underlying.files)
+        idx_map = tuple(sorted(split.shard_file_idx_map().items()))
+        return (tuple(underlying.partition.values), underlying.bucket, files, 
idx_map)
+    if isinstance(split, IndexedSplit):
+        underlying = split.data_split()
+        files = tuple(sorted(f.file_name for f in underlying.files))
+        ranges = tuple((r.from_, r.to) for r in split.row_ranges())
+        return (tuple(underlying.partition.values), underlying.bucket, files, 
ranges)
+    if isinstance(split, DataSplit):
+        files = tuple(f.file_name for f in split.files)
+        return (tuple(split.partition.values), split.bucket, files, ())
+    raise AssertionError("unexpected split type: %r" % type(split))
+
+
+def _split_rows(split):
+    """Effective row count this split actually exposes."""
+    return split.row_count
+
+
+class ChunkShuffleSplitGeneratorAlgoTest(unittest.TestCase):
+
+    def test_no_entries_returns_empty(self):
+        gen = _make_generator(seed=1, chunk_size=100)
+        self.assertEqual(gen.create_splits([]), [])
+
+    def test_full_files_no_truncation(self):
+        entries = [
+            _mock_entry([], 0, 'f1', 100),
+            _mock_entry([], 0, 'f2', 100),
+            _mock_entry([], 0, 'f3', 100),
+        ]
+        gen = _make_generator(seed=1, chunk_size=100)
+        splits = gen.create_splits(entries)
+        # 3 chunks, each holding exactly one whole file → all DataSplit, no 
SlicedSplit
+        self.assertEqual(len(splits), 3)
+        for s in splits:
+            self.assertIsInstance(s, DataSplit)
+            self.assertEqual(s.row_count, 100)
+
+    def test_chunk_truncates_inside_file(self):
+        # one file of 250 rows, chunk_size 100 → 3 chunks: 100, 100, 50
+        entries = [_mock_entry([], 0, 'f1', 250)]
+        gen = _make_generator(seed=1, chunk_size=100)
+        splits = gen.create_splits(entries)
+        self.assertEqual(len(splits), 3)
+        # All three chunks slice the same file → all SlicedSplit
+        for s in splits:
+            self.assertIsInstance(s, SlicedSplit)
+        # union of (start, end) intervals must cover [0, 250)
+        intervals = sorted(s.shard_file_idx_map()['f1'] for s in splits)
+        self.assertEqual(intervals, [(0, 100), (100, 200), (200, 250)])
+        total = sum(end - start for start, end in intervals)
+        self.assertEqual(total, 250)
+
+    def test_chunk_spans_multiple_files(self):
+        # f1=30, f2=30, f3=30, chunk_size=50 → chunks: [f1(30)+f2(0,20)], 
[f2(20,30)+f3(0,40 cap 30=30)] ...
+        entries = [
+            _mock_entry([], 0, 'f1', 30),
+            _mock_entry([], 0, 'f2', 30),
+            _mock_entry([], 0, 'f3', 30),
+        ]
+        gen = _make_generator(seed=1, chunk_size=50)
+        splits = gen.create_splits(entries)
+        # total 90 rows, chunk_size 50 → 2 chunks (50 + 40)
+        self.assertEqual(len(splits), 2)
+        total_rows = sum(_split_rows(s) for s in splits)
+        self.assertEqual(total_rows, 90)
+
+    def test_chunk_size_larger_than_total(self):
+        entries = [
+            _mock_entry([], 0, 'f1', 30),
+            _mock_entry([], 0, 'f2', 30),
+        ]
+        gen = _make_generator(seed=1, chunk_size=1000)
+        splits = gen.create_splits(entries)
+        self.assertEqual(len(splits), 1)
+        # No truncation — full files inside one chunk → DataSplit not 
SlicedSplit
+        self.assertIsInstance(splits[0], DataSplit)
+        self.assertEqual(_split_rows(splits[0]), 60)
+
+    def test_deterministic_same_seed_same_order(self):
+        entries = [_mock_entry([], 0, f'f{i}', 100) for i in range(20)]
+        gen1 = _make_generator(seed=42, chunk_size=50)
+        gen2 = _make_generator(seed=42, chunk_size=50)
+        splits1 = gen1.create_splits(entries)
+        splits2 = gen2.create_splits(entries)
+        self.assertEqual(
+            [_split_signature(s) for s in splits1],
+            [_split_signature(s) for s in splits2],
+        )
+
+    def test_different_seed_different_order(self):
+        entries = [_mock_entry([], 0, f'f{i}', 100) for i in range(50)]
+        gen1 = _make_generator(seed=1, chunk_size=100)
+        gen2 = _make_generator(seed=2, chunk_size=100)
+        sigs1 = [_split_signature(s) for s in gen1.create_splits(entries)]
+        sigs2 = [_split_signature(s) for s in gen2.create_splits(entries)]
+        # Same set of chunks, different order — high probability they differ 
on 50 items
+        self.assertEqual(sorted(sigs1), sorted(sigs2))
+        self.assertNotEqual(sigs1, sigs2)
+
+    def test_shuffle_actually_reorders(self):
+        # 20 files in scan order f0..f19. After shuffle the file order should 
not be sorted.
+        entries = [_mock_entry([], 0, f'f{i:02d}', 100) for i in range(20)]
+        gen = _make_generator(seed=42, chunk_size=100)
+        splits = gen.create_splits(entries)
+        file_names = [s.files[0].file_name for s in splits]
+        self.assertNotEqual(file_names, sorted(file_names))
+
+    def test_shard_round_trip_no_overlap_no_loss(self):
+        # 13 files × 100 rows = 1300 rows. 4 workers.
+        entries = [_mock_entry([], 0, f'f{i:02d}', 100) for i in range(13)]
+        num_workers = 4
+        all_sigs = []
+        total_rows = 0
+        for worker in range(num_workers):
+            gen = _make_generator(seed=7, chunk_size=100)
+            gen.with_shard(worker, num_workers)
+            splits = gen.create_splits(list(entries))  # copy: shuffle is 
in-place on chunks list
+            for s in splits:
+                all_sigs.append(_split_signature(s))
+                total_rows += _split_rows(s)
+        self.assertEqual(total_rows, 13 * 100)
+        # No duplicate chunks across workers
+        self.assertEqual(len(all_sigs), len(set(all_sigs)))
+        # All chunks together equal an unsharded run
+        unsharded = _make_generator(seed=7, 
chunk_size=100).create_splits(list(entries))
+        self.assertEqual(
+            sorted(all_sigs),
+            sorted(_split_signature(s) for s in unsharded),
+        )
+
+    def test_shard_balanced_distribution(self):
+        # 10 chunks across 3 workers → 4, 3, 3 (front-loaded by 
_compute_shard_range)
+        entries = [_mock_entry([], 0, f'f{i:02d}', 100) for i in range(10)]
+        counts = []
+        for worker in range(3):
+            gen = _make_generator(seed=0, chunk_size=100)
+            gen.with_shard(worker, 3)
+            counts.append(len(gen.create_splits(list(entries))))
+        self.assertEqual(sorted(counts, reverse=True), [4, 3, 3])
+
+    def test_chunks_fewer_than_workers(self):
+        # 2 chunks, 5 workers → 3 workers get nothing
+        entries = [_mock_entry([], 0, f'f{i}', 100) for i in range(2)]
+        empties = 0
+        non_empties = 0
+        for worker in range(5):
+            gen = _make_generator(seed=0, chunk_size=100)
+            gen.with_shard(worker, 5)
+            n = len(gen.create_splits(list(entries)))
+            if n == 0:
+                empties += 1
+            else:
+                non_empties += 1
+                self.assertEqual(n, 1)
+        self.assertEqual(empties, 3)
+        self.assertEqual(non_empties, 2)
+
+    def test_multi_partition_no_chunk_crosses_partition(self):
+        entries = [
+            _mock_entry(['p1'], 0, 'f1', 100),
+            _mock_entry(['p1'], 0, 'f2', 100),
+            _mock_entry(['p2'], 0, 'f3', 100),
+            _mock_entry(['p2'], 0, 'f4', 100),
+        ]
+        gen = _make_generator(seed=0, chunk_size=100)
+        splits = gen.create_splits(entries)
+        # Each split's underlying files come from one partition only
+        for s in splits:
+            partitions_in_files = set()
+            data_split = s.data_split() if isinstance(s, SlicedSplit) else s
+            partitions_in_files.add(tuple(data_split.partition.values))
+            self.assertEqual(len(partitions_in_files), 1)
+
+    def test_null_and_non_null_partitions_sort_safely(self):
+        # Mixing null and non-null partition values used to raise
+        # ``TypeError: '<' not supported between instances of 'NoneType' and 
'str'``
+        # before _null_safe_partition_key. Validate planning succeeds and
+        # both partitions produce splits.
+        entries = [
+            _mock_entry(['p1'], 0, 'f1', 100),
+            _mock_entry([None], 0, 'f2', 100),
+            _mock_entry(['p2'], 0, 'f3', 100),
+        ]
+        gen = _make_generator(seed=1, chunk_size=100)
+        splits = gen.create_splits(entries)
+        self.assertEqual(len(splits), 3)
+        partitions = {tuple(_split_signature(s)[0]) for s in splits}
+        self.assertEqual(partitions, {('p1',), ('p2',), (None,)})
+
+    def test_input_order_does_not_affect_output_when_same_files(self):
+        """Manifest read parallelism shouldn't bleed through — sorting is 
internal."""
+        a = _mock_entry([], 0, 'f1', 100)
+        b = _mock_entry([], 0, 'f2', 100)
+        c = _mock_entry([], 0, 'f3', 100)
+        gen1 = _make_generator(seed=99, chunk_size=100)
+        gen2 = _make_generator(seed=99, chunk_size=100)
+        sigs1 = [_split_signature(s) for s in gen1.create_splits([a, b, c])]
+        sigs2 = [_split_signature(s) for s in gen2.create_splits([c, a, b])]
+        self.assertEqual(sigs1, sigs2)
+
+
+class ChunkShuffleEndToEndTest(unittest.TestCase):
+    """Real append table → with_chunk_shuffle → multiple workers → union == 
original."""
+
+    @classmethod
+    def setUpClass(cls):
+        cls.tempdir = tempfile.mkdtemp()
+        cls.warehouse = os.path.join(cls.tempdir, 'warehouse')
+        cls.catalog = CatalogFactory.create({'warehouse': cls.warehouse})
+        cls.catalog.create_database('default', True)
+
+    @classmethod
+    def tearDownClass(cls):
+        shutil.rmtree(cls.tempdir, ignore_errors=True)
+
+    def _create_append_table(self, name, partition_keys=None):
+        pa_schema = pa.schema([
+            ('id', pa.int64()),
+            ('value', pa.string()),
+            ('part', pa.string()),
+        ])
+        schema = Schema.from_pyarrow_schema(
+            pa_schema, partition_keys=partition_keys or [])
+        identifier = f'default.{name}'
+        self.catalog.create_table(identifier, schema, False)
+        return self.catalog.get_table(identifier), pa_schema
+
+    def _write_n_batches(self, table, pa_schema, batches):
+        wb = table.new_batch_write_builder()
+        for batch in batches:
+            tw = wb.new_write()
+            tc = wb.new_commit()
+            tw.write_arrow(pa.Table.from_pydict(batch, schema=pa_schema))
+            tc.commit(tw.prepare_commit())
+            tw.close()
+            tc.close()
+
+    def test_workers_union_equals_full_table(self):
+        table, pa_schema = self._create_append_table('cs_union')
+        # 4 commits × 50 rows = 200 rows across several files
+        batches = []
+        for c in range(4):
+            base = c * 50
+            batches.append({
+                'id': list(range(base, base + 50)),
+                'value': [f'v{i}' for i in range(base, base + 50)],
+                'part': ['p1'] * 50,
+            })
+        self._write_n_batches(table, pa_schema, batches)
+
+        read_builder = table.new_read_builder()
+        table_read = read_builder.new_read()
+
+        num_workers = 3
+        worker_tables = []
+        for w in range(num_workers):
+            scan = read_builder.new_scan() \
+                .with_chunk_shuffle(seed=123, chunk_size=37) \
+                .with_shard(w, num_workers)
+            splits = scan.plan().splits()
+            if splits:
+                worker_tables.append(table_read.to_arrow(splits))
+
+        actual = pa.concat_tables(worker_tables).sort_by('id') if 
worker_tables else None
+        self.assertIsNotNone(actual)
+        self.assertEqual(actual.num_rows, 200)
+        self.assertEqual(actual.column('id').to_pylist(), list(range(200)))
+
+    def test_deterministic_plan_across_calls(self):
+        table, pa_schema = self._create_append_table('cs_determinism')
+        self._write_n_batches(table, pa_schema, [{
+            'id': list(range(100)),
+            'value': [f'v{i}' for i in range(100)],
+            'part': ['p'] * 100,
+        }])
+
+        def plan_files(worker):
+            scan = table.new_read_builder().new_scan() \
+                .with_chunk_shuffle(seed=42, chunk_size=20) \
+                .with_shard(worker, 3)
+            return [_split_signature(s) for s in scan.plan().splits()]
+
+        for worker in range(3):
+            self.assertEqual(plan_files(worker), plan_files(worker))
+
+
+class ChunkShuffleCompatibilityTest(unittest.TestCase):
+    """Validates the reject-on-incompatible-combination matrix."""
+
+    @classmethod
+    def setUpClass(cls):
+        cls.tempdir = tempfile.mkdtemp()
+        cls.warehouse = os.path.join(cls.tempdir, 'warehouse')
+        cls.catalog = CatalogFactory.create({'warehouse': cls.warehouse})
+        cls.catalog.create_database('default', True)
+
+    @classmethod
+    def tearDownClass(cls):
+        shutil.rmtree(cls.tempdir, ignore_errors=True)
+
+    def _append_table(self, name, options=None, partition_keys=None):
+        if partition_keys:
+            pa_schema = pa.schema([
+                ('id', pa.int64()),
+                ('value', pa.string()),
+                ('part', pa.string()),
+            ])
+        else:
+            pa_schema = pa.schema([('id', pa.int64()), ('value', pa.string())])
+        schema = Schema.from_pyarrow_schema(
+            pa_schema, partition_keys=partition_keys, options=options or {})
+        self.catalog.create_table(f'default.{name}', schema, False)
+        return self.catalog.get_table(f'default.{name}')
+
+    def _pk_table(self, name):
+        pa_schema = pa.schema([
+            pa.field('id', pa.int64(), nullable=False),
+            ('value', pa.string()),
+        ])
+        schema = Schema.from_pyarrow_schema(
+            pa_schema, primary_keys=['id'], options={'bucket': '1'})
+        self.catalog.create_table(f'default.{name}', schema, False)
+        return self.catalog.get_table(f'default.{name}')
+
+    def test_pk_table_rejected(self):
+        table = self._pk_table('cs_pk')
+        scan = table.new_read_builder().new_scan()
+        scan.with_chunk_shuffle(seed=1, chunk_size=100)
+        with self.assertRaisesRegex(ValueError, "only supports append tables"):
+            scan.plan()
+
+    def test_dv_table_rejected(self):
+        table = self._append_table('cs_dv', 
options={'deletion-vectors.enabled': 'true'})
+        scan = table.new_read_builder().new_scan()
+        scan.with_chunk_shuffle(seed=1, chunk_size=100)
+        with self.assertRaisesRegex(ValueError, "deletion vectors"):
+            scan.plan()
+
+    def test_with_slice_then_chunk_shuffle_rejected(self):
+        table = self._append_table('cs_slice')
+        scan = table.new_read_builder().new_scan()
+        scan.with_slice(0, 100).with_chunk_shuffle(seed=1, chunk_size=100)
+        with self.assertRaisesRegex(ValueError, "with_slice"):
+            scan.plan()
+
+    def test_limit_with_chunk_shuffle_rejected(self):
+        table = self._append_table('cs_limit')
+        scan = table.new_read_builder().with_limit(50).new_scan()
+        scan.with_chunk_shuffle(seed=1, chunk_size=100)
+        with self.assertRaisesRegex(ValueError, "limit"):
+            scan.plan()
+
+    def test_invalid_chunk_size(self):
+        table = self._append_table('cs_invalid')
+        scan = table.new_read_builder().new_scan()
+        with self.assertRaisesRegex(ValueError, "chunk_size"):
+            scan.with_chunk_shuffle(seed=1, chunk_size=0)
+        with self.assertRaisesRegex(ValueError, "chunk_size"):
+            scan.with_chunk_shuffle(seed=1, chunk_size=-5)
+
+    def test_column_predicate_rejected(self):
+        # Non-partition predicate would silently shrink effective chunk
+        # row counts inside the reader → not allowed.
+        table = self._append_table('cs_col_pred', partition_keys=['part'])
+        rb = table.new_read_builder()
+        col_pred = rb.new_predicate_builder().equal('id', 5)
+        rb = rb.with_filter(col_pred)
+        scan = rb.new_scan().with_chunk_shuffle(seed=1, chunk_size=10)
+        with self.assertRaisesRegex(ValueError, "partition keys"):
+            scan.plan()
+
+    def test_partition_predicate_allowed(self):
+        # Filter is partition-only → must succeed and read only the
+        # matching partition.
+        table, pa_schema = self._partitioned_table_with_data('cs_part_pred')
+
+        rb = table.new_read_builder()
+        pred = rb.new_predicate_builder().equal('part', 'p1')
+        scan = rb.with_filter(pred).new_scan() \
+            .with_chunk_shuffle(seed=1, chunk_size=10)
+        plan = scan.plan()
+        # All splits should be from partition 'p1'
+        for split in plan.splits():
+            partition_values = split.partition.values
+            self.assertEqual(tuple(partition_values), ('p1',))
+
+    def _partitioned_table_with_data(self, name):
+        pa_schema = pa.schema([
+            ('id', pa.int64()),
+            ('value', pa.string()),
+            ('part', pa.string()),
+        ])
+        schema = Schema.from_pyarrow_schema(pa_schema, partition_keys=['part'])
+        identifier = f'default.{name}'
+        self.catalog.create_table(identifier, schema, False)
+        table = self.catalog.get_table(identifier)
+        wb = table.new_batch_write_builder()
+        for part, ids in [('p1', range(50)), ('p2', range(50, 100))]:
+            tw = wb.new_write()
+            tc = wb.new_commit()
+            tw.write_arrow(pa.Table.from_pydict(
+                {'id': list(ids),
+                 'value': [f'v{i}' for i in ids],
+                 'part': [part] * 50},
+                schema=pa_schema,
+            ))
+            tc.commit(tw.prepare_commit())
+            tw.close()
+            tc.close()
+        return table, pa_schema
+
+
+class DataEvolutionChunkShuffleAlgoTest(unittest.TestCase):
+    """Mock-based tests for the DE chunk slicer."""
+
+    def test_no_entries_returns_empty(self):
+        gen = _make_de_generator(seed=1, chunk_size=100)
+        self.assertEqual(gen.create_splits([]), [])
+
+    def test_full_aligned_groups_one_per_chunk(self):
+        # Three commits of 100 rows each → three aligned groups.
+        # chunk_size = 100 → 3 chunks, each holding one group whole.
+        entries = [
+            _mock_de_entry([], 0, 'g0.parquet', 0, 100),
+            _mock_de_entry([], 0, 'g1.parquet', 100, 100),
+            _mock_de_entry([], 0, 'g2.parquet', 200, 100),
+        ]
+        gen = _make_de_generator(seed=1, chunk_size=100)
+        splits = gen.create_splits(entries)
+        self.assertEqual(len(splits), 3)
+        for s in splits:
+            self.assertIsInstance(s, IndexedSplit)
+            self.assertEqual(s.row_count, 100)
+            self.assertEqual(len(s.row_ranges()), 1)
+
+    def test_aligned_group_split_across_chunks(self):
+        # One 250-row group, chunk_size=100 → 3 chunks (100, 100, 50).
+        # All three chunks reference the SAME aligned group's files but
+        # each carries a distinct row_range slice.
+        entries = [_mock_de_entry([], 0, 'g0.parquet', 1000, 250)]
+        gen = _make_de_generator(seed=1, chunk_size=100)
+        splits = gen.create_splits(entries)
+        self.assertEqual(len(splits), 3)
+
+        # Union of the three chunks' row_ranges must cover the whole group 
[1000, 1249].
+        ranges = []
+        for s in splits:
+            self.assertIsInstance(s, IndexedSplit)
+            ranges.extend((r.from_, r.to) for r in s.row_ranges())
+        ranges.sort()
+        self.assertEqual(ranges, [(1000, 1099), (1100, 1199), (1200, 1249)])
+        total = sum(r[1] - r[0] + 1 for r in ranges)
+        self.assertEqual(total, 250)
+
+    def test_chunk_pulls_in_blob_siblings(self):
+        # One aligned group with a main parquet and a blob sibling sharing the
+        # row_id range. A single chunk must include BOTH files so the reader
+        # can union the columns.
+        entries = [
+            _mock_de_entry([], 0, 'g0.parquet', 0, 100),
+            _mock_de_entry([], 0, 'g0.blob', 0, 100),  # .blob ext → 
is_blob_file
+        ]
+        gen = _make_de_generator(seed=1, chunk_size=100)
+        splits = gen.create_splits(entries)
+        self.assertEqual(len(splits), 1)
+        files = sorted(f.file_name for f in splits[0].files)
+        self.assertEqual(files, ['g0.blob', 'g0.parquet'])
+
+    def test_blob_propagates_when_group_split(self):
+        # Same scenario but chunk_size halves the group → the blob sibling
+        # must appear in BOTH chunk splits.
+        entries = [
+            _mock_de_entry([], 0, 'g0.parquet', 0, 100),
+            _mock_de_entry([], 0, 'g0.blob', 0, 100),
+        ]
+        gen = _make_de_generator(seed=1, chunk_size=50)
+        splits = gen.create_splits(entries)
+        self.assertEqual(len(splits), 2)
+        for s in splits:
+            files = sorted(f.file_name for f in s.files)
+            self.assertEqual(files, ['g0.blob', 'g0.parquet'])
+
+    def test_deterministic_same_seed(self):
+        entries = [_mock_de_entry([], 0, f'g{i:02d}.parquet', i * 100, 100) 
for i in range(20)]
+        gen1 = _make_de_generator(seed=42, chunk_size=100)
+        gen2 = _make_de_generator(seed=42, chunk_size=100)
+        sigs1 = [_split_signature(s) for s in gen1.create_splits(entries)]
+        sigs2 = [_split_signature(s) for s in gen2.create_splits(entries)]
+        self.assertEqual(sigs1, sigs2)
+
+    def test_different_seed_reorders(self):
+        entries = [_mock_de_entry([], 0, f'g{i:02d}.parquet', i * 100, 100) 
for i in range(50)]
+        gen1 = _make_de_generator(seed=1, chunk_size=100)
+        gen2 = _make_de_generator(seed=2, chunk_size=100)
+        sigs1 = [_split_signature(s) for s in gen1.create_splits(entries)]
+        sigs2 = [_split_signature(s) for s in gen2.create_splits(entries)]
+        self.assertEqual(sorted(sigs1), sorted(sigs2))
+        self.assertNotEqual(sigs1, sigs2)
+
+    def test_input_order_does_not_affect_output(self):
+        a = _mock_de_entry([], 0, 'g0.parquet', 0, 100)
+        b = _mock_de_entry([], 0, 'g1.parquet', 100, 100)
+        c = _mock_de_entry([], 0, 'g2.parquet', 200, 100)
+        gen1 = _make_de_generator(seed=99, chunk_size=100)
+        gen2 = _make_de_generator(seed=99, chunk_size=100)
+        sigs1 = [_split_signature(s) for s in gen1.create_splits([a, b, c])]
+        sigs2 = [_split_signature(s) for s in gen2.create_splits([c, a, b])]
+        self.assertEqual(sigs1, sigs2)
+
+    def test_shard_round_trip_no_overlap_no_loss(self):
+        # 13 aligned groups × 100 rows = 1300 rows. Shard across 4 workers.
+        entries = [_mock_de_entry([], 0, f'g{i:02d}.parquet', i * 100, 100) 
for i in range(13)]
+        num_workers = 4
+
+        unsharded = _make_de_generator(seed=7, 
chunk_size=100).create_splits(list(entries))
+        unsharded_sigs = sorted(_split_signature(s) for s in unsharded)
+
+        sharded_sigs = []
+        total_rows = 0
+        for w in range(num_workers):
+            gen = _make_de_generator(seed=7, chunk_size=100)
+            gen.with_shard(w, num_workers)
+            for s in gen.create_splits(list(entries)):
+                sharded_sigs.append(_split_signature(s))
+                total_rows += s.row_count
+        self.assertEqual(total_rows, 13 * 100)
+        # No duplicate splits across workers
+        self.assertEqual(len(sharded_sigs), len(set(sharded_sigs)))
+        self.assertEqual(sorted(sharded_sigs), unsharded_sigs)
+
+    def test_multi_partition_no_chunk_crosses_partition(self):
+        entries = [
+            _mock_de_entry(['p1'], 0, 'g0.parquet', 0, 100),
+            _mock_de_entry(['p1'], 0, 'g1.parquet', 100, 100),
+            _mock_de_entry(['p2'], 0, 'g2.parquet', 200, 100),
+            _mock_de_entry(['p2'], 0, 'g3.parquet', 300, 100),
+        ]
+        gen = _make_de_generator(seed=0, chunk_size=100)
+        splits = gen.create_splits(entries)
+        for s in splits:
+            data_split = s.data_split() if isinstance(s, IndexedSplit) else s
+            self.assertEqual(len({tuple(data_split.partition.values)}), 1)
+
+    def test_null_and_non_null_partitions_sort_safely(self):
+        # Same null-vs-non-null sort guard, exercised on the DE path.
+        entries = [
+            _mock_de_entry(['p1'], 0, 'g0.parquet', 0, 100),
+            _mock_de_entry([None], 0, 'g1.parquet', 100, 100),
+            _mock_de_entry(['p2'], 0, 'g2.parquet', 200, 100),
+        ]
+        gen = _make_de_generator(seed=1, chunk_size=100)
+        splits = gen.create_splits(entries)
+        self.assertEqual(len(splits), 3)
+        partitions = {_split_signature(s)[0] for s in splits}
+        self.assertEqual(partitions, {('p1',), ('p2',), (None,)})
+
+
+class DataEvolutionChunkShuffleEndToEndTest(unittest.TestCase):
+    """Real DE table → with_chunk_shuffle → multi-worker → union == full 
table."""
+
+    @classmethod
+    def setUpClass(cls):
+        cls.tempdir = tempfile.mkdtemp()
+        cls.warehouse = os.path.join(cls.tempdir, 'warehouse')
+        cls.catalog = CatalogFactory.create({'warehouse': cls.warehouse})
+        cls.catalog.create_database('default', True)
+
+    @classmethod
+    def tearDownClass(cls):
+        shutil.rmtree(cls.tempdir, ignore_errors=True)
+
+    def _create_de_table(self, name):
+        pa_schema = pa.schema([
+            ('id', pa.int32()),
+            ('value', pa.string()),
+            ('payload', pa.large_binary()),
+        ])
+        schema = Schema.from_pyarrow_schema(
+            pa_schema,
+            options={
+                'row-tracking.enabled': 'true',
+                'data-evolution.enabled': 'true',
+                'blob.target-file-size': '1 b',
+            },
+        )
+        identifier = f'default.{name}'
+        self.catalog.create_table(identifier, schema, False)
+        return self.catalog.get_table(identifier), pa_schema
+
+    @staticmethod
+    def _payloads(ids):
+        return [f'payload-{i:03d}'.encode('utf-8') for i in ids]
+
+    def _commit_full_rows(self, table, pa_schema, ids):
+        wb = table.new_batch_write_builder()
+        tw = wb.new_write()
+        tc = wb.new_commit()
+        tw.write_arrow(pa.Table.from_pydict(
+            {
+                'id': ids,
+                'value': [f'v{i}' for i in ids],
+                'payload': self._payloads(ids),
+            },
+            schema=pa_schema))
+        commit_messages = tw.prepare_commit()
+        tc.commit(commit_messages)
+        tw.close()
+        tc.close()
+        return commit_messages
+
+    def _assert_commit_has_main_and_multiple_blob_files(self, commit_messages):
+        all_files = [f for msg in commit_messages for f in msg.new_files]
+        main_files = [f for f in all_files if not 
DataFileMeta.is_blob_file(f.file_name)]
+        blob_files = [f for f in all_files if 
DataFileMeta.is_blob_file(f.file_name)]
+        self.assertGreaterEqual(len(main_files), 1)
+        self.assertGreater(
+            len(blob_files), 1,
+            "DE chunk-shuffle tests should exercise one row-id group with 
multiple blob files",
+        )
+
+    def _assert_splits_include_blob_files(self, splits):
+        self.assertGreater(len(splits), 0)
+        for split in splits:
+            data_split = split.data_split() if isinstance(split, IndexedSplit) 
else split
+            blob_files = [
+                f for f in data_split.files
+                if DataFileMeta.is_blob_file(f.file_name)
+            ]
+            self.assertGreater(
+                len(blob_files), 0,
+                "Each DE chunk should keep blob sidecar files with its aligned 
row-id group",
+            )
+
+    def test_workers_union_equals_full_table(self):
+        table, pa_schema = self._create_de_table('cs_de_union')
+        # 4 commits → 4 aligned groups. Each group has one normal file and
+        # multiple blob sidecar files because blob.target-file-size is 1 byte.
+        for c in range(4):
+            base = c * 50
+            commit_messages = self._commit_full_rows(
+                table, pa_schema, list(range(base, base + 50)))
+            
self._assert_commit_has_main_and_multiple_blob_files(commit_messages)
+
+        read_builder = table.new_read_builder()
+        table_read = read_builder.new_read()
+
+        num_workers = 3
+        worker_tables = []
+        for w in range(num_workers):
+            scan = read_builder.new_scan() \
+                .with_chunk_shuffle(seed=123, chunk_size=37) \
+                .with_shard(w, num_workers)
+            splits = scan.plan().splits()
+            if splits:
+                self._assert_splits_include_blob_files(splits)
+                worker_tables.append(table_read.to_arrow(splits))
+
+        actual = pa.concat_tables(worker_tables).sort_by('id')
+        self.assertEqual(actual.num_rows, 200)
+        self.assertEqual(actual.column('id').to_pylist(), list(range(200)))
+        self.assertEqual(actual.column('payload').to_pylist(), 
self._payloads(range(200)))
+
+    def test_deterministic_plan_across_calls(self):
+        table, pa_schema = self._create_de_table('cs_de_determinism')
+        for c in range(3):
+            base = c * 40
+            commit_messages = self._commit_full_rows(
+                table, pa_schema, list(range(base, base + 40)))
+            
self._assert_commit_has_main_and_multiple_blob_files(commit_messages)
+
+        def plan_sigs(worker):
+            scan = table.new_read_builder().new_scan() \
+                .with_chunk_shuffle(seed=42, chunk_size=15) \
+                .with_shard(worker, 4)
+            splits = scan.plan().splits()
+            if splits:
+                self._assert_splits_include_blob_files(splits)
+            return [_split_signature(s) for s in splits]
+
+        for worker in range(4):
+            self.assertEqual(plan_sigs(worker), plan_sigs(worker))
+
+
+if __name__ == '__main__':
+    unittest.main()
diff --git a/paimon-python/pypaimon/tests/torch_read_test.py 
b/paimon-python/pypaimon/tests/torch_read_test.py
index ac6088ece3..5f55cb2bc8 100644
--- a/paimon-python/pypaimon/tests/torch_read_test.py
+++ b/paimon-python/pypaimon/tests/torch_read_test.py
@@ -645,6 +645,236 @@ class TorchReadTest(unittest.TestCase):
         print("✓ All predicate test cases passed!")
         print(f"{'=' * 60}\n")
 
+    def test_torch_streaming_shuffle_single_worker(self):
+        table = 
self._create_shuffle_append_table('default.test_torch_shuffle_single')
+        read_builder = table.new_read_builder().with_projection(['user_id'])
+        table_read = read_builder.new_read()
+        splits = read_builder.new_scan().plan().splits()
+
+        expected = list(range(80))
+        for max_buffer_input_splits in [1, 3]:
+            with self.subTest(max_buffer_input_splits=max_buffer_input_splits):
+                dataset = table_read.to_torch(
+                    splits,
+                    streaming=True,
+                    shuffle=True,
+                    seed=17,
+                    buffer_size=7,
+                    max_buffer_input_splits=max_buffer_input_splits,
+                )
+                ids = self._collect_torch_user_ids(dataset, num_workers=0)
+                self.assertEqual(sorted(ids), expected)
+                self.assertNotEqual(ids, expected)
+
+    def test_torch_streaming_shuffle_seed_and_epoch(self):
+        table = 
self._create_shuffle_append_table('default.test_torch_shuffle_epoch')
+        read_builder = table.new_read_builder().with_projection(['user_id'])
+        table_read = read_builder.new_read()
+        splits = read_builder.new_scan().plan().splits()
+
+        dataset = table_read.to_torch(
+            splits,
+            streaming=True,
+            shuffle=True,
+            seed=23,
+            buffer_size=11,
+            max_buffer_input_splits=4,
+        )
+        epoch0 = self._collect_torch_user_ids(dataset, num_workers=0)
+        epoch0_again = self._collect_torch_user_ids(dataset, num_workers=0)
+        self.assertEqual(epoch0, epoch0_again)
+
+        dataset.set_epoch(1)
+        epoch1 = self._collect_torch_user_ids(dataset, num_workers=0)
+        self.assertEqual(sorted(epoch1), list(range(80)))
+        self.assertNotEqual(epoch0, epoch1)
+
+        dataset.set_epoch(0)
+        self.assertEqual(epoch0, self._collect_torch_user_ids(dataset, 
num_workers=0))
+
+        other_seed_dataset = table_read.to_torch(
+            splits,
+            streaming=True,
+            shuffle=True,
+            seed=24,
+            buffer_size=11,
+            max_buffer_input_splits=4,
+        )
+        self.assertNotEqual(
+            epoch0,
+            self._collect_torch_user_ids(other_seed_dataset, num_workers=0),
+        )
+
+    def test_torch_streaming_shuffle_epoch_with_persistent_workers(self):
+        table = 
self._create_shuffle_append_table('default.test_torch_shuffle_persistent_epoch')
+        read_builder = table.new_read_builder().with_projection(['user_id'])
+        table_read = read_builder.new_read()
+        splits = read_builder.new_scan().plan().splits()
+
+        dataset = table_read.to_torch(
+            splits,
+            streaming=True,
+            shuffle=True,
+            seed=23,
+            buffer_size=11,
+            max_buffer_input_splits=4,
+        )
+        dataloader = DataLoader(
+            dataset,
+            batch_size=8,
+            num_workers=2,
+            persistent_workers=True,
+            shuffle=False,
+        )
+
+        epoch0 = self._collect_torch_user_ids_from_dataloader(dataloader)
+        self.assertEqual(epoch0, 
self._collect_torch_user_ids_from_dataloader(dataloader))
+
+        dataset.set_epoch(1)
+        epoch1 = self._collect_torch_user_ids_from_dataloader(dataloader)
+        self.assertEqual(sorted(epoch1), list(range(80)))
+        self.assertNotEqual(epoch0, epoch1)
+
+    def test_torch_streaming_shuffle_multi_worker(self):
+        table = 
self._create_shuffle_append_table('default.test_torch_shuffle_multi')
+        read_builder = table.new_read_builder().with_projection(['user_id'])
+        table_read = read_builder.new_read()
+        splits = read_builder.new_scan() \
+            .with_chunk_shuffle(seed=31, chunk_size=5) \
+            .plan() \
+            .splits()
+
+        dataset = table_read.to_torch(
+            splits,
+            streaming=True,
+            shuffle=True,
+            seed=31,
+            buffer_size=13,
+            max_buffer_input_splits=4,
+        )
+        ids = self._collect_torch_user_ids(dataset, num_workers=2)
+
+        expected = list(range(80))
+        self.assertEqual(len(ids), len(expected))
+        self.assertEqual(sorted(ids), expected)
+
+    def test_torch_streaming_shuffle_rejects_non_streaming(self):
+        table = 
self._create_shuffle_append_table('default.test_torch_shuffle_non_streaming')
+        read_builder = table.new_read_builder()
+        table_read = read_builder.new_read()
+        splits = read_builder.new_scan().plan().splits()
+
+        with self.assertRaisesRegex(ValueError, "streaming=True"):
+            table_read.to_torch(splits, streaming=False, shuffle=True)
+
+    def test_torch_streaming_shuffle_accepts_pk_table_splits(self):
+        pa_schema = pa.schema([
+            pa.field('user_id', pa.int32(), nullable=False),
+            ('item_id', pa.int64()),
+            ('behavior', pa.string()),
+            ('dt', pa.string())
+        ])
+        schema = Schema.from_pyarrow_schema(
+            pa_schema,
+            primary_keys=['user_id'],
+            options={'bucket': '1'},
+        )
+        self.catalog.create_table('default.test_torch_shuffle_pk', schema, 
False)
+        table = self.catalog.get_table('default.test_torch_shuffle_pk')
+        self._write_test_table(table)
+
+        read_builder = table.new_read_builder().with_projection(['user_id'])
+        splits = read_builder.new_scan().plan().splits()
+        dataset = read_builder.new_read().to_torch(
+            splits,
+            streaming=True,
+            shuffle=True,
+            seed=7,
+            buffer_size=3,
+        )
+        ids = self._collect_torch_user_ids(dataset, num_workers=0)
+
+        self.assertEqual(sorted(ids), [1, 2, 3, 4, 5, 6, 7, 8])
+
+    def test_torch_streaming_shuffle_rejects_invalid_dataset_options(self):
+        table = 
self._create_shuffle_append_table('default.test_torch_shuffle_invalid_options')
+        read_builder = table.new_read_builder().with_projection(['user_id'])
+        table_read = read_builder.new_read()
+        splits = read_builder.new_scan().plan().splits()
+
+        with self.assertRaisesRegex(ValueError, "prefetch_concurrency"):
+            table_read.to_torch(
+                splits,
+                streaming=True,
+                shuffle=True,
+                prefetch_concurrency=2,
+            )
+        with self.assertRaisesRegex(ValueError, "buffer_size"):
+            table_read.to_torch(
+                splits,
+                streaming=True,
+                shuffle=True,
+                buffer_size=0,
+            )
+        with self.assertRaisesRegex(ValueError, "max_buffer_input_splits"):
+            table_read.to_torch(
+                splits,
+                streaming=True,
+                shuffle=True,
+                max_buffer_input_splits=0,
+            )
+
+    def _create_shuffle_append_table(
+        self,
+        identifier,
+        total_rows=80,
+        rows_per_commit=10,
+        partition_keys=None,
+    ):
+        schema = Schema.from_pyarrow_schema(
+            self.pa_schema,
+            partition_keys=partition_keys or [],
+        )
+        self.catalog.create_table(identifier, schema, False)
+        table = self.catalog.get_table(identifier)
+
+        write_builder = table.new_batch_write_builder()
+        for start in range(0, total_rows, rows_per_commit):
+            end = min(start + rows_per_commit, total_rows)
+            table_write = write_builder.new_write()
+            table_commit = write_builder.new_commit()
+            pa_table = pa.Table.from_pydict({
+                'user_id': list(range(start, end)),
+                'item_id': [1000 + i for i in range(start, end)],
+                'behavior': [chr(ord('a') + (i % 26)) for i in range(start, 
end)],
+                'dt': [f'p{i % 4}' for i in range(start, end)],
+            }, schema=self.pa_schema)
+            table_write.write_arrow(pa_table)
+            table_commit.commit(table_write.prepare_commit())
+            table_write.close()
+            table_commit.close()
+        return table
+
+    @staticmethod
+    def _collect_torch_user_ids(dataset, num_workers=0):
+        dataloader = DataLoader(
+            dataset,
+            batch_size=8,
+            num_workers=num_workers,
+            shuffle=False,
+        )
+        all_user_ids = []
+        for batch_data in dataloader:
+            all_user_ids.extend(batch_data['user_id'].tolist())
+        return all_user_ids
+
+    @staticmethod
+    def _collect_torch_user_ids_from_dataloader(dataloader):
+        all_user_ids = []
+        for batch_data in dataloader:
+            all_user_ids.extend(batch_data['user_id'].tolist())
+        return all_user_ids
+
     def _write_test_table(self, table):
         write_builder = table.new_batch_write_builder()
         table_pa_schema = self.pk_pa_schema if table.primary_keys else 
self.pa_schema

Reply via email to