This is an automated email from the ASF dual-hosted git repository. jin pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/incubator-hugegraph-ai.git
The following commit(s) were added to refs/heads/main by this push: new d2fcfdb fix(llm): update prompt to fit prefix cache (#137) d2fcfdb is described below commit d2fcfdb2348baea1d5ebf526d6e687271d5cda03 Author: HaoJin Yang <1454...@gmail.com> AuthorDate: Mon Dec 23 00:32:45 2024 +0800 fix(llm): update prompt to fit prefix cache (#137) * fix vid not readable for LLM in gremlin prompt --------- Co-authored-by: imbajin <j...@apache.org> --- .../src/hugegraph_llm/api/models/rag_requests.py | 2 +- hugegraph-llm/src/hugegraph_llm/api/rag_api.py | 88 ++++++++---------- .../src/hugegraph_llm/config/prompt_config.py | 79 +++++++++------- .../src/hugegraph_llm/demo/rag_demo/app.py | 10 +- .../demo/rag_demo/text2gremlin_block.py | 102 ++++++++++++++------- .../operators/llm_op/gremlin_generate.py | 14 +-- 6 files changed, 166 insertions(+), 129 deletions(-) diff --git a/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py b/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py index 1eaa1e2..489ce0a 100644 --- a/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py +++ b/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py @@ -46,11 +46,11 @@ class RAGRequest(BaseModel): ) +# TODO: import the default value of prompt.* dynamically class GraphRAGRequest(BaseModel): query: str = Query("", description="Query you want to ask") gremlin_tmpl_num: int = Query(1, description="Number of Gremlin templates to use.") with_gremlin_tmpl: bool = Query(True, description="Use example template in text2gremlin") - answer_prompt: Optional[str] = Query(prompt.answer_prompt, description="Prompt to guide the answer generation.") rerank_method: Literal["bleu", "reranker"] = Query("bleu", description="Method to rerank the results.") near_neighbor_first: bool = Query(False, description="Prioritize near neighbors in the search results.") custom_priority_info: str = Query("", description="Custom information to prioritize certain results.") diff --git a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py b/hugegraph-llm/src/hugegraph_llm/api/rag_api.py index 8a59c8b..4036496 100644 --- a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py +++ b/hugegraph-llm/src/hugegraph_llm/api/rag_api.py @@ -16,7 +16,6 @@ # under the License. import json -from typing import Literal from fastapi import status, APIRouter, HTTPException @@ -29,72 +28,52 @@ from hugegraph_llm.api.models.rag_requests import ( GraphRAGRequest, ) from hugegraph_llm.api.models.rag_response import RAGResponse -from hugegraph_llm.config import llm_settings, huge_settings, prompt +from hugegraph_llm.config import llm_settings, prompt from hugegraph_llm.utils.log import log -def graph_rag_recall( - query: str, - gremlin_tmpl_num: int, - with_gremlin_tmpl: bool, - answer_prompt: str, # FIXME: should be used in the query - rerank_method: Literal["bleu", "reranker"], - near_neighbor_first: bool, - custom_related_information: str, - gremlin_prompt: str, -) -> dict: - from hugegraph_llm.operators.graph_rag_task import RAGPipeline - - rag = RAGPipeline() - - rag.extract_keywords().keywords_to_vid().import_schema(huge_settings.graph_name).query_graphdb( - with_gremlin_template=with_gremlin_tmpl, - num_gremlin_generate_example=gremlin_tmpl_num, - gremlin_prompt=gremlin_prompt, - ).merge_dedup_rerank( - rerank_method=rerank_method, - near_neighbor_first=near_neighbor_first, - custom_related_information=custom_related_information, - ) - context = rag.run(verbose=True, query=query, graph_search=True) - return context - - def rag_http_api( - router: APIRouter, rag_answer_func, apply_graph_conf, apply_llm_conf, apply_embedding_conf, apply_reranker_conf + router: APIRouter, + rag_answer_func, + graph_rag_recall_func, + apply_graph_conf, + apply_llm_conf, + apply_embedding_conf, + apply_reranker_conf, ): @router.post("/rag", status_code=status.HTTP_200_OK) def rag_answer_api(req: RAGRequest): result = rag_answer_func( - req.query, - req.raw_answer, - req.vector_only, - req.graph_only, - req.graph_vector_answer, - req.with_gremlin_tmpl, - req.graph_ratio, - req.rerank_method, - req.near_neighbor_first, - req.custom_priority_info, - req.answer_prompt or prompt.answer_prompt, - req.keywords_extract_prompt or prompt.keywords_extract_prompt, - req.gremlin_tmpl_num, - req.gremlin_prompt or prompt.gremlin_generate_prompt, + text=req.query, + raw_answer=req.raw_answer, + vector_only_answer=req.vector_only, + graph_only_answer=req.graph_only, + graph_vector_answer=req.graph_vector_answer, + with_gremlin_template=req.with_gremlin_tmpl, + graph_ratio=req.graph_ratio, + rerank_method=req.rerank_method, + near_neighbor_first=req.near_neighbor_first, + custom_related_information=req.custom_priority_info, + answer_prompt=req.answer_prompt or prompt.answer_prompt, + keywords_extract_prompt=req.keywords_extract_prompt or prompt.keywords_extract_prompt, + gremlin_tmpl_num=req.gremlin_tmpl_num, + gremlin_prompt=req.gremlin_prompt or prompt.gremlin_generate_prompt, ) + # TODO: we need more info in the response for users to understand the query logic return { - key: value - for key, value in zip(["raw_answer", "vector_only", "graph_only", "graph_vector_answer"], result) - if getattr(req, key) + "query": req.query, + **{key: value + for key, value in zip(["raw_answer", "vector_only", "graph_only", "graph_vector_answer"], result) + if getattr(req, key)} } @router.post("/rag/graph", status_code=status.HTTP_200_OK) def graph_rag_recall_api(req: GraphRAGRequest): try: - result = graph_rag_recall( + result = graph_rag_recall_func( query=req.query, gremlin_tmpl_num=req.gremlin_tmpl_num, with_gremlin_tmpl=req.with_gremlin_tmpl, - answer_prompt=req.answer_prompt or prompt.answer_prompt, rerank_method=req.rerank_method, near_neighbor_first=req.near_neighbor_first, custom_related_information=req.custom_priority_info, @@ -102,8 +81,15 @@ def rag_http_api( ) if isinstance(result, dict): - params = ["query", "keywords", "match_vids", "graph_result_flag", "gremlin", "graph_result", - "vertex_degree_list"] + params = [ + "query", + "keywords", + "match_vids", + "graph_result_flag", + "gremlin", + "graph_result", + "vertex_degree_list", + ] user_result = {key: result[key] for key in params if key in result} return {"graph_recall": user_result} # Note: Maybe only for qianfan/wenxin diff --git a/hugegraph-llm/src/hugegraph_llm/config/prompt_config.py b/hugegraph-llm/src/hugegraph_llm/config/prompt_config.py index 4d37ced..f6aeef5 100644 --- a/hugegraph-llm/src/hugegraph_llm/config/prompt_config.py +++ b/hugegraph-llm/src/hugegraph_llm/config/prompt_config.py @@ -18,18 +18,19 @@ from hugegraph_llm.config.models.base_prompt_config import BasePromptConfig + class PromptConfig(BasePromptConfig): # Data is detached from llm_op/answer_synthesize.py answer_prompt: str = """You are an expert in knowledge graphs and natural language processing. Your task is to provide a precise and accurate answer based on the given context. +Given the context information and without using fictive knowledge, +answer the following query in a concise and professional manner. + Context information is below. --------------------- {context_str} --------------------- - -Given the context information and without using fictive knowledge, -answer the following query in a concise and professional manner. Query: {query_str} Answer: """ @@ -131,7 +132,7 @@ Meet Sarah, a 30-year-old attorney, and her roommate, James, whom she's shared a keywords_extract_prompt: str = """指令: 请对以下文本执行以下任务: 1. 从文本中提取关键词: - - 最少 0 个,最多 {max_keywords} 个。 + - 最少 0 个,最多 MAX_KEYWORDS 个。 - 关键词应为具有完整语义的词语或短语,确保信息完整。 2. 识别需改写的关键词: - 从提取的关键词中,识别那些在原语境中具有歧义或存在信息缺失的关键词。 @@ -151,47 +152,55 @@ Meet Sarah, a 30-year-old attorney, and her roommate, James, whom she's shared a - 仅输出一行内容, 以 KEYWORDS: 为前缀,后跟所有关键词或对应的同义词,之间用逗号分隔。抽取的关键词中不允许出现空格或空字符 - 格式示例: KEYWORDS:关键词1,关键词2,...,关键词n + +MAX_KEYWORDS: {max_keywords} 文本: {question} """ -#pylint: disable=C0301 + # pylint: disable=C0301 # keywords_extract_prompt_EN = """ -# Instruction: -# Please perform the following tasks on the text below: -# 1. Extract Keywords and Generate Synonyms from text: -# - At least 0, at most {max_keywords} keywords. -# - For each keyword, generate its synonyms or possible variant forms. -# Requirements: -# - Keywords should be meaningful and specific entities; avoid using meaningless or overly broad terms (e.g., “object,” “the,” “he”). -# - Prioritize extracting subjects, verbs, and objects; avoid extracting function words or auxiliary words. -# - Do not expand into unrelated generalized categories. -# Note: -# - Only consider semantic synonyms and other words with similar meanings in the given context. -# Output Format: -# - Output only one line, prefixed with KEYWORDS:, followed by all keywords and synonyms, separated by commas.No spaces or empty characters are allowed in the extracted keywords. -# - Format example: -# KEYWORDS: keyword1, keyword2, ..., keywordn, synonym1, synonym2, ..., synonymn -# Text: -# {question} -# """ - - gremlin_generate_prompt = """\ -Given the example query-gremlin pairs: -{example} - -Given the graph schema: + # Instruction: + # Please perform the following tasks on the text below: + # 1. Extract Keywords and Generate Synonyms from the text: + # - At least 0, at most {max_keywords} keywords. + # - For each keyword, generate its synonyms or possible variant forms. + # Requirements: + # - Keywords should be meaningful and specific entities; avoid using meaningless or overly broad terms (e.g., “object,” “the,” “he”). + # - Prioritize extracting subjects, verbs, and objects; avoid extracting function words or auxiliary words. + # - Do not expand into unrelated generalized categories. + # Note: + # - Only consider semantic synonyms and other words with similar meanings in the given context. + # Output Format: + # - Output only one line, prefixed with KEYWORDS:, followed by all keywords and synonyms, separated by commas.No spaces or empty characters are allowed in the extracted keywords. + # - Format example: + # KEYWORDS: keyword1, keyword2, ..., keywordN, synonym1, synonym2, ..., synonymN + # Text: + # {question} + # """ + + gremlin_generate_prompt = """ +You are an expert in graph query language(Gremlin), your role is to understand the schema of the graph and generate +accurate Gremlin code based on the given instructions. + +# Graph Schema: ```json {schema} ``` +# Rule: +1. Could use the vertex ID directly if it's given in the context. +2. The output format must be like: +```gremlin +g.V().limit(10) +``` -Given the extracted vertex vid: +# Extracted vertex vid: {vertices} -Generate gremlin from the following user query. +# Given the example query-gremlin pairs: +{example} + +# Generate gremlin from the following user query. {query} -The output format must be like: -```gremlin -g.V().limit(10) -``` + The generated gremlin is: """ diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/app.py b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/app.py index 3fe6c0f..700e60b 100644 --- a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/app.py +++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/app.py @@ -35,7 +35,7 @@ from hugegraph_llm.demo.rag_demo.configs_block import ( apply_graph_config, ) from hugegraph_llm.demo.rag_demo.other_block import create_other_block -from hugegraph_llm.demo.rag_demo.text2gremlin_block import create_text2gremlin_block +from hugegraph_llm.demo.rag_demo.text2gremlin_block import create_text2gremlin_block, graph_rag_recall from hugegraph_llm.demo.rag_demo.rag_block import create_rag_block, rag_answer from hugegraph_llm.demo.rag_demo.vector_graph_block import create_vector_graph_block from hugegraph_llm.resources.demo.css import CSS @@ -171,7 +171,13 @@ if __name__ == "__main__": hugegraph_llm = init_rag_ui() rag_http_api( - api_auth, rag_answer, apply_graph_config, apply_llm_config, apply_embedding_config, apply_reranker_config + api_auth, + rag_answer, + graph_rag_recall, + apply_graph_config, + apply_llm_config, + apply_embedding_config, + apply_reranker_config, ) admin_http_api(api_auth, log_stream) diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/text2gremlin_block.py b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/text2gremlin_block.py index eaa37bc..c47ce7c 100644 --- a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/text2gremlin_block.py +++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/text2gremlin_block.py @@ -17,14 +17,15 @@ import json import os -from typing import Any, Tuple, Dict, Union +from typing import Any, Tuple, Dict, Union, Literal import gradio as gr import pandas as pd -from hugegraph_llm.config import prompt, resource_path +from hugegraph_llm.config import prompt, resource_path, huge_settings from hugegraph_llm.models.embeddings.init_embedding import Embeddings from hugegraph_llm.models.llms.init_llm import LLMs +from hugegraph_llm.operators.graph_rag_task import RAGPipeline from hugegraph_llm.operators.gremlin_generate_task import GremlinGenerator from hugegraph_llm.operators.hugegraph_op.schema_manager import SchemaManager from hugegraph_llm.utils.hugegraph_utils import run_gremlin_query @@ -32,8 +33,11 @@ from hugegraph_llm.utils.log import log def store_schema(schema, question, gremlin_prompt): - if (prompt.text2gql_graph_schema != schema or prompt.default_question != question or - prompt.gremlin_generate_prompt != gremlin_prompt): + if ( + prompt.text2gql_graph_schema != schema or + prompt.default_question != question or + prompt.gremlin_generate_prompt != gremlin_prompt + ): prompt.text2gql_graph_schema = schema prompt.default_question = question prompt.gremlin_generate_prompt = gremlin_prompt @@ -49,7 +53,7 @@ def build_example_vector_index(temp_file) -> dict: with open(full_path, "r", encoding="utf-8") as f: examples = json.load(f) elif full_path.endswith(".csv"): - examples = pd.read_csv(full_path).to_dict('records') + examples = pd.read_csv(full_path).to_dict("records") else: log.critical("Unsupported file format. Please input a JSON or CSV file.") return {"error": "Unsupported file format. Please input a JSON or CSV file."} @@ -60,8 +64,9 @@ def build_example_vector_index(temp_file) -> dict: return builder.example_index_build(examples).run() -def gremlin_generate(inp, example_num, schema, gremlin_prompt) -> Union[ - tuple[str, str], tuple[str, Any, Any, Any, Any]]: +def gremlin_generate( + inp, example_num, schema, gremlin_prompt +) -> Union[tuple[str, str], tuple[str, Any, Any, Any, Any]]: generator = GremlinGenerator(llm=LLMs().get_text2gql_llm(), embedding=Embeddings().get_embedding()) sm = SchemaManager(graph_name=schema) short_schema = False @@ -83,19 +88,28 @@ def gremlin_generate(inp, example_num, schema, gremlin_prompt) -> Union[ return "Invalid JSON schema, please check the format carefully.", "" # FIXME: schema is not used in gremlin_generate() step, no context for it (enhance the logic here) updated_schema = sm.simple_schema(schema) if short_schema else schema - context = generator.example_index_query(example_num).gremlin_generate_synthesize(updated_schema, - gremlin_prompt).run(query=inp) + store_schema(str(updated_schema), inp, gremlin_prompt) + context = ( + generator.example_index_query(example_num).gremlin_generate_synthesize(updated_schema, gremlin_prompt) + .run(query=inp) + ) try: context["template_exec_res"] = run_gremlin_query(query=context["result"]) - except Exception as e: # pylint: disable=broad-except + except Exception as e: # pylint: disable=broad-except context["template_exec_res"] = f"{e}" try: context["raw_exec_res"] = run_gremlin_query(query=context["raw_result"]) - except Exception as e: # pylint: disable=broad-except + except Exception as e: # pylint: disable=broad-except context["raw_exec_res"] = f"{e}" match_result = json.dumps(context.get("match_result", "No Results"), ensure_ascii=False, indent=2) - return match_result, context["result"], context["raw_result"], context["template_exec_res"], context["raw_exec_res"] + return ( + match_result, + context["result"], + context["raw_result"], + context["template_exec_res"], + context["raw_exec_res"], + ) def simple_schema(schema: Dict[str, Any]) -> Dict[str, Any]: @@ -112,24 +126,24 @@ def simple_schema(schema: Dict[str, Any]) -> Dict[str, Any]: if "edgelabels" in schema: mini_schema["edgelabels"] = [] for edge in schema["edgelabels"]: - new_edge = {key: edge[key] for key in - ["name", "source_label", "target_label", "properties"] if key in edge} + new_edge = {key: edge[key] for key in ["name", "source_label", "target_label", "properties"] if key in edge} mini_schema["edgelabels"].append(new_edge) return mini_schema def create_text2gremlin_block() -> Tuple: - gr.Markdown("""## Build Vector Template Index (Optional) + gr.Markdown( + """## Build Vector Template Index (Optional) > Uploaded CSV file should be in `query,gremlin` format below: > e.g. `who is peter?`,`g.V().has('name', 'peter')` > JSON file should be in format below: > e.g. `[{"query":"who is peter", "gremlin":"g.V().has('name', 'peter')"}]` - """) + """ + ) with gr.Row(): file = gr.File( - value=os.path.join(resource_path, "demo", "text2gremlin.csv"), - label="Upload Text-Gremlin Pairs File" + value=os.path.join(resource_path, "demo", "text2gremlin.csv"), label="Upload Text-Gremlin Pairs File" ) out = gr.Textbox(label="Result Message") with gr.Row(): @@ -143,27 +157,49 @@ def create_text2gremlin_block() -> Tuple: match = gr.Code(label="Similar Template (TopN)", language="javascript", elem_classes="code-container-show") initialized_out = gr.Textbox(label="Gremlin With Template", show_copy_button=True) raw_out = gr.Textbox(label="Gremlin Without Template", show_copy_button=True) - tmpl_exec_out = gr.Code(label="Query With Template Output", language="json", - elem_classes="code-container-show") - raw_exec_out = gr.Code(label="Query Without Template Output", language="json", - elem_classes="code-container-show") + tmpl_exec_out = gr.Code( + label="Query With Template Output", language="json", elem_classes="code-container-show" + ) + raw_exec_out = gr.Code( + label="Query Without Template Output", language="json", elem_classes="code-container-show" + ) with gr.Column(scale=1): - example_num_slider = gr.Slider( - minimum=0, - maximum=10, - step=1, - value=2, - label="Number of refer examples" - ) + example_num_slider = gr.Slider(minimum=0, maximum=10, step=1, value=2, label="Number of refer examples") schema_box = gr.Textbox(value=prompt.text2gql_graph_schema, label="Schema", lines=2, show_copy_button=True) - prompt_box = gr.Textbox(value=prompt.gremlin_generate_prompt, label="Prompt", lines=2, - show_copy_button=True) + prompt_box = gr.Textbox( + value=prompt.gremlin_generate_prompt, label="Prompt", lines=20, show_copy_button=True + ) btn = gr.Button("Text2Gremlin", variant="primary") btn.click( # pylint: disable=no-member fn=gremlin_generate, inputs=[input_box, example_num_slider, schema_box, prompt_box], - outputs=[match, initialized_out, raw_out, tmpl_exec_out, raw_exec_out] - ).then(store_schema, inputs=[schema_box, input_box, prompt_box], ) + outputs=[match, initialized_out, raw_out, tmpl_exec_out, raw_exec_out], + ) return input_box, schema_box, prompt_box + + +def graph_rag_recall( + query: str, + gremlin_tmpl_num: int, + with_gremlin_tmpl: bool, + rerank_method: Literal["bleu", "reranker"], + near_neighbor_first: bool, + custom_related_information: str, + gremlin_prompt: str, +) -> dict: + store_schema(prompt.text2gql_graph_schema, query, gremlin_prompt) + rag = RAGPipeline() + + rag.extract_keywords().keywords_to_vid().import_schema(huge_settings.graph_name).query_graphdb( + with_gremlin_template=with_gremlin_tmpl, + num_gremlin_generate_example=gremlin_tmpl_num, + gremlin_prompt=gremlin_prompt, + ).merge_dedup_rerank( + rerank_method=rerank_method, + near_neighbor_first=near_neighbor_first, + custom_related_information=custom_related_information, + ) + context = rag.run(verbose=True, query=query, graph_search=True) + return context diff --git a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/gremlin_generate.py b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/gremlin_generate.py index 955fc9e..9694647 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/gremlin_generate.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/gremlin_generate.py @@ -20,19 +20,19 @@ import json import re from typing import Optional, List, Dict, Any, Union +from hugegraph_llm.config import prompt from hugegraph_llm.models.llms.base import BaseLLM from hugegraph_llm.models.llms.init_llm import LLMs from hugegraph_llm.utils.log import log -from hugegraph_llm.config import prompt class GremlinGenerateSynthesize: def __init__( - self, - llm: BaseLLM = None, - schema: Optional[Union[dict, str]] = None, - vertices: Optional[List[str]] = None, - gremlin_prompt: Optional[str] = None + self, + llm: BaseLLM = None, + schema: Optional[Union[dict, str]] = None, + vertices: Optional[List[str]] = None, + gremlin_prompt: Optional[str] = None ) -> None: self.llm = llm or LLMs().get_text2gql_llm() if isinstance(schema, dict): @@ -59,7 +59,7 @@ class GremlinGenerateSynthesize: def _format_vertices(self, vertices: Optional[List[str]]) -> Optional[str]: if not vertices: return None - return "\n".join([f"- {vid}" for vid in vertices]) + return "\n".join([f"- '{vid}'" for vid in vertices]) async def async_generate(self, context: Dict[str, Any]): async_tasks = {}