This is an automated email from the ASF dual-hosted git repository.

jin pushed a commit to branch bak-api
in repository https://gitbox.apache.org/repos/asf/incubator-hugegraph-ai.git

commit a629571dfa6fe860e167fb4f302691c613acc069
Author: imbajin <[email protected]>
AuthorDate: Thu Sep 5 16:25:42 2024 +0800

    refactor some params
    
    feat(llm): support graph_rag_recall api
    
    refact: fix some bugs & query logic
    
    bugfix
    
    refact(llm): enhance the graph/gremlin query phrase
---
 hugegraph-llm/README.md                            |  11 +-
 .../src/hugegraph_llm/api/models/rag_requests.py   |  15 +-
 hugegraph-llm/src/hugegraph_llm/api/rag_api.py     |  51 ++++++-
 .../src/hugegraph_llm/demo/rag_web_demo.py         |  73 +++++----
 .../src/hugegraph_llm/models/llms/init_llm.py      |   4 +-
 .../src/hugegraph_llm/models/llms/openai.py        |   2 +-
 .../operators/document_op/word_extract.py          |   3 +-
 .../src/hugegraph_llm/operators/graph_rag_task.py  |  36 ++---
 .../operators/hugegraph_op/graph_rag_query.py      | 169 +++++++++++----------
 .../operators/hugegraph_op/schema_manager.py       |   6 +-
 .../operators/index_op/semantic_id_query.py        |  10 +-
 .../operators/llm_op/answer_synthesize.py          |  21 +--
 .../operators/llm_op/disambiguate_data.py          |   2 +-
 .../operators/llm_op/keyword_extract.py            |  48 +++---
 .../src/hugegraph_llm/utils/decorators.py          |   2 +-
 .../src/hugegraph_llm/utils/graph_index_utils.py   |   6 +-
 16 files changed, 248 insertions(+), 211 deletions(-)

diff --git a/hugegraph-llm/README.md b/hugegraph-llm/README.md
index 25de21b..8df48bb 100644
--- a/hugegraph-llm/README.md
+++ b/hugegraph-llm/README.md
@@ -130,22 +130,19 @@ The methods of the `KgBuilder` class can be chained 
together to perform a sequen
 
 Run example like `python3 ./hugegraph_llm/examples/graph_rag_test.py`
 
-The `GraphRAG` class is used to integrate HugeGraph with large language models 
to provide retrieval-augmented generation capabilities.
+The `RAGPipeline` class is used to integrate HugeGraph with large language 
models to provide retrieval-augmented generation capabilities.
 Here is a brief usage guide:
 
 1. **Extract Keyword:**: Extract keywords and expand synonyms.
-    
+
     ```python
-    graph_rag.extract_keyword(text="Tell me about Al Pacino.").print_result()
+    graph_rag.extract_keywords(text="Tell me about Al Pacino.").print_result()
     ```
 
 2. **Query Graph for Rag**: Retrieve the corresponding keywords and their 
multi-degree associated relationships from HugeGraph.
 
      ```python
-     graph_rag.query_graph_for_rag(
-        max_deep=2,
-        max_items=30
-     ).print_result()
+     graph_rag.query_graph_db(max_deep=2, max_items=30).print_result()
      ```
 3. **Synthesize Answer**: Summarize the results and organize the language to 
answer the question.
 
diff --git a/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py 
b/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py
index a211bb8..0e7c666 100644
--- a/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py
+++ b/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py
@@ -15,20 +15,31 @@
 # specific language governing permissions and limitations
 # under the License.
 
-from typing import Optional
+from typing import Optional, Literal
 
 from pydantic import BaseModel
 
 
 class RAGRequest(BaseModel):
     query: str
-    raw_llm: Optional[bool] = True
+    raw_llm: Optional[bool] = False
     vector_only: Optional[bool] = False
     graph_only: Optional[bool] = False
     graph_vector: Optional[bool] = False
+    graph_ratio: float = 0.5
+    rerank_method: Literal["bleu", "reranker"] = "bleu"
+    near_neighbor_first: bool = False
+    custom_related_information: str = None
     answer_prompt: Optional[str] = None
 
 
+class GraphRAGRequest(BaseModel):
+    query: str
+    rerank_method: Literal["bleu", "reranker"] = "bleu"
+    near_neighbor_first: bool = False
+    custom_related_information: str = None
+
+
 class GraphConfigRequest(BaseModel):
     ip: str = "127.0.0.1"
     port: str = "8080"
diff --git a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py 
b/hugegraph-llm/src/hugegraph_llm/api/rag_api.py
index 64daf70..3ec4580 100644
--- a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py
+++ b/hugegraph-llm/src/hugegraph_llm/api/rag_api.py
@@ -14,22 +14,42 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+import json
+from typing import Literal
 
-from fastapi import status, APIRouter
+from fastapi import status, APIRouter, HTTPException
 
 from hugegraph_llm.api.exceptions.rag_exceptions import generate_response
 from hugegraph_llm.api.models.rag_requests import (
     RAGRequest,
     GraphConfigRequest,
     LLMConfigRequest,
-    RerankerConfigRequest,
+    RerankerConfigRequest, GraphRAGRequest,
 )
 from hugegraph_llm.api.models.rag_response import RAGResponse
 from hugegraph_llm.config import settings
+from hugegraph_llm.utils.log import log
+
+
+def graph_rag_recall(
+        text: str,
+        rerank_method: Literal["bleu", "reranker"],
+        near_neighbor_first: bool,
+        custom_related_information: str
+) -> dict:
+    from hugegraph_llm.operators.graph_rag_task import RAGPipeline
+    rag = RAGPipeline()
+    
rag.extract_keywords().keywords_to_vid().query_graph_db().merge_dedup_rerank(
+        rerank_method=rerank_method,
+        near_neighbor_first=near_neighbor_first,
+        custom_related_information=custom_related_information,
+    )
+    context = rag.run(verbose=True, query=text, graph_search=True)
+    return context
 
 
 def rag_http_api(
-    router: APIRouter, rag_answer_func, apply_graph_conf, apply_llm_conf, 
apply_embedding_conf, apply_reranker_conf
+        router: APIRouter, rag_answer_func, apply_graph_conf, apply_llm_conf, 
apply_embedding_conf, apply_reranker_conf
 ):
     @router.post("/rag", status_code=status.HTTP_200_OK)
     def rag_answer_api(req: RAGRequest):
@@ -41,6 +61,31 @@ def rag_http_api(
             if getattr(req, key)
         }
 
+    @router.post("/rag/graph", status_code=status.HTTP_200_OK)
+    def graph_rag_recall_api(req: GraphRAGRequest):
+        try:
+            result = graph_rag_recall(
+                text=req.query,
+                rerank_method=req.rerank_method,
+                near_neighbor_first=req.near_neighbor_first,
+                custom_related_information=req.custom_related_information
+            )
+            # TODO: handle QianFanClient error (not dict..)
+            # log.critical(f"## {type(result)}, {json.dumps(result)}")
+            if isinstance(result, dict):
+                log.critical(f"1## {type(result)}")
+                return {"graph_recall": result}
+            else:
+                log.critical(f"2## {type(result)}")
+                return {"graph_recall": json.dumps(result)}
+
+        except TypeError as e:
+            log.error(f"TypeError in graph_rag_recall_api: {e}")
+            raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, 
detail=str(e))
+        except Exception as e:
+            log.error(f"Unexpected error occurred: {e}")
+            raise 
HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An 
unexpected error occurred.")
+
     @router.post("/config/graph", status_code=status.HTTP_201_CREATED)
     def graph_config_api(req: GraphConfigRequest):
         # Accept status code
diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py 
b/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py
index c4c68c0..f145d4d 100644
--- a/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py
+++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py
@@ -68,37 +68,40 @@ def rag_answer(
     graph_only_answer: bool,
     graph_vector_answer: bool,
     graph_ratio: float,
-    rerank_method: Literal["bleu", "reranker"],
-    near_neighbor_first: bool,
-    custom_related_information: str,
-    answer_prompt: str,
+    rerank_method: Literal["bleu", "reranker"] = "bleu",
+    near_neighbor_first: bool = False,
+    custom_related_information: str = None,
+    answer_prompt: str = None,
 ) -> Tuple:
+    """
+    Generate an answer using the RAG (Retrieval-Augmented Generation) pipeline.
+
+    1. Initialize the RAGPipeline.
+    2. Select vector search or graph search based on parameters.
+    3. Merge, deduplicate, and rerank the results.
+    4. Synthesize the final answer.
+    5. Run the pipeline and return the results.
+    """
     vector_search = vector_only_answer or graph_vector_answer
     graph_search = graph_only_answer or graph_vector_answer
-
     if raw_answer is False and not vector_search and not graph_search:
         gr.Warning("Please select at least one generate mode.")
         return "", "", "", ""
-    searcher = RAGPipeline()
+
+    rag = RAGPipeline()
     if vector_search:
-        searcher.query_vector_index_for_rag()
+        rag.query_vector_index()
     if graph_search:
-        searcher.extract_keyword().match_keyword_to_id().query_graph_for_rag()
+        rag.extract_keywords().keywords_to_vid().query_graph_db()
+
     # TODO: add more user-defined search strategies
-    searcher.merge_dedup_rerank(
-        graph_ratio, rerank_method, near_neighbor_first, 
custom_related_information
-    ).synthesize_answer(
-        raw_answer=raw_answer,
-        vector_only_answer=vector_only_answer,
-        graph_only_answer=graph_only_answer,
-        graph_vector_answer=graph_vector_answer,
-        answer_prompt=answer_prompt,
-    )
+    rag.merge_dedup_rerank(graph_ratio, rerank_method, near_neighbor_first, 
custom_related_information)
+    rag.synthesize_answer(raw_answer, vector_only_answer, graph_only_answer, 
graph_vector_answer, answer_prompt)
 
     try:
-        context = searcher.run(verbose=True, query=text, 
vector_search=vector_search, graph_search=graph_search)
+        context = rag.run(verbose=True, query=text, 
vector_search=vector_search, graph_search=graph_search)
         if context.get("switch_to_bleu"):
-            gr.Warning("Online reranker fails, automatically switches to local 
bleu method.")
+            gr.Warning("Online reranker fails, automatically switches to local 
bleu rerank.")
         return (
             context.get("raw_answer", ""),
             context.get("vector_only_answer", ""),
@@ -106,10 +109,10 @@ def rag_answer(
             context.get("graph_vector_answer", ""),
         )
     except ValueError as e:
-        log.error(e)
+        log.critical(e)
         raise gr.Error(str(e))
     except Exception as e:
-        log.error(e)
+        log.critical(e)
         raise gr.Error(f"An unexpected error occurred: {str(e)}")
 
 
@@ -529,12 +532,14 @@ def init_rag_ui() -> gr.Interface:
                     input_file = gr.File(
                         value=[os.path.join(resource_path, "demo", 
"test.txt")],
                         label="Docs (multi-files can be selected together)",
-                        file_count="multiple"
+                        file_count="multiple",
                     )
                 with gr.Tab("text") as tab_upload_text:
-                    input_text = gr.Textbox(value="", label="Doc(s)")
-            input_schema = gr.Textbox(value=schema, label="Schema")
-            info_extract_template = gr.Textbox(value=SCHEMA_EXAMPLE_PROMPT, 
label="Info extract head")
+                    input_text = gr.Textbox(value="", label="Doc(s)", 
lines=20, show_copy_button=True)
+            input_schema = gr.Textbox(value=schema, label="Schema", lines=15, 
show_copy_button=True)
+            info_extract_template = gr.Textbox(value=SCHEMA_EXAMPLE_PROMPT, 
label="Info extract head", lines=15,
+                                               show_copy_button=True)
+            out = gr.Textbox(label="Output", lines=15, show_copy_button=True)
         with gr.Column(visible=False, elem_classes="modal-box") as preview_box:
             with gr.Row():
                 import_btn = gr.Button("Import Into Graph")
@@ -552,8 +557,8 @@ def init_rag_ui() -> gr.Interface:
             graph_index_btn1 = gr.Button("Clear Graph Index")
             graph_index_btn2 = gr.Button("Extract Graph", variant="primary")
             graph_index_btn3 = gr.Button("Fit Vid Index")
-        with gr.Row():
-            out = gr.Textbox(label="Output", show_copy_button=True)
+        # with gr.Row():
+        #     out = gr.Textbox(label="Output", show_copy_button=True)
         vector_index_btn0.click(get_vector_index_info, outputs=out)  # pylint: 
disable=no-member
         vector_index_btn1.click(clean_vector_index)  # pylint: 
disable=no-member
         vector_index_btn2.click(build_vector_index, inputs=[input_file, 
input_text], outputs=out)  # pylint: disable=no-member
@@ -663,23 +668,15 @@ if __name__ == "__main__":
     parser.add_argument("--port", type=int, default=8001, help="port")
     args = parser.parse_args()
     app = FastAPI()
-    app_auth = APIRouter(dependencies=[Depends(authenticate)])
+    auth = APIRouter(dependencies=[Depends(authenticate)])
 
     hugegraph_llm = init_rag_ui()
-    rag_http_api(
-        app_auth,
-        rag_answer,
-        apply_graph_config,
-        apply_llm_config,
-        apply_embedding_config,
-        apply_reranker_config,
-    )
+    rag_http_api(auth, rag_answer, apply_graph_config, apply_llm_config, 
apply_embedding_config, apply_reranker_config)
 
-    app.include_router(app_auth)
+    app.include_router(auth)
     auth_enabled = os.getenv("ENABLE_LOGIN", "False").lower() == "true"
     log.info("Authentication is %s.", "enabled" if auth_enabled else 
"disabled")
     # TODO: support multi-user login when need
-
     app = gr.mount_gradio_app(app, hugegraph_llm, path="/", auth=("rag", 
os.getenv("TOKEN")) if auth_enabled else None)
 
     # TODO: we can't use reload now due to the config 'app' of uvicorn.run
diff --git a/hugegraph-llm/src/hugegraph_llm/models/llms/init_llm.py 
b/hugegraph-llm/src/hugegraph_llm/models/llms/init_llm.py
index 22a82cb..2c90748 100644
--- a/hugegraph-llm/src/hugegraph_llm/models/llms/init_llm.py
+++ b/hugegraph-llm/src/hugegraph_llm/models/llms/init_llm.py
@@ -17,7 +17,7 @@
 
 
 from hugegraph_llm.models.llms.ollama import OllamaClient
-from hugegraph_llm.models.llms.openai import OpenAIChat
+from hugegraph_llm.models.llms.openai import OpenAIClient
 from hugegraph_llm.models.llms.qianfan import QianfanClient
 from hugegraph_llm.config import settings
 
@@ -34,7 +34,7 @@ class LLMs:
                 secret_key=settings.qianfan_secret_key
             )
         if self.llm_type == "openai":
-            return OpenAIChat(
+            return OpenAIClient(
                 api_key=settings.openai_api_key,
                 api_base=settings.openai_api_base,
                 model_name=settings.openai_language_model,
diff --git a/hugegraph-llm/src/hugegraph_llm/models/llms/openai.py 
b/hugegraph-llm/src/hugegraph_llm/models/llms/openai.py
index 30ac805..263f4f2 100644
--- a/hugegraph-llm/src/hugegraph_llm/models/llms/openai.py
+++ b/hugegraph-llm/src/hugegraph_llm/models/llms/openai.py
@@ -27,7 +27,7 @@ from hugegraph_llm.models.llms.base import BaseLLM
 from hugegraph_llm.utils.log import log
 
 
-class OpenAIChat(BaseLLM):
+class OpenAIClient(BaseLLM):
     """Wrapper around OpenAI Chat large language models."""
 
     def __init__(
diff --git 
a/hugegraph-llm/src/hugegraph_llm/operators/document_op/word_extract.py 
b/hugegraph-llm/src/hugegraph_llm/operators/document_op/word_extract.py
index 43d0a29..b8f40df 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/document_op/word_extract.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/document_op/word_extract.py
@@ -62,7 +62,8 @@ class WordExtract:
 
         verbose = context.get("verbose") or False
         if verbose:
-            print(f"\033[92mKEYWORDS: {context['keywords']}\033[0m")
+            from hugegraph_llm.utils.log import log
+            log.info(f"KEYWORDS: {context['keywords']}")
 
         return context
 
diff --git a/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py 
b/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py
index 91bc7b3..573b00c 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py
@@ -50,11 +50,7 @@ class RAGPipeline:
         self._embedding = embedding or Embeddings().get_embedding()
         self._operators: List[Any] = []
 
-    def extract_word(
-            self,
-            text: Optional[str] = None,
-            language: str = "english",
-    ):
+    def extract_word(self, text: Optional[str] = None, language: str = 
"english"):
         """
         Add a word extraction operator to the pipeline.
 
@@ -62,15 +58,10 @@ class RAGPipeline:
         :param language: Language of the text.
         :return: Self-instance for chaining.
         """
-        self._operators.append(
-            WordExtract(
-                text=text,
-                language=language,
-            )
-        )
+        self._operators.append(WordExtract(text=text, language=language))
         return self
 
-    def extract_keyword(
+    def extract_keywords(
             self,
             text: Optional[str] = None,
             max_keywords: int = 5,
@@ -99,7 +90,7 @@ class RAGPipeline:
         )
         return self
 
-    def match_keyword_to_id(
+    def keywords_to_vid(
         self,
         by: Literal["query", "keywords"] = "keywords",
         topk_per_keyword: int = 1,
@@ -108,6 +99,8 @@ class RAGPipeline:
         """
         Add a semantic ID query operator to the pipeline.
 
+        :param topk_per_query: Top K results per query.
+        :param by: Match by query or keywords.
         :param topk_per_keyword: Top K results per keyword.
         :return: Self-instance for chaining.
         """
@@ -121,7 +114,7 @@ class RAGPipeline:
         )
         return self
 
-    def query_graph_for_rag(
+    def query_graph_db(
         self,
         max_deep: int = 2,
         max_items: int = 30,
@@ -136,15 +129,11 @@ class RAGPipeline:
         :return: Self-instance for chaining.
         """
         self._operators.append(
-            GraphRAGQuery(
-                max_deep=max_deep,
-                max_items=max_items,
-                prop_to_match=prop_to_match,
-            )
+            GraphRAGQuery(max_deep=max_deep, max_items=max_items, 
prop_to_match=prop_to_match)
         )
         return self
 
-    def query_vector_index_for_rag(self, max_items: int = 3):
+    def query_vector_index(self, max_items: int = 3):
         """
         Add a vector index query operator to the pipeline.
 
@@ -152,10 +141,7 @@ class RAGPipeline:
         :return: Self-instance for chaining.
         """
         self._operators.append(
-            VectorIndexQuery(
-                embedding=self._embedding,
-                topk=max_items,
-            )
+            VectorIndexQuery(embedding=self._embedding, topk=max_items)
         )
         return self
 
@@ -230,7 +216,7 @@ class RAGPipeline:
         :return: Final context after all operators have been executed.
         """
         if len(self._operators) == 0:
-            self.extract_keyword().query_graph_for_rag().synthesize_answer()
+            self.extract_keywords().query_graph_db().synthesize_answer()
 
         context = kwargs
         context["llm"] = self._llm
diff --git 
a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py 
b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py
index fe225c2..01cb0fa 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py
@@ -20,64 +20,63 @@ import re
 from typing import Any, Dict, Optional, List, Set, Tuple
 
 from hugegraph_llm.config import settings
+from hugegraph_llm.utils.log import log
 from pyhugegraph.client import PyHugeClient
 
 
-class GraphRAGQuery:
-    VERTEX_GREMLIN_QUERY_TEMPL = "g.V().hasId({keywords}).as('subj').toList()"
-    # ID_RAG_GREMLIN_QUERY_TEMPL = 
"g.V().hasId({keywords}).as('subj').repeat(bothE({edge_labels}).as('rel').otherV(
-    # ).as('obj')).times({max_deep}).path().by(project('label', 'id', 
'props').by(label()).by(id()).by(valueMap().by(
-    # unfold()))).by(project('label', 'inV', 'outV', 
'props').by(label()).by(inV().id()).by(outV().id()).by(valueMap(
-    # ).by(unfold()))).limit({max_items}).toList()"
+VERTEX_QUERY_TPL = "g.V({keywords}).as('subj').toList()"
+# ID_RAG_GREMLIN_QUERY_TEMPL = 
"g.V().hasId({keywords}).as('subj').repeat(bothE({edge_labels}).as('rel').otherV(
+# ).as('obj')).times({max_deep}).path().by(project('label', 'id', 
'props').by(label()).by(id()).by(valueMap().by(
+# unfold()))).by(project('label', 'inV', 'outV', 
'props').by(label()).by(inV().id()).by(outV().id()).by(valueMap(
+# ).by(unfold()))).limit({max_items}).toList()"
+
+# TODO: we could use a simpler query (like kneighbor-api to get the edges)
+# TODO: use dedup() to filter duplicate paths
+ID_QUERY_NEIGHBOR_TPL = """
+g.V({keywords}).as('subj')
+.repeat(
+   bothE({edge_labels}).as('rel').otherV().as('obj')
+).times({max_deep})
+.path()
+.by(project('label', 'id', 'props')
+   .by(label())
+   .by(id())
+   .by(valueMap().by(unfold()))
+)
+.by(project('label', 'inV', 'outV', 'props')
+   .by(label())
+   .by(inV().id())
+   .by(outV().id())
+   .by(valueMap().by(unfold()))
+)
+.limit({max_items})
+.toList()
+"""
 
-    # TODO: we could use a simpler query (like kneighbor-api to get the edges)
-    ID_RAG_GREMLIN_QUERY_TEMPL = """
-    g.V().hasId({keywords}).as('subj')
-    .repeat(
-       bothE({edge_labels}).as('rel').otherV().as('obj')
-    ).times({max_deep})
-    .path()
-    .by(project('label', 'id', 'props')
-       .by(label())
-       .by(id())
-       .by(valueMap().by(unfold()))
-    )
-    .by(project('label', 'inV', 'outV', 'props')
-       .by(label())
-       .by(inV().id())
-       .by(outV().id())
-       .by(valueMap().by(unfold()))
-    )
-    .limit({max_items})
-    .toList()
-    """
+PROPERTY_QUERY_NEIGHBOR_TPL = """
+g.V().has('{prop}', within({keywords})).as('subj')
+.repeat(
+   bothE({edge_labels}).as('rel').otherV().as('obj')
+).times({max_deep})
+.path()
+.by(project('label', 'props')
+   .by(label())
+   .by(valueMap().by(unfold()))
+)
+.by(project('label', 'inV', 'outV', 'props')
+   .by(label())
+   .by(inV().values('{prop}'))
+   .by(outV().values('{prop}'))
+   .by(valueMap().by(unfold()))
+)
+.limit({max_items})
+.toList()
+"""
 
-    PROP_RAG_GREMLIN_QUERY_TEMPL = """
-    g.V().has('{prop}', within({keywords})).as('subj')
-    .repeat(
-       bothE({edge_labels}).as('rel').otherV().as('obj')
-    ).times({max_deep})
-    .path()
-    .by(project('label', 'props')
-       .by(label())
-       .by(valueMap().by(unfold()))
-    )
-    .by(project('label', 'inV', 'outV', 'props')
-       .by(label())
-       .by(inV().values('{prop}'))
-       .by(outV().values('{prop}'))
-       .by(valueMap().by(unfold()))
-    )
-    .limit({max_items})
-    .toList()
-    """
 
-    def __init__(
-        self,
-        max_deep: int = 2,
-        max_items: int = 30,
-        prop_to_match: Optional[str] = None,
-    ):
+class GraphRAGQuery:
+
+    def __init__(self, max_deep: int = 2, max_items: int = 30, prop_to_match: 
Optional[str] = None):
         self._client = PyHugeClient(
             settings.graph_ip,
             settings.graph_port,
@@ -106,7 +105,7 @@ class GraphRAGQuery:
         assert self._client is not None, "No valid graph to search."
 
         keywords = context.get("keywords")
-        entrance_vids = context.get("entrance_vids")
+        match_vids = context.get("match_vids")
 
         if isinstance(context.get("max_deep"), int):
             self._max_deep = context["max_deep"]
@@ -119,40 +118,48 @@ class GraphRAGQuery:
         edge_labels_str = ",".join("'" + label + "'" for label in edge_labels)
 
         use_id_to_match = self._prop_to_match is None
+        if use_id_to_match:
+            if not match_vids: return context
 
-        if not use_id_to_match:
-            assert keywords is not None, "No keywords for graph query."
-            keywords_str = ",".join("'" + kw + "'" for kw in keywords)
-            rag_gremlin_query = self.PROP_RAG_GREMLIN_QUERY_TEMPL.format(
-                prop=self._prop_to_match,
-                keywords=keywords_str,
+            gremlin_query = VERTEX_QUERY_TPL.format(keywords=match_vids)
+            result: List[Any] = 
self._client.gremlin().exec(gremlin=gremlin_query)["data"]
+            log.debug(f"Vids query: {gremlin_query}")
+
+            vertex_knowledge = 
self._format_graph_from_vertex(query_result=result)
+            gremlin_query = ID_QUERY_NEIGHBOR_TPL.format(
+                keywords=match_vids,
                 max_deep=self._max_deep,
                 max_items=self._max_items,
                 edge_labels=edge_labels_str,
             )
-            result: List[Any] = 
self._client.gremlin().exec(gremlin=rag_gremlin_query)["data"]
-            graph_chain_knowledge, vertex_degree_list, knowledge_with_degree = 
self._format_knowledge_from_query_result(
+            log.debug(f"Kneighbor query: {gremlin_query}")
+
+            result: List[Any] = 
self._client.gremlin().exec(gremlin=gremlin_query)["data"]
+            graph_chain_knowledge, vertex_degree_list, knowledge_with_degree = 
self._format_graph_from_query_result(
                 query_result=result
             )
+            graph_chain_knowledge.update(vertex_knowledge)
+            if vertex_degree_list:
+                vertex_degree_list[0].update(vertex_knowledge)
+            else:
+                vertex_degree_list.append(vertex_knowledge)
         else:
-            assert entrance_vids is not None, "No entrance vertices for query."
-            rag_gremlin_query = self.VERTEX_GREMLIN_QUERY_TEMPL.format(
-                keywords=entrance_vids,
-            )
-            result: List[Any] = 
self._client.gremlin().exec(gremlin=rag_gremlin_query)["data"]
-            vertex_knowledge = 
self._format_knowledge_from_vertex(query_result=result)
-            rag_gremlin_query = self.ID_RAG_GREMLIN_QUERY_TEMPL.format(
-                keywords=entrance_vids,
+            # When will the query enter here?
+            assert keywords, "No related property(keywords) for graph query."
+            keywords_str = ",".join("'" + kw + "'" for kw in keywords)
+            gremlin_query = PROPERTY_QUERY_NEIGHBOR_TPL.format(
+                prop=self._prop_to_match,
+                keywords=keywords_str,
                 max_deep=self._max_deep,
                 max_items=self._max_items,
                 edge_labels=edge_labels_str,
             )
-            result: List[Any] = 
self._client.gremlin().exec(gremlin=rag_gremlin_query)["data"]
-            graph_chain_knowledge, vertex_degree_list, knowledge_with_degree = 
self._format_knowledge_from_query_result(
+            log.warning("Unable to find vid, downgraded to property query, 
please confirm if it meets expectation.")
+
+            result: List[Any] = 
self._client.gremlin().exec(gremlin=gremlin_query)["data"]
+            graph_chain_knowledge, vertex_degree_list, knowledge_with_degree = 
self._format_graph_from_query_result(
                 query_result=result
             )
-            graph_chain_knowledge.update(vertex_knowledge)
-            vertex_degree_list[0].update(vertex_knowledge)
 
         context["graph_result"] = list(graph_chain_knowledge)
         context["vertex_degree_list"] = [list(vertex_degree) for vertex_degree 
in vertex_degree_list]
@@ -172,7 +179,7 @@ class GraphRAGQuery:
 
         return context
 
-    def _format_knowledge_from_vertex(self, query_result: List[Any]) -> 
Set[str]:
+    def _format_graph_from_vertex(self, query_result: List[Any]) -> Set[str]:
         knowledge = set()
         for item in query_result:
             props_str = ", ".join(f"{k}: {v}" for k, v in 
item["properties"].items())
@@ -180,8 +187,8 @@ class GraphRAGQuery:
             knowledge.add(node_str)
         return knowledge
 
-    def _format_knowledge_from_query_result(
-        self, query_result: List[Any]
+    def _format_graph_from_query_result(
+            self, query_result: List[Any]
     ) -> Tuple[Set[str], List[Set[str]], Dict[str, List[str]]]:
         use_id_to_match = self._prop_to_match is None
         knowledge = set()
@@ -234,18 +241,14 @@ class GraphRAGQuery:
     def _extract_labels_from_schema(self) -> Tuple[List[str], List[str]]:
         schema = self._get_graph_schema()
         node_props_str, edge_props_str = schema.split("\n")[:2]
-        node_props_str = node_props_str[len("Node properties: ") 
:].strip("[").strip("]")
-        edge_props_str = edge_props_str[len("Edge properties: ") 
:].strip("[").strip("]")
+        node_props_str = node_props_str[len("Node properties: 
"):].strip("[").strip("]")
+        edge_props_str = edge_props_str[len("Edge properties: 
"):].strip("[").strip("]")
         node_labels = self._extract_label_names(node_props_str)
         edge_labels = self._extract_label_names(edge_props_str)
         return node_labels, edge_labels
 
     @staticmethod
-    def _extract_label_names(
-        source: str,
-        head: str = "name: ",
-        tail: str = ", ",
-    ) -> List[str]:
+    def _extract_label_names(source: str, head: str = "name: ", tail: str = ", 
") -> List[str]:
         result = []
         for s in source.split(head):
             end = s.find(tail)
diff --git 
a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/schema_manager.py 
b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/schema_manager.py
index 5c002ae..a215bb4 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/schema_manager.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/schema_manager.py
@@ -14,7 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-
+from typing import Dict, Any
 
 from hugegraph_llm.config import settings
 from pyhugegraph.client import PyHugeClient
@@ -33,7 +33,9 @@ class SchemaManager:
         )
         self.schema = self.client.schema()
 
-    def run(self):
+    # FIXME: This method is not working as expected
+    # def run(self, context: Dict[str, Any]) -> Dict[str, Any]:
+    def run(self) -> Dict[str, Any]:
         schema = self.schema.getSchema()
         vertices = []
         for vl in schema["vertexlabels"]:
diff --git 
a/hugegraph-llm/src/hugegraph_llm/operators/index_op/semantic_id_query.py 
b/hugegraph-llm/src/hugegraph_llm/operators/index_op/semantic_id_query.py
index ef036d2..60be886 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/index_op/semantic_id_query.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/index_op/semantic_id_query.py
@@ -40,19 +40,19 @@ class SemanticIdQuery:
         self.topk_per_keyword = topk_per_keyword
 
     def run(self, context: Dict[str, Any]) -> Dict[str, Any]:
-        graph_query_entrance = []
+        graph_query_list = []
         if self.by == "query":
             query = context["query"]
             query_vector = self.embedding.get_text_embedding(query)
             results = self.vector_index.search(query_vector, 
top_k=self.topk_per_query)
             if results:
-                graph_query_entrance.extend(results[:self.topk_per_query])
-        else:  # by keywords
+                graph_query_list.extend(results[:self.topk_per_query])
+        else:
             keywords = context["keywords"]
             for keyword in keywords:
                 keyword_vector = self.embedding.get_text_embedding(keyword)
                 results = self.vector_index.search(keyword_vector, 
top_k=self.topk_per_keyword)
                 if results:
-                    
graph_query_entrance.extend(results[:self.topk_per_keyword])
-        context["entrance_vids"] = list(set(graph_query_entrance))
+                    graph_query_list.extend(results[:self.topk_per_keyword])
+        context["match_vids"] = list(set(graph_query_list))
         return context
diff --git 
a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py 
b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py
index 2cb7264..a960d24 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py
@@ -21,6 +21,7 @@ from typing import Any, Dict, Optional
 
 from hugegraph_llm.models.llms.base import BaseLLM
 from hugegraph_llm.models.llms.init_llm import LLMs
+from hugegraph_llm.utils.log import log
 
 # TODO: we need enhance the template to answer the question (put it in a 
separate file)
 DEFAULT_ANSWER_TEMPLATE = """You are an expert in knowledge graphs and natural 
language processing. 
@@ -88,25 +89,27 @@ class AnswerSynthesize:
             response = self._llm.generate(prompt=prompt)
             return {"answer": response}
 
-        vector_result = context.get("vector_result", [])
-        if len(vector_result) == 0:
-            vector_result_context = "No (vector)phrase related to the query."
-        else:
+        vector_result = context.get("vector_result")
+        if vector_result:
             vector_result_context = "Phrases related to the query:\n" + 
"\n".join(
                 f"{i + 1}. {res}" for i, res in enumerate(vector_result)
             )
-        graph_result = context.get("graph_result", [])
-        if len(graph_result) == 0:
-            graph_result_context = "No knowledge found in HugeGraph for the 
query."
         else:
+            vector_result_context = "No (vector)phrase related to the query."
+
+        graph_result = context.get("graph_result")
+        if graph_result:
             graph_context_head = context.get("graph_context_head",
-                                             "The following are knowledge from 
HugeGraph related to the query:\n")
+                                             "Knowledge from graph related to 
the query:\n")
             graph_result_context = graph_context_head + "\n".join(
                 f"{i + 1}. {res}" for i, res in enumerate(graph_result)
             )
+        else:
+            graph_result_context = "No related graph data found for current 
query."
+            log.warning(graph_result_context)
+
         context = asyncio.run(self.async_generate(context, context_head_str, 
context_tail_str,
                                                   vector_result_context, 
graph_result_context))
-
         return context
 
     async def async_generate(self, context: Dict[str, Any], context_head_str: 
str,
diff --git 
a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/disambiguate_data.py 
b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/disambiguate_data.py
index fffa03c..dcc0ab6 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/disambiguate_data.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/disambiguate_data.py
@@ -46,7 +46,7 @@ class DisambiguateData:
         # only disambiguate triples
         if "triples" in data:
             # TODO: ensure the logic here
-            log.debug(data)
+            # log.debug(data)
             triples = data["triples"]
             prompt = generate_disambiguate_prompt(triples)
             llm_output = self.llm.generate(prompt=prompt)
diff --git 
a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/keyword_extract.py 
b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/keyword_extract.py
index e4854c1..b9f75db 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/keyword_extract.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/keyword_extract.py
@@ -23,20 +23,18 @@ from hugegraph_llm.models.llms.base import BaseLLM
 from hugegraph_llm.models.llms.init_llm import LLMs
 from hugegraph_llm.operators.common_op.nltk_helper import NLTKHelper
 
-DEFAULT_KEYWORDS_EXTRACT_TEMPLATE_TMPL = """extract {max_keywords} keywords 
from the text:
-    {question}
-    Provide keywords in the following comma-separated format: 'KEYWORDS: 
<keywords>'
-    """
-
-DEFAULT_KEYWORDS_EXPAND_TEMPLATE_TMPL = (
-    "Generate synonyms or possible form of keywords up to {max_keywords} in 
total,\n"
-    "considering possible cases of capitalization, pluralization, common 
expressions, etc.\n"
-    "Provide all synonyms of keywords in comma-separated format: 'SYNONYMS: 
<keywords>'\n"
-    "Note, result should be in one-line with only one 'SYNONYMS: ' prefix\n"
-    "----\n"
-    "KEYWORDS: {question}\n"
-    "----"
-)
+KEYWORDS_EXTRACT_TPL = """extract {max_keywords} keywords from the text:
+{question}
+Provide keywords in the following comma-separated format: 'KEYWORDS: 
<keywords>'
+"""
+
+KEYWORDS_EXPAND_TPL = """Generate synonyms or possible form of keywords up to 
{max_keywords} in total,
+considering possible cases of capitalization, pluralization, common 
expressions, etc.
+Provide all synonyms of keywords in comma-separated format: 'SYNONYMS: 
<keywords>'
+Note, result should be in one-line with only one 'SYNONYMS: ' prefix
+----
+KEYWORDS: {question}
+----"""
 
 
 class KeywordExtract:
@@ -53,8 +51,8 @@ class KeywordExtract:
         self._query = text
         self._language = language.lower()
         self._max_keywords = max_keywords
-        self._extract_template = extract_template or 
DEFAULT_KEYWORDS_EXTRACT_TEMPLATE_TMPL
-        self._expand_template = expand_template or 
DEFAULT_KEYWORDS_EXPAND_TEMPLATE_TMPL
+        self._extract_template = extract_template or KEYWORDS_EXTRACT_TPL
+        self._expand_template = expand_template or KEYWORDS_EXPAND_TPL
 
     def run(self, context: Dict[str, Any]) -> Dict[str, Any]:
         if self._query is None:
@@ -84,14 +82,15 @@ class KeywordExtract:
         response = self._llm.generate(prompt=prompt)
 
         keywords = self._extract_keywords_from_response(
-            response=response, lowercase=False, start_token="KEYWORDS:"
+            response=response, start_token="KEYWORDS:"
         )
         keywords.union(self._expand_synonyms(keywords=keywords))
         context["keywords"] = list(keywords)
 
         verbose = context.get("verbose") or False
         if verbose:
-            print(f"\033[92mKEYWORDS: {context['keywords']}\033[0m")
+            from hugegraph_llm.utils.log import log
+            log.info(f"KEYWORDS: {context['keywords']}")
 
         # extracting keywords & expanding synonyms increase the call count by 2
         context["call_count"] = context.get("call_count", 0) + 2
@@ -104,16 +103,11 @@ class KeywordExtract:
         )
         response = self._llm.generate(prompt=prompt)
         keywords = self._extract_keywords_from_response(
-            response=response, lowercase=False, start_token="SYNONYMS:"
+            response=response, start_token="SYNONYMS:"
         )
         return keywords
 
-    def _extract_keywords_from_response(
-        self,
-        response: str,
-        lowercase: bool = True,
-        start_token: str = "",
-    ) -> Set[str]:
+    def _extract_keywords_from_response(self, response: str, start_token: str 
= "", ) -> Set[str]:
         keywords = []
         matches = re.findall(rf'{start_token}[^\n]+\n?', response)
         for match in matches:
@@ -130,8 +124,6 @@ class KeywordExtract:
             results.add(token)
             sub_tokens = re.findall(r"\w+", token)
             if len(sub_tokens) > 1:
-                results.update(
-                    {w for w in sub_tokens if w not in 
NLTKHelper().stopwords(lang=self._language)}
-                )
+                results.update({w for w in sub_tokens if w not in 
NLTKHelper().stopwords(lang=self._language)})
 
         return results
diff --git a/hugegraph-llm/src/hugegraph_llm/utils/decorators.py 
b/hugegraph-llm/src/hugegraph_llm/utils/decorators.py
index 173c1f7..8cbe475 100644
--- a/hugegraph-llm/src/hugegraph_llm/utils/decorators.py
+++ b/hugegraph-llm/src/hugegraph_llm/utils/decorators.py
@@ -72,7 +72,7 @@ def log_operator_time(func: Callable) -> Callable:
         # Only record time ≥ 0.01s (10ms)
         if op_time >= 0.01:
             log.debug("Operator %s finished in %.2f seconds", 
operator.__class__.__name__, op_time)
-            log.debug("Context:\n%s", result)
+            log.debug("Current context: %s", result)
         return result
     return wrapper
 
diff --git a/hugegraph-llm/src/hugegraph_llm/utils/graph_index_utils.py 
b/hugegraph-llm/src/hugegraph_llm/utils/graph_index_utils.py
index 0db406e..10bf72e 100644
--- a/hugegraph-llm/src/hugegraph_llm/utils/graph_index_utils.py
+++ b/hugegraph-llm/src/hugegraph_llm/utils/graph_index_utils.py
@@ -64,7 +64,7 @@ def extract_graph(input_file, input_text, schema, 
example_prompt):
     (builder
      .chunk_split(texts, "paragraph", "zh")
      .extract_info(example_prompt, "property_graph"))
-    log.debug(builder.operators)
+    log.debug("Operators: %s", builder.operators)
     try:
         context = builder.run()
         return (
@@ -80,7 +80,7 @@ def extract_graph(input_file, input_text, schema, 
example_prompt):
 def fit_vid_index():
     builder = KgBuilder(LLMs().get_llm(), Embeddings().get_embedding(), 
get_hg_client())
     builder.fetch_graph_data().build_vertex_id_semantic_index()
-    log.debug(builder.operators)
+    log.debug("Operators: %s", builder.operators)
     try:
         context = builder.run()
         removed_num = context["removed_vid_vector_num"]
@@ -109,7 +109,7 @@ def build_graph_index(input_file, input_text, schema, 
example_prompt):
      .extract_info(example_prompt, "property_graph")
      .commit_to_hugegraph()
      .build_vertex_id_semantic_index())
-    log.debug(builder.operators)
+    log.debug("Operators: %s", builder.operators)
     try:
         context = builder.run()
         return json.dumps(context, ensure_ascii=False, indent=2)


Reply via email to