This is an automated email from the ASF dual-hosted git repository.
lzljs3620320 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git
The following commit(s) were added to refs/heads/master by this push:
new 36278e3627 [python] Add multi-threaded prefetch for pytorch streaming
read (#7143)
36278e3627 is described below
commit 36278e3627d8c715273a76b549622ed3ea6558c4
Author: XiaoHongbo <[email protected]>
AuthorDate: Fri Feb 27 13:12:22 2026 +0800
[python] Add multi-threaded prefetch for pytorch streaming read (#7143)
---
docs/content/pypaimon/pytorch.md | 4 +-
.../pypaimon/read/datasource/torch_dataset.py | 83 +++++++++++++++++++++-
paimon-python/pypaimon/read/table_read.py | 9 ++-
paimon-python/pypaimon/tests/torch_read_test.py | 36 ++++++++++
4 files changed, 128 insertions(+), 4 deletions(-)
diff --git a/docs/content/pypaimon/pytorch.md b/docs/content/pypaimon/pytorch.md
index b34f49edcd..6ab485f2c9 100644
--- a/docs/content/pypaimon/pytorch.md
+++ b/docs/content/pypaimon/pytorch.md
@@ -37,7 +37,7 @@ You can read all the data into a `torch.utils.data.Dataset`
or `torch.utils.data
from torch.utils.data import DataLoader
table_read = read_builder.new_read()
-dataset = table_read.to_torch(splits, streaming=True)
+dataset = table_read.to_torch(splits, streaming=True, prefetch_concurrency=2)
dataloader = DataLoader(
dataset,
batch_size=2,
@@ -58,3 +58,5 @@ for batch_idx, batch_data in enumerate(dataloader):
When the `streaming` parameter is true, it will iteratively read;
when it is false, it will read the full amount of data into memory.
+
+**`prefetch_concurrency`** (default: 1): When streaming is true, number of
threads used for parallel prefetch within each DataLoader worker. Set to a
value greater than 1 to partition splits across threads and increase read
throughput. Has no effect when streaming is false.
diff --git a/paimon-python/pypaimon/read/datasource/torch_dataset.py
b/paimon-python/pypaimon/read/datasource/torch_dataset.py
index a800295f9e..97ebc53566 100644
--- a/paimon-python/pypaimon/read/datasource/torch_dataset.py
+++ b/paimon-python/pypaimon/read/datasource/torch_dataset.py
@@ -18,6 +18,8 @@
"""
Module to read a Paimon table into PyTorch Dataset.
"""
+import queue
+import threading
from typing import List
import torch
@@ -83,19 +85,38 @@ class TorchIterDataset(IterableDataset):
rather than loading everything into memory upfront.
"""
- def __init__(self, table_read: TableRead, splits: List[Split]):
+ _SENTINEL = 0
+ _ROW = 1
+ _ERR = 2
+ _PREFETCH_QUEUE_MAXSIZE = 512
+ _PREFETCH_PUT_TIMEOUT_SEC = 30.0
+ _PREFETCH_GET_TIMEOUT_SEC = 300.0
+ _PREFETCH_JOIN_TIMEOUT_SEC = 5.0
+
+ def __init__(self, table_read: TableRead, splits: List[Split],
prefetch_concurrency: int = 1):
"""
Initialize TorchIterDataset.
Args:
table_read: TableRead instance for reading data
splits: List of splits to read
+ prefetch_concurrency: Number of threads to use for parallel OSS
reads within
+ this worker (default 1). When > 1, splits are partitioned
across
+ threads to increase read throughput.
"""
self.table_read = table_read
self.splits = splits
+ self.prefetch_concurrency = max(1, int(prefetch_concurrency))
# Get field names from read_type
self.field_names = [field.name for field in table_read.read_type]
+ def _row_to_dict(self, offset_row) -> dict:
+ row_dict = {}
+ for i, field_name in enumerate(self.field_names):
+ value = offset_row.get_field(i)
+ row_dict[field_name] = value
+ return row_dict
+
def __iter__(self):
"""
Iterate over the dataset, converting each OffsetRow to a dictionary.
@@ -132,6 +153,11 @@ class TorchIterDataset(IterableDataset):
splits_to_process = self.splits[start_idx:end_idx]
+ if self.prefetch_concurrency > 1:
+ for row in self._iter_rows(splits_to_process):
+ yield row
+ return
+
worker_iterator = self.table_read.to_iterator(splits_to_process)
for offset_row in worker_iterator:
@@ -140,3 +166,58 @@ class TorchIterDataset(IterableDataset):
value = offset_row.get_field(i)
row_dict[field_name] = value
yield row_dict
+
+ def _iter_rows(self, splits: List[Split]):
+ n = min(self.prefetch_concurrency, len(splits))
+ if n == 0:
+ return
+ split_groups = [splits[i::n] for i in range(n)]
+
+ q = queue.Queue(maxsize=self._PREFETCH_QUEUE_MAXSIZE)
+ stop = threading.Event()
+
+ def put_item(tag: int, payload):
+ while not stop.is_set():
+ try:
+ q.put((tag, payload),
timeout=self._PREFETCH_PUT_TIMEOUT_SEC)
+ return True
+ except queue.Full:
+ continue
+ return False
+
+ def producer(split_group: List):
+ try:
+ for offset_row in self.table_read.to_iterator(split_group):
+ if stop.is_set():
+ break
+ row_dict = self._row_to_dict(offset_row)
+ if not put_item(self._ROW, row_dict):
+ break
+ put_item(self._SENTINEL, None)
+ except Exception as e:
+ put_item(self._ERR, e)
+
+ threads = [threading.Thread(target=producer, args=(split_groups[i],),
daemon=True)
+ for i in range(n)]
+ for t in threads:
+ t.start()
+
+ try:
+ done = 0
+ while done < n:
+ try:
+ tag, payload =
q.get(timeout=self._PREFETCH_GET_TIMEOUT_SEC)
+ except queue.Empty:
+ if stop.is_set():
+ break
+ continue
+ if tag == self._SENTINEL:
+ done += 1
+ elif tag == self._ERR:
+ raise payload
+ else:
+ yield payload
+ finally:
+ stop.set()
+ for t in threads:
+ t.join(timeout=self._PREFETCH_JOIN_TIMEOUT_SEC)
diff --git a/paimon-python/pypaimon/read/table_read.py
b/paimon-python/pypaimon/read/table_read.py
index 142444dcd4..5206147f80 100644
--- a/paimon-python/pypaimon/read/table_read.py
+++ b/paimon-python/pypaimon/read/table_read.py
@@ -225,11 +225,16 @@ class TableRead:
**read_args
)
- def to_torch(self, splits: List[Split], streaming: bool = False) ->
"torch.utils.data.Dataset":
+ def to_torch(
+ self,
+ splits: List[Split],
+ streaming: bool = False,
+ prefetch_concurrency: int = 1,
+ ) -> "torch.utils.data.Dataset":
"""Wrap Paimon table data to PyTorch Dataset."""
if streaming:
from pypaimon.read.datasource.torch_dataset import TorchIterDataset
- dataset = TorchIterDataset(self, splits)
+ dataset = TorchIterDataset(self, splits, prefetch_concurrency)
return dataset
else:
from pypaimon.read.datasource.torch_dataset import TorchDataset
diff --git a/paimon-python/pypaimon/tests/torch_read_test.py
b/paimon-python/pypaimon/tests/torch_read_test.py
index b6862c6cb1..6e4d5cdbb4 100644
--- a/paimon-python/pypaimon/tests/torch_read_test.py
+++ b/paimon-python/pypaimon/tests/torch_read_test.py
@@ -100,6 +100,42 @@ class TorchReadTest(unittest.TestCase):
print(f"✓ Test passed: Successfully read {len(all_user_ids)} rows with
correct data")
+ def test_torch_streaming_prefetch_concurrency(self):
+ schema = Schema.from_pyarrow_schema(self.pa_schema,
partition_keys=['user_id'])
+ self.catalog.create_table('default.test_torch_prefetch_concurrency',
schema, False)
+ table =
self.catalog.get_table('default.test_torch_prefetch_concurrency')
+ self._write_test_table(table)
+
+ read_builder = table.new_read_builder().with_projection(['user_id',
'behavior'])
+ table_scan = read_builder.new_scan()
+ table_read = read_builder.new_read()
+ splits = table_scan.plan().splits()
+ self.assertGreater(len(splits), 0, "Need at least one split to test
prefetch")
+
+ dataset = table_read.to_torch(splits, streaming=True,
prefetch_concurrency=4)
+ dataloader = DataLoader(
+ dataset,
+ batch_size=2,
+ num_workers=0,
+ shuffle=False
+ )
+
+ all_user_ids = []
+ all_behaviors = []
+ for batch_data in dataloader:
+ all_user_ids.extend(batch_data['user_id'].tolist())
+ all_behaviors.extend(batch_data['behavior'])
+
+ sorted_data = sorted(zip(all_user_ids, all_behaviors), key=lambda x:
x[0])
+ sorted_user_ids = [x[0] for x in sorted_data]
+ sorted_behaviors = [x[1] for x in sorted_data]
+
+ expected_user_ids = [1, 2, 3, 4, 5, 6, 7, 8]
+ expected_behaviors = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h']
+ self.assertEqual(len(all_user_ids), 8, "Should read 8 rows with
prefetch_concurrency")
+ self.assertEqual(sorted_user_ids, expected_user_ids)
+ self.assertEqual(sorted_behaviors, expected_behaviors)
+
def test_blob_torch_read(self):
"""Test end-to-end blob functionality using blob descriptors."""
import random