This is an automated email from the ASF dual-hosted git repository.
jin pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-hugegraph-ai.git
The following commit(s) were added to refs/heads/main by this push:
new ef56263 refactor(llm): merge & update keyword extraction logic/prompt
(#120)
ef56263 is described below
commit ef56263052883e8da5dbd9629d76931b6a0ecb00
Author: HaoJin Yang <[email protected]>
AuthorDate: Wed Nov 27 17:03:37 2024 +0800
refactor(llm): merge & update keyword extraction logic/prompt (#120)
* update english prompt
---------
Co-authored-by: imbajin <[email protected]>
---
hugegraph-llm/src/hugegraph_llm/config/config.py | 6 ++
.../src/hugegraph_llm/config/config_data.py | 48 ++++++++++++
.../src/hugegraph_llm/demo/rag_demo/app.py | 7 +-
.../hugegraph_llm/demo/rag_demo/configs_block.py | 2 +-
.../src/hugegraph_llm/demo/rag_demo/rag_block.py | 19 +++--
.../operators/common_op/nltk_helper.py | 2 +-
.../src/hugegraph_llm/operators/graph_rag_task.py | 3 -
.../operators/llm_op/keyword_extract.py | 87 ++++++++--------------
8 files changed, 104 insertions(+), 70 deletions(-)
diff --git a/hugegraph-llm/src/hugegraph_llm/config/config.py
b/hugegraph-llm/src/hugegraph_llm/config/config.py
index 499d6a0..8209463 100644
--- a/hugegraph-llm/src/hugegraph_llm/config/config.py
+++ b/hugegraph-llm/src/hugegraph_llm/config/config.py
@@ -124,6 +124,9 @@ class PromptConfig(PromptData):
"\n".join([f" {line}" for line in
self.custom_rerank_info.splitlines()])
)
indented_default_answer_template = "\n".join([f" {line}" for line
in self.answer_prompt.splitlines()])
+ indented_keywords_extract_template = (
+ "\n".join([f" {line}" for line in
self.keywords_extract_prompt.splitlines()])
+ )
# This can be extended to add storage fields according to the data
needs to be stored
yaml_content = f"""graph_schema: |
@@ -141,6 +144,9 @@ custom_rerank_info: |
answer_prompt: |
{indented_default_answer_template}
+keywords_extract_prompt: |
+{indented_keywords_extract_template}
+
"""
with open(yaml_file_path, "w", encoding="utf-8") as file:
file.write(yaml_content)
diff --git a/hugegraph-llm/src/hugegraph_llm/config/config_data.py
b/hugegraph-llm/src/hugegraph_llm/config/config_data.py
index 3e5711c..78ff0cf 100644
--- a/hugegraph-llm/src/hugegraph_llm/config/config_data.py
+++ b/hugegraph-llm/src/hugegraph_llm/config/config_data.py
@@ -218,3 +218,51 @@ Meet Sarah, a 30-year-old attorney, and her roommate,
James, whom she's shared a
]
}
"""
+
+ # Extracted from llm_op/keyword_extract.py
+ keywords_extract_prompt = """指令:
+请对以下文本执行以下任务:
+1. 从文本中提取关键词:
+ - 最少 0 个,最多 {max_keywords} 个。
+ - 关键词应为具有完整语义的词语或短语,确保信息完整。
+2. 识别需改写的关键词:
+ - 从提取的关键词中,识别那些在原语境中具有歧义或存在信息缺失的关键词。
+3. 生成同义词:
+ - 对这些需改写的关键词,生成其在给定语境下的同义词或含义相近的词语。
+ - 使用生成的同义词替换原文中的相应关键词。
+ - 如果某个关键词没有合适的同义词,则保留该关键词不变。
+要求:
+- 关键词应为有意义且具体的实体,避免使用无意义或过于宽泛的词语,或单字符的词(例如:“物品”、“动作”、“效果”、“作用”、“的”、“他”)。
+- 优先提取主语、动词和宾语,避免提取虚词或助词。
+- 保持语义完整性: 抽取的关键词应尽量保持关键词在原语境中语义和信息的完整性(例如:“苹果电脑”应作为一个整体被抽取,而不是被分为“苹果”和“电脑”)。
+- 避免泛化: 不要扩展为不相关的泛化类别。
+注意:
+- 仅考虑语境相关的同义词: 只需考虑给定语境下的关键词的语义近义词和具有类似含义的其他词语。
+- 调整关键词长度:
如果关键词相对宽泛,可以根据语境适当增加单个关键词的长度(例如:“违法行为”可以作为一个单独的关键词被抽取,或抽取为“违法”,但不应拆分为“违法”和“行为”)。
+输出格式:
+- 仅输出一行内容, 以 KEYWORDS: 为前缀,后跟所有关键词或对应的同义词,之间用逗号分隔。抽取的关键词中不允许出现空格或空字符
+- 格式示例:
+KEYWORDS:关键词1,关键词2,...,关键词n
+文本:
+{question}
+"""
+
+ # keywords_extract_prompt_EN = """
+# Instruction:
+# Please perform the following tasks on the text below:
+# 1. Extract Keywords and Generate Synonyms from text:
+# - At least 0, at most {max_keywords} keywords.
+# - For each keyword, generate its synonyms or possible variant forms.
+# Requirements:
+# - Keywords should be meaningful and specific entities; avoid using
meaningless or overly broad terms (e.g., “object,” “the,” “he”).
+# - Prioritize extracting subjects, verbs, and objects; avoid extracting
function words or auxiliary words.
+# - Do not expand into unrelated generalized categories.
+# Note:
+# - Only consider semantic synonyms and other words with similar meanings in
the given context.
+# Output Format:
+# - Output only one line, prefixed with KEYWORDS:, followed by all keywords
and synonyms, separated by commas.No spaces or empty characters are allowed in
the extracted keywords.
+# - Format example:
+# KEYWORDS: keyword1, keyword2, ..., keywordn, synonym1, synonym2, ...,
synonymn
+# Text:
+# {question}
+# """
diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/app.py
b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/app.py
index 84a25b8..f5a067a 100644
--- a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/app.py
+++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/app.py
@@ -92,7 +92,7 @@ def init_rag_ui() -> gr.Interface:
with gr.Tab(label="1. Build RAG Index 💡"):
textbox_input_schema, textbox_info_extract_template =
create_vector_graph_block()
with gr.Tab(label="2. (Graph)RAG & User Functions 📖"):
- textbox_inp, textbox_answer_prompt_input = create_rag_block()
+ textbox_inp, textbox_answer_prompt_input,
textbox_keywords_extract_prompt_input = create_rag_block()
with gr.Tab(label="3. Graph Tools 🚧"):
create_other_block()
with gr.Tab(label="4. Admin Tools ⚙️"):
@@ -105,7 +105,7 @@ def init_rag_ui() -> gr.Interface:
return (
settings.graph_ip, settings.graph_port, settings.graph_name,
settings.graph_user,
settings.graph_pwd, settings.graph_space, prompt.graph_schema,
prompt.extract_graph_prompt,
- prompt.default_question, prompt.answer_prompt
+ prompt.default_question, prompt.answer_prompt,
prompt.keywords_extract_prompt
)
hugegraph_llm_ui.load(fn=refresh_ui_config_prompt, outputs=[
@@ -120,7 +120,8 @@ def init_rag_ui() -> gr.Interface:
textbox_info_extract_template,
textbox_inp,
- textbox_answer_prompt_input
+ textbox_answer_prompt_input,
+ textbox_keywords_extract_prompt_input
])
return hugegraph_llm_ui
diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/configs_block.py
b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/configs_block.py
index 39b036b..c35a33f 100644
--- a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/configs_block.py
+++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/configs_block.py
@@ -251,7 +251,7 @@ def create_configs_block() -> list:
llm_config_button = gr.Button("Apply configuration")
llm_config_button.click(apply_llm_config_with_chat_op,
inputs=llm_config_input)
- with gr.Tab(label='extract'):
+ with gr.Tab(label='mini_tasks'):
extract_llm_dropdown = gr.Dropdown(choices=["openai",
"qianfan_wenxin", "ollama/local"],
value=getattr(settings, f"extract_llm_type"),
label=f"type")
apply_llm_config_with_extract_op = partial(apply_llm_config,
"extract")
diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/rag_block.py
b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/rag_block.py
index c4263f4..506e19d 100644
--- a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/rag_block.py
+++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/rag_block.py
@@ -40,6 +40,7 @@ def rag_answer(
near_neighbor_first: bool,
custom_related_information: str,
answer_prompt: str,
+ keywords_extract_prompt: str,
) -> Tuple:
"""
Generate an answer using the RAG (Retrieval-Augmented Generation) pipeline.
@@ -49,11 +50,12 @@ def rag_answer(
4. Synthesize the final answer.
5. Run the pipeline and return the results.
"""
- should_update_prompt = prompt.default_question != text or
prompt.answer_prompt != answer_prompt
+ should_update_prompt = prompt.default_question != text or
prompt.answer_prompt != answer_prompt or prompt.keywords_extract_prompt !=
keywords_extract_prompt
if should_update_prompt or prompt.custom_rerank_info !=
custom_related_information:
prompt.custom_rerank_info = custom_related_information
prompt.default_question = text
prompt.answer_prompt = answer_prompt
+ prompt.keywords_extract_prompt = keywords_extract_prompt
prompt.update_yaml_file()
vector_search = vector_only_answer or graph_vector_answer
@@ -66,7 +68,7 @@ def rag_answer(
if vector_search:
rag.query_vector_index()
if graph_search:
- rag.extract_keywords().keywords_to_vid().query_graphdb()
+
rag.extract_keywords(extract_template=keywords_extract_prompt).keywords_to_vid().query_graphdb()
# TODO: add more user-defined search strategies
rag.merge_dedup_rerank(graph_ratio, rerank_method, near_neighbor_first,
custom_related_information)
rag.synthesize_answer(raw_answer, vector_only_answer, graph_only_answer,
graph_vector_answer, answer_prompt)
@@ -101,7 +103,10 @@ def create_rag_block():
graph_vector_out = gr.Textbox(label="Graph-Vector Answer",
show_copy_button=True)
answer_prompt_input = gr.Textbox(
- value=prompt.answer_prompt, label="Custom Prompt",
show_copy_button=True, lines=7
+ value=prompt.answer_prompt, label="Query Prompt",
show_copy_button=True, lines=7
+ )
+ keywords_extract_prompt_input = gr.Textbox(
+ value=prompt.keywords_extract_prompt, label="Keywords
Extraction Prompt", show_copy_button=True, lines=7
)
with gr.Column(scale=1):
with gr.Row():
@@ -134,7 +139,7 @@ def create_rag_block():
)
custom_related_information = gr.Text(
prompt.custom_rerank_info,
- label="Custom related information(Optional)",
+ label="Query related information(Optional)",
)
btn = gr.Button("Answer Question", variant="primary")
@@ -151,6 +156,7 @@ def create_rag_block():
near_neighbor_first,
custom_related_information,
answer_prompt_input,
+ keywords_extract_prompt_input
],
outputs=[raw_out, vector_only_out, graph_only_out, graph_vector_out],
)
@@ -209,6 +215,7 @@ def create_rag_block():
near_neighbor_first: bool,
custom_related_information: str,
answer_prompt: str,
+ keywords_extract_prompt: str,
progress=gr.Progress(track_tqdm=True),
answer_max_line_count: int = 1,
):
@@ -227,6 +234,7 @@ def create_rag_block():
near_neighbor_first,
custom_related_information,
answer_prompt,
+ keywords_extract_prompt,
)
df.at[index, "Basic LLM Answer"] = basic_llm_answer
df.at[index, "Vector-only Answer"] = vector_only_answer
@@ -259,10 +267,11 @@ def create_rag_block():
near_neighbor_first,
custom_related_information,
answer_prompt_input,
+ keywords_extract_prompt_input,
answer_max_line_count,
],
outputs=[qa_dataframe, gr.File(label="Download Answered File",
min_width=40)],
)
questions_file.change(read_file_to_excel, questions_file, [qa_dataframe,
answer_max_line_count])
answer_max_line_count.change(change_showing_excel, answer_max_line_count,
qa_dataframe)
- return inp, answer_prompt_input
+ return inp, answer_prompt_input, keywords_extract_prompt_input
diff --git a/hugegraph-llm/src/hugegraph_llm/operators/common_op/nltk_helper.py
b/hugegraph-llm/src/hugegraph_llm/operators/common_op/nltk_helper.py
index 196fc3a..797ea70 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/common_op/nltk_helper.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/common_op/nltk_helper.py
@@ -33,7 +33,7 @@ class NLTKHelper:
"chinese": None,
}
- def stopwords(self, lang: str = "english") -> List[str]:
+ def stopwords(self, lang: str = "chinese") -> List[str]:
"""Get stopwords."""
nltk.data.path.append(os.path.join(resource_path, "nltk_data"))
if self._stopwords.get(lang) is None:
diff --git a/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py
b/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py
index 0dd4d7d..07dc770 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py
@@ -69,7 +69,6 @@ class RAGPipeline:
max_keywords: int = 5,
language: str = "english",
extract_template: Optional[str] = None,
- expand_template: Optional[str] = None,
):
"""
Add a keyword extraction operator to the pipeline.
@@ -78,7 +77,6 @@ class RAGPipeline:
:param max_keywords: Maximum number of keywords to extract.
:param language: Language of the text.
:param extract_template: Template for keyword extraction.
- :param expand_template: Template for keyword expansion.
:return: Self-instance for chaining.
"""
self._operators.append(
@@ -87,7 +85,6 @@ class RAGPipeline:
max_keywords=max_keywords,
language=language,
extract_template=extract_template,
- expand_template=expand_template,
)
)
return self
diff --git
a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/keyword_extract.py
b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/keyword_extract.py
index c47a79d..85fe995 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/keyword_extract.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/keyword_extract.py
@@ -17,46 +17,32 @@
import re
+import time
from typing import Set, Dict, Any, Optional
+from hugegraph_llm.config import prompt
from hugegraph_llm.models.llms.base import BaseLLM
from hugegraph_llm.models.llms.init_llm import LLMs
from hugegraph_llm.operators.common_op.nltk_helper import NLTKHelper
from hugegraph_llm.utils.log import log
-KEYWORDS_EXTRACT_TPL = """Extract {max_keywords} keywords from the text:
-{question}
-
-1. Keywords can't contain meaningless/broad words(e.g action/relation/thing),
must represent certain entities,
-2. Better to extract subject/verb/object and don't extract particles, don't
extend to synonyms/general categories.
-Provide keywords in the following comma-separated format: 'KEYWORDS:
<keywords>'
-"""
-
-KEYWORDS_EXPAND_TPL = """Generate synonyms or possible form of keywords up to
{max_keywords} in total,
-considering possible cases of capitalization, pluralization, common
expressions, etc.
-Provide all synonyms of keywords in comma-separated format: 'SYNONYMS:
<keywords>'
-Note, result should be in one-line with only one 'SYNONYMS: ' prefix
-----
-KEYWORDS: {question}
-----"""
+KEYWORDS_EXTRACT_TPL = prompt.keywords_extract_prompt
class KeywordExtract:
def __init__(
- self,
- text: Optional[str] = None,
- llm: Optional[BaseLLM] = None,
- max_keywords: int = 5,
- extract_template: Optional[str] = None,
- expand_template: Optional[str] = None,
- language: str = "english",
+ self,
+ text: Optional[str] = None,
+ llm: Optional[BaseLLM] = None,
+ max_keywords: int = 5,
+ extract_template: Optional[str] = None,
+ language: str = "english",
):
self._llm = llm
self._query = text
self._language = language.lower()
self._max_keywords = max_keywords
self._extract_template = extract_template or KEYWORDS_EXTRACT_TPL
- self._expand_template = expand_template or KEYWORDS_EXPAND_TPL
def run(self, context: Dict[str, Any]) -> Dict[str, Any]:
if self._query is None:
@@ -69,61 +55,48 @@ class KeywordExtract:
self._llm = LLMs().get_extract_llm()
assert isinstance(self._llm, BaseLLM), "Invalid LLM Object."
- if isinstance(context.get("language"), str):
- self._language = context["language"].lower()
- else:
- context["language"] = self._language
-
- if isinstance(context.get("max_keywords"), int):
- self._max_keywords = context["max_keywords"]
+ self._language = context.get("language", self._language).lower()
+ self._max_keywords = context.get("max_keywords", self._max_keywords)
- prompt = self._extract_template.format(question=self._query,
max_keywords=self._max_keywords)
+ prompt = f"{self._extract_template.format(question=self._query,
max_keywords=self._max_keywords)}"
+ start_time = time.perf_counter()
response = self._llm.generate(prompt=prompt)
+ end_time = time.perf_counter()
+ log.debug("Keyword extraction time: %.2f seconds", end_time -
start_time)
keywords = self._extract_keywords_from_response(
response=response, lowercase=False, start_token="KEYWORDS:"
)
- keywords.union(self._expand_synonyms(keywords=keywords))
keywords = {k.replace("'", "") for k in keywords}
context["keywords"] = list(keywords)
log.info("User Query: %s\nKeywords: %s", self._query,
context["keywords"])
- # extracting keywords & expanding synonyms increase the call count by 2
- context["call_count"] = context.get("call_count", 0) + 2
+ # extracting keywords & expanding synonyms increase the call count by 1
+ context["call_count"] = context.get("call_count", 0) + 1
return context
- def _expand_synonyms(self, keywords: Set[str]) -> Set[str]:
- prompt = self._expand_template.format(question=str(keywords),
max_keywords=self._max_keywords)
- response = self._llm.generate(prompt=prompt)
- keywords = self._extract_keywords_from_response(
- response=response, lowercase=False, start_token="SYNONYMS:"
- )
- return keywords
-
def _extract_keywords_from_response(
- self,
- response: str,
- lowercase: bool = True,
- start_token: str = "",
+ self,
+ response: str,
+ lowercase: bool = True,
+ start_token: str = "",
) -> Set[str]:
keywords = []
+ # use re.escape(start_token) if start_token contains special chars
like */&/^ etc.
matches = re.findall(rf'{start_token}[^\n]+\n?', response)
for match in matches:
- match = match[len(start_token):]
- for k in re.split(r"[,,]+", match):
- k = k.strip()
- if len(k) > 1:
- if lowercase:
- keywords.append(k.lower())
- else:
- keywords.append(k)
+ match = match[len(start_token):].strip()
+ keywords.extend(
+ k.lower() if lowercase else k
+ for k in re.split(r"[,,]+", match)
+ if len(k.strip()) > 1
+ )
# if the keyword consists of multiple words, split into sub-words
(removing stopwords)
- results = set()
+ results = set(keywords)
for token in keywords:
- results.add(token)
sub_tokens = re.findall(r"\w+", token)
if len(sub_tokens) > 1:
- results.update({w for w in sub_tokens if w not in
NLTKHelper().stopwords(lang=self._language)})
+ results.update(w for w in sub_tokens if w not in
NLTKHelper().stopwords(lang=self._language))
return results