This is an automated email from the ASF dual-hosted git repository. jin pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/incubator-hugegraph-ai.git
commit 6ad8fd9ba6ed488923ba7902dfdec7bbc064a53a Author: Linyu <[email protected]> AuthorDate: Tue Sep 30 10:55:47 2025 +0800 refactor(RAG workflow): modularize flows, add streaming, and improve node initialization (#51) --- .../src/hugegraph_llm/demo/rag_demo/rag_block.py | 223 ++++++++++++--------- hugegraph-llm/src/hugegraph_llm/flows/common.py | 26 +++ .../src/hugegraph_llm/flows/rag_flow_graph_only.py | 153 ++++++++++++++ .../hugegraph_llm/flows/rag_flow_graph_vector.py | 158 +++++++++++++++ .../src/hugegraph_llm/flows/rag_flow_raw.py | 99 +++++++++ .../hugegraph_llm/flows/rag_flow_vector_only.py | 123 ++++++++++++ hugegraph-llm/src/hugegraph_llm/flows/scheduler.py | 63 ++++++ hugegraph-llm/src/hugegraph_llm/nodes/base_node.py | 3 + .../nodes/common_node/merge_rerank_node.py | 83 ++++++++ .../nodes/document_node/chunk_split.py | 2 +- .../nodes/hugegraph_node/commit_to_hugegraph.py | 3 +- .../nodes/hugegraph_node/fetch_graph_data.py | 3 +- .../nodes/hugegraph_node/graph_query_node.py | 93 +++++++++ .../hugegraph_llm/nodes/hugegraph_node/schema.py | 3 +- .../nodes/index_node/build_semantic_index.py | 3 +- .../nodes/index_node/build_vector_index.py | 3 +- .../index_node/gremlin_example_index_query.py | 13 +- .../nodes/index_node/semantic_id_query_node.py | 91 +++++++++ .../nodes/index_node/vector_query_node.py | 74 +++++++ .../nodes/llm_node/answer_synthesize_node.py | 99 +++++++++ .../hugegraph_llm/nodes/llm_node/extract_info.py | 2 +- .../nodes/llm_node/keyword_extract_node.py | 80 ++++++++ .../nodes/llm_node/prompt_generate.py | 2 +- .../hugegraph_llm/nodes/llm_node/schema_build.py | 2 +- .../hugegraph_llm/nodes/llm_node/text2gremlin.py | 10 +- hugegraph-llm/src/hugegraph_llm/state/ai_state.py | 106 +++++++++- .../src/hugegraph_llm/utils/graph_index_utils.py | 115 ++--------- 27 files changed, 1411 insertions(+), 224 deletions(-) diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/rag_block.py b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/rag_block.py index 8f70c34b..ca36867d 100644 --- a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/rag_block.py +++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/rag_block.py @@ -21,12 +21,11 @@ import os from typing import AsyncGenerator, Literal, Optional, Tuple import gradio as gr +from hugegraph_llm.flows.scheduler import SchedulerSingleton import pandas as pd from gradio.utils import NamedString -from hugegraph_llm.config import huge_settings, llm_settings, prompt, resource_path -from hugegraph_llm.operators.graph_rag_task import RAGPipeline -from hugegraph_llm.operators.llm_op.answer_synthesize import AnswerSynthesize +from hugegraph_llm.config import resource_path, prompt, llm_settings from hugegraph_llm.utils.decorators import with_task_id from hugegraph_llm.utils.log import log @@ -72,44 +71,51 @@ def rag_answer( gr.Warning("Please select at least one generate mode.") return "", "", "", "" - rag = RAGPipeline() - if vector_search: - rag.query_vector_index() - if graph_search: - rag.extract_keywords(extract_template=keywords_extract_prompt).keywords_to_vid( - vector_dis_threshold=vector_dis_threshold, - topk_per_keyword=topk_per_keyword, - ).import_schema(huge_settings.graph_name).query_graphdb( - num_gremlin_generate_example=gremlin_tmpl_num, - gremlin_prompt=gremlin_prompt, - max_graph_items=max_graph_items, - ) - # TODO: add more user-defined search strategies - rag.merge_dedup_rerank( - graph_ratio=graph_ratio, - rerank_method=rerank_method, - near_neighbor_first=near_neighbor_first, - topk_return_results=topk_return_results, - ) - rag.synthesize_answer( - raw_answer, vector_only_answer, graph_only_answer, graph_vector_answer, answer_prompt - ) - + scheduler = SchedulerSingleton.get_instance() try: - context = rag.run( - verbose=True, + # Select workflow by mode to avoid fetching the wrong pipeline from the pool + if graph_vector_answer or (graph_only_answer and vector_only_answer): + flow_key = "rag_graph_vector" + elif vector_only_answer: + flow_key = "rag_vector_only" + elif graph_only_answer: + flow_key = "rag_graph_only" + elif raw_answer: + flow_key = "rag_raw" + else: + raise RuntimeError("Unsupported flow type") + + res = scheduler.schedule_flow( + flow_key, query=text, vector_search=vector_search, graph_search=graph_search, + raw_answer=raw_answer, + vector_only_answer=vector_only_answer, + graph_only_answer=graph_only_answer, + graph_vector_answer=graph_vector_answer, + graph_ratio=graph_ratio, + rerank_method=rerank_method, + near_neighbor_first=near_neighbor_first, + custom_related_information=custom_related_information, + answer_prompt=answer_prompt, + keywords_extract_prompt=keywords_extract_prompt, + gremlin_tmpl_num=gremlin_tmpl_num, + gremlin_prompt=gremlin_prompt, max_graph_items=max_graph_items, + topk_return_results=topk_return_results, + vector_dis_threshold=vector_dis_threshold, + topk_per_keyword=topk_per_keyword, ) - if context.get("switch_to_bleu"): - gr.Warning("Online reranker fails, automatically switches to local bleu rerank.") + if res.get("switch_to_bleu"): + gr.Warning( + "Online reranker fails, automatically switches to local bleu rerank." + ) return ( - context.get("raw_answer", ""), - context.get("vector_only_answer", ""), - context.get("graph_only_answer", ""), - context.get("graph_vector_answer", ""), + res.get("raw_answer", ""), + res.get("vector_only_answer", ""), + res.get("graph_only_answer", ""), + res.get("graph_vector_answer", ""), ) except ValueError as e: log.critical(e) @@ -187,44 +193,47 @@ async def rag_answer_streaming( yield "", "", "", "" return - rag = RAGPipeline() - if vector_search: - rag.query_vector_index() - if graph_search: - rag.extract_keywords( - extract_template=keywords_extract_prompt - ).keywords_to_vid().import_schema(huge_settings.graph_name).query_graphdb( - num_gremlin_generate_example=gremlin_tmpl_num, - gremlin_prompt=gremlin_prompt, - ) - rag.merge_dedup_rerank( - graph_ratio, - rerank_method, - near_neighbor_first, - ) - # rag.synthesize_answer(raw_answer, vector_only_answer, graph_only_answer, graph_vector_answer, answer_prompt) - try: - context = rag.run( - verbose=True, query=text, vector_search=vector_search, graph_search=graph_search - ) - if context.get("switch_to_bleu"): - gr.Warning("Online reranker fails, automatically switches to local bleu rerank.") - answer_synthesize = AnswerSynthesize( + # Select the specific streaming workflow + scheduler = SchedulerSingleton.get_instance() + if graph_vector_answer or (graph_only_answer and vector_only_answer): + flow_key = "rag_graph_vector" + elif vector_only_answer: + flow_key = "rag_vector_only" + elif graph_only_answer: + flow_key = "rag_graph_only" + elif raw_answer: + flow_key = "rag_raw" + else: + raise RuntimeError("Unsupported flow type") + + async for res in scheduler.schedule_stream_flow( + flow_key, + query=text, + vector_search=vector_search, + graph_search=graph_search, raw_answer=raw_answer, vector_only_answer=vector_only_answer, graph_only_answer=graph_only_answer, graph_vector_answer=graph_vector_answer, - prompt_template=answer_prompt, - ) - async for context in answer_synthesize.run_streaming(context): - if context.get("switch_to_bleu"): - gr.Warning("Online reranker fails, automatically switches to local bleu rerank.") + graph_ratio=graph_ratio, + rerank_method=rerank_method, + near_neighbor_first=near_neighbor_first, + custom_related_information=custom_related_information, + answer_prompt=answer_prompt, + keywords_extract_prompt=keywords_extract_prompt, + gremlin_tmpl_num=gremlin_tmpl_num, + gremlin_prompt=gremlin_prompt, + ): + if res.get("switch_to_bleu"): + gr.Warning( + "Online reranker fails, automatically switches to local bleu rerank." + ) yield ( - context.get("raw_answer", ""), - context.get("vector_only_answer", ""), - context.get("graph_only_answer", ""), - context.get("graph_vector_answer", ""), + res.get("raw_answer", ""), + res.get("vector_only_answer", ""), + res.get("graph_only_answer", ""), + res.get("graph_vector_answer", ""), ) except ValueError as e: log.critical(e) @@ -242,7 +251,10 @@ def create_rag_block(): with gr.Column(scale=2): # with gr.Blocks().queue(max_size=20, default_concurrency_limit=5): inp = gr.Textbox( - value=prompt.default_question, label="Question", show_copy_button=True, lines=3 + value=prompt.default_question, + label="Question", + show_copy_button=True, + lines=3, ) # TODO: Only support inline formula now. Should support block formula @@ -271,7 +283,10 @@ def create_rag_block(): latex_delimiters=[{"left": "$", "right": "$", "display": False}], ) answer_prompt_input = gr.Textbox( - value=prompt.answer_prompt, label="Query Prompt", show_copy_button=True, lines=7 + value=prompt.answer_prompt, + label="Query Prompt", + show_copy_button=True, + lines=7, ) keywords_extract_prompt_input = gr.Textbox( value=prompt.keywords_extract_prompt, @@ -282,7 +297,9 @@ def create_rag_block(): with gr.Column(scale=1): with gr.Row(): - raw_radio = gr.Radio(choices=[True, False], value=False, label="Basic LLM Answer") + raw_radio = gr.Radio( + choices=[True, False], value=False, label="Basic LLM Answer" + ) vector_only_radio = gr.Radio( choices=[True, False], value=False, label="Vector-only Answer" ) @@ -306,7 +323,9 @@ def create_rag_block(): label="Rerank method", ) example_num = gr.Number( - value=-1, label="Template Num (<0 means disable text2gql) ", precision=0 + value=-1, + label="Template Num (<0 means disable text2gql) ", + precision=0, ) graph_ratio = gr.Slider( 0, 1, 0.6, label="Graph Ratio", step=0.1, interactive=False @@ -351,7 +370,7 @@ def create_rag_block(): """## 2. (Batch) Back-testing ) > 1. Download the template file & fill in the questions you want to test. > 2. Upload the file & click the button to generate answers. (Preview shows the first 40 lines) - > 3. The answer options are the same as the above RAG/Q&A frame + > 3. The answer options are the same as the above RAG/Q&A frame """ ) tests_df_headers = [ @@ -365,7 +384,9 @@ def create_rag_block(): # FIXME: "demo" might conflict with the graph name, it should be modified. answers_path = os.path.join(resource_path, "demo", "questions_answers.xlsx") questions_path = os.path.join(resource_path, "demo", "questions.xlsx") - questions_template_path = os.path.join(resource_path, "demo", "questions_template.xlsx") + questions_template_path = os.path.join( + resource_path, "demo", "questions_template.xlsx" + ) def read_file_to_excel(file: NamedString, line_count: Optional[int] = None): df = None @@ -412,20 +433,23 @@ def create_rag_block(): total_rows = len(df) for index, row in df.iterrows(): question = row.iloc[0] - basic_llm_answer, vector_only_answer, graph_only_answer, graph_vector_answer = ( - rag_answer( - question, - is_raw_answer, - is_vector_only_answer, - is_graph_only_answer, - is_graph_vector_answer, - graph_ratio_ui, - rerank_method_ui, - near_neighbor_first_ui, - custom_related_information_ui, - answer_prompt, - keywords_extract_prompt, - ) + ( + basic_llm_answer, + vector_only_answer, + graph_only_answer, + graph_vector_answer, + ) = rag_answer( + question, + is_raw_answer, + is_vector_only_answer, + is_graph_only_answer, + is_graph_vector_answer, + graph_ratio_ui, + rerank_method_ui, + near_neighbor_first_ui, + custom_related_information_ui, + answer_prompt, + keywords_extract_prompt, ) df.at[index, "Basic LLM Answer"] = basic_llm_answer df.at[index, "Vector-only Answer"] = vector_only_answer @@ -442,12 +466,18 @@ def create_rag_block(): file_types=[".xlsx", ".csv"], label="Questions File (.xlsx & csv)" ) with gr.Column(): - test_template_file = os.path.join(resource_path, "demo", "questions_template.xlsx") + test_template_file = os.path.join( + resource_path, "demo", "questions_template.xlsx" + ) gr.File(value=test_template_file, label="Download Template File") - answer_max_line_count = gr.Number(1, label="Max Lines To Show", minimum=1, maximum=40) + answer_max_line_count = gr.Number( + 1, label="Max Lines To Show", minimum=1, maximum=40 + ) answers_btn = gr.Button("Generate Answer (Batch)", variant="primary") # TODO: Set individual progress bars for dataframe - qa_dataframe = gr.DataFrame(label="Questions & Answers (Preview)", headers=tests_df_headers) + qa_dataframe = gr.DataFrame( + label="Questions & Answers (Preview)", headers=tests_df_headers + ) answers_btn.click( several_rag_answer, inputs=[ @@ -465,6 +495,15 @@ def create_rag_block(): ], outputs=[qa_dataframe, gr.File(label="Download Answered File", min_width=40)], ) - questions_file.change(read_file_to_excel, questions_file, [qa_dataframe, answer_max_line_count]) - answer_max_line_count.change(change_showing_excel, answer_max_line_count, qa_dataframe) - return inp, answer_prompt_input, keywords_extract_prompt_input, custom_related_information + questions_file.change( + read_file_to_excel, questions_file, [qa_dataframe, answer_max_line_count] + ) + answer_max_line_count.change( + change_showing_excel, answer_max_line_count, qa_dataframe + ) + return ( + inp, + answer_prompt_input, + keywords_extract_prompt_input, + custom_related_information, + ) diff --git a/hugegraph-llm/src/hugegraph_llm/flows/common.py b/hugegraph-llm/src/hugegraph_llm/flows/common.py index 4c552626..e2348466 100644 --- a/hugegraph-llm/src/hugegraph_llm/flows/common.py +++ b/hugegraph-llm/src/hugegraph_llm/flows/common.py @@ -14,8 +14,10 @@ # limitations under the License. from abc import ABC, abstractmethod +from typing import Dict, Any, AsyncGenerator from hugegraph_llm.state.ai_state import WkFlowInput +from hugegraph_llm.utils.log import log class BaseFlow(ABC): @@ -43,3 +45,27 @@ class BaseFlow(ABC): Post-processing interface. """ pass + + async def post_deal_stream( + self, pipeline=None + ) -> AsyncGenerator[Dict[str, Any], None]: + """ + Streaming post-processing interface. + Subclasses can override this method as needed. + """ + flow_name = self.__class__.__name__ + if pipeline is None: + yield {"error": "No pipeline provided"} + return + try: + state_json = pipeline.getGParamWithNoEmpty("wkflow_state").to_json() + log.info(f"{flow_name} post processing success") + stream_flow = state_json.get("stream_generator") + if stream_flow is None: + yield {"error": "No stream_generator found in workflow state"} + return + async for chunk in stream_flow: + yield chunk + except Exception as e: + log.error(f"{flow_name} post processing failed: {e}") + yield {"error": f"Post processing failed: {str(e)}"} diff --git a/hugegraph-llm/src/hugegraph_llm/flows/rag_flow_graph_only.py b/hugegraph-llm/src/hugegraph_llm/flows/rag_flow_graph_only.py new file mode 100644 index 00000000..5feb3d47 --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/flows/rag_flow_graph_only.py @@ -0,0 +1,153 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +from typing import Optional, Literal + +from PyCGraph import GPipeline + +from hugegraph_llm.flows.common import BaseFlow +from hugegraph_llm.nodes.llm_node.keyword_extract_node import KeywordExtractNode +from hugegraph_llm.nodes.index_node.semantic_id_query_node import SemanticIdQueryNode +from hugegraph_llm.nodes.hugegraph_node.schema import SchemaNode +from hugegraph_llm.nodes.hugegraph_node.graph_query_node import GraphQueryNode +from hugegraph_llm.nodes.common_node.merge_rerank_node import MergeRerankNode +from hugegraph_llm.nodes.llm_node.answer_synthesize_node import AnswerSynthesizeNode +from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState +from hugegraph_llm.config import huge_settings, prompt +from hugegraph_llm.utils.log import log + + +class RAGGraphOnlyFlow(BaseFlow): + """ + Workflow for graph-only answering (graph_only_answer) + """ + + def prepare( + self, + prepared_input: WkFlowInput, + query: str, + vector_search: bool = None, + graph_search: bool = None, + raw_answer: bool = None, + vector_only_answer: bool = None, + graph_only_answer: bool = None, + graph_vector_answer: bool = None, + graph_ratio: float = 0.5, + rerank_method: Literal["bleu", "reranker"] = "bleu", + near_neighbor_first: bool = False, + custom_related_information: str = "", + answer_prompt: Optional[str] = None, + keywords_extract_prompt: Optional[str] = None, + gremlin_tmpl_num: Optional[int] = -1, + gremlin_prompt: Optional[str] = None, + max_graph_items: int = None, + topk_return_results: int = None, + vector_dis_threshold: float = None, + topk_per_keyword: int = None, + **_: dict, + ): + prepared_input.query = query + prepared_input.vector_search = vector_search + prepared_input.graph_search = graph_search + prepared_input.raw_answer = raw_answer + prepared_input.vector_only_answer = vector_only_answer + prepared_input.graph_only_answer = graph_only_answer + prepared_input.graph_vector_answer = graph_vector_answer + prepared_input.gremlin_tmpl_num = gremlin_tmpl_num + prepared_input.gremlin_prompt = gremlin_prompt or prompt.gremlin_generate_prompt + prepared_input.max_graph_items = ( + max_graph_items or huge_settings.max_graph_items + ) + prepared_input.topk_per_keyword = ( + topk_per_keyword or huge_settings.topk_per_keyword + ) + prepared_input.topk_return_results = ( + topk_return_results or huge_settings.topk_return_results + ) + prepared_input.rerank_method = rerank_method + prepared_input.near_neighbor_first = near_neighbor_first + prepared_input.keywords_extract_prompt = ( + keywords_extract_prompt or prompt.keywords_extract_prompt + ) + prepared_input.answer_prompt = answer_prompt or prompt.answer_prompt + prepared_input.custom_related_information = custom_related_information + prepared_input.vector_dis_threshold = ( + vector_dis_threshold or huge_settings.vector_dis_threshold + ) + prepared_input.schema = huge_settings.graph_name + + prepared_input.data_json = { + "query": query, + "vector_search": vector_search, + "graph_search": graph_search, + "max_graph_items": max_graph_items or huge_settings.max_graph_items, + } + return + + def build_flow(self, **kwargs): + pipeline = GPipeline() + prepared_input = WkFlowInput() + self.prepare(prepared_input, **kwargs) + pipeline.createGParam(prepared_input, "wkflow_input") + pipeline.createGParam(WkFlowState(), "wkflow_state") + + # Create nodes and register them with registerGElement + only_keyword_extract_node = KeywordExtractNode() + only_semantic_id_query_node = SemanticIdQueryNode() + only_schema_node = SchemaNode() + only_graph_query_node = GraphQueryNode() + merge_rerank_node = MergeRerankNode() + answer_synthesize_node = AnswerSynthesizeNode() + + pipeline.registerGElement(only_keyword_extract_node, set(), "only_keyword") + pipeline.registerGElement( + only_semantic_id_query_node, {only_keyword_extract_node}, "only_semantic" + ) + pipeline.registerGElement(only_schema_node, set(), "only_schema") + pipeline.registerGElement( + only_graph_query_node, + {only_schema_node, only_semantic_id_query_node}, + "only_graph", + ) + pipeline.registerGElement( + merge_rerank_node, {only_graph_query_node}, "merge_one" + ) + pipeline.registerGElement(answer_synthesize_node, {merge_rerank_node}, "graph") + log.info("RAGGraphOnlyFlow pipeline built successfully") + return pipeline + + def post_deal(self, pipeline=None): + if pipeline is None: + return json.dumps( + {"error": "No pipeline provided"}, ensure_ascii=False, indent=2 + ) + try: + res = pipeline.getGParamWithNoEmpty("wkflow_state").to_json() + log.info("RAGGraphOnlyFlow post processing success") + return { + "raw_answer": res.get("raw_answer", ""), + "vector_only_answer": res.get("vector_only_answer", ""), + "graph_only_answer": res.get("graph_only_answer", ""), + "graph_vector_answer": res.get("graph_vector_answer", ""), + } + except Exception as e: + log.error(f"RAGGraphOnlyFlow post processing failed: {e}") + return json.dumps( + {"error": f"Post processing failed: {str(e)}"}, + ensure_ascii=False, + indent=2, + ) diff --git a/hugegraph-llm/src/hugegraph_llm/flows/rag_flow_graph_vector.py b/hugegraph-llm/src/hugegraph_llm/flows/rag_flow_graph_vector.py new file mode 100644 index 00000000..2f4a2bfa --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/flows/rag_flow_graph_vector.py @@ -0,0 +1,158 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +from typing import Optional, Literal + +from PyCGraph import GPipeline + +from hugegraph_llm.flows.common import BaseFlow +from hugegraph_llm.nodes.index_node.vector_query_node import VectorQueryNode +from hugegraph_llm.nodes.llm_node.keyword_extract_node import KeywordExtractNode +from hugegraph_llm.nodes.index_node.semantic_id_query_node import SemanticIdQueryNode +from hugegraph_llm.nodes.hugegraph_node.schema import SchemaNode +from hugegraph_llm.nodes.hugegraph_node.graph_query_node import GraphQueryNode +from hugegraph_llm.nodes.common_node.merge_rerank_node import MergeRerankNode +from hugegraph_llm.nodes.llm_node.answer_synthesize_node import AnswerSynthesizeNode +from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState +from hugegraph_llm.config import huge_settings, prompt +from hugegraph_llm.utils.log import log + + +class RAGGraphVectorFlow(BaseFlow): + """ + Workflow for graph + vector hybrid answering (graph_vector_answer) + """ + + def prepare( + self, + prepared_input: WkFlowInput, + query: str, + vector_search: bool = None, + graph_search: bool = None, + raw_answer: bool = None, + vector_only_answer: bool = None, + graph_only_answer: bool = None, + graph_vector_answer: bool = None, + graph_ratio: float = 0.5, + rerank_method: Literal["bleu", "reranker"] = "bleu", + near_neighbor_first: bool = False, + custom_related_information: str = "", + answer_prompt: Optional[str] = None, + keywords_extract_prompt: Optional[str] = None, + gremlin_tmpl_num: Optional[int] = -1, + gremlin_prompt: Optional[str] = None, + max_graph_items: int = None, + topk_return_results: int = None, + vector_dis_threshold: float = None, + topk_per_keyword: int = None, + **_: dict, + ): + prepared_input.query = query + prepared_input.vector_search = vector_search + prepared_input.graph_search = graph_search + prepared_input.raw_answer = raw_answer + prepared_input.vector_only_answer = vector_only_answer + prepared_input.graph_only_answer = graph_only_answer + prepared_input.graph_vector_answer = graph_vector_answer + prepared_input.graph_ratio = graph_ratio + prepared_input.gremlin_tmpl_num = gremlin_tmpl_num + prepared_input.gremlin_prompt = gremlin_prompt or prompt.gremlin_generate_prompt + prepared_input.max_graph_items = ( + max_graph_items or huge_settings.max_graph_items + ) + prepared_input.topk_return_results = ( + topk_return_results or huge_settings.topk_return_results + ) + prepared_input.topk_per_keyword = ( + topk_per_keyword or huge_settings.topk_per_keyword + ) + prepared_input.vector_dis_threshold = ( + vector_dis_threshold or huge_settings.vector_dis_threshold + ) + prepared_input.rerank_method = rerank_method + prepared_input.near_neighbor_first = near_neighbor_first + prepared_input.keywords_extract_prompt = ( + keywords_extract_prompt or prompt.keywords_extract_prompt + ) + prepared_input.answer_prompt = answer_prompt or prompt.answer_prompt + prepared_input.custom_related_information = custom_related_information + prepared_input.schema = huge_settings.graph_name + + prepared_input.data_json = { + "query": query, + "vector_search": vector_search, + "graph_search": graph_search, + "max_graph_items": max_graph_items or huge_settings.max_graph_items, + } + return + + def build_flow(self, **kwargs): + pipeline = GPipeline() + prepared_input = WkFlowInput() + self.prepare(prepared_input, **kwargs) + pipeline.createGParam(prepared_input, "wkflow_input") + pipeline.createGParam(WkFlowState(), "wkflow_state") + + # Create nodes (registration style consistent with RAGFlow) + vector_query_node = VectorQueryNode() + keyword_extract_node = KeywordExtractNode() + semantic_id_query_node = SemanticIdQueryNode() + schema_node = SchemaNode() + graph_query_node = GraphQueryNode() + merge_rerank_node = MergeRerankNode() + answer_synthesize_node = AnswerSynthesizeNode() + + # Register nodes and their dependencies + pipeline.registerGElement(vector_query_node, set(), "vector") + pipeline.registerGElement(keyword_extract_node, set(), "keyword") + pipeline.registerGElement( + semantic_id_query_node, {keyword_extract_node}, "semantic" + ) + pipeline.registerGElement(schema_node, set(), "schema") + pipeline.registerGElement( + graph_query_node, {schema_node, semantic_id_query_node}, "graph" + ) + pipeline.registerGElement( + merge_rerank_node, {graph_query_node, vector_query_node}, "merge" + ) + pipeline.registerGElement( + answer_synthesize_node, {merge_rerank_node}, "graph_vector" + ) + log.info("RAGGraphVectorFlow pipeline built successfully") + return pipeline + + def post_deal(self, pipeline=None): + if pipeline is None: + return json.dumps( + {"error": "No pipeline provided"}, ensure_ascii=False, indent=2 + ) + try: + res = pipeline.getGParamWithNoEmpty("wkflow_state").to_json() + log.info("RAGGraphVectorFlow post processing success") + return { + "raw_answer": res.get("raw_answer", ""), + "vector_only_answer": res.get("vector_only_answer", ""), + "graph_only_answer": res.get("graph_only_answer", ""), + "graph_vector_answer": res.get("graph_vector_answer", ""), + } + except Exception as e: + log.error(f"RAGGraphVectorFlow post processing failed: {e}") + return json.dumps( + {"error": f"Post processing failed: {str(e)}"}, + ensure_ascii=False, + indent=2, + ) diff --git a/hugegraph-llm/src/hugegraph_llm/flows/rag_flow_raw.py b/hugegraph-llm/src/hugegraph_llm/flows/rag_flow_raw.py new file mode 100644 index 00000000..f62e574b --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/flows/rag_flow_raw.py @@ -0,0 +1,99 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +from typing import Optional + +from PyCGraph import GPipeline + +from hugegraph_llm.flows.common import BaseFlow +from hugegraph_llm.nodes.llm_node.answer_synthesize_node import AnswerSynthesizeNode +from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState +from hugegraph_llm.config import huge_settings, prompt +from hugegraph_llm.utils.log import log + + +class RAGRawFlow(BaseFlow): + """ + Workflow for basic LLM answering only (raw_answer) + """ + + def prepare( + self, + prepared_input: WkFlowInput, + query: str, + vector_search: bool = None, + graph_search: bool = None, + raw_answer: bool = None, + vector_only_answer: bool = None, + graph_only_answer: bool = None, + graph_vector_answer: bool = None, + custom_related_information: str = "", + answer_prompt: Optional[str] = None, + max_graph_items: int = None, + **_: dict, + ): + prepared_input.query = query + prepared_input.raw_answer = raw_answer + prepared_input.vector_only_answer = vector_only_answer + prepared_input.graph_only_answer = graph_only_answer + prepared_input.graph_vector_answer = graph_vector_answer + prepared_input.custom_related_information = custom_related_information + prepared_input.answer_prompt = answer_prompt or prompt.answer_prompt + prepared_input.schema = huge_settings.graph_name + + prepared_input.data_json = { + "query": query, + "vector_search": vector_search, + "graph_search": graph_search, + "max_graph_items": max_graph_items or huge_settings.max_graph_items, + } + return + + def build_flow(self, **kwargs): + pipeline = GPipeline() + prepared_input = WkFlowInput() + self.prepare(prepared_input, **kwargs) + pipeline.createGParam(prepared_input, "wkflow_input") + pipeline.createGParam(WkFlowState(), "wkflow_state") + + # Create nodes and register with registerGElement (no GRegion required) + answer_synthesize_node = AnswerSynthesizeNode() + pipeline.registerGElement(answer_synthesize_node, set(), "raw") + log.info("RAGRawFlow pipeline built successfully") + return pipeline + + def post_deal(self, pipeline=None): + if pipeline is None: + return json.dumps( + {"error": "No pipeline provided"}, ensure_ascii=False, indent=2 + ) + try: + res = pipeline.getGParamWithNoEmpty("wkflow_state").to_json() + log.info("RAGRawFlow post processing success") + return { + "raw_answer": res.get("raw_answer", ""), + "vector_only_answer": res.get("vector_only_answer", ""), + "graph_only_answer": res.get("graph_only_answer", ""), + "graph_vector_answer": res.get("graph_vector_answer", ""), + } + except Exception as e: + log.error(f"RAGRawFlow post processing failed: {e}") + return json.dumps( + {"error": f"Post processing failed: {str(e)}"}, + ensure_ascii=False, + indent=2, + ) diff --git a/hugegraph-llm/src/hugegraph_llm/flows/rag_flow_vector_only.py b/hugegraph-llm/src/hugegraph_llm/flows/rag_flow_vector_only.py new file mode 100644 index 00000000..c727eacc --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/flows/rag_flow_vector_only.py @@ -0,0 +1,123 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +from typing import Optional, Literal + +from PyCGraph import GPipeline + +from hugegraph_llm.flows.common import BaseFlow +from hugegraph_llm.nodes.index_node.vector_query_node import VectorQueryNode +from hugegraph_llm.nodes.common_node.merge_rerank_node import MergeRerankNode +from hugegraph_llm.nodes.llm_node.answer_synthesize_node import AnswerSynthesizeNode +from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState +from hugegraph_llm.config import huge_settings, prompt +from hugegraph_llm.utils.log import log + + +class RAGVectorOnlyFlow(BaseFlow): + """ + Workflow for vector-only answering (vector_only_answer) + """ + + def prepare( + self, + prepared_input: WkFlowInput, + query: str, + vector_search: bool = None, + graph_search: bool = None, + raw_answer: bool = None, + vector_only_answer: bool = None, + graph_only_answer: bool = None, + graph_vector_answer: bool = None, + rerank_method: Literal["bleu", "reranker"] = "bleu", + near_neighbor_first: bool = False, + custom_related_information: str = "", + answer_prompt: Optional[str] = None, + max_graph_items: int = None, + topk_return_results: int = None, + vector_dis_threshold: float = None, + **_: dict, + ): + prepared_input.query = query + prepared_input.vector_search = vector_search + prepared_input.graph_search = graph_search + prepared_input.raw_answer = raw_answer + prepared_input.vector_only_answer = vector_only_answer + prepared_input.graph_only_answer = graph_only_answer + prepared_input.graph_vector_answer = graph_vector_answer + prepared_input.vector_dis_threshold = ( + vector_dis_threshold or huge_settings.vector_dis_threshold + ) + prepared_input.topk_return_results = ( + topk_return_results or huge_settings.topk_return_results + ) + prepared_input.rerank_method = rerank_method + prepared_input.near_neighbor_first = near_neighbor_first + prepared_input.custom_related_information = custom_related_information + prepared_input.answer_prompt = answer_prompt or prompt.answer_prompt + prepared_input.schema = huge_settings.graph_name + + prepared_input.data_json = { + "query": query, + "vector_search": vector_search, + "graph_search": graph_search, + "max_graph_items": max_graph_items or huge_settings.max_graph_items, + } + return + + def build_flow(self, **kwargs): + pipeline = GPipeline() + prepared_input = WkFlowInput() + self.prepare(prepared_input, **kwargs) + pipeline.createGParam(prepared_input, "wkflow_input") + pipeline.createGParam(WkFlowState(), "wkflow_state") + + # Create nodes (do not use GRegion, use registerGElement for all nodes) + only_vector_query_node = VectorQueryNode() + merge_rerank_node = MergeRerankNode() + answer_synthesize_node = AnswerSynthesizeNode() + + # Register nodes and dependencies, keep naming consistent with original + pipeline.registerGElement(only_vector_query_node, set(), "only_vector") + pipeline.registerGElement( + merge_rerank_node, {only_vector_query_node}, "merge_two" + ) + pipeline.registerGElement(answer_synthesize_node, {merge_rerank_node}, "vector") + log.info("RAGVectorOnlyFlow pipeline built successfully") + return pipeline + + def post_deal(self, pipeline=None): + if pipeline is None: + return json.dumps( + {"error": "No pipeline provided"}, ensure_ascii=False, indent=2 + ) + try: + res = pipeline.getGParamWithNoEmpty("wkflow_state").to_json() + log.info("RAGVectorOnlyFlow post processing success") + return { + "raw_answer": res.get("raw_answer", ""), + "vector_only_answer": res.get("vector_only_answer", ""), + "graph_only_answer": res.get("graph_only_answer", ""), + "graph_vector_answer": res.get("graph_vector_answer", ""), + } + except Exception as e: + log.error(f"RAGVectorOnlyFlow post processing failed: {e}") + return json.dumps( + {"error": f"Post processing failed: {str(e)}"}, + ensure_ascii=False, + indent=2, + ) diff --git a/hugegraph-llm/src/hugegraph_llm/flows/scheduler.py b/hugegraph-llm/src/hugegraph_llm/flows/scheduler.py index 3aedbe7f..5afa1bf8 100644 --- a/hugegraph-llm/src/hugegraph_llm/flows/scheduler.py +++ b/hugegraph-llm/src/hugegraph_llm/flows/scheduler.py @@ -24,6 +24,11 @@ from hugegraph_llm.flows.update_vid_embeddings import UpdateVidEmbeddingsFlows from hugegraph_llm.flows.get_graph_index_info import GetGraphIndexInfoFlow from hugegraph_llm.flows.build_schema import BuildSchemaFlow from hugegraph_llm.flows.prompt_generate import PromptGenerateFlow +from hugegraph_llm.flows.rag_flow_raw import RAGRawFlow +from hugegraph_llm.flows.rag_flow_vector_only import RAGVectorOnlyFlow +from hugegraph_llm.flows.rag_flow_graph_only import RAGGraphOnlyFlow +from hugegraph_llm.flows.rag_flow_graph_vector import RAGGraphVectorFlow +from hugegraph_llm.state.ai_state import WkFlowInput from hugegraph_llm.utils.log import log from hugegraph_llm.flows.text2gremlin import Text2GremlinFlow @@ -67,6 +72,23 @@ class Scheduler: "manager": GPipelineManager(), "flow": Text2GremlinFlow(), } + # New split rag pipelines + self.pipeline_pool["rag_raw"] = { + "manager": GPipelineManager(), + "flow": RAGRawFlow(), + } + self.pipeline_pool["rag_vector_only"] = { + "manager": GPipelineManager(), + "flow": RAGVectorOnlyFlow(), + } + self.pipeline_pool["rag_graph_only"] = { + "manager": GPipelineManager(), + "flow": RAGGraphOnlyFlow(), + } + self.pipeline_pool["rag_graph_vector"] = { + "manager": GPipelineManager(), + "flow": RAGGraphVectorFlow(), + } self.max_pipeline = max_pipeline # TODO: Implement Agentic Workflow @@ -108,6 +130,47 @@ class Scheduler: manager.release(pipeline) return res + async def schedule_stream_flow(self, flow: str, *args, **kwargs): + if flow not in self.pipeline_pool: + raise ValueError(f"Unsupported workflow {flow}") + manager: GPipelineManager = self.pipeline_pool[flow]["manager"] + flow: BaseFlow = self.pipeline_pool[flow]["flow"] + pipeline: GPipeline = manager.fetch() + if pipeline is None: + # call coresponding flow_func to create new workflow + pipeline = flow.build_flow(*args, **kwargs) + try: + pipeline.getGParamWithNoEmpty("wkflow_input").stream = True + status = pipeline.init() + if status.isErr(): + error_msg = f"Error in flow init: {status.getInfo()}" + log.error(error_msg) + raise RuntimeError(error_msg) + status = pipeline.run() + if status.isErr(): + error_msg = f"Error in flow execution: {status.getInfo()}" + log.error(error_msg) + raise RuntimeError(error_msg) + async for res in flow.post_deal_stream(pipeline): + yield res + finally: + manager.add(pipeline) + else: + try: + # fetch pipeline & prepare input for flow + prepared_input: WkFlowInput = pipeline.getGParamWithNoEmpty( + "wkflow_input" + ) + prepared_input.stream = True + flow.prepare(prepared_input, *args, **kwargs) + status = pipeline.run() + if status.isErr(): + raise RuntimeError(f"Error in flow execution {status.getInfo()}") + async for res in flow.post_deal_stream(pipeline): + yield res + finally: + manager.release(pipeline) + class SchedulerSingleton: _instance = None diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/base_node.py b/hugegraph-llm/src/hugegraph_llm/nodes/base_node.py index 0ea0675c..f9016730 100644 --- a/hugegraph-llm/src/hugegraph_llm/nodes/base_node.py +++ b/hugegraph-llm/src/hugegraph_llm/nodes/base_node.py @@ -30,6 +30,9 @@ class BaseNode(GNode): Node initialization method, can be overridden by subclasses. Returns a CStatus object indicating whether initialization succeeded. """ + if self.wk_input.data_json is not None: + self.context.assign_from_json(self.wk_input.data_json) + self.wk_input.data_json = None return CStatus() def run(self): diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/common_node/merge_rerank_node.py b/hugegraph-llm/src/hugegraph_llm/nodes/common_node/merge_rerank_node.py new file mode 100644 index 00000000..78f53e23 --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/nodes/common_node/merge_rerank_node.py @@ -0,0 +1,83 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Any +from hugegraph_llm.nodes.base_node import BaseNode +from hugegraph_llm.operators.common_op.merge_dedup_rerank import MergeDedupRerank +from hugegraph_llm.models.embeddings.init_embedding import get_embedding +from hugegraph_llm.config import huge_settings, llm_settings +from hugegraph_llm.utils.log import log + + +class MergeRerankNode(BaseNode): + """ + Merge and rerank node, responsible for merging vector and graph query results, deduplication and reranking. + """ + + operator: MergeDedupRerank + + def node_init(self): + """ + Initialize the merge and rerank operator. + """ + try: + # Read user configuration parameters from wk_input + embedding = get_embedding(llm_settings) + graph_ratio = self.wk_input.graph_ratio or 0.5 + rerank_method = self.wk_input.rerank_method or "bleu" + near_neighbor_first = self.wk_input.near_neighbor_first or False + custom_related_information = self.wk_input.custom_related_information or "" + topk_return_results = ( + self.wk_input.topk_return_results or huge_settings.topk_return_results + ) + + self.operator = MergeDedupRerank( + embedding=embedding, + graph_ratio=graph_ratio, + method=rerank_method, + near_neighbor_first=near_neighbor_first, + custom_related_information=custom_related_information, + topk_return_results=topk_return_results, + ) + return super().node_init() + except Exception as e: + log.error(f"Failed to initialize MergeRerankNode: {e}") + from PyCGraph import CStatus + + return CStatus(-1, f"MergeRerankNode initialization failed: {e}") + + def operator_schedule(self, data_json: Dict[str, Any]) -> Dict[str, Any]: + """ + Execute the merge and rerank operation. + """ + try: + # Perform merge, deduplication, and rerank + result = self.operator.run(data_json) + + # Log result statistics + vector_count = len(result.get("vector_result", [])) + graph_count = len(result.get("graph_result", [])) + merged_count = len(result.get("merged_result", [])) + + log.info( + f"Merge and rerank completed: {vector_count} vector results, " + f"{graph_count} graph results, {merged_count} merged results" + ) + + return result + + except Exception as e: + log.error(f"Merge and rerank failed: {e}") + return data_json diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/document_node/chunk_split.py b/hugegraph-llm/src/hugegraph_llm/nodes/document_node/chunk_split.py index 4c5acbe9..f71bd7bd 100644 --- a/hugegraph-llm/src/hugegraph_llm/nodes/document_node/chunk_split.py +++ b/hugegraph-llm/src/hugegraph_llm/nodes/document_node/chunk_split.py @@ -37,7 +37,7 @@ class ChunkSplitNode(BaseNode): if isinstance(texts, str): texts = [texts] self.chunk_split_op = ChunkSplit(texts, split_type, language) - return CStatus() + return super().node_init() def operator_schedule(self, data_json): return self.chunk_split_op.run(data_json) diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/commit_to_hugegraph.py b/hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/commit_to_hugegraph.py index b576e817..a4ebc709 100644 --- a/hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/commit_to_hugegraph.py +++ b/hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/commit_to_hugegraph.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from PyCGraph import CStatus from hugegraph_llm.nodes.base_node import BaseNode from hugegraph_llm.operators.hugegraph_op.commit_to_hugegraph import Commit2Graph from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState @@ -29,7 +28,7 @@ class Commit2GraphNode(BaseNode): if data_json: self.context.assign_from_json(data_json) self.commit_to_graph_op = Commit2Graph() - return CStatus() + return super().node_init() def operator_schedule(self, data_json): return self.commit_to_graph_op.run(data_json) diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/fetch_graph_data.py b/hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/fetch_graph_data.py index b2434e52..99b428e5 100644 --- a/hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/fetch_graph_data.py +++ b/hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/fetch_graph_data.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from PyCGraph import CStatus from hugegraph_llm.nodes.base_node import BaseNode from hugegraph_llm.operators.hugegraph_op.fetch_graph_data import FetchGraphData from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState @@ -27,7 +26,7 @@ class FetchGraphDataNode(BaseNode): def node_init(self): self.fetch_graph_data_op = FetchGraphData(get_hg_client()) - return CStatus() + return super().node_init() def operator_schedule(self, data_json): return self.fetch_graph_data_op.run(data_json) diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/graph_query_node.py b/hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/graph_query_node.py new file mode 100644 index 00000000..ae65ccb3 --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/graph_query_node.py @@ -0,0 +1,93 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from PyCGraph import CStatus +from typing import Dict, Any +from hugegraph_llm.nodes.base_node import BaseNode +from hugegraph_llm.operators.hugegraph_op.graph_rag_query import GraphRAGQuery +from hugegraph_llm.config import huge_settings, prompt +from hugegraph_llm.utils.log import log + + +class GraphQueryNode(BaseNode): + """ + Graph query node, responsible for retrieving relevant information from the graph database. + """ + + graph_rag_query: GraphRAGQuery + + def node_init(self): + """ + Initialize the graph query operator. + """ + try: + graph_name = huge_settings.graph_name + if not graph_name: + return CStatus(-1, "graph_name is required in wk_input") + + max_deep = self.wk_input.max_deep or 2 + max_graph_items = ( + self.wk_input.max_graph_items or huge_settings.max_graph_items + ) + max_v_prop_len = self.wk_input.max_v_prop_len or 2048 + max_e_prop_len = self.wk_input.max_e_prop_len or 256 + prop_to_match = self.wk_input.prop_to_match + num_gremlin_generate_example = self.wk_input.gremlin_tmpl_num or -1 + gremlin_prompt = ( + self.wk_input.gremlin_prompt or prompt.gremlin_generate_prompt + ) + + # Initialize GraphRAGQuery operator + self.graph_rag_query = GraphRAGQuery( + max_deep=max_deep, + max_graph_items=max_graph_items, + max_v_prop_len=max_v_prop_len, + max_e_prop_len=max_e_prop_len, + prop_to_match=prop_to_match, + num_gremlin_generate_example=num_gremlin_generate_example, + gremlin_prompt=gremlin_prompt, + ) + + return super().node_init() + except Exception as e: + log.error(f"Failed to initialize GraphQueryNode: {e}") + + return CStatus(-1, f"GraphQueryNode initialization failed: {e}") + + def operator_schedule(self, data_json: Dict[str, Any]) -> Dict[str, Any]: + """ + Execute the graph query operation. + """ + try: + # Get the query text from input + query = data_json.get("query", "") + + if not query: + log.warning("No query text provided for graph query") + return data_json + + # Execute the graph query (assuming schema and semantic query have been completed in previous nodes) + graph_result = self.graph_rag_query.run(data_json) + data_json.update(graph_result) + + log.info( + f"Graph query completed, found {len(data_json.get('graph_result', []))} results" + ) + + return data_json + + except Exception as e: + log.error(f"Graph query failed: {e}") + return data_json diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/schema.py b/hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/schema.py index 84719d9e..3face9d6 100644 --- a/hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/schema.py +++ b/hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/schema.py @@ -15,7 +15,6 @@ import json -from PyCGraph import CStatus from hugegraph_llm.nodes.base_node import BaseNode from hugegraph_llm.operators.common_op.check_schema import CheckSchema from hugegraph_llm.operators.hugegraph_op.schema_manager import SchemaManager @@ -59,7 +58,7 @@ class SchemaNode(BaseNode): else: log.info("Get schema '%s' from graphdb.", self.schema) self.schema_manager = self._import_schema(from_hugegraph=self.schema) - return CStatus() + return super().node_init() def operator_schedule(self, data_json): log.debug("SchemaNode input state: %s", data_json) diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/index_node/build_semantic_index.py b/hugegraph-llm/src/hugegraph_llm/nodes/index_node/build_semantic_index.py index ab31fa39..c01cffc9 100644 --- a/hugegraph-llm/src/hugegraph_llm/nodes/index_node/build_semantic_index.py +++ b/hugegraph-llm/src/hugegraph_llm/nodes/index_node/build_semantic_index.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from PyCGraph import CStatus from hugegraph_llm.config import llm_settings from hugegraph_llm.models.embeddings.init_embedding import get_embedding from hugegraph_llm.nodes.base_node import BaseNode @@ -28,7 +27,7 @@ class BuildSemanticIndexNode(BaseNode): def node_init(self): self.build_semantic_index_op = BuildSemanticIndex(get_embedding(llm_settings)) - return CStatus() + return super().node_init() def operator_schedule(self, data_json): return self.build_semantic_index_op.run(data_json) diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/index_node/build_vector_index.py b/hugegraph-llm/src/hugegraph_llm/nodes/index_node/build_vector_index.py index cf2f9b67..1f6a3c75 100644 --- a/hugegraph-llm/src/hugegraph_llm/nodes/index_node/build_vector_index.py +++ b/hugegraph-llm/src/hugegraph_llm/nodes/index_node/build_vector_index.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from PyCGraph import CStatus from hugegraph_llm.config import llm_settings from hugegraph_llm.models.embeddings.init_embedding import get_embedding from hugegraph_llm.nodes.base_node import BaseNode @@ -28,7 +27,7 @@ class BuildVectorIndexNode(BaseNode): def node_init(self): self.build_vector_index_op = BuildVectorIndex(get_embedding(llm_settings)) - return CStatus() + return super().node_init() def operator_schedule(self, data_json): return self.build_vector_index_op.run(data_json) diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/index_node/gremlin_example_index_query.py b/hugegraph-llm/src/hugegraph_llm/nodes/index_node/gremlin_example_index_query.py index eb033d86..e9283598 100644 --- a/hugegraph-llm/src/hugegraph_llm/nodes/index_node/gremlin_example_index_query.py +++ b/hugegraph-llm/src/hugegraph_llm/nodes/index_node/gremlin_example_index_query.py @@ -19,9 +19,12 @@ from typing import Any, Dict from PyCGraph import CStatus +from hugegraph_llm.config import llm_settings from hugegraph_llm.nodes.base_node import BaseNode -from hugegraph_llm.operators.index_op.gremlin_example_index_query import GremlinExampleIndexQuery -from hugegraph_llm.models.embeddings.init_embedding import Embeddings +from hugegraph_llm.operators.index_op.gremlin_example_index_query import ( + GremlinExampleIndexQuery, +) +from hugegraph_llm.models.embeddings.init_embedding import get_embedding class GremlinExampleIndexQueryNode(BaseNode): @@ -29,13 +32,15 @@ class GremlinExampleIndexQueryNode(BaseNode): def node_init(self): # Build operator (index lazy-loading handled in operator) - embedding = Embeddings().get_embedding() + embedding = get_embedding(llm_settings) example_num = getattr(self.wk_input, "example_num", None) if not isinstance(example_num, int): example_num = 2 # Clamp to [0, 10] example_num = max(0, min(10, example_num)) - self.operator = GremlinExampleIndexQuery(embedding=embedding, num_examples=example_num) + self.operator = GremlinExampleIndexQuery( + embedding=embedding, num_examples=example_num + ) return CStatus() def operator_schedule(self, data_json: Dict[str, Any]): diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/index_node/semantic_id_query_node.py b/hugegraph-llm/src/hugegraph_llm/nodes/index_node/semantic_id_query_node.py new file mode 100644 index 00000000..bf605aa4 --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/nodes/index_node/semantic_id_query_node.py @@ -0,0 +1,91 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from PyCGraph import CStatus +from typing import Dict, Any +from hugegraph_llm.nodes.base_node import BaseNode +from hugegraph_llm.operators.index_op.semantic_id_query import SemanticIdQuery +from hugegraph_llm.models.embeddings.init_embedding import get_embedding +from hugegraph_llm.config import huge_settings, llm_settings +from hugegraph_llm.utils.log import log + + +class SemanticIdQueryNode(BaseNode): + """ + Semantic ID query node, responsible for semantic matching based on keywords. + """ + + semantic_id_query: SemanticIdQuery + + def node_init(self): + """ + Initialize the semantic ID query operator. + """ + try: + graph_name = huge_settings.graph_name + if not graph_name: + return CStatus(-1, "graph_name is required in wk_input") + + embedding = get_embedding(llm_settings) + by = self.wk_input.semantic_by or "keywords" + topk_per_keyword = ( + self.wk_input.topk_per_keyword or huge_settings.topk_per_keyword + ) + topk_per_query = self.wk_input.topk_per_query or 10 + vector_dis_threshold = ( + self.wk_input.vector_dis_threshold or huge_settings.vector_dis_threshold + ) + + # Initialize the semantic ID query operator + self.semantic_id_query = SemanticIdQuery( + embedding=embedding, + by=by, + topk_per_keyword=topk_per_keyword, + topk_per_query=topk_per_query, + vector_dis_threshold=vector_dis_threshold, + ) + + return super().node_init() + except Exception as e: + log.error(f"Failed to initialize SemanticIdQueryNode: {e}") + + return CStatus(-1, f"SemanticIdQueryNode initialization failed: {e}") + + def operator_schedule(self, data_json: Dict[str, Any]) -> Dict[str, Any]: + """ + Execute the semantic ID query operation. + """ + try: + # Get the query text and keywords from input + query = data_json.get("query", "") + keywords = data_json.get("keywords", []) + + if not query and not keywords: + log.warning("No query text or keywords provided for semantic query") + return data_json + + # Perform the semantic query + semantic_result = self.semantic_id_query.run(data_json) + + match_vids = semantic_result.get("match_vids", []) + log.info( + f"Semantic query completed, found {len(match_vids)} matching vertex IDs" + ) + + return semantic_result + + except Exception as e: + log.error(f"Semantic query failed: {e}") + return data_json diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/index_node/vector_query_node.py b/hugegraph-llm/src/hugegraph_llm/nodes/index_node/vector_query_node.py new file mode 100644 index 00000000..48b50acf --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/nodes/index_node/vector_query_node.py @@ -0,0 +1,74 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Any +from hugegraph_llm.config import llm_settings +from hugegraph_llm.nodes.base_node import BaseNode +from hugegraph_llm.operators.index_op.vector_index_query import VectorIndexQuery +from hugegraph_llm.models.embeddings.init_embedding import get_embedding +from hugegraph_llm.utils.log import log + + +class VectorQueryNode(BaseNode): + """ + Vector query node, responsible for retrieving relevant documents from the vector index + """ + + operator: VectorIndexQuery + + def node_init(self): + """ + Initialize the vector query operator + """ + try: + # 从 wk_input 中读取用户配置参数 + embedding = get_embedding(llm_settings) + max_items = ( + self.wk_input.max_items if self.wk_input.max_items is not None else 3 + ) + + self.operator = VectorIndexQuery(embedding=embedding, topk=max_items) + return super().node_init() + except Exception as e: + log.error(f"Failed to initialize VectorQueryNode: {e}") + from PyCGraph import CStatus + + return CStatus(-1, f"VectorQueryNode initialization failed: {e}") + + def operator_schedule(self, data_json: Dict[str, Any]) -> Dict[str, Any]: + """ + Execute the vector query operation + """ + try: + # Get the query text from input + query = data_json.get("query", "") + if not query: + log.warning("No query text provided for vector query") + return data_json + + # Perform the vector query + result = self.operator.run({"query": query}) + + # Update the state + data_json.update(result) + log.info( + f"Vector query completed, found {len(result.get('vector_result', []))} results" + ) + + return data_json + + except Exception as e: + log.error(f"Vector query failed: {e}") + return data_json diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/answer_synthesize_node.py b/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/answer_synthesize_node.py new file mode 100644 index 00000000..22b970b4 --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/answer_synthesize_node.py @@ -0,0 +1,99 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Any +from hugegraph_llm.nodes.base_node import BaseNode +from hugegraph_llm.operators.llm_op.answer_synthesize import AnswerSynthesize +from hugegraph_llm.utils.log import log + + +class AnswerSynthesizeNode(BaseNode): + """ + Answer synthesis node, responsible for generating the final answer based on retrieval results. + """ + + operator: AnswerSynthesize + + def node_init(self): + """ + Initialize the answer synthesis operator. + """ + try: + prompt_template = self.wk_input.answer_prompt + raw_answer = self.wk_input.raw_answer or False + vector_only_answer = self.wk_input.vector_only_answer or False + graph_only_answer = self.wk_input.graph_only_answer or False + graph_vector_answer = self.wk_input.graph_vector_answer or False + + self.operator = AnswerSynthesize( + prompt_template=prompt_template, + raw_answer=raw_answer, + vector_only_answer=vector_only_answer, + graph_only_answer=graph_only_answer, + graph_vector_answer=graph_vector_answer, + ) + return super().node_init() + except Exception as e: + log.error(f"Failed to initialize AnswerSynthesizeNode: {e}") + from PyCGraph import CStatus + + return CStatus(-1, f"AnswerSynthesizeNode initialization failed: {e}") + + def operator_schedule(self, data_json: Dict[str, Any]) -> Dict[str, Any]: + """ + Execute the answer synthesis operation. + """ + try: + if self.getGParamWithNoEmpty("wkflow_input").stream: + # Streaming mode: return a generator for streaming output + data_json["stream_generator"] = self.operator.run_streaming(data_json) + return data_json + else: + # Non-streaming mode: execute answer synthesis + result = self.operator.run(data_json) + + # Record the types of answers generated + answer_types = [] + if result.get("raw_answer"): + answer_types.append("raw") + if result.get("vector_only_answer"): + answer_types.append("vector_only") + if result.get("graph_only_answer"): + answer_types.append("graph_only") + if result.get("graph_vector_answer"): + answer_types.append("graph_vector") + + log.info( + f"Answer synthesis completed for types: {', '.join(answer_types)}" + ) + + # Print enabled answer types according to self.wk_input configuration + wk_input_types = [] + if getattr(self.wk_input, "raw_answer", False): + wk_input_types.append("raw") + if getattr(self.wk_input, "vector_only_answer", False): + wk_input_types.append("vector_only") + if getattr(self.wk_input, "graph_only_answer", False): + wk_input_types.append("graph_only") + if getattr(self.wk_input, "graph_vector_answer", False): + wk_input_types.append("graph_vector") + log.info( + f"Enabled answer types according to wk_input config: {', '.join(wk_input_types)}" + ) + return result + + except Exception as e: + log.error(f"Answer synthesis failed: {e}") + return data_json diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/extract_info.py b/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/extract_info.py index 8bceed80..628765f5 100644 --- a/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/extract_info.py +++ b/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/extract_info.py @@ -43,7 +43,7 @@ class ExtractNode(BaseNode): self.property_graph_extract = PropertyGraphExtract(llm, example_prompt) else: return CStatus(-1, f"Unsupported extract_type: {extract_type}") - return CStatus() + return super().node_init() def operator_schedule(self, data_json): if self.extract_type == "triples": diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/keyword_extract_node.py b/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/keyword_extract_node.py new file mode 100644 index 00000000..76fc06eb --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/keyword_extract_node.py @@ -0,0 +1,80 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Any +from PyCGraph import CStatus + +from hugegraph_llm.nodes.base_node import BaseNode +from hugegraph_llm.operators.llm_op.keyword_extract import KeywordExtract +from hugegraph_llm.utils.log import log + + +class KeywordExtractNode(BaseNode): + operator: KeywordExtract + + """ + Keyword extraction node, responsible for extracting keywords from query text. + """ + + def node_init(self): + """ + Initialize the keyword extraction operator. + """ + try: + max_keywords = ( + self.wk_input.max_keywords + if self.wk_input.max_keywords is not None + else 5 + ) + language = ( + self.wk_input.language + if self.wk_input.language is not None + else "english" + ) + extract_template = self.wk_input.keywords_extract_prompt + + self.operator = KeywordExtract( + text=self.wk_input.query, + max_keywords=max_keywords, + language=language, + extract_template=extract_template, + ) + return super().node_init() + except Exception as e: + log.error(f"Failed to initialize KeywordExtractNode: {e}") + return CStatus(-1, f"KeywordExtractNode initialization failed: {e}") + + def operator_schedule(self, data_json: Dict[str, Any]) -> Dict[str, Any]: + """ + Execute the keyword extraction operation. + """ + try: + # Perform keyword extraction + result = self.operator.run(data_json) + if "keywords" not in result: + log.warning("Keyword extraction result missing 'keywords' field") + result["keywords"] = [] + + log.info(f"Extracted keywords: {result.get('keywords', [])}") + + return result + + except Exception as e: + log.error(f"Keyword extraction failed: {e}") + # Add error flag to indicate failure + error_result = data_json.copy() + error_result["error"] = str(e) + error_result["keywords"] = [] + return error_result diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/prompt_generate.py b/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/prompt_generate.py index 317f9e6a..8c49994f 100644 --- a/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/prompt_generate.py +++ b/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/prompt_generate.py @@ -50,7 +50,7 @@ class PromptGenerateNode(BaseNode): "example_name": self.wk_input.example_name, } self.context.assign_from_json(context) - return CStatus() + return super().node_init() def operator_schedule(self, data_json): """ diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/schema_build.py b/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/schema_build.py index 7df2e68e..408adb10 100644 --- a/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/schema_build.py +++ b/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/schema_build.py @@ -75,7 +75,7 @@ class SchemaBuildNode(BaseNode): } self.context.assign_from_json(_context_payload) - return CStatus() + return super().node_init() def operator_schedule(self, data_json): try: diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/text2gremlin.py b/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/text2gremlin.py index ffbafbaf..a3683152 100644 --- a/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/text2gremlin.py +++ b/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/text2gremlin.py @@ -22,13 +22,15 @@ from PyCGraph import CStatus from hugegraph_llm.nodes.base_node import BaseNode from hugegraph_llm.operators.llm_op.gremlin_generate import GremlinGenerateSynthesize -from hugegraph_llm.models.llms.init_llm import LLMs -from hugegraph_llm.config import prompt as prompt_cfg +from hugegraph_llm.models.llms.init_llm import get_text2gql_llm +from hugegraph_llm.config import llm_settings, prompt as prompt_cfg def _stable_schema_string(state_json: Dict[str, Any]) -> str: if "simple_schema" in state_json and state_json["simple_schema"] is not None: - return json.dumps(state_json["simple_schema"], ensure_ascii=False, sort_keys=True) + return json.dumps( + state_json["simple_schema"], ensure_ascii=False, sort_keys=True + ) if "schema" in state_json and state_json["schema"] is not None: return json.dumps(state_json["schema"], ensure_ascii=False, sort_keys=True) return "" @@ -39,7 +41,7 @@ class Text2GremlinNode(BaseNode): def node_init(self): # Select LLM - llm = LLMs().get_text2gql_llm() + llm = get_text2gql_llm(llm_settings) # Serialize schema deterministically state_json = self.context.to_json() schema_str = _stable_schema_string(state_json) diff --git a/hugegraph-llm/src/hugegraph_llm/state/ai_state.py b/hugegraph-llm/src/hugegraph_llm/state/ai_state.py index f941098b..3a6fd3c1 100644 --- a/hugegraph-llm/src/hugegraph_llm/state/ai_state.py +++ b/hugegraph-llm/src/hugegraph_llm/state/ai_state.py @@ -24,7 +24,6 @@ class WkFlowInput(GParam): split_type: str = None # split type used by ChunkSplit Node example_prompt: str = None # need by graph information extract schema: str = None # Schema information requeired by SchemaNode - graph_name: str = None data_json = None extract_type = None query_examples = None @@ -34,11 +33,45 @@ class WkFlowInput(GParam): scenario: str = None # Scenario description example_name: str = None # Example name # Fields for Text2Gremlin - query: str = None example_num: int = None gremlin_prompt: str = None requested_outputs: Optional[List[str]] = None + # RAG Flow related fields + query: str = None # User query for RAG + vector_search: bool = None # Enable vector search + graph_search: bool = None # Enable graph search + raw_answer: bool = None # Return raw answer + vector_only_answer: bool = None # Vector only answer mode + graph_only_answer: bool = None # Graph only answer mode + graph_vector_answer: bool = None # Combined graph and vector answer + graph_ratio: float = None # Graph ratio for merging + rerank_method: str = None # Reranking method + near_neighbor_first: bool = None # Near neighbor first flag + custom_related_information: str = None # Custom related information + answer_prompt: str = None # Answer generation prompt + keywords_extract_prompt: str = None # Keywords extraction prompt + gremlin_tmpl_num: int = None # Gremlin template number + gremlin_prompt: str = None # Gremlin generation prompt + max_graph_items: int = None # Maximum graph items + topk_return_results: int = None # Top-k return results + vector_dis_threshold: float = None # Vector distance threshold + topk_per_keyword: int = None # Top-k per keyword + max_keywords: int = None + max_items: int = None + + # Semantic query related fields + semantic_by: str = None # Semantic query method + topk_per_query: int = None # Top-k per query + + # Graph query related fields + max_deep: int = None # Maximum depth for graph traversal + max_v_prop_len: int = None # Maximum vertex property length + max_e_prop_len: int = None # Maximum edge property length + prop_to_match: str = None # Property to match + + stream: bool = None # used for recognize stream mode + def reset(self, _: CStatus) -> None: self.texts = None self.language = None @@ -55,10 +88,40 @@ class WkFlowInput(GParam): self.scenario = None self.example_name = None # Text2Gremlin related configuration - self.query = None self.example_num = None self.gremlin_prompt = None self.requested_outputs = None + # RAG Flow related fields + self.query = None + self.vector_search = None + self.graph_search = None + self.raw_answer = None + self.vector_only_answer = None + self.graph_only_answer = None + self.graph_vector_answer = None + self.graph_ratio = None + self.rerank_method = None + self.near_neighbor_first = None + self.custom_related_information = None + self.answer_prompt = None + self.keywords_extract_prompt = None + self.gremlin_tmpl_num = None + self.gremlin_prompt = None + self.max_graph_items = None + self.topk_return_results = None + self.vector_dis_threshold = None + self.topk_per_keyword = None + self.max_keywords = None + self.max_items = None + # Semantic query related fields + self.semantic_by = None + self.topk_per_query = None + # Graph query related fields + self.max_deep = None + self.max_v_prop_len = None + self.max_e_prop_len = None + self.prop_to_match = None + self.stream = None class WkFlowState(GParam): @@ -83,6 +146,17 @@ class WkFlowState(GParam): template_exec_res: Optional[Any] = None raw_exec_res: Optional[Any] = None + match_vids = None + vector_result = None + graph_result = None + + raw_answer: str = None + vector_only_answer: str = None + graph_only_answer: str = None + graph_vector_answer: str = None + + merged_result = None + def setup(self): self.schema = None self.simple_schema = None @@ -90,7 +164,7 @@ class WkFlowState(GParam): self.edges = None self.vertices = None self.triples = None - self.call_count = 0 + self.call_count = None self.keywords = None self.vector_result = None @@ -99,12 +173,20 @@ class WkFlowState(GParam): self.generated_extract_prompt = None # Text2Gremlin results reset - self.match_result = [] - self.result = "" - self.raw_result = "" - self.template_exec_res = "" - self.raw_exec_res = "" + self.match_result = None + self.result = None + self.raw_result = None + self.template_exec_res = None + self.raw_exec_res = None + self.raw_answer = None + self.vector_only_answer = None + self.graph_only_answer = None + self.graph_vector_answer = None + + self.vector_result = None + self.graph_result = None + self.merged_result = None return CStatus() def to_json(self): @@ -116,7 +198,11 @@ class WkFlowState(GParam): dict: A dictionary containing non-None instance members and their serialized values. """ # Only export instance attributes (excluding methods and class attributes) whose values are not None - return {k: v for k, v in self.__dict__.items() if not k.startswith("_") and v is not None} + return { + k: v + for k, v in self.__dict__.items() + if not k.startswith("_") and v is not None + } # Implement a method that assigns keys from data_json as WkFlowState member variables def assign_from_json(self, data_json: dict): diff --git a/hugegraph-llm/src/hugegraph_llm/utils/graph_index_utils.py b/hugegraph-llm/src/hugegraph_llm/utils/graph_index_utils.py index 7b870033..3f527f2f 100644 --- a/hugegraph-llm/src/hugegraph_llm/utils/graph_index_utils.py +++ b/hugegraph-llm/src/hugegraph_llm/utils/graph_index_utils.py @@ -44,30 +44,17 @@ def get_graph_index_info(): raise gr.Error(str(e)) -def get_graph_index_info_old(): - builder = KgBuilder(LLMs().get_chat_llm(), Embeddings().get_embedding(), get_hg_client()) - graph_summary_info = builder.fetch_graph_data().run() - folder_name = get_index_folder_name(huge_settings.graph_name, huge_settings.graph_space) - index_dir = str(os.path.join(resource_path, folder_name, "graph_vids")) - filename_prefix = get_filename_prefix( - llm_settings.embedding_type, getattr(builder.embedding, "model_name", None) - ) - vector_index = VectorIndex.from_index_file(index_dir, filename_prefix) - graph_summary_info["vid_index"] = { - "embed_dim": vector_index.index.d, - "num_vectors": vector_index.index.ntotal, - "num_vids": len(vector_index.properties), - } - return json.dumps(graph_summary_info, ensure_ascii=False, indent=2) - - def clean_all_graph_index(): - folder_name = get_index_folder_name(huge_settings.graph_name, huge_settings.graph_space) + folder_name = get_index_folder_name( + huge_settings.graph_name, huge_settings.graph_space + ) filename_prefix = get_filename_prefix( llm_settings.embedding_type, getattr(Embeddings().get_embedding(), "model_name", None), ) - VectorIndex.clean(str(os.path.join(resource_path, folder_name, "graph_vids")), filename_prefix) + VectorIndex.clean( + str(os.path.join(resource_path, folder_name, "graph_vids")), filename_prefix + ) VectorIndex.clean( str(os.path.join(resource_path, folder_name, "gremlin_examples")), filename_prefix, @@ -99,14 +86,18 @@ def parse_schema(schema: str, builder: KgBuilder) -> Optional[str]: def extract_graph_origin(input_file, input_text, schema, example_prompt) -> str: texts = read_documents(input_file, input_text) - builder = KgBuilder(LLMs().get_chat_llm(), Embeddings().get_embedding(), get_hg_client()) + builder = KgBuilder( + LLMs().get_chat_llm(), Embeddings().get_embedding(), get_hg_client() + ) if not schema: return "ERROR: please input with correct schema/format." error_message = parse_schema(schema, builder) if error_message: return error_message - builder.chunk_split(texts, "document", "zh").extract_info(example_prompt, "property_graph") + builder.chunk_split(texts, "document", "zh").extract_info( + example_prompt, "property_graph" + ) try: context = builder.run() @@ -155,20 +146,6 @@ def update_vid_embedding(): raise gr.Error(str(e)) -def update_vid_embedding_old(): - builder = KgBuilder(LLMs().get_chat_llm(), Embeddings().get_embedding(), get_hg_client()) - builder.fetch_graph_data().build_vertex_id_semantic_index() - log.debug("Operators: %s", builder.operators) - try: - context = builder.run() - removed_num = context["removed_vid_vector_num"] - added_num = context["added_vid_vector_num"] - return f"Removed {removed_num} vectors, added {added_num} vectors." - except Exception as e: # pylint: disable=broad-exception-caught - log.error(e) - raise gr.Error(str(e)) - - def import_graph_data(data: str, schema: str) -> Union[str, Dict[str, Any]]: try: scheduler = SchedulerSingleton.get_instance() @@ -181,73 +158,11 @@ def import_graph_data(data: str, schema: str) -> Union[str, Dict[str, Any]]: return data -def import_graph_data_old(data: str, schema: str) -> Union[str, Dict[str, Any]]: - try: - data_json = json.loads(data.strip()) - log.debug("Import graph data: %s", data) - builder = KgBuilder(LLMs().get_chat_llm(), Embeddings().get_embedding(), get_hg_client()) - if schema: - error_message = parse_schema(schema, builder) - if error_message: - return error_message - - context = builder.commit_to_hugegraph().run(data_json) - gr.Info("Import graph data successfully!") - print(context) - return json.dumps(context, ensure_ascii=False, indent=2) - except Exception as e: # pylint: disable=W0718 - log.error(e) - traceback.print_exc() - # Note: can't use gr.Error here - gr.Warning(str(e) + " Please check the graph data format/type carefully.") - return data - - def build_schema(input_text, query_example, few_shot): scheduler = SchedulerSingleton.get_instance() try: - return scheduler.schedule_flow("build_schema", input_text, query_example, few_shot) + return scheduler.schedule_flow( + "build_schema", input_text, query_example, few_shot + ) except (TypeError, ValueError) as e: raise gr.Error(f"Schema generation failed: {e}") - - -def build_schema_old(input_text, query_example, few_shot): - context = { - "raw_texts": [input_text] if input_text else [], - "query_examples": [], - "few_shot_schema": {}, - } - - if few_shot: - try: - context["few_shot_schema"] = json.loads(few_shot) - except json.JSONDecodeError as e: - raise gr.Error(f"Few Shot Schema is not in a valid JSON format: {e}") from e - - if query_example: - try: - parsed_examples = json.loads(query_example) - # Validate and retain the description and gremlin fields - context["query_examples"] = [ - { - "description": ex.get("description", ""), - "gremlin": ex.get("gremlin", ""), - } - for ex in parsed_examples - if isinstance(ex, dict) and "description" in ex and "gremlin" in ex - ] - except json.JSONDecodeError as e: - raise gr.Error(f"Query Examples is not in a valid JSON format: {e}") from e - - builder = KgBuilder(LLMs().get_chat_llm(), Embeddings().get_embedding(), get_hg_client()) - try: - schema = builder.build_schema().run(context) - except Exception as e: - log.error("Failed to generate schema: %s", e) - raise gr.Error(f"Schema generation failed: {e}") from e - try: - formatted_schema = json.dumps(schema, ensure_ascii=False, indent=2) - return formatted_schema - except (TypeError, ValueError) as e: - log.error("Failed to format schema: %s", e) - return str(schema)
