This is an automated email from the ASF dual-hosted git repository. jin pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/incubator-hugegraph-ai.git
The following commit(s) were added to refs/heads/main by this push: new cd404b8c perf(llm): optimize vector index with asyncio embedding (#264) cd404b8c is described below commit cd404b8c7facade86bc541212e6e071ec6a61402 Author: Linyu <94553312+weijing...@users.noreply.github.com> AuthorDate: Thu Jul 31 17:02:58 2025 +0800 perf(llm): optimize vector index with asyncio embedding (#264) ## Changes This PR introduces performance optimizations for vector index building and querying by implementing parallel text embedding generation. ### Key Improvements 1. Added new utility class `embedding_utils.py` with parallel batch processing capabilities - Implements `get_embeddings_parallel` function for efficient batch processing - Uses asyncio with semaphore for controlled concurrency - Supports batch size of 1000 with max 10 concurrent tasks 2. Refactored all index operation classes to use parallel processing: - `BuildGremlinExampleIndex` - `BuildSemanticIndex` - `BuildVectorIndex` - `GremlinExampleIndexQuery` - `SemanticIdQuery` - `VectorIndexQuery` 3. Unified embedding generation approach: - Replaced individual `get_text_embedding` calls with batch `get_texts_embeddings` - Removed duplicate parallel processing code - Improved code reusability and maintainabilityl --------- Co-authored-by: imbajin <j...@apache.org> --- .asf.yaml | 2 +- .../config/models/base_prompt_config.py | 79 ++++++++++++++-------- .../src/hugegraph_llm/indices/vector_index.py | 8 ++- .../src/hugegraph_llm/models/embeddings/base.py | 25 +++++-- .../src/hugegraph_llm/models/embeddings/litellm.py | 6 +- .../src/hugegraph_llm/models/embeddings/ollama.py | 21 ++++-- .../src/hugegraph_llm/models/embeddings/openai.py | 24 +++++-- .../index_op/build_gremlin_example_index.py | 9 ++- .../operators/index_op/build_semantic_index.py | 64 +++++++----------- .../operators/index_op/build_vector_index.py | 7 +- .../index_op/gremlin_example_index_query.py | 12 ++-- .../operators/index_op/semantic_id_query.py | 4 +- .../operators/index_op/vector_index_query.py | 2 +- .../src/hugegraph_llm/utils/embedding_utils.py | 62 +++++++++++++++++ .../src/hugegraph_llm/utils/vector_index_utils.py | 10 ++- 15 files changed, 228 insertions(+), 107 deletions(-) diff --git a/.asf.yaml b/.asf.yaml index cafdba4f..21bb671e 100644 --- a/.asf.yaml +++ b/.asf.yaml @@ -53,7 +53,7 @@ github: # (for non-committer): assign/edit/close issues & PR, without write access to the code collaborators: - ChenZiHong-Gavin - - MrJs133 + - weijinglin - HJ-Young - afterimagex - returnToInnocence diff --git a/hugegraph-llm/src/hugegraph_llm/config/models/base_prompt_config.py b/hugegraph-llm/src/hugegraph_llm/config/models/base_prompt_config.py index 691247b3..23832bf9 100644 --- a/hugegraph-llm/src/hugegraph_llm/config/models/base_prompt_config.py +++ b/hugegraph-llm/src/hugegraph_llm/config/models/base_prompt_config.py @@ -15,8 +15,8 @@ # specific language governing permissions and limitations # under the License. -import os import sys +import os from pathlib import Path import yaml @@ -30,22 +30,24 @@ yaml_file_path = os.path.join(os.getcwd(), "src/hugegraph_llm/resources/demo", F class BasePromptConfig: - graph_schema: str = '' - extract_graph_prompt: str = '' - default_question: str = '' - custom_rerank_info: str = '' - answer_prompt: str = '' - keywords_extract_prompt: str = '' - text2gql_graph_schema: str = '' - gremlin_generate_prompt: str = '' - doc_input_text: str = '' - generate_extract_prompt_template: str = '' + graph_schema: str = "" + extract_graph_prompt: str = "" + default_question: str = "" + custom_rerank_info: str = "" + answer_prompt: str = "" + keywords_extract_prompt: str = "" + text2gql_graph_schema: str = "" + gremlin_generate_prompt: str = "" + doc_input_text: str = "" + generate_extract_prompt_template: str = "" def ensure_yaml_file_exists(self): current_dir = Path.cwd().resolve() project_root = get_project_root() if current_dir == project_root: - log.info("Current working directory is the project root, proceeding to run the app.") + log.info( + "Current working directory is the project root, proceeding to run the app." + ) else: error_msg = ( f"Current working directory is not the project root. " @@ -66,22 +68,42 @@ class BasePromptConfig: log.info("Prompt file '%s' doesn't exist, create it.", yaml_file_path) def save_to_yaml(self): - indented_schema = "\n".join([f" {line}" for line in self.graph_schema.splitlines()]) - indented_text2gql_schema = "\n".join([f" {line}" for line in self.text2gql_graph_schema.splitlines()]) - indented_gremlin_prompt = "\n".join([f" {line}" for line in self.gremlin_generate_prompt.splitlines()]) - indented_example_prompt = "\n".join([f" {line}" for line in self.extract_graph_prompt.splitlines()]) - indented_question = "\n".join([f" {line}" for line in self.default_question.splitlines()]) - indented_custom_related_information = ( - "\n".join([f" {line}" for line in self.custom_rerank_info.splitlines()]) + indented_schema = "\n".join( + [f" {line}" for line in self.graph_schema.splitlines()] + ) + indented_text2gql_schema = "\n".join( + [f" {line}" for line in self.text2gql_graph_schema.splitlines()] + ) + indented_gremlin_prompt = "\n".join( + [f" {line}" for line in self.gremlin_generate_prompt.splitlines()] + ) + indented_example_prompt = "\n".join( + [f" {line}" for line in self.extract_graph_prompt.splitlines()] + ) + indented_question = "\n".join( + [f" {line}" for line in self.default_question.splitlines()] ) - indented_default_answer_template = "\n".join([f" {line}" for line in self.answer_prompt.splitlines()]) - indented_keywords_extract_template = ( - "\n".join([f" {line}" for line in self.keywords_extract_prompt.splitlines()]) + indented_custom_related_information = "\n".join( + [f" {line}" for line in self.custom_rerank_info.splitlines()] + ) + indented_default_answer_template = "\n".join( + [f" {line}" for line in self.answer_prompt.splitlines()] + ) + indented_keywords_extract_template = "\n".join( + [f" {line}" for line in self.keywords_extract_prompt.splitlines()] + ) + indented_doc_input_text = "\n".join( + [f" {line}" for line in self.doc_input_text.splitlines()] + ) + indented_generate_extract_prompt = ( + "\n".join( + [ + f" {line}" + for line in self.generate_extract_prompt_template.splitlines() + ] + ) + + "\n" ) - indented_doc_input_text = "\n".join([f" {line}" for line in self.doc_input_text.splitlines()]) - indented_generate_extract_prompt = "\n".join( - [f" {line}" for line in self.generate_extract_prompt_template.splitlines()] - ) + "\n" # This can be extended to add storage fields according to the data needs to be stored yaml_content = f"""graph_schema: | {indented_schema} @@ -118,7 +140,10 @@ generate_extract_prompt_template: | def generate_yaml_file(self): if os.path.exists(yaml_file_path): - log.info("%s already exists, do you want to override with the default configuration? (y/n)", yaml_file_path) + log.info( + "%s already exists, do you want to override with the default configuration? (y/n)", + yaml_file_path, + ) update = input() if update.lower() != "y": return diff --git a/hugegraph-llm/src/hugegraph_llm/indices/vector_index.py b/hugegraph-llm/src/hugegraph_llm/indices/vector_index.py index 5e810ed9..301d741f 100644 --- a/hugegraph-llm/src/hugegraph_llm/indices/vector_index.py +++ b/hugegraph-llm/src/hugegraph_llm/indices/vector_index.py @@ -37,11 +37,13 @@ class VectorIndex: self.properties = [] @staticmethod - def from_index_file(dir_path: str) -> "VectorIndex": + def from_index_file(dir_path: str, record_miss: bool = True) -> "VectorIndex": index_file = os.path.join(dir_path, INDEX_FILE_NAME) properties_file = os.path.join(dir_path, PROPERTIES_FILE_NAME) - if not os.path.exists(index_file) or not os.path.exists(properties_file): - log.warning("No index file found, create a new one.") + miss_files = [f for f in [index_file, properties_file] if not os.path.exists(f)] + if miss_files: + if record_miss: + log.warning("Missing vector files: %s. \nNeed create a new one for it.", ", ".join(miss_files)) return VectorIndex() faiss_index = faiss.read_index(index_file) diff --git a/hugegraph-llm/src/hugegraph_llm/models/embeddings/base.py b/hugegraph-llm/src/hugegraph_llm/models/embeddings/base.py index d6b66294..db9b2f10 100644 --- a/hugegraph-llm/src/hugegraph_llm/models/embeddings/base.py +++ b/hugegraph-llm/src/hugegraph_llm/models/embeddings/base.py @@ -86,13 +86,28 @@ class BaseEmbedding(ABC): The order of embeddings should match the order of input texts. """ - # TODO: [PR-238] Add & implement batch processing for async_get_texts_embeddings (refactor here) @abstractmethod - async def async_get_text_embedding( + async def async_get_texts_embeddings( self, - text: str - ) -> List[float]: - """Comment""" + texts: List[str] + ) -> List[List[float]]: + """Get embeddings for multiple texts in a single batch asynchronously. + + This method should efficiently process multiple texts at once by leveraging + the embedding model's batching capabilities, which is typically more efficient + than processing texts individually. + + Parameters + ---------- + texts : List[str] + A list of text strings to be embedded. + + Returns + ------- + List[List[float]] + A list of embedding vectors, where each vector is a list of floats. + The order of embeddings should match the order of input texts. + """ @staticmethod def similarity( diff --git a/hugegraph-llm/src/hugegraph_llm/models/embeddings/litellm.py b/hugegraph-llm/src/hugegraph_llm/models/embeddings/litellm.py index ee808b09..b793800e 100644 --- a/hugegraph-llm/src/hugegraph_llm/models/embeddings/litellm.py +++ b/hugegraph-llm/src/hugegraph_llm/models/embeddings/litellm.py @@ -77,17 +77,17 @@ class LiteLLMEmbedding(BaseEmbedding): log.error("Error in LiteLLM batch embedding call: %s", e) raise - async def async_get_text_embedding(self, text: str) -> List[float]: + async def async_get_texts_embeddings(self, texts: List[str]) -> List[List[float]]: """Get embedding for a single text asynchronously.""" try: response = await aembedding( model=self.model, - input=text, + input=texts, api_key=self.api_key, api_base=self.api_base, ) log.info("Token usage: %s", response.usage) - return response.data[0]["embedding"] + return [data["embedding"] for data in response.data] except (RateLimitError, APIConnectionError, APIError) as e: log.error("Error in async LiteLLM embedding call: %s", e) raise diff --git a/hugegraph-llm/src/hugegraph_llm/models/embeddings/ollama.py b/hugegraph-llm/src/hugegraph_llm/models/embeddings/ollama.py index e54750f0..78c5bd08 100644 --- a/hugegraph-llm/src/hugegraph_llm/models/embeddings/ollama.py +++ b/hugegraph-llm/src/hugegraph_llm/models/embeddings/ollama.py @@ -53,7 +53,20 @@ class OllamaEmbedding(BaseEmbedding): response = self.client.embed(model=self.model, input=texts)["embeddings"] return [list(inner_sequence) for inner_sequence in response] - # TODO: Add & implement batch processing for async_get_texts_embeddings (refactor here) - async def async_get_text_embedding(self, text: str) -> List[float]: - response = await self.async_client.embeddings(model=self.model, prompt=text) - return list(response["embedding"]) + async def async_get_texts_embeddings(self, texts: List[str]) -> List[List[float]]: + """Get embeddings for multiple texts in a single batch asynchronously. + + Returns + ------- + List[List[float]] + A list of embedding vectors, where each vector is a list of floats. + The order of embeddings matches the order of input texts. + """ + if not hasattr(self.client, "embed"): + error_message = ( + "The required 'embed' method was not found on the Ollama client. " + "Please ensure your ollama library is up-to-date and supports batch embedding. " + ) + raise AttributeError(error_message) + response = await self.async_client.embed(model=self.model, input=texts) + return [list(inner_sequence) for inner_sequence in response["embeddings"]] diff --git a/hugegraph-llm/src/hugegraph_llm/models/embeddings/openai.py b/hugegraph-llm/src/hugegraph_llm/models/embeddings/openai.py index 890f4918..c18a0fb1 100644 --- a/hugegraph-llm/src/hugegraph_llm/models/embeddings/openai.py +++ b/hugegraph-llm/src/hugegraph_llm/models/embeddings/openai.py @@ -62,7 +62,23 @@ class OpenAIEmbedding: response = self.client.embeddings.create(input=texts, model=self.embedding_model_name) return [data.embedding for data in response.data] - async def async_get_text_embedding(self, text: str) -> List[float]: - """Comment""" - response = await self.aclient.embeddings.create(input=text, model=self.embedding_model_name) - return response.data[0].embedding + async def async_get_texts_embeddings(self, texts: List[str]) -> List[List[float]]: + """Get embeddings for multiple texts in a single batch asynchronously. + + This method should efficiently process multiple texts at once by leveraging + the embedding model's batching capabilities, which is typically more efficient + than processing texts individually. + + Parameters + ---------- + texts : List[str] + A list of text strings to be embedded. + + Returns + ------- + List[List[float]] + A list of embedding vectors, where each vector is a list of floats. + The order of embeddings should match the order of input texts. + """ + response = await self.aclient.embeddings.create(input=texts, model=self.embedding_model_name) + return [data.embedding for data in response.data] diff --git a/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_gremlin_example_index.py b/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_gremlin_example_index.py index b865bc65..b444dcd6 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_gremlin_example_index.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_gremlin_example_index.py @@ -16,12 +16,14 @@ # under the License. +import asyncio import os from typing import Dict, Any, List from hugegraph_llm.config import resource_path from hugegraph_llm.indices.vector_index import VectorIndex from hugegraph_llm.models.embeddings.base import BaseEmbedding +from hugegraph_llm.utils.embedding_utils import get_embeddings_parallel # FIXME: we need keep the logic same with build_semantic_index.py @@ -32,9 +34,10 @@ class BuildGremlinExampleIndex: self.embedding = embedding def run(self, context: Dict[str, Any]) -> Dict[str, Any]: - examples_embedding = [] - for example in self.examples: - examples_embedding.append(self.embedding.get_text_embedding(example["query"])) + # !: We have assumed that self.example is not empty + queries = [example["query"] for example in self.examples] + # TODO: refactor function chain async to avoid blocking + examples_embedding = asyncio.run(get_embeddings_parallel(self.embedding, queries)) embed_dim = len(examples_embedding[0]) if len(self.examples) > 0: vector_index = VectorIndex(embed_dim) diff --git a/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_semantic_index.py b/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_semantic_index.py index e6b4080a..a2a0412f 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_semantic_index.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_semantic_index.py @@ -20,18 +20,19 @@ import asyncio import os from typing import Any, Dict -from tqdm import tqdm - from hugegraph_llm.config import resource_path, huge_settings from hugegraph_llm.indices.vector_index import VectorIndex from hugegraph_llm.models.embeddings.base import BaseEmbedding -from hugegraph_llm.operators.hugegraph_op.schema_manager import SchemaManager from hugegraph_llm.utils.log import log +from hugegraph_llm.utils.embedding_utils import get_embeddings_parallel +from hugegraph_llm.operators.hugegraph_op.schema_manager import SchemaManager class BuildSemanticIndex: def __init__(self, embedding: BaseEmbedding): - self.index_dir = str(os.path.join(resource_path, huge_settings.graph_name, "graph_vids")) + self.index_dir = str( + os.path.join(resource_path, huge_settings.graph_name, "graph_vids") + ) self.vid_index = VectorIndex.from_index_file(self.index_dir) self.embedding = embedding self.sm = SchemaManager(huge_settings.graph_name) @@ -39,55 +40,38 @@ class BuildSemanticIndex: def _extract_names(self, vertices: list[str]) -> list[str]: return [v.split(":")[1] for v in vertices] - async def _get_embeddings_parallel(self, vids: list[str]) -> list[Any]: - sem = asyncio.Semaphore(10) - batch_size = 1000 - - # TODO: refactor the logic here (call async method) - async def get_embeddings_with_semaphore(vid_list: list[str]) -> Any: - # Executes sync embedding method in a thread pool via loop.run_in_executor, combining async programming - # with multi-threading capabilities. - # This pattern avoids blocking the event loop and prepares for a future fully async pipeline. - async with sem: - loop = asyncio.get_running_loop() - # FIXME: [PR-238] add & use async_get_texts_embedding instead of sync method - return await loop.run_in_executor(None, self.embedding.get_texts_embeddings, vid_list) - - # Split vids into batches of size batch_size - vid_batches = [vids[i:i + batch_size] for i in range(0, len(vids), batch_size)] - - # Create tasks for each batch - tasks = [get_embeddings_with_semaphore(batch) for batch in vid_batches] - - embeddings = [] - with tqdm(total=len(tasks)) as pbar: - for future in asyncio.as_completed(tasks): - batch_embeddings = await future - embeddings.extend(batch_embeddings) # Extend the list with batch results - pbar.update(1) - return embeddings - def run(self, context: Dict[str, Any]) -> Dict[str, Any]: vertexlabels = self.sm.schema.getSchema()["vertexlabels"] - all_pk_flag = all(data.get('id_strategy') == 'PRIMARY_KEY' for data in vertexlabels) + all_pk_flag = all( + data.get("id_strategy") == "PRIMARY_KEY" for data in vertexlabels + ) past_vids = self.vid_index.properties # TODO: We should build vid vector index separately, especially when the vertices may be very large - present_vids = context["vertices"] # Warning: data truncated by fetch_graph_data.py + + present_vids = context[ + "vertices" + ] # Warning: data truncated by fetch_graph_data.py removed_vids = set(past_vids) - set(present_vids) removed_num = self.vid_index.remove(removed_vids) added_vids = list(set(present_vids) - set(past_vids)) if added_vids: - vids_to_process = self._extract_names(added_vids) if all_pk_flag else added_vids - added_embeddings = asyncio.run(self._get_embeddings_parallel(vids_to_process)) + vids_to_process = ( + self._extract_names(added_vids) if all_pk_flag else added_vids + ) + added_embeddings = asyncio.run( + get_embeddings_parallel(self.embedding, vids_to_process) + ) log.info("Building vector index for %s vertices...", len(added_vids)) self.vid_index.add(added_embeddings, added_vids) self.vid_index.to_index_file(self.index_dir) else: log.debug("No update vertices to build vector index.") - context.update({ - "removed_vid_vector_num": removed_num, - "added_vid_vector_num": len(added_vids) - }) + context.update( + { + "removed_vid_vector_num": removed_num, + "added_vid_vector_num": len(added_vids), + } + ) return context diff --git a/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_vector_index.py b/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_vector_index.py index 01499f52..8a66d0b0 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_vector_index.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_vector_index.py @@ -16,15 +16,15 @@ # under the License. +import asyncio import os from typing import Dict, Any -from tqdm import tqdm - from hugegraph_llm.config import huge_settings, resource_path from hugegraph_llm.indices.vector_index import VectorIndex from hugegraph_llm.models.embeddings.base import BaseEmbedding from hugegraph_llm.utils.log import log +from hugegraph_llm.utils.embedding_utils import get_embeddings_parallel class BuildVectorIndex: @@ -40,8 +40,7 @@ class BuildVectorIndex: chunks_embedding = [] log.debug("Building vector index for %s chunks...", len(context["chunks"])) # TODO: use async_get_texts_embedding instead of single sync method - for chunk in tqdm(chunks): - chunks_embedding.append(self.embedding.get_text_embedding(chunk)) + chunks_embedding = asyncio.run(get_embeddings_parallel(self.embedding, chunks)) if len(chunks_embedding) > 0: self.vector_index.add(chunks_embedding, chunks) self.vector_index.to_index_file(self.index_dir) diff --git a/hugegraph-llm/src/hugegraph_llm/operators/index_op/gremlin_example_index_query.py b/hugegraph-llm/src/hugegraph_llm/operators/index_op/gremlin_example_index_query.py index b8acd506..b14da3d5 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/index_op/gremlin_example_index_query.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/index_op/gremlin_example_index_query.py @@ -16,17 +16,18 @@ # under the License. +import asyncio import os from typing import Dict, Any, List import pandas as pd -from tqdm import tqdm from hugegraph_llm.config import resource_path from hugegraph_llm.indices.vector_index import VectorIndex from hugegraph_llm.models.embeddings.base import BaseEmbedding from hugegraph_llm.models.embeddings.init_embedding import Embeddings from hugegraph_llm.utils.log import log +from hugegraph_llm.utils.embedding_utils import get_embeddings_parallel class GremlinExampleIndexQuery: @@ -51,17 +52,14 @@ class GremlinExampleIndexQuery: query_embedding = context.get("query_embedding") if not isinstance(query_embedding, list): - query_embedding = self.embedding.get_text_embedding(query) + query_embedding = self.embedding.get_texts_embeddings([query])[0] return self.vector_index.search(query_embedding, self.num_examples, dis_threshold=1.8) def _build_default_example_index(self): properties = pd.read_csv(os.path.join(resource_path, "demo", "text2gremlin.csv")).to_dict(orient="records") - from concurrent.futures import ThreadPoolExecutor # TODO: reuse the logic in build_semantic_index.py (consider extract the batch-embedding method) - with ThreadPoolExecutor() as executor: - embeddings = list( - tqdm(executor.map(self.embedding.get_text_embedding, [row["query"] for row in properties]), - total=len(properties))) + queries = [row["query"] for row in properties] + embeddings = asyncio.run(get_embeddings_parallel(self.embedding, queries)) vector_index = VectorIndex(len(embeddings[0])) vector_index.add(embeddings, properties) vector_index.to_index_file(self.index_dir) diff --git a/hugegraph-llm/src/hugegraph_llm/operators/index_op/semantic_id_query.py b/hugegraph-llm/src/hugegraph_llm/operators/index_op/semantic_id_query.py index e3375ef0..10233a6c 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/index_op/semantic_id_query.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/index_op/semantic_id_query.py @@ -75,7 +75,7 @@ class SemanticIdQuery: def _fuzzy_match_vids(self, keywords: List[str]) -> List[str]: fuzzy_match_result = [] for keyword in keywords: - keyword_vector = self.embedding.get_text_embedding(keyword) + keyword_vector = self.embedding.get_texts_embeddings([keyword])[0] results = self.vector_index.search(keyword_vector, top_k=self.topk_per_keyword, dis_threshold=float(self.vector_dis_threshold)) if results: @@ -86,7 +86,7 @@ class SemanticIdQuery: graph_query_list = set() if self.by == "query": query = context["query"] - query_vector = self.embedding.get_text_embedding(query) + query_vector = self.embedding.get_texts_embeddings([query])[0] results = self.vector_index.search(query_vector, top_k=self.topk_per_query) if results: graph_query_list.update(results[:self.topk_per_query]) diff --git a/hugegraph-llm/src/hugegraph_llm/operators/index_op/vector_index_query.py b/hugegraph-llm/src/hugegraph_llm/operators/index_op/vector_index_query.py index 976155c3..f2d5d600 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/index_op/vector_index_query.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/index_op/vector_index_query.py @@ -34,7 +34,7 @@ class VectorIndexQuery: def run(self, context: Dict[str, Any]) -> Dict[str, Any]: query = context.get("query") - query_embedding = self.embedding.get_text_embedding(query) + query_embedding = self.embedding.get_texts_embeddings([query])[0] # TODO: why set dis_threshold=2? results = self.vector_index.search(query_embedding, self.topk, dis_threshold=2) # TODO: check format results diff --git a/hugegraph-llm/src/hugegraph_llm/utils/embedding_utils.py b/hugegraph-llm/src/hugegraph_llm/utils/embedding_utils.py new file mode 100644 index 00000000..2c7b3874 --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/utils/embedding_utils.py @@ -0,0 +1,62 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +import asyncio +from typing import Any + +from tqdm import tqdm + +from hugegraph_llm.models.embeddings.base import BaseEmbedding + + +async def get_embeddings_parallel( + embedding: BaseEmbedding, vids: list[str] +) -> list[Any]: + """Get embeddings for texts in parallel. + + This function processes text embeddings asynchronously in parallel, using batching and semaphore + to control concurrency, improving processing efficiency while preventing resource overuse. + + Args: + embedding (BaseEmbedding): The embedding model instance used to compute text embeddings. + vids (list[str]): List of texts to compute embeddings for. + + Returns: + list[Any]: List of embedding vectors corresponding to the input texts, maintaining the same + order as the input vids list. + + Note: + - Note: Uses a semaphore to limit maximum concurrency if we need + - Processes texts in batches of 500 + - Displays progress using a progress bar + """ + batch_size = 500 + + # Split vids into batches of size batch_size + vid_batches = [vids[i : i + batch_size] for i in range(0, len(vids), batch_size)] + + # Create tasks for each batch + tasks = [embedding.async_get_texts_embeddings(batch) for batch in vid_batches] + + embeddings = [] + with tqdm(total=len(tasks)) as pbar: + for future in asyncio.as_completed(tasks): + batch_embeddings = await future + embeddings.extend(batch_embeddings) # Extend the list with batch results + pbar.update(1) + return embeddings diff --git a/hugegraph-llm/src/hugegraph_llm/utils/vector_index_utils.py b/hugegraph-llm/src/hugegraph_llm/utils/vector_index_utils.py index ef2b5e9b..0542b906 100644 --- a/hugegraph-llm/src/hugegraph_llm/utils/vector_index_utils.py +++ b/hugegraph-llm/src/hugegraph_llm/utils/vector_index_utils.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + import json import os @@ -55,10 +56,13 @@ def read_documents(input_file, input_text): return texts -#pylint: disable=C0301 +# pylint: disable=C0301 def get_vector_index_info(): - chunk_vector_index = VectorIndex.from_index_file(str(os.path.join(resource_path, huge_settings.graph_name, "chunks"))) - graph_vid_vector_index = VectorIndex.from_index_file(str(os.path.join(resource_path, huge_settings.graph_name, "graph_vids"))) + chunk_vector_index = VectorIndex.from_index_file( + str(os.path.join(resource_path, huge_settings.graph_name, "chunks")), record_miss=False, + ) + graph_vid_vector_index = VectorIndex.from_index_file(str(os.path.join(resource_path, + huge_settings.graph_name, "graph_vids"))) return json.dumps({ "embed_dim": chunk_vector_index.index.d, "vector_info": {