This is an automated email from the ASF dual-hosted git repository. jin pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/incubator-hugegraph-ai.git
The following commit(s) were added to refs/heads/main by this push: new 39e8d48 refactor(llm): remove enable_gql logic in api & rag block (#148) 39e8d48 is described below commit 39e8d486b31f6a2dd99f8a78dbd6402e91e7ce18 Author: HaoJin Yang <1454...@gmail.com> AuthorDate: Fri Dec 27 16:53:37 2024 +0800 refactor(llm): remove enable_gql logic in api & rag block (#148) * feat(llm): support choose num of examples in rag --------- Co-authored-by: imbajin <j...@apache.org> --- .../src/hugegraph_llm/api/models/rag_requests.py | 6 +++--- hugegraph-llm/src/hugegraph_llm/api/rag_api.py | 10 +++++----- .../src/hugegraph_llm/demo/rag_demo/rag_block.py | 15 ++++++--------- .../src/hugegraph_llm/demo/rag_demo/text2gremlin_block.py | 11 +++++------ .../src/hugegraph_llm/operators/graph_rag_task.py | 4 +--- .../operators/hugegraph_op/graph_rag_query.py | 10 ++++------ .../operators/index_op/gremlin_example_index_query.py | 11 ++++++----- 7 files changed, 30 insertions(+), 37 deletions(-) diff --git a/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py b/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py index 489ce0a..de47aa0 100644 --- a/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py +++ b/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py @@ -32,7 +32,6 @@ class RAGRequest(BaseModel): graph_ratio: float = Query(0.5, description="The ratio of GraphRAG ans & vector ans") rerank_method: Literal["bleu", "reranker"] = Query("bleu", description="Method to rerank the results.") near_neighbor_first: bool = Query(False, description="Prioritize near neighbors in the search results.") - with_gremlin_tmpl: bool = Query(True, description="Use example template in text2gremlin") custom_priority_info: str = Query("", description="Custom information to prioritize certain results.") answer_prompt: Optional[str] = Query(prompt.answer_prompt, description="Prompt to guide the answer generation.") keywords_extract_prompt: Optional[str] = Query( @@ -49,8 +48,9 @@ class RAGRequest(BaseModel): # TODO: import the default value of prompt.* dynamically class GraphRAGRequest(BaseModel): query: str = Query("", description="Query you want to ask") - gremlin_tmpl_num: int = Query(1, description="Number of Gremlin templates to use.") - with_gremlin_tmpl: bool = Query(True, description="Use example template in text2gremlin") + gremlin_tmpl_num: int = Query( + 1, description="Number of Gremlin templates to use. If num <=0 means template is not provided" + ) rerank_method: Literal["bleu", "reranker"] = Query("bleu", description="Method to rerank the results.") near_neighbor_first: bool = Query(False, description="Prioritize near neighbors in the search results.") custom_priority_info: str = Query("", description="Custom information to prioritize certain results.") diff --git a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py b/hugegraph-llm/src/hugegraph_llm/api/rag_api.py index 4036496..d851fd1 100644 --- a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py +++ b/hugegraph-llm/src/hugegraph_llm/api/rag_api.py @@ -49,7 +49,6 @@ def rag_http_api( vector_only_answer=req.vector_only, graph_only_answer=req.graph_only, graph_vector_answer=req.graph_vector_answer, - with_gremlin_template=req.with_gremlin_tmpl, graph_ratio=req.graph_ratio, rerank_method=req.rerank_method, near_neighbor_first=req.near_neighbor_first, @@ -62,9 +61,11 @@ def rag_http_api( # TODO: we need more info in the response for users to understand the query logic return { "query": req.query, - **{key: value - for key, value in zip(["raw_answer", "vector_only", "graph_only", "graph_vector_answer"], result) - if getattr(req, key)} + **{ + key: value + for key, value in zip(["raw_answer", "vector_only", "graph_only", "graph_vector_answer"], result) + if getattr(req, key) + }, } @router.post("/rag/graph", status_code=status.HTTP_200_OK) @@ -73,7 +74,6 @@ def rag_http_api( result = graph_rag_recall_func( query=req.query, gremlin_tmpl_num=req.gremlin_tmpl_num, - with_gremlin_tmpl=req.with_gremlin_tmpl, rerank_method=req.rerank_method, near_neighbor_first=req.near_neighbor_first, custom_related_information=req.custom_priority_info, diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/rag_block.py b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/rag_block.py index 070f3b3..c10f84b 100644 --- a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/rag_block.py +++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/rag_block.py @@ -35,7 +35,6 @@ def rag_answer( vector_only_answer: bool, graph_only_answer: bool, graph_vector_answer: bool, - with_gremlin_template: bool, graph_ratio: float, rerank_method: Literal["bleu", "reranker"], near_neighbor_first: bool, @@ -80,7 +79,6 @@ def rag_answer( rag.extract_keywords(extract_template=keywords_extract_prompt).keywords_to_vid().import_schema( huge_settings.graph_name ).query_graphdb( - with_gremlin_template=with_gremlin_template, num_gremlin_generate_example=gremlin_tmpl_num, gremlin_prompt=gremlin_prompt, ) @@ -125,7 +123,10 @@ def create_rag_block(): value=prompt.answer_prompt, label="Query Prompt", show_copy_button=True, lines=7 ) keywords_extract_prompt_input = gr.Textbox( - value=prompt.keywords_extract_prompt, label="Keywords Extraction Prompt", show_copy_button=True, lines=7 + value=prompt.keywords_extract_prompt, + label="Keywords Extraction Prompt", + show_copy_button=True, + lines=7, ) with gr.Column(scale=1): with gr.Row(): @@ -134,8 +135,6 @@ def create_rag_block(): with gr.Row(): graph_only_radio = gr.Radio(choices=[True, False], value=True, label="Graph-only Answer") graph_vector_radio = gr.Radio(choices=[True, False], value=False, label="Graph-Vector Answer") - with gr.Row(): - with_gremlin_template_radio = gr.Radio(choices=[True, False], value=True, label="With Gremlin Template") def toggle_slider(enable): return gr.update(interactive=enable) @@ -148,6 +147,7 @@ def create_rag_block(): value="reranker" if online_rerank else "bleu", label="Rerank method", ) + example_num = gr.Number(value=2, label="Template Num (0 to disable it) ", precision=0) graph_ratio = gr.Slider(0, 1, 0.6, label="Graph Ratio", step=0.1, interactive=False) graph_vector_radio.change( @@ -172,13 +172,13 @@ def create_rag_block(): vector_only_radio, graph_only_radio, graph_vector_radio, - with_gremlin_template_radio, graph_ratio, rerank_method, near_neighbor_first, custom_related_information, answer_prompt_input, keywords_extract_prompt_input, + example_num, ], outputs=[raw_out, vector_only_out, graph_only_out, graph_vector_out], ) @@ -237,7 +237,6 @@ def create_rag_block(): graph_ratio: float, rerank_method: Literal["bleu", "reranker"], near_neighbor_first: bool, - with_gremlin_template: bool, custom_related_information: str, answer_prompt: str, keywords_extract_prompt: str, @@ -257,7 +256,6 @@ def create_rag_block(): graph_ratio, rerank_method, near_neighbor_first, - with_gremlin_template, custom_related_information, answer_prompt, keywords_extract_prompt, @@ -291,7 +289,6 @@ def create_rag_block(): graph_ratio, rerank_method, near_neighbor_first, - with_gremlin_template_radio, custom_related_information, answer_prompt_input, keywords_extract_prompt_input, diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/text2gremlin_block.py b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/text2gremlin_block.py index c47ce7c..46e2e9e 100644 --- a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/text2gremlin_block.py +++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/text2gremlin_block.py @@ -34,9 +34,9 @@ from hugegraph_llm.utils.log import log def store_schema(schema, question, gremlin_prompt): if ( - prompt.text2gql_graph_schema != schema or - prompt.default_question != question or - prompt.gremlin_generate_prompt != gremlin_prompt + prompt.text2gql_graph_schema != schema + or prompt.default_question != question + or prompt.gremlin_generate_prompt != gremlin_prompt ): prompt.text2gql_graph_schema = schema prompt.default_question = question @@ -90,7 +90,8 @@ def gremlin_generate( updated_schema = sm.simple_schema(schema) if short_schema else schema store_schema(str(updated_schema), inp, gremlin_prompt) context = ( - generator.example_index_query(example_num).gremlin_generate_synthesize(updated_schema, gremlin_prompt) + generator.example_index_query(example_num) + .gremlin_generate_synthesize(updated_schema, gremlin_prompt) .run(query=inp) ) try: @@ -183,7 +184,6 @@ def create_text2gremlin_block() -> Tuple: def graph_rag_recall( query: str, gremlin_tmpl_num: int, - with_gremlin_tmpl: bool, rerank_method: Literal["bleu", "reranker"], near_neighbor_first: bool, custom_related_information: str, @@ -193,7 +193,6 @@ def graph_rag_recall( rag = RAGPipeline() rag.extract_keywords().keywords_to_vid().import_schema(huge_settings.graph_name).query_graphdb( - with_gremlin_template=with_gremlin_tmpl, num_gremlin_generate_example=gremlin_tmpl_num, gremlin_prompt=gremlin_prompt, ).merge_dedup_rerank( diff --git a/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py b/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py index 03ac9ae..399864a 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py @@ -125,8 +125,7 @@ class RAGPipeline: max_v_prop_len: int = 2048, max_e_prop_len: int = 256, prop_to_match: Optional[str] = None, - with_gremlin_template: bool = True, - num_gremlin_generate_example: int = 1, + num_gremlin_generate_example: Optional[int] = 1, gremlin_prompt: Optional[str] = prompt.gremlin_generate_prompt, ): """ @@ -146,7 +145,6 @@ class RAGPipeline: max_v_prop_len=max_v_prop_len, max_e_prop_len=max_e_prop_len, prop_to_match=prop_to_match, - with_gremlin_template=with_gremlin_template, num_gremlin_generate_example=num_gremlin_generate_example, gremlin_prompt=gremlin_prompt, ) diff --git a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py index a4186c8..e213c37 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py @@ -82,10 +82,9 @@ class GraphRAGQuery: prop_to_match: Optional[str] = None, llm: Optional[BaseLLM] = None, embedding: Optional[BaseEmbedding] = None, - max_v_prop_len: int = 2048, - max_e_prop_len: int = 256, - with_gremlin_template: bool = True, - num_gremlin_generate_example: int = 1, + max_v_prop_len: Optional[int] = 2048, + max_e_prop_len: Optional[int] = 256, + num_gremlin_generate_example: Optional[int] = 1, gremlin_prompt: Optional[str] = None, ): self._client = PyHugeClient( @@ -108,7 +107,6 @@ class GraphRAGQuery: embedding=embedding, ) self._num_gremlin_generate_example = num_gremlin_generate_example - self._with_gremlin_template = with_gremlin_template self._gremlin_prompt = gremlin_prompt or prompt.gremlin_generate_prompt def run(self, context: Dict[str, Any]) -> Dict[str, Any]: @@ -138,7 +136,7 @@ class GraphRAGQuery: gremlin_response = self._gremlin_generator.gremlin_generate_synthesize( context["simple_schema"], vertices=vertices, gremlin_prompt=self._gremlin_prompt ).run(query=query, query_embedding=query_embedding) - if self._with_gremlin_template: + if self._num_gremlin_generate_example > 0: gremlin = gremlin_response["result"] else: gremlin = gremlin_response["raw_result"] diff --git a/hugegraph-llm/src/hugegraph_llm/operators/index_op/gremlin_example_index_query.py b/hugegraph-llm/src/hugegraph_llm/operators/index_op/gremlin_example_index_query.py index a95e7da..4029995 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/index_op/gremlin_example_index_query.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/index_op/gremlin_example_index_query.py @@ -38,13 +38,15 @@ class GremlinExampleIndexQuery: self.vector_index = VectorIndex.from_index_file(self.index_dir) def _ensure_index_exists(self): - if not (os.path.exists(os.path.join(self.index_dir, "index.faiss")) - and os.path.exists(os.path.join(self.index_dir, "properties.pkl"))): + if not ( + os.path.exists(os.path.join(self.index_dir, "index.faiss")) + and os.path.exists(os.path.join(self.index_dir, "properties.pkl")) + ): log.warning("No gremlin example index found, will generate one.") self._build_default_example_index() def _get_match_result(self, context: Dict[str, Any], query: str) -> List[Dict[str, Any]]: - if self.num_examples == 0: + if self.num_examples <= 0: return [] query_embedding = context.get("query_embedding") @@ -53,8 +55,7 @@ class GremlinExampleIndexQuery: return self.vector_index.search(query_embedding, self.num_examples, dis_threshold=1.8) def _build_default_example_index(self): - properties = pd.read_csv(os.path.join(resource_path, "demo", - "text2gremlin.csv")).to_dict(orient="records") + properties = pd.read_csv(os.path.join(resource_path, "demo", "text2gremlin.csv")).to_dict(orient="records") embeddings = [self.embedding.get_text_embedding(row["query"]) for row in tqdm(properties)] vector_index = VectorIndex(len(embeddings[0])) vector_index.add(embeddings, properties)