This is an automated email from the ASF dual-hosted git repository. jin pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/incubator-hugegraph-ai.git
The following commit(s) were added to refs/heads/main by this push: new 927e17c fix(llm): fix tiny bugs & optimize reranker layout (#202) 927e17c is described below commit 927e17cde6bff339e912f29c4ec87815b4794e4a Author: HaoJin Yang <1454...@gmail.com> AuthorDate: Tue Mar 18 20:32:22 2025 +0800 fix(llm): fix tiny bugs & optimize reranker layout (#202) * also update answer promp --- .../src/hugegraph_llm/config/prompt_config.py | 9 ++- .../src/hugegraph_llm/demo/rag_demo/rag_block.py | 87 +++++++++++++++------- .../models/rerankers/init_reranker.py | 11 +-- .../operators/common_op/merge_dedup_rerank.py | 4 +- 4 files changed, 76 insertions(+), 35 deletions(-) diff --git a/hugegraph-llm/src/hugegraph_llm/config/prompt_config.py b/hugegraph-llm/src/hugegraph_llm/config/prompt_config.py index ad32bbc..ca2b7b7 100644 --- a/hugegraph-llm/src/hugegraph_llm/config/prompt_config.py +++ b/hugegraph-llm/src/hugegraph_llm/config/prompt_config.py @@ -18,11 +18,13 @@ from hugegraph_llm.config.models.base_prompt_config import BasePromptConfig + # pylint: disable=C0301 class PromptConfig(BasePromptConfig): # Data is detached from llm_op/answer_synthesize.py - answer_prompt: str = """You are an expert in knowledge graphs and natural language processing. -Your task is to provide a precise and accurate answer based on the given context. + answer_prompt: str = """You are an expert in the fields of knowledge graphs and natural language processing. + +Please provide precise and accurate answers based on the following context information, which is sorted in order of importance from high to low, without using any fabricated knowledge. Given the context information and without using fictive knowledge, answer the following query in a concise and professional manner. @@ -246,7 +248,8 @@ and experiences. answer_prompt_CN: str = """你是知识图谱和自然语言处理领域的专家。 你的任务是基于给定的上下文提供精确和准确的答案。 -根据提供的上下文信息,不使用虚构知识, +请根据以下按重要性从高到低排序的上下文信息,提供基于上下文的精确、准确的答案,不使用任何虚构的知识。 + 请以简洁专业的方式回答以下问题。 请使用 Markdown 格式编写答案,其中行内数学公式用 `$...$` 包裹 diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/rag_block.py b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/rag_block.py index df82568..e51e25e 100644 --- a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/rag_block.py +++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/rag_block.py @@ -57,10 +57,16 @@ def rag_answer( 4. Synthesize the final answer. 5. Run the pipeline and return the results. """ - graph_search, gremlin_prompt, vector_search = update_ui_configs(answer_prompt, custom_related_information, - graph_only_answer, graph_vector_answer, - gremlin_prompt, keywords_extract_prompt, text, - vector_only_answer) + graph_search, gremlin_prompt, vector_search = update_ui_configs( + answer_prompt, + custom_related_information, + graph_only_answer, + graph_vector_answer, + gremlin_prompt, + keywords_extract_prompt, + text, + vector_only_answer, + ) if raw_answer is False and not vector_search and not graph_search: gr.Warning("Please select at least one generate mode.") return "", "", "", "" @@ -72,25 +78,28 @@ def rag_answer( rag.extract_keywords(extract_template=keywords_extract_prompt).keywords_to_vid( vector_dis_threshold=vector_dis_threshold, topk_per_keyword=topk_per_keyword, - ).import_schema( - huge_settings.graph_name - ).query_graphdb( + ).import_schema(huge_settings.graph_name).query_graphdb( num_gremlin_generate_example=gremlin_tmpl_num, gremlin_prompt=gremlin_prompt, - max_graph_items=max_graph_items + max_graph_items=max_graph_items, ) # TODO: add more user-defined search strategies rag.merge_dedup_rerank( graph_ratio=graph_ratio, rerank_method=rerank_method, near_neighbor_first=near_neighbor_first, - topk_return_results=topk_return_results + topk_return_results=topk_return_results, ) rag.synthesize_answer(raw_answer, vector_only_answer, graph_only_answer, graph_vector_answer, answer_prompt) try: - context = rag.run(verbose=True, query=text, vector_search=vector_search, graph_search=graph_search, - max_graph_items=max_graph_items) + context = rag.run( + verbose=True, + query=text, + vector_search=vector_search, + graph_search=graph_search, + max_graph_items=max_graph_items, + ) if context.get("switch_to_bleu"): gr.Warning("Online reranker fails, automatically switches to local bleu rerank.") return ( @@ -107,8 +116,16 @@ def rag_answer( raise gr.Error(f"An unexpected error occurred: {str(e)}") -def update_ui_configs(answer_prompt, custom_related_information, graph_only_answer, graph_vector_answer, gremlin_prompt, - keywords_extract_prompt, text, vector_only_answer): +def update_ui_configs( + answer_prompt, + custom_related_information, + graph_only_answer, + graph_vector_answer, + gremlin_prompt, + keywords_extract_prompt, + text, + vector_only_answer, +): gremlin_prompt = gremlin_prompt or prompt.gremlin_generate_prompt should_update_prompt = ( prompt.default_question != text @@ -153,10 +170,16 @@ async def rag_answer_streaming( 5. Run the pipeline and return the results. """ - graph_search, gremlin_prompt, vector_search = update_ui_configs(answer_prompt, custom_related_information, - graph_only_answer, graph_vector_answer, - gremlin_prompt, keywords_extract_prompt, text, - vector_only_answer) + graph_search, gremlin_prompt, vector_search = update_ui_configs( + answer_prompt, + custom_related_information, + graph_only_answer, + graph_vector_answer, + gremlin_prompt, + keywords_extract_prompt, + text, + vector_only_answer, + ) if raw_answer is False and not vector_search and not graph_search: gr.Warning("Please select at least one generate mode.") yield "", "", "", "" @@ -216,17 +239,29 @@ def create_rag_block(): # TODO: Only support inline formula now. Should support block formula gr.Markdown("Basic LLM Answer", elem_classes="output-box-label") - raw_out = gr.Markdown(elem_classes="output-box", show_copy_button=True, - latex_delimiters=[{"left": "$", "right": "$", "display": False}]) + raw_out = gr.Markdown( + elem_classes="output-box", + show_copy_button=True, + latex_delimiters=[{"left": "$", "right": "$", "display": False}], + ) gr.Markdown("Vector-only Answer", elem_classes="output-box-label") - vector_only_out = gr.Markdown(elem_classes="output-box", show_copy_button=True, - latex_delimiters=[{"left": "$", "right": "$", "display": False}]) + vector_only_out = gr.Markdown( + elem_classes="output-box", + show_copy_button=True, + latex_delimiters=[{"left": "$", "right": "$", "display": False}], + ) gr.Markdown("Graph-only Answer", elem_classes="output-box-label") - graph_only_out = gr.Markdown(elem_classes="output-box", show_copy_button=True, - latex_delimiters=[{"left": "$", "right": "$", "display": False}]) + graph_only_out = gr.Markdown( + elem_classes="output-box", + show_copy_button=True, + latex_delimiters=[{"left": "$", "right": "$", "display": False}], + ) gr.Markdown("Graph-Vector Answer", elem_classes="output-box-label") - graph_vector_out = gr.Markdown(elem_classes="output-box", show_copy_button=True, - latex_delimiters=[{"left": "$", "right": "$", "display": False}]) + graph_vector_out = gr.Markdown( + elem_classes="output-box", + show_copy_button=True, + latex_delimiters=[{"left": "$", "right": "$", "display": False}], + ) answer_prompt_input = gr.Textbox( value=prompt.answer_prompt, label="Query Prompt", show_copy_button=True, lines=7 @@ -252,7 +287,7 @@ def create_rag_block(): with gr.Row(): online_rerank = llm_settings.reranker_type rerank_method = gr.Dropdown( - choices=["bleu", ("rerank (online)", "reranker")] if online_rerank else ["bleu"], + choices=["bleu", ("rerank (online)", "reranker")], value="reranker" if online_rerank else "bleu", label="Rerank method", ) diff --git a/hugegraph-llm/src/hugegraph_llm/models/rerankers/init_reranker.py b/hugegraph-llm/src/hugegraph_llm/models/rerankers/init_reranker.py index b1d6ef5..aa9f0c0 100644 --- a/hugegraph-llm/src/hugegraph_llm/models/rerankers/init_reranker.py +++ b/hugegraph-llm/src/hugegraph_llm/models/rerankers/init_reranker.py @@ -15,21 +15,22 @@ # specific language governing permissions and limitations # under the License. -from hugegraph_llm.config import huge_settings +from hugegraph_llm.config import llm_settings from hugegraph_llm.models.rerankers.cohere import CohereReranker from hugegraph_llm.models.rerankers.siliconflow import SiliconReranker class Rerankers: def __init__(self): - self.reranker_type = huge_settings.reranker_type + self.reranker_type = llm_settings.reranker_type def get_reranker(self): if self.reranker_type == "cohere": return CohereReranker( - api_key=huge_settings.reranker_api_key, base_url=huge_settings.cohere_base_url, - model=huge_settings.reranker_model + api_key=llm_settings.reranker_api_key, + base_url=llm_settings.cohere_base_url, + model=llm_settings.reranker_model, ) if self.reranker_type == "siliconflow": - return SiliconReranker(api_key=huge_settings.reranker_api_key, model=huge_settings.reranker_model) + return SiliconReranker(api_key=llm_settings.reranker_api_key, model=llm_settings.reranker_model) raise Exception("Reranker type is not supported!") diff --git a/hugegraph-llm/src/hugegraph_llm/operators/common_op/merge_dedup_rerank.py b/hugegraph-llm/src/hugegraph_llm/operators/common_op/merge_dedup_rerank.py index 62968fd..910de20 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/common_op/merge_dedup_rerank.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/common_op/merge_dedup_rerank.py @@ -22,7 +22,7 @@ import jieba import requests from nltk.translate.bleu_score import sentence_bleu -from hugegraph_llm.config import huge_settings +from hugegraph_llm.config import huge_settings, llm_settings from hugegraph_llm.models.embeddings.base import BaseEmbedding from hugegraph_llm.models.rerankers.init_reranker import Rerankers from hugegraph_llm.utils.log import log @@ -52,6 +52,8 @@ class MergeDedupRerank: priority: bool = False, # TODO: implement priority ): assert method in ["bleu", "reranker"], f"Unimplemented rerank method '{method}'." + if llm_settings.reranker_type is None: + assert method == "bleu", "Please set the online reranker first" self.embedding = embedding self.graph_ratio = graph_ratio self.topk_return_results = topk_return_results