This is an automated email from the ASF dual-hosted git repository.

lzljs3620320 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git


The following commit(s) were added to refs/heads/master by this push:
     new 36278e3627 [python] Add multi-threaded prefetch for pytorch streaming 
read (#7143)
36278e3627 is described below

commit 36278e3627d8c715273a76b549622ed3ea6558c4
Author: XiaoHongbo <[email protected]>
AuthorDate: Fri Feb 27 13:12:22 2026 +0800

    [python] Add multi-threaded prefetch for pytorch streaming read (#7143)
---
 docs/content/pypaimon/pytorch.md                   |  4 +-
 .../pypaimon/read/datasource/torch_dataset.py      | 83 +++++++++++++++++++++-
 paimon-python/pypaimon/read/table_read.py          |  9 ++-
 paimon-python/pypaimon/tests/torch_read_test.py    | 36 ++++++++++
 4 files changed, 128 insertions(+), 4 deletions(-)

diff --git a/docs/content/pypaimon/pytorch.md b/docs/content/pypaimon/pytorch.md
index b34f49edcd..6ab485f2c9 100644
--- a/docs/content/pypaimon/pytorch.md
+++ b/docs/content/pypaimon/pytorch.md
@@ -37,7 +37,7 @@ You can read all the data into a `torch.utils.data.Dataset` 
or `torch.utils.data
 from torch.utils.data import DataLoader
 
 table_read = read_builder.new_read()
-dataset = table_read.to_torch(splits, streaming=True)
+dataset = table_read.to_torch(splits, streaming=True, prefetch_concurrency=2)
 dataloader = DataLoader(
     dataset,
     batch_size=2,
@@ -58,3 +58,5 @@ for batch_idx, batch_data in enumerate(dataloader):
 
 When the `streaming` parameter is true, it will iteratively read;
 when it is false, it will read the full amount of data into memory.
+
+**`prefetch_concurrency`** (default: 1): When streaming is true, number of 
threads used for parallel prefetch within each DataLoader worker. Set to a 
value greater than 1 to partition splits across threads and increase read 
throughput. Has no effect when streaming is false.
diff --git a/paimon-python/pypaimon/read/datasource/torch_dataset.py 
b/paimon-python/pypaimon/read/datasource/torch_dataset.py
index a800295f9e..97ebc53566 100644
--- a/paimon-python/pypaimon/read/datasource/torch_dataset.py
+++ b/paimon-python/pypaimon/read/datasource/torch_dataset.py
@@ -18,6 +18,8 @@
 """
 Module to read a Paimon table into PyTorch Dataset.
 """
+import queue
+import threading
 from typing import List
 
 import torch
@@ -83,19 +85,38 @@ class TorchIterDataset(IterableDataset):
     rather than loading everything into memory upfront.
     """
 
-    def __init__(self, table_read: TableRead, splits: List[Split]):
+    _SENTINEL = 0
+    _ROW = 1
+    _ERR = 2
+    _PREFETCH_QUEUE_MAXSIZE = 512
+    _PREFETCH_PUT_TIMEOUT_SEC = 30.0
+    _PREFETCH_GET_TIMEOUT_SEC = 300.0
+    _PREFETCH_JOIN_TIMEOUT_SEC = 5.0
+
+    def __init__(self, table_read: TableRead, splits: List[Split], 
prefetch_concurrency: int = 1):
         """
         Initialize TorchIterDataset.
 
         Args:
             table_read: TableRead instance for reading data
             splits: List of splits to read
+            prefetch_concurrency: Number of threads to use for parallel OSS 
reads within
+                this worker (default 1). When > 1, splits are partitioned 
across
+                threads to increase read throughput.
         """
         self.table_read = table_read
         self.splits = splits
+        self.prefetch_concurrency = max(1, int(prefetch_concurrency))
         # Get field names from read_type
         self.field_names = [field.name for field in table_read.read_type]
 
+    def _row_to_dict(self, offset_row) -> dict:
+        row_dict = {}
+        for i, field_name in enumerate(self.field_names):
+            value = offset_row.get_field(i)
+            row_dict[field_name] = value
+        return row_dict
+
     def __iter__(self):
         """
         Iterate over the dataset, converting each OffsetRow to a dictionary.
@@ -132,6 +153,11 @@ class TorchIterDataset(IterableDataset):
 
             splits_to_process = self.splits[start_idx:end_idx]
 
+        if self.prefetch_concurrency > 1:
+            for row in self._iter_rows(splits_to_process):
+                yield row
+            return
+
         worker_iterator = self.table_read.to_iterator(splits_to_process)
 
         for offset_row in worker_iterator:
@@ -140,3 +166,58 @@ class TorchIterDataset(IterableDataset):
                 value = offset_row.get_field(i)
                 row_dict[field_name] = value
             yield row_dict
+
+    def _iter_rows(self, splits: List[Split]):
+        n = min(self.prefetch_concurrency, len(splits))
+        if n == 0:
+            return
+        split_groups = [splits[i::n] for i in range(n)]
+
+        q = queue.Queue(maxsize=self._PREFETCH_QUEUE_MAXSIZE)
+        stop = threading.Event()
+
+        def put_item(tag: int, payload):
+            while not stop.is_set():
+                try:
+                    q.put((tag, payload), 
timeout=self._PREFETCH_PUT_TIMEOUT_SEC)
+                    return True
+                except queue.Full:
+                    continue
+            return False
+
+        def producer(split_group: List):
+            try:
+                for offset_row in self.table_read.to_iterator(split_group):
+                    if stop.is_set():
+                        break
+                    row_dict = self._row_to_dict(offset_row)
+                    if not put_item(self._ROW, row_dict):
+                        break
+                put_item(self._SENTINEL, None)
+            except Exception as e:
+                put_item(self._ERR, e)
+
+        threads = [threading.Thread(target=producer, args=(split_groups[i],), 
daemon=True)
+                   for i in range(n)]
+        for t in threads:
+            t.start()
+
+        try:
+            done = 0
+            while done < n:
+                try:
+                    tag, payload = 
q.get(timeout=self._PREFETCH_GET_TIMEOUT_SEC)
+                except queue.Empty:
+                    if stop.is_set():
+                        break
+                    continue
+                if tag == self._SENTINEL:
+                    done += 1
+                elif tag == self._ERR:
+                    raise payload
+                else:
+                    yield payload
+        finally:
+            stop.set()
+            for t in threads:
+                t.join(timeout=self._PREFETCH_JOIN_TIMEOUT_SEC)
diff --git a/paimon-python/pypaimon/read/table_read.py 
b/paimon-python/pypaimon/read/table_read.py
index 142444dcd4..5206147f80 100644
--- a/paimon-python/pypaimon/read/table_read.py
+++ b/paimon-python/pypaimon/read/table_read.py
@@ -225,11 +225,16 @@ class TableRead:
             **read_args
         )
 
-    def to_torch(self, splits: List[Split], streaming: bool = False) -> 
"torch.utils.data.Dataset":
+    def to_torch(
+        self,
+        splits: List[Split],
+        streaming: bool = False,
+        prefetch_concurrency: int = 1,
+    ) -> "torch.utils.data.Dataset":
         """Wrap Paimon table data to PyTorch Dataset."""
         if streaming:
             from pypaimon.read.datasource.torch_dataset import TorchIterDataset
-            dataset = TorchIterDataset(self, splits)
+            dataset = TorchIterDataset(self, splits, prefetch_concurrency)
             return dataset
         else:
             from pypaimon.read.datasource.torch_dataset import TorchDataset
diff --git a/paimon-python/pypaimon/tests/torch_read_test.py 
b/paimon-python/pypaimon/tests/torch_read_test.py
index b6862c6cb1..6e4d5cdbb4 100644
--- a/paimon-python/pypaimon/tests/torch_read_test.py
+++ b/paimon-python/pypaimon/tests/torch_read_test.py
@@ -100,6 +100,42 @@ class TorchReadTest(unittest.TestCase):
 
         print(f"✓ Test passed: Successfully read {len(all_user_ids)} rows with 
correct data")
 
+    def test_torch_streaming_prefetch_concurrency(self):
+        schema = Schema.from_pyarrow_schema(self.pa_schema, 
partition_keys=['user_id'])
+        self.catalog.create_table('default.test_torch_prefetch_concurrency', 
schema, False)
+        table = 
self.catalog.get_table('default.test_torch_prefetch_concurrency')
+        self._write_test_table(table)
+
+        read_builder = table.new_read_builder().with_projection(['user_id', 
'behavior'])
+        table_scan = read_builder.new_scan()
+        table_read = read_builder.new_read()
+        splits = table_scan.plan().splits()
+        self.assertGreater(len(splits), 0, "Need at least one split to test 
prefetch")
+
+        dataset = table_read.to_torch(splits, streaming=True, 
prefetch_concurrency=4)
+        dataloader = DataLoader(
+            dataset,
+            batch_size=2,
+            num_workers=0,
+            shuffle=False
+        )
+
+        all_user_ids = []
+        all_behaviors = []
+        for batch_data in dataloader:
+            all_user_ids.extend(batch_data['user_id'].tolist())
+            all_behaviors.extend(batch_data['behavior'])
+
+        sorted_data = sorted(zip(all_user_ids, all_behaviors), key=lambda x: 
x[0])
+        sorted_user_ids = [x[0] for x in sorted_data]
+        sorted_behaviors = [x[1] for x in sorted_data]
+
+        expected_user_ids = [1, 2, 3, 4, 5, 6, 7, 8]
+        expected_behaviors = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h']
+        self.assertEqual(len(all_user_ids), 8, "Should read 8 rows with 
prefetch_concurrency")
+        self.assertEqual(sorted_user_ids, expected_user_ids)
+        self.assertEqual(sorted_behaviors, expected_behaviors)
+
     def test_blob_torch_read(self):
         """Test end-to-end blob functionality using blob descriptors."""
         import random

Reply via email to