This is an automated email from the ASF dual-hosted git repository.
ming pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-hugegraph-ai.git
The following commit(s) were added to refs/heads/main by this push:
new 0168b85 feat(llm): support user-defined search template (#71)
0168b85 is described below
commit 0168b85a7f92409771b89c65dd3abf8fd5b1a816
Author: imbajin <[email protected]>
AuthorDate: Sat Aug 24 21:49:06 2024 +0800
feat(llm): support user-defined search template (#71)
* feat: Add word segmentation extraction component and multiple file import
adaptation
1. Change LLM keyword extraction to word tokenization extraction, change
vid matching method
2. Change the file uploading in the import stage to fit multiple file
uploading
* fix ollama test api connection and fix invoking graph rag twice
* feat(llm): support graphspace & unify the port to str
* chore: update gitignore files
* feat(llm): support user-defined search template
* enhance the search prompt
---------
Co-authored-by: vichayturen <[email protected]>
---
.../src/hugegraph_llm/demo/rag_web_demo.py | 21 ++--
.../operators/common_op/merge_dedup_rerank.py | 8 +-
.../src/hugegraph_llm/operators/graph_rag_task.py | 128 ++++++++++++++++-----
.../operators/llm_op/answer_synthesize.py | 35 +++---
.../operators/llm_op/property_graph_extract.py | 1 +
5 files changed, 137 insertions(+), 56 deletions(-)
diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py
b/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py
index 74704ba..f42d6e7 100644
--- a/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py
+++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py
@@ -35,7 +35,7 @@ from hugegraph_llm.config import settings, resource_path
from hugegraph_llm.enums.build_mode import BuildMode
from hugegraph_llm.models.embeddings.init_embedding import Embeddings
from hugegraph_llm.models.llms.init_llm import LLMs
-from hugegraph_llm.operators.graph_rag_task import GraphRAG
+from hugegraph_llm.operators.graph_rag_task import RAGPipeline
from hugegraph_llm.operators.kg_construction_task import KgBuilder
from hugegraph_llm.operators.llm_op.property_graph_extract import
SCHEMA_EXAMPLE_PROMPT
from hugegraph_llm.utils.hugegraph_utils import get_hg_client
@@ -58,24 +58,26 @@ def authenticate(credentials: HTTPAuthorizationCredentials
= Depends(sec)):
def rag_answer(
- text: str, raw_answer: bool, vector_only_answer: bool,
graph_only_answer: bool, graph_vector_answer: bool
-) -> tuple:
+ text: str, raw_answer: bool, vector_only_answer: bool,
graph_only_answer: bool,
+ graph_vector_answer: bool, answer_prompt: str) -> tuple:
vector_search = vector_only_answer or graph_vector_answer
graph_search = graph_only_answer or graph_vector_answer
if raw_answer is False and not vector_search and not graph_search:
gr.Warning("Please select at least one generate mode.")
return "", "", "", ""
- searcher = GraphRAG()
+ searcher = RAGPipeline()
if vector_search:
searcher.query_vector_index_for_rag()
if graph_search:
searcher.extract_keyword().match_keyword_to_id().query_graph_for_rag()
+ # TODO: add more user-defined search strategies
searcher.merge_dedup_rerank().synthesize_answer(
raw_answer=raw_answer,
vector_only_answer=vector_only_answer,
graph_only_answer=graph_only_answer,
graph_vector_answer=graph_vector_answer,
+ answer_prompt=answer_prompt
)
try:
@@ -421,7 +423,7 @@ def init_rag_ui() -> gr.Interface:
with gr.Row():
input_file = gr.File(
value=[os.path.join(resource_path, "demo", "test.txt")],
- label="Doc(s) (multi-files can be selected together)",
+ label="Docs (multi-files can be selected together)",
file_count="multiple")
input_schema = gr.Textbox(value=schema, label="Schema")
info_extract_template = gr.Textbox(value=SCHEMA_EXAMPLE_PROMPT,
label="Info extract head")
@@ -452,6 +454,9 @@ def init_rag_ui() -> gr.Interface:
graph_only_radio = gr.Radio(choices=[True, False],
value=False, label="Graph-only Answer")
graph_vector_radio = gr.Radio(choices=[True, False],
value=False, label="Graph-Vector Answer")
btn = gr.Button("Answer Question")
+ from hugegraph_llm.operators.llm_op.answer_synthesize import
DEFAULT_ANSWER_TEMPLATE
+ answer_prompt_input =
gr.Textbox(value=DEFAULT_ANSWER_TEMPLATE, label="Custom Prompt",
+ show_copy_button=True)
btn.click( # pylint: disable=no-member
fn=rag_answer,
inputs=[
@@ -460,6 +465,7 @@ def init_rag_ui() -> gr.Interface:
vector_only_radio,
graph_only_radio,
graph_vector_radio,
+ answer_prompt_input,
],
outputs=[raw_out, vector_only_out, graph_only_out,
graph_vector_out],
)
@@ -498,7 +504,6 @@ if __name__ == "__main__":
# TODO: support multi-user login when need
app = gr.mount_gradio_app(app, hugegraph_llm, path="/", auth=("rag",
os.getenv("TOKEN")) if auth_enabled else None)
- # Note: set reload to False in production environment
- uvicorn.run(app, host=args.host, port=args.port)
# TODO: we can't use reload now due to the config 'app' of uvicorn.run
- # uvicorn.run("rag_web_demo:app", host="0.0.0.0", port=8001, reload=True)
+ # ❎:f'{__name__}:app' / rag_web_demo:app /
hugegraph_llm.demo.rag_web_demo:app
+ uvicorn.run(app, host=args.host, port=args.port, reload=False)
diff --git
a/hugegraph-llm/src/hugegraph_llm/operators/common_op/merge_dedup_rerank.py
b/hugegraph-llm/src/hugegraph_llm/operators/common_op/merge_dedup_rerank.py
index e012479..a34cbed 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/common_op/merge_dedup_rerank.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/common_op/merge_dedup_rerank.py
@@ -34,16 +34,16 @@ class MergeDedupRerank:
self,
embedding: BaseEmbedding,
topk: int = 10,
- policy: Literal["bleu", "priority"] = "bleu"
+ strategy: Literal["bleu", "priority"] = "bleu"
):
self.embedding = embedding
self.topk = topk
- if policy == "bleu":
+ if strategy == "bleu":
self.rerank_func = self._bleu_rerank
- elif policy == "priority":
+ elif strategy == "priority":
self.rerank_func = self._priority_rerank
else:
- raise ValueError(f"Unimplemented policy {policy}.")
+ raise ValueError(f"Unimplemented rerank strategy {strategy}.")
def run(self, context: Dict[str, Any]) -> Dict[str, Any]:
query = context.get("query")
diff --git a/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py
b/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py
index cacac61..444ff45 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py
@@ -19,12 +19,12 @@
import time
from typing import Dict, Any, Optional, List, Literal
-from hugegraph_llm.models.llms.base import BaseLLM
from hugegraph_llm.models.embeddings.base import BaseEmbedding
-from hugegraph_llm.models.llms.init_llm import LLMs
from hugegraph_llm.models.embeddings.init_embedding import Embeddings
-from hugegraph_llm.operators.common_op.print_result import PrintResult
+from hugegraph_llm.models.llms.base import BaseLLM
+from hugegraph_llm.models.llms.init_llm import LLMs
from hugegraph_llm.operators.common_op.merge_dedup_rerank import
MergeDedupRerank
+from hugegraph_llm.operators.common_op.print_result import PrintResult
from hugegraph_llm.operators.document_op.word_extract import WordExtract
from hugegraph_llm.operators.hugegraph_op.graph_rag_query import GraphRAGQuery
from hugegraph_llm.operators.index_op.semantic_id_query import SemanticIdQuery
@@ -34,17 +34,35 @@ from hugegraph_llm.operators.llm_op.keyword_extract import
KeywordExtract
from hugegraph_llm.utils.log import log
-class GraphRAG:
+class RAGPipeline:
+ """
+ RAGPipeline is a (core)class that encapsulates a series of operations for
extracting information from text,
+ querying graph databases and vector indices, merging and re-ranking
results, and generating answers.
+ """
+
def __init__(self, llm: Optional[BaseLLM] = None, embedding:
Optional[BaseEmbedding] = None):
+ """
+ Initialize the RAGPipeline with optional LLM and embedding models.
+
+ :param llm: Optional LLM model to use.
+ :param embedding: Optional embedding model to use.
+ """
self._llm = llm or LLMs().get_llm()
self._embedding = embedding or Embeddings().get_embedding()
self._operators: List[Any] = []
def extract_word(
- self,
- text: Optional[str] = None,
- language: str = "english",
+ self,
+ text: Optional[str] = None,
+ language: str = "english",
):
+ """
+ Add a word extraction operator to the pipeline.
+
+ :param text: Text to extract words from.
+ :param language: Language of the text.
+ :return: Self-instance for chaining.
+ """
self._operators.append(
WordExtract(
text=text,
@@ -54,13 +72,23 @@ class GraphRAG:
return self
def extract_keyword(
- self,
- text: Optional[str] = None,
- max_keywords: int = 5,
- language: str = "english",
- extract_template: Optional[str] = None,
- expand_template: Optional[str] = None,
+ self,
+ text: Optional[str] = None,
+ max_keywords: int = 5,
+ language: str = "english",
+ extract_template: Optional[str] = None,
+ expand_template: Optional[str] = None,
):
+ """
+ Add a keyword extraction operator to the pipeline.
+
+ :param text: Text to extract keywords from.
+ :param max_keywords: Maximum number of keywords to extract.
+ :param language: Language of the text.
+ :param extract_template: Template for keyword extraction.
+ :param expand_template: Template for keyword expansion.
+ :return: Self-instance for chaining.
+ """
self._operators.append(
KeywordExtract(
text=text,
@@ -78,6 +106,12 @@ class GraphRAG:
topk_per_keyword: int = 1,
topk_per_query: int = 10
):
+ """
+ Add a semantic ID query operator to the pipeline.
+
+ :param topk_per_keyword: Top K results per keyword.
+ :return: Self-instance for chaining.
+ """
self._operators.append(
SemanticIdQuery(
embedding=self._embedding,
@@ -94,6 +128,14 @@ class GraphRAG:
max_items: int = 30,
prop_to_match: Optional[str] = None,
):
+ """
+ Add a graph RAG query operator to the pipeline.
+
+ :param max_deep: Maximum depth for the graph query.
+ :param max_items: Maximum number of items to retrieve.
+ :param prop_to_match: Property to match in the graph.
+ :return: Self-instance for chaining.
+ """
self._operators.append(
GraphRAGQuery(
max_deep=max_deep,
@@ -107,6 +149,12 @@ class GraphRAG:
self,
max_items: int = 3
):
+ """
+ Add a vector index query operator to the pipeline.
+
+ :param max_items: Maximum number of items to retrieve.
+ :return: Self-instance for chaining.
+ """
self._operators.append(
VectorIndexQuery(
embedding=self._embedding,
@@ -116,37 +164,59 @@ class GraphRAG:
return self
def merge_dedup_rerank(self):
- self._operators.append(
- MergeDedupRerank(
- embedding=self._embedding,
- )
- )
+ """
+ Add a merge, deduplication, and rerank operator to the pipeline.
+
+ :return: Self-instance for chaining.
+ """
+ self._operators.append(MergeDedupRerank(embedding=self._embedding))
return self
def synthesize_answer(
- self,
- raw_answer: bool = False,
- vector_only_answer: bool = True,
- graph_only_answer: bool = False,
- graph_vector_answer: bool = False,
- prompt_template: Optional[str] = None,
+ self,
+ raw_answer: bool = False,
+ vector_only_answer: bool = True,
+ graph_only_answer: bool = False,
+ graph_vector_answer: bool = False,
+ answer_prompt: Optional[str] = None,
):
+ """
+ Add an answer synthesis operator to the pipeline.
+
+ :param raw_answer: Whether to return raw answers.
+ :param vector_only_answer: Whether to return vector-only answers.
+ :param graph_only_answer: Whether to return graph-only answers.
+ :param graph_vector_answer: Whether to return graph-vector combined
answers.
+ :param answer_prompt: Template for the answer synthesis prompt.
+ :return: Self-instance for chaining.
+ """
self._operators.append(
AnswerSynthesize(
- raw_answer = raw_answer,
- vector_only_answer = vector_only_answer,
- graph_only_answer = graph_only_answer,
- graph_vector_answer = graph_vector_answer,
- prompt_template=prompt_template,
+ raw_answer=raw_answer,
+ vector_only_answer=vector_only_answer,
+ graph_only_answer=graph_only_answer,
+ graph_vector_answer=graph_vector_answer,
+ prompt_template=answer_prompt,
)
)
return self
def print_result(self):
+ """
+ Add a print result operator to the pipeline.
+
+ :return: Self-instance for chaining.
+ """
self._operators.append(PrintResult())
return self
def run(self, **kwargs) -> Dict[str, Any]:
+ """
+ Execute all operators in the pipeline in sequence.
+
+ :param kwargs: Additional context to pass to operators.
+ :return: Final context after all operators have been executed.
+ """
if len(self._operators) == 0:
self.extract_keyword().query_graph_for_rag().synthesize_answer()
diff --git
a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py
b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py
index f3803c7..a52fdb6 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py
@@ -22,20 +22,25 @@ from typing import Any, Dict, Optional
from hugegraph_llm.models.llms.base import BaseLLM
from hugegraph_llm.models.llms.init_llm import LLMs
-# TODO: we need enhance the template to answer the question
-DEFAULT_ANSWER_SYNTHESIZE_TEMPLATE_TMPL = (
- "Context information is below.\n"
- "---------------------\n"
- "{context_str}\n"
- "---------------------\n"
- "You need to refer to the context based on the following priority:\n"
- "1. Graph recall > vector recall\n"
- "2. Exact recall > Fuzzy recall\n"
- "3. Independent vertex > 1-depth neighbor> 2-depth neighbors\n"
- "Given the context information and not prior knowledge, answer the
query.\n"
- "Query: {query_str}\n"
- "Answer: "
-)
+# TODO: we need enhance the template to answer the question (put it in a
separate file)
+DEFAULT_ANSWER_TEMPLATE = f"""
+You are an expert in knowledge graphs and natural language processing.
+Your task is to provide a precise and accurate answer based on the given
context.
+
+Context information is below.
+---------------------
+{{context_str}}
+---------------------
+Please refer to the context based on the following priority:
+1. Graph data > vector data
+2. Precise data > fuzzy data
+3. One-depth neighbors > two-depth neighbors
+
+Given the context information and without using fictive knowledge,
+answer the following query in a concise and professional manner.
+Query: {{query_str}}
+Answer:
+"""
class AnswerSynthesize:
@@ -53,7 +58,7 @@ class AnswerSynthesize:
graph_vector_answer: bool = False,
):
self._llm = llm
- self._prompt_template = prompt_template or
DEFAULT_ANSWER_SYNTHESIZE_TEMPLATE_TMPL
+ self._prompt_template = prompt_template or DEFAULT_ANSWER_TEMPLATE
self._question = question
self._context_body = context_body
self._context_head = context_head
diff --git
a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/property_graph_extract.py
b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/property_graph_extract.py
index 171b48a..e4d6dd2 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/property_graph_extract.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/property_graph_extract.py
@@ -24,6 +24,7 @@ from hugegraph_llm.models.llms.base import BaseLLM
from hugegraph_llm.document.chunk_split import ChunkSplitter
from hugegraph_llm.utils.log import log
+# TODO: put in a separate file for users to customize the content
SCHEMA_EXAMPLE_PROMPT = """## Main Task
Given the following graph schema and a piece of text, your task is to analyze
the text and extract information that fits into the schema's structure,
formatting the information into vertices and edges as specified.