steFaiz commented on code in PR #8064:
URL: https://github.com/apache/paimon/pull/8064#discussion_r3345902313
##########
paimon-python/pypaimon/read/datasource/torch_dataset.py:
##########
@@ -221,3 +222,123 @@ def producer(split_group: List):
stop.set()
for t in threads:
t.join(timeout=self._PREFETCH_JOIN_TIMEOUT_SEC)
+
+
+class TorchShuffledIterDataset(_BaseTorchIterDataset):
+ """
+ PyTorch IterableDataset with Paimon-controlled streaming shuffle.
+
+ This dataset consumes pre-planned splits, then mixes rows with split
+ interleaving and a shuffle buffer. Chunk-level shuffle, when needed,
+ stays in TableScan.with_chunk_shuffle().
+ """
+
+ def __init__(
+ self,
+ table_read: TableRead,
+ splits: List[Split],
+ seed: int = 0,
+ buffer_size: int = 1000,
+ max_buffer_input_splits: int = 10,
+ ):
+ super().__init__(table_read, splits)
+ self.seed = self._require_int(seed, "seed")
+ self.buffer_size = self._require_positive_int(buffer_size,
"buffer_size")
+ self.max_buffer_input_splits = self._require_positive_int(
+ max_buffer_input_splits, "max_buffer_input_splits")
+ self.epoch = 0
+
+ @staticmethod
+ def _require_int(value: int, name: str) -> int:
+ if not isinstance(value, int):
+ raise ValueError("%s must be an int" % name)
+ return value
+
+ @staticmethod
+ def _require_positive_int(value: int, name: str) -> int:
+ if not isinstance(value, int) or value <= 0:
+ raise ValueError("%s must be a positive int" % name)
+ return value
+
+ def set_epoch(self, epoch: int) -> "TorchShuffledIterDataset":
Review Comment:
Thanks! I studied that the Dataloader mechanism is as below:
<img width="2042" height="682" alt="image"
src="https://github.com/user-attachments/assets/6d50f673-c943-4627-9ffb-1bf82fed0be6"
/>
The main process will create multiple worker processes through LINUX fork,
they share the same memory but with COW protection.
If `persistent worker` is true, workers are reused across different epochs.
If the main process changes the epoch, workers won't see it because of COW.
Now I also refer to huggingface Dataset, use `torch.Tensor.share_memory_()`
which will store the shared data in a special file, so that the changes are
visible to all processes.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]