This is an automated email from the ASF dual-hosted git repository.

jin pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-hugegraph-ai.git


The following commit(s) were added to refs/heads/main by this push:
     new 39e8d48  refactor(llm): remove enable_gql logic in api & rag block 
(#148)
39e8d48 is described below

commit 39e8d486b31f6a2dd99f8a78dbd6402e91e7ce18
Author: HaoJin Yang <1454...@gmail.com>
AuthorDate: Fri Dec 27 16:53:37 2024 +0800

    refactor(llm): remove enable_gql logic in api & rag block (#148)
    
    * feat(llm): support choose num of examples in rag
    
    ---------
    
    Co-authored-by: imbajin <j...@apache.org>
---
 .../src/hugegraph_llm/api/models/rag_requests.py          |  6 +++---
 hugegraph-llm/src/hugegraph_llm/api/rag_api.py            | 10 +++++-----
 .../src/hugegraph_llm/demo/rag_demo/rag_block.py          | 15 ++++++---------
 .../src/hugegraph_llm/demo/rag_demo/text2gremlin_block.py | 11 +++++------
 .../src/hugegraph_llm/operators/graph_rag_task.py         |  4 +---
 .../operators/hugegraph_op/graph_rag_query.py             | 10 ++++------
 .../operators/index_op/gremlin_example_index_query.py     | 11 ++++++-----
 7 files changed, 30 insertions(+), 37 deletions(-)

diff --git a/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py 
b/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py
index 489ce0a..de47aa0 100644
--- a/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py
+++ b/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py
@@ -32,7 +32,6 @@ class RAGRequest(BaseModel):
     graph_ratio: float = Query(0.5, description="The ratio of GraphRAG ans & 
vector ans")
     rerank_method: Literal["bleu", "reranker"] = Query("bleu", 
description="Method to rerank the results.")
     near_neighbor_first: bool = Query(False, description="Prioritize near 
neighbors in the search results.")
-    with_gremlin_tmpl: bool = Query(True, description="Use example template in 
text2gremlin")
     custom_priority_info: str = Query("", description="Custom information to 
prioritize certain results.")
     answer_prompt: Optional[str] = Query(prompt.answer_prompt, 
description="Prompt to guide the answer generation.")
     keywords_extract_prompt: Optional[str] = Query(
@@ -49,8 +48,9 @@ class RAGRequest(BaseModel):
 # TODO: import the default value of prompt.* dynamically
 class GraphRAGRequest(BaseModel):
     query: str = Query("", description="Query you want to ask")
-    gremlin_tmpl_num: int = Query(1, description="Number of Gremlin templates 
to use.")
-    with_gremlin_tmpl: bool = Query(True, description="Use example template in 
text2gremlin")
+    gremlin_tmpl_num: int = Query(
+        1, description="Number of Gremlin templates to use. If num <=0 means 
template is not provided"
+    )
     rerank_method: Literal["bleu", "reranker"] = Query("bleu", 
description="Method to rerank the results.")
     near_neighbor_first: bool = Query(False, description="Prioritize near 
neighbors in the search results.")
     custom_priority_info: str = Query("", description="Custom information to 
prioritize certain results.")
diff --git a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py 
b/hugegraph-llm/src/hugegraph_llm/api/rag_api.py
index 4036496..d851fd1 100644
--- a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py
+++ b/hugegraph-llm/src/hugegraph_llm/api/rag_api.py
@@ -49,7 +49,6 @@ def rag_http_api(
             vector_only_answer=req.vector_only,
             graph_only_answer=req.graph_only,
             graph_vector_answer=req.graph_vector_answer,
-            with_gremlin_template=req.with_gremlin_tmpl,
             graph_ratio=req.graph_ratio,
             rerank_method=req.rerank_method,
             near_neighbor_first=req.near_neighbor_first,
@@ -62,9 +61,11 @@ def rag_http_api(
         # TODO: we need more info in the response for users to understand the 
query logic
         return {
             "query": req.query,
-            **{key: value
-               for key, value in zip(["raw_answer", "vector_only", 
"graph_only", "graph_vector_answer"], result)
-               if getattr(req, key)}
+            **{
+                key: value
+                for key, value in zip(["raw_answer", "vector_only", 
"graph_only", "graph_vector_answer"], result)
+                if getattr(req, key)
+            },
         }
 
     @router.post("/rag/graph", status_code=status.HTTP_200_OK)
@@ -73,7 +74,6 @@ def rag_http_api(
             result = graph_rag_recall_func(
                 query=req.query,
                 gremlin_tmpl_num=req.gremlin_tmpl_num,
-                with_gremlin_tmpl=req.with_gremlin_tmpl,
                 rerank_method=req.rerank_method,
                 near_neighbor_first=req.near_neighbor_first,
                 custom_related_information=req.custom_priority_info,
diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/rag_block.py 
b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/rag_block.py
index 070f3b3..c10f84b 100644
--- a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/rag_block.py
+++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/rag_block.py
@@ -35,7 +35,6 @@ def rag_answer(
     vector_only_answer: bool,
     graph_only_answer: bool,
     graph_vector_answer: bool,
-    with_gremlin_template: bool,
     graph_ratio: float,
     rerank_method: Literal["bleu", "reranker"],
     near_neighbor_first: bool,
@@ -80,7 +79,6 @@ def rag_answer(
         
rag.extract_keywords(extract_template=keywords_extract_prompt).keywords_to_vid().import_schema(
             huge_settings.graph_name
         ).query_graphdb(
-            with_gremlin_template=with_gremlin_template,
             num_gremlin_generate_example=gremlin_tmpl_num,
             gremlin_prompt=gremlin_prompt,
         )
@@ -125,7 +123,10 @@ def create_rag_block():
                 value=prompt.answer_prompt, label="Query Prompt", 
show_copy_button=True, lines=7
             )
             keywords_extract_prompt_input = gr.Textbox(
-                value=prompt.keywords_extract_prompt, label="Keywords 
Extraction Prompt", show_copy_button=True, lines=7
+                value=prompt.keywords_extract_prompt,
+                label="Keywords Extraction Prompt",
+                show_copy_button=True,
+                lines=7,
             )
         with gr.Column(scale=1):
             with gr.Row():
@@ -134,8 +135,6 @@ def create_rag_block():
             with gr.Row():
                 graph_only_radio = gr.Radio(choices=[True, False], value=True, 
label="Graph-only Answer")
                 graph_vector_radio = gr.Radio(choices=[True, False], 
value=False, label="Graph-Vector Answer")
-            with gr.Row():
-                with_gremlin_template_radio = gr.Radio(choices=[True, False], 
value=True, label="With Gremlin Template")
 
             def toggle_slider(enable):
                 return gr.update(interactive=enable)
@@ -148,6 +147,7 @@ def create_rag_block():
                         value="reranker" if online_rerank else "bleu",
                         label="Rerank method",
                     )
+                    example_num = gr.Number(value=2, label="Template Num (0 to 
disable it) ", precision=0)
                     graph_ratio = gr.Slider(0, 1, 0.6, label="Graph Ratio", 
step=0.1, interactive=False)
 
                 graph_vector_radio.change(
@@ -172,13 +172,13 @@ def create_rag_block():
             vector_only_radio,
             graph_only_radio,
             graph_vector_radio,
-            with_gremlin_template_radio,
             graph_ratio,
             rerank_method,
             near_neighbor_first,
             custom_related_information,
             answer_prompt_input,
             keywords_extract_prompt_input,
+            example_num,
         ],
         outputs=[raw_out, vector_only_out, graph_only_out, graph_vector_out],
     )
@@ -237,7 +237,6 @@ def create_rag_block():
         graph_ratio: float,
         rerank_method: Literal["bleu", "reranker"],
         near_neighbor_first: bool,
-        with_gremlin_template: bool,
         custom_related_information: str,
         answer_prompt: str,
         keywords_extract_prompt: str,
@@ -257,7 +256,6 @@ def create_rag_block():
                 graph_ratio,
                 rerank_method,
                 near_neighbor_first,
-                with_gremlin_template,
                 custom_related_information,
                 answer_prompt,
                 keywords_extract_prompt,
@@ -291,7 +289,6 @@ def create_rag_block():
             graph_ratio,
             rerank_method,
             near_neighbor_first,
-            with_gremlin_template_radio,
             custom_related_information,
             answer_prompt_input,
             keywords_extract_prompt_input,
diff --git 
a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/text2gremlin_block.py 
b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/text2gremlin_block.py
index c47ce7c..46e2e9e 100644
--- a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/text2gremlin_block.py
+++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/text2gremlin_block.py
@@ -34,9 +34,9 @@ from hugegraph_llm.utils.log import log
 
 def store_schema(schema, question, gremlin_prompt):
     if (
-        prompt.text2gql_graph_schema != schema or
-        prompt.default_question != question or
-        prompt.gremlin_generate_prompt != gremlin_prompt
+        prompt.text2gql_graph_schema != schema
+        or prompt.default_question != question
+        or prompt.gremlin_generate_prompt != gremlin_prompt
     ):
         prompt.text2gql_graph_schema = schema
         prompt.default_question = question
@@ -90,7 +90,8 @@ def gremlin_generate(
     updated_schema = sm.simple_schema(schema) if short_schema else schema
     store_schema(str(updated_schema), inp, gremlin_prompt)
     context = (
-        
generator.example_index_query(example_num).gremlin_generate_synthesize(updated_schema,
 gremlin_prompt)
+        generator.example_index_query(example_num)
+        .gremlin_generate_synthesize(updated_schema, gremlin_prompt)
         .run(query=inp)
     )
     try:
@@ -183,7 +184,6 @@ def create_text2gremlin_block() -> Tuple:
 def graph_rag_recall(
     query: str,
     gremlin_tmpl_num: int,
-    with_gremlin_tmpl: bool,
     rerank_method: Literal["bleu", "reranker"],
     near_neighbor_first: bool,
     custom_related_information: str,
@@ -193,7 +193,6 @@ def graph_rag_recall(
     rag = RAGPipeline()
 
     
rag.extract_keywords().keywords_to_vid().import_schema(huge_settings.graph_name).query_graphdb(
-        with_gremlin_template=with_gremlin_tmpl,
         num_gremlin_generate_example=gremlin_tmpl_num,
         gremlin_prompt=gremlin_prompt,
     ).merge_dedup_rerank(
diff --git a/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py 
b/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py
index 03ac9ae..399864a 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py
@@ -125,8 +125,7 @@ class RAGPipeline:
         max_v_prop_len: int = 2048,
         max_e_prop_len: int = 256,
         prop_to_match: Optional[str] = None,
-        with_gremlin_template: bool = True,
-        num_gremlin_generate_example: int = 1,
+        num_gremlin_generate_example: Optional[int] = 1,
         gremlin_prompt: Optional[str] = prompt.gremlin_generate_prompt,
     ):
         """
@@ -146,7 +145,6 @@ class RAGPipeline:
                 max_v_prop_len=max_v_prop_len,
                 max_e_prop_len=max_e_prop_len,
                 prop_to_match=prop_to_match,
-                with_gremlin_template=with_gremlin_template,
                 num_gremlin_generate_example=num_gremlin_generate_example,
                 gremlin_prompt=gremlin_prompt,
             )
diff --git 
a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py 
b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py
index a4186c8..e213c37 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py
@@ -82,10 +82,9 @@ class GraphRAGQuery:
         prop_to_match: Optional[str] = None,
         llm: Optional[BaseLLM] = None,
         embedding: Optional[BaseEmbedding] = None,
-        max_v_prop_len: int = 2048,
-        max_e_prop_len: int = 256,
-        with_gremlin_template: bool = True,
-        num_gremlin_generate_example: int = 1,
+        max_v_prop_len: Optional[int] = 2048,
+        max_e_prop_len: Optional[int] = 256,
+        num_gremlin_generate_example: Optional[int] = 1,
         gremlin_prompt: Optional[str] = None,
     ):
         self._client = PyHugeClient(
@@ -108,7 +107,6 @@ class GraphRAGQuery:
             embedding=embedding,
         )
         self._num_gremlin_generate_example = num_gremlin_generate_example
-        self._with_gremlin_template = with_gremlin_template
         self._gremlin_prompt = gremlin_prompt or prompt.gremlin_generate_prompt
 
     def run(self, context: Dict[str, Any]) -> Dict[str, Any]:
@@ -138,7 +136,7 @@ class GraphRAGQuery:
         gremlin_response = self._gremlin_generator.gremlin_generate_synthesize(
             context["simple_schema"], vertices=vertices, 
gremlin_prompt=self._gremlin_prompt
         ).run(query=query, query_embedding=query_embedding)
-        if self._with_gremlin_template:
+        if self._num_gremlin_generate_example > 0:
             gremlin = gremlin_response["result"]
         else:
             gremlin = gremlin_response["raw_result"]
diff --git 
a/hugegraph-llm/src/hugegraph_llm/operators/index_op/gremlin_example_index_query.py
 
b/hugegraph-llm/src/hugegraph_llm/operators/index_op/gremlin_example_index_query.py
index a95e7da..4029995 100644
--- 
a/hugegraph-llm/src/hugegraph_llm/operators/index_op/gremlin_example_index_query.py
+++ 
b/hugegraph-llm/src/hugegraph_llm/operators/index_op/gremlin_example_index_query.py
@@ -38,13 +38,15 @@ class GremlinExampleIndexQuery:
         self.vector_index = VectorIndex.from_index_file(self.index_dir)
 
     def _ensure_index_exists(self):
-        if not (os.path.exists(os.path.join(self.index_dir, "index.faiss"))
-                and os.path.exists(os.path.join(self.index_dir, 
"properties.pkl"))):
+        if not (
+            os.path.exists(os.path.join(self.index_dir, "index.faiss"))
+            and os.path.exists(os.path.join(self.index_dir, "properties.pkl"))
+        ):
             log.warning("No gremlin example index found, will generate one.")
             self._build_default_example_index()
 
     def _get_match_result(self, context: Dict[str, Any], query: str) -> 
List[Dict[str, Any]]:
-        if self.num_examples == 0:
+        if self.num_examples <= 0:
             return []
 
         query_embedding = context.get("query_embedding")
@@ -53,8 +55,7 @@ class GremlinExampleIndexQuery:
         return self.vector_index.search(query_embedding, self.num_examples, 
dis_threshold=1.8)
 
     def _build_default_example_index(self):
-        properties = pd.read_csv(os.path.join(resource_path, "demo",
-                                              
"text2gremlin.csv")).to_dict(orient="records")
+        properties = pd.read_csv(os.path.join(resource_path, "demo", 
"text2gremlin.csv")).to_dict(orient="records")
         embeddings = [self.embedding.get_text_embedding(row["query"]) for row 
in tqdm(properties)]
         vector_index = VectorIndex(len(embeddings[0]))
         vector_index.add(embeddings, properties)

Reply via email to