This is an automated email from the ASF dual-hosted git repository.

jin pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-hugegraph-ai.git


The following commit(s) were added to refs/heads/main by this push:
     new 0c9c15c2 fix(llm): refactor embedding parallelization to preserve 
order (#295)
0c9c15c2 is described below

commit 0c9c15c20fee3748718c5a77b3120469b6b06d02
Author: imbajin <j...@apache.org>
AuthorDate: Thu Jul 31 17:30:41 2025 +0800

    fix(llm): refactor embedding parallelization to preserve order (#295)
    
    Reworked get_embeddings_parallel to use asyncio.gather for batch
    processing, ensuring output order matches input. Added a helper for
    batch progress updates and improved progress bar accuracy.
---
 .../src/hugegraph_llm/utils/embedding_utils.py     | 35 ++++++++++++++--------
 1 file changed, 23 insertions(+), 12 deletions(-)

diff --git a/hugegraph-llm/src/hugegraph_llm/utils/embedding_utils.py 
b/hugegraph-llm/src/hugegraph_llm/utils/embedding_utils.py
index 2c7b3874..4209890f 100644
--- a/hugegraph-llm/src/hugegraph_llm/utils/embedding_utils.py
+++ b/hugegraph-llm/src/hugegraph_llm/utils/embedding_utils.py
@@ -24,9 +24,13 @@ from tqdm import tqdm
 from hugegraph_llm.models.embeddings.base import BaseEmbedding
 
 
-async def get_embeddings_parallel(
-    embedding: BaseEmbedding, vids: list[str]
-) -> list[Any]:
+async def _get_batch_with_progress(embedding: BaseEmbedding, batch: list[str], 
pbar: tqdm) -> list[Any]:
+    result = await embedding.async_get_texts_embeddings(batch)
+    pbar.update(1)
+    return result
+
+
+async def get_embeddings_parallel(embedding: BaseEmbedding, vids: list[str]) 
-> list[Any]:
     """Get embeddings for texts in parallel.
 
     This function processes text embeddings asynchronously in parallel, using 
batching and semaphore
@@ -43,20 +47,27 @@ async def get_embeddings_parallel(
     Note:
         - Note: Uses a semaphore to limit maximum concurrency if we need
         - Processes texts in batches of 500
-        - Displays progress using a progress bar
+        - Displays progress using a progress bar that updates as each batch 
completes
+        - Uses asyncio.gather() to preserve order correspondence between input 
and output
     """
     batch_size = 500
 
     # Split vids into batches of size batch_size
     vid_batches = [vids[i : i + batch_size] for i in range(0, len(vids), 
batch_size)]
 
-    # Create tasks for each batch
-    tasks = [embedding.async_get_texts_embeddings(batch) for batch in 
vid_batches]
-
     embeddings = []
-    with tqdm(total=len(tasks)) as pbar:
-        for future in asyncio.as_completed(tasks):
-            batch_embeddings = await future
-            embeddings.extend(batch_embeddings)  # Extend the list with batch 
results
-            pbar.update(1)
+    with tqdm(total=len(vid_batches)) as pbar:
+        # Create tasks for each batch with progress bar updates
+        tasks = [
+            _get_batch_with_progress(embedding, batch, pbar)
+            for batch in vid_batches
+        ]
+
+        # Use asyncio.gather() to preserve order
+        batch_results = await asyncio.gather(*tasks)
+
+        # Combine all batch results in order
+        for batch_embeddings in batch_results:
+            embeddings.extend(batch_embeddings)
+
     return embeddings

Reply via email to