This is an automated email from the ASF dual-hosted git repository. xtsong pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/flink-agents.git
commit 4b8beabdeb6cd87e63a045bd6c3fd07f75d2f957 Author: WenjinXie <wenjin...@gmail.com> AuthorDate: Tue Aug 5 17:34:51 2025 +0800 [integration][python] Introduce ollama chat model in python. --- python/flink_agents/api/tools/tool.py | 28 ++++ .../flink_agents/examples/chat_ollama_exmaple.py | 105 ++++++++++++ .../integrations/chat_models/__init__.py | 17 ++ .../integrations/chat_models/ollama_chat_model.py | 183 +++++++++++++++++++++ .../integrations/chat_models/tests/__init__.py | 17 ++ .../chat_models/tests/start_ollama_server.sh | 32 ++++ .../chat_models/tests/test_ollama_chat_model.py | 102 ++++++++++++ python/pyproject.toml | 2 + 8 files changed, 486 insertions(+) diff --git a/python/flink_agents/api/tools/tool.py b/python/flink_agents/api/tools/tool.py index 891e068..dfa3e26 100644 --- a/python/flink_agents/api/tools/tool.py +++ b/python/flink_agents/api/tools/tool.py @@ -89,6 +89,34 @@ class ToolMetadata(BaseModel): == self.args_schema.model_json_schema() ) + def __get_parameters_dict(self) -> dict: + parameters = self.args_schema.model_json_schema() + parameters = { + k: v + for k, v in parameters.items() + if k in ["type", "properties", "required", "definitions", "$defs"] + } + return parameters + + def to_openai_tool(self, skip_length_check: bool = False) -> typing.Dict[str, Any]: # noqa:FBT001 + """To OpenAI tool.""" + if not skip_length_check and len(self.description) > 1024: + msg = ( + "Tool description exceeds maximum length of 1024 characters. " + "Please shorten your description or move it to the prompt." + ) + raise ValueError( + msg + ) + return { + "type": "function", + "function": { + "name": self.name, + "description": self.description, + "parameters": self.__get_parameters_dict(), + }, + } + class BaseTool(SerializableResource, ABC): """Base abstract class of all kinds of tools. diff --git a/python/flink_agents/examples/chat_ollama_exmaple.py b/python/flink_agents/examples/chat_ollama_exmaple.py new file mode 100644 index 0000000..d0ca673 --- /dev/null +++ b/python/flink_agents/examples/chat_ollama_exmaple.py @@ -0,0 +1,105 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# +import os +from typing import Any, Dict, Tuple, Type + +from flink_agents.api.agent import Agent +from flink_agents.api.chat_message import ChatMessage, MessageRole +from flink_agents.api.chat_models.chat_model import BaseChatModel +from flink_agents.api.decorators import action, chat_model, tool +from flink_agents.api.events.chat_event import ChatRequestEvent, ChatResponseEvent +from flink_agents.api.events.event import ( + InputEvent, + OutputEvent, +) +from flink_agents.api.execution_environment import AgentsExecutionEnvironment +from flink_agents.api.runner_context import RunnerContext +from flink_agents.integrations.chat_models.ollama_chat_model import OllamaChatModel + +model = os.environ.get("OLLAMA_CHAT_MODEL", "qwen3:8b") + + +class MyAgent(Agent): + """Mock agent for testing chat ollama in agent.""" + + @chat_model + @staticmethod + def chat_model() -> Tuple[Type[BaseChatModel], Dict[str, Any]]: + """ChatModel can be used in action.""" + return OllamaChatModel, { + "name": "chat_model", + "model": model, + "tools": ["add"], + } + + @tool + @staticmethod + def add(a: int, b: int) -> int: + """Calculate the sum of a and b. + + Parameters + ---------- + a : int + The first operand + b : int + The second operand + + Returns: + ------- + int: + The sum of a and b + """ + return a + b + + @action(InputEvent) + @staticmethod + def process_input(event: InputEvent, ctx: RunnerContext) -> None: + """User defined action for processing input. + + In this action, we will send ChatRequestEvent to trigger built-in actions. + """ + input = event.input + ctx.send_event( + ChatRequestEvent( + model="chat_model", + messages=[ChatMessage(role=MessageRole.USER, content=input)], + ) + ) + + @action(ChatResponseEvent) + @staticmethod + def process_chat_response(event: ChatResponseEvent, ctx: RunnerContext) -> None: + """User defined action for processing chat model response.""" + input = event.response + ctx.send_event(OutputEvent(output=input.content)) + +# Should manually start ollama server before run this example. +if __name__ == "__main__": + env = AgentsExecutionEnvironment.get_execution_environment() + + input_list = [] + agent = MyAgent() + + output_list = env.from_list(input_list).apply(agent).to_list() + + input_list.append({"key": "0001", "value": "calculate the sum of 1 and 2."}) + + env.execute() + + for key, value in output_list[0].items(): + print(f"{key}: {value}") diff --git a/python/flink_agents/integrations/chat_models/__init__.py b/python/flink_agents/integrations/chat_models/__init__.py new file mode 100644 index 0000000..e154fad --- /dev/null +++ b/python/flink_agents/integrations/chat_models/__init__.py @@ -0,0 +1,17 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# diff --git a/python/flink_agents/integrations/chat_models/ollama_chat_model.py b/python/flink_agents/integrations/chat_models/ollama_chat_model.py new file mode 100644 index 0000000..6b9b937 --- /dev/null +++ b/python/flink_agents/integrations/chat_models/ollama_chat_model.py @@ -0,0 +1,183 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# +import uuid +from typing import Any, Dict, List, Mapping, Optional, Sequence, Union + +from ollama import Client, Message +from pydantic import Field + +from flink_agents.api.chat_message import ChatMessage, MessageRole +from flink_agents.api.chat_models.chat_model import BaseChatModel +from flink_agents.api.resource import ResourceType + +DEFAULT_CONTEXT_WINDOW = 2048 +DEFAULT_REQUEST_TIMEOUT = 30.0 + + +class OllamaChatModel(BaseChatModel): + """Ollama ChatModel. + + Visit https://ollama.com/ to download and install Ollama. + + Run `ollama serve` to start a server. + + Run `ollama pull <name>` to download a model to run. + """ + + base_url: str = Field( + default="http://localhost:11434", + description="Base url the model is hosted under.", + ) + model: str = Field(description="Model name to use.") + temperature: float = Field( + default=0.75, + description="The temperature to use for sampling.", + ge=0.0, + le=1.0, + ) + num_ctx: int = Field( + default=DEFAULT_CONTEXT_WINDOW, + description="The maximum number of context tokens for the model.", + gt=0, + ) + request_timeout: float = Field( + default=DEFAULT_REQUEST_TIMEOUT, + description="The timeout for making http request to Ollama API server", + ) + additional_kwargs: Dict[str, Any] = Field( + default_factory=dict, + description="Additional model parameters for the Ollama API.", + ) + keep_alive: Optional[Union[float, str]] = Field( + default="5m", + description="controls how long the model will stay loaded into memory following the request(default: 5m)", + ) + + __client: Client = None + __tools: Sequence[Mapping[str, Any]] = [] + + def __init__( + self, + model: str, + base_url: str = "http://localhost:11434", + temperature: float = 0.75, + num_ctx: int = DEFAULT_CONTEXT_WINDOW, + request_timeout: Optional[float] = DEFAULT_REQUEST_TIMEOUT, + additional_kwargs: Optional[Dict[str, Any]] = None, + keep_alive: Optional[Union[float, str]] = None, + **kwargs: Any, + ) -> None: + """Init method.""" + if additional_kwargs is None: + additional_kwargs = {} + super().__init__( + model=model, + base_url=base_url, + temperature=temperature, + num_ctx=num_ctx, + request_timeout=request_timeout, + additional_kwargs=additional_kwargs, + keep_alive=keep_alive, + **kwargs, + ) + # bind tools + if self.tools is not None: + tools = [ + self.get_resource(tool_name, ResourceType.TOOL) + for tool_name in self.tools + ] + self.__tools = [tool.metadata.to_openai_tool() for tool in tools] + # bind prompt + if self.prompt is not None and isinstance(self.prompt, str): + self.prompt = self.get_resource(self.prompt, ResourceType.PROMPT) + + @property + def client(self) -> Client: + """Return ollama client.""" + if self.__client is None: + self.__client = Client(host=self.base_url, timeout=self.request_timeout) + return self.__client + + @property + def model_kwargs(self) -> Dict[str, Any]: + """Return ollama model configuration.""" + base_kwargs = { + "temperature": self.temperature, + "num_ctx": self.num_ctx, + } + return { + **base_kwargs, + **self.additional_kwargs, + } + + def chat(self, messages: Sequence[ChatMessage]) -> ChatMessage: + """Process a sequence of messages, and return a response.""" + if self.prompt is not None: + input_variable = {} + for msg in messages: + input_variable.update(msg.additional_kwargs) + messages = self.prompt.format_messages(**input_variable) + ollama_messages = self.__convert_to_ollama_messages(messages) + response = self.client.chat( + model=self.model, + messages=ollama_messages, + stream=False, + tools=self.__tools, + options=self.model_kwargs, + keep_alive=self.keep_alive, + ) + + ollama_tool_calls = response.message.tool_calls + if ollama_tool_calls is None: + ollama_tool_calls = [] + tool_calls = [] + for ollama_tool_call in ollama_tool_calls: + tool_call = { + "id": uuid.uuid4(), + "type": "function", + "function": { + "name": ollama_tool_call.function.name, + "arguments": ollama_tool_call.function.arguments, + }, + } + tool_calls.append(tool_call) + return ChatMessage( + role=MessageRole(response.message.role), + content=response.message.content, + tool_calls=tool_calls, + ) + + @staticmethod + def __convert_to_ollama_messages(messages: Sequence[ChatMessage]) -> List[Message]: + ollama_messages = [] + for message in messages: + ollama_message = Message(role=message.role.value, content=message.content) + if len(message.tool_calls) > 0: + ollama_tool_calls = [] + for tool_call in message.tool_calls: + name = tool_call["function"]["name"] + arguments = tool_call["function"]["arguments"] + ollama_tool_call = Message.ToolCall( + function=Message.ToolCall.Function( + name=name, arguments=arguments + ) + ) + ollama_tool_calls.append(ollama_tool_call) + ollama_message.tool_calls = ollama_tool_calls + ollama_messages.append(ollama_message) + return ollama_messages diff --git a/python/flink_agents/integrations/chat_models/tests/__init__.py b/python/flink_agents/integrations/chat_models/tests/__init__.py new file mode 100644 index 0000000..e154fad --- /dev/null +++ b/python/flink_agents/integrations/chat_models/tests/__init__.py @@ -0,0 +1,17 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# diff --git a/python/flink_agents/integrations/chat_models/tests/start_ollama_server.sh b/python/flink_agents/integrations/chat_models/tests/start_ollama_server.sh new file mode 100644 index 0000000..04c252a --- /dev/null +++ b/python/flink_agents/integrations/chat_models/tests/start_ollama_server.sh @@ -0,0 +1,32 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# + +# only works on linux +os=$(uname -s) +echo $os +if [[ $os == "Linux" ]]; then + curl -fsSL https://ollama.com/install.sh | sh + ret=$? + if [ "$ret" != "0" ] + then + exit $ret + fi + ollama serve + ollama pull qwen3:0.6b + ollama run qwen3:0.6b +fi \ No newline at end of file diff --git a/python/flink_agents/integrations/chat_models/tests/test_ollama_chat_model.py b/python/flink_agents/integrations/chat_models/tests/test_ollama_chat_model.py new file mode 100644 index 0000000..9614b65 --- /dev/null +++ b/python/flink_agents/integrations/chat_models/tests/test_ollama_chat_model.py @@ -0,0 +1,102 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# +import os +import subprocess +import sys +from pathlib import Path + +import pytest +from ollama import Client + +from flink_agents.api.chat_message import ChatMessage, MessageRole +from flink_agents.api.resource import ResourceType +from flink_agents.integrations.chat_models.ollama_chat_model import OllamaChatModel +from flink_agents.plan.tools.function_tool import FunctionTool, from_callable + +test_model = os.environ.get("OLLAMA_CHAT_MODEL", "qwen3:0.6b") +current_dir = Path(__file__).parent + +try: + # only auto setup ollama in ci with python3.9 to reduce ci cost. + if "3.9" in sys.version: + subprocess.run(["bash", f"{current_dir}/start_ollama_server.sh"], check=True) + client = Client() + models = client.list() + + model_found = False + for model in models["models"]: + if model.model == test_model: + model_found = True + break + + if not model_found: + client = None # type: ignore +except Exception: + client = None # type: ignore + + +@pytest.mark.skipif( + client is None, reason="Ollama client is not available or test model is missing" +) +def test_ollama_chat() -> None: # noqa :D103 + llm = OllamaChatModel(name="ollama", model=test_model) + response = llm.chat([ChatMessage(role=MessageRole.USER, content="Hello!")]) + assert response is not None + assert str(response).strip() != "" + + +def add(a: int, b: int) -> int: + """Calculate the sum of a and b. + + Parameters + ---------- + a : int + The first operand + b : int + The second operand + + Returns: + ------- + int: + The sum of a and b + """ + return a + b + + +def get_tool(name: str, type: ResourceType) -> FunctionTool: # noqa :D103 + return from_callable(name=name, func=add) + + +@pytest.mark.skipif( + client is None, reason="Ollama client is not available or test model is missing" +) +def test_ollama_chat_with_tools() -> None: # noqa :D103 + llm = OllamaChatModel( + name="ollama", model=test_model, tools=["add"], get_resource=get_tool + ) + response = llm.chat([ + ChatMessage( + role=MessageRole.USER, + content="Could you help me calculate the sum of 1 and 2?", + ) + ]) + + tool_calls = response.tool_calls + assert len(tool_calls) == 1 + tool_call = tool_calls[0] + assert add(**tool_call["function"]["arguments"]) == 3 diff --git a/python/pyproject.toml b/python/pyproject.toml index e6b8f83..1f36950 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -47,6 +47,8 @@ dependencies = [ "apache-flink==1.20.1", "pydantic==2.11.4", "docstring-parser==0.16", + #TODO: Seperate integration dependencies from project + "ollama==0.4.8", ] # Optional dependencies (dependency groups)