imbajin commented on code in PR #73:
URL: 
https://github.com/apache/incubator-hugegraph-ai/pull/73#discussion_r1732062905


##########
hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py:
##########
@@ -228,6 +246,37 @@ def apply_embedding_config(arg1, arg2, arg3, 
origin_call=None) -> int:
     return status_code
 
 
+def apply_reranker_config(arg1, arg2, arg3: str | None = None, 
origin_call=None) -> int:

Review Comment:
   if `arg1~n` is not flexible, better to use a real name for it (like 
`api_key`...)



##########
hugegraph-llm/src/hugegraph_llm/api/rag_api.py:
##########
@@ -45,22 +63,50 @@ def llm_config_api(req: LLMConfigRequest):
 
         if req.llm_type == "openai":
             res = apply_llm_conf(
-                req.api_key, req.api_base, req.language_model, req.max_tokens, 
origin_call="http"
+                req.api_key,
+                req.api_base,
+                req.language_model,
+                req.max_tokens,
+                origin_call="http",
             )
         elif req.llm_type == "qianfan_wenxin":
-            res = apply_llm_conf(req.api_key, req.secret_key, 
req.language_model, None, origin_call="http")
+            res = apply_llm_conf(
+                req.api_key,
+                req.secret_key,
+                req.language_model,
+                None,
+                origin_call="http",
+            )
         else:
-            res = apply_llm_conf(req.host, req.port, req.language_model, None, 
origin_call="http")
+            res = apply_llm_conf(
+                req.host, req.port, req.language_model, None, 
origin_call="http"
+            )
         return generate_response(RAGResponse(status_code=res, message="Missing 
Value"))
 
     @router.post("/config/embedding", status_code=status.HTTP_201_CREATED)
     def embedding_config_api(req: LLMConfigRequest):
         settings.embedding_type = req.llm_type
 
         if req.llm_type == "openai":
-            res = apply_embedding_conf(req.api_key, req.api_base, 
req.language_model, origin_call="http")
+            res = apply_embedding_conf(
+                req.api_key, req.api_base, req.language_model, 
origin_call="http"
+            )
         elif req.llm_type == "qianfan_wenxin":
-            res = apply_embedding_conf(req.api_key, req.api_base, None, 
origin_call="http")
+            res = apply_embedding_conf(
+                req.api_key, req.api_base, None, origin_call="http"
+            )
         else:
-            res = apply_embedding_conf(req.host, req.port, req.language_model, 
origin_call="http")
+            res = apply_embedding_conf(
+                req.host, req.port, req.language_model, origin_call="http"
+            )
+        return generate_response(RAGResponse(status_code=res, message="Missing 
Value"))
+
+    @router.post("/config/rerank", status_code=status.HTTP_201_CREATED)
+    def rerank_config_api(req: RerankerConfigRequest):
+        settings.reranker_type = req.reranker_type
+
+        if req.reranker_type == "cohere":
+            res = apply_reranker_conf(

Review Comment:
   `res` might be `Null` here? Maybe define it in advance?



##########
hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py:
##########
@@ -72,7 +82,9 @@ def rag_answer(
     if graph_search:
         searcher.extract_keyword().match_keyword_to_id().query_graph_for_rag()
     # TODO: add more user-defined search strategies
-    searcher.merge_dedup_rerank().synthesize_answer(
+    searcher.merge_dedup_rerank(
+        graph_ratio, rerank_method, near_neighbor_first, 
custom_related_information

Review Comment:
   `rerank_method` is `str`, but expected type  is `Literal["bleu", "reranker"]`



##########
hugegraph-llm/src/hugegraph_llm/operators/common_op/merge_dedup_rerank.py:
##########
@@ -16,59 +16,119 @@
 # under the License.
 
 
-from typing import Dict, Any, List, Literal
+from typing import Literal, Dict, Any, List, Optional, Tuple
 
 import jieba
 from hugegraph_llm.models.embeddings.base import BaseEmbedding
+from hugegraph_llm.models.rerankers.init_reranker import Rerankers
 from nltk.translate.bleu_score import sentence_bleu
 
 
-def get_score(query: str, content: str) -> float:
+def get_bleu_score(query: str, content: str) -> float:
     query_tokens = jieba.lcut(query)
     content_tokens = jieba.lcut(content)
     return sentence_bleu([query_tokens], content_tokens)
 
 
 class MergeDedupRerank:
     def __init__(
-            self,
-            embedding: BaseEmbedding,
-            topk: int = 10,
-            strategy: Literal["bleu", "priority"] = "bleu"
+        self,
+        embedding: BaseEmbedding,
+        topk: int = 20,
+        graph_ratio: float = 0.5,
+        method: Literal["bleu", "reranker"] = "bleu",
+        near_neighbor_first: bool = False,
+        custom_related_information: Optional[str] = None,
     ):
+        assert method in [
+            "bleu",
+            "reranker",
+        ], f"Unimplemented rerank method '{method}'."
         self.embedding = embedding
+        self.graph_ratio = graph_ratio
         self.topk = topk
-        if strategy == "bleu":
-            self.rerank_func = self._bleu_rerank
-        elif strategy == "priority":
-            self.rerank_func = self._priority_rerank
-        else:
-            raise ValueError(f"Unimplemented rerank strategy {strategy}.")
+        self.method = method
+        self.near_neighbor_first = near_neighbor_first
+        self.custom_related_information = custom_related_information
 
     def run(self, context: Dict[str, Any]) -> Dict[str, Any]:
         query = context.get("query")
+        if self.custom_related_information:
+            query = query + self.custom_related_information
+        context["graph_ratio"] = self.graph_ratio
+        vector_search = context.get("vector_search", False)
+        graph_search = context.get("graph_search", False)
+        if graph_search and vector_search:
+            graph_length = int(self.topk * self.graph_ratio)
+            vector_length = self.topk - graph_length
+        else:
+            graph_length = self.topk
+            vector_length = self.topk
 
         vector_result = context.get("vector_result", [])
-        vector_result = self.rerank_func(query, vector_result)[:self.topk]
+        vector_length = min(len(vector_result), vector_length)
+        vector_result = self._dedup_and_rerank(query, vector_result, 
vector_length)
 
         graph_result = context.get("graph_result", [])
-        graph_result = self.rerank_func(query, graph_result)[:self.topk]
+        graph_length = min(len(graph_result), graph_length)
+        if self.near_neighbor_first:
+            graph_result = self._rerank_with_vertex_degree(
+                query,
+                graph_result,
+                graph_length,
+                context.get("vertex_degree_list"),
+                context.get("knowledge_with_degree"),
+            )
+        else:
+            graph_result = self._dedup_and_rerank(query, graph_result, 
graph_length)
 
         context["vector_result"] = vector_result
         context["graph_result"] = graph_result
 
         return context
 
-    def _bleu_rerank(self, query: str, results: List[str]):
+    def _dedup_and_rerank(self, query: str, results: List[str], topn: int) -> 
List[str]:
         results = list(set(results))
-        result_score_list = [[res, get_score(query, res)] for res in results]
-        result_score_list.sort(key=lambda x: x[1], reverse=True)
-        return [res[0] for res in result_score_list]
-
-    def _priority_rerank(self, query: str, results: List[str]):
-        # TODO: implement
-        # 1. Precise recall > Fuzzy recall
-        # 2. 1-degree neighbors > 2-degree neighbors
-        # 3. The priority of a certain type of point is higher than others,
-        # such as Law being higher than vehicles/people/locations
-        raise NotImplementedError()
+        if self.method == "bleu":
+            result_score_list = [[res, get_bleu_score(query, res)] for res in 
results]
+            result_score_list.sort(key=lambda x: x[1], reverse=True)
+            return [res[0] for res in result_score_list][:topn]
+        if self.method == "reranker":
+            reranker = Rerankers().get_reranker()
+            return reranker.get_rerank_lists(query, results, topn)
+
+    def _rerank_with_vertex_degree(
+        self,
+        query: str,
+        results: List[str],
+        topn: int,
+        vertex_degree_list: List[List[str]] | None,
+        knowledge_with_degree: Dict[str, List[str]] | None,
+    ) -> List[str]:
+        if vertex_degree_list is None or len(vertex_degree_list) == 0:
+            return self._dedup_and_rerank(query, results, topn)
+        if self.method == "bleu":
+            vertex_degree_rerank_result: List[List[str]] = []
+            for vertex_degree in vertex_degree_list:
+                vertex_degree_score_list = [[res, get_bleu_score(query, res)] 
for res in vertex_degree]
+                vertex_degree_score_list.sort(key=lambda x: x[1], reverse=True)
+                vertex_degree = [res[0] for res in vertex_degree_score_list] + 
[""]
+                vertex_degree_rerank_result.append(vertex_degree)
+
+        if self.method == "reranker":
+            reranker = Rerankers().get_reranker()
+            vertex_degree_rerank_result = [
+                reranker.get_rerank_lists(query, vertex_degree) + [""] for 
vertex_degree in vertex_degree_list
+            ]
+        depth = len(vertex_degree_list)
+        for result in results:
+            if result not in knowledge_with_degree:
+                knowledge_with_degree[result] = [result] + [""] * (depth - 1)
+            if len(knowledge_with_degree[result]) < depth:
+                knowledge_with_degree[result] += [""] * (depth - 
len(knowledge_with_degree[result]))
+
+        def sort_key(result: str) -> Tuple[int, ...]:

Review Comment:
   `result` is used in global, better to change the name



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to