This is an automated email from the ASF dual-hosted git repository. xtsong pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/flink-agents.git
commit 407eac4ec7393c7e207fdc9c26cfcf8467cccfd1 Author: WenjinXie <wenjin...@gmail.com> AuthorDate: Fri Sep 12 14:58:23 2025 +0800 [hotfix] Refactor Prompt abstraction. Co-authored-by: Hao Li <1127478+lihao...@users.noreply.github.com> Co-authored-by: yanand0909 <yan...@confluent.io> --- python/flink_agents/api/prompts/prompt.py | 42 ++++++++++++++++++++-------- python/flink_agents/api/tests/test_prompt.py | 14 +++++----- 2 files changed, 37 insertions(+), 19 deletions(-) diff --git a/python/flink_agents/api/prompts/prompt.py b/python/flink_agents/api/prompts/prompt.py index 0b9b75e..a6b4bcf 100644 --- a/python/flink_agents/api/prompts/prompt.py +++ b/python/flink_agents/api/prompts/prompt.py @@ -15,39 +15,57 @@ # See the License for the specific language governing permissions and # limitations under the License. ################################################################################# +from abc import ABC, abstractmethod from typing import List, Sequence +from typing_extensions import override + from flink_agents.api.chat_message import ChatMessage, MessageRole from flink_agents.api.prompts.utils import format_string from flink_agents.api.resource import ResourceType, SerializableResource -class Prompt(SerializableResource): - """Prompt for a language model. - - Attributes: - ---------- - template : Union[Sequence[ChatMessage], str] - The prompt template. - """ - - template: Sequence[ChatMessage] | str +class Prompt(SerializableResource, ABC): + """Base prompt abstract.""" @staticmethod def from_messages(name: str, messages: Sequence[ChatMessage]) -> "Prompt": """Create prompt from sequence of ChatMessage.""" - return Prompt(name=name, template=messages) + return LocalPrompt(name=name, template=messages) @staticmethod def from_text(name: str, text: str) -> "Prompt": """Create prompt from text string.""" - return Prompt(name=name, template=text) + return LocalPrompt(name=name, template=text) + + @abstractmethod + def format_string(self, **kwargs: str) -> str: + """Generate text string from template with additional arguments.""" + + @abstractmethod + def format_messages( + self, role: MessageRole = MessageRole.SYSTEM, **kwargs: str + ) -> List[ChatMessage]: + """Generate list of ChatMessage from template with additional arguments.""" @classmethod + @override def resource_type(cls) -> ResourceType: """Get the resource type.""" return ResourceType.PROMPT + +class LocalPrompt(Prompt): + """Prompt for a language model. + + Attributes: + ---------- + template : Union[Sequence[ChatMessage], str] + The prompt template. + """ + + template: Sequence[ChatMessage] | str + def format_string(self, **kwargs: str) -> str: """Generate text string from template with input arguments.""" if isinstance(self.template, str): diff --git a/python/flink_agents/api/tests/test_prompt.py b/python/flink_agents/api/tests/test_prompt.py index 38342f9..a3831b1 100644 --- a/python/flink_agents/api/tests/test_prompt.py +++ b/python/flink_agents/api/tests/test_prompt.py @@ -18,7 +18,7 @@ import pytest from flink_agents.api.chat_message import ChatMessage, MessageRole -from flink_agents.api.prompts.prompt import Prompt +from flink_agents.api.prompts.prompt import LocalPrompt, Prompt @pytest.fixture(scope="module") @@ -32,7 +32,7 @@ def text_prompt() -> Prompt: # noqa: D103 return Prompt.from_text(name="prompt", text=template) -def test_prompt_from_text_to_string(text_prompt: Prompt) -> None: # noqa: D103 +def test_prompt_from_text_to_string(text_prompt: LocalPrompt) -> None: # noqa: D103 assert text_prompt.format_string( product_id="12345", description="wireless noise-canceling headphones with 20-hour battery life", @@ -45,7 +45,7 @@ def test_prompt_from_text_to_string(text_prompt: Prompt) -> None: # noqa: D103 ) -def test_prompt_from_text_to_messages(text_prompt: Prompt) -> None: # noqa: D103 +def test_prompt_from_text_to_messages(text_prompt: LocalPrompt) -> None: # noqa: D103 assert text_prompt.format_messages( product_id="12345", description="wireless noise-canceling headphones with 20-hour battery life", @@ -78,7 +78,7 @@ def messages_prompt() -> Prompt: # noqa: D103 return Prompt.from_messages(name="prompt", messages=template) -def test_prompt_from_messages_to_string(messages_prompt: Prompt) -> None: # noqa: D103 +def test_prompt_from_messages_to_string(messages_prompt: LocalPrompt) -> None: # noqa: D103 assert messages_prompt.format_string( product_id="12345", description="wireless noise-canceling headphones with 20-hour battery life", @@ -92,7 +92,7 @@ def test_prompt_from_messages_to_string(messages_prompt: Prompt) -> None: # noq ) -def test_prompt_from_messages_to_messages(messages_prompt: Prompt) -> None: # noqa: D103 +def test_prompt_from_messages_to_messages(messages_prompt: LocalPrompt) -> None: # noqa: D103 assert messages_prompt.format_messages( product_id="12345", description="wireless noise-canceling headphones with 20-hour battery life", @@ -112,7 +112,7 @@ def test_prompt_from_messages_to_messages(messages_prompt: Prompt) -> None: # n ] -def test_prompt_lack_one_argument(text_prompt: Prompt) -> None: # noqa: D103 +def test_prompt_lack_one_argument(text_prompt: LocalPrompt) -> None: # noqa: D103 assert text_prompt.format_string( product_id="12345", review="The headphones broke after one week of use. Very poor quality", @@ -126,6 +126,6 @@ def test_prompt_lack_one_argument(text_prompt: Prompt) -> None: # noqa: D103 def test_prompt_contain_json_schema() -> None: # noqa: D103 prompt = Prompt.from_text( name="prompt", - text=f"The json schema is {Prompt.model_json_schema(mode='serialization')}", + text=f"The json schema is {LocalPrompt.model_json_schema(mode='serialization')}", ) prompt.format_string()