This is an automated email from the ASF dual-hosted git repository.

jin pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-hugegraph-ai.git


The following commit(s) were added to refs/heads/main by this push:
     new d2fcfdb  fix(llm): update prompt to fit prefix cache (#137)
d2fcfdb is described below

commit d2fcfdb2348baea1d5ebf526d6e687271d5cda03
Author: HaoJin Yang <1454...@gmail.com>
AuthorDate: Mon Dec 23 00:32:45 2024 +0800

    fix(llm): update prompt to fit prefix cache (#137)
    
    * fix vid not readable for LLM in gremlin prompt
    
    ---------
    
    Co-authored-by: imbajin <j...@apache.org>
---
 .../src/hugegraph_llm/api/models/rag_requests.py   |   2 +-
 hugegraph-llm/src/hugegraph_llm/api/rag_api.py     |  88 ++++++++----------
 .../src/hugegraph_llm/config/prompt_config.py      |  79 +++++++++-------
 .../src/hugegraph_llm/demo/rag_demo/app.py         |  10 +-
 .../demo/rag_demo/text2gremlin_block.py            | 102 ++++++++++++++-------
 .../operators/llm_op/gremlin_generate.py           |  14 +--
 6 files changed, 166 insertions(+), 129 deletions(-)

diff --git a/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py 
b/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py
index 1eaa1e2..489ce0a 100644
--- a/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py
+++ b/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py
@@ -46,11 +46,11 @@ class RAGRequest(BaseModel):
     )
 
 
+# TODO: import the default value of prompt.* dynamically
 class GraphRAGRequest(BaseModel):
     query: str = Query("", description="Query you want to ask")
     gremlin_tmpl_num: int = Query(1, description="Number of Gremlin templates 
to use.")
     with_gremlin_tmpl: bool = Query(True, description="Use example template in 
text2gremlin")
-    answer_prompt: Optional[str] = Query(prompt.answer_prompt, 
description="Prompt to guide the answer generation.")
     rerank_method: Literal["bleu", "reranker"] = Query("bleu", 
description="Method to rerank the results.")
     near_neighbor_first: bool = Query(False, description="Prioritize near 
neighbors in the search results.")
     custom_priority_info: str = Query("", description="Custom information to 
prioritize certain results.")
diff --git a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py 
b/hugegraph-llm/src/hugegraph_llm/api/rag_api.py
index 8a59c8b..4036496 100644
--- a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py
+++ b/hugegraph-llm/src/hugegraph_llm/api/rag_api.py
@@ -16,7 +16,6 @@
 # under the License.
 
 import json
-from typing import Literal
 
 from fastapi import status, APIRouter, HTTPException
 
@@ -29,72 +28,52 @@ from hugegraph_llm.api.models.rag_requests import (
     GraphRAGRequest,
 )
 from hugegraph_llm.api.models.rag_response import RAGResponse
-from hugegraph_llm.config import llm_settings, huge_settings, prompt
+from hugegraph_llm.config import llm_settings, prompt
 from hugegraph_llm.utils.log import log
 
 
-def graph_rag_recall(
-    query: str,
-    gremlin_tmpl_num: int,
-    with_gremlin_tmpl: bool,
-    answer_prompt: str,  # FIXME: should be used in the query
-    rerank_method: Literal["bleu", "reranker"],
-    near_neighbor_first: bool,
-    custom_related_information: str,
-    gremlin_prompt: str,
-) -> dict:
-    from hugegraph_llm.operators.graph_rag_task import RAGPipeline
-
-    rag = RAGPipeline()
-
-    
rag.extract_keywords().keywords_to_vid().import_schema(huge_settings.graph_name).query_graphdb(
-        with_gremlin_template=with_gremlin_tmpl,
-        num_gremlin_generate_example=gremlin_tmpl_num,
-        gremlin_prompt=gremlin_prompt,
-    ).merge_dedup_rerank(
-        rerank_method=rerank_method,
-        near_neighbor_first=near_neighbor_first,
-        custom_related_information=custom_related_information,
-    )
-    context = rag.run(verbose=True, query=query, graph_search=True)
-    return context
-
-
 def rag_http_api(
-    router: APIRouter, rag_answer_func, apply_graph_conf, apply_llm_conf, 
apply_embedding_conf, apply_reranker_conf
+    router: APIRouter,
+    rag_answer_func,
+    graph_rag_recall_func,
+    apply_graph_conf,
+    apply_llm_conf,
+    apply_embedding_conf,
+    apply_reranker_conf,
 ):
     @router.post("/rag", status_code=status.HTTP_200_OK)
     def rag_answer_api(req: RAGRequest):
         result = rag_answer_func(
-            req.query,
-            req.raw_answer,
-            req.vector_only,
-            req.graph_only,
-            req.graph_vector_answer,
-            req.with_gremlin_tmpl,
-            req.graph_ratio,
-            req.rerank_method,
-            req.near_neighbor_first,
-            req.custom_priority_info,
-            req.answer_prompt or prompt.answer_prompt,
-            req.keywords_extract_prompt or prompt.keywords_extract_prompt,
-            req.gremlin_tmpl_num,
-            req.gremlin_prompt or prompt.gremlin_generate_prompt,
+            text=req.query,
+            raw_answer=req.raw_answer,
+            vector_only_answer=req.vector_only,
+            graph_only_answer=req.graph_only,
+            graph_vector_answer=req.graph_vector_answer,
+            with_gremlin_template=req.with_gremlin_tmpl,
+            graph_ratio=req.graph_ratio,
+            rerank_method=req.rerank_method,
+            near_neighbor_first=req.near_neighbor_first,
+            custom_related_information=req.custom_priority_info,
+            answer_prompt=req.answer_prompt or prompt.answer_prompt,
+            keywords_extract_prompt=req.keywords_extract_prompt or 
prompt.keywords_extract_prompt,
+            gremlin_tmpl_num=req.gremlin_tmpl_num,
+            gremlin_prompt=req.gremlin_prompt or 
prompt.gremlin_generate_prompt,
         )
+        # TODO: we need more info in the response for users to understand the 
query logic
         return {
-            key: value
-            for key, value in zip(["raw_answer", "vector_only", "graph_only", 
"graph_vector_answer"], result)
-            if getattr(req, key)
+            "query": req.query,
+            **{key: value
+               for key, value in zip(["raw_answer", "vector_only", 
"graph_only", "graph_vector_answer"], result)
+               if getattr(req, key)}
         }
 
     @router.post("/rag/graph", status_code=status.HTTP_200_OK)
     def graph_rag_recall_api(req: GraphRAGRequest):
         try:
-            result = graph_rag_recall(
+            result = graph_rag_recall_func(
                 query=req.query,
                 gremlin_tmpl_num=req.gremlin_tmpl_num,
                 with_gremlin_tmpl=req.with_gremlin_tmpl,
-                answer_prompt=req.answer_prompt or prompt.answer_prompt,
                 rerank_method=req.rerank_method,
                 near_neighbor_first=req.near_neighbor_first,
                 custom_related_information=req.custom_priority_info,
@@ -102,8 +81,15 @@ def rag_http_api(
             )
 
             if isinstance(result, dict):
-                params = ["query", "keywords", "match_vids", 
"graph_result_flag", "gremlin", "graph_result",
-                          "vertex_degree_list"]
+                params = [
+                    "query",
+                    "keywords",
+                    "match_vids",
+                    "graph_result_flag",
+                    "gremlin",
+                    "graph_result",
+                    "vertex_degree_list",
+                ]
                 user_result = {key: result[key] for key in params if key in 
result}
                 return {"graph_recall": user_result}
             # Note: Maybe only for qianfan/wenxin
diff --git a/hugegraph-llm/src/hugegraph_llm/config/prompt_config.py 
b/hugegraph-llm/src/hugegraph_llm/config/prompt_config.py
index 4d37ced..f6aeef5 100644
--- a/hugegraph-llm/src/hugegraph_llm/config/prompt_config.py
+++ b/hugegraph-llm/src/hugegraph_llm/config/prompt_config.py
@@ -18,18 +18,19 @@
 
 from hugegraph_llm.config.models.base_prompt_config import BasePromptConfig
 
+
 class PromptConfig(BasePromptConfig):
     # Data is detached from llm_op/answer_synthesize.py
     answer_prompt: str = """You are an expert in knowledge graphs and natural 
language processing.
 Your task is to provide a precise and accurate answer based on the given 
context.
 
+Given the context information and without using fictive knowledge, 
+answer the following query in a concise and professional manner.
+
 Context information is below.
 ---------------------
 {context_str}
 ---------------------
-
-Given the context information and without using fictive knowledge, 
-answer the following query in a concise and professional manner.
 Query: {query_str}
 Answer:
 """
@@ -131,7 +132,7 @@ Meet Sarah, a 30-year-old attorney, and her roommate, 
James, whom she's shared a
     keywords_extract_prompt: str = """指令:
 请对以下文本执行以下任务:
 1. 从文本中提取关键词:
-  - 最少 0 个,最多 {max_keywords} 个。
+  - 最少 0 个,最多 MAX_KEYWORDS 个。
   - 关键词应为具有完整语义的词语或短语,确保信息完整。
 2. 识别需改写的关键词:
   - 从提取的关键词中,识别那些在原语境中具有歧义或存在信息缺失的关键词。
@@ -151,47 +152,55 @@ Meet Sarah, a 30-year-old attorney, and her roommate, 
James, whom she's shared a
 - 仅输出一行内容, 以 KEYWORDS: 为前缀,后跟所有关键词或对应的同义词,之间用逗号分隔。抽取的关键词中不允许出现空格或空字符
 - 格式示例:
 KEYWORDS:关键词1,关键词2,...,关键词n
+
+MAX_KEYWORDS: {max_keywords}
 文本:
 {question}
 """
-#pylint: disable=C0301
+    # pylint: disable=C0301
     # keywords_extract_prompt_EN = """
-# Instruction:
-# Please perform the following tasks on the text below:
-# 1. Extract Keywords and Generate Synonyms from text:
-#   - At least 0, at most {max_keywords} keywords.
-#   - For each keyword, generate its synonyms or possible variant forms.
-# Requirements:
-# - Keywords should be meaningful and specific entities; avoid using 
meaningless or overly broad terms (e.g., “object,” “the,” “he”).
-# - Prioritize extracting subjects, verbs, and objects; avoid extracting 
function words or auxiliary words.
-# - Do not expand into unrelated generalized categories.
-# Note:
-# - Only consider semantic synonyms and other words with similar meanings in 
the given context.
-# Output Format:
-# - Output only one line, prefixed with KEYWORDS:, followed by all keywords 
and synonyms, separated by commas.No spaces or empty characters are allowed in 
the extracted keywords.
-# - Format example:
-# KEYWORDS: keyword1, keyword2, ..., keywordn, synonym1, synonym2, ..., 
synonymn
-# Text:
-# {question}
-# """
-
-    gremlin_generate_prompt = """\
-Given the example query-gremlin pairs:
-{example}
-
-Given the graph schema:
+    # Instruction:
+    # Please perform the following tasks on the text below:
+    # 1. Extract Keywords and Generate Synonyms from the text:
+    #   - At least 0, at most {max_keywords} keywords.
+    #   - For each keyword, generate its synonyms or possible variant forms.
+    # Requirements:
+    # - Keywords should be meaningful and specific entities; avoid using 
meaningless or overly broad terms (e.g., “object,” “the,” “he”).
+    # - Prioritize extracting subjects, verbs, and objects; avoid extracting 
function words or auxiliary words.
+    # - Do not expand into unrelated generalized categories.
+    # Note:
+    # - Only consider semantic synonyms and other words with similar meanings 
in the given context.
+    # Output Format:
+    # - Output only one line, prefixed with KEYWORDS:, followed by all 
keywords and synonyms, separated by commas.No spaces or empty characters are 
allowed in the extracted keywords.
+    # - Format example:
+    # KEYWORDS: keyword1, keyword2, ..., keywordN, synonym1, synonym2, ..., 
synonymN
+    # Text:
+    # {question}
+    # """
+
+    gremlin_generate_prompt = """
+You are an expert in graph query language(Gremlin), your role is to understand 
the schema of the graph and generate 
+accurate Gremlin code based on the given instructions.
+
+# Graph Schema:
 ```json
 {schema}
 ```
+# Rule:
+1. Could use the vertex ID directly if it's given in the context.
+2. The output format must be like:
+```gremlin
+g.V().limit(10)
+```
 
-Given the extracted vertex vid:
+# Extracted vertex vid:
 {vertices}
 
-Generate gremlin from the following user query.
+# Given the example query-gremlin pairs:
+{example}
+
+# Generate  gremlin from the following user query.
 {query}
-The output format must be like:
-```gremlin
-g.V().limit(10)
-```
+
 The generated gremlin is:
 """
diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/app.py 
b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/app.py
index 3fe6c0f..700e60b 100644
--- a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/app.py
+++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/app.py
@@ -35,7 +35,7 @@ from hugegraph_llm.demo.rag_demo.configs_block import (
     apply_graph_config,
 )
 from hugegraph_llm.demo.rag_demo.other_block import create_other_block
-from hugegraph_llm.demo.rag_demo.text2gremlin_block import 
create_text2gremlin_block
+from hugegraph_llm.demo.rag_demo.text2gremlin_block import 
create_text2gremlin_block, graph_rag_recall
 from hugegraph_llm.demo.rag_demo.rag_block import create_rag_block, rag_answer
 from hugegraph_llm.demo.rag_demo.vector_graph_block import 
create_vector_graph_block
 from hugegraph_llm.resources.demo.css import CSS
@@ -171,7 +171,13 @@ if __name__ == "__main__":
     hugegraph_llm = init_rag_ui()
 
     rag_http_api(
-        api_auth, rag_answer, apply_graph_config, apply_llm_config, 
apply_embedding_config, apply_reranker_config
+        api_auth,
+        rag_answer,
+        graph_rag_recall,
+        apply_graph_config,
+        apply_llm_config,
+        apply_embedding_config,
+        apply_reranker_config,
     )
     admin_http_api(api_auth, log_stream)
 
diff --git 
a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/text2gremlin_block.py 
b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/text2gremlin_block.py
index eaa37bc..c47ce7c 100644
--- a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/text2gremlin_block.py
+++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/text2gremlin_block.py
@@ -17,14 +17,15 @@
 
 import json
 import os
-from typing import Any, Tuple, Dict, Union
+from typing import Any, Tuple, Dict, Union, Literal
 
 import gradio as gr
 import pandas as pd
 
-from hugegraph_llm.config import prompt, resource_path
+from hugegraph_llm.config import prompt, resource_path, huge_settings
 from hugegraph_llm.models.embeddings.init_embedding import Embeddings
 from hugegraph_llm.models.llms.init_llm import LLMs
+from hugegraph_llm.operators.graph_rag_task import RAGPipeline
 from hugegraph_llm.operators.gremlin_generate_task import GremlinGenerator
 from hugegraph_llm.operators.hugegraph_op.schema_manager import SchemaManager
 from hugegraph_llm.utils.hugegraph_utils import run_gremlin_query
@@ -32,8 +33,11 @@ from hugegraph_llm.utils.log import log
 
 
 def store_schema(schema, question, gremlin_prompt):
-    if (prompt.text2gql_graph_schema != schema or prompt.default_question != 
question or
-            prompt.gremlin_generate_prompt != gremlin_prompt):
+    if (
+        prompt.text2gql_graph_schema != schema or
+        prompt.default_question != question or
+        prompt.gremlin_generate_prompt != gremlin_prompt
+    ):
         prompt.text2gql_graph_schema = schema
         prompt.default_question = question
         prompt.gremlin_generate_prompt = gremlin_prompt
@@ -49,7 +53,7 @@ def build_example_vector_index(temp_file) -> dict:
         with open(full_path, "r", encoding="utf-8") as f:
             examples = json.load(f)
     elif full_path.endswith(".csv"):
-        examples = pd.read_csv(full_path).to_dict('records')
+        examples = pd.read_csv(full_path).to_dict("records")
     else:
         log.critical("Unsupported file format. Please input a JSON or CSV 
file.")
         return {"error": "Unsupported file format. Please input a JSON or CSV 
file."}
@@ -60,8 +64,9 @@ def build_example_vector_index(temp_file) -> dict:
     return builder.example_index_build(examples).run()
 
 
-def gremlin_generate(inp, example_num, schema, gremlin_prompt) -> Union[
-    tuple[str, str], tuple[str, Any, Any, Any, Any]]:
+def gremlin_generate(
+    inp, example_num, schema, gremlin_prompt
+) -> Union[tuple[str, str], tuple[str, Any, Any, Any, Any]]:
     generator = GremlinGenerator(llm=LLMs().get_text2gql_llm(), 
embedding=Embeddings().get_embedding())
     sm = SchemaManager(graph_name=schema)
     short_schema = False
@@ -83,19 +88,28 @@ def gremlin_generate(inp, example_num, schema, 
gremlin_prompt) -> Union[
                 return "Invalid JSON schema, please check the format 
carefully.", ""
     # FIXME: schema is not used in gremlin_generate() step, no context for it 
(enhance the logic here)
     updated_schema = sm.simple_schema(schema) if short_schema else schema
-    context = 
generator.example_index_query(example_num).gremlin_generate_synthesize(updated_schema,
-                                                                               
      gremlin_prompt).run(query=inp)
+    store_schema(str(updated_schema), inp, gremlin_prompt)
+    context = (
+        
generator.example_index_query(example_num).gremlin_generate_synthesize(updated_schema,
 gremlin_prompt)
+        .run(query=inp)
+    )
     try:
         context["template_exec_res"] = 
run_gremlin_query(query=context["result"])
-    except Exception as e: # pylint: disable=broad-except
+    except Exception as e:  # pylint: disable=broad-except
         context["template_exec_res"] = f"{e}"
     try:
         context["raw_exec_res"] = 
run_gremlin_query(query=context["raw_result"])
-    except Exception as e: # pylint: disable=broad-except
+    except Exception as e:  # pylint: disable=broad-except
         context["raw_exec_res"] = f"{e}"
 
     match_result = json.dumps(context.get("match_result", "No Results"), 
ensure_ascii=False, indent=2)
-    return match_result, context["result"], context["raw_result"], 
context["template_exec_res"], context["raw_exec_res"]
+    return (
+        match_result,
+        context["result"],
+        context["raw_result"],
+        context["template_exec_res"],
+        context["raw_exec_res"],
+    )
 
 
 def simple_schema(schema: Dict[str, Any]) -> Dict[str, Any]:
@@ -112,24 +126,24 @@ def simple_schema(schema: Dict[str, Any]) -> Dict[str, 
Any]:
     if "edgelabels" in schema:
         mini_schema["edgelabels"] = []
         for edge in schema["edgelabels"]:
-            new_edge = {key: edge[key] for key in
-                        ["name", "source_label", "target_label", "properties"] 
if key in edge}
+            new_edge = {key: edge[key] for key in ["name", "source_label", 
"target_label", "properties"] if key in edge}
             mini_schema["edgelabels"].append(new_edge)
 
     return mini_schema
 
 
 def create_text2gremlin_block() -> Tuple:
-    gr.Markdown("""## Build Vector Template Index (Optional)
+    gr.Markdown(
+        """## Build Vector Template Index (Optional)
     > Uploaded CSV file should be in `query,gremlin` format below:    
     > e.g. `who is peter?`,`g.V().has('name', 'peter')`    
     > JSON file should be in format below:  
     > e.g. `[{"query":"who is peter", "gremlin":"g.V().has('name', 'peter')"}]`
-    """)
+    """
+    )
     with gr.Row():
         file = gr.File(
-            value=os.path.join(resource_path, "demo", "text2gremlin.csv"),
-            label="Upload Text-Gremlin Pairs File"
+            value=os.path.join(resource_path, "demo", "text2gremlin.csv"), 
label="Upload Text-Gremlin Pairs File"
         )
         out = gr.Textbox(label="Result Message")
     with gr.Row():
@@ -143,27 +157,49 @@ def create_text2gremlin_block() -> Tuple:
             match = gr.Code(label="Similar Template (TopN)", 
language="javascript", elem_classes="code-container-show")
             initialized_out = gr.Textbox(label="Gremlin With Template", 
show_copy_button=True)
             raw_out = gr.Textbox(label="Gremlin Without Template", 
show_copy_button=True)
-            tmpl_exec_out = gr.Code(label="Query With Template Output", 
language="json",
-                                    elem_classes="code-container-show")
-            raw_exec_out = gr.Code(label="Query Without Template Output", 
language="json",
-                                   elem_classes="code-container-show")
+            tmpl_exec_out = gr.Code(
+                label="Query With Template Output", language="json", 
elem_classes="code-container-show"
+            )
+            raw_exec_out = gr.Code(
+                label="Query Without Template Output", language="json", 
elem_classes="code-container-show"
+            )
 
         with gr.Column(scale=1):
-            example_num_slider = gr.Slider(
-                minimum=0,
-                maximum=10,
-                step=1,
-                value=2,
-                label="Number of refer examples"
-            )
+            example_num_slider = gr.Slider(minimum=0, maximum=10, step=1, 
value=2, label="Number of refer examples")
             schema_box = gr.Textbox(value=prompt.text2gql_graph_schema, 
label="Schema", lines=2, show_copy_button=True)
-            prompt_box = gr.Textbox(value=prompt.gremlin_generate_prompt, 
label="Prompt", lines=2,
-                                    show_copy_button=True)
+            prompt_box = gr.Textbox(
+                value=prompt.gremlin_generate_prompt, label="Prompt", 
lines=20, show_copy_button=True
+            )
             btn = gr.Button("Text2Gremlin", variant="primary")
     btn.click(  # pylint: disable=no-member
         fn=gremlin_generate,
         inputs=[input_box, example_num_slider, schema_box, prompt_box],
-        outputs=[match, initialized_out, raw_out, tmpl_exec_out, raw_exec_out]
-    ).then(store_schema, inputs=[schema_box, input_box, prompt_box], )
+        outputs=[match, initialized_out, raw_out, tmpl_exec_out, raw_exec_out],
+    )
 
     return input_box, schema_box, prompt_box
+
+
+def graph_rag_recall(
+    query: str,
+    gremlin_tmpl_num: int,
+    with_gremlin_tmpl: bool,
+    rerank_method: Literal["bleu", "reranker"],
+    near_neighbor_first: bool,
+    custom_related_information: str,
+    gremlin_prompt: str,
+) -> dict:
+    store_schema(prompt.text2gql_graph_schema, query, gremlin_prompt)
+    rag = RAGPipeline()
+
+    
rag.extract_keywords().keywords_to_vid().import_schema(huge_settings.graph_name).query_graphdb(
+        with_gremlin_template=with_gremlin_tmpl,
+        num_gremlin_generate_example=gremlin_tmpl_num,
+        gremlin_prompt=gremlin_prompt,
+    ).merge_dedup_rerank(
+        rerank_method=rerank_method,
+        near_neighbor_first=near_neighbor_first,
+        custom_related_information=custom_related_information,
+    )
+    context = rag.run(verbose=True, query=query, graph_search=True)
+    return context
diff --git 
a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/gremlin_generate.py 
b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/gremlin_generate.py
index 955fc9e..9694647 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/gremlin_generate.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/gremlin_generate.py
@@ -20,19 +20,19 @@ import json
 import re
 from typing import Optional, List, Dict, Any, Union
 
+from hugegraph_llm.config import prompt
 from hugegraph_llm.models.llms.base import BaseLLM
 from hugegraph_llm.models.llms.init_llm import LLMs
 from hugegraph_llm.utils.log import log
-from hugegraph_llm.config import prompt
 
 
 class GremlinGenerateSynthesize:
     def __init__(
-            self,
-            llm: BaseLLM = None,
-            schema: Optional[Union[dict, str]] = None,
-            vertices: Optional[List[str]] = None,
-            gremlin_prompt: Optional[str] = None
+        self,
+        llm: BaseLLM = None,
+        schema: Optional[Union[dict, str]] = None,
+        vertices: Optional[List[str]] = None,
+        gremlin_prompt: Optional[str] = None
     ) -> None:
         self.llm = llm or LLMs().get_text2gql_llm()
         if isinstance(schema, dict):
@@ -59,7 +59,7 @@ class GremlinGenerateSynthesize:
     def _format_vertices(self, vertices: Optional[List[str]]) -> Optional[str]:
         if not vertices:
             return None
-        return "\n".join([f"- {vid}" for vid in vertices])
+        return "\n".join([f"- '{vid}'" for vid in vertices])
 
     async def async_generate(self, context: Dict[str, Any]):
         async_tasks = {}

Reply via email to