This is an automated email from the ASF dual-hosted git repository.

jin pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-hugegraph-ai.git


The following commit(s) were added to refs/heads/main by this push:
     new 7ae5d6f  feat(llm): support async streaming output in RAG answer block 
(#190)
7ae5d6f is described below

commit 7ae5d6fcc1013bc39164672355a8270f078bff8d
Author: vichayturen <1073931...@qq.com>
AuthorDate: Thu Mar 6 18:21:42 2025 +0800

    feat(llm): support async streaming output in RAG answer block (#190)
    
    follow #172
    
    In order to achieve asynchronization, we compromised by changing 
`gremlin_generate_operator` to a synchronous generation mode. This can be 
changed back to an asynchronous mode after achieving full asynchronization in 
the subsequent agentization process.
    
    ---------
    
    Co-authored-by: chenzihong <522023320...@smail.nju.edu.cn>
    Co-authored-by: chenzihong 
<58508660+chenzihong-ga...@users.noreply.github.com>
    Co-authored-by: imbajin <j...@apache.org>
---
 .../src/hugegraph_llm/demo/rag_demo/admin_block.py |   2 +-
 .../src/hugegraph_llm/demo/rag_demo/rag_block.py   | 164 ++++++++++++++++-----
 .../src/hugegraph_llm/models/llms/base.py          |  15 +-
 .../src/hugegraph_llm/models/llms/litellm.py       |  31 +++-
 .../src/hugegraph_llm/models/llms/ollama.py        |  47 ++++--
 .../src/hugegraph_llm/models/llms/openai.py        | 104 ++++++++++---
 .../src/hugegraph_llm/models/llms/qianfan.py       |  37 ++++-
 .../operators/llm_op/answer_synthesize.py          | 156 ++++++++++++++++----
 .../operators/llm_op/gremlin_generate.py           |  32 +++-
 9 files changed, 484 insertions(+), 104 deletions(-)

diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/admin_block.py 
b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/admin_block.py
index b8c1852..2d5937a 100644
--- a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/admin_block.py
+++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/admin_block.py
@@ -50,7 +50,7 @@ async def log_stream(log_path: str, lines: int = 125):
 def read_llm_server_log(lines=250):
     log_path = "logs/llm-server.log"
     try:
-        with open(log_path, "r", encoding='utf-8') as f:
+        with open(log_path, "r", encoding='utf-8', errors="replace") as f:
             return ''.join(deque(f, maxlen=lines))
     except FileNotFoundError:
         log.critical("Log file not found: %s", log_path)
diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/rag_block.py 
b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/rag_block.py
index 8261887..df82568 100644
--- a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/rag_block.py
+++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/rag_block.py
@@ -18,7 +18,7 @@
 # pylint: disable=E1101
 
 import os
-from typing import Tuple, Literal, Optional
+from typing import AsyncGenerator, Tuple, Literal, Optional
 
 import gradio as gr
 import pandas as pd
@@ -26,6 +26,7 @@ from gradio.utils import NamedString
 
 from hugegraph_llm.config import resource_path, prompt, huge_settings, 
llm_settings
 from hugegraph_llm.operators.graph_rag_task import RAGPipeline
+from hugegraph_llm.operators.llm_op.answer_synthesize import AnswerSynthesize
 from hugegraph_llm.utils.log import log
 
 
@@ -56,25 +57,10 @@ def rag_answer(
     4. Synthesize the final answer.
     5. Run the pipeline and return the results.
     """
-
-    gremlin_prompt = gremlin_prompt or prompt.gremlin_generate_prompt
-    should_update_prompt = (
-        prompt.default_question != text
-        or prompt.answer_prompt != answer_prompt
-        or prompt.keywords_extract_prompt != keywords_extract_prompt
-        or prompt.gremlin_generate_prompt != gremlin_prompt
-        or prompt.custom_rerank_info != custom_related_information
-    )
-    if should_update_prompt:
-        prompt.custom_rerank_info = custom_related_information
-        prompt.default_question = text
-        prompt.answer_prompt = answer_prompt
-        prompt.keywords_extract_prompt = keywords_extract_prompt
-        prompt.gremlin_generate_prompt = gremlin_prompt
-        prompt.update_yaml_file()
-
-    vector_search = vector_only_answer or graph_vector_answer
-    graph_search = graph_only_answer or graph_vector_answer
+    graph_search, gremlin_prompt, vector_search = 
update_ui_configs(answer_prompt, custom_related_information,
+                                                                    
graph_only_answer, graph_vector_answer,
+                                                                    
gremlin_prompt, keywords_extract_prompt, text,
+                                                                    
vector_only_answer)
     if raw_answer is False and not vector_search and not graph_search:
         gr.Warning("Please select at least one generate mode.")
         return "", "", "", ""
@@ -121,6 +107,106 @@ def rag_answer(
         raise gr.Error(f"An unexpected error occurred: {str(e)}")
 
 
+def update_ui_configs(answer_prompt, custom_related_information, 
graph_only_answer, graph_vector_answer, gremlin_prompt,
+                      keywords_extract_prompt, text, vector_only_answer):
+    gremlin_prompt = gremlin_prompt or prompt.gremlin_generate_prompt
+    should_update_prompt = (
+        prompt.default_question != text
+        or prompt.answer_prompt != answer_prompt
+        or prompt.keywords_extract_prompt != keywords_extract_prompt
+        or prompt.gremlin_generate_prompt != gremlin_prompt
+        or prompt.custom_rerank_info != custom_related_information
+    )
+    if should_update_prompt:
+        prompt.custom_rerank_info = custom_related_information
+        prompt.default_question = text
+        prompt.answer_prompt = answer_prompt
+        prompt.keywords_extract_prompt = keywords_extract_prompt
+        prompt.gremlin_generate_prompt = gremlin_prompt
+        prompt.update_yaml_file()
+    vector_search = vector_only_answer or graph_vector_answer
+    graph_search = graph_only_answer or graph_vector_answer
+    return graph_search, gremlin_prompt, vector_search
+
+
+async def rag_answer_streaming(
+    text: str,
+    raw_answer: bool,
+    vector_only_answer: bool,
+    graph_only_answer: bool,
+    graph_vector_answer: bool,
+    graph_ratio: float,
+    rerank_method: Literal["bleu", "reranker"],
+    near_neighbor_first: bool,
+    custom_related_information: str,
+    answer_prompt: str,
+    keywords_extract_prompt: str,
+    gremlin_tmpl_num: Optional[int] = 2,
+    gremlin_prompt: Optional[str] = None,
+) -> AsyncGenerator[Tuple[str, str, str, str], None]:
+    """
+    Generate an answer using the RAG (Retrieval-Augmented Generation) pipeline.
+    1. Initialize the RAGPipeline.
+    2. Select vector search or graph search based on parameters.
+    3. Merge, deduplicate, and rerank the results.
+    4. Synthesize the final answer.
+    5. Run the pipeline and return the results.
+    """
+
+    graph_search, gremlin_prompt, vector_search = 
update_ui_configs(answer_prompt, custom_related_information,
+                                                                    
graph_only_answer, graph_vector_answer,
+                                                                    
gremlin_prompt, keywords_extract_prompt, text,
+                                                                    
vector_only_answer)
+    if raw_answer is False and not vector_search and not graph_search:
+        gr.Warning("Please select at least one generate mode.")
+        yield "", "", "", ""
+        return
+
+    rag = RAGPipeline()
+    if vector_search:
+        rag.query_vector_index()
+    if graph_search:
+        
rag.extract_keywords(extract_template=keywords_extract_prompt).keywords_to_vid().import_schema(
+            huge_settings.graph_name
+        ).query_graphdb(
+            num_gremlin_generate_example=gremlin_tmpl_num,
+            gremlin_prompt=gremlin_prompt,
+        )
+    rag.merge_dedup_rerank(
+        graph_ratio,
+        rerank_method,
+        near_neighbor_first,
+    )
+    # rag.synthesize_answer(raw_answer, vector_only_answer, graph_only_answer, 
graph_vector_answer, answer_prompt)
+
+    try:
+        context = rag.run(verbose=True, query=text, 
vector_search=vector_search, graph_search=graph_search)
+        if context.get("switch_to_bleu"):
+            gr.Warning("Online reranker fails, automatically switches to local 
bleu rerank.")
+        answer_synthesize = AnswerSynthesize(
+            raw_answer=raw_answer,
+            vector_only_answer=vector_only_answer,
+            graph_only_answer=graph_only_answer,
+            graph_vector_answer=graph_vector_answer,
+            prompt_template=answer_prompt,
+        )
+        async for context in answer_synthesize.run_streaming(context):
+            if context.get("switch_to_bleu"):
+                gr.Warning("Online reranker fails, automatically switches to 
local bleu rerank.")
+            yield (
+                context.get("raw_answer", ""),
+                context.get("vector_only_answer", ""),
+                context.get("graph_only_answer", ""),
+                context.get("graph_vector_answer", ""),
+            )
+    except ValueError as e:
+        log.critical(e)
+        raise gr.Error(str(e))
+    except Exception as e:
+        log.critical(e)
+        raise gr.Error(f"An unexpected error occurred: {str(e)}")
+
+
 def create_rag_block():
     # pylint: disable=R0915 (too-many-statements),C0301
     gr.Markdown("""## 1. HugeGraph RAG Query""")
@@ -130,13 +216,17 @@ def create_rag_block():
 
             # TODO: Only support inline formula now. Should support block 
formula
             gr.Markdown("Basic LLM Answer", elem_classes="output-box-label")
-            raw_out = gr.Markdown(elem_classes="output-box", 
show_copy_button=True, latex_delimiters=[{"left":"$", "right":"$", 
"display":False}])
+            raw_out = gr.Markdown(elem_classes="output-box", 
show_copy_button=True,
+                                  latex_delimiters=[{"left": "$", "right": 
"$", "display": False}])
             gr.Markdown("Vector-only Answer", elem_classes="output-box-label")
-            vector_only_out = gr.Markdown(elem_classes="output-box", 
show_copy_button=True, latex_delimiters=[{"left":"$", "right":"$", 
"display":False}])
+            vector_only_out = gr.Markdown(elem_classes="output-box", 
show_copy_button=True,
+                                          latex_delimiters=[{"left": "$", 
"right": "$", "display": False}])
             gr.Markdown("Graph-only Answer", elem_classes="output-box-label")
-            graph_only_out = gr.Markdown(elem_classes="output-box", 
show_copy_button=True, latex_delimiters=[{"left":"$", "right":"$", 
"display":False}])
+            graph_only_out = gr.Markdown(elem_classes="output-box", 
show_copy_button=True,
+                                         latex_delimiters=[{"left": "$", 
"right": "$", "display": False}])
             gr.Markdown("Graph-Vector Answer", elem_classes="output-box-label")
-            graph_vector_out = gr.Markdown(elem_classes="output-box", 
show_copy_button=True, latex_delimiters=[{"left":"$", "right":"$", 
"display":False}])
+            graph_vector_out = gr.Markdown(elem_classes="output-box", 
show_copy_button=True,
+                                           latex_delimiters=[{"left": "$", 
"right": "$", "display": False}])
 
             answer_prompt_input = gr.Textbox(
                 value=prompt.answer_prompt, label="Query Prompt", 
show_copy_button=True, lines=7
@@ -184,7 +274,7 @@ def create_rag_block():
                 btn = gr.Button("Answer Question", variant="primary")
 
     btn.click(  # pylint: disable=no-member
-        fn=rag_answer,
+        fn=rag_answer_streaming,
         inputs=[
             inp,
             raw_radio,
@@ -254,13 +344,13 @@ def create_rag_block():
         is_vector_only_answer: bool,
         is_graph_only_answer: bool,
         is_graph_vector_answer: bool,
-        graph_ratio: float,
-        rerank_method: Literal["bleu", "reranker"],
-        near_neighbor_first: bool,
-        custom_related_information: str,
+        graph_ratio_ui: float,
+        rerank_method_ui: Literal["bleu", "reranker"],
+        near_neighbor_first_ui: bool,
+        custom_related_information_ui: str,
         answer_prompt: str,
         keywords_extract_prompt: str,
-        answer_max_line_count: int = 1,
+        answer_max_line_count_ui: int = 1,
         progress=gr.Progress(track_tqdm=True),
     ):
         df = pd.read_excel(questions_path, dtype=str)
@@ -273,10 +363,10 @@ def create_rag_block():
                 is_vector_only_answer,
                 is_graph_only_answer,
                 is_graph_vector_answer,
-                graph_ratio,
-                rerank_method,
-                near_neighbor_first,
-                custom_related_information,
+                graph_ratio_ui,
+                rerank_method_ui,
+                near_neighbor_first_ui,
+                custom_related_information_ui,
                 answer_prompt,
                 keywords_extract_prompt,
             )
@@ -285,9 +375,9 @@ def create_rag_block():
             df.at[index, "Graph-only Answer"] = graph_only_answer
             df.at[index, "Graph-Vector Answer"] = graph_vector_answer
             progress((index + 1, total_rows))
-        answers_path = os.path.join(resource_path, "demo", 
"questions_answers.xlsx")
-        df.to_excel(answers_path, index=False)
-        return df.head(answer_max_line_count), answers_path
+        answers_path_ui = os.path.join(resource_path, "demo", 
"questions_answers.xlsx")
+        df.to_excel(answers_path_ui, index=False)
+        return df.head(answer_max_line_count_ui), answers_path_ui
 
     with gr.Row():
         with gr.Column():
diff --git a/hugegraph-llm/src/hugegraph_llm/models/llms/base.py 
b/hugegraph-llm/src/hugegraph_llm/models/llms/base.py
index f2dd234..c6bfa44 100644
--- a/hugegraph-llm/src/hugegraph_llm/models/llms/base.py
+++ b/hugegraph-llm/src/hugegraph_llm/models/llms/base.py
@@ -16,7 +16,7 @@
 # under the License.
 
 from abc import ABC, abstractmethod
-from typing import Any, List, Optional, Callable, Dict
+from typing import Any, AsyncGenerator, Generator, List, Optional, Callable, 
Dict
 
 
 class BaseLLM(ABC):
@@ -43,8 +43,17 @@ class BaseLLM(ABC):
             self,
             messages: Optional[List[Dict[str, Any]]] = None,
             prompt: Optional[str] = None,
-            on_token_callback: Callable = None,
-    ) -> List[Any]:
+            on_token_callback: Optional[Callable] = None,
+    ) -> Generator[str, None, None]:
+        """Comment"""
+
+    @abstractmethod
+    async def agenerate_streaming(
+            self,
+            messages: Optional[List[Dict[str, Any]]] = None,
+            prompt: Optional[str] = None,
+            on_token_callback: Optional[Callable] = None,
+    ) -> AsyncGenerator[str, None]:
         """Comment"""
 
     @abstractmethod
diff --git a/hugegraph-llm/src/hugegraph_llm/models/llms/litellm.py 
b/hugegraph-llm/src/hugegraph_llm/models/llms/litellm.py
index 23a1250..ca5ae60 100644
--- a/hugegraph-llm/src/hugegraph_llm/models/llms/litellm.py
+++ b/hugegraph-llm/src/hugegraph_llm/models/llms/litellm.py
@@ -15,7 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 
-from typing import Callable, List, Optional, Dict, Any
+from typing import Callable, List, Optional, Dict, Any, AsyncGenerator
 
 import tiktoken
 from litellm import completion, acompletion
@@ -137,6 +137,35 @@ class LiteLLMClient(BaseLLM):
             log.error("Error in streaming LiteLLM call: %s", e)
             return f"Error: {str(e)}"
 
+    async def agenerate_streaming(
+        self,
+        messages: Optional[List[Dict[str, Any]]] = None,
+        prompt: Optional[str] = None,
+        on_token_callback: Optional[Callable] = None,
+    ) -> AsyncGenerator[str, None]:
+        """Generate a response to the query messages/prompt in async streaming 
mode."""
+        if messages is None:
+            assert prompt is not None, "Messages or prompt must be provided."
+            messages = [{"role": "user", "content": prompt}]
+        try:
+            response = await acompletion(
+                model=self.model,
+                messages=messages,
+                temperature=self.temperature,
+                max_tokens=self.max_tokens,
+                api_key=self.api_key,
+                base_url=self.api_base,
+                stream=True,
+            )
+            async for chunk in response:
+                if chunk.choices[0].delta.content:
+                    if on_token_callback:
+                        on_token_callback(chunk)
+                    yield chunk.choices[0].delta.content
+        except (RateLimitError, BudgetExceededError, APIError) as e:
+            log.error("Error in async streaming LiteLLM call: %s", e)
+            yield f"Error: {str(e)}"
+
     def num_tokens_from_string(self, string: str) -> int:
         """Get token count from string."""
         try:
diff --git a/hugegraph-llm/src/hugegraph_llm/models/llms/ollama.py 
b/hugegraph-llm/src/hugegraph_llm/models/llms/ollama.py
index 62f5ef2..58f063b 100644
--- a/hugegraph-llm/src/hugegraph_llm/models/llms/ollama.py
+++ b/hugegraph-llm/src/hugegraph_llm/models/llms/ollama.py
@@ -17,7 +17,7 @@
 
 
 import json
-from typing import Any, List, Optional, Callable, Dict
+from typing import Any, AsyncGenerator, Generator, List, Optional, Callable, 
Dict
 
 import ollama
 from retry import retry
@@ -89,22 +89,49 @@ class OllamaClient(BaseLLM):
         self,
         messages: Optional[List[Dict[str, Any]]] = None,
         prompt: Optional[str] = None,
-        on_token_callback: Callable = None,
-    ) -> List[Any]:
+        on_token_callback: Optional[Callable] = None,
+    ) -> Generator[str, None, None]:
         """Comment"""
         if messages is None:
             assert prompt is not None, "Messages or prompt must be provided."
             messages = [{"role": "user", "content": prompt}]
-        stream = self.client.chat(
+
+        for chunk in self.client.chat(
             model=self.model,
             messages=messages,
             stream=True
-        )
-        chunks = []
-        for chunk in stream:
-            on_token_callback(chunk["message"]["content"])
-            chunks.append(chunk)
-        return chunks
+        ):
+            token = chunk["message"]["content"]
+            if on_token_callback:
+                on_token_callback(token)
+            yield token
+
+    async def agenerate_streaming(
+        self,
+        messages: Optional[List[Dict[str, Any]]] = None,
+        prompt: Optional[str] = None,
+        on_token_callback: Optional[Callable] = None,
+    ) -> AsyncGenerator[str, None]:
+        """Comment"""
+        if messages is None:
+            assert prompt is not None, "Messages or prompt must be provided."
+            messages = [{"role": "user", "content": prompt}]
+
+        try:
+            async_generator = await self.async_client.chat(
+                model=self.model,
+                messages=messages,
+                stream=True
+            )
+            async for chunk in async_generator:
+                token = chunk.get("message", {}).get("content", "")
+                if on_token_callback:
+                    on_token_callback(token)
+                yield token
+        except Exception as e:
+            print(f"Retrying LLM call {e}")
+            raise e
+
 
     def num_tokens_from_string(
         self,
diff --git a/hugegraph-llm/src/hugegraph_llm/models/llms/openai.py 
b/hugegraph-llm/src/hugegraph_llm/models/llms/openai.py
index a020067..45f6d7a 100644
--- a/hugegraph-llm/src/hugegraph_llm/models/llms/openai.py
+++ b/hugegraph-llm/src/hugegraph_llm/models/llms/openai.py
@@ -15,7 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 
-from typing import Callable, List, Optional, Dict, Any
+from typing import Callable, List, Optional, Dict, Any, Generator, 
AsyncGenerator
 
 import openai
 import tiktoken
@@ -90,9 +90,9 @@ class OpenAIClient(BaseLLM):
         retry=retry_if_exception_type((RateLimitError, APIConnectionError, 
APITimeoutError)),
     )
     async def agenerate(
-            self,
-            messages: Optional[List[Dict[str, Any]]] = None,
-            prompt: Optional[str] = None,
+        self,
+        messages: Optional[List[Dict[str, Any]]] = None,
+        prompt: Optional[str] = None,
     ) -> str:
         """Generate a response to the query messages/prompt."""
         if messages is None:
@@ -119,31 +119,91 @@ class OpenAIClient(BaseLLM):
             log.error("Retrying LLM call %s", e)
             raise e
 
+    @retry(
+        stop=stop_after_attempt(3),
+        wait=wait_exponential(multiplier=1, min=4, max=10),
+        retry=retry_if_exception_type((RateLimitError, APIConnectionError, 
APITimeoutError)),
+    )
     def generate_streaming(
         self,
         messages: Optional[List[Dict[str, Any]]] = None,
         prompt: Optional[str] = None,
-        on_token_callback: Callable = None,
-    ) -> str:
-        """Generate a response to the query messages/prompt in streaming 
mode."""
+        on_token_callback: Optional[Callable[[str], None]] = None,
+    ) -> Generator[str, None, None]:
+        """Generate a response to the query messages/prompt in streaming mode.
+
+        Yields:
+            Accumulated response string after each new token.
+        """
         if messages is None:
             assert prompt is not None, "Messages or prompt must be provided."
             messages = [{"role": "user", "content": prompt}]
-        completions = self.client.chat.completions.create(
-            model=self.model,
-            temperature=self.temperature,
-            max_tokens=self.max_tokens,
-            messages=messages,
-            stream=True,
-        )
-        result = ""
-        for message in completions:
-            # Process the streamed messages or perform any other desired action
-            delta = message["choices"][0]["delta"]
-            if "content" in delta:
-                result += delta["content"]
-            on_token_callback(message)
-        return result
+
+        try:
+            completions = self.client.chat.completions.create(
+                model=self.model,
+                temperature=self.temperature,
+                max_tokens=self.max_tokens,
+                messages=messages,
+                stream=True,
+            )
+
+            for chunk in completions:
+                delta = chunk.choices[0].delta
+                if delta.content:
+                    token = delta.content
+                    if on_token_callback:
+                        on_token_callback(token)
+                    yield token
+
+        except openai.BadRequestError as e:
+            log.critical("Fatal: %s", e)
+            yield str(f"Error: {e}")
+        except openai.AuthenticationError:
+            log.critical("The provided API key is invalid")
+            yield "Error: The provided API key is invalid"
+        except Exception as e:
+            log.error("Error in streaming: %s", e)
+            raise e
+
+    async def agenerate_streaming(
+        self,
+        messages: Optional[List[Dict[str, Any]]] = None,
+        prompt: Optional[str] = None,
+        on_token_callback: Optional[Callable] = None,
+    ) -> AsyncGenerator[str, None]:
+        """Comment"""
+        if messages is None:
+            assert prompt is not None, "Messages or prompt must be provided."
+            messages = [{"role": "user", "content": prompt}]
+
+        try:
+            completions = await self.aclient.chat.completions.create(
+                model=self.model,
+                temperature=self.temperature,
+                max_tokens=self.max_tokens,
+                messages=messages,
+                stream=True
+            )
+            async for chunk in completions:
+                delta = chunk.choices[0].delta
+                if delta.content:
+                    token = delta.content
+                    if on_token_callback:
+                        on_token_callback(token)
+                    yield token
+            # TODO: log.info("Token usage: %s", 
completions.usage.model_dump_json())
+        # catch context length / do not retry
+        except openai.BadRequestError as e:
+            log.critical("Fatal: %s", e)
+            yield str(f"Error: {e}")
+        # catch authorization errors / do not retry
+        except openai.AuthenticationError:
+            log.critical("The provided OpenAI API key is invalid")
+            yield "Error: The provided OpenAI API key is invalid"
+        except Exception as e:
+            log.error("Retrying LLM call %s", e)
+            raise e
 
     def num_tokens_from_string(self, string: str) -> int:
         """Get token count from string."""
diff --git a/hugegraph-llm/src/hugegraph_llm/models/llms/qianfan.py 
b/hugegraph-llm/src/hugegraph_llm/models/llms/qianfan.py
index 967c391..cbca691 100644
--- a/hugegraph-llm/src/hugegraph_llm/models/llms/qianfan.py
+++ b/hugegraph-llm/src/hugegraph_llm/models/llms/qianfan.py
@@ -16,7 +16,7 @@
 # under the License.
 
 import json
-from typing import Optional, List, Dict, Any, Callable
+from typing import AsyncGenerator, Generator, Optional, List, Dict, Any, 
Callable
 
 import qianfan
 from retry import retry
@@ -74,9 +74,38 @@ class QianfanClient(BaseLLM):
             self,
             messages: Optional[List[Dict[str, Any]]] = None,
             prompt: Optional[str] = None,
-            on_token_callback: Callable = None,
-    ) -> str:
-        return self.generate(messages, prompt)
+            on_token_callback: Optional[Callable] = None,
+    ) -> Generator[str, None, None]:
+        if messages is None:
+            assert prompt is not None, "Messages or prompt must be provided."
+            messages = [{"role": "user", "content": prompt}]
+
+        for msg in self.chat_comp.do(messages=messages, model=self.chat_model, 
stream=True):
+            token = msg.body['result']
+            if on_token_callback:
+                on_token_callback(token)
+            yield token
+
+    async def agenerate_streaming(
+            self,
+            messages: Optional[List[Dict[str, Any]]] = None,
+            prompt: Optional[str] = None,
+            on_token_callback: Optional[Callable] = None,
+    ) -> AsyncGenerator[str, None]:
+        if messages is None:
+            assert prompt is not None, "Messages or prompt must be provided."
+            messages = [{"role": "user", "content": prompt}]
+
+        try:
+            async_generator = await self.chat_comp.ado(messages=messages, 
model=self.chat_model, stream=True)
+            async for msg in async_generator:
+                chunk = msg.body['result']
+                if on_token_callback:
+                    on_token_callback(chunk)
+                yield chunk
+        except Exception as e:
+            print(f"Retrying LLM call {e}")
+            raise e
 
     def num_tokens_from_string(self, string: str) -> int:
         return len(string)
diff --git 
a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py 
b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py
index 666ecf9..5c4ab5f 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py
@@ -18,7 +18,7 @@
 # pylint: disable=W0621
 
 import asyncio
-from typing import Any, Dict, Optional
+from typing import Any, AsyncGenerator, Dict, Optional
 
 from hugegraph_llm.config import prompt
 from hugegraph_llm.models.llms.base import BaseLLM
@@ -35,17 +35,17 @@ DEFAULT_ANSWER_TEMPLATE = prompt.answer_prompt
 
 class AnswerSynthesize:
     def __init__(
-            self,
-            llm: Optional[BaseLLM] = None,
-            prompt_template: Optional[str] = None,
-            question: Optional[str] = None,
-            context_body: Optional[str] = None,
-            context_head: Optional[str] = None,
-            context_tail: Optional[str] = None,
-            raw_answer: bool = False,
-            vector_only_answer: bool = True,
-            graph_only_answer: bool = False,
-            graph_vector_answer: bool = False,
+        self,
+        llm: Optional[BaseLLM] = None,
+        prompt_template: Optional[str] = None,
+        question: Optional[str] = None,
+        context_body: Optional[str] = None,
+        context_head: Optional[str] = None,
+        context_tail: Optional[str] = None,
+        raw_answer: bool = False,
+        vector_only_answer: bool = True,
+        graph_only_answer: bool = False,
+        graph_vector_answer: bool = False,
     ):
         self._llm = llm
         self._prompt_template = prompt_template or DEFAULT_ANSWER_TEMPLATE
@@ -59,15 +59,7 @@ class AnswerSynthesize:
         self._graph_vector_answer = graph_vector_answer
 
     def run(self, context: Dict[str, Any]) -> Dict[str, Any]:
-        if self._llm is None:
-            self._llm = LLMs().get_chat_llm()
-
-        if self._question is None:
-            self._question = context.get("query") or None
-        assert self._question is not None, "No question for synthesizing."
-
-        context_head_str = context.get("synthesize_context_head") or 
self._context_head or ""
-        context_tail_str = context.get("synthesize_context_tail") or 
self._context_tail or ""
+        context_head_str, context_tail_str = self.init_llm(context)
 
         if self._context_body is not None:
             context_str = (f"{context_head_str}\n"
@@ -78,6 +70,22 @@ class AnswerSynthesize:
             response = self._llm.generate(prompt=final_prompt)
             return {"answer": response}
 
+        graph_result_context, vector_result_context = 
self.handle_vector_graph(context)
+        context = asyncio.run(self.async_generate(context, context_head_str, 
context_tail_str,
+                                                  vector_result_context, 
graph_result_context))
+        return context
+
+    def init_llm(self, context):
+        if self._llm is None:
+            self._llm = LLMs().get_chat_llm()
+        if self._question is None:
+            self._question = context.get("query") or None
+        assert self._question is not None, "No question for synthesizing."
+        context_head_str = context.get("synthesize_context_head") or 
self._context_head or ""
+        context_tail_str = context.get("synthesize_context_tail") or 
self._context_tail or ""
+        return context_head_str, context_tail_str
+
+    def handle_vector_graph(self, context):
         vector_result = context.get("vector_result")
         if vector_result:
             vector_result_context = "Phrases related to the query:\n" + 
"\n".join(
@@ -85,7 +93,6 @@ class AnswerSynthesize:
             )
         else:
             vector_result_context = "No (vector)phrase related to the query."
-
         graph_result = context.get("graph_result")
         if graph_result:
             graph_context_head = context.get("graph_context_head", "Knowledge 
from graphdb for the query:\n")
@@ -95,10 +102,31 @@ class AnswerSynthesize:
         else:
             graph_result_context = "No related graph data found for current 
query."
             log.warning(graph_result_context)
+        return graph_result_context, vector_result_context
 
-        context = asyncio.run(self.async_generate(context, context_head_str, 
context_tail_str,
-                                                  vector_result_context, 
graph_result_context))
-        return context
+    async def run_streaming(self, context: Dict[str, Any]) -> 
AsyncGenerator[Dict[str, Any], None]:
+        context_head_str, context_tail_str = self.init_llm(context)
+
+        if self._context_body is not None:
+            context_str = (f"{context_head_str}\n"
+                           f"{self._context_body}\n"
+                           f"{context_tail_str}".strip("\n"))
+
+            final_prompt = 
self._prompt_template.format(context_str=context_str, query_str=self._question)
+            response = self._llm.generate(prompt=final_prompt)
+            yield {"answer": response}
+            return
+
+        graph_result_context, vector_result_context = 
self.handle_vector_graph(context)
+
+        async for context in self.async_streaming_generate(
+            context,
+            context_head_str,
+            context_tail_str,
+            vector_result_context,
+            graph_result_context
+        ):
+            yield context
 
     async def async_generate(self, context: Dict[str, Any], context_head_str: 
str,
                              context_tail_str: str, vector_result_context: str,
@@ -151,3 +179,81 @@ class AnswerSynthesize:
         ops = sum([self._raw_answer, self._vector_only_answer, 
self._graph_only_answer, self._graph_vector_answer])
         context['call_count'] = context.get('call_count', 0) + ops
         return context
+
+    async def async_streaming_generate(self, context: Dict[str, Any], 
context_head_str: str,
+                                       context_tail_str: str, 
vector_result_context: str,
+                                       graph_result_context: str) -> 
AsyncGenerator[Dict[str, Any], None]:
+        # async_tasks stores the async tasks for different answer types
+        async_generators = []
+        auto_id = 0
+        if self._raw_answer:
+            final_prompt = self._question
+            async_generators.append(
+                self.__llm_generate_with_meta_info(task_id=auto_id, 
target_key="raw_answer", prompt=final_prompt)
+            )
+            auto_id += 1
+        if self._vector_only_answer:
+            context_str = (f"{context_head_str}\n"
+                           f"{vector_result_context}\n"
+                           f"{context_tail_str}".strip("\n"))
+
+            final_prompt = 
self._prompt_template.format(context_str=context_str, query_str=self._question)
+            async_generators.append(
+                self.__llm_generate_with_meta_info(
+                    task_id=auto_id,
+                    target_key="vector_only_answer",
+                    prompt=final_prompt
+                )
+            )
+            auto_id += 1
+        if self._graph_only_answer:
+            context_str = (f"{context_head_str}\n"
+                           f"{graph_result_context}\n"
+                           f"{context_tail_str}".strip("\n"))
+
+            final_prompt = 
self._prompt_template.format(context_str=context_str, query_str=self._question)
+            async_generators.append(
+                self.__llm_generate_with_meta_info(task_id=auto_id, 
target_key="graph_only_answer", prompt=final_prompt)
+            )
+            auto_id += 1
+        if self._graph_vector_answer:
+            context_body_str = 
f"{vector_result_context}\n{graph_result_context}"
+            if context.get("graph_ratio", 0.5) < 0.5:
+                context_body_str = 
f"{graph_result_context}\n{vector_result_context}"
+            context_str = (f"{context_head_str}\n"
+                           f"{context_body_str}\n"
+                           f"{context_tail_str}".strip("\n"))
+
+            final_prompt = 
self._prompt_template.format(context_str=context_str, query_str=self._question)
+            async_generators.append(
+                self.__llm_generate_with_meta_info(
+                    task_id=auto_id,
+                    target_key="graph_vector_answer",
+                    prompt=final_prompt
+                )
+            )
+            auto_id += 1
+
+        ops = sum([self._raw_answer, self._vector_only_answer, 
self._graph_only_answer, self._graph_vector_answer])
+        context['call_count'] = context.get('call_count', 0) + ops
+
+        async_tasks = [asyncio.create_task(anext(gen)) for gen in 
async_generators]
+        while True:
+            done, _ = await asyncio.wait(async_tasks, 
return_when=asyncio.FIRST_COMPLETED)
+            stop_task_num = 0
+            for task in done:
+                try:
+                    task_id, target_key, token = task.result()
+                    context[target_key] = context.get(target_key, "") + token
+                    gen = async_generators[task_id]
+                    async_tasks[task_id] = asyncio.create_task(anext(gen))
+                except StopAsyncIteration:
+                    stop_task_num += 1
+            if stop_task_num == len(async_tasks):
+                break
+            yield context
+
+    async def __llm_generate_with_meta_info(self, task_id: int, target_key: 
str, prompt: str):
+        # FIXME: Expected type 'AsyncIterable', got 'Coroutine[Any, Any, 
AsyncGenerator[str, None]]' instead
+        async for token in self._llm.agenerate_streaming(prompt=prompt):
+            yield task_id, target_key, token
diff --git 
a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/gremlin_generate.py 
b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/gremlin_generate.py
index 9694647..219a358 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/gremlin_generate.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/gremlin_generate.py
@@ -92,10 +92,40 @@ class GremlinGenerateSynthesize:
 
         return context
 
+    def sync_generate(self, context: Dict[str, Any]):
+        query = context.get("query")
+        raw_example = [{'query': 'who is peter', 'gremlin': "g.V().has('name', 
'peter')"}]
+        raw_prompt = self.gremlin_prompt.format(
+            query=query,
+            schema=self.schema,
+            example=self._format_examples(examples=raw_example),
+            vertices=self._format_vertices(vertices=self.vertices)
+        )
+        raw_response = self.llm.generate(prompt=raw_prompt)
+
+        examples = context.get("match_result")
+        init_prompt = self.gremlin_prompt.format(
+            query=query,
+            schema=self.schema,
+            example=self._format_examples(examples=examples),
+            vertices=self._format_vertices(vertices=self.vertices)
+        )
+        initialized_response = self.llm.generate(prompt=init_prompt)
+
+        log.debug("Text2Gremlin with tmpl prompt:\n %s,\n LLM Response: %s", 
init_prompt, initialized_response)
+
+        context["result"] = 
self._extract_gremlin(response=initialized_response)
+        context["raw_result"] = self._extract_gremlin(response=raw_response)
+        context["call_count"] = context.get("call_count", 0) + 2
+
+        return context
+
     def run(self, context: Dict[str, Any]) -> Dict[str, Any]:
         query = context.get("query", "")
         if not query:
             raise ValueError("query is required")
 
-        context = asyncio.run(self.async_generate(context))
+        # TODO: Update to async_generate again
+        #       The best method may be changing all `operator.run(*arg)` to be 
async function
+        context = self.sync_generate(context)
         return context


Reply via email to