This is an automated email from the ASF dual-hosted git repository.

jin pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-hugegraph-ai.git


The following commit(s) were added to refs/heads/main by this push:
     new c26f5bf  refactor(llm): enhance the regex extraction func (#194)
c26f5bf is described below

commit c26f5bfb44bf6e349adadf071c0497f5d5d6ea95
Author: HaoJin Yang <1454...@gmail.com>
AuthorDate: Fri Mar 7 14:01:40 2025 +0800

    refactor(llm): enhance the regex extraction func (#194)
---
 .../operators/llm_op/gremlin_generate.py           | 34 +++++++++++-----------
 1 file changed, 17 insertions(+), 17 deletions(-)

diff --git 
a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/gremlin_generate.py 
b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/gremlin_generate.py
index 219a358..09e01e5 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/gremlin_generate.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/gremlin_generate.py
@@ -32,7 +32,7 @@ class GremlinGenerateSynthesize:
         llm: BaseLLM = None,
         schema: Optional[Union[dict, str]] = None,
         vertices: Optional[List[str]] = None,
-        gremlin_prompt: Optional[str] = None
+        gremlin_prompt: Optional[str] = None,
     ) -> None:
         self.llm = llm or LLMs().get_text2gql_llm()
         if isinstance(schema, dict):
@@ -41,10 +41,10 @@ class GremlinGenerateSynthesize:
         self.vertices = vertices
         self.gremlin_prompt = gremlin_prompt or prompt.gremlin_generate_prompt
 
-    def _extract_gremlin(self, response: str) -> str:
-        match = re.search("```gremlin.*```", response, re.DOTALL)
-        assert match is not None, f"No gremlin found in response: {response}"
-        return match.group()[len("```gremlin"):-len("```")].strip()
+    def _extract_response(self, response: str, label: str = "gremlin") -> str:
+        match = re.search(f"```{label}(.*?)```", response, re.DOTALL)
+        assert match is not None, f"No {label} found in response: {response}"
+        return match.group(1).strip()
 
     def _format_examples(self, examples: Optional[List[Dict[str, str]]]) -> 
Optional[str]:
         if not examples:
@@ -52,8 +52,8 @@ class GremlinGenerateSynthesize:
         example_strings = []
         for example in examples:
             example_strings.append(
-                f"- query: {example['query']}\n"
-                f"- gremlin:\n```gremlin\n{example['gremlin']}\n```")
+                f"- query: {example['query']}\n" f"- 
gremlin:\n```gremlin\n{example['gremlin']}\n```"
+            )
         return "\n\n".join(example_strings)
 
     def _format_vertices(self, vertices: Optional[List[str]]) -> Optional[str]:
@@ -64,12 +64,12 @@ class GremlinGenerateSynthesize:
     async def async_generate(self, context: Dict[str, Any]):
         async_tasks = {}
         query = context.get("query")
-        raw_example = [{'query': 'who is peter', 'gremlin': "g.V().has('name', 
'peter')"}]
+        raw_example = [{"query": "who is peter", "gremlin": "g.V().has('name', 
'peter')"}]
         raw_prompt = self.gremlin_prompt.format(
             query=query,
             schema=self.schema,
             example=self._format_examples(examples=raw_example),
-            vertices=self._format_vertices(vertices=self.vertices)
+            vertices=self._format_vertices(vertices=self.vertices),
         )
         async_tasks["raw_answer"] = 
asyncio.create_task(self.llm.agenerate(prompt=raw_prompt))
 
@@ -78,7 +78,7 @@ class GremlinGenerateSynthesize:
             query=query,
             schema=self.schema,
             example=self._format_examples(examples=examples),
-            vertices=self._format_vertices(vertices=self.vertices)
+            vertices=self._format_vertices(vertices=self.vertices),
         )
         async_tasks["initialized_answer"] = 
asyncio.create_task(self.llm.agenerate(prompt=init_prompt))
 
@@ -86,20 +86,20 @@ class GremlinGenerateSynthesize:
         initialized_response = await async_tasks["initialized_answer"]
         log.debug("Text2Gremlin with tmpl prompt:\n %s,\n LLM Response: %s", 
init_prompt, initialized_response)
 
-        context["result"] = 
self._extract_gremlin(response=initialized_response)
-        context["raw_result"] = self._extract_gremlin(response=raw_response)
+        context["result"] = 
self._extract_response(response=initialized_response)
+        context["raw_result"] = self._extract_response(response=raw_response)
         context["call_count"] = context.get("call_count", 0) + 2
 
         return context
 
     def sync_generate(self, context: Dict[str, Any]):
         query = context.get("query")
-        raw_example = [{'query': 'who is peter', 'gremlin': "g.V().has('name', 
'peter')"}]
+        raw_example = [{"query": "who is peter", "gremlin": "g.V().has('name', 
'peter')"}]
         raw_prompt = self.gremlin_prompt.format(
             query=query,
             schema=self.schema,
             example=self._format_examples(examples=raw_example),
-            vertices=self._format_vertices(vertices=self.vertices)
+            vertices=self._format_vertices(vertices=self.vertices),
         )
         raw_response = self.llm.generate(prompt=raw_prompt)
 
@@ -108,14 +108,14 @@ class GremlinGenerateSynthesize:
             query=query,
             schema=self.schema,
             example=self._format_examples(examples=examples),
-            vertices=self._format_vertices(vertices=self.vertices)
+            vertices=self._format_vertices(vertices=self.vertices),
         )
         initialized_response = self.llm.generate(prompt=init_prompt)
 
         log.debug("Text2Gremlin with tmpl prompt:\n %s,\n LLM Response: %s", 
init_prompt, initialized_response)
 
-        context["result"] = 
self._extract_gremlin(response=initialized_response)
-        context["raw_result"] = self._extract_gremlin(response=raw_response)
+        context["result"] = 
self._extract_response(response=initialized_response)
+        context["raw_result"] = self._extract_response(response=raw_response)
         context["call_count"] = context.get("call_count", 0) + 2
 
         return context

Reply via email to