This is an automated email from the ASF dual-hosted git repository.
JingsongLi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git
The following commit(s) were added to refs/heads/master by this push:
new 2281765a7f [python] support chunk shuffle for planning and 3-layer
shuffle for pytorch Dataset (#8064)
2281765a7f is described below
commit 2281765a7f44c643353266efb1f95053f9a059ea
Author: Faiz <[email protected]>
AuthorDate: Mon Jun 8 19:14:25 2026 +0800
[python] support chunk shuffle for planning and 3-layer shuffle for pytorch
Dataset (#8064)
---
docs/docs/pypaimon/pytorch.md | 97 +++
.../pypaimon/read/datasource/torch_dataset.py | 224 ++++--
.../read/scanner/chunk_shuffle_split_generator.py | 377 ++++++++++
.../pypaimon/read/scanner/file_scanner.py | 66 +-
paimon-python/pypaimon/read/table_read.py | 20 +
paimon-python/pypaimon/read/table_scan.py | 4 +
.../scanner/chunk_shuffle_split_generator_test.py | 800 +++++++++++++++++++++
paimon-python/pypaimon/tests/torch_read_test.py | 230 ++++++
8 files changed, 1775 insertions(+), 43 deletions(-)
diff --git a/docs/docs/pypaimon/pytorch.md b/docs/docs/pypaimon/pytorch.md
index 9e98b487ee..0f7f7bbdef 100644
--- a/docs/docs/pypaimon/pytorch.md
+++ b/docs/docs/pypaimon/pytorch.md
@@ -58,3 +58,100 @@ When the `streaming` parameter is true, it will iteratively
read;
when it is false, it will read the full amount of data into memory.
**`prefetch_concurrency`** (default: 1): When streaming is true, number of
threads used for parallel prefetch within each DataLoader worker. Set to a
value greater than 1 to partition splits across threads and increase read
throughput. Has no effect when streaming is false.
+
+## Shuffle
+
+PyPaimon supports streaming shuffle for PyTorch `IterableDataset`. The shuffle
+pipeline can be composed of three layers:
+
+1. **Chunk shuffle**: split files into row chunks during scan planning and
+ shuffle the generated chunk splits. This is enabled by
+ `TableScan.with_chunk_shuffle(seed, chunk_size)`.
+2. **Split interleave**: read from multiple splits in round-robin order inside
+ each DataLoader worker.
+3. **Buffer shuffle**: apply a reservoir-style row shuffle buffer before rows
+ are yielded to PyTorch.
+
+Chunk shuffle is a scan planning feature for append tables, including
+Data Evolution append tables. For Data Evolution tables, chunk shuffle keeps
+row-id-aligned data files and sidecar files together while slicing by row-id
+range. Chunk shuffle should be used with file formats that **support random
+access**. Currently, the random-access file formats are Lance, Vortex, Row, and
+Blob. Primary-key tables and deletion-vector scans are not supported by
+`with_chunk_shuffle`.
+
+The second and third layers are Dataset features. They work on the splits you
+pass to `to_torch`, so they can be used with either normal splits or
+chunk-shuffled splits.
+
+### Use Dataset Shuffle Only
+
+Use this when normal scan splits are enough and you only want split interleave
+plus row buffer shuffle:
+
+```python
+from torch.utils.data import DataLoader
+
+table_scan = read_builder.new_scan()
+table_read = read_builder.new_read()
+splits = table_scan.plan().splits()
+
+dataset = table_read.to_torch(
+ splits,
+ streaming=True,
+ shuffle=True,
+ seed=42,
+ buffer_size=1000,
+ max_buffer_input_splits=10,
+)
+
+dataloader = DataLoader(
+ dataset,
+ batch_size=32,
+ num_workers=2,
+ shuffle=False,
+)
+```
+
+`buffer_size` controls the row shuffle buffer. Larger values produce a better
+approximation of global shuffle, at the cost of more memory. If
+`max_buffer_input_splits` is `1`, split interleave is skipped and only buffer
+shuffle is applied. `shuffle=True` requires `streaming=True` and does not
+support `prefetch_concurrency > 1`.
+
+### Use All Three Layers
+
+For append tables, enable chunk shuffle during scan planning, then enable
+Dataset shuffle when converting to PyTorch:
+
+```python
+from torch.utils.data import DataLoader
+
+seed = 42
+
+table_scan = read_builder.new_scan().with_chunk_shuffle(
+ seed=seed,
+ chunk_size=1000,
+)
+table_read = read_builder.new_read()
+splits = table_scan.plan().splits()
+
+dataset = table_read.to_torch(
+ splits,
+ streaming=True,
+ shuffle=True,
+ seed=seed,
+ buffer_size=1000,
+ max_buffer_input_splits=10,
+)
+
+dataloader = DataLoader(
+ dataset,
+ batch_size=32,
+ num_workers=2,
+ shuffle=False,
+)
+```
+
+Call `dataset.set_epoch(epoch)` before creating or iterating a DataLoader for a
+new training epoch if you want a different buffer-shuffle order for each epoch.
diff --git a/paimon-python/pypaimon/read/datasource/torch_dataset.py
b/paimon-python/pypaimon/read/datasource/torch_dataset.py
index 6012d3dc68..5eb3485ddd 100644
--- a/paimon-python/pypaimon/read/datasource/torch_dataset.py
+++ b/paimon-python/pypaimon/read/datasource/torch_dataset.py
@@ -19,8 +19,9 @@
Module to read a Paimon table into PyTorch Dataset.
"""
import queue
+import random
import threading
-from typing import List
+from typing import Iterator, List
import torch
from torch.utils.data import Dataset, IterableDataset
@@ -29,6 +30,12 @@ from pypaimon.read.split import Split
from pypaimon.read.table_read import TableRead
+def _share_epoch_with_torch_workers(value):
+ if isinstance(value, torch.Tensor):
+ return value.share_memory_()
+ return torch.tensor(value, dtype=torch.long).share_memory_()
+
+
class TorchDataset(Dataset):
"""
PyTorch Dataset implementation for reading Paimon table data.
@@ -76,7 +83,44 @@ class TorchDataset(Dataset):
return self._data[index]
-class TorchIterDataset(IterableDataset):
+class _BaseTorchIterDataset(IterableDataset):
+ """
+ Shared helpers for streaming PyTorch datasets backed by Paimon splits.
+ """
+
+ def __init__(self, table_read: TableRead, splits: List[Split]):
+ self.table_read = table_read
+ self.splits = splits
+ self.field_names = [field.name for field in table_read.read_type]
+
+ def _row_to_dict(self, offset_row) -> dict:
+ row_dict = {}
+ for i, field_name in enumerate(self.field_names):
+ value = offset_row.get_field(i)
+ row_dict[field_name] = value
+ return row_dict
+
+ def _worker_splits(self, worker_info) -> List[Split]:
+ if worker_info is None:
+ return self.splits
+
+ worker_id = worker_info.id
+ num_workers = worker_info.num_workers
+ total_splits = len(self.splits)
+ splits_per_worker = total_splits // num_workers
+ remainder = total_splits % num_workers
+
+ if worker_id < remainder:
+ start_idx = worker_id * (splits_per_worker + 1)
+ end_idx = start_idx + splits_per_worker + 1
+ else:
+ start_idx = worker_id * splits_per_worker + remainder
+ end_idx = start_idx + splits_per_worker
+
+ return self.splits[start_idx:end_idx]
+
+
+class TorchIterDataset(_BaseTorchIterDataset):
"""
PyTorch IterableDataset implementation for reading Paimon table data.
@@ -104,18 +148,8 @@ class TorchIterDataset(IterableDataset):
this worker (default 1). When > 1, splits are partitioned
across
threads to increase read throughput.
"""
- self.table_read = table_read
- self.splits = splits
+ super().__init__(table_read, splits)
self.prefetch_concurrency = max(1, int(prefetch_concurrency))
- # Get field names from read_type
- self.field_names = [field.name for field in table_read.read_type]
-
- def _row_to_dict(self, offset_row) -> dict:
- row_dict = {}
- for i, field_name in enumerate(self.field_names):
- value = offset_row.get_field(i)
- row_dict[field_name] = value
- return row_dict
def __iter__(self):
"""
@@ -128,30 +162,7 @@ class TorchIterDataset(IterableDataset):
row data of dict type, where keys are column names
"""
worker_info = torch.utils.data.get_worker_info()
-
- if worker_info is None:
- # Single-process data loading, iterate over all splits
- splits_to_process = self.splits
- else:
- # Multi-process data loading, partition splits across workers
- worker_id = worker_info.id
- num_workers = worker_info.num_workers
-
- # Calculate start and end indices for this worker
- # Distribute splits evenly by slicing
- total_splits = len(self.splits)
- splits_per_worker = total_splits // num_workers
- remainder = total_splits % num_workers
-
- # Workers with id < remainder get one extra split
- if worker_id < remainder:
- start_idx = worker_id * (splits_per_worker + 1)
- end_idx = start_idx + splits_per_worker + 1
- else:
- start_idx = worker_id * splits_per_worker + remainder
- end_idx = start_idx + splits_per_worker
-
- splits_to_process = self.splits[start_idx:end_idx]
+ splits_to_process = self._worker_splits(worker_info)
if self.prefetch_concurrency > 1:
for row in self._iter_rows(splits_to_process):
@@ -161,11 +172,7 @@ class TorchIterDataset(IterableDataset):
worker_iterator = self.table_read.to_iterator(splits_to_process)
for offset_row in worker_iterator:
- row_dict = {}
- for i, field_name in enumerate(self.field_names):
- value = offset_row.get_field(i)
- row_dict[field_name] = value
- yield row_dict
+ yield self._row_to_dict(offset_row)
def _iter_rows(self, splits: List[Split]):
n = min(self.prefetch_concurrency, len(splits))
@@ -221,3 +228,136 @@ class TorchIterDataset(IterableDataset):
stop.set()
for t in threads:
t.join(timeout=self._PREFETCH_JOIN_TIMEOUT_SEC)
+
+
+class TorchShuffledIterDataset(_BaseTorchIterDataset):
+ """
+ PyTorch IterableDataset with Paimon-controlled streaming shuffle.
+
+ This dataset consumes pre-planned splits, then mixes rows with split
+ interleaving and a shuffle buffer. Chunk-level shuffle, when needed,
+ stays in TableScan.with_chunk_shuffle().
+ """
+
+ def __init__(
+ self,
+ table_read: TableRead,
+ splits: List[Split],
+ seed: int = 0,
+ buffer_size: int = 1000,
+ max_buffer_input_splits: int = 10,
+ ):
+ super().__init__(table_read, splits)
+ self.seed = self._require_int(seed, "seed")
+ self.buffer_size = self._require_positive_int(buffer_size,
"buffer_size")
+ self.max_buffer_input_splits = self._require_positive_int(
+ max_buffer_input_splits, "max_buffer_input_splits")
+ self._epoch = _share_epoch_with_torch_workers(0)
+
+ def __setstate__(self, state):
+ self.__dict__ = state
+ self._epoch = _share_epoch_with_torch_workers(self._epoch)
+
+ @property
+ def epoch(self) -> int:
+ return int(self._epoch)
+
+ @epoch.setter
+ def epoch(self, epoch: int) -> None:
+ epoch = self._require_int(epoch, "epoch")
+ self._epoch += epoch - self._epoch
+
+ @staticmethod
+ def _require_int(value: int, name: str) -> int:
+ if not isinstance(value, int):
+ raise ValueError("%s must be an int" % name)
+ return value
+
+ @staticmethod
+ def _require_positive_int(value: int, name: str) -> int:
+ if not isinstance(value, int) or value <= 0:
+ raise ValueError("%s must be a positive int" % name)
+ return value
+
+ def set_epoch(self, epoch: int) -> "TorchShuffledIterDataset":
+ self.epoch = epoch
+ return self
+
+ def __iter__(self):
+ worker_info = torch.utils.data.get_worker_info()
+ worker_id = worker_info.id if worker_info is not None else 0
+ splits_to_process = self._worker_splits(worker_info)
+
+ if self.max_buffer_input_splits == 1:
+ rows = self._iter_ordered_rows(splits_to_process)
+ else:
+ rows = self._iter_interleaved_rows(splits_to_process)
+ for row in self._iter_buffer_shuffled_rows(rows, worker_id):
+ yield row
+
+ def _iter_ordered_rows(self, splits: List[Split]) -> Iterator[dict]:
+ for offset_row in self.table_read.to_iterator(splits):
+ yield self._row_to_dict(offset_row)
+
+ def _iter_interleaved_rows(self, splits: List[Split]) -> Iterator[dict]:
+ if not splits:
+ return
+
+ split_iter = iter(splits)
+ active: List[Iterator] = []
+
+ def add_next_split() -> bool:
+ try:
+ split = next(split_iter)
+ except StopIteration:
+ return False
+ active.append(self.table_read.to_iterator([split]))
+ return True
+
+ for _ in range(min(self.max_buffer_input_splits, len(splits))):
+ add_next_split()
+
+ idx = 0
+ try:
+ while active:
+ if idx >= len(active):
+ idx = 0
+ row_iter = active[idx]
+ try:
+ offset_row = next(row_iter)
+ except StopIteration:
+ self._close_iterator(row_iter)
+ del active[idx]
+ add_next_split()
+ continue
+
+ yield self._row_to_dict(offset_row)
+ idx += 1
+ finally:
+ for row_iter in active:
+ self._close_iterator(row_iter)
+
+ @staticmethod
+ def _close_iterator(row_iter) -> None:
+ close = getattr(row_iter, "close", None)
+ if close is not None:
+ close()
+
+ def _iter_buffer_shuffled_rows(
+ self,
+ rows: Iterator[dict],
+ worker_id: int,
+ ) -> Iterator[dict]:
+ rng = random.Random(self.seed + self.epoch * 1000003 + worker_id)
+ buffer = []
+ for row in rows:
+ if len(buffer) < self.buffer_size:
+ buffer.append(row)
+ continue
+ idx = rng.randint(0, self.buffer_size - 1)
+ yield buffer[idx]
+ buffer[idx] = row
+
+ rng.shuffle(buffer)
+ for row in buffer:
+ yield row
diff --git
a/paimon-python/pypaimon/read/scanner/chunk_shuffle_split_generator.py
b/paimon-python/pypaimon/read/scanner/chunk_shuffle_split_generator.py
new file mode 100644
index 0000000000..504236e500
--- /dev/null
+++ b/paimon-python/pypaimon/read/scanner/chunk_shuffle_split_generator.py
@@ -0,0 +1,377 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import random
+from abc import abstractmethod
+from collections import defaultdict
+from dataclasses import dataclass
+from typing import Any, List, Optional, Tuple
+
+from pypaimon.globalindex.indexed_split import IndexedSplit
+from pypaimon.manifest.schema.data_file_meta import DataFileMeta
+from pypaimon.manifest.schema.manifest_entry import ManifestEntry
+from pypaimon.read.scanner.split_generator import AbstractSplitGenerator
+from pypaimon.read.sliced_split import SlicedSplit
+from pypaimon.read.split import DataSplit, Split
+from pypaimon.table.row.generic_row import GenericRow
+from pypaimon.utils.range import Range
+
+
+def _null_safe_partition_key(partition_values) -> tuple:
+ """Wrap each partition value with a None-aware tag so tuples that mix
+ null and non-null partition values can be ordered without raising
+ ``TypeError: '<' not supported between instances of 'NoneType' and 'str'``.
+ Paimon supports null partition values; Python 3 refuses to compare
+ None against str/int directly.
+ """
+ return tuple((v is None, v) for v in partition_values)
+
+
+@dataclass
+class _Chunk:
+ """A unit of work for one DataLoader read. ``segments`` carries
+ subclass-specific payload (file segments for append, aligned-group
+ segments for data evolution).
+ """
+ partition: GenericRow
+ bucket: int
+ segments: List[Any]
+
+
+class ChunkShuffleSplitGeneratorBase(AbstractSplitGenerator):
+ """Common scaffolding for deterministic chunk-shuffled split generation.
+
+ Pipeline (template method, in :meth:`create_splits`):
+ 1. Stable-sort entries (key from :meth:`_sort_key`) so manifest-read
+ parallelism cannot bleed into the output.
+ 2. Group by (partition, bucket); iterate groups in sorted-key order.
+ 3. Per group, call :meth:`_slice_group_into_chunks` to produce a list
+ of segment lists — one segment list per chunk.
+ 4. Wrap each chunk with its (partition, bucket) into ``_Chunk``,
+ concatenate across groups.
+ 5. ``random.Random(seed).shuffle`` all chunks.
+ 6. If sharded, take this worker's slice via balanced
``_compute_shard_range``.
+ 7. Map each chunk through :meth:`_chunk_to_split`.
+
+ Subclasses implement the three abstract hooks. Reader paths
+ (``RawFileSplitRead`` for append, ``DataEvolutionSplitRead`` for DE)
+ are unchanged because chunks ride on existing wrappers
+ (``SlicedSplit`` / ``IndexedSplit``).
+ """
+
+ def __init__(
+ self,
+ table,
+ target_split_size: int,
+ open_file_cost: int,
+ deletion_files_map=None,
+ seed: int = 0,
+ chunk_size: int = 0,
+ ):
+ super().__init__(table, target_split_size, open_file_cost,
deletion_files_map)
+ self.seed = seed
+ self.chunk_size = chunk_size
+
+ def create_splits(self, file_entries: List[ManifestEntry]) -> List[Split]:
+ if not file_entries:
+ return []
+
+ sorted_entries = sorted(file_entries, key=self._sort_key)
+
+ partitioned: "defaultdict[Tuple[tuple, int], List[ManifestEntry]]" =
defaultdict(list)
+ for entry in sorted_entries:
+ partitioned[(tuple(entry.partition.values),
entry.bucket)].append(entry)
+
+ all_chunks: List[_Chunk] = []
+ for key in sorted(
+ partitioned.keys(),
+ key=lambda k: (_null_safe_partition_key(k[0]), k[1]),
+ ):
+ entries_in_group = partitioned[key]
+ partition_row = entries_in_group[0].partition
+ bucket = entries_in_group[0].bucket
+ # Materialize file_path once per unique file in this group.
+ seen_paths: set = set()
+ for entry in entries_in_group:
+ f = entry.file
+ if f.file_name in seen_paths:
+ continue
+ seen_paths.add(f.file_name)
+ f.set_file_path(
+ self.table.table_path,
+ partition_row,
+ bucket,
+ self.default_part_value,
+ )
+ for segments in self._slice_group_into_chunks(entries_in_group):
+ all_chunks.append(_Chunk(partition_row, bucket, segments))
+
+ rng = random.Random(self.seed)
+ rng.shuffle(all_chunks)
+
+ if self.idx_of_this_subtask is not None:
+ start, end = self._compute_shard_range(len(all_chunks))
+ all_chunks = all_chunks[start:end]
+
+ return [self._chunk_to_split(c) for c in all_chunks]
+
+ @abstractmethod
+ def _sort_key(self, entry: ManifestEntry):
+ """Return a comparable, deterministic key for stable sort."""
+
+ @abstractmethod
+ def _slice_group_into_chunks(self, entries: List[ManifestEntry]) ->
List[List[Any]]:
+ """Cut one (partition, bucket) group into chunks of segments.
+
+ Each returned inner list represents one chunk; segment shape is
+ subclass-defined.
+ """
+
+ @abstractmethod
+ def _chunk_to_split(self, chunk: _Chunk) -> Split:
+ """Wrap a chunk into a Split that the existing readers consume."""
+
+
+# ---------------------------------------------------------------------------
+# Append (non-DE, non-DV) implementation
+# ---------------------------------------------------------------------------
+
+
+@dataclass
+class _FileSegment:
+ """A contiguous slice of a data file inside one chunk.
+
+ start/end are half-open row offsets within the file when the chunk
+ boundary falls inside the file; both are None when the chunk owns
+ the full file (so SlicedSplit's shard_file_idx_map can skip it and
+ treat the file as full — see sliced_split.py:73-78).
+ """
+ file: DataFileMeta
+ start: Optional[int]
+ end: Optional[int]
+
+
+class AppendChunkShuffleSplitGenerator(ChunkShuffleSplitGeneratorBase):
+ """Chunk-shuffled splits for plain append tables (non-PK, non-DV,
non-DE)."""
+
+ def _sort_key(self, entry: ManifestEntry):
+ return (
+ _null_safe_partition_key(entry.partition.values),
+ entry.bucket,
+ entry.file.file_name,
+ )
+
+ def _slice_group_into_chunks(
+ self, entries: List[ManifestEntry]
+ ) -> List[List[_FileSegment]]:
+ """Cut a (partition, bucket) group into chunks of at most
+ ``self.chunk_size`` rows. ``chunk_size`` is a hard upper bound:
+ the last chunk may be smaller, but no chunk exceeds it.
+ """
+ chunks: List[List[_FileSegment]] = []
+ current: List[_FileSegment] = []
+ current_rows = 0
+
+ for entry in entries:
+ file = entry.file
+ offset = 0
+ remaining = file.row_count
+ while remaining > 0:
+ avail = self.chunk_size - current_rows
+ if avail <= 0:
+ chunks.append(current)
+ current = []
+ current_rows = 0
+ avail = self.chunk_size
+
+ take = min(remaining, avail)
+
+ if take == file.row_count and offset == 0:
+ current.append(_FileSegment(file, None, None))
+ else:
+ current.append(_FileSegment(file, offset, offset + take))
+
+ current_rows += take
+ offset += take
+ remaining -= take
+
+ if current:
+ chunks.append(current)
+
+ return chunks
+
+ def _chunk_to_split(self, chunk: _Chunk) -> Split:
+ files: List[DataFileMeta] = []
+ shard_file_idx_map = {}
+ for seg in chunk.segments:
+ files.append(seg.file)
+ if seg.start is not None and seg.end is not None:
+ shard_file_idx_map[seg.file.file_name] = (seg.start, seg.end)
+
+ # set_file_path is already done once per unique file in
+ # ChunkShuffleSplitGeneratorBase.create_splits.
+
+ data_split = DataSplit(
+ files=files,
+ partition=chunk.partition,
+ bucket=chunk.bucket,
+ raw_convertible=True,
+ data_deletion_files=None,
+ )
+
+ if shard_file_idx_map:
+ return SlicedSplit(data_split, shard_file_idx_map)
+ return data_split
+
+
+# ---------------------------------------------------------------------------
+# Data Evolution implementation
+# ---------------------------------------------------------------------------
+
+
+@dataclass
+class _AlignedGroupSegment:
+ """A row_id sub-range over one row-id-aligned file group.
+
+ ``files`` is the entire group (may include blob/vector siblings),
+ so the reader sees every column file even when only a slice of the
+ group's row_id range lands in this chunk. ``row_range`` is the
+ inclusive global row_id range this segment owns.
+ """
+ files: List[DataFileMeta]
+ row_range: Range
+
+
+class DataEvolutionChunkShuffleSplitGenerator(ChunkShuffleSplitGeneratorBase):
+ """Chunk-shuffled splits for data-evolution append tables.
+
+ The minimum cuttable unit is a row_id-aligned file group: cutting
+ inside one group would orphan column files relative to the row_id
+ range, so we keep groups intact and only slice along their row_id
+ axis. Each chunk maps to an :class:`IndexedSplit` whose ``row_ranges``
+ bound the readable slice for that chunk.
+ """
+
+ def _sort_key(self, entry: ManifestEntry):
+ first_row_id = (
+ entry.file.first_row_id
+ if entry.file.first_row_id is not None
+ else float('-inf')
+ )
+ is_special = 1 if (
+ DataFileMeta.is_blob_file(entry.file.file_name)
+ or DataFileMeta.is_vector_file(entry.file.file_name)
+ ) else 0
+ return (
+ _null_safe_partition_key(entry.partition.values),
+ entry.bucket,
+ first_row_id,
+ is_special,
+ entry.file.file_name,
+ )
+
+ def _slice_group_into_chunks(
+ self, entries: List[ManifestEntry]
+ ) -> List[List[_AlignedGroupSegment]]:
+ files = [e.file for e in entries]
+ # (Range, [files]) pairs sorted by row_id — see helper docstring.
+ aligned_groups = self._split_by_row_id_with_range(files)
+
+ chunks: List[List[_AlignedGroupSegment]] = []
+ current: List[_AlignedGroupSegment] = []
+ current_rows = 0
+
+ for group_range, group_files in aligned_groups:
+ offset = 0
+ group_rows = group_range.count()
+ while offset < group_rows:
+ avail = self.chunk_size - current_rows
+ if avail <= 0:
+ chunks.append(current)
+ current = []
+ current_rows = 0
+ avail = self.chunk_size
+
+ take = min(group_rows - offset, avail)
+ seg_range = Range(
+ group_range.from_ + offset,
+ group_range.from_ + offset + take - 1,
+ )
+ current.append(_AlignedGroupSegment(group_files, seg_range))
+ current_rows += take
+ offset += take
+
+ if current:
+ chunks.append(current)
+
+ return chunks
+
+ def _chunk_to_split(self, chunk: _Chunk) -> Split:
+ segments = chunk.segments
+ if len(segments) == 1:
+ all_files = segments[0].files
+ row_ranges = [segments[0].row_range]
+ else:
+ all_files = []
+ row_ranges = []
+ for seg in segments:
+ all_files.extend(seg.files)
+ row_ranges.append(seg.row_range)
+ row_ranges.sort(key=lambda r: r.from_)
+
+ data_split = DataSplit(
+ files=all_files,
+ partition=chunk.partition,
+ bucket=chunk.bucket,
+ raw_convertible=False,
+ data_deletion_files=None,
+ )
+ return IndexedSplit(data_split, row_ranges, scores=None)
+
+ @staticmethod
+ def _split_by_row_id_with_range(
+ files: List[DataFileMeta],
+ ) -> List[Tuple[Range, List[DataFileMeta]]]:
+ """Group files by overlapping row_id range, returning (range, files)
+ pairs sorted by ``range.from_``.
+
+ Mirrors :meth:`DataEvolutionSplitGenerator._split_by_row_id` but
+ also returns the merged row_id range per group, which the chunk
+ slicer needs to drive row-count accumulation.
+ """
+ list_ranges = []
+ for f in files:
+ file_range = f.row_id_range()
+ if file_range is None:
+ raise ValueError(
+ "chunk_shuffle for data evolution tables requires row
tracking; "
+ f"file {f.file_name} is missing first_row_id"
+ )
+ list_ranges.append(file_range)
+ if not list_ranges:
+ return []
+ sorted_ranges = Range.sort_and_merge_overlap(list_ranges, True, False)
+
+ range_to_files: "dict[Range, List[DataFileMeta]]" = {}
+ for f in files:
+ file_range = f.row_id_range()
+ for r in sorted_ranges:
+ if r.overlaps(file_range):
+ range_to_files.setdefault(r, []).append(f)
+ break
+
+ return sorted(range_to_files.items(), key=lambda kv: kv[0].from_)
diff --git a/paimon-python/pypaimon/read/scanner/file_scanner.py
b/paimon-python/pypaimon/read/scanner/file_scanner.py
index 5424d2d519..650bbe8ca4 100755
--- a/paimon-python/pypaimon/read/scanner/file_scanner.py
+++ b/paimon-python/pypaimon/read/scanner/file_scanner.py
@@ -40,6 +40,10 @@ from pypaimon.read.scanner.append_table_split_generator
import \
AppendTableSplitGenerator
from pypaimon.read.scanner.bucket_select_converter import \
create_bucket_selector
+from pypaimon.read.scanner.chunk_shuffle_split_generator import (
+ AppendChunkShuffleSplitGenerator,
+ DataEvolutionChunkShuffleSplitGenerator,
+)
from pypaimon.read.scanner.data_evolution_split_generator import \
DataEvolutionSplitGenerator
from pypaimon.read.scanner.primary_key_table_split_generator import \
@@ -204,6 +208,7 @@ class FileScanner:
self.number_of_para_subtasks = None
self.start_pos_of_this_subtask = None
self.end_pos_of_this_subtask = None
+ self.chunk_shuffle: Optional[Tuple[int, int]] = None
self.only_read_real_buckets = options.bucket() ==
BucketMode.POSTPONE_BUCKET.value
self.data_evolution = options.data_evolution_enabled()
@@ -243,7 +248,34 @@ class FileScanner:
def scan(self) -> Plan:
start_ms = time.time() * 1000
# Create appropriate split generator based on table type
- if self.table.is_primary_key_table:
+ if self.chunk_shuffle is not None:
+ self._validate_chunk_shuffle_compat()
+ seed, chunk_size = self.chunk_shuffle
+ # Both append and DE paths use plan_files() directly: the
+ # predicate is partition-only (enforced by
+ # _validate_chunk_shuffle_compat), so manifest_entry-level
+ # partition pruning in plan_files() is the only filter we
+ # want — no row_id range pushdown, no global index lookup.
+ entries = self.plan_files()
+ if self.data_evolution:
+ split_generator = DataEvolutionChunkShuffleSplitGenerator(
+ self.table,
+ self.target_split_size,
+ self.open_file_cost,
+ self._deletion_files_map(entries),
+ seed=seed,
+ chunk_size=chunk_size,
+ )
+ else:
+ split_generator = AppendChunkShuffleSplitGenerator(
+ self.table,
+ self.target_split_size,
+ self.open_file_cost,
+ self._deletion_files_map(entries),
+ seed=seed,
+ chunk_size=chunk_size,
+ )
+ elif self.table.is_primary_key_table:
entries = self.plan_files()
split_generator = PrimaryKeyTableSplitGenerator(
self.table,
@@ -441,6 +473,38 @@ class FileScanner:
plan = self.scan()
return plan, self.scan_stats
+ def with_chunk_shuffle(self, seed: int, chunk_size: int) -> 'FileScanner':
+ if not isinstance(seed, int):
+ raise ValueError("chunk_shuffle seed must be an int")
+ if not isinstance(chunk_size, int) or chunk_size <= 0:
+ raise ValueError("chunk_shuffle chunk_size must be a positive int")
+ self.chunk_shuffle = (seed, chunk_size)
+ return self
+
+ def _validate_chunk_shuffle_compat(self) -> None:
+ if self.table.is_primary_key_table:
+ raise ValueError("chunk_shuffle only supports append tables")
+ if self.deletion_vectors_enabled:
+ raise ValueError("chunk_shuffle not supported with deletion
vectors")
+ if self.start_pos_of_this_subtask is not None:
+ raise ValueError("chunk_shuffle cannot combine with with_slice")
+ if self.limit is not None:
+ raise ValueError("chunk_shuffle cannot combine with limit")
+ if self._global_index_result is not None:
+ raise ValueError("chunk_shuffle cannot combine with global index")
+ # Only partition predicates are allowed: row-level / column-level
+ # predicates would silently shrink each chunk's effective row count,
+ # breaking the chunk_size contract DataLoader callers expect.
+ if self.predicate is not None:
+ partition_keys = set(self.table.partition_keys or [])
+ non_partition_fields = _get_all_fields(self.predicate) -
partition_keys
+ if non_partition_fields:
+ raise ValueError(
+ "chunk_shuffle predicate must reference only partition
keys; "
+ "got non-partition fields: "
+ f"{sorted(non_partition_fields)}"
+ )
+
def _apply_push_down_limit(self, splits: List[DataSplit]) ->
List[DataSplit]:
"""Mirror Java ``DataTableBatchScan.applyPushDownLimit``: sum the
DV-aware ``merged_row_count`` (== Java ``Split.mergedRowCount()``)
diff --git a/paimon-python/pypaimon/read/table_read.py
b/paimon-python/pypaimon/read/table_read.py
index 826b2b4024..57e704b2b5 100644
--- a/paimon-python/pypaimon/read/table_read.py
+++ b/paimon-python/pypaimon/read/table_read.py
@@ -511,8 +511,28 @@ class TableRead:
splits: List[Split],
streaming: bool = False,
prefetch_concurrency: int = 1,
+ *,
+ shuffle: bool = False,
+ seed: int = 0,
+ buffer_size: int = 1000,
+ max_buffer_input_splits: int = 10,
) -> "torch.utils.data.Dataset":
"""Wrap Paimon table data to PyTorch Dataset."""
+ if shuffle:
+ if not streaming:
+ raise ValueError("shuffle=True only supports streaming=True")
+ if prefetch_concurrency > 1:
+ raise ValueError("shuffle=True does not support
prefetch_concurrency > 1")
+ from pypaimon.read.datasource.torch_dataset import
TorchShuffledIterDataset
+ dataset = TorchShuffledIterDataset(
+ self,
+ splits,
+ seed=seed,
+ buffer_size=buffer_size,
+ max_buffer_input_splits=max_buffer_input_splits,
+ )
+ return dataset
+
if streaming:
from pypaimon.read.datasource.torch_dataset import TorchIterDataset
dataset = TorchIterDataset(self, splits, prefetch_concurrency)
diff --git a/paimon-python/pypaimon/read/table_scan.py
b/paimon-python/pypaimon/read/table_scan.py
index 03a1c8b062..36568c618f 100755
--- a/paimon-python/pypaimon/read/table_scan.py
+++ b/paimon-python/pypaimon/read/table_scan.py
@@ -161,6 +161,10 @@ class TableScan:
self.file_scanner.with_global_index_result(result)
return self
+ def with_chunk_shuffle(self, seed: int, chunk_size: int) -> 'TableScan':
+ self.file_scanner.with_chunk_shuffle(seed, chunk_size)
+ return self
+
def _validate_scan_mode(self):
"""Validate scan.mode against companion options using a whitelist
approach.
diff --git
a/paimon-python/pypaimon/tests/scanner/chunk_shuffle_split_generator_test.py
b/paimon-python/pypaimon/tests/scanner/chunk_shuffle_split_generator_test.py
new file mode 100644
index 0000000000..be50cfa637
--- /dev/null
+++ b/paimon-python/pypaimon/tests/scanner/chunk_shuffle_split_generator_test.py
@@ -0,0 +1,800 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Tests for ChunkShuffleSplitGenerator and TableScan.with_chunk_shuffle.
+
+Algorithmic tests use Mock entries so they don't touch disk; the
+end-to-end test writes a real append table and validates that all
+workers together cover the data exactly once.
+"""
+
+import os
+import shutil
+import tempfile
+import unittest
+from unittest.mock import Mock
+
+import pyarrow as pa
+
+from pypaimon import CatalogFactory, Schema
+from pypaimon.globalindex.indexed_split import IndexedSplit
+from pypaimon.manifest.schema.data_file_meta import DataFileMeta
+from pypaimon.read.scanner.chunk_shuffle_split_generator import (
+ AppendChunkShuffleSplitGenerator,
+ DataEvolutionChunkShuffleSplitGenerator,
+)
+from pypaimon.read.sliced_split import SlicedSplit
+from pypaimon.read.split import DataSplit
+from pypaimon.utils.range import Range
+
+
+def _mock_table(table_path='/tmp/_chunk_shuffle_test_path'):
+ table = Mock()
+ table.table_path = table_path
+ table.options = Mock()
+ return table
+
+
+def _mock_entry(partition_values, bucket, file_name, row_count,
file_size=1024):
+ entry = Mock()
+ entry.partition = Mock()
+ entry.partition.values = partition_values
+ entry.bucket = bucket
+ entry.file = Mock()
+ entry.file.file_name = file_name
+ entry.file.file_size = file_size
+ entry.file.row_count = row_count
+ # Swallow set_file_path so we don't need to mock partition path encoding.
+ entry.file.set_file_path = Mock()
+ return entry
+
+
+def _make_generator(seed, chunk_size, table=None):
+ if table is None:
+ table = _mock_table()
+ return AppendChunkShuffleSplitGenerator(
+ table,
+ target_split_size=128 * 1024 * 1024,
+ open_file_cost=4 * 1024 * 1024,
+ deletion_files_map=None,
+ seed=seed,
+ chunk_size=chunk_size,
+ )
+
+
+def _make_de_generator(seed, chunk_size, table=None):
+ if table is None:
+ table = _mock_table()
+ return DataEvolutionChunkShuffleSplitGenerator(
+ table,
+ target_split_size=128 * 1024 * 1024,
+ open_file_cost=4 * 1024 * 1024,
+ deletion_files_map=None,
+ seed=seed,
+ chunk_size=chunk_size,
+ )
+
+
+def _mock_de_entry(partition_values, bucket, file_name, first_row_id,
row_count, file_size=1024):
+ """A DE-flavoured mock entry: file carries first_row_id and a real
+ Range so :meth:`row_id_range` and ``Range.overlaps`` work."""
+ entry = Mock()
+ entry.partition = Mock()
+ entry.partition.values = partition_values
+ entry.bucket = bucket
+ file = Mock(spec=DataFileMeta)
+ file.file_name = file_name
+ file.file_size = file_size
+ file.row_count = row_count
+ file.first_row_id = first_row_id
+ file.row_id_range = lambda f=first_row_id, c=row_count: Range(f, f + c - 1)
+ file.set_file_path = Mock()
+ entry.file = file
+ return entry
+
+
+def _split_signature(split):
+ """A stable, comparable identity for a split — what the worker would
actually read."""
+ if isinstance(split, SlicedSplit):
+ underlying = split.data_split()
+ files = tuple(f.file_name for f in underlying.files)
+ idx_map = tuple(sorted(split.shard_file_idx_map().items()))
+ return (tuple(underlying.partition.values), underlying.bucket, files,
idx_map)
+ if isinstance(split, IndexedSplit):
+ underlying = split.data_split()
+ files = tuple(sorted(f.file_name for f in underlying.files))
+ ranges = tuple((r.from_, r.to) for r in split.row_ranges())
+ return (tuple(underlying.partition.values), underlying.bucket, files,
ranges)
+ if isinstance(split, DataSplit):
+ files = tuple(f.file_name for f in split.files)
+ return (tuple(split.partition.values), split.bucket, files, ())
+ raise AssertionError("unexpected split type: %r" % type(split))
+
+
+def _split_rows(split):
+ """Effective row count this split actually exposes."""
+ return split.row_count
+
+
+class ChunkShuffleSplitGeneratorAlgoTest(unittest.TestCase):
+
+ def test_no_entries_returns_empty(self):
+ gen = _make_generator(seed=1, chunk_size=100)
+ self.assertEqual(gen.create_splits([]), [])
+
+ def test_full_files_no_truncation(self):
+ entries = [
+ _mock_entry([], 0, 'f1', 100),
+ _mock_entry([], 0, 'f2', 100),
+ _mock_entry([], 0, 'f3', 100),
+ ]
+ gen = _make_generator(seed=1, chunk_size=100)
+ splits = gen.create_splits(entries)
+ # 3 chunks, each holding exactly one whole file → all DataSplit, no
SlicedSplit
+ self.assertEqual(len(splits), 3)
+ for s in splits:
+ self.assertIsInstance(s, DataSplit)
+ self.assertEqual(s.row_count, 100)
+
+ def test_chunk_truncates_inside_file(self):
+ # one file of 250 rows, chunk_size 100 → 3 chunks: 100, 100, 50
+ entries = [_mock_entry([], 0, 'f1', 250)]
+ gen = _make_generator(seed=1, chunk_size=100)
+ splits = gen.create_splits(entries)
+ self.assertEqual(len(splits), 3)
+ # All three chunks slice the same file → all SlicedSplit
+ for s in splits:
+ self.assertIsInstance(s, SlicedSplit)
+ # union of (start, end) intervals must cover [0, 250)
+ intervals = sorted(s.shard_file_idx_map()['f1'] for s in splits)
+ self.assertEqual(intervals, [(0, 100), (100, 200), (200, 250)])
+ total = sum(end - start for start, end in intervals)
+ self.assertEqual(total, 250)
+
+ def test_chunk_spans_multiple_files(self):
+ # f1=30, f2=30, f3=30, chunk_size=50 → chunks: [f1(30)+f2(0,20)],
[f2(20,30)+f3(0,40 cap 30=30)] ...
+ entries = [
+ _mock_entry([], 0, 'f1', 30),
+ _mock_entry([], 0, 'f2', 30),
+ _mock_entry([], 0, 'f3', 30),
+ ]
+ gen = _make_generator(seed=1, chunk_size=50)
+ splits = gen.create_splits(entries)
+ # total 90 rows, chunk_size 50 → 2 chunks (50 + 40)
+ self.assertEqual(len(splits), 2)
+ total_rows = sum(_split_rows(s) for s in splits)
+ self.assertEqual(total_rows, 90)
+
+ def test_chunk_size_larger_than_total(self):
+ entries = [
+ _mock_entry([], 0, 'f1', 30),
+ _mock_entry([], 0, 'f2', 30),
+ ]
+ gen = _make_generator(seed=1, chunk_size=1000)
+ splits = gen.create_splits(entries)
+ self.assertEqual(len(splits), 1)
+ # No truncation — full files inside one chunk → DataSplit not
SlicedSplit
+ self.assertIsInstance(splits[0], DataSplit)
+ self.assertEqual(_split_rows(splits[0]), 60)
+
+ def test_deterministic_same_seed_same_order(self):
+ entries = [_mock_entry([], 0, f'f{i}', 100) for i in range(20)]
+ gen1 = _make_generator(seed=42, chunk_size=50)
+ gen2 = _make_generator(seed=42, chunk_size=50)
+ splits1 = gen1.create_splits(entries)
+ splits2 = gen2.create_splits(entries)
+ self.assertEqual(
+ [_split_signature(s) for s in splits1],
+ [_split_signature(s) for s in splits2],
+ )
+
+ def test_different_seed_different_order(self):
+ entries = [_mock_entry([], 0, f'f{i}', 100) for i in range(50)]
+ gen1 = _make_generator(seed=1, chunk_size=100)
+ gen2 = _make_generator(seed=2, chunk_size=100)
+ sigs1 = [_split_signature(s) for s in gen1.create_splits(entries)]
+ sigs2 = [_split_signature(s) for s in gen2.create_splits(entries)]
+ # Same set of chunks, different order — high probability they differ
on 50 items
+ self.assertEqual(sorted(sigs1), sorted(sigs2))
+ self.assertNotEqual(sigs1, sigs2)
+
+ def test_shuffle_actually_reorders(self):
+ # 20 files in scan order f0..f19. After shuffle the file order should
not be sorted.
+ entries = [_mock_entry([], 0, f'f{i:02d}', 100) for i in range(20)]
+ gen = _make_generator(seed=42, chunk_size=100)
+ splits = gen.create_splits(entries)
+ file_names = [s.files[0].file_name for s in splits]
+ self.assertNotEqual(file_names, sorted(file_names))
+
+ def test_shard_round_trip_no_overlap_no_loss(self):
+ # 13 files × 100 rows = 1300 rows. 4 workers.
+ entries = [_mock_entry([], 0, f'f{i:02d}', 100) for i in range(13)]
+ num_workers = 4
+ all_sigs = []
+ total_rows = 0
+ for worker in range(num_workers):
+ gen = _make_generator(seed=7, chunk_size=100)
+ gen.with_shard(worker, num_workers)
+ splits = gen.create_splits(list(entries)) # copy: shuffle is
in-place on chunks list
+ for s in splits:
+ all_sigs.append(_split_signature(s))
+ total_rows += _split_rows(s)
+ self.assertEqual(total_rows, 13 * 100)
+ # No duplicate chunks across workers
+ self.assertEqual(len(all_sigs), len(set(all_sigs)))
+ # All chunks together equal an unsharded run
+ unsharded = _make_generator(seed=7,
chunk_size=100).create_splits(list(entries))
+ self.assertEqual(
+ sorted(all_sigs),
+ sorted(_split_signature(s) for s in unsharded),
+ )
+
+ def test_shard_balanced_distribution(self):
+ # 10 chunks across 3 workers → 4, 3, 3 (front-loaded by
_compute_shard_range)
+ entries = [_mock_entry([], 0, f'f{i:02d}', 100) for i in range(10)]
+ counts = []
+ for worker in range(3):
+ gen = _make_generator(seed=0, chunk_size=100)
+ gen.with_shard(worker, 3)
+ counts.append(len(gen.create_splits(list(entries))))
+ self.assertEqual(sorted(counts, reverse=True), [4, 3, 3])
+
+ def test_chunks_fewer_than_workers(self):
+ # 2 chunks, 5 workers → 3 workers get nothing
+ entries = [_mock_entry([], 0, f'f{i}', 100) for i in range(2)]
+ empties = 0
+ non_empties = 0
+ for worker in range(5):
+ gen = _make_generator(seed=0, chunk_size=100)
+ gen.with_shard(worker, 5)
+ n = len(gen.create_splits(list(entries)))
+ if n == 0:
+ empties += 1
+ else:
+ non_empties += 1
+ self.assertEqual(n, 1)
+ self.assertEqual(empties, 3)
+ self.assertEqual(non_empties, 2)
+
+ def test_multi_partition_no_chunk_crosses_partition(self):
+ entries = [
+ _mock_entry(['p1'], 0, 'f1', 100),
+ _mock_entry(['p1'], 0, 'f2', 100),
+ _mock_entry(['p2'], 0, 'f3', 100),
+ _mock_entry(['p2'], 0, 'f4', 100),
+ ]
+ gen = _make_generator(seed=0, chunk_size=100)
+ splits = gen.create_splits(entries)
+ # Each split's underlying files come from one partition only
+ for s in splits:
+ partitions_in_files = set()
+ data_split = s.data_split() if isinstance(s, SlicedSplit) else s
+ partitions_in_files.add(tuple(data_split.partition.values))
+ self.assertEqual(len(partitions_in_files), 1)
+
+ def test_null_and_non_null_partitions_sort_safely(self):
+ # Mixing null and non-null partition values used to raise
+ # ``TypeError: '<' not supported between instances of 'NoneType' and
'str'``
+ # before _null_safe_partition_key. Validate planning succeeds and
+ # both partitions produce splits.
+ entries = [
+ _mock_entry(['p1'], 0, 'f1', 100),
+ _mock_entry([None], 0, 'f2', 100),
+ _mock_entry(['p2'], 0, 'f3', 100),
+ ]
+ gen = _make_generator(seed=1, chunk_size=100)
+ splits = gen.create_splits(entries)
+ self.assertEqual(len(splits), 3)
+ partitions = {tuple(_split_signature(s)[0]) for s in splits}
+ self.assertEqual(partitions, {('p1',), ('p2',), (None,)})
+
+ def test_input_order_does_not_affect_output_when_same_files(self):
+ """Manifest read parallelism shouldn't bleed through — sorting is
internal."""
+ a = _mock_entry([], 0, 'f1', 100)
+ b = _mock_entry([], 0, 'f2', 100)
+ c = _mock_entry([], 0, 'f3', 100)
+ gen1 = _make_generator(seed=99, chunk_size=100)
+ gen2 = _make_generator(seed=99, chunk_size=100)
+ sigs1 = [_split_signature(s) for s in gen1.create_splits([a, b, c])]
+ sigs2 = [_split_signature(s) for s in gen2.create_splits([c, a, b])]
+ self.assertEqual(sigs1, sigs2)
+
+
+class ChunkShuffleEndToEndTest(unittest.TestCase):
+ """Real append table → with_chunk_shuffle → multiple workers → union ==
original."""
+
+ @classmethod
+ def setUpClass(cls):
+ cls.tempdir = tempfile.mkdtemp()
+ cls.warehouse = os.path.join(cls.tempdir, 'warehouse')
+ cls.catalog = CatalogFactory.create({'warehouse': cls.warehouse})
+ cls.catalog.create_database('default', True)
+
+ @classmethod
+ def tearDownClass(cls):
+ shutil.rmtree(cls.tempdir, ignore_errors=True)
+
+ def _create_append_table(self, name, partition_keys=None):
+ pa_schema = pa.schema([
+ ('id', pa.int64()),
+ ('value', pa.string()),
+ ('part', pa.string()),
+ ])
+ schema = Schema.from_pyarrow_schema(
+ pa_schema, partition_keys=partition_keys or [])
+ identifier = f'default.{name}'
+ self.catalog.create_table(identifier, schema, False)
+ return self.catalog.get_table(identifier), pa_schema
+
+ def _write_n_batches(self, table, pa_schema, batches):
+ wb = table.new_batch_write_builder()
+ for batch in batches:
+ tw = wb.new_write()
+ tc = wb.new_commit()
+ tw.write_arrow(pa.Table.from_pydict(batch, schema=pa_schema))
+ tc.commit(tw.prepare_commit())
+ tw.close()
+ tc.close()
+
+ def test_workers_union_equals_full_table(self):
+ table, pa_schema = self._create_append_table('cs_union')
+ # 4 commits × 50 rows = 200 rows across several files
+ batches = []
+ for c in range(4):
+ base = c * 50
+ batches.append({
+ 'id': list(range(base, base + 50)),
+ 'value': [f'v{i}' for i in range(base, base + 50)],
+ 'part': ['p1'] * 50,
+ })
+ self._write_n_batches(table, pa_schema, batches)
+
+ read_builder = table.new_read_builder()
+ table_read = read_builder.new_read()
+
+ num_workers = 3
+ worker_tables = []
+ for w in range(num_workers):
+ scan = read_builder.new_scan() \
+ .with_chunk_shuffle(seed=123, chunk_size=37) \
+ .with_shard(w, num_workers)
+ splits = scan.plan().splits()
+ if splits:
+ worker_tables.append(table_read.to_arrow(splits))
+
+ actual = pa.concat_tables(worker_tables).sort_by('id') if
worker_tables else None
+ self.assertIsNotNone(actual)
+ self.assertEqual(actual.num_rows, 200)
+ self.assertEqual(actual.column('id').to_pylist(), list(range(200)))
+
+ def test_deterministic_plan_across_calls(self):
+ table, pa_schema = self._create_append_table('cs_determinism')
+ self._write_n_batches(table, pa_schema, [{
+ 'id': list(range(100)),
+ 'value': [f'v{i}' for i in range(100)],
+ 'part': ['p'] * 100,
+ }])
+
+ def plan_files(worker):
+ scan = table.new_read_builder().new_scan() \
+ .with_chunk_shuffle(seed=42, chunk_size=20) \
+ .with_shard(worker, 3)
+ return [_split_signature(s) for s in scan.plan().splits()]
+
+ for worker in range(3):
+ self.assertEqual(plan_files(worker), plan_files(worker))
+
+
+class ChunkShuffleCompatibilityTest(unittest.TestCase):
+ """Validates the reject-on-incompatible-combination matrix."""
+
+ @classmethod
+ def setUpClass(cls):
+ cls.tempdir = tempfile.mkdtemp()
+ cls.warehouse = os.path.join(cls.tempdir, 'warehouse')
+ cls.catalog = CatalogFactory.create({'warehouse': cls.warehouse})
+ cls.catalog.create_database('default', True)
+
+ @classmethod
+ def tearDownClass(cls):
+ shutil.rmtree(cls.tempdir, ignore_errors=True)
+
+ def _append_table(self, name, options=None, partition_keys=None):
+ if partition_keys:
+ pa_schema = pa.schema([
+ ('id', pa.int64()),
+ ('value', pa.string()),
+ ('part', pa.string()),
+ ])
+ else:
+ pa_schema = pa.schema([('id', pa.int64()), ('value', pa.string())])
+ schema = Schema.from_pyarrow_schema(
+ pa_schema, partition_keys=partition_keys, options=options or {})
+ self.catalog.create_table(f'default.{name}', schema, False)
+ return self.catalog.get_table(f'default.{name}')
+
+ def _pk_table(self, name):
+ pa_schema = pa.schema([
+ pa.field('id', pa.int64(), nullable=False),
+ ('value', pa.string()),
+ ])
+ schema = Schema.from_pyarrow_schema(
+ pa_schema, primary_keys=['id'], options={'bucket': '1'})
+ self.catalog.create_table(f'default.{name}', schema, False)
+ return self.catalog.get_table(f'default.{name}')
+
+ def test_pk_table_rejected(self):
+ table = self._pk_table('cs_pk')
+ scan = table.new_read_builder().new_scan()
+ scan.with_chunk_shuffle(seed=1, chunk_size=100)
+ with self.assertRaisesRegex(ValueError, "only supports append tables"):
+ scan.plan()
+
+ def test_dv_table_rejected(self):
+ table = self._append_table('cs_dv',
options={'deletion-vectors.enabled': 'true'})
+ scan = table.new_read_builder().new_scan()
+ scan.with_chunk_shuffle(seed=1, chunk_size=100)
+ with self.assertRaisesRegex(ValueError, "deletion vectors"):
+ scan.plan()
+
+ def test_with_slice_then_chunk_shuffle_rejected(self):
+ table = self._append_table('cs_slice')
+ scan = table.new_read_builder().new_scan()
+ scan.with_slice(0, 100).with_chunk_shuffle(seed=1, chunk_size=100)
+ with self.assertRaisesRegex(ValueError, "with_slice"):
+ scan.plan()
+
+ def test_limit_with_chunk_shuffle_rejected(self):
+ table = self._append_table('cs_limit')
+ scan = table.new_read_builder().with_limit(50).new_scan()
+ scan.with_chunk_shuffle(seed=1, chunk_size=100)
+ with self.assertRaisesRegex(ValueError, "limit"):
+ scan.plan()
+
+ def test_invalid_chunk_size(self):
+ table = self._append_table('cs_invalid')
+ scan = table.new_read_builder().new_scan()
+ with self.assertRaisesRegex(ValueError, "chunk_size"):
+ scan.with_chunk_shuffle(seed=1, chunk_size=0)
+ with self.assertRaisesRegex(ValueError, "chunk_size"):
+ scan.with_chunk_shuffle(seed=1, chunk_size=-5)
+
+ def test_column_predicate_rejected(self):
+ # Non-partition predicate would silently shrink effective chunk
+ # row counts inside the reader → not allowed.
+ table = self._append_table('cs_col_pred', partition_keys=['part'])
+ rb = table.new_read_builder()
+ col_pred = rb.new_predicate_builder().equal('id', 5)
+ rb = rb.with_filter(col_pred)
+ scan = rb.new_scan().with_chunk_shuffle(seed=1, chunk_size=10)
+ with self.assertRaisesRegex(ValueError, "partition keys"):
+ scan.plan()
+
+ def test_partition_predicate_allowed(self):
+ # Filter is partition-only → must succeed and read only the
+ # matching partition.
+ table, pa_schema = self._partitioned_table_with_data('cs_part_pred')
+
+ rb = table.new_read_builder()
+ pred = rb.new_predicate_builder().equal('part', 'p1')
+ scan = rb.with_filter(pred).new_scan() \
+ .with_chunk_shuffle(seed=1, chunk_size=10)
+ plan = scan.plan()
+ # All splits should be from partition 'p1'
+ for split in plan.splits():
+ partition_values = split.partition.values
+ self.assertEqual(tuple(partition_values), ('p1',))
+
+ def _partitioned_table_with_data(self, name):
+ pa_schema = pa.schema([
+ ('id', pa.int64()),
+ ('value', pa.string()),
+ ('part', pa.string()),
+ ])
+ schema = Schema.from_pyarrow_schema(pa_schema, partition_keys=['part'])
+ identifier = f'default.{name}'
+ self.catalog.create_table(identifier, schema, False)
+ table = self.catalog.get_table(identifier)
+ wb = table.new_batch_write_builder()
+ for part, ids in [('p1', range(50)), ('p2', range(50, 100))]:
+ tw = wb.new_write()
+ tc = wb.new_commit()
+ tw.write_arrow(pa.Table.from_pydict(
+ {'id': list(ids),
+ 'value': [f'v{i}' for i in ids],
+ 'part': [part] * 50},
+ schema=pa_schema,
+ ))
+ tc.commit(tw.prepare_commit())
+ tw.close()
+ tc.close()
+ return table, pa_schema
+
+
+class DataEvolutionChunkShuffleAlgoTest(unittest.TestCase):
+ """Mock-based tests for the DE chunk slicer."""
+
+ def test_no_entries_returns_empty(self):
+ gen = _make_de_generator(seed=1, chunk_size=100)
+ self.assertEqual(gen.create_splits([]), [])
+
+ def test_full_aligned_groups_one_per_chunk(self):
+ # Three commits of 100 rows each → three aligned groups.
+ # chunk_size = 100 → 3 chunks, each holding one group whole.
+ entries = [
+ _mock_de_entry([], 0, 'g0.parquet', 0, 100),
+ _mock_de_entry([], 0, 'g1.parquet', 100, 100),
+ _mock_de_entry([], 0, 'g2.parquet', 200, 100),
+ ]
+ gen = _make_de_generator(seed=1, chunk_size=100)
+ splits = gen.create_splits(entries)
+ self.assertEqual(len(splits), 3)
+ for s in splits:
+ self.assertIsInstance(s, IndexedSplit)
+ self.assertEqual(s.row_count, 100)
+ self.assertEqual(len(s.row_ranges()), 1)
+
+ def test_aligned_group_split_across_chunks(self):
+ # One 250-row group, chunk_size=100 → 3 chunks (100, 100, 50).
+ # All three chunks reference the SAME aligned group's files but
+ # each carries a distinct row_range slice.
+ entries = [_mock_de_entry([], 0, 'g0.parquet', 1000, 250)]
+ gen = _make_de_generator(seed=1, chunk_size=100)
+ splits = gen.create_splits(entries)
+ self.assertEqual(len(splits), 3)
+
+ # Union of the three chunks' row_ranges must cover the whole group
[1000, 1249].
+ ranges = []
+ for s in splits:
+ self.assertIsInstance(s, IndexedSplit)
+ ranges.extend((r.from_, r.to) for r in s.row_ranges())
+ ranges.sort()
+ self.assertEqual(ranges, [(1000, 1099), (1100, 1199), (1200, 1249)])
+ total = sum(r[1] - r[0] + 1 for r in ranges)
+ self.assertEqual(total, 250)
+
+ def test_chunk_pulls_in_blob_siblings(self):
+ # One aligned group with a main parquet and a blob sibling sharing the
+ # row_id range. A single chunk must include BOTH files so the reader
+ # can union the columns.
+ entries = [
+ _mock_de_entry([], 0, 'g0.parquet', 0, 100),
+ _mock_de_entry([], 0, 'g0.blob', 0, 100), # .blob ext →
is_blob_file
+ ]
+ gen = _make_de_generator(seed=1, chunk_size=100)
+ splits = gen.create_splits(entries)
+ self.assertEqual(len(splits), 1)
+ files = sorted(f.file_name for f in splits[0].files)
+ self.assertEqual(files, ['g0.blob', 'g0.parquet'])
+
+ def test_blob_propagates_when_group_split(self):
+ # Same scenario but chunk_size halves the group → the blob sibling
+ # must appear in BOTH chunk splits.
+ entries = [
+ _mock_de_entry([], 0, 'g0.parquet', 0, 100),
+ _mock_de_entry([], 0, 'g0.blob', 0, 100),
+ ]
+ gen = _make_de_generator(seed=1, chunk_size=50)
+ splits = gen.create_splits(entries)
+ self.assertEqual(len(splits), 2)
+ for s in splits:
+ files = sorted(f.file_name for f in s.files)
+ self.assertEqual(files, ['g0.blob', 'g0.parquet'])
+
+ def test_deterministic_same_seed(self):
+ entries = [_mock_de_entry([], 0, f'g{i:02d}.parquet', i * 100, 100)
for i in range(20)]
+ gen1 = _make_de_generator(seed=42, chunk_size=100)
+ gen2 = _make_de_generator(seed=42, chunk_size=100)
+ sigs1 = [_split_signature(s) for s in gen1.create_splits(entries)]
+ sigs2 = [_split_signature(s) for s in gen2.create_splits(entries)]
+ self.assertEqual(sigs1, sigs2)
+
+ def test_different_seed_reorders(self):
+ entries = [_mock_de_entry([], 0, f'g{i:02d}.parquet', i * 100, 100)
for i in range(50)]
+ gen1 = _make_de_generator(seed=1, chunk_size=100)
+ gen2 = _make_de_generator(seed=2, chunk_size=100)
+ sigs1 = [_split_signature(s) for s in gen1.create_splits(entries)]
+ sigs2 = [_split_signature(s) for s in gen2.create_splits(entries)]
+ self.assertEqual(sorted(sigs1), sorted(sigs2))
+ self.assertNotEqual(sigs1, sigs2)
+
+ def test_input_order_does_not_affect_output(self):
+ a = _mock_de_entry([], 0, 'g0.parquet', 0, 100)
+ b = _mock_de_entry([], 0, 'g1.parquet', 100, 100)
+ c = _mock_de_entry([], 0, 'g2.parquet', 200, 100)
+ gen1 = _make_de_generator(seed=99, chunk_size=100)
+ gen2 = _make_de_generator(seed=99, chunk_size=100)
+ sigs1 = [_split_signature(s) for s in gen1.create_splits([a, b, c])]
+ sigs2 = [_split_signature(s) for s in gen2.create_splits([c, a, b])]
+ self.assertEqual(sigs1, sigs2)
+
+ def test_shard_round_trip_no_overlap_no_loss(self):
+ # 13 aligned groups × 100 rows = 1300 rows. Shard across 4 workers.
+ entries = [_mock_de_entry([], 0, f'g{i:02d}.parquet', i * 100, 100)
for i in range(13)]
+ num_workers = 4
+
+ unsharded = _make_de_generator(seed=7,
chunk_size=100).create_splits(list(entries))
+ unsharded_sigs = sorted(_split_signature(s) for s in unsharded)
+
+ sharded_sigs = []
+ total_rows = 0
+ for w in range(num_workers):
+ gen = _make_de_generator(seed=7, chunk_size=100)
+ gen.with_shard(w, num_workers)
+ for s in gen.create_splits(list(entries)):
+ sharded_sigs.append(_split_signature(s))
+ total_rows += s.row_count
+ self.assertEqual(total_rows, 13 * 100)
+ # No duplicate splits across workers
+ self.assertEqual(len(sharded_sigs), len(set(sharded_sigs)))
+ self.assertEqual(sorted(sharded_sigs), unsharded_sigs)
+
+ def test_multi_partition_no_chunk_crosses_partition(self):
+ entries = [
+ _mock_de_entry(['p1'], 0, 'g0.parquet', 0, 100),
+ _mock_de_entry(['p1'], 0, 'g1.parquet', 100, 100),
+ _mock_de_entry(['p2'], 0, 'g2.parquet', 200, 100),
+ _mock_de_entry(['p2'], 0, 'g3.parquet', 300, 100),
+ ]
+ gen = _make_de_generator(seed=0, chunk_size=100)
+ splits = gen.create_splits(entries)
+ for s in splits:
+ data_split = s.data_split() if isinstance(s, IndexedSplit) else s
+ self.assertEqual(len({tuple(data_split.partition.values)}), 1)
+
+ def test_null_and_non_null_partitions_sort_safely(self):
+ # Same null-vs-non-null sort guard, exercised on the DE path.
+ entries = [
+ _mock_de_entry(['p1'], 0, 'g0.parquet', 0, 100),
+ _mock_de_entry([None], 0, 'g1.parquet', 100, 100),
+ _mock_de_entry(['p2'], 0, 'g2.parquet', 200, 100),
+ ]
+ gen = _make_de_generator(seed=1, chunk_size=100)
+ splits = gen.create_splits(entries)
+ self.assertEqual(len(splits), 3)
+ partitions = {_split_signature(s)[0] for s in splits}
+ self.assertEqual(partitions, {('p1',), ('p2',), (None,)})
+
+
+class DataEvolutionChunkShuffleEndToEndTest(unittest.TestCase):
+ """Real DE table → with_chunk_shuffle → multi-worker → union == full
table."""
+
+ @classmethod
+ def setUpClass(cls):
+ cls.tempdir = tempfile.mkdtemp()
+ cls.warehouse = os.path.join(cls.tempdir, 'warehouse')
+ cls.catalog = CatalogFactory.create({'warehouse': cls.warehouse})
+ cls.catalog.create_database('default', True)
+
+ @classmethod
+ def tearDownClass(cls):
+ shutil.rmtree(cls.tempdir, ignore_errors=True)
+
+ def _create_de_table(self, name):
+ pa_schema = pa.schema([
+ ('id', pa.int32()),
+ ('value', pa.string()),
+ ('payload', pa.large_binary()),
+ ])
+ schema = Schema.from_pyarrow_schema(
+ pa_schema,
+ options={
+ 'row-tracking.enabled': 'true',
+ 'data-evolution.enabled': 'true',
+ 'blob.target-file-size': '1 b',
+ },
+ )
+ identifier = f'default.{name}'
+ self.catalog.create_table(identifier, schema, False)
+ return self.catalog.get_table(identifier), pa_schema
+
+ @staticmethod
+ def _payloads(ids):
+ return [f'payload-{i:03d}'.encode('utf-8') for i in ids]
+
+ def _commit_full_rows(self, table, pa_schema, ids):
+ wb = table.new_batch_write_builder()
+ tw = wb.new_write()
+ tc = wb.new_commit()
+ tw.write_arrow(pa.Table.from_pydict(
+ {
+ 'id': ids,
+ 'value': [f'v{i}' for i in ids],
+ 'payload': self._payloads(ids),
+ },
+ schema=pa_schema))
+ commit_messages = tw.prepare_commit()
+ tc.commit(commit_messages)
+ tw.close()
+ tc.close()
+ return commit_messages
+
+ def _assert_commit_has_main_and_multiple_blob_files(self, commit_messages):
+ all_files = [f for msg in commit_messages for f in msg.new_files]
+ main_files = [f for f in all_files if not
DataFileMeta.is_blob_file(f.file_name)]
+ blob_files = [f for f in all_files if
DataFileMeta.is_blob_file(f.file_name)]
+ self.assertGreaterEqual(len(main_files), 1)
+ self.assertGreater(
+ len(blob_files), 1,
+ "DE chunk-shuffle tests should exercise one row-id group with
multiple blob files",
+ )
+
+ def _assert_splits_include_blob_files(self, splits):
+ self.assertGreater(len(splits), 0)
+ for split in splits:
+ data_split = split.data_split() if isinstance(split, IndexedSplit)
else split
+ blob_files = [
+ f for f in data_split.files
+ if DataFileMeta.is_blob_file(f.file_name)
+ ]
+ self.assertGreater(
+ len(blob_files), 0,
+ "Each DE chunk should keep blob sidecar files with its aligned
row-id group",
+ )
+
+ def test_workers_union_equals_full_table(self):
+ table, pa_schema = self._create_de_table('cs_de_union')
+ # 4 commits → 4 aligned groups. Each group has one normal file and
+ # multiple blob sidecar files because blob.target-file-size is 1 byte.
+ for c in range(4):
+ base = c * 50
+ commit_messages = self._commit_full_rows(
+ table, pa_schema, list(range(base, base + 50)))
+
self._assert_commit_has_main_and_multiple_blob_files(commit_messages)
+
+ read_builder = table.new_read_builder()
+ table_read = read_builder.new_read()
+
+ num_workers = 3
+ worker_tables = []
+ for w in range(num_workers):
+ scan = read_builder.new_scan() \
+ .with_chunk_shuffle(seed=123, chunk_size=37) \
+ .with_shard(w, num_workers)
+ splits = scan.plan().splits()
+ if splits:
+ self._assert_splits_include_blob_files(splits)
+ worker_tables.append(table_read.to_arrow(splits))
+
+ actual = pa.concat_tables(worker_tables).sort_by('id')
+ self.assertEqual(actual.num_rows, 200)
+ self.assertEqual(actual.column('id').to_pylist(), list(range(200)))
+ self.assertEqual(actual.column('payload').to_pylist(),
self._payloads(range(200)))
+
+ def test_deterministic_plan_across_calls(self):
+ table, pa_schema = self._create_de_table('cs_de_determinism')
+ for c in range(3):
+ base = c * 40
+ commit_messages = self._commit_full_rows(
+ table, pa_schema, list(range(base, base + 40)))
+
self._assert_commit_has_main_and_multiple_blob_files(commit_messages)
+
+ def plan_sigs(worker):
+ scan = table.new_read_builder().new_scan() \
+ .with_chunk_shuffle(seed=42, chunk_size=15) \
+ .with_shard(worker, 4)
+ splits = scan.plan().splits()
+ if splits:
+ self._assert_splits_include_blob_files(splits)
+ return [_split_signature(s) for s in splits]
+
+ for worker in range(4):
+ self.assertEqual(plan_sigs(worker), plan_sigs(worker))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/paimon-python/pypaimon/tests/torch_read_test.py
b/paimon-python/pypaimon/tests/torch_read_test.py
index ac6088ece3..5f55cb2bc8 100644
--- a/paimon-python/pypaimon/tests/torch_read_test.py
+++ b/paimon-python/pypaimon/tests/torch_read_test.py
@@ -645,6 +645,236 @@ class TorchReadTest(unittest.TestCase):
print("✓ All predicate test cases passed!")
print(f"{'=' * 60}\n")
+ def test_torch_streaming_shuffle_single_worker(self):
+ table =
self._create_shuffle_append_table('default.test_torch_shuffle_single')
+ read_builder = table.new_read_builder().with_projection(['user_id'])
+ table_read = read_builder.new_read()
+ splits = read_builder.new_scan().plan().splits()
+
+ expected = list(range(80))
+ for max_buffer_input_splits in [1, 3]:
+ with self.subTest(max_buffer_input_splits=max_buffer_input_splits):
+ dataset = table_read.to_torch(
+ splits,
+ streaming=True,
+ shuffle=True,
+ seed=17,
+ buffer_size=7,
+ max_buffer_input_splits=max_buffer_input_splits,
+ )
+ ids = self._collect_torch_user_ids(dataset, num_workers=0)
+ self.assertEqual(sorted(ids), expected)
+ self.assertNotEqual(ids, expected)
+
+ def test_torch_streaming_shuffle_seed_and_epoch(self):
+ table =
self._create_shuffle_append_table('default.test_torch_shuffle_epoch')
+ read_builder = table.new_read_builder().with_projection(['user_id'])
+ table_read = read_builder.new_read()
+ splits = read_builder.new_scan().plan().splits()
+
+ dataset = table_read.to_torch(
+ splits,
+ streaming=True,
+ shuffle=True,
+ seed=23,
+ buffer_size=11,
+ max_buffer_input_splits=4,
+ )
+ epoch0 = self._collect_torch_user_ids(dataset, num_workers=0)
+ epoch0_again = self._collect_torch_user_ids(dataset, num_workers=0)
+ self.assertEqual(epoch0, epoch0_again)
+
+ dataset.set_epoch(1)
+ epoch1 = self._collect_torch_user_ids(dataset, num_workers=0)
+ self.assertEqual(sorted(epoch1), list(range(80)))
+ self.assertNotEqual(epoch0, epoch1)
+
+ dataset.set_epoch(0)
+ self.assertEqual(epoch0, self._collect_torch_user_ids(dataset,
num_workers=0))
+
+ other_seed_dataset = table_read.to_torch(
+ splits,
+ streaming=True,
+ shuffle=True,
+ seed=24,
+ buffer_size=11,
+ max_buffer_input_splits=4,
+ )
+ self.assertNotEqual(
+ epoch0,
+ self._collect_torch_user_ids(other_seed_dataset, num_workers=0),
+ )
+
+ def test_torch_streaming_shuffle_epoch_with_persistent_workers(self):
+ table =
self._create_shuffle_append_table('default.test_torch_shuffle_persistent_epoch')
+ read_builder = table.new_read_builder().with_projection(['user_id'])
+ table_read = read_builder.new_read()
+ splits = read_builder.new_scan().plan().splits()
+
+ dataset = table_read.to_torch(
+ splits,
+ streaming=True,
+ shuffle=True,
+ seed=23,
+ buffer_size=11,
+ max_buffer_input_splits=4,
+ )
+ dataloader = DataLoader(
+ dataset,
+ batch_size=8,
+ num_workers=2,
+ persistent_workers=True,
+ shuffle=False,
+ )
+
+ epoch0 = self._collect_torch_user_ids_from_dataloader(dataloader)
+ self.assertEqual(epoch0,
self._collect_torch_user_ids_from_dataloader(dataloader))
+
+ dataset.set_epoch(1)
+ epoch1 = self._collect_torch_user_ids_from_dataloader(dataloader)
+ self.assertEqual(sorted(epoch1), list(range(80)))
+ self.assertNotEqual(epoch0, epoch1)
+
+ def test_torch_streaming_shuffle_multi_worker(self):
+ table =
self._create_shuffle_append_table('default.test_torch_shuffle_multi')
+ read_builder = table.new_read_builder().with_projection(['user_id'])
+ table_read = read_builder.new_read()
+ splits = read_builder.new_scan() \
+ .with_chunk_shuffle(seed=31, chunk_size=5) \
+ .plan() \
+ .splits()
+
+ dataset = table_read.to_torch(
+ splits,
+ streaming=True,
+ shuffle=True,
+ seed=31,
+ buffer_size=13,
+ max_buffer_input_splits=4,
+ )
+ ids = self._collect_torch_user_ids(dataset, num_workers=2)
+
+ expected = list(range(80))
+ self.assertEqual(len(ids), len(expected))
+ self.assertEqual(sorted(ids), expected)
+
+ def test_torch_streaming_shuffle_rejects_non_streaming(self):
+ table =
self._create_shuffle_append_table('default.test_torch_shuffle_non_streaming')
+ read_builder = table.new_read_builder()
+ table_read = read_builder.new_read()
+ splits = read_builder.new_scan().plan().splits()
+
+ with self.assertRaisesRegex(ValueError, "streaming=True"):
+ table_read.to_torch(splits, streaming=False, shuffle=True)
+
+ def test_torch_streaming_shuffle_accepts_pk_table_splits(self):
+ pa_schema = pa.schema([
+ pa.field('user_id', pa.int32(), nullable=False),
+ ('item_id', pa.int64()),
+ ('behavior', pa.string()),
+ ('dt', pa.string())
+ ])
+ schema = Schema.from_pyarrow_schema(
+ pa_schema,
+ primary_keys=['user_id'],
+ options={'bucket': '1'},
+ )
+ self.catalog.create_table('default.test_torch_shuffle_pk', schema,
False)
+ table = self.catalog.get_table('default.test_torch_shuffle_pk')
+ self._write_test_table(table)
+
+ read_builder = table.new_read_builder().with_projection(['user_id'])
+ splits = read_builder.new_scan().plan().splits()
+ dataset = read_builder.new_read().to_torch(
+ splits,
+ streaming=True,
+ shuffle=True,
+ seed=7,
+ buffer_size=3,
+ )
+ ids = self._collect_torch_user_ids(dataset, num_workers=0)
+
+ self.assertEqual(sorted(ids), [1, 2, 3, 4, 5, 6, 7, 8])
+
+ def test_torch_streaming_shuffle_rejects_invalid_dataset_options(self):
+ table =
self._create_shuffle_append_table('default.test_torch_shuffle_invalid_options')
+ read_builder = table.new_read_builder().with_projection(['user_id'])
+ table_read = read_builder.new_read()
+ splits = read_builder.new_scan().plan().splits()
+
+ with self.assertRaisesRegex(ValueError, "prefetch_concurrency"):
+ table_read.to_torch(
+ splits,
+ streaming=True,
+ shuffle=True,
+ prefetch_concurrency=2,
+ )
+ with self.assertRaisesRegex(ValueError, "buffer_size"):
+ table_read.to_torch(
+ splits,
+ streaming=True,
+ shuffle=True,
+ buffer_size=0,
+ )
+ with self.assertRaisesRegex(ValueError, "max_buffer_input_splits"):
+ table_read.to_torch(
+ splits,
+ streaming=True,
+ shuffle=True,
+ max_buffer_input_splits=0,
+ )
+
+ def _create_shuffle_append_table(
+ self,
+ identifier,
+ total_rows=80,
+ rows_per_commit=10,
+ partition_keys=None,
+ ):
+ schema = Schema.from_pyarrow_schema(
+ self.pa_schema,
+ partition_keys=partition_keys or [],
+ )
+ self.catalog.create_table(identifier, schema, False)
+ table = self.catalog.get_table(identifier)
+
+ write_builder = table.new_batch_write_builder()
+ for start in range(0, total_rows, rows_per_commit):
+ end = min(start + rows_per_commit, total_rows)
+ table_write = write_builder.new_write()
+ table_commit = write_builder.new_commit()
+ pa_table = pa.Table.from_pydict({
+ 'user_id': list(range(start, end)),
+ 'item_id': [1000 + i for i in range(start, end)],
+ 'behavior': [chr(ord('a') + (i % 26)) for i in range(start,
end)],
+ 'dt': [f'p{i % 4}' for i in range(start, end)],
+ }, schema=self.pa_schema)
+ table_write.write_arrow(pa_table)
+ table_commit.commit(table_write.prepare_commit())
+ table_write.close()
+ table_commit.close()
+ return table
+
+ @staticmethod
+ def _collect_torch_user_ids(dataset, num_workers=0):
+ dataloader = DataLoader(
+ dataset,
+ batch_size=8,
+ num_workers=num_workers,
+ shuffle=False,
+ )
+ all_user_ids = []
+ for batch_data in dataloader:
+ all_user_ids.extend(batch_data['user_id'].tolist())
+ return all_user_ids
+
+ @staticmethod
+ def _collect_torch_user_ids_from_dataloader(dataloader):
+ all_user_ids = []
+ for batch_data in dataloader:
+ all_user_ids.extend(batch_data['user_id'].tolist())
+ return all_user_ids
+
def _write_test_table(self, table):
write_builder = table.new_batch_write_builder()
table_pa_schema = self.pk_pa_schema if table.primary_keys else
self.pa_schema