This is an automated email from the ASF dual-hosted git repository.

jin pushed a commit to branch search-template
in repository https://gitbox.apache.org/repos/asf/incubator-hugegraph-ai.git

commit 919377587a90728af4879fcffb25a5910d668077
Author: imbajin <[email protected]>
AuthorDate: Wed Aug 21 17:45:57 2024 +0800

    feat(llm): support user-defined search template
---
 .../src/hugegraph_llm/demo/rag_web_demo.py         |  19 +--
 .../operators/common_op/merge_dedup_rerank.py      |   8 +-
 .../src/hugegraph_llm/operators/graph_rag_task.py  | 136 +++++++++++++++------
 .../operators/llm_op/answer_synthesize.py          |  33 ++---
 4 files changed, 137 insertions(+), 59 deletions(-)

diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py 
b/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py
index 3151e23..b532ac1 100644
--- a/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py
+++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py
@@ -35,7 +35,7 @@ from hugegraph_llm.config import settings, resource_path
 from hugegraph_llm.enums.build_mode import BuildMode
 from hugegraph_llm.models.embeddings.init_embedding import Embeddings
 from hugegraph_llm.models.llms.init_llm import LLMs
-from hugegraph_llm.operators.graph_rag_task import GraphRAG
+from hugegraph_llm.operators.graph_rag_task import RAGPipeline
 from hugegraph_llm.operators.kg_construction_task import KgBuilder
 from hugegraph_llm.operators.llm_op.property_graph_extract import 
SCHEMA_EXAMPLE_PROMPT
 from hugegraph_llm.utils.hugegraph_utils import get_hg_client
@@ -58,24 +58,26 @@ def authenticate(credentials: HTTPAuthorizationCredentials 
= Depends(sec)):
 
 
 def rag_answer(
-        text: str, raw_answer: bool, vector_only_answer: bool, 
graph_only_answer: bool, graph_vector_answer: bool
-) -> tuple:
+        text: str, raw_answer: bool, vector_only_answer: bool, 
graph_only_answer: bool,
+         graph_vector_answer: bool, answer_prompt: str) -> tuple:
     vector_search = vector_only_answer or graph_vector_answer
     graph_search = graph_only_answer or graph_vector_answer
 
     if raw_answer is False and not vector_search and not graph_search:
         gr.Warning("Please select at least one generate mode.")
         return "", "", "", ""
-    searcher = GraphRAG()
+    searcher = RAGPipeline()
     if vector_search:
         searcher.query_vector_index_for_rag()
     if graph_search:
         searcher.extract_word().match_keyword_to_id().query_graph_for_rag()
+    # TODO: add more user-defined search strategies
     searcher.merge_dedup_rerank().synthesize_answer(
         raw_answer=raw_answer,
         vector_only_answer=vector_only_answer,
         graph_only_answer=graph_only_answer,
         graph_vector_answer=graph_vector_answer,
+        answer_prompt=answer_prompt
     )
 
     try:
@@ -449,6 +451,9 @@ def init_rag_ui() -> gr.Interface:
                 vector_only_radio = gr.Radio(choices=[True, False], 
value=False, label="Vector-only Answer")
                 graph_only_radio = gr.Radio(choices=[True, False], 
value=False, label="Graph-only Answer")
                 graph_vector_radio = gr.Radio(choices=[True, False], 
value=False, label="Graph-Vector Answer")
+                from hugegraph_llm.operators.llm_op.answer_synthesize import 
DEFAULT_ANSWER_TEMPLATE
+                answer_prompt_input = 
gr.Textbox(value=DEFAULT_ANSWER_TEMPLATE, label="Custom Prompt",
+                                                 show_copy_button=True)
                 btn = gr.Button("Answer Question")
         btn.click(  # pylint: disable=no-member
             fn=rag_answer,
@@ -458,6 +463,7 @@ def init_rag_ui() -> gr.Interface:
                 vector_only_radio,
                 graph_only_radio,
                 graph_vector_radio,
+                answer_prompt_input,
             ],
             outputs=[raw_out, vector_only_out, graph_only_out, 
graph_vector_out],
         )
@@ -496,7 +502,6 @@ if __name__ == "__main__":
     # TODO: support multi-user login when need
     app = gr.mount_gradio_app(app, hugegraph_llm, path="/", auth=("rag", 
os.getenv("TOKEN")) if auth_enabled else None)
 
-    # Note: set reload to False in production environment
-    uvicorn.run(app, host=args.host, port=args.port)
     # TODO: we can't use reload now due to the config 'app' of uvicorn.run
-    # uvicorn.run("rag_web_demo:app", host="0.0.0.0", port=8001, reload=True)
+    # ❎:f'{__name__}:app' / rag_web_demo:app / 
hugegraph_llm.demo.rag_web_demo:app
+    uvicorn.run(app, host=args.host, port=args.port, reload=False)
diff --git 
a/hugegraph-llm/src/hugegraph_llm/operators/common_op/merge_dedup_rerank.py 
b/hugegraph-llm/src/hugegraph_llm/operators/common_op/merge_dedup_rerank.py
index e012479..2187096 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/common_op/merge_dedup_rerank.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/common_op/merge_dedup_rerank.py
@@ -34,16 +34,16 @@ class MergeDedupRerank:
             self,
             embedding: BaseEmbedding,
             topk: int = 10,
-            policy: Literal["bleu", "priority"] = "bleu"
+            strategy: Literal["bleu", "priority"] = "bleu"
     ):
         self.embedding = embedding
         self.topk = topk
-        if policy == "bleu":
+        if strategy == "bleu":
             self.rerank_func = self._bleu_rerank
-        elif policy == "priority":
+        elif strategy == "priority":
             self.rerank_func = self._priority_rerank
         else:
-            raise ValueError(f"Unimplemented policy {policy}.")
+            raise ValueError(f"Unimplemented rank strategy {strategy}.")
 
     def run(self, context: Dict[str, Any]) -> Dict[str, Any]:
         query = context.get("query")
diff --git a/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py 
b/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py
index f60f091..de75352 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py
@@ -19,12 +19,12 @@
 import time
 from typing import Dict, Any, Optional, List
 
-from hugegraph_llm.models.llms.base import BaseLLM
 from hugegraph_llm.models.embeddings.base import BaseEmbedding
-from hugegraph_llm.models.llms.init_llm import LLMs
 from hugegraph_llm.models.embeddings.init_embedding import Embeddings
-from hugegraph_llm.operators.common_op.print_result import PrintResult
+from hugegraph_llm.models.llms.base import BaseLLM
+from hugegraph_llm.models.llms.init_llm import LLMs
 from hugegraph_llm.operators.common_op.merge_dedup_rerank import 
MergeDedupRerank
+from hugegraph_llm.operators.common_op.print_result import PrintResult
 from hugegraph_llm.operators.document_op.word_extract import WordExtract
 from hugegraph_llm.operators.hugegraph_op.graph_rag_query import GraphRAGQuery
 from hugegraph_llm.operators.index_op.semantic_id_query import SemanticIdQuery
@@ -34,33 +34,56 @@ from hugegraph_llm.operators.llm_op.keyword_extract import 
KeywordExtract
 from hugegraph_llm.utils.log import log
 
 
-class GraphRAG:
+class RAGPipeline:
+    """
+    RAGPipeline is a (core)class that encapsulates a series of operations for 
extracting information from text,
+    querying graph databases and vector indices, merging and re-ranking 
results, and generating answers.
+    """
+
     def __init__(self, llm: Optional[BaseLLM] = None, embedding: 
Optional[BaseEmbedding] = None):
+        """
+        Initialize the RAGPipeline with optional LLM and embedding models.
+
+        :param llm: Optional LLM model to use.
+        :param embedding: Optional embedding model to use.
+        """
         self._llm = llm or LLMs().get_llm()
         self._embedding = embedding or Embeddings().get_embedding()
         self._operators: List[Any] = []
 
     def extract_word(
-        self,
-        text: Optional[str] = None,
-        language: str = "english",
+            self,
+            text: Optional[str] = None,
+            language: str = "english",
     ):
-        self._operators.append(
-            WordExtract(
-                text=text,
-                language=language,
-            )
-        )
+        """
+        Add a word extraction operator to the pipeline.
+
+        :param text: Text to extract words from.
+        :param language: Language of the text.
+        :return: Self-instance for chaining.
+        """
+        self._operators.append(WordExtract(text=text, language=language))
         return self
 
     def extract_keyword(
-        self,
-        text: Optional[str] = None,
-        max_keywords: int = 5,
-        language: str = "english",
-        extract_template: Optional[str] = None,
-        expand_template: Optional[str] = None,
+            self,
+            text: Optional[str] = None,
+            max_keywords: int = 5,
+            language: str = "english",
+            extract_template: Optional[str] = None,
+            expand_template: Optional[str] = None,
     ):
+        """
+        Add a keyword extraction operator to the pipeline.
+
+        :param text: Text to extract keywords from.
+        :param max_keywords: Maximum number of keywords to extract.
+        :param language: Language of the text.
+        :param extract_template: Template for keyword extraction.
+        :param expand_template: Template for keyword expansion.
+        :return: Self-instance for chaining.
+        """
         self._operators.append(
             KeywordExtract(
                 text=text,
@@ -73,6 +96,12 @@ class GraphRAG:
         return self
 
     def match_keyword_to_id(self, topk_per_keyword: int = 1):
+        """
+        Add a semantic ID query operator to the pipeline.
+
+        :param topk_per_keyword: Top K results per keyword.
+        :return: Self-instance for chaining.
+        """
         self._operators.append(
             SemanticIdQuery(
                 embedding=self._embedding,
@@ -87,6 +116,14 @@ class GraphRAG:
             max_items: int = 30,
             prop_to_match: Optional[str] = None,
     ):
+        """
+        Add a graph RAG query operator to the pipeline.
+
+        :param max_deep: Maximum depth for the graph query.
+        :param max_items: Maximum number of items to retrieve.
+        :param prop_to_match: Property to match in the graph.
+        :return: Self-instance for chaining.
+        """
         self._operators.append(
             GraphRAGQuery(
                 max_deep=max_deep,
@@ -100,6 +137,12 @@ class GraphRAG:
             self,
             max_items: int = 3
     ):
+        """
+        Add a vector index query operator to the pipeline.
+
+        :param max_items: Maximum number of items to retrieve.
+        :return: Self-instance for chaining.
+        """
         self._operators.append(
             VectorIndexQuery(
                 embedding=self._embedding,
@@ -109,37 +152,62 @@ class GraphRAG:
         return self
 
     def merge_dedup_rerank(self):
-        self._operators.append(
-            MergeDedupRerank(
-                embedding=self._embedding,
-            )
+        """
+        Add a merge, deduplication, and rerank operator to the pipeline.
+
+        :return: Self-instance for chaining.
+        """
+        self._operators.append(MergeDedupRerank(
+            embedding=self._embedding,
+        )
         )
         return self
 
     def synthesize_answer(
-        self,
-        raw_answer: bool = False,
-        vector_only_answer: bool = True,
-        graph_only_answer: bool = False,
-        graph_vector_answer: bool = False,
-        prompt_template: Optional[str] = None,
+            self,
+            raw_answer: bool = False,
+            vector_only_answer: bool = True,
+            graph_only_answer: bool = False,
+            graph_vector_answer: bool = False,
+            answer_prompt: Optional[str] = None,
     ):
+        """
+        Add an answer synthesis operator to the pipeline.
+
+        :param raw_answer: Whether to return raw answers.
+        :param vector_only_answer: Whether to return vector-only answers.
+        :param graph_only_answer: Whether to return graph-only answers.
+        :param graph_vector_answer: Whether to return graph-vector combined 
answers.
+        :param answer_prompt: Template for the answer synthesis prompt.
+        :return: Self-instance for chaining.
+        """
         self._operators.append(
             AnswerSynthesize(
-                raw_answer = raw_answer,
-                vector_only_answer = vector_only_answer,
-                graph_only_answer = graph_only_answer,
-                graph_vector_answer = graph_vector_answer,
-                prompt_template=prompt_template,
+                raw_answer=raw_answer,
+                vector_only_answer=vector_only_answer,
+                graph_only_answer=graph_only_answer,
+                graph_vector_answer=graph_vector_answer,
+                prompt_template=answer_prompt,
             )
         )
         return self
 
     def print_result(self):
+        """
+        Add a print result operator to the pipeline.
+
+        :return: Self-instance for chaining.
+        """
         self._operators.append(PrintResult())
         return self
 
     def run(self, **kwargs) -> Dict[str, Any]:
+        """
+        Execute all operators in the pipeline in sequence.
+
+        :param kwargs: Additional context to pass to operators.
+        :return: Final context after all operators have been executed.
+        """
         if len(self._operators) == 0:
             self.extract_keyword().query_graph_for_rag().synthesize_answer()
 
diff --git 
a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py 
b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py
index f3803c7..6e050d5 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py
@@ -23,19 +23,24 @@ from hugegraph_llm.models.llms.base import BaseLLM
 from hugegraph_llm.models.llms.init_llm import LLMs
 
 # TODO: we need enhance the template to answer the question
-DEFAULT_ANSWER_SYNTHESIZE_TEMPLATE_TMPL = (
-    "Context information is below.\n"
-    "---------------------\n"
-    "{context_str}\n"
-    "---------------------\n"
-    "You need to refer to the context based on the following priority:\n"
-    "1. Graph recall > vector recall\n"
-    "2. Exact recall > Fuzzy recall\n"
-    "3. Independent vertex > 1-depth neighbor> 2-depth neighbors\n"
-    "Given the context information and not prior knowledge, answer the 
query.\n"
-    "Query: {query_str}\n"
-    "Answer: "
-)
+DEFAULT_ANSWER_TEMPLATE = f"""
+You are an expert in knowledge graphs and natural language processing. 
+Your task is to provide a precise and accurate answer based on the given 
context.
+
+Context information is below.
+---------------------
+{{context_str}}
+---------------------
+Please refer to the context based on the following priority:
+1. Graph recall > Vector recall
+2. Exact recall > Fuzzy recall
+3. Independent vertex > 1-depth neighbor > 2-depth neighbors
+
+Given the context information and without using prior knowledge, 
+answer the following query in a concise and professional manner.
+Query: {{query_str}}
+Answer:
+"""
 
 
 class AnswerSynthesize:
@@ -53,7 +58,7 @@ class AnswerSynthesize:
             graph_vector_answer: bool = False,
     ):
         self._llm = llm
-        self._prompt_template = prompt_template or 
DEFAULT_ANSWER_SYNTHESIZE_TEMPLATE_TMPL
+        self._prompt_template = prompt_template or DEFAULT_ANSWER_TEMPLATE
         self._question = question
         self._context_body = context_body
         self._context_head = context_head

Reply via email to