This is an automated email from the ASF dual-hosted git repository.
vaughn pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-hugegraph-ai.git
The following commit(s) were added to refs/heads/main by this push:
new c519ec0 feat(llm): support graph_rag_recall api (#79)
c519ec0 is described below
commit c519ec0ec13c7d8c12b92234a99228e59e7d3c2f
Author: imbajin <[email protected]>
AuthorDate: Wed Sep 18 18:48:53 2024 +0800
feat(llm): support graph_rag_recall api (#79)
---
hugegraph-llm/README.md | 13 +-
.../src/hugegraph_llm/api/models/rag_requests.py | 15 ++-
hugegraph-llm/src/hugegraph_llm/api/rag_api.py | 52 +++++++-
.../src/hugegraph_llm/demo/rag_web_demo.py | 53 ++++----
.../src/hugegraph_llm/models/llms/init_llm.py | 4 +-
.../src/hugegraph_llm/models/llms/openai.py | 4 +-
.../operators/document_op/word_extract.py | 3 +-
.../src/hugegraph_llm/operators/graph_rag_task.py | 37 ++----
.../operators/hugegraph_op/graph_rag_query.py | 142 ++++++++++-----------
.../operators/hugegraph_op/schema_manager.py | 2 -
.../operators/index_op/semantic_id_query.py | 12 +-
.../operators/llm_op/answer_synthesize.py | 8 +-
.../operators/llm_op/disambiguate_data.py | 2 +-
.../operators/llm_op/keyword_extract.py | 41 +++---
.../src/hugegraph_llm/utils/decorators.py | 2 +-
.../src/hugegraph_llm/utils/graph_index_utils.py | 3 +-
16 files changed, 209 insertions(+), 184 deletions(-)
diff --git a/hugegraph-llm/README.md b/hugegraph-llm/README.md
index 25de21b..5ca1fca 100644
--- a/hugegraph-llm/README.md
+++ b/hugegraph-llm/README.md
@@ -53,7 +53,7 @@ Refer to
[docker-link](https://hub.docker.com/r/hugegraph/hugegraph) & [deploy-d
python3 -m hugegraph_llm.demo.gremlin_generate_web_demo
```
-7. After starting the web demo, the config file `.env` will be automatically
generated. You can modify its content in the web page. Or modify the file
directly and restart the web application.
+7. After starting the web demo, the config file `.env` will be automatically
generated. You can modify its content on the web page. Or modify the file
directly and restart the web application.
(Optional)To regenerate the config file, you can use `config.generate`
with `-u` or `--update`.
```bash
@@ -130,22 +130,19 @@ The methods of the `KgBuilder` class can be chained
together to perform a sequen
Run example like `python3 ./hugegraph_llm/examples/graph_rag_test.py`
-The `GraphRAG` class is used to integrate HugeGraph with large language models
to provide retrieval-augmented generation capabilities.
+The `RAGPipeline` class is used to integrate HugeGraph with large language
models to provide retrieval-augmented generation capabilities.
Here is a brief usage guide:
1. **Extract Keyword:**: Extract keywords and expand synonyms.
-
+
```python
- graph_rag.extract_keyword(text="Tell me about Al Pacino.").print_result()
+ graph_rag.extract_keywords(text="Tell me about Al Pacino.").print_result()
```
2. **Query Graph for Rag**: Retrieve the corresponding keywords and their
multi-degree associated relationships from HugeGraph.
```python
- graph_rag.query_graph_for_rag(
- max_deep=2,
- max_items=30
- ).print_result()
+ graph_rag.query_graphdb(max_deep=2, max_items=30).print_result()
```
3. **Synthesize Answer**: Summarize the results and organize the language to
answer the question.
diff --git a/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py
b/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py
index a211bb8..0e7c666 100644
--- a/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py
+++ b/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py
@@ -15,20 +15,31 @@
# specific language governing permissions and limitations
# under the License.
-from typing import Optional
+from typing import Optional, Literal
from pydantic import BaseModel
class RAGRequest(BaseModel):
query: str
- raw_llm: Optional[bool] = True
+ raw_llm: Optional[bool] = False
vector_only: Optional[bool] = False
graph_only: Optional[bool] = False
graph_vector: Optional[bool] = False
+ graph_ratio: float = 0.5
+ rerank_method: Literal["bleu", "reranker"] = "bleu"
+ near_neighbor_first: bool = False
+ custom_related_information: str = None
answer_prompt: Optional[str] = None
+class GraphRAGRequest(BaseModel):
+ query: str
+ rerank_method: Literal["bleu", "reranker"] = "bleu"
+ near_neighbor_first: bool = False
+ custom_related_information: str = None
+
+
class GraphConfigRequest(BaseModel):
ip: str = "127.0.0.1"
port: str = "8080"
diff --git a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py
b/hugegraph-llm/src/hugegraph_llm/api/rag_api.py
index 64daf70..a0268e8 100644
--- a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py
+++ b/hugegraph-llm/src/hugegraph_llm/api/rag_api.py
@@ -14,22 +14,42 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+import json
+from typing import Literal
-from fastapi import status, APIRouter
+from fastapi import status, APIRouter, HTTPException
from hugegraph_llm.api.exceptions.rag_exceptions import generate_response
from hugegraph_llm.api.models.rag_requests import (
RAGRequest,
GraphConfigRequest,
LLMConfigRequest,
- RerankerConfigRequest,
+ RerankerConfigRequest, GraphRAGRequest,
)
from hugegraph_llm.api.models.rag_response import RAGResponse
from hugegraph_llm.config import settings
+from hugegraph_llm.utils.log import log
+
+
+def graph_rag_recall(
+ text: str,
+ rerank_method: Literal["bleu", "reranker"],
+ near_neighbor_first: bool,
+ custom_related_information: str
+) -> dict:
+ from hugegraph_llm.operators.graph_rag_task import RAGPipeline
+ rag = RAGPipeline()
+
rag.extract_keywords().keywords_to_vid().query_graphdb().merge_dedup_rerank(
+ rerank_method=rerank_method,
+ near_neighbor_first=near_neighbor_first,
+ custom_related_information=custom_related_information,
+ )
+ context = rag.run(verbose=True, query=text, graph_search=True)
+ return context
def rag_http_api(
- router: APIRouter, rag_answer_func, apply_graph_conf, apply_llm_conf,
apply_embedding_conf, apply_reranker_conf
+ router: APIRouter, rag_answer_func, apply_graph_conf, apply_llm_conf,
apply_embedding_conf, apply_reranker_conf
):
@router.post("/rag", status_code=status.HTTP_200_OK)
def rag_answer_api(req: RAGRequest):
@@ -41,6 +61,32 @@ def rag_http_api(
if getattr(req, key)
}
+ @router.post("/rag/graph", status_code=status.HTTP_200_OK)
+ def graph_rag_recall_api(req: GraphRAGRequest):
+ try:
+ result = graph_rag_recall(
+ text=req.query,
+ rerank_method=req.rerank_method,
+ near_neighbor_first=req.near_neighbor_first,
+ custom_related_information=req.custom_related_information
+ )
+ # TODO/FIXME: handle QianFanClient error (not dict..critical)
+ # log.critical(f"## {type(result)}, {json.dumps(result)}")
+ if isinstance(result, dict):
+ log.critical(f"##1. {type(result)}")
+ return {"graph_recall": result}
+ else:
+ log.critical(f"##2. {type(result)}")
+ return {"graph_recall": json.dumps(result)}
+
+ except TypeError as e:
+ log.error(f"TypeError in graph_rag_recall_api: {e}")
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e))
+ except Exception as e:
+ log.error(f"Unexpected error occurred: {e}")
+ raise
HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An
unexpected error occurred.")
+
+
@router.post("/config/graph", status_code=status.HTTP_201_CREATED)
def graph_config_api(req: GraphConfigRequest):
# Accept status code
diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py
b/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py
index 6dc60b1..b8c772a 100644
--- a/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py
+++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py
@@ -68,8 +68,16 @@ def rag_answer(
custom_related_information: str,
answer_prompt: str,
) -> Tuple:
-
- if prompt.default_question != text or prompt.custom_rerank_info !=
custom_related_information or prompt.answer_prompt != answer_prompt:
+ """
+ Generate an answer using the RAG (Retrieval-Augmented Generation) pipeline.
+ 1. Initialize the RAGPipeline.
+ 2. Select vector search or graph search based on parameters.
+ 3. Merge, deduplicate, and rerank the results.
+ 4. Synthesize the final answer.
+ 5. Run the pipeline and return the results.
+ """
+ should_update_prompt = prompt.default_question != text or
prompt.answer_prompt != answer_prompt
+ if should_update_prompt or prompt.custom_rerank_info !=
custom_related_information:
prompt.custom_rerank_info = custom_related_information
prompt.default_question = text
prompt.answer_prompt = answer_prompt
@@ -77,30 +85,23 @@ def rag_answer(
vector_search = vector_only_answer or graph_vector_answer
graph_search = graph_only_answer or graph_vector_answer
-
if raw_answer is False and not vector_search and not graph_search:
gr.Warning("Please select at least one generate mode.")
return "", "", "", ""
- searcher = RAGPipeline()
+
+ rag = RAGPipeline()
if vector_search:
- searcher.query_vector_index_for_rag()
+ rag.query_vector_index()
if graph_search:
- searcher.extract_keyword().match_keyword_to_id().query_graph_for_rag()
+ rag.extract_keywords().keywords_to_vid().query_graphdb()
# TODO: add more user-defined search strategies
- searcher.merge_dedup_rerank(
- graph_ratio, rerank_method, near_neighbor_first,
custom_related_information
- ).synthesize_answer(
- raw_answer=raw_answer,
- vector_only_answer=vector_only_answer,
- graph_only_answer=graph_only_answer,
- graph_vector_answer=graph_vector_answer,
- answer_prompt=answer_prompt,
- )
+ rag.merge_dedup_rerank(graph_ratio, rerank_method, near_neighbor_first,
custom_related_information)
+ rag.synthesize_answer(raw_answer, vector_only_answer, graph_only_answer,
graph_vector_answer, answer_prompt)
try:
- context = searcher.run(verbose=True, query=text,
vector_search=vector_search, graph_search=graph_search)
+ context = rag.run(verbose=True, query=text,
vector_search=vector_search, graph_search=graph_search)
if context.get("switch_to_bleu"):
- gr.Warning("Online reranker fails, automatically switches to local
bleu method.")
+ gr.Warning("Online reranker fails, automatically switches to local
bleu rerank.")
return (
context.get("raw_answer", ""),
context.get("vector_only_answer", ""),
@@ -108,10 +109,10 @@ def rag_answer(
context.get("graph_vector_answer", ""),
)
except ValueError as e:
- log.error(e)
+ log.critical(e)
raise gr.Error(str(e))
except Exception as e:
- log.error(e)
+ log.critical(e)
raise gr.Error(f"An unexpected error occurred: {str(e)}")
@@ -665,19 +666,13 @@ if __name__ == "__main__":
parser.add_argument("--port", type=int, default=8001, help="port")
args = parser.parse_args()
app = FastAPI()
- app_auth = APIRouter(dependencies=[Depends(authenticate)])
+ api_auth = APIRouter(dependencies=[Depends(authenticate)])
hugegraph_llm = init_rag_ui()
- rag_http_api(
- app_auth,
- rag_answer,
- apply_graph_config,
- apply_llm_config,
- apply_embedding_config,
- apply_reranker_config,
- )
+ rag_http_api(api_auth, rag_answer, apply_graph_config, apply_llm_config,
apply_embedding_config,
+ apply_reranker_config)
- app.include_router(app_auth)
+ app.include_router(api_auth)
auth_enabled = os.getenv("ENABLE_LOGIN", "False").lower() == "true"
log.info("(Status) Authentication is %s now.", "enabled" if auth_enabled
else "disabled")
# TODO: support multi-user login when need
diff --git a/hugegraph-llm/src/hugegraph_llm/models/llms/init_llm.py
b/hugegraph-llm/src/hugegraph_llm/models/llms/init_llm.py
index 22a82cb..2c90748 100644
--- a/hugegraph-llm/src/hugegraph_llm/models/llms/init_llm.py
+++ b/hugegraph-llm/src/hugegraph_llm/models/llms/init_llm.py
@@ -17,7 +17,7 @@
from hugegraph_llm.models.llms.ollama import OllamaClient
-from hugegraph_llm.models.llms.openai import OpenAIChat
+from hugegraph_llm.models.llms.openai import OpenAIClient
from hugegraph_llm.models.llms.qianfan import QianfanClient
from hugegraph_llm.config import settings
@@ -34,7 +34,7 @@ class LLMs:
secret_key=settings.qianfan_secret_key
)
if self.llm_type == "openai":
- return OpenAIChat(
+ return OpenAIClient(
api_key=settings.openai_api_key,
api_base=settings.openai_api_base,
model_name=settings.openai_language_model,
diff --git a/hugegraph-llm/src/hugegraph_llm/models/llms/openai.py
b/hugegraph-llm/src/hugegraph_llm/models/llms/openai.py
index 30ac805..bfdb83b 100644
--- a/hugegraph-llm/src/hugegraph_llm/models/llms/openai.py
+++ b/hugegraph-llm/src/hugegraph_llm/models/llms/openai.py
@@ -27,8 +27,8 @@ from hugegraph_llm.models.llms.base import BaseLLM
from hugegraph_llm.utils.log import log
-class OpenAIChat(BaseLLM):
- """Wrapper around OpenAI Chat large language models."""
+class OpenAIClient(BaseLLM):
+ """Wrapper for OpenAI Client."""
def __init__(
self,
diff --git
a/hugegraph-llm/src/hugegraph_llm/operators/document_op/word_extract.py
b/hugegraph-llm/src/hugegraph_llm/operators/document_op/word_extract.py
index 43d0a29..b8f40df 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/document_op/word_extract.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/document_op/word_extract.py
@@ -62,7 +62,8 @@ class WordExtract:
verbose = context.get("verbose") or False
if verbose:
- print(f"\033[92mKEYWORDS: {context['keywords']}\033[0m")
+ from hugegraph_llm.utils.log import log
+ log.info(f"KEYWORDS: {context['keywords']}")
return context
diff --git a/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py
b/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py
index 91bc7b3..dd75b18 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py
@@ -50,11 +50,7 @@ class RAGPipeline:
self._embedding = embedding or Embeddings().get_embedding()
self._operators: List[Any] = []
- def extract_word(
- self,
- text: Optional[str] = None,
- language: str = "english",
- ):
+ def extract_word(self, text: Optional[str] = None, language: str =
"english"):
"""
Add a word extraction operator to the pipeline.
@@ -62,15 +58,10 @@ class RAGPipeline:
:param language: Language of the text.
:return: Self-instance for chaining.
"""
- self._operators.append(
- WordExtract(
- text=text,
- language=language,
- )
- )
+ self._operators.append(WordExtract(text=text, language=language))
return self
- def extract_keyword(
+ def extract_keywords(
self,
text: Optional[str] = None,
max_keywords: int = 5,
@@ -99,7 +90,7 @@ class RAGPipeline:
)
return self
- def match_keyword_to_id(
+ def keywords_to_vid(
self,
by: Literal["query", "keywords"] = "keywords",
topk_per_keyword: int = 1,
@@ -107,8 +98,9 @@ class RAGPipeline:
):
"""
Add a semantic ID query operator to the pipeline.
-
+ :param by: Match by query or keywords.
:param topk_per_keyword: Top K results per keyword.
+ :param topk_per_query: Top K results per query.
:return: Self-instance for chaining.
"""
self._operators.append(
@@ -121,7 +113,7 @@ class RAGPipeline:
)
return self
- def query_graph_for_rag(
+ def query_graphdb(
self,
max_deep: int = 2,
max_items: int = 30,
@@ -136,15 +128,11 @@ class RAGPipeline:
:return: Self-instance for chaining.
"""
self._operators.append(
- GraphRAGQuery(
- max_deep=max_deep,
- max_items=max_items,
- prop_to_match=prop_to_match,
- )
+ GraphRAGQuery(max_deep=max_deep, max_items=max_items,
prop_to_match=prop_to_match)
)
return self
- def query_vector_index_for_rag(self, max_items: int = 3):
+ def query_vector_index(self, max_items: int = 3):
"""
Add a vector index query operator to the pipeline.
@@ -152,10 +140,7 @@ class RAGPipeline:
:return: Self-instance for chaining.
"""
self._operators.append(
- VectorIndexQuery(
- embedding=self._embedding,
- topk=max_items,
- )
+ VectorIndexQuery(embedding=self._embedding, topk=max_items, )
)
return self
@@ -230,7 +215,7 @@ class RAGPipeline:
:return: Final context after all operators have been executed.
"""
if len(self._operators) == 0:
- self.extract_keyword().query_graph_for_rag().synthesize_answer()
+ self.extract_keywords().query_graphdb().synthesize_answer()
context = kwargs
context["llm"] = self._llm
diff --git
a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py
b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py
index 2f08f11..cecb89c 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py
@@ -20,64 +20,58 @@ import re
from typing import Any, Dict, Optional, List, Set, Tuple
from hugegraph_llm.config import settings
+from hugegraph_llm.utils.log import log
from pyhugegraph.client import PyHugeClient
+VERTEX_QUERY_TPL = "g.V({keywords}).as('subj').toList()"
-class GraphRAGQuery:
- VERTEX_GREMLIN_QUERY_TEMPL = "g.V().hasId({keywords}).as('subj').toList()"
- # ID_RAG_GREMLIN_QUERY_TEMPL =
"g.V().hasId({keywords}).as('subj').repeat(bothE({edge_labels}).as('rel').otherV(
- # ).as('obj')).times({max_deep}).path().by(project('label', 'id',
'props').by(label()).by(id()).by(valueMap().by(
- # unfold()))).by(project('label', 'inV', 'outV',
'props').by(label()).by(inV().id()).by(outV().id()).by(valueMap(
- # ).by(unfold()))).limit({max_items}).toList()"
+# TODO: we could use a simpler query (like kneighbor-api to get the edges)
+# TODO: use dedup() to filter duplicate paths
+ID_QUERY_NEIGHBOR_TPL = """
+g.V({keywords}).as('subj')
+.repeat(
+ bothE({edge_labels}).as('rel').otherV().as('obj')
+).times({max_deep})
+.path()
+.by(project('label', 'id', 'props')
+ .by(label())
+ .by(id())
+ .by(valueMap().by(unfold()))
+)
+.by(project('label', 'inV', 'outV', 'props')
+ .by(label())
+ .by(inV().id())
+ .by(outV().id())
+ .by(valueMap().by(unfold()))
+)
+.limit({max_items})
+.toList()
+"""
+
+PROPERTY_QUERY_NEIGHBOR_TPL = """
+g.V().has('{prop}', within({keywords})).as('subj')
+.repeat(
+ bothE({edge_labels}).as('rel').otherV().as('obj')
+).times({max_deep})
+.path()
+.by(project('label', 'props')
+ .by(label())
+ .by(valueMap().by(unfold()))
+)
+.by(project('label', 'inV', 'outV', 'props')
+ .by(label())
+ .by(inV().values('{prop}'))
+ .by(outV().values('{prop}'))
+ .by(valueMap().by(unfold()))
+)
+.limit({max_items})
+.toList()
+"""
- # TODO: we could use a simpler query (like kneighbor-api to get the edges)
- ID_RAG_GREMLIN_QUERY_TEMPL = """
- g.V().hasId({keywords}).as('subj')
- .repeat(
- bothE({edge_labels}).as('rel').otherV().as('obj')
- ).times({max_deep})
- .path()
- .by(project('label', 'id', 'props')
- .by(label())
- .by(id())
- .by(valueMap().by(unfold()))
- )
- .by(project('label', 'inV', 'outV', 'props')
- .by(label())
- .by(inV().id())
- .by(outV().id())
- .by(valueMap().by(unfold()))
- )
- .limit({max_items})
- .toList()
- """
- PROP_RAG_GREMLIN_QUERY_TEMPL = """
- g.V().has('{prop}', within({keywords})).as('subj')
- .repeat(
- bothE({edge_labels}).as('rel').otherV().as('obj')
- ).times({max_deep})
- .path()
- .by(project('label', 'props')
- .by(label())
- .by(valueMap().by(unfold()))
- )
- .by(project('label', 'inV', 'outV', 'props')
- .by(label())
- .by(inV().values('{prop}'))
- .by(outV().values('{prop}'))
- .by(valueMap().by(unfold()))
- )
- .limit({max_items})
- .toList()
- """
+class GraphRAGQuery:
- def __init__(
- self,
- max_deep: int = 2,
- max_items: int = 30,
- prop_to_match: Optional[str] = None,
- ):
+ def __init__(self, max_deep: int = 2, max_items: int = 30, prop_to_match:
Optional[str] = None):
self._client = PyHugeClient(
settings.graph_ip,
settings.graph_port,
@@ -106,7 +100,7 @@ class GraphRAGQuery:
assert self._client is not None, "No valid graph to search."
keywords = context.get("keywords")
- entrance_vids = context.get("entrance_vids")
+ match_vids = context.get("match_vids")
if isinstance(context.get("max_deep"), int):
self._max_deep = context["max_deep"]
@@ -120,21 +114,24 @@ class GraphRAGQuery:
use_id_to_match = self._prop_to_match is None
if use_id_to_match:
- if not entrance_vids:
+ if not match_vids:
return context
- rag_gremlin_query =
self.VERTEX_GREMLIN_QUERY_TEMPL.format(keywords=entrance_vids)
- result: List[Any] =
self._client.gremlin().exec(gremlin=rag_gremlin_query)["data"]
+ gremlin_query = VERTEX_QUERY_TPL.format(keywords=match_vids)
+ result: List[Any] =
self._client.gremlin().exec(gremlin=gremlin_query)["data"]
+ log.debug(f"Vids query: {gremlin_query}")
- vertex_knowledge =
self._format_knowledge_from_vertex(query_result=result)
- rag_gremlin_query = self.ID_RAG_GREMLIN_QUERY_TEMPL.format(
- keywords=entrance_vids,
+ vertex_knowledge =
self._format_graph_from_vertex(query_result=result)
+ gremlin_query = ID_QUERY_NEIGHBOR_TPL.format(
+ keywords=match_vids,
max_deep=self._max_deep,
max_items=self._max_items,
edge_labels=edge_labels_str,
)
- result: List[Any] =
self._client.gremlin().exec(gremlin=rag_gremlin_query)["data"]
- graph_chain_knowledge, vertex_degree_list, knowledge_with_degree =
self._format_knowledge_from_query_result(
+ log.debug(f"Kneighbor query: {gremlin_query}")
+
+ result: List[Any] =
self._client.gremlin().exec(gremlin=gremlin_query)["data"]
+ graph_chain_knowledge, vertex_degree_list, knowledge_with_degree =
self._format_graph_from_query_result(
query_result=result
)
graph_chain_knowledge.update(vertex_knowledge)
@@ -143,17 +140,20 @@ class GraphRAGQuery:
else:
vertex_degree_list.append(vertex_knowledge)
else:
+ # WARN: When will the query enter here?
assert keywords, "No related property(keywords) for graph query."
keywords_str = ",".join("'" + kw + "'" for kw in keywords)
- rag_gremlin_query = self.PROP_RAG_GREMLIN_QUERY_TEMPL.format(
+ gremlin_query = PROPERTY_QUERY_NEIGHBOR_TPL.format(
prop=self._prop_to_match,
keywords=keywords_str,
max_deep=self._max_deep,
max_items=self._max_items,
edge_labels=edge_labels_str,
)
- result: List[Any] =
self._client.gremlin().exec(gremlin=rag_gremlin_query)["data"]
- graph_chain_knowledge, vertex_degree_list, knowledge_with_degree =
self._format_knowledge_from_query_result(
+ log.warning("Unable to find vid, downgraded to property query,
please confirm if it meets expectation.")
+
+ result: List[Any] =
self._client.gremlin().exec(gremlin=gremlin_query)["data"]
+ graph_chain_knowledge, vertex_degree_list, knowledge_with_degree =
self._format_graph_from_query_result(
query_result=result
)
@@ -175,7 +175,7 @@ class GraphRAGQuery:
return context
- def _format_knowledge_from_vertex(self, query_result: List[Any]) ->
Set[str]:
+ def _format_graph_from_vertex(self, query_result: List[Any]) -> Set[str]:
knowledge = set()
for item in query_result:
props_str = ", ".join(f"{k}: {v}" for k, v in
item["properties"].items())
@@ -183,7 +183,7 @@ class GraphRAGQuery:
knowledge.add(node_str)
return knowledge
- def _format_knowledge_from_query_result(
+ def _format_graph_from_query_result(
self, query_result: List[Any]
) -> Tuple[Set[str], List[Set[str]], Dict[str, List[str]]]:
use_id_to_match = self._prop_to_match is None
@@ -237,18 +237,14 @@ class GraphRAGQuery:
def _extract_labels_from_schema(self) -> Tuple[List[str], List[str]]:
schema = self._get_graph_schema()
node_props_str, edge_props_str = schema.split("\n")[:2]
- node_props_str = node_props_str[len("Node properties: ")
:].strip("[").strip("]")
- edge_props_str = edge_props_str[len("Edge properties: ")
:].strip("[").strip("]")
+ node_props_str = node_props_str[len("Node properties:
"):].strip("[").strip("]")
+ edge_props_str = edge_props_str[len("Edge properties:
"):].strip("[").strip("]")
node_labels = self._extract_label_names(node_props_str)
edge_labels = self._extract_label_names(edge_props_str)
return node_labels, edge_labels
@staticmethod
- def _extract_label_names(
- source: str,
- head: str = "name: ",
- tail: str = ", ",
- ) -> List[str]:
+ def _extract_label_names(source: str, head: str = "name: ", tail: str = ",
") -> List[str]:
result = []
for s in source.split(head):
end = s.find(tail)
diff --git
a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/schema_manager.py
b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/schema_manager.py
index a61063f..57da3ef 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/schema_manager.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/schema_manager.py
@@ -33,8 +33,6 @@ class SchemaManager:
)
self.schema = self.client.schema()
- # FIXME: This method is not working as expected
- # def run(self, context: Dict[str, Any]) -> Dict[str, Any]:
def run(self, context: Optional[Dict[str, Any]]) -> Dict[str, Any]:
if context is None:
context = {}
diff --git
a/hugegraph-llm/src/hugegraph_llm/operators/index_op/semantic_id_query.py
b/hugegraph-llm/src/hugegraph_llm/operators/index_op/semantic_id_query.py
index 6cd8620..d22d00a 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/index_op/semantic_id_query.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/index_op/semantic_id_query.py
@@ -51,19 +51,19 @@ class SemanticIdQuery:
)
def run(self, context: Dict[str, Any]) -> Dict[str, Any]:
- graph_query_entrance = []
+ graph_query_list = []
if self.by == "query":
query = context["query"]
query_vector = self.embedding.get_text_embedding(query)
results = self.vector_index.search(query_vector,
top_k=self.topk_per_query)
if results:
- graph_query_entrance.extend(results[:self.topk_per_query])
+ graph_query_list.extend(results[:self.topk_per_query])
else: # by keywords
exact_match_vids, unmatched_vids =
self._exact_match_vids(context["keywords"])
- graph_query_entrance.extend(exact_match_vids)
+ graph_query_list.extend(exact_match_vids)
fuzzy_match_vids = self._fuzzy_match_vids(unmatched_vids)
- graph_query_entrance.extend(fuzzy_match_vids)
- context["entrance_vids"] = list(set(graph_query_entrance))
+ graph_query_list.extend(fuzzy_match_vids)
+ context["match_vids"] = list(set(graph_query_list))
return context
def _exact_match_vids(self, keywords: List[str]) -> Tuple[List[str],
List[str]]:
@@ -89,4 +89,4 @@ class SemanticIdQuery:
results = self.vector_index.search(keyword_vector,
top_k=self.topk_per_keyword)
if results:
fuzzy_match_result.extend(results[:self.topk_per_keyword])
- return fuzzy_match_result
+ return fuzzy_match_result # FIXME: type mismatch, got 'list[dict[str,
Any]]' instead
diff --git
a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py
b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py
index fda6eb7..3d1e018 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py
@@ -22,7 +22,7 @@ from typing import Any, Dict, Optional
from hugegraph_llm.models.llms.base import BaseLLM
from hugegraph_llm.models.llms.init_llm import LLMs
from hugegraph_llm.config import prompt
-
+from hugegraph_llm.utils.log import log
DEFAULT_ANSWER_TEMPLATE = prompt.answer_prompt
@@ -87,13 +87,13 @@ class AnswerSynthesize:
graph_result = context.get("graph_result")
if graph_result:
- graph_context_head = context.get("graph_context_head",
- "The following are knowledge from
HugeGraph related to the query:\n")
+ graph_context_head = context.get("graph_context_head", "Knowledge
from graphdb for the query:\n")
graph_result_context = graph_context_head + "\n".join(
f"{i + 1}. {res}" for i, res in enumerate(graph_result)
)
else:
- graph_result_context = "No related knowledge found in graph for
the query."
+ graph_result_context = "No related graph data found for current
query."
+ log.warning(graph_result_context)
context = asyncio.run(self.async_generate(context, context_head_str,
context_tail_str,
vector_result_context,
graph_result_context))
diff --git
a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/disambiguate_data.py
b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/disambiguate_data.py
index fffa03c..dcc0ab6 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/disambiguate_data.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/disambiguate_data.py
@@ -46,7 +46,7 @@ class DisambiguateData:
# only disambiguate triples
if "triples" in data:
# TODO: ensure the logic here
- log.debug(data)
+ # log.debug(data)
triples = data["triples"]
prompt = generate_disambiguate_prompt(triples)
llm_output = self.llm.generate(prompt=prompt)
diff --git
a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/keyword_extract.py
b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/keyword_extract.py
index 3249670..5e58c3c 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/keyword_extract.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/keyword_extract.py
@@ -23,20 +23,18 @@ from hugegraph_llm.models.llms.base import BaseLLM
from hugegraph_llm.models.llms.init_llm import LLMs
from hugegraph_llm.operators.common_op.nltk_helper import NLTKHelper
-DEFAULT_KEYWORDS_EXTRACT_TEMPLATE_TMPL = """extract {max_keywords} keywords
from the text:
- {question}
- Provide keywords in the following comma-separated format: 'KEYWORDS:
<keywords>'
- """
-
-DEFAULT_KEYWORDS_EXPAND_TEMPLATE_TMPL = (
- "Generate synonyms or possible form of keywords up to {max_keywords} in
total,\n"
- "considering possible cases of capitalization, pluralization, common
expressions, etc.\n"
- "Provide all synonyms of keywords in comma-separated format: 'SYNONYMS:
<keywords>'\n"
- "Note, result should be in one-line with only one 'SYNONYMS: ' prefix\n"
- "----\n"
- "KEYWORDS: {question}\n"
- "----"
-)
+KEYWORDS_EXTRACT_TPL = """extract {max_keywords} keywords from the text:
+{question}
+Provide keywords in the following comma-separated format: 'KEYWORDS:
<keywords>'
+"""
+
+KEYWORDS_EXPAND_TPL = """Generate synonyms or possible form of keywords up to
{max_keywords} in total,
+considering possible cases of capitalization, pluralization, common
expressions, etc.
+Provide all synonyms of keywords in comma-separated format: 'SYNONYMS:
<keywords>'
+Note, result should be in one-line with only one 'SYNONYMS: ' prefix
+----
+KEYWORDS: {question}
+----"""
class KeywordExtract:
@@ -53,8 +51,8 @@ class KeywordExtract:
self._query = text
self._language = language.lower()
self._max_keywords = max_keywords
- self._extract_template = extract_template or
DEFAULT_KEYWORDS_EXTRACT_TEMPLATE_TMPL
- self._expand_template = expand_template or
DEFAULT_KEYWORDS_EXPAND_TEMPLATE_TMPL
+ self._extract_template = extract_template or KEYWORDS_EXTRACT_TPL
+ self._expand_template = expand_template or KEYWORDS_EXPAND_TPL
def run(self, context: Dict[str, Any]) -> Dict[str, Any]:
if self._query is None:
@@ -91,7 +89,8 @@ class KeywordExtract:
verbose = context.get("verbose") or False
if verbose:
- print(f"\033[92mKEYWORDS: {context['keywords']}\033[0m")
+ from hugegraph_llm.utils.log import log
+ log.info(f"KEYWORDS: {context['keywords']}")
# extracting keywords & expanding synonyms increase the call count by 2
context["call_count"] = context.get("call_count", 0) + 2
@@ -126,15 +125,11 @@ class KeywordExtract:
else:
keywords.append(k)
- # if the keyword consists of multiple words, split into sub-words
- # (removing stopwords)
+ # if the keyword consists of multiple words, split into sub-words
(removing stopwords)
results = set()
for token in keywords:
results.add(token)
sub_tokens = re.findall(r"\w+", token)
if len(sub_tokens) > 1:
- results.update(
- {w for w in sub_tokens if w not in
NLTKHelper().stopwords(lang=self._language)}
- )
-
+ results.update({w for w in sub_tokens if w not in
NLTKHelper().stopwords(lang=self._language)})
return results
diff --git a/hugegraph-llm/src/hugegraph_llm/utils/decorators.py
b/hugegraph-llm/src/hugegraph_llm/utils/decorators.py
index 5f37d47..394acda 100644
--- a/hugegraph-llm/src/hugegraph_llm/utils/decorators.py
+++ b/hugegraph-llm/src/hugegraph_llm/utils/decorators.py
@@ -72,7 +72,7 @@ def log_operator_time(func: Callable) -> Callable:
# Only record time ≥ 0.01s (10ms)
if op_time >= 0.01:
log.debug("Operator %s finished in %.2f seconds",
operator.__class__.__name__, op_time)
- # log.debug("Context:\n%s", result)
+ # log.debug("Current context:\n%s", result)
return result
return wrapper
diff --git a/hugegraph-llm/src/hugegraph_llm/utils/graph_index_utils.py
b/hugegraph-llm/src/hugegraph_llm/utils/graph_index_utils.py
index 73bb813..fc3ead8 100644
--- a/hugegraph-llm/src/hugegraph_llm/utils/graph_index_utils.py
+++ b/hugegraph-llm/src/hugegraph_llm/utils/graph_index_utils.py
@@ -22,6 +22,7 @@ import traceback
from typing import Dict, Any, Union
import gradio as gr
+
from .hugegraph_utils import get_hg_client, clean_hg_data
from .log import log
from .vector_index_utils import read_documents
@@ -86,7 +87,7 @@ def extract_graph(input_file, input_text, schema,
example_prompt) -> str:
def fit_vid_index():
builder = KgBuilder(LLMs().get_llm(), Embeddings().get_embedding(),
get_hg_client())
builder.fetch_graph_data().build_vertex_id_semantic_index()
- log.debug(builder.operators)
+ log.debug("Operators: %s", builder.operators)
try:
context = builder.run()
removed_num = context["removed_vid_vector_num"]