This is an automated email from the ASF dual-hosted git repository.

jin pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-hugegraph-ai.git


The following commit(s) were added to refs/heads/main by this push:
     new 8a61cf7  feat(llm): support batch embedding (#238)
8a61cf7 is described below

commit 8a61cf7991fe9dc264ad197e8ed1fd52e61494a9
Author: Linyu <94553312+weijing...@users.noreply.github.com>
AuthorDate: Wed May 21 19:49:15 2025 +0800

    feat(llm): support batch embedding (#238)
    
    ## Implement Batch Embedding by modifying the underlying LLM interaction
    interface
    
    - [✔] ollama
    - [✔ ] openai
    - [✔ ] qianfan
    
    I've modified the original concurrent call to a single batch call in
    build_semantic_index.py & perform a simple test
    
    close #233
    
    ---------
    
    Co-authored-by: imbajin <j...@apache.org>
---
 hugegraph-llm/pyproject.toml                       |  2 +-
 hugegraph-llm/requirements.txt                     |  2 +-
 .../src/hugegraph_llm/models/embeddings/base.py    | 23 ++++++++++++++++
 .../src/hugegraph_llm/models/embeddings/ollama.py  | 26 +++++++++++++++++-
 .../src/hugegraph_llm/models/embeddings/openai.py  | 24 +++++++++++++++++
 .../src/hugegraph_llm/models/embeddings/qianfan.py |  8 ++++++
 .../index_op/build_gremlin_example_index.py        |  3 ++-
 .../operators/index_op/build_semantic_index.py     | 31 +++++++++++++---------
 .../index_op/gremlin_example_index_query.py        |  2 +-
 9 files changed, 103 insertions(+), 18 deletions(-)

diff --git a/hugegraph-llm/pyproject.toml b/hugegraph-llm/pyproject.toml
index 3f72aa5..dfb2681 100644
--- a/hugegraph-llm/pyproject.toml
+++ b/hugegraph-llm/pyproject.toml
@@ -39,7 +39,7 @@ documentation = 
"https://hugegraph.apache.org/docs/quickstart/hugegraph-ai/";
 [tool.poetry.dependencies]
 python = "^3.10,<3.12"
 openai = "~1.61.0"
-ollama = "~0.2.1"
+ollama = "~0.4.8"
 qianfan = "~0.3.18"
 retry = "~0.9.2"
 tiktoken = ">=0.7.0"
diff --git a/hugegraph-llm/requirements.txt b/hugegraph-llm/requirements.txt
index 7467ec6..3abe63e 100644
--- a/hugegraph-llm/requirements.txt
+++ b/hugegraph-llm/requirements.txt
@@ -1,5 +1,5 @@
 openai~=1.61.0
-ollama~=0.2.1
+ollama~=0.4.8
 qianfan~=0.3.18
 retry~=0.9.2
 tiktoken>=0.7.0
diff --git a/hugegraph-llm/src/hugegraph_llm/models/embeddings/base.py 
b/hugegraph-llm/src/hugegraph_llm/models/embeddings/base.py
index 2ea8786..73e973e 100644
--- a/hugegraph-llm/src/hugegraph_llm/models/embeddings/base.py
+++ b/hugegraph-llm/src/hugegraph_llm/models/embeddings/base.py
@@ -60,6 +60,29 @@ class BaseEmbedding(ABC):
     ) -> List[float]:
         """Comment"""
 
+    @abstractmethod
+    def get_texts_embeddings(
+            self,
+            texts: List[str]
+    ) -> List[List[float]]:
+        """Get embeddings for multiple texts in a single batch.
+        
+        This method should efficiently process multiple texts at once by 
leveraging
+        the embedding model's batching capabilities, which is typically more 
efficient
+        than processing texts individually.
+        
+        Parameters
+        ----------
+        texts : List[str]
+            A list of text strings to be embedded.
+            
+        Returns
+        -------
+        List[List[float]]
+            A list of embedding vectors, where each vector is a list of floats.
+            The order of embeddings should match the order of input texts.
+        """
+
     @abstractmethod
     async def async_get_text_embedding(
             self,
diff --git a/hugegraph-llm/src/hugegraph_llm/models/embeddings/ollama.py 
b/hugegraph-llm/src/hugegraph_llm/models/embeddings/ollama.py
index 81e11cc..062e098 100644
--- a/hugegraph-llm/src/hugegraph_llm/models/embeddings/ollama.py
+++ b/hugegraph-llm/src/hugegraph_llm/models/embeddings/ollama.py
@@ -40,7 +40,31 @@ class OllamaEmbedding(BaseEmbedding):
             text: str
     ) -> List[float]:
         """Comment"""
-        return list(self.client.embeddings(model=self.model, 
prompt=text)["embedding"])
+        return list(self.client.embed(model=self.model, 
input=text)["embeddings"][0])
+
+    def get_texts_embeddings(
+            self,
+            texts: List[str]
+    ) -> List[List[float]]:
+        """Get embeddings for multiple texts in a single batch.
+        
+        This method efficiently processes multiple texts at once by leveraging
+        Ollama's batching capabilities, which is more efficient than processing
+        texts individually.
+        
+        Parameters
+        ----------
+        texts : List[str]
+            A list of text strings to be embedded.
+            
+        Returns
+        -------
+        List[List[float]]
+            A list of embedding vectors, where each vector is a list of floats.
+            The order of embeddings matches the order of input texts.
+        """
+        response = self.client.embed(model=self.model, 
input=texts)["embeddings"]
+        return [list(inner_sequence) for inner_sequence in response]
 
     async def async_get_text_embedding(
             self,
diff --git a/hugegraph-llm/src/hugegraph_llm/models/embeddings/openai.py 
b/hugegraph-llm/src/hugegraph_llm/models/embeddings/openai.py
index aacef1e..890f491 100644
--- a/hugegraph-llm/src/hugegraph_llm/models/embeddings/openai.py
+++ b/hugegraph-llm/src/hugegraph_llm/models/embeddings/openai.py
@@ -38,6 +38,30 @@ class OpenAIEmbedding:
         response = self.client.embeddings.create(input=text, 
model=self.embedding_model_name)
         return response.data[0].embedding
 
+    def get_texts_embeddings(
+            self,
+            texts: List[str]
+    ) -> List[List[float]]:
+        """Get embeddings for multiple texts in a single batch.
+        
+        This method efficiently processes multiple texts at once by leveraging
+        OpenAI's batching capabilities, which is more efficient than processing
+        texts individually.
+        
+        Parameters
+        ----------
+        texts : List[str]
+            A list of text strings to be embedded.
+            
+        Returns
+        -------
+        List[List[float]]
+            A list of embedding vectors, where each vector is a list of floats.
+            The order of embeddings matches the order of input texts.
+        """
+        response = self.client.embeddings.create(input=texts, 
model=self.embedding_model_name)
+        return [data.embedding for data in response.data]
+
     async def async_get_text_embedding(self, text: str) -> List[float]:
         """Comment"""
         response = await self.aclient.embeddings.create(input=text, 
model=self.embedding_model_name)
diff --git a/hugegraph-llm/src/hugegraph_llm/models/embeddings/qianfan.py 
b/hugegraph-llm/src/hugegraph_llm/models/embeddings/qianfan.py
index 1745eb2..99eeb59 100644
--- a/hugegraph-llm/src/hugegraph_llm/models/embeddings/qianfan.py
+++ b/hugegraph-llm/src/hugegraph_llm/models/embeddings/qianfan.py
@@ -51,6 +51,14 @@ class QianFanEmbedding:
         )
         return response["body"]["data"][0]["embedding"]
 
+    def get_texts_embeddings(self, texts: List[str]) -> List[List[float]]:
+        """ Usage refer: 
https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlmokk9qn""";
+        response = self.client.do(
+            model=self.embedding_model_name,
+            texts=texts
+        )
+        return [data["embedding"] for data in response["body"]["data"]]
+
     async def async_get_text_embedding(self, text: str) -> List[float]:
         """ Usage refer: 
https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlmokk9qn""";
         response = await self.client.ado(
diff --git 
a/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_gremlin_example_index.py
 
b/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_gremlin_example_index.py
index 4e15274..b865bc6 100644
--- 
a/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_gremlin_example_index.py
+++ 
b/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_gremlin_example_index.py
@@ -19,11 +19,12 @@
 import os
 from typing import Dict, Any, List
 
-from hugegraph_llm.models.embeddings.base import BaseEmbedding
 from hugegraph_llm.config import resource_path
 from hugegraph_llm.indices.vector_index import VectorIndex
+from hugegraph_llm.models.embeddings.base import BaseEmbedding
 
 
+# FIXME: we need keep the logic same with build_semantic_index.py
 class BuildGremlinExampleIndex:
     def __init__(self, embedding: BaseEmbedding, examples: List[Dict[str, 
str]]):
         self.index_dir = os.path.join(resource_path, "gremlin_examples")
diff --git 
a/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_semantic_index.py 
b/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_semantic_index.py
index f8a911e..ce64442 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_semantic_index.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_semantic_index.py
@@ -25,9 +25,8 @@ from tqdm import tqdm
 from hugegraph_llm.config import resource_path, huge_settings
 from hugegraph_llm.indices.vector_index import VectorIndex
 from hugegraph_llm.models.embeddings.base import BaseEmbedding
-from hugegraph_llm.operators.hugegraph_op.schema_manager import SchemaManager
 from hugegraph_llm.utils.log import log
-
+from hugegraph_llm.operators.hugegraph_op.schema_manager import SchemaManager
 
 class BuildSemanticIndex:
     def __init__(self, embedding: BaseEmbedding):
@@ -41,35 +40,38 @@ class BuildSemanticIndex:
 
     async def _get_embeddings_parallel(self, vids: list[str]) -> list[Any]:
         sem = asyncio.Semaphore(10)
-
-        async def get_embedding_with_semaphore(vid: str) -> Any:
+        batch_size = 1000
+        async def get_embeddings_with_semaphore(vid_list: list[str]) -> Any:
             # Executes sync embedding method in a thread pool via 
loop.run_in_executor, combining async programming
             # with multi-threading capabilities.
             # This pattern avoids blocking the event loop and prepares for a 
future fully async pipeline.
             async with sem:
                 loop = asyncio.get_running_loop()
-                return await loop.run_in_executor(None, 
self.embedding.get_text_embedding, vid)
+                return await loop.run_in_executor(None, 
self.embedding.get_texts_embeddings, vid_list)
+
+        # Split vids into batches of size batch_size
+        vid_batches = [vids[i:i + batch_size] for i in range(0, len(vids), 
batch_size)]
+
+        # Create tasks for each batch
+        tasks = [get_embeddings_with_semaphore(batch) for batch in vid_batches]
 
-        tasks = [get_embedding_with_semaphore(vid) for vid in vids]
         embeddings = []
         with tqdm(total=len(tasks)) as pbar:
             for future in asyncio.as_completed(tasks):
-                embedding = await future
-                embeddings.append(embedding)
+                batch_embeddings = await future
+                embeddings.extend(batch_embeddings) # Extend the list with 
batch results
                 pbar.update(1)
         return embeddings
 
     def run(self, context: Dict[str, Any]) -> Dict[str, Any]:
         vertexlabels = self.sm.schema.getSchema()["vertexlabels"]
-        all_pk_flag = all(data.get("id_strategy") == "PRIMARY_KEY" for data in 
vertexlabels)
+        all_pk_flag = all(data.get('id_strategy') == 'PRIMARY_KEY' for data in 
vertexlabels)
 
         past_vids = self.vid_index.properties
         # TODO: We should build vid vector index separately, especially when 
the vertices may be very large
-        present_vids = context["vertices"]  # Warning: data truncated by 
fetch_graph_data.py
+        present_vids = context["vertices"] # Warning: data truncated by 
fetch_graph_data.py
         removed_vids = set(past_vids) - set(present_vids)
         removed_num = self.vid_index.remove(removed_vids)
-        if removed_vids:
-            self.vid_index.to_index_file(self.index_dir)
         added_vids = list(set(present_vids) - set(past_vids))
 
         if added_vids:
@@ -80,5 +82,8 @@ class BuildSemanticIndex:
             self.vid_index.to_index_file(self.index_dir)
         else:
             log.debug("No update vertices to build vector index.")
-        context.update({"removed_vid_vector_num": removed_num, 
"added_vid_vector_num": len(added_vids)})
+        context.update({
+            "removed_vid_vector_num": removed_num,
+            "added_vid_vector_num": len(added_vids)
+        })
         return context
diff --git 
a/hugegraph-llm/src/hugegraph_llm/operators/index_op/gremlin_example_index_query.py
 
b/hugegraph-llm/src/hugegraph_llm/operators/index_op/gremlin_example_index_query.py
index 31d9c2b..b8acd50 100644
--- 
a/hugegraph-llm/src/hugegraph_llm/operators/index_op/gremlin_example_index_query.py
+++ 
b/hugegraph-llm/src/hugegraph_llm/operators/index_op/gremlin_example_index_query.py
@@ -57,7 +57,7 @@ class GremlinExampleIndexQuery:
     def _build_default_example_index(self):
         properties = pd.read_csv(os.path.join(resource_path, "demo", 
"text2gremlin.csv")).to_dict(orient="records")
         from concurrent.futures import ThreadPoolExecutor
-        # TODO: use asyncio for IO tasks
+        # TODO: reuse the logic in build_semantic_index.py (consider extract 
the batch-embedding method)
         with ThreadPoolExecutor() as executor:
             embeddings = list(
                 tqdm(executor.map(self.embedding.get_text_embedding, 
[row["query"] for row in properties]),

Reply via email to