This is an automated email from the ASF dual-hosted git repository. jin pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/incubator-hugegraph-ai.git
The following commit(s) were added to refs/heads/main by this push: new 7ae5d6f feat(llm): support async streaming output in RAG answer block (#190) 7ae5d6f is described below commit 7ae5d6fcc1013bc39164672355a8270f078bff8d Author: vichayturen <1073931...@qq.com> AuthorDate: Thu Mar 6 18:21:42 2025 +0800 feat(llm): support async streaming output in RAG answer block (#190) follow #172 In order to achieve asynchronization, we compromised by changing `gremlin_generate_operator` to a synchronous generation mode. This can be changed back to an asynchronous mode after achieving full asynchronization in the subsequent agentization process. --------- Co-authored-by: chenzihong <522023320...@smail.nju.edu.cn> Co-authored-by: chenzihong <58508660+chenzihong-ga...@users.noreply.github.com> Co-authored-by: imbajin <j...@apache.org> --- .../src/hugegraph_llm/demo/rag_demo/admin_block.py | 2 +- .../src/hugegraph_llm/demo/rag_demo/rag_block.py | 164 ++++++++++++++++----- .../src/hugegraph_llm/models/llms/base.py | 15 +- .../src/hugegraph_llm/models/llms/litellm.py | 31 +++- .../src/hugegraph_llm/models/llms/ollama.py | 47 ++++-- .../src/hugegraph_llm/models/llms/openai.py | 104 ++++++++++--- .../src/hugegraph_llm/models/llms/qianfan.py | 37 ++++- .../operators/llm_op/answer_synthesize.py | 156 ++++++++++++++++---- .../operators/llm_op/gremlin_generate.py | 32 +++- 9 files changed, 484 insertions(+), 104 deletions(-) diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/admin_block.py b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/admin_block.py index b8c1852..2d5937a 100644 --- a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/admin_block.py +++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/admin_block.py @@ -50,7 +50,7 @@ async def log_stream(log_path: str, lines: int = 125): def read_llm_server_log(lines=250): log_path = "logs/llm-server.log" try: - with open(log_path, "r", encoding='utf-8') as f: + with open(log_path, "r", encoding='utf-8', errors="replace") as f: return ''.join(deque(f, maxlen=lines)) except FileNotFoundError: log.critical("Log file not found: %s", log_path) diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/rag_block.py b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/rag_block.py index 8261887..df82568 100644 --- a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/rag_block.py +++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/rag_block.py @@ -18,7 +18,7 @@ # pylint: disable=E1101 import os -from typing import Tuple, Literal, Optional +from typing import AsyncGenerator, Tuple, Literal, Optional import gradio as gr import pandas as pd @@ -26,6 +26,7 @@ from gradio.utils import NamedString from hugegraph_llm.config import resource_path, prompt, huge_settings, llm_settings from hugegraph_llm.operators.graph_rag_task import RAGPipeline +from hugegraph_llm.operators.llm_op.answer_synthesize import AnswerSynthesize from hugegraph_llm.utils.log import log @@ -56,25 +57,10 @@ def rag_answer( 4. Synthesize the final answer. 5. Run the pipeline and return the results. """ - - gremlin_prompt = gremlin_prompt or prompt.gremlin_generate_prompt - should_update_prompt = ( - prompt.default_question != text - or prompt.answer_prompt != answer_prompt - or prompt.keywords_extract_prompt != keywords_extract_prompt - or prompt.gremlin_generate_prompt != gremlin_prompt - or prompt.custom_rerank_info != custom_related_information - ) - if should_update_prompt: - prompt.custom_rerank_info = custom_related_information - prompt.default_question = text - prompt.answer_prompt = answer_prompt - prompt.keywords_extract_prompt = keywords_extract_prompt - prompt.gremlin_generate_prompt = gremlin_prompt - prompt.update_yaml_file() - - vector_search = vector_only_answer or graph_vector_answer - graph_search = graph_only_answer or graph_vector_answer + graph_search, gremlin_prompt, vector_search = update_ui_configs(answer_prompt, custom_related_information, + graph_only_answer, graph_vector_answer, + gremlin_prompt, keywords_extract_prompt, text, + vector_only_answer) if raw_answer is False and not vector_search and not graph_search: gr.Warning("Please select at least one generate mode.") return "", "", "", "" @@ -121,6 +107,106 @@ def rag_answer( raise gr.Error(f"An unexpected error occurred: {str(e)}") +def update_ui_configs(answer_prompt, custom_related_information, graph_only_answer, graph_vector_answer, gremlin_prompt, + keywords_extract_prompt, text, vector_only_answer): + gremlin_prompt = gremlin_prompt or prompt.gremlin_generate_prompt + should_update_prompt = ( + prompt.default_question != text + or prompt.answer_prompt != answer_prompt + or prompt.keywords_extract_prompt != keywords_extract_prompt + or prompt.gremlin_generate_prompt != gremlin_prompt + or prompt.custom_rerank_info != custom_related_information + ) + if should_update_prompt: + prompt.custom_rerank_info = custom_related_information + prompt.default_question = text + prompt.answer_prompt = answer_prompt + prompt.keywords_extract_prompt = keywords_extract_prompt + prompt.gremlin_generate_prompt = gremlin_prompt + prompt.update_yaml_file() + vector_search = vector_only_answer or graph_vector_answer + graph_search = graph_only_answer or graph_vector_answer + return graph_search, gremlin_prompt, vector_search + + +async def rag_answer_streaming( + text: str, + raw_answer: bool, + vector_only_answer: bool, + graph_only_answer: bool, + graph_vector_answer: bool, + graph_ratio: float, + rerank_method: Literal["bleu", "reranker"], + near_neighbor_first: bool, + custom_related_information: str, + answer_prompt: str, + keywords_extract_prompt: str, + gremlin_tmpl_num: Optional[int] = 2, + gremlin_prompt: Optional[str] = None, +) -> AsyncGenerator[Tuple[str, str, str, str], None]: + """ + Generate an answer using the RAG (Retrieval-Augmented Generation) pipeline. + 1. Initialize the RAGPipeline. + 2. Select vector search or graph search based on parameters. + 3. Merge, deduplicate, and rerank the results. + 4. Synthesize the final answer. + 5. Run the pipeline and return the results. + """ + + graph_search, gremlin_prompt, vector_search = update_ui_configs(answer_prompt, custom_related_information, + graph_only_answer, graph_vector_answer, + gremlin_prompt, keywords_extract_prompt, text, + vector_only_answer) + if raw_answer is False and not vector_search and not graph_search: + gr.Warning("Please select at least one generate mode.") + yield "", "", "", "" + return + + rag = RAGPipeline() + if vector_search: + rag.query_vector_index() + if graph_search: + rag.extract_keywords(extract_template=keywords_extract_prompt).keywords_to_vid().import_schema( + huge_settings.graph_name + ).query_graphdb( + num_gremlin_generate_example=gremlin_tmpl_num, + gremlin_prompt=gremlin_prompt, + ) + rag.merge_dedup_rerank( + graph_ratio, + rerank_method, + near_neighbor_first, + ) + # rag.synthesize_answer(raw_answer, vector_only_answer, graph_only_answer, graph_vector_answer, answer_prompt) + + try: + context = rag.run(verbose=True, query=text, vector_search=vector_search, graph_search=graph_search) + if context.get("switch_to_bleu"): + gr.Warning("Online reranker fails, automatically switches to local bleu rerank.") + answer_synthesize = AnswerSynthesize( + raw_answer=raw_answer, + vector_only_answer=vector_only_answer, + graph_only_answer=graph_only_answer, + graph_vector_answer=graph_vector_answer, + prompt_template=answer_prompt, + ) + async for context in answer_synthesize.run_streaming(context): + if context.get("switch_to_bleu"): + gr.Warning("Online reranker fails, automatically switches to local bleu rerank.") + yield ( + context.get("raw_answer", ""), + context.get("vector_only_answer", ""), + context.get("graph_only_answer", ""), + context.get("graph_vector_answer", ""), + ) + except ValueError as e: + log.critical(e) + raise gr.Error(str(e)) + except Exception as e: + log.critical(e) + raise gr.Error(f"An unexpected error occurred: {str(e)}") + + def create_rag_block(): # pylint: disable=R0915 (too-many-statements),C0301 gr.Markdown("""## 1. HugeGraph RAG Query""") @@ -130,13 +216,17 @@ def create_rag_block(): # TODO: Only support inline formula now. Should support block formula gr.Markdown("Basic LLM Answer", elem_classes="output-box-label") - raw_out = gr.Markdown(elem_classes="output-box", show_copy_button=True, latex_delimiters=[{"left":"$", "right":"$", "display":False}]) + raw_out = gr.Markdown(elem_classes="output-box", show_copy_button=True, + latex_delimiters=[{"left": "$", "right": "$", "display": False}]) gr.Markdown("Vector-only Answer", elem_classes="output-box-label") - vector_only_out = gr.Markdown(elem_classes="output-box", show_copy_button=True, latex_delimiters=[{"left":"$", "right":"$", "display":False}]) + vector_only_out = gr.Markdown(elem_classes="output-box", show_copy_button=True, + latex_delimiters=[{"left": "$", "right": "$", "display": False}]) gr.Markdown("Graph-only Answer", elem_classes="output-box-label") - graph_only_out = gr.Markdown(elem_classes="output-box", show_copy_button=True, latex_delimiters=[{"left":"$", "right":"$", "display":False}]) + graph_only_out = gr.Markdown(elem_classes="output-box", show_copy_button=True, + latex_delimiters=[{"left": "$", "right": "$", "display": False}]) gr.Markdown("Graph-Vector Answer", elem_classes="output-box-label") - graph_vector_out = gr.Markdown(elem_classes="output-box", show_copy_button=True, latex_delimiters=[{"left":"$", "right":"$", "display":False}]) + graph_vector_out = gr.Markdown(elem_classes="output-box", show_copy_button=True, + latex_delimiters=[{"left": "$", "right": "$", "display": False}]) answer_prompt_input = gr.Textbox( value=prompt.answer_prompt, label="Query Prompt", show_copy_button=True, lines=7 @@ -184,7 +274,7 @@ def create_rag_block(): btn = gr.Button("Answer Question", variant="primary") btn.click( # pylint: disable=no-member - fn=rag_answer, + fn=rag_answer_streaming, inputs=[ inp, raw_radio, @@ -254,13 +344,13 @@ def create_rag_block(): is_vector_only_answer: bool, is_graph_only_answer: bool, is_graph_vector_answer: bool, - graph_ratio: float, - rerank_method: Literal["bleu", "reranker"], - near_neighbor_first: bool, - custom_related_information: str, + graph_ratio_ui: float, + rerank_method_ui: Literal["bleu", "reranker"], + near_neighbor_first_ui: bool, + custom_related_information_ui: str, answer_prompt: str, keywords_extract_prompt: str, - answer_max_line_count: int = 1, + answer_max_line_count_ui: int = 1, progress=gr.Progress(track_tqdm=True), ): df = pd.read_excel(questions_path, dtype=str) @@ -273,10 +363,10 @@ def create_rag_block(): is_vector_only_answer, is_graph_only_answer, is_graph_vector_answer, - graph_ratio, - rerank_method, - near_neighbor_first, - custom_related_information, + graph_ratio_ui, + rerank_method_ui, + near_neighbor_first_ui, + custom_related_information_ui, answer_prompt, keywords_extract_prompt, ) @@ -285,9 +375,9 @@ def create_rag_block(): df.at[index, "Graph-only Answer"] = graph_only_answer df.at[index, "Graph-Vector Answer"] = graph_vector_answer progress((index + 1, total_rows)) - answers_path = os.path.join(resource_path, "demo", "questions_answers.xlsx") - df.to_excel(answers_path, index=False) - return df.head(answer_max_line_count), answers_path + answers_path_ui = os.path.join(resource_path, "demo", "questions_answers.xlsx") + df.to_excel(answers_path_ui, index=False) + return df.head(answer_max_line_count_ui), answers_path_ui with gr.Row(): with gr.Column(): diff --git a/hugegraph-llm/src/hugegraph_llm/models/llms/base.py b/hugegraph-llm/src/hugegraph_llm/models/llms/base.py index f2dd234..c6bfa44 100644 --- a/hugegraph-llm/src/hugegraph_llm/models/llms/base.py +++ b/hugegraph-llm/src/hugegraph_llm/models/llms/base.py @@ -16,7 +16,7 @@ # under the License. from abc import ABC, abstractmethod -from typing import Any, List, Optional, Callable, Dict +from typing import Any, AsyncGenerator, Generator, List, Optional, Callable, Dict class BaseLLM(ABC): @@ -43,8 +43,17 @@ class BaseLLM(ABC): self, messages: Optional[List[Dict[str, Any]]] = None, prompt: Optional[str] = None, - on_token_callback: Callable = None, - ) -> List[Any]: + on_token_callback: Optional[Callable] = None, + ) -> Generator[str, None, None]: + """Comment""" + + @abstractmethod + async def agenerate_streaming( + self, + messages: Optional[List[Dict[str, Any]]] = None, + prompt: Optional[str] = None, + on_token_callback: Optional[Callable] = None, + ) -> AsyncGenerator[str, None]: """Comment""" @abstractmethod diff --git a/hugegraph-llm/src/hugegraph_llm/models/llms/litellm.py b/hugegraph-llm/src/hugegraph_llm/models/llms/litellm.py index 23a1250..ca5ae60 100644 --- a/hugegraph-llm/src/hugegraph_llm/models/llms/litellm.py +++ b/hugegraph-llm/src/hugegraph_llm/models/llms/litellm.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Callable, List, Optional, Dict, Any +from typing import Callable, List, Optional, Dict, Any, AsyncGenerator import tiktoken from litellm import completion, acompletion @@ -137,6 +137,35 @@ class LiteLLMClient(BaseLLM): log.error("Error in streaming LiteLLM call: %s", e) return f"Error: {str(e)}" + async def agenerate_streaming( + self, + messages: Optional[List[Dict[str, Any]]] = None, + prompt: Optional[str] = None, + on_token_callback: Optional[Callable] = None, + ) -> AsyncGenerator[str, None]: + """Generate a response to the query messages/prompt in async streaming mode.""" + if messages is None: + assert prompt is not None, "Messages or prompt must be provided." + messages = [{"role": "user", "content": prompt}] + try: + response = await acompletion( + model=self.model, + messages=messages, + temperature=self.temperature, + max_tokens=self.max_tokens, + api_key=self.api_key, + base_url=self.api_base, + stream=True, + ) + async for chunk in response: + if chunk.choices[0].delta.content: + if on_token_callback: + on_token_callback(chunk) + yield chunk.choices[0].delta.content + except (RateLimitError, BudgetExceededError, APIError) as e: + log.error("Error in async streaming LiteLLM call: %s", e) + yield f"Error: {str(e)}" + def num_tokens_from_string(self, string: str) -> int: """Get token count from string.""" try: diff --git a/hugegraph-llm/src/hugegraph_llm/models/llms/ollama.py b/hugegraph-llm/src/hugegraph_llm/models/llms/ollama.py index 62f5ef2..58f063b 100644 --- a/hugegraph-llm/src/hugegraph_llm/models/llms/ollama.py +++ b/hugegraph-llm/src/hugegraph_llm/models/llms/ollama.py @@ -17,7 +17,7 @@ import json -from typing import Any, List, Optional, Callable, Dict +from typing import Any, AsyncGenerator, Generator, List, Optional, Callable, Dict import ollama from retry import retry @@ -89,22 +89,49 @@ class OllamaClient(BaseLLM): self, messages: Optional[List[Dict[str, Any]]] = None, prompt: Optional[str] = None, - on_token_callback: Callable = None, - ) -> List[Any]: + on_token_callback: Optional[Callable] = None, + ) -> Generator[str, None, None]: """Comment""" if messages is None: assert prompt is not None, "Messages or prompt must be provided." messages = [{"role": "user", "content": prompt}] - stream = self.client.chat( + + for chunk in self.client.chat( model=self.model, messages=messages, stream=True - ) - chunks = [] - for chunk in stream: - on_token_callback(chunk["message"]["content"]) - chunks.append(chunk) - return chunks + ): + token = chunk["message"]["content"] + if on_token_callback: + on_token_callback(token) + yield token + + async def agenerate_streaming( + self, + messages: Optional[List[Dict[str, Any]]] = None, + prompt: Optional[str] = None, + on_token_callback: Optional[Callable] = None, + ) -> AsyncGenerator[str, None]: + """Comment""" + if messages is None: + assert prompt is not None, "Messages or prompt must be provided." + messages = [{"role": "user", "content": prompt}] + + try: + async_generator = await self.async_client.chat( + model=self.model, + messages=messages, + stream=True + ) + async for chunk in async_generator: + token = chunk.get("message", {}).get("content", "") + if on_token_callback: + on_token_callback(token) + yield token + except Exception as e: + print(f"Retrying LLM call {e}") + raise e + def num_tokens_from_string( self, diff --git a/hugegraph-llm/src/hugegraph_llm/models/llms/openai.py b/hugegraph-llm/src/hugegraph_llm/models/llms/openai.py index a020067..45f6d7a 100644 --- a/hugegraph-llm/src/hugegraph_llm/models/llms/openai.py +++ b/hugegraph-llm/src/hugegraph_llm/models/llms/openai.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Callable, List, Optional, Dict, Any +from typing import Callable, List, Optional, Dict, Any, Generator, AsyncGenerator import openai import tiktoken @@ -90,9 +90,9 @@ class OpenAIClient(BaseLLM): retry=retry_if_exception_type((RateLimitError, APIConnectionError, APITimeoutError)), ) async def agenerate( - self, - messages: Optional[List[Dict[str, Any]]] = None, - prompt: Optional[str] = None, + self, + messages: Optional[List[Dict[str, Any]]] = None, + prompt: Optional[str] = None, ) -> str: """Generate a response to the query messages/prompt.""" if messages is None: @@ -119,31 +119,91 @@ class OpenAIClient(BaseLLM): log.error("Retrying LLM call %s", e) raise e + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type((RateLimitError, APIConnectionError, APITimeoutError)), + ) def generate_streaming( self, messages: Optional[List[Dict[str, Any]]] = None, prompt: Optional[str] = None, - on_token_callback: Callable = None, - ) -> str: - """Generate a response to the query messages/prompt in streaming mode.""" + on_token_callback: Optional[Callable[[str], None]] = None, + ) -> Generator[str, None, None]: + """Generate a response to the query messages/prompt in streaming mode. + + Yields: + Accumulated response string after each new token. + """ if messages is None: assert prompt is not None, "Messages or prompt must be provided." messages = [{"role": "user", "content": prompt}] - completions = self.client.chat.completions.create( - model=self.model, - temperature=self.temperature, - max_tokens=self.max_tokens, - messages=messages, - stream=True, - ) - result = "" - for message in completions: - # Process the streamed messages or perform any other desired action - delta = message["choices"][0]["delta"] - if "content" in delta: - result += delta["content"] - on_token_callback(message) - return result + + try: + completions = self.client.chat.completions.create( + model=self.model, + temperature=self.temperature, + max_tokens=self.max_tokens, + messages=messages, + stream=True, + ) + + for chunk in completions: + delta = chunk.choices[0].delta + if delta.content: + token = delta.content + if on_token_callback: + on_token_callback(token) + yield token + + except openai.BadRequestError as e: + log.critical("Fatal: %s", e) + yield str(f"Error: {e}") + except openai.AuthenticationError: + log.critical("The provided API key is invalid") + yield "Error: The provided API key is invalid" + except Exception as e: + log.error("Error in streaming: %s", e) + raise e + + async def agenerate_streaming( + self, + messages: Optional[List[Dict[str, Any]]] = None, + prompt: Optional[str] = None, + on_token_callback: Optional[Callable] = None, + ) -> AsyncGenerator[str, None]: + """Comment""" + if messages is None: + assert prompt is not None, "Messages or prompt must be provided." + messages = [{"role": "user", "content": prompt}] + + try: + completions = await self.aclient.chat.completions.create( + model=self.model, + temperature=self.temperature, + max_tokens=self.max_tokens, + messages=messages, + stream=True + ) + async for chunk in completions: + delta = chunk.choices[0].delta + if delta.content: + token = delta.content + if on_token_callback: + on_token_callback(token) + yield token + # TODO: log.info("Token usage: %s", completions.usage.model_dump_json()) + # catch context length / do not retry + except openai.BadRequestError as e: + log.critical("Fatal: %s", e) + yield str(f"Error: {e}") + # catch authorization errors / do not retry + except openai.AuthenticationError: + log.critical("The provided OpenAI API key is invalid") + yield "Error: The provided OpenAI API key is invalid" + except Exception as e: + log.error("Retrying LLM call %s", e) + raise e def num_tokens_from_string(self, string: str) -> int: """Get token count from string.""" diff --git a/hugegraph-llm/src/hugegraph_llm/models/llms/qianfan.py b/hugegraph-llm/src/hugegraph_llm/models/llms/qianfan.py index 967c391..cbca691 100644 --- a/hugegraph-llm/src/hugegraph_llm/models/llms/qianfan.py +++ b/hugegraph-llm/src/hugegraph_llm/models/llms/qianfan.py @@ -16,7 +16,7 @@ # under the License. import json -from typing import Optional, List, Dict, Any, Callable +from typing import AsyncGenerator, Generator, Optional, List, Dict, Any, Callable import qianfan from retry import retry @@ -74,9 +74,38 @@ class QianfanClient(BaseLLM): self, messages: Optional[List[Dict[str, Any]]] = None, prompt: Optional[str] = None, - on_token_callback: Callable = None, - ) -> str: - return self.generate(messages, prompt) + on_token_callback: Optional[Callable] = None, + ) -> Generator[str, None, None]: + if messages is None: + assert prompt is not None, "Messages or prompt must be provided." + messages = [{"role": "user", "content": prompt}] + + for msg in self.chat_comp.do(messages=messages, model=self.chat_model, stream=True): + token = msg.body['result'] + if on_token_callback: + on_token_callback(token) + yield token + + async def agenerate_streaming( + self, + messages: Optional[List[Dict[str, Any]]] = None, + prompt: Optional[str] = None, + on_token_callback: Optional[Callable] = None, + ) -> AsyncGenerator[str, None]: + if messages is None: + assert prompt is not None, "Messages or prompt must be provided." + messages = [{"role": "user", "content": prompt}] + + try: + async_generator = await self.chat_comp.ado(messages=messages, model=self.chat_model, stream=True) + async for msg in async_generator: + chunk = msg.body['result'] + if on_token_callback: + on_token_callback(chunk) + yield chunk + except Exception as e: + print(f"Retrying LLM call {e}") + raise e def num_tokens_from_string(self, string: str) -> int: return len(string) diff --git a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py index 666ecf9..5c4ab5f 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py @@ -18,7 +18,7 @@ # pylint: disable=W0621 import asyncio -from typing import Any, Dict, Optional +from typing import Any, AsyncGenerator, Dict, Optional from hugegraph_llm.config import prompt from hugegraph_llm.models.llms.base import BaseLLM @@ -35,17 +35,17 @@ DEFAULT_ANSWER_TEMPLATE = prompt.answer_prompt class AnswerSynthesize: def __init__( - self, - llm: Optional[BaseLLM] = None, - prompt_template: Optional[str] = None, - question: Optional[str] = None, - context_body: Optional[str] = None, - context_head: Optional[str] = None, - context_tail: Optional[str] = None, - raw_answer: bool = False, - vector_only_answer: bool = True, - graph_only_answer: bool = False, - graph_vector_answer: bool = False, + self, + llm: Optional[BaseLLM] = None, + prompt_template: Optional[str] = None, + question: Optional[str] = None, + context_body: Optional[str] = None, + context_head: Optional[str] = None, + context_tail: Optional[str] = None, + raw_answer: bool = False, + vector_only_answer: bool = True, + graph_only_answer: bool = False, + graph_vector_answer: bool = False, ): self._llm = llm self._prompt_template = prompt_template or DEFAULT_ANSWER_TEMPLATE @@ -59,15 +59,7 @@ class AnswerSynthesize: self._graph_vector_answer = graph_vector_answer def run(self, context: Dict[str, Any]) -> Dict[str, Any]: - if self._llm is None: - self._llm = LLMs().get_chat_llm() - - if self._question is None: - self._question = context.get("query") or None - assert self._question is not None, "No question for synthesizing." - - context_head_str = context.get("synthesize_context_head") or self._context_head or "" - context_tail_str = context.get("synthesize_context_tail") or self._context_tail or "" + context_head_str, context_tail_str = self.init_llm(context) if self._context_body is not None: context_str = (f"{context_head_str}\n" @@ -78,6 +70,22 @@ class AnswerSynthesize: response = self._llm.generate(prompt=final_prompt) return {"answer": response} + graph_result_context, vector_result_context = self.handle_vector_graph(context) + context = asyncio.run(self.async_generate(context, context_head_str, context_tail_str, + vector_result_context, graph_result_context)) + return context + + def init_llm(self, context): + if self._llm is None: + self._llm = LLMs().get_chat_llm() + if self._question is None: + self._question = context.get("query") or None + assert self._question is not None, "No question for synthesizing." + context_head_str = context.get("synthesize_context_head") or self._context_head or "" + context_tail_str = context.get("synthesize_context_tail") or self._context_tail or "" + return context_head_str, context_tail_str + + def handle_vector_graph(self, context): vector_result = context.get("vector_result") if vector_result: vector_result_context = "Phrases related to the query:\n" + "\n".join( @@ -85,7 +93,6 @@ class AnswerSynthesize: ) else: vector_result_context = "No (vector)phrase related to the query." - graph_result = context.get("graph_result") if graph_result: graph_context_head = context.get("graph_context_head", "Knowledge from graphdb for the query:\n") @@ -95,10 +102,31 @@ class AnswerSynthesize: else: graph_result_context = "No related graph data found for current query." log.warning(graph_result_context) + return graph_result_context, vector_result_context - context = asyncio.run(self.async_generate(context, context_head_str, context_tail_str, - vector_result_context, graph_result_context)) - return context + async def run_streaming(self, context: Dict[str, Any]) -> AsyncGenerator[Dict[str, Any], None]: + context_head_str, context_tail_str = self.init_llm(context) + + if self._context_body is not None: + context_str = (f"{context_head_str}\n" + f"{self._context_body}\n" + f"{context_tail_str}".strip("\n")) + + final_prompt = self._prompt_template.format(context_str=context_str, query_str=self._question) + response = self._llm.generate(prompt=final_prompt) + yield {"answer": response} + return + + graph_result_context, vector_result_context = self.handle_vector_graph(context) + + async for context in self.async_streaming_generate( + context, + context_head_str, + context_tail_str, + vector_result_context, + graph_result_context + ): + yield context async def async_generate(self, context: Dict[str, Any], context_head_str: str, context_tail_str: str, vector_result_context: str, @@ -151,3 +179,81 @@ class AnswerSynthesize: ops = sum([self._raw_answer, self._vector_only_answer, self._graph_only_answer, self._graph_vector_answer]) context['call_count'] = context.get('call_count', 0) + ops return context + + async def async_streaming_generate(self, context: Dict[str, Any], context_head_str: str, + context_tail_str: str, vector_result_context: str, + graph_result_context: str) -> AsyncGenerator[Dict[str, Any], None]: + # async_tasks stores the async tasks for different answer types + async_generators = [] + auto_id = 0 + if self._raw_answer: + final_prompt = self._question + async_generators.append( + self.__llm_generate_with_meta_info(task_id=auto_id, target_key="raw_answer", prompt=final_prompt) + ) + auto_id += 1 + if self._vector_only_answer: + context_str = (f"{context_head_str}\n" + f"{vector_result_context}\n" + f"{context_tail_str}".strip("\n")) + + final_prompt = self._prompt_template.format(context_str=context_str, query_str=self._question) + async_generators.append( + self.__llm_generate_with_meta_info( + task_id=auto_id, + target_key="vector_only_answer", + prompt=final_prompt + ) + ) + auto_id += 1 + if self._graph_only_answer: + context_str = (f"{context_head_str}\n" + f"{graph_result_context}\n" + f"{context_tail_str}".strip("\n")) + + final_prompt = self._prompt_template.format(context_str=context_str, query_str=self._question) + async_generators.append( + self.__llm_generate_with_meta_info(task_id=auto_id, target_key="graph_only_answer", prompt=final_prompt) + ) + auto_id += 1 + if self._graph_vector_answer: + context_body_str = f"{vector_result_context}\n{graph_result_context}" + if context.get("graph_ratio", 0.5) < 0.5: + context_body_str = f"{graph_result_context}\n{vector_result_context}" + context_str = (f"{context_head_str}\n" + f"{context_body_str}\n" + f"{context_tail_str}".strip("\n")) + + final_prompt = self._prompt_template.format(context_str=context_str, query_str=self._question) + async_generators.append( + self.__llm_generate_with_meta_info( + task_id=auto_id, + target_key="graph_vector_answer", + prompt=final_prompt + ) + ) + auto_id += 1 + + ops = sum([self._raw_answer, self._vector_only_answer, self._graph_only_answer, self._graph_vector_answer]) + context['call_count'] = context.get('call_count', 0) + ops + + async_tasks = [asyncio.create_task(anext(gen)) for gen in async_generators] + while True: + done, _ = await asyncio.wait(async_tasks, return_when=asyncio.FIRST_COMPLETED) + stop_task_num = 0 + for task in done: + try: + task_id, target_key, token = task.result() + context[target_key] = context.get(target_key, "") + token + gen = async_generators[task_id] + async_tasks[task_id] = asyncio.create_task(anext(gen)) + except StopAsyncIteration: + stop_task_num += 1 + if stop_task_num == len(async_tasks): + break + yield context + + async def __llm_generate_with_meta_info(self, task_id: int, target_key: str, prompt: str): + # FIXME: Expected type 'AsyncIterable', got 'Coroutine[Any, Any, AsyncGenerator[str, None]]' instead + async for token in self._llm.agenerate_streaming(prompt=prompt): + yield task_id, target_key, token diff --git a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/gremlin_generate.py b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/gremlin_generate.py index 9694647..219a358 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/gremlin_generate.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/gremlin_generate.py @@ -92,10 +92,40 @@ class GremlinGenerateSynthesize: return context + def sync_generate(self, context: Dict[str, Any]): + query = context.get("query") + raw_example = [{'query': 'who is peter', 'gremlin': "g.V().has('name', 'peter')"}] + raw_prompt = self.gremlin_prompt.format( + query=query, + schema=self.schema, + example=self._format_examples(examples=raw_example), + vertices=self._format_vertices(vertices=self.vertices) + ) + raw_response = self.llm.generate(prompt=raw_prompt) + + examples = context.get("match_result") + init_prompt = self.gremlin_prompt.format( + query=query, + schema=self.schema, + example=self._format_examples(examples=examples), + vertices=self._format_vertices(vertices=self.vertices) + ) + initialized_response = self.llm.generate(prompt=init_prompt) + + log.debug("Text2Gremlin with tmpl prompt:\n %s,\n LLM Response: %s", init_prompt, initialized_response) + + context["result"] = self._extract_gremlin(response=initialized_response) + context["raw_result"] = self._extract_gremlin(response=raw_response) + context["call_count"] = context.get("call_count", 0) + 2 + + return context + def run(self, context: Dict[str, Any]) -> Dict[str, Any]: query = context.get("query", "") if not query: raise ValueError("query is required") - context = asyncio.run(self.async_generate(context)) + # TODO: Update to async_generate again + # The best method may be changing all `operator.run(*arg)` to be async function + context = self.sync_generate(context) return context