This is an automated email from the ASF dual-hosted git repository. jin pushed a commit to branch bak-api in repository https://gitbox.apache.org/repos/asf/incubator-hugegraph-ai.git
commit a629571dfa6fe860e167fb4f302691c613acc069 Author: imbajin <[email protected]> AuthorDate: Thu Sep 5 16:25:42 2024 +0800 refactor some params feat(llm): support graph_rag_recall api refact: fix some bugs & query logic bugfix refact(llm): enhance the graph/gremlin query phrase --- hugegraph-llm/README.md | 11 +- .../src/hugegraph_llm/api/models/rag_requests.py | 15 +- hugegraph-llm/src/hugegraph_llm/api/rag_api.py | 51 ++++++- .../src/hugegraph_llm/demo/rag_web_demo.py | 73 +++++---- .../src/hugegraph_llm/models/llms/init_llm.py | 4 +- .../src/hugegraph_llm/models/llms/openai.py | 2 +- .../operators/document_op/word_extract.py | 3 +- .../src/hugegraph_llm/operators/graph_rag_task.py | 36 ++--- .../operators/hugegraph_op/graph_rag_query.py | 169 +++++++++++---------- .../operators/hugegraph_op/schema_manager.py | 6 +- .../operators/index_op/semantic_id_query.py | 10 +- .../operators/llm_op/answer_synthesize.py | 21 +-- .../operators/llm_op/disambiguate_data.py | 2 +- .../operators/llm_op/keyword_extract.py | 48 +++--- .../src/hugegraph_llm/utils/decorators.py | 2 +- .../src/hugegraph_llm/utils/graph_index_utils.py | 6 +- 16 files changed, 248 insertions(+), 211 deletions(-) diff --git a/hugegraph-llm/README.md b/hugegraph-llm/README.md index 25de21b..8df48bb 100644 --- a/hugegraph-llm/README.md +++ b/hugegraph-llm/README.md @@ -130,22 +130,19 @@ The methods of the `KgBuilder` class can be chained together to perform a sequen Run example like `python3 ./hugegraph_llm/examples/graph_rag_test.py` -The `GraphRAG` class is used to integrate HugeGraph with large language models to provide retrieval-augmented generation capabilities. +The `RAGPipeline` class is used to integrate HugeGraph with large language models to provide retrieval-augmented generation capabilities. Here is a brief usage guide: 1. **Extract Keyword:**: Extract keywords and expand synonyms. - + ```python - graph_rag.extract_keyword(text="Tell me about Al Pacino.").print_result() + graph_rag.extract_keywords(text="Tell me about Al Pacino.").print_result() ``` 2. **Query Graph for Rag**: Retrieve the corresponding keywords and their multi-degree associated relationships from HugeGraph. ```python - graph_rag.query_graph_for_rag( - max_deep=2, - max_items=30 - ).print_result() + graph_rag.query_graph_db(max_deep=2, max_items=30).print_result() ``` 3. **Synthesize Answer**: Summarize the results and organize the language to answer the question. diff --git a/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py b/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py index a211bb8..0e7c666 100644 --- a/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py +++ b/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py @@ -15,20 +15,31 @@ # specific language governing permissions and limitations # under the License. -from typing import Optional +from typing import Optional, Literal from pydantic import BaseModel class RAGRequest(BaseModel): query: str - raw_llm: Optional[bool] = True + raw_llm: Optional[bool] = False vector_only: Optional[bool] = False graph_only: Optional[bool] = False graph_vector: Optional[bool] = False + graph_ratio: float = 0.5 + rerank_method: Literal["bleu", "reranker"] = "bleu" + near_neighbor_first: bool = False + custom_related_information: str = None answer_prompt: Optional[str] = None +class GraphRAGRequest(BaseModel): + query: str + rerank_method: Literal["bleu", "reranker"] = "bleu" + near_neighbor_first: bool = False + custom_related_information: str = None + + class GraphConfigRequest(BaseModel): ip: str = "127.0.0.1" port: str = "8080" diff --git a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py b/hugegraph-llm/src/hugegraph_llm/api/rag_api.py index 64daf70..3ec4580 100644 --- a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py +++ b/hugegraph-llm/src/hugegraph_llm/api/rag_api.py @@ -14,22 +14,42 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import json +from typing import Literal -from fastapi import status, APIRouter +from fastapi import status, APIRouter, HTTPException from hugegraph_llm.api.exceptions.rag_exceptions import generate_response from hugegraph_llm.api.models.rag_requests import ( RAGRequest, GraphConfigRequest, LLMConfigRequest, - RerankerConfigRequest, + RerankerConfigRequest, GraphRAGRequest, ) from hugegraph_llm.api.models.rag_response import RAGResponse from hugegraph_llm.config import settings +from hugegraph_llm.utils.log import log + + +def graph_rag_recall( + text: str, + rerank_method: Literal["bleu", "reranker"], + near_neighbor_first: bool, + custom_related_information: str +) -> dict: + from hugegraph_llm.operators.graph_rag_task import RAGPipeline + rag = RAGPipeline() + rag.extract_keywords().keywords_to_vid().query_graph_db().merge_dedup_rerank( + rerank_method=rerank_method, + near_neighbor_first=near_neighbor_first, + custom_related_information=custom_related_information, + ) + context = rag.run(verbose=True, query=text, graph_search=True) + return context def rag_http_api( - router: APIRouter, rag_answer_func, apply_graph_conf, apply_llm_conf, apply_embedding_conf, apply_reranker_conf + router: APIRouter, rag_answer_func, apply_graph_conf, apply_llm_conf, apply_embedding_conf, apply_reranker_conf ): @router.post("/rag", status_code=status.HTTP_200_OK) def rag_answer_api(req: RAGRequest): @@ -41,6 +61,31 @@ def rag_http_api( if getattr(req, key) } + @router.post("/rag/graph", status_code=status.HTTP_200_OK) + def graph_rag_recall_api(req: GraphRAGRequest): + try: + result = graph_rag_recall( + text=req.query, + rerank_method=req.rerank_method, + near_neighbor_first=req.near_neighbor_first, + custom_related_information=req.custom_related_information + ) + # TODO: handle QianFanClient error (not dict..) + # log.critical(f"## {type(result)}, {json.dumps(result)}") + if isinstance(result, dict): + log.critical(f"1## {type(result)}") + return {"graph_recall": result} + else: + log.critical(f"2## {type(result)}") + return {"graph_recall": json.dumps(result)} + + except TypeError as e: + log.error(f"TypeError in graph_rag_recall_api: {e}") + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) + except Exception as e: + log.error(f"Unexpected error occurred: {e}") + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred.") + @router.post("/config/graph", status_code=status.HTTP_201_CREATED) def graph_config_api(req: GraphConfigRequest): # Accept status code diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py b/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py index c4c68c0..f145d4d 100644 --- a/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py +++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py @@ -68,37 +68,40 @@ def rag_answer( graph_only_answer: bool, graph_vector_answer: bool, graph_ratio: float, - rerank_method: Literal["bleu", "reranker"], - near_neighbor_first: bool, - custom_related_information: str, - answer_prompt: str, + rerank_method: Literal["bleu", "reranker"] = "bleu", + near_neighbor_first: bool = False, + custom_related_information: str = None, + answer_prompt: str = None, ) -> Tuple: + """ + Generate an answer using the RAG (Retrieval-Augmented Generation) pipeline. + + 1. Initialize the RAGPipeline. + 2. Select vector search or graph search based on parameters. + 3. Merge, deduplicate, and rerank the results. + 4. Synthesize the final answer. + 5. Run the pipeline and return the results. + """ vector_search = vector_only_answer or graph_vector_answer graph_search = graph_only_answer or graph_vector_answer - if raw_answer is False and not vector_search and not graph_search: gr.Warning("Please select at least one generate mode.") return "", "", "", "" - searcher = RAGPipeline() + + rag = RAGPipeline() if vector_search: - searcher.query_vector_index_for_rag() + rag.query_vector_index() if graph_search: - searcher.extract_keyword().match_keyword_to_id().query_graph_for_rag() + rag.extract_keywords().keywords_to_vid().query_graph_db() + # TODO: add more user-defined search strategies - searcher.merge_dedup_rerank( - graph_ratio, rerank_method, near_neighbor_first, custom_related_information - ).synthesize_answer( - raw_answer=raw_answer, - vector_only_answer=vector_only_answer, - graph_only_answer=graph_only_answer, - graph_vector_answer=graph_vector_answer, - answer_prompt=answer_prompt, - ) + rag.merge_dedup_rerank(graph_ratio, rerank_method, near_neighbor_first, custom_related_information) + rag.synthesize_answer(raw_answer, vector_only_answer, graph_only_answer, graph_vector_answer, answer_prompt) try: - context = searcher.run(verbose=True, query=text, vector_search=vector_search, graph_search=graph_search) + context = rag.run(verbose=True, query=text, vector_search=vector_search, graph_search=graph_search) if context.get("switch_to_bleu"): - gr.Warning("Online reranker fails, automatically switches to local bleu method.") + gr.Warning("Online reranker fails, automatically switches to local bleu rerank.") return ( context.get("raw_answer", ""), context.get("vector_only_answer", ""), @@ -106,10 +109,10 @@ def rag_answer( context.get("graph_vector_answer", ""), ) except ValueError as e: - log.error(e) + log.critical(e) raise gr.Error(str(e)) except Exception as e: - log.error(e) + log.critical(e) raise gr.Error(f"An unexpected error occurred: {str(e)}") @@ -529,12 +532,14 @@ def init_rag_ui() -> gr.Interface: input_file = gr.File( value=[os.path.join(resource_path, "demo", "test.txt")], label="Docs (multi-files can be selected together)", - file_count="multiple" + file_count="multiple", ) with gr.Tab("text") as tab_upload_text: - input_text = gr.Textbox(value="", label="Doc(s)") - input_schema = gr.Textbox(value=schema, label="Schema") - info_extract_template = gr.Textbox(value=SCHEMA_EXAMPLE_PROMPT, label="Info extract head") + input_text = gr.Textbox(value="", label="Doc(s)", lines=20, show_copy_button=True) + input_schema = gr.Textbox(value=schema, label="Schema", lines=15, show_copy_button=True) + info_extract_template = gr.Textbox(value=SCHEMA_EXAMPLE_PROMPT, label="Info extract head", lines=15, + show_copy_button=True) + out = gr.Textbox(label="Output", lines=15, show_copy_button=True) with gr.Column(visible=False, elem_classes="modal-box") as preview_box: with gr.Row(): import_btn = gr.Button("Import Into Graph") @@ -552,8 +557,8 @@ def init_rag_ui() -> gr.Interface: graph_index_btn1 = gr.Button("Clear Graph Index") graph_index_btn2 = gr.Button("Extract Graph", variant="primary") graph_index_btn3 = gr.Button("Fit Vid Index") - with gr.Row(): - out = gr.Textbox(label="Output", show_copy_button=True) + # with gr.Row(): + # out = gr.Textbox(label="Output", show_copy_button=True) vector_index_btn0.click(get_vector_index_info, outputs=out) # pylint: disable=no-member vector_index_btn1.click(clean_vector_index) # pylint: disable=no-member vector_index_btn2.click(build_vector_index, inputs=[input_file, input_text], outputs=out) # pylint: disable=no-member @@ -663,23 +668,15 @@ if __name__ == "__main__": parser.add_argument("--port", type=int, default=8001, help="port") args = parser.parse_args() app = FastAPI() - app_auth = APIRouter(dependencies=[Depends(authenticate)]) + auth = APIRouter(dependencies=[Depends(authenticate)]) hugegraph_llm = init_rag_ui() - rag_http_api( - app_auth, - rag_answer, - apply_graph_config, - apply_llm_config, - apply_embedding_config, - apply_reranker_config, - ) + rag_http_api(auth, rag_answer, apply_graph_config, apply_llm_config, apply_embedding_config, apply_reranker_config) - app.include_router(app_auth) + app.include_router(auth) auth_enabled = os.getenv("ENABLE_LOGIN", "False").lower() == "true" log.info("Authentication is %s.", "enabled" if auth_enabled else "disabled") # TODO: support multi-user login when need - app = gr.mount_gradio_app(app, hugegraph_llm, path="/", auth=("rag", os.getenv("TOKEN")) if auth_enabled else None) # TODO: we can't use reload now due to the config 'app' of uvicorn.run diff --git a/hugegraph-llm/src/hugegraph_llm/models/llms/init_llm.py b/hugegraph-llm/src/hugegraph_llm/models/llms/init_llm.py index 22a82cb..2c90748 100644 --- a/hugegraph-llm/src/hugegraph_llm/models/llms/init_llm.py +++ b/hugegraph-llm/src/hugegraph_llm/models/llms/init_llm.py @@ -17,7 +17,7 @@ from hugegraph_llm.models.llms.ollama import OllamaClient -from hugegraph_llm.models.llms.openai import OpenAIChat +from hugegraph_llm.models.llms.openai import OpenAIClient from hugegraph_llm.models.llms.qianfan import QianfanClient from hugegraph_llm.config import settings @@ -34,7 +34,7 @@ class LLMs: secret_key=settings.qianfan_secret_key ) if self.llm_type == "openai": - return OpenAIChat( + return OpenAIClient( api_key=settings.openai_api_key, api_base=settings.openai_api_base, model_name=settings.openai_language_model, diff --git a/hugegraph-llm/src/hugegraph_llm/models/llms/openai.py b/hugegraph-llm/src/hugegraph_llm/models/llms/openai.py index 30ac805..263f4f2 100644 --- a/hugegraph-llm/src/hugegraph_llm/models/llms/openai.py +++ b/hugegraph-llm/src/hugegraph_llm/models/llms/openai.py @@ -27,7 +27,7 @@ from hugegraph_llm.models.llms.base import BaseLLM from hugegraph_llm.utils.log import log -class OpenAIChat(BaseLLM): +class OpenAIClient(BaseLLM): """Wrapper around OpenAI Chat large language models.""" def __init__( diff --git a/hugegraph-llm/src/hugegraph_llm/operators/document_op/word_extract.py b/hugegraph-llm/src/hugegraph_llm/operators/document_op/word_extract.py index 43d0a29..b8f40df 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/document_op/word_extract.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/document_op/word_extract.py @@ -62,7 +62,8 @@ class WordExtract: verbose = context.get("verbose") or False if verbose: - print(f"\033[92mKEYWORDS: {context['keywords']}\033[0m") + from hugegraph_llm.utils.log import log + log.info(f"KEYWORDS: {context['keywords']}") return context diff --git a/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py b/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py index 91bc7b3..573b00c 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py @@ -50,11 +50,7 @@ class RAGPipeline: self._embedding = embedding or Embeddings().get_embedding() self._operators: List[Any] = [] - def extract_word( - self, - text: Optional[str] = None, - language: str = "english", - ): + def extract_word(self, text: Optional[str] = None, language: str = "english"): """ Add a word extraction operator to the pipeline. @@ -62,15 +58,10 @@ class RAGPipeline: :param language: Language of the text. :return: Self-instance for chaining. """ - self._operators.append( - WordExtract( - text=text, - language=language, - ) - ) + self._operators.append(WordExtract(text=text, language=language)) return self - def extract_keyword( + def extract_keywords( self, text: Optional[str] = None, max_keywords: int = 5, @@ -99,7 +90,7 @@ class RAGPipeline: ) return self - def match_keyword_to_id( + def keywords_to_vid( self, by: Literal["query", "keywords"] = "keywords", topk_per_keyword: int = 1, @@ -108,6 +99,8 @@ class RAGPipeline: """ Add a semantic ID query operator to the pipeline. + :param topk_per_query: Top K results per query. + :param by: Match by query or keywords. :param topk_per_keyword: Top K results per keyword. :return: Self-instance for chaining. """ @@ -121,7 +114,7 @@ class RAGPipeline: ) return self - def query_graph_for_rag( + def query_graph_db( self, max_deep: int = 2, max_items: int = 30, @@ -136,15 +129,11 @@ class RAGPipeline: :return: Self-instance for chaining. """ self._operators.append( - GraphRAGQuery( - max_deep=max_deep, - max_items=max_items, - prop_to_match=prop_to_match, - ) + GraphRAGQuery(max_deep=max_deep, max_items=max_items, prop_to_match=prop_to_match) ) return self - def query_vector_index_for_rag(self, max_items: int = 3): + def query_vector_index(self, max_items: int = 3): """ Add a vector index query operator to the pipeline. @@ -152,10 +141,7 @@ class RAGPipeline: :return: Self-instance for chaining. """ self._operators.append( - VectorIndexQuery( - embedding=self._embedding, - topk=max_items, - ) + VectorIndexQuery(embedding=self._embedding, topk=max_items) ) return self @@ -230,7 +216,7 @@ class RAGPipeline: :return: Final context after all operators have been executed. """ if len(self._operators) == 0: - self.extract_keyword().query_graph_for_rag().synthesize_answer() + self.extract_keywords().query_graph_db().synthesize_answer() context = kwargs context["llm"] = self._llm diff --git a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py index fe225c2..01cb0fa 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py @@ -20,64 +20,63 @@ import re from typing import Any, Dict, Optional, List, Set, Tuple from hugegraph_llm.config import settings +from hugegraph_llm.utils.log import log from pyhugegraph.client import PyHugeClient -class GraphRAGQuery: - VERTEX_GREMLIN_QUERY_TEMPL = "g.V().hasId({keywords}).as('subj').toList()" - # ID_RAG_GREMLIN_QUERY_TEMPL = "g.V().hasId({keywords}).as('subj').repeat(bothE({edge_labels}).as('rel').otherV( - # ).as('obj')).times({max_deep}).path().by(project('label', 'id', 'props').by(label()).by(id()).by(valueMap().by( - # unfold()))).by(project('label', 'inV', 'outV', 'props').by(label()).by(inV().id()).by(outV().id()).by(valueMap( - # ).by(unfold()))).limit({max_items}).toList()" +VERTEX_QUERY_TPL = "g.V({keywords}).as('subj').toList()" +# ID_RAG_GREMLIN_QUERY_TEMPL = "g.V().hasId({keywords}).as('subj').repeat(bothE({edge_labels}).as('rel').otherV( +# ).as('obj')).times({max_deep}).path().by(project('label', 'id', 'props').by(label()).by(id()).by(valueMap().by( +# unfold()))).by(project('label', 'inV', 'outV', 'props').by(label()).by(inV().id()).by(outV().id()).by(valueMap( +# ).by(unfold()))).limit({max_items}).toList()" + +# TODO: we could use a simpler query (like kneighbor-api to get the edges) +# TODO: use dedup() to filter duplicate paths +ID_QUERY_NEIGHBOR_TPL = """ +g.V({keywords}).as('subj') +.repeat( + bothE({edge_labels}).as('rel').otherV().as('obj') +).times({max_deep}) +.path() +.by(project('label', 'id', 'props') + .by(label()) + .by(id()) + .by(valueMap().by(unfold())) +) +.by(project('label', 'inV', 'outV', 'props') + .by(label()) + .by(inV().id()) + .by(outV().id()) + .by(valueMap().by(unfold())) +) +.limit({max_items}) +.toList() +""" - # TODO: we could use a simpler query (like kneighbor-api to get the edges) - ID_RAG_GREMLIN_QUERY_TEMPL = """ - g.V().hasId({keywords}).as('subj') - .repeat( - bothE({edge_labels}).as('rel').otherV().as('obj') - ).times({max_deep}) - .path() - .by(project('label', 'id', 'props') - .by(label()) - .by(id()) - .by(valueMap().by(unfold())) - ) - .by(project('label', 'inV', 'outV', 'props') - .by(label()) - .by(inV().id()) - .by(outV().id()) - .by(valueMap().by(unfold())) - ) - .limit({max_items}) - .toList() - """ +PROPERTY_QUERY_NEIGHBOR_TPL = """ +g.V().has('{prop}', within({keywords})).as('subj') +.repeat( + bothE({edge_labels}).as('rel').otherV().as('obj') +).times({max_deep}) +.path() +.by(project('label', 'props') + .by(label()) + .by(valueMap().by(unfold())) +) +.by(project('label', 'inV', 'outV', 'props') + .by(label()) + .by(inV().values('{prop}')) + .by(outV().values('{prop}')) + .by(valueMap().by(unfold())) +) +.limit({max_items}) +.toList() +""" - PROP_RAG_GREMLIN_QUERY_TEMPL = """ - g.V().has('{prop}', within({keywords})).as('subj') - .repeat( - bothE({edge_labels}).as('rel').otherV().as('obj') - ).times({max_deep}) - .path() - .by(project('label', 'props') - .by(label()) - .by(valueMap().by(unfold())) - ) - .by(project('label', 'inV', 'outV', 'props') - .by(label()) - .by(inV().values('{prop}')) - .by(outV().values('{prop}')) - .by(valueMap().by(unfold())) - ) - .limit({max_items}) - .toList() - """ - def __init__( - self, - max_deep: int = 2, - max_items: int = 30, - prop_to_match: Optional[str] = None, - ): +class GraphRAGQuery: + + def __init__(self, max_deep: int = 2, max_items: int = 30, prop_to_match: Optional[str] = None): self._client = PyHugeClient( settings.graph_ip, settings.graph_port, @@ -106,7 +105,7 @@ class GraphRAGQuery: assert self._client is not None, "No valid graph to search." keywords = context.get("keywords") - entrance_vids = context.get("entrance_vids") + match_vids = context.get("match_vids") if isinstance(context.get("max_deep"), int): self._max_deep = context["max_deep"] @@ -119,40 +118,48 @@ class GraphRAGQuery: edge_labels_str = ",".join("'" + label + "'" for label in edge_labels) use_id_to_match = self._prop_to_match is None + if use_id_to_match: + if not match_vids: return context - if not use_id_to_match: - assert keywords is not None, "No keywords for graph query." - keywords_str = ",".join("'" + kw + "'" for kw in keywords) - rag_gremlin_query = self.PROP_RAG_GREMLIN_QUERY_TEMPL.format( - prop=self._prop_to_match, - keywords=keywords_str, + gremlin_query = VERTEX_QUERY_TPL.format(keywords=match_vids) + result: List[Any] = self._client.gremlin().exec(gremlin=gremlin_query)["data"] + log.debug(f"Vids query: {gremlin_query}") + + vertex_knowledge = self._format_graph_from_vertex(query_result=result) + gremlin_query = ID_QUERY_NEIGHBOR_TPL.format( + keywords=match_vids, max_deep=self._max_deep, max_items=self._max_items, edge_labels=edge_labels_str, ) - result: List[Any] = self._client.gremlin().exec(gremlin=rag_gremlin_query)["data"] - graph_chain_knowledge, vertex_degree_list, knowledge_with_degree = self._format_knowledge_from_query_result( + log.debug(f"Kneighbor query: {gremlin_query}") + + result: List[Any] = self._client.gremlin().exec(gremlin=gremlin_query)["data"] + graph_chain_knowledge, vertex_degree_list, knowledge_with_degree = self._format_graph_from_query_result( query_result=result ) + graph_chain_knowledge.update(vertex_knowledge) + if vertex_degree_list: + vertex_degree_list[0].update(vertex_knowledge) + else: + vertex_degree_list.append(vertex_knowledge) else: - assert entrance_vids is not None, "No entrance vertices for query." - rag_gremlin_query = self.VERTEX_GREMLIN_QUERY_TEMPL.format( - keywords=entrance_vids, - ) - result: List[Any] = self._client.gremlin().exec(gremlin=rag_gremlin_query)["data"] - vertex_knowledge = self._format_knowledge_from_vertex(query_result=result) - rag_gremlin_query = self.ID_RAG_GREMLIN_QUERY_TEMPL.format( - keywords=entrance_vids, + # When will the query enter here? + assert keywords, "No related property(keywords) for graph query." + keywords_str = ",".join("'" + kw + "'" for kw in keywords) + gremlin_query = PROPERTY_QUERY_NEIGHBOR_TPL.format( + prop=self._prop_to_match, + keywords=keywords_str, max_deep=self._max_deep, max_items=self._max_items, edge_labels=edge_labels_str, ) - result: List[Any] = self._client.gremlin().exec(gremlin=rag_gremlin_query)["data"] - graph_chain_knowledge, vertex_degree_list, knowledge_with_degree = self._format_knowledge_from_query_result( + log.warning("Unable to find vid, downgraded to property query, please confirm if it meets expectation.") + + result: List[Any] = self._client.gremlin().exec(gremlin=gremlin_query)["data"] + graph_chain_knowledge, vertex_degree_list, knowledge_with_degree = self._format_graph_from_query_result( query_result=result ) - graph_chain_knowledge.update(vertex_knowledge) - vertex_degree_list[0].update(vertex_knowledge) context["graph_result"] = list(graph_chain_knowledge) context["vertex_degree_list"] = [list(vertex_degree) for vertex_degree in vertex_degree_list] @@ -172,7 +179,7 @@ class GraphRAGQuery: return context - def _format_knowledge_from_vertex(self, query_result: List[Any]) -> Set[str]: + def _format_graph_from_vertex(self, query_result: List[Any]) -> Set[str]: knowledge = set() for item in query_result: props_str = ", ".join(f"{k}: {v}" for k, v in item["properties"].items()) @@ -180,8 +187,8 @@ class GraphRAGQuery: knowledge.add(node_str) return knowledge - def _format_knowledge_from_query_result( - self, query_result: List[Any] + def _format_graph_from_query_result( + self, query_result: List[Any] ) -> Tuple[Set[str], List[Set[str]], Dict[str, List[str]]]: use_id_to_match = self._prop_to_match is None knowledge = set() @@ -234,18 +241,14 @@ class GraphRAGQuery: def _extract_labels_from_schema(self) -> Tuple[List[str], List[str]]: schema = self._get_graph_schema() node_props_str, edge_props_str = schema.split("\n")[:2] - node_props_str = node_props_str[len("Node properties: ") :].strip("[").strip("]") - edge_props_str = edge_props_str[len("Edge properties: ") :].strip("[").strip("]") + node_props_str = node_props_str[len("Node properties: "):].strip("[").strip("]") + edge_props_str = edge_props_str[len("Edge properties: "):].strip("[").strip("]") node_labels = self._extract_label_names(node_props_str) edge_labels = self._extract_label_names(edge_props_str) return node_labels, edge_labels @staticmethod - def _extract_label_names( - source: str, - head: str = "name: ", - tail: str = ", ", - ) -> List[str]: + def _extract_label_names(source: str, head: str = "name: ", tail: str = ", ") -> List[str]: result = [] for s in source.split(head): end = s.find(tail) diff --git a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/schema_manager.py b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/schema_manager.py index 5c002ae..a215bb4 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/schema_manager.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/schema_manager.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - +from typing import Dict, Any from hugegraph_llm.config import settings from pyhugegraph.client import PyHugeClient @@ -33,7 +33,9 @@ class SchemaManager: ) self.schema = self.client.schema() - def run(self): + # FIXME: This method is not working as expected + # def run(self, context: Dict[str, Any]) -> Dict[str, Any]: + def run(self) -> Dict[str, Any]: schema = self.schema.getSchema() vertices = [] for vl in schema["vertexlabels"]: diff --git a/hugegraph-llm/src/hugegraph_llm/operators/index_op/semantic_id_query.py b/hugegraph-llm/src/hugegraph_llm/operators/index_op/semantic_id_query.py index ef036d2..60be886 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/index_op/semantic_id_query.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/index_op/semantic_id_query.py @@ -40,19 +40,19 @@ class SemanticIdQuery: self.topk_per_keyword = topk_per_keyword def run(self, context: Dict[str, Any]) -> Dict[str, Any]: - graph_query_entrance = [] + graph_query_list = [] if self.by == "query": query = context["query"] query_vector = self.embedding.get_text_embedding(query) results = self.vector_index.search(query_vector, top_k=self.topk_per_query) if results: - graph_query_entrance.extend(results[:self.topk_per_query]) - else: # by keywords + graph_query_list.extend(results[:self.topk_per_query]) + else: keywords = context["keywords"] for keyword in keywords: keyword_vector = self.embedding.get_text_embedding(keyword) results = self.vector_index.search(keyword_vector, top_k=self.topk_per_keyword) if results: - graph_query_entrance.extend(results[:self.topk_per_keyword]) - context["entrance_vids"] = list(set(graph_query_entrance)) + graph_query_list.extend(results[:self.topk_per_keyword]) + context["match_vids"] = list(set(graph_query_list)) return context diff --git a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py index 2cb7264..a960d24 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py @@ -21,6 +21,7 @@ from typing import Any, Dict, Optional from hugegraph_llm.models.llms.base import BaseLLM from hugegraph_llm.models.llms.init_llm import LLMs +from hugegraph_llm.utils.log import log # TODO: we need enhance the template to answer the question (put it in a separate file) DEFAULT_ANSWER_TEMPLATE = """You are an expert in knowledge graphs and natural language processing. @@ -88,25 +89,27 @@ class AnswerSynthesize: response = self._llm.generate(prompt=prompt) return {"answer": response} - vector_result = context.get("vector_result", []) - if len(vector_result) == 0: - vector_result_context = "No (vector)phrase related to the query." - else: + vector_result = context.get("vector_result") + if vector_result: vector_result_context = "Phrases related to the query:\n" + "\n".join( f"{i + 1}. {res}" for i, res in enumerate(vector_result) ) - graph_result = context.get("graph_result", []) - if len(graph_result) == 0: - graph_result_context = "No knowledge found in HugeGraph for the query." else: + vector_result_context = "No (vector)phrase related to the query." + + graph_result = context.get("graph_result") + if graph_result: graph_context_head = context.get("graph_context_head", - "The following are knowledge from HugeGraph related to the query:\n") + "Knowledge from graph related to the query:\n") graph_result_context = graph_context_head + "\n".join( f"{i + 1}. {res}" for i, res in enumerate(graph_result) ) + else: + graph_result_context = "No related graph data found for current query." + log.warning(graph_result_context) + context = asyncio.run(self.async_generate(context, context_head_str, context_tail_str, vector_result_context, graph_result_context)) - return context async def async_generate(self, context: Dict[str, Any], context_head_str: str, diff --git a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/disambiguate_data.py b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/disambiguate_data.py index fffa03c..dcc0ab6 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/disambiguate_data.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/disambiguate_data.py @@ -46,7 +46,7 @@ class DisambiguateData: # only disambiguate triples if "triples" in data: # TODO: ensure the logic here - log.debug(data) + # log.debug(data) triples = data["triples"] prompt = generate_disambiguate_prompt(triples) llm_output = self.llm.generate(prompt=prompt) diff --git a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/keyword_extract.py b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/keyword_extract.py index e4854c1..b9f75db 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/keyword_extract.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/keyword_extract.py @@ -23,20 +23,18 @@ from hugegraph_llm.models.llms.base import BaseLLM from hugegraph_llm.models.llms.init_llm import LLMs from hugegraph_llm.operators.common_op.nltk_helper import NLTKHelper -DEFAULT_KEYWORDS_EXTRACT_TEMPLATE_TMPL = """extract {max_keywords} keywords from the text: - {question} - Provide keywords in the following comma-separated format: 'KEYWORDS: <keywords>' - """ - -DEFAULT_KEYWORDS_EXPAND_TEMPLATE_TMPL = ( - "Generate synonyms or possible form of keywords up to {max_keywords} in total,\n" - "considering possible cases of capitalization, pluralization, common expressions, etc.\n" - "Provide all synonyms of keywords in comma-separated format: 'SYNONYMS: <keywords>'\n" - "Note, result should be in one-line with only one 'SYNONYMS: ' prefix\n" - "----\n" - "KEYWORDS: {question}\n" - "----" -) +KEYWORDS_EXTRACT_TPL = """extract {max_keywords} keywords from the text: +{question} +Provide keywords in the following comma-separated format: 'KEYWORDS: <keywords>' +""" + +KEYWORDS_EXPAND_TPL = """Generate synonyms or possible form of keywords up to {max_keywords} in total, +considering possible cases of capitalization, pluralization, common expressions, etc. +Provide all synonyms of keywords in comma-separated format: 'SYNONYMS: <keywords>' +Note, result should be in one-line with only one 'SYNONYMS: ' prefix +---- +KEYWORDS: {question} +----""" class KeywordExtract: @@ -53,8 +51,8 @@ class KeywordExtract: self._query = text self._language = language.lower() self._max_keywords = max_keywords - self._extract_template = extract_template or DEFAULT_KEYWORDS_EXTRACT_TEMPLATE_TMPL - self._expand_template = expand_template or DEFAULT_KEYWORDS_EXPAND_TEMPLATE_TMPL + self._extract_template = extract_template or KEYWORDS_EXTRACT_TPL + self._expand_template = expand_template or KEYWORDS_EXPAND_TPL def run(self, context: Dict[str, Any]) -> Dict[str, Any]: if self._query is None: @@ -84,14 +82,15 @@ class KeywordExtract: response = self._llm.generate(prompt=prompt) keywords = self._extract_keywords_from_response( - response=response, lowercase=False, start_token="KEYWORDS:" + response=response, start_token="KEYWORDS:" ) keywords.union(self._expand_synonyms(keywords=keywords)) context["keywords"] = list(keywords) verbose = context.get("verbose") or False if verbose: - print(f"\033[92mKEYWORDS: {context['keywords']}\033[0m") + from hugegraph_llm.utils.log import log + log.info(f"KEYWORDS: {context['keywords']}") # extracting keywords & expanding synonyms increase the call count by 2 context["call_count"] = context.get("call_count", 0) + 2 @@ -104,16 +103,11 @@ class KeywordExtract: ) response = self._llm.generate(prompt=prompt) keywords = self._extract_keywords_from_response( - response=response, lowercase=False, start_token="SYNONYMS:" + response=response, start_token="SYNONYMS:" ) return keywords - def _extract_keywords_from_response( - self, - response: str, - lowercase: bool = True, - start_token: str = "", - ) -> Set[str]: + def _extract_keywords_from_response(self, response: str, start_token: str = "", ) -> Set[str]: keywords = [] matches = re.findall(rf'{start_token}[^\n]+\n?', response) for match in matches: @@ -130,8 +124,6 @@ class KeywordExtract: results.add(token) sub_tokens = re.findall(r"\w+", token) if len(sub_tokens) > 1: - results.update( - {w for w in sub_tokens if w not in NLTKHelper().stopwords(lang=self._language)} - ) + results.update({w for w in sub_tokens if w not in NLTKHelper().stopwords(lang=self._language)}) return results diff --git a/hugegraph-llm/src/hugegraph_llm/utils/decorators.py b/hugegraph-llm/src/hugegraph_llm/utils/decorators.py index 173c1f7..8cbe475 100644 --- a/hugegraph-llm/src/hugegraph_llm/utils/decorators.py +++ b/hugegraph-llm/src/hugegraph_llm/utils/decorators.py @@ -72,7 +72,7 @@ def log_operator_time(func: Callable) -> Callable: # Only record time ≥ 0.01s (10ms) if op_time >= 0.01: log.debug("Operator %s finished in %.2f seconds", operator.__class__.__name__, op_time) - log.debug("Context:\n%s", result) + log.debug("Current context: %s", result) return result return wrapper diff --git a/hugegraph-llm/src/hugegraph_llm/utils/graph_index_utils.py b/hugegraph-llm/src/hugegraph_llm/utils/graph_index_utils.py index 0db406e..10bf72e 100644 --- a/hugegraph-llm/src/hugegraph_llm/utils/graph_index_utils.py +++ b/hugegraph-llm/src/hugegraph_llm/utils/graph_index_utils.py @@ -64,7 +64,7 @@ def extract_graph(input_file, input_text, schema, example_prompt): (builder .chunk_split(texts, "paragraph", "zh") .extract_info(example_prompt, "property_graph")) - log.debug(builder.operators) + log.debug("Operators: %s", builder.operators) try: context = builder.run() return ( @@ -80,7 +80,7 @@ def extract_graph(input_file, input_text, schema, example_prompt): def fit_vid_index(): builder = KgBuilder(LLMs().get_llm(), Embeddings().get_embedding(), get_hg_client()) builder.fetch_graph_data().build_vertex_id_semantic_index() - log.debug(builder.operators) + log.debug("Operators: %s", builder.operators) try: context = builder.run() removed_num = context["removed_vid_vector_num"] @@ -109,7 +109,7 @@ def build_graph_index(input_file, input_text, schema, example_prompt): .extract_info(example_prompt, "property_graph") .commit_to_hugegraph() .build_vertex_id_semantic_index()) - log.debug(builder.operators) + log.debug("Operators: %s", builder.operators) try: context = builder.run() return json.dumps(context, ensure_ascii=False, indent=2)
