This is an automated email from the ASF dual-hosted git repository. jin pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/incubator-hugegraph-ai.git
The following commit(s) were added to refs/heads/main by this push: new c26f5bf refactor(llm): enhance the regex extraction func (#194) c26f5bf is described below commit c26f5bfb44bf6e349adadf071c0497f5d5d6ea95 Author: HaoJin Yang <1454...@gmail.com> AuthorDate: Fri Mar 7 14:01:40 2025 +0800 refactor(llm): enhance the regex extraction func (#194) --- .../operators/llm_op/gremlin_generate.py | 34 +++++++++++----------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/gremlin_generate.py b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/gremlin_generate.py index 219a358..09e01e5 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/gremlin_generate.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/gremlin_generate.py @@ -32,7 +32,7 @@ class GremlinGenerateSynthesize: llm: BaseLLM = None, schema: Optional[Union[dict, str]] = None, vertices: Optional[List[str]] = None, - gremlin_prompt: Optional[str] = None + gremlin_prompt: Optional[str] = None, ) -> None: self.llm = llm or LLMs().get_text2gql_llm() if isinstance(schema, dict): @@ -41,10 +41,10 @@ class GremlinGenerateSynthesize: self.vertices = vertices self.gremlin_prompt = gremlin_prompt or prompt.gremlin_generate_prompt - def _extract_gremlin(self, response: str) -> str: - match = re.search("```gremlin.*```", response, re.DOTALL) - assert match is not None, f"No gremlin found in response: {response}" - return match.group()[len("```gremlin"):-len("```")].strip() + def _extract_response(self, response: str, label: str = "gremlin") -> str: + match = re.search(f"```{label}(.*?)```", response, re.DOTALL) + assert match is not None, f"No {label} found in response: {response}" + return match.group(1).strip() def _format_examples(self, examples: Optional[List[Dict[str, str]]]) -> Optional[str]: if not examples: @@ -52,8 +52,8 @@ class GremlinGenerateSynthesize: example_strings = [] for example in examples: example_strings.append( - f"- query: {example['query']}\n" - f"- gremlin:\n```gremlin\n{example['gremlin']}\n```") + f"- query: {example['query']}\n" f"- gremlin:\n```gremlin\n{example['gremlin']}\n```" + ) return "\n\n".join(example_strings) def _format_vertices(self, vertices: Optional[List[str]]) -> Optional[str]: @@ -64,12 +64,12 @@ class GremlinGenerateSynthesize: async def async_generate(self, context: Dict[str, Any]): async_tasks = {} query = context.get("query") - raw_example = [{'query': 'who is peter', 'gremlin': "g.V().has('name', 'peter')"}] + raw_example = [{"query": "who is peter", "gremlin": "g.V().has('name', 'peter')"}] raw_prompt = self.gremlin_prompt.format( query=query, schema=self.schema, example=self._format_examples(examples=raw_example), - vertices=self._format_vertices(vertices=self.vertices) + vertices=self._format_vertices(vertices=self.vertices), ) async_tasks["raw_answer"] = asyncio.create_task(self.llm.agenerate(prompt=raw_prompt)) @@ -78,7 +78,7 @@ class GremlinGenerateSynthesize: query=query, schema=self.schema, example=self._format_examples(examples=examples), - vertices=self._format_vertices(vertices=self.vertices) + vertices=self._format_vertices(vertices=self.vertices), ) async_tasks["initialized_answer"] = asyncio.create_task(self.llm.agenerate(prompt=init_prompt)) @@ -86,20 +86,20 @@ class GremlinGenerateSynthesize: initialized_response = await async_tasks["initialized_answer"] log.debug("Text2Gremlin with tmpl prompt:\n %s,\n LLM Response: %s", init_prompt, initialized_response) - context["result"] = self._extract_gremlin(response=initialized_response) - context["raw_result"] = self._extract_gremlin(response=raw_response) + context["result"] = self._extract_response(response=initialized_response) + context["raw_result"] = self._extract_response(response=raw_response) context["call_count"] = context.get("call_count", 0) + 2 return context def sync_generate(self, context: Dict[str, Any]): query = context.get("query") - raw_example = [{'query': 'who is peter', 'gremlin': "g.V().has('name', 'peter')"}] + raw_example = [{"query": "who is peter", "gremlin": "g.V().has('name', 'peter')"}] raw_prompt = self.gremlin_prompt.format( query=query, schema=self.schema, example=self._format_examples(examples=raw_example), - vertices=self._format_vertices(vertices=self.vertices) + vertices=self._format_vertices(vertices=self.vertices), ) raw_response = self.llm.generate(prompt=raw_prompt) @@ -108,14 +108,14 @@ class GremlinGenerateSynthesize: query=query, schema=self.schema, example=self._format_examples(examples=examples), - vertices=self._format_vertices(vertices=self.vertices) + vertices=self._format_vertices(vertices=self.vertices), ) initialized_response = self.llm.generate(prompt=init_prompt) log.debug("Text2Gremlin with tmpl prompt:\n %s,\n LLM Response: %s", init_prompt, initialized_response) - context["result"] = self._extract_gremlin(response=initialized_response) - context["raw_result"] = self._extract_gremlin(response=raw_response) + context["result"] = self._extract_response(response=initialized_response) + context["raw_result"] = self._extract_response(response=raw_response) context["call_count"] = context.get("call_count", 0) + 2 return context