This is an automated email from the ASF dual-hosted git repository. jin pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/incubator-hugegraph-ai.git
The following commit(s) were added to refs/heads/main by this push: new 0c9c15c2 fix(llm): refactor embedding parallelization to preserve order (#295) 0c9c15c2 is described below commit 0c9c15c20fee3748718c5a77b3120469b6b06d02 Author: imbajin <j...@apache.org> AuthorDate: Thu Jul 31 17:30:41 2025 +0800 fix(llm): refactor embedding parallelization to preserve order (#295) Reworked get_embeddings_parallel to use asyncio.gather for batch processing, ensuring output order matches input. Added a helper for batch progress updates and improved progress bar accuracy. --- .../src/hugegraph_llm/utils/embedding_utils.py | 35 ++++++++++++++-------- 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/hugegraph-llm/src/hugegraph_llm/utils/embedding_utils.py b/hugegraph-llm/src/hugegraph_llm/utils/embedding_utils.py index 2c7b3874..4209890f 100644 --- a/hugegraph-llm/src/hugegraph_llm/utils/embedding_utils.py +++ b/hugegraph-llm/src/hugegraph_llm/utils/embedding_utils.py @@ -24,9 +24,13 @@ from tqdm import tqdm from hugegraph_llm.models.embeddings.base import BaseEmbedding -async def get_embeddings_parallel( - embedding: BaseEmbedding, vids: list[str] -) -> list[Any]: +async def _get_batch_with_progress(embedding: BaseEmbedding, batch: list[str], pbar: tqdm) -> list[Any]: + result = await embedding.async_get_texts_embeddings(batch) + pbar.update(1) + return result + + +async def get_embeddings_parallel(embedding: BaseEmbedding, vids: list[str]) -> list[Any]: """Get embeddings for texts in parallel. This function processes text embeddings asynchronously in parallel, using batching and semaphore @@ -43,20 +47,27 @@ async def get_embeddings_parallel( Note: - Note: Uses a semaphore to limit maximum concurrency if we need - Processes texts in batches of 500 - - Displays progress using a progress bar + - Displays progress using a progress bar that updates as each batch completes + - Uses asyncio.gather() to preserve order correspondence between input and output """ batch_size = 500 # Split vids into batches of size batch_size vid_batches = [vids[i : i + batch_size] for i in range(0, len(vids), batch_size)] - # Create tasks for each batch - tasks = [embedding.async_get_texts_embeddings(batch) for batch in vid_batches] - embeddings = [] - with tqdm(total=len(tasks)) as pbar: - for future in asyncio.as_completed(tasks): - batch_embeddings = await future - embeddings.extend(batch_embeddings) # Extend the list with batch results - pbar.update(1) + with tqdm(total=len(vid_batches)) as pbar: + # Create tasks for each batch with progress bar updates + tasks = [ + _get_batch_with_progress(embedding, batch, pbar) + for batch in vid_batches + ] + + # Use asyncio.gather() to preserve order + batch_results = await asyncio.gather(*tasks) + + # Combine all batch results in order + for batch_embeddings in batch_results: + embeddings.extend(batch_embeddings) + return embeddings