This is an automated email from the ASF dual-hosted git repository. jin pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/incubator-hugegraph-ai.git
The following commit(s) were added to refs/heads/main by this push: new 8a61cf7 feat(llm): support batch embedding (#238) 8a61cf7 is described below commit 8a61cf7991fe9dc264ad197e8ed1fd52e61494a9 Author: Linyu <94553312+weijing...@users.noreply.github.com> AuthorDate: Wed May 21 19:49:15 2025 +0800 feat(llm): support batch embedding (#238) ## Implement Batch Embedding by modifying the underlying LLM interaction interface - [✔] ollama - [✔ ] openai - [✔ ] qianfan I've modified the original concurrent call to a single batch call in build_semantic_index.py & perform a simple test close #233 --------- Co-authored-by: imbajin <j...@apache.org> --- hugegraph-llm/pyproject.toml | 2 +- hugegraph-llm/requirements.txt | 2 +- .../src/hugegraph_llm/models/embeddings/base.py | 23 ++++++++++++++++ .../src/hugegraph_llm/models/embeddings/ollama.py | 26 +++++++++++++++++- .../src/hugegraph_llm/models/embeddings/openai.py | 24 +++++++++++++++++ .../src/hugegraph_llm/models/embeddings/qianfan.py | 8 ++++++ .../index_op/build_gremlin_example_index.py | 3 ++- .../operators/index_op/build_semantic_index.py | 31 +++++++++++++--------- .../index_op/gremlin_example_index_query.py | 2 +- 9 files changed, 103 insertions(+), 18 deletions(-) diff --git a/hugegraph-llm/pyproject.toml b/hugegraph-llm/pyproject.toml index 3f72aa5..dfb2681 100644 --- a/hugegraph-llm/pyproject.toml +++ b/hugegraph-llm/pyproject.toml @@ -39,7 +39,7 @@ documentation = "https://hugegraph.apache.org/docs/quickstart/hugegraph-ai/" [tool.poetry.dependencies] python = "^3.10,<3.12" openai = "~1.61.0" -ollama = "~0.2.1" +ollama = "~0.4.8" qianfan = "~0.3.18" retry = "~0.9.2" tiktoken = ">=0.7.0" diff --git a/hugegraph-llm/requirements.txt b/hugegraph-llm/requirements.txt index 7467ec6..3abe63e 100644 --- a/hugegraph-llm/requirements.txt +++ b/hugegraph-llm/requirements.txt @@ -1,5 +1,5 @@ openai~=1.61.0 -ollama~=0.2.1 +ollama~=0.4.8 qianfan~=0.3.18 retry~=0.9.2 tiktoken>=0.7.0 diff --git a/hugegraph-llm/src/hugegraph_llm/models/embeddings/base.py b/hugegraph-llm/src/hugegraph_llm/models/embeddings/base.py index 2ea8786..73e973e 100644 --- a/hugegraph-llm/src/hugegraph_llm/models/embeddings/base.py +++ b/hugegraph-llm/src/hugegraph_llm/models/embeddings/base.py @@ -60,6 +60,29 @@ class BaseEmbedding(ABC): ) -> List[float]: """Comment""" + @abstractmethod + def get_texts_embeddings( + self, + texts: List[str] + ) -> List[List[float]]: + """Get embeddings for multiple texts in a single batch. + + This method should efficiently process multiple texts at once by leveraging + the embedding model's batching capabilities, which is typically more efficient + than processing texts individually. + + Parameters + ---------- + texts : List[str] + A list of text strings to be embedded. + + Returns + ------- + List[List[float]] + A list of embedding vectors, where each vector is a list of floats. + The order of embeddings should match the order of input texts. + """ + @abstractmethod async def async_get_text_embedding( self, diff --git a/hugegraph-llm/src/hugegraph_llm/models/embeddings/ollama.py b/hugegraph-llm/src/hugegraph_llm/models/embeddings/ollama.py index 81e11cc..062e098 100644 --- a/hugegraph-llm/src/hugegraph_llm/models/embeddings/ollama.py +++ b/hugegraph-llm/src/hugegraph_llm/models/embeddings/ollama.py @@ -40,7 +40,31 @@ class OllamaEmbedding(BaseEmbedding): text: str ) -> List[float]: """Comment""" - return list(self.client.embeddings(model=self.model, prompt=text)["embedding"]) + return list(self.client.embed(model=self.model, input=text)["embeddings"][0]) + + def get_texts_embeddings( + self, + texts: List[str] + ) -> List[List[float]]: + """Get embeddings for multiple texts in a single batch. + + This method efficiently processes multiple texts at once by leveraging + Ollama's batching capabilities, which is more efficient than processing + texts individually. + + Parameters + ---------- + texts : List[str] + A list of text strings to be embedded. + + Returns + ------- + List[List[float]] + A list of embedding vectors, where each vector is a list of floats. + The order of embeddings matches the order of input texts. + """ + response = self.client.embed(model=self.model, input=texts)["embeddings"] + return [list(inner_sequence) for inner_sequence in response] async def async_get_text_embedding( self, diff --git a/hugegraph-llm/src/hugegraph_llm/models/embeddings/openai.py b/hugegraph-llm/src/hugegraph_llm/models/embeddings/openai.py index aacef1e..890f491 100644 --- a/hugegraph-llm/src/hugegraph_llm/models/embeddings/openai.py +++ b/hugegraph-llm/src/hugegraph_llm/models/embeddings/openai.py @@ -38,6 +38,30 @@ class OpenAIEmbedding: response = self.client.embeddings.create(input=text, model=self.embedding_model_name) return response.data[0].embedding + def get_texts_embeddings( + self, + texts: List[str] + ) -> List[List[float]]: + """Get embeddings for multiple texts in a single batch. + + This method efficiently processes multiple texts at once by leveraging + OpenAI's batching capabilities, which is more efficient than processing + texts individually. + + Parameters + ---------- + texts : List[str] + A list of text strings to be embedded. + + Returns + ------- + List[List[float]] + A list of embedding vectors, where each vector is a list of floats. + The order of embeddings matches the order of input texts. + """ + response = self.client.embeddings.create(input=texts, model=self.embedding_model_name) + return [data.embedding for data in response.data] + async def async_get_text_embedding(self, text: str) -> List[float]: """Comment""" response = await self.aclient.embeddings.create(input=text, model=self.embedding_model_name) diff --git a/hugegraph-llm/src/hugegraph_llm/models/embeddings/qianfan.py b/hugegraph-llm/src/hugegraph_llm/models/embeddings/qianfan.py index 1745eb2..99eeb59 100644 --- a/hugegraph-llm/src/hugegraph_llm/models/embeddings/qianfan.py +++ b/hugegraph-llm/src/hugegraph_llm/models/embeddings/qianfan.py @@ -51,6 +51,14 @@ class QianFanEmbedding: ) return response["body"]["data"][0]["embedding"] + def get_texts_embeddings(self, texts: List[str]) -> List[List[float]]: + """ Usage refer: https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlmokk9qn""" + response = self.client.do( + model=self.embedding_model_name, + texts=texts + ) + return [data["embedding"] for data in response["body"]["data"]] + async def async_get_text_embedding(self, text: str) -> List[float]: """ Usage refer: https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlmokk9qn""" response = await self.client.ado( diff --git a/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_gremlin_example_index.py b/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_gremlin_example_index.py index 4e15274..b865bc6 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_gremlin_example_index.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_gremlin_example_index.py @@ -19,11 +19,12 @@ import os from typing import Dict, Any, List -from hugegraph_llm.models.embeddings.base import BaseEmbedding from hugegraph_llm.config import resource_path from hugegraph_llm.indices.vector_index import VectorIndex +from hugegraph_llm.models.embeddings.base import BaseEmbedding +# FIXME: we need keep the logic same with build_semantic_index.py class BuildGremlinExampleIndex: def __init__(self, embedding: BaseEmbedding, examples: List[Dict[str, str]]): self.index_dir = os.path.join(resource_path, "gremlin_examples") diff --git a/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_semantic_index.py b/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_semantic_index.py index f8a911e..ce64442 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_semantic_index.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_semantic_index.py @@ -25,9 +25,8 @@ from tqdm import tqdm from hugegraph_llm.config import resource_path, huge_settings from hugegraph_llm.indices.vector_index import VectorIndex from hugegraph_llm.models.embeddings.base import BaseEmbedding -from hugegraph_llm.operators.hugegraph_op.schema_manager import SchemaManager from hugegraph_llm.utils.log import log - +from hugegraph_llm.operators.hugegraph_op.schema_manager import SchemaManager class BuildSemanticIndex: def __init__(self, embedding: BaseEmbedding): @@ -41,35 +40,38 @@ class BuildSemanticIndex: async def _get_embeddings_parallel(self, vids: list[str]) -> list[Any]: sem = asyncio.Semaphore(10) - - async def get_embedding_with_semaphore(vid: str) -> Any: + batch_size = 1000 + async def get_embeddings_with_semaphore(vid_list: list[str]) -> Any: # Executes sync embedding method in a thread pool via loop.run_in_executor, combining async programming # with multi-threading capabilities. # This pattern avoids blocking the event loop and prepares for a future fully async pipeline. async with sem: loop = asyncio.get_running_loop() - return await loop.run_in_executor(None, self.embedding.get_text_embedding, vid) + return await loop.run_in_executor(None, self.embedding.get_texts_embeddings, vid_list) + + # Split vids into batches of size batch_size + vid_batches = [vids[i:i + batch_size] for i in range(0, len(vids), batch_size)] + + # Create tasks for each batch + tasks = [get_embeddings_with_semaphore(batch) for batch in vid_batches] - tasks = [get_embedding_with_semaphore(vid) for vid in vids] embeddings = [] with tqdm(total=len(tasks)) as pbar: for future in asyncio.as_completed(tasks): - embedding = await future - embeddings.append(embedding) + batch_embeddings = await future + embeddings.extend(batch_embeddings) # Extend the list with batch results pbar.update(1) return embeddings def run(self, context: Dict[str, Any]) -> Dict[str, Any]: vertexlabels = self.sm.schema.getSchema()["vertexlabels"] - all_pk_flag = all(data.get("id_strategy") == "PRIMARY_KEY" for data in vertexlabels) + all_pk_flag = all(data.get('id_strategy') == 'PRIMARY_KEY' for data in vertexlabels) past_vids = self.vid_index.properties # TODO: We should build vid vector index separately, especially when the vertices may be very large - present_vids = context["vertices"] # Warning: data truncated by fetch_graph_data.py + present_vids = context["vertices"] # Warning: data truncated by fetch_graph_data.py removed_vids = set(past_vids) - set(present_vids) removed_num = self.vid_index.remove(removed_vids) - if removed_vids: - self.vid_index.to_index_file(self.index_dir) added_vids = list(set(present_vids) - set(past_vids)) if added_vids: @@ -80,5 +82,8 @@ class BuildSemanticIndex: self.vid_index.to_index_file(self.index_dir) else: log.debug("No update vertices to build vector index.") - context.update({"removed_vid_vector_num": removed_num, "added_vid_vector_num": len(added_vids)}) + context.update({ + "removed_vid_vector_num": removed_num, + "added_vid_vector_num": len(added_vids) + }) return context diff --git a/hugegraph-llm/src/hugegraph_llm/operators/index_op/gremlin_example_index_query.py b/hugegraph-llm/src/hugegraph_llm/operators/index_op/gremlin_example_index_query.py index 31d9c2b..b8acd50 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/index_op/gremlin_example_index_query.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/index_op/gremlin_example_index_query.py @@ -57,7 +57,7 @@ class GremlinExampleIndexQuery: def _build_default_example_index(self): properties = pd.read_csv(os.path.join(resource_path, "demo", "text2gremlin.csv")).to_dict(orient="records") from concurrent.futures import ThreadPoolExecutor - # TODO: use asyncio for IO tasks + # TODO: reuse the logic in build_semantic_index.py (consider extract the batch-embedding method) with ThreadPoolExecutor() as executor: embeddings = list( tqdm(executor.map(self.embedding.get_text_embedding, [row["query"] for row in properties]),