This is an automated email from the ASF dual-hosted git repository. xtsong pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/flink-agents.git
The following commit(s) were added to refs/heads/main by this push: new 30fbba0 [integration][python] Integrate OpenAI chat model. (#128) 30fbba0 is described below commit 30fbba06905191559fd73e206e0719f39ecec732 Author: Wenjin Xie <166717626+wenjin...@users.noreply.github.com> AuthorDate: Wed Sep 3 14:51:47 2025 +0800 [integration][python] Integrate OpenAI chat model. (#128) * [hotfix] Move to_openai_tool to common utils of chat model. * [integration][python] Integrate openai chat model. --- python/flink_agents/api/tools/tool.py | 20 +- .../integrations/chat_models/chat_model_utils.py | 40 +++ .../integrations/chat_models/ollama_chat_model.py | 3 +- .../integrations/chat_models/openai/__init__.py | 17 ++ .../chat_models/openai/openai_chat_model.py | 274 +++++++++++++++++++++ .../chat_models/openai/openai_utils.py | 131 ++++++++++ .../chat_models/openai/tests/__init__.py | 17 ++ .../openai/tests/test_openai_chat_model.py | 98 ++++++++ python/pyproject.toml | 1 + 9 files changed, 582 insertions(+), 19 deletions(-) diff --git a/python/flink_agents/api/tools/tool.py b/python/flink_agents/api/tools/tool.py index 0fa0958..d48b7a7 100644 --- a/python/flink_agents/api/tools/tool.py +++ b/python/flink_agents/api/tools/tool.py @@ -89,7 +89,8 @@ class ToolMetadata(BaseModel): == self.args_schema.model_json_schema() ) - def __get_parameters_dict(self) -> dict: + def get_parameters_dict(self) -> dict: + """Get the parameters of the tool.""" parameters = self.args_schema.model_json_schema() parameters = { k: v @@ -98,23 +99,6 @@ class ToolMetadata(BaseModel): } return parameters - def to_openai_tool(self, skip_length_check: bool = False) -> typing.Dict[str, Any]: # noqa:FBT001 - """To OpenAI tool.""" - if not skip_length_check and len(self.description) > 1024: - msg = ( - "Tool description exceeds maximum length of 1024 characters. " - "Please shorten your description or move it to the prompt." - ) - raise ValueError(msg) - return { - "type": "function", - "function": { - "name": self.name, - "description": self.description, - "parameters": self.__get_parameters_dict(), - }, - } - class BaseTool(SerializableResource, ABC): """Base abstract class of all kinds of tools. diff --git a/python/flink_agents/integrations/chat_models/chat_model_utils.py b/python/flink_agents/integrations/chat_models/chat_model_utils.py new file mode 100644 index 0000000..08198f2 --- /dev/null +++ b/python/flink_agents/integrations/chat_models/chat_model_utils.py @@ -0,0 +1,40 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# +from typing import Any, Dict + +from flink_agents.api.tools.tool import ToolMetadata + + +def to_openai_tool( + *, metadata: ToolMetadata, skip_length_check: bool = False +) -> Dict[str, Any]: + """To OpenAI tool.""" + if not skip_length_check and len(metadata.description) > 1024: + msg = ( + "Tool description exceeds maximum length of 1024 characters. " + "Please shorten your description or move it to the prompt." + ) + raise ValueError(msg) + return { + "type": "function", + "function": { + "name": metadata.name, + "description": metadata.description, + "parameters": metadata.get_parameters_dict(), + }, + } diff --git a/python/flink_agents/integrations/chat_models/ollama_chat_model.py b/python/flink_agents/integrations/chat_models/ollama_chat_model.py index bced33a..d63cf21 100644 --- a/python/flink_agents/integrations/chat_models/ollama_chat_model.py +++ b/python/flink_agents/integrations/chat_models/ollama_chat_model.py @@ -27,6 +27,7 @@ from flink_agents.api.chat_models.chat_model import ( BaseChatModelSetup, ) from flink_agents.api.tools.tool import BaseTool +from flink_agents.integrations.chat_models.chat_model_utils import to_openai_tool DEFAULT_CONTEXT_WINDOW = 2048 DEFAULT_REQUEST_TIMEOUT = 30.0 @@ -97,7 +98,7 @@ class OllamaChatModelConnection(BaseChatModelConnection): # Convert tool format ollama_tools = None if tools is not None: - ollama_tools = [tool.metadata.to_openai_tool() for tool in tools] + ollama_tools = [to_openai_tool(tool.metadata) for tool in tools] response = self.client.chat( model=self.model, diff --git a/python/flink_agents/integrations/chat_models/openai/__init__.py b/python/flink_agents/integrations/chat_models/openai/__init__.py new file mode 100644 index 0000000..e154fad --- /dev/null +++ b/python/flink_agents/integrations/chat_models/openai/__init__.py @@ -0,0 +1,17 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# diff --git a/python/flink_agents/integrations/chat_models/openai/openai_chat_model.py b/python/flink_agents/integrations/chat_models/openai/openai_chat_model.py new file mode 100644 index 0000000..6c6ab9c --- /dev/null +++ b/python/flink_agents/integrations/chat_models/openai/openai_chat_model.py @@ -0,0 +1,274 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# +from typing import Any, Dict, List, Literal, Optional, Sequence + +import httpx +from openai import NOT_GIVEN, OpenAI +from pydantic import Field, PrivateAttr + +from flink_agents.api.chat_message import ChatMessage +from flink_agents.api.chat_models.chat_model import ( + BaseChatModelConnection, + BaseChatModelSetup, +) +from flink_agents.api.tools.tool import BaseTool +from flink_agents.integrations.chat_models.chat_model_utils import to_openai_tool +from flink_agents.integrations.chat_models.openai.openai_utils import ( + convert_from_openai_message, + convert_to_openai_messages, + resolve_openai_credentials, +) + +DEFAULT_OPENAI_MODEL = "gpt-3.5-turbo" + + +class OpenAIChatModelConnection(BaseChatModelConnection): + """The connection to the OpenAI LLM. + + Attributes: + ---------- + api_key : str + The OpenAI API key. + api_base_url : str + The base URL for OpenAI API. + max_retries : int + The maximum number of API retries. + timeout : float + How long to wait, in seconds, for an API call before failing. + default_headers : Optional[Dict[str, str]] + The default headers for API requests. + reuse_client : bool + Whether to reuse the OpenAI client between requests. + """ + + api_key: str = Field(default=None, description="The OpenAI API key.") + api_base_url: str = Field(description="The base URL for OpenAI API.") + max_retries: int = Field( + default=3, + description="The maximum number of API retries.", + ge=0, + ) + timeout: float = Field( + default=60.0, + description="The timeout, in seconds, for API requests.", + ge=0, + ) + default_headers: Optional[Dict[str, str]] = Field( + default=None, description="The default headers for API requests." + ) + reuse_client: bool = Field( + default=True, + description=( + "Reuse the OpenAI client between requests. When doing anything with large " + "volumes of async API calls, setting this to false can improve stability." + ), + ) + + _client: Optional[OpenAI] = PrivateAttr(default=None) + _http_client: Optional[httpx.Client] = PrivateAttr() + + def __init__( + self, + *, + api_key: Optional[str] = None, + api_base_url: Optional[str] = None, + max_retries: int = 3, + timeout: float = 60.0, + reuse_client: bool = True, + http_client: Optional[httpx.Client] = None, + async_http_client: Optional[httpx.AsyncClient] = None, + **kwargs: Any, + ) -> None: + """Init method.""" + api_key, api_base_url = resolve_openai_credentials( + api_key=api_key, + api_base_url=api_base_url, + ) + super().__init__( + api_key=api_key, + api_base_url=api_base_url, + max_retries=max_retries, + timeout=timeout, + reuse_client=reuse_client, + **kwargs, + ) + + self._http_client = http_client + self._async_http_client = async_http_client + + @property + def client(self) -> OpenAI: + """Get OpenAI client.""" + config = self.__get_client_kwargs() + + if not self.reuse_client: + return OpenAI(**config) + + if self._client is None: + self._client = OpenAI(**config) + return self._client + + def __get_client_kwargs(self) -> Dict[str, Any]: + return { + "api_key": self.api_key, + "base_url": self.api_base_url, + "max_retries": self.max_retries, + "timeout": self.timeout, + "default_headers": self.default_headers, + "http_client": self._http_client, + } + + def chat( + self, + messages: Sequence[ChatMessage], + tools: Optional[List[BaseTool]] = None, + **kwargs: Any, + ) -> ChatMessage: + """Direct communication with model service for chat conversation. + + Parameters + ---------- + messages : Sequence[ChatMessage] + Input message sequence + tools : Optional[List] + List of tools that can be called by the model + **kwargs : Any + Additional parameters passed to the model service (e.g., temperature, + max_tokens, etc.) + + Returns: + ------- + ChatMessage + Model response message + """ + tool_specs = None + if tools is not None: + tool_specs = [to_openai_tool(tool.metadata) for tool in tools] + strict = kwargs.get("strict", False) + for tool_spec in tool_specs: + if tool_spec["type"] == "function": + tool_spec["function"]["strict"] = strict + tool_spec["function"]["parameters"]["additionalProperties"] = False + + response = self.client.chat.completions.create( + messages=convert_to_openai_messages(messages), + tools=tool_specs or NOT_GIVEN, + **kwargs, + ) + + response = response.choices[0].message + + return convert_from_openai_message(response) + + +DEFAULT_TEMPERATURE = 0.1 + + +class OpenAIChatModelSetup(BaseChatModelSetup): + """The settings for the OpenAI LLM. + + Attributes: + ---------- + model : str + The OpenAI model to use. + temperature : float + The temperature to use during generation. + max_tokens : Optional[int] + The maximum number of tokens to generate. + logprobs : Optional[bool] + Whether to return logprobs per token. + top_logprobs : int + The number of top token log probs to return. + additional_kwargs : Dict[str, Any] + Additional kwargs for the OpenAI API. + strict : bool + Whether to use strict mode for invoking tools/using schemas. + reasoning_effort : Optional[Literal["low", "medium", "high"]] + The effort to use for reasoning models. + """ + + model: str = Field( + default=DEFAULT_OPENAI_MODEL, description="The OpenAI model to use." + ) + temperature: float = Field( + default=DEFAULT_TEMPERATURE, + description="The temperature to use during generation.", + ge=0.0, + le=2.0, + ) + max_tokens: Optional[int] = Field( + description="The maximum number of tokens to generate.", + gt=0, + ) + logprobs: Optional[bool] = Field( + description="Whether to return logprobs per token.", + default=None, + ) + top_logprobs: int = Field( + description="The number of top token log probs to return.", + default=0, + ge=0, + le=20, + ) + additional_kwargs: Dict[str, Any] = Field( + default_factory=dict, description="Additional kwargs for the OpenAI API." + ) + strict: bool = Field( + default=False, + description="Whether to use strict mode for invoking tools/using schemas.", + ) + reasoning_effort: Optional[Literal["low", "medium", "high"]] = Field( + default=None, + description="The effort to use for reasoning models.", + ) + + def __init__( + self, + *, + model: str = DEFAULT_OPENAI_MODEL, + temperature: float = DEFAULT_TEMPERATURE, + max_tokens: Optional[int] = None, + additional_kwargs: Optional[Dict[str, Any]] = None, + strict: bool = False, + reasoning_effort: Optional[Literal["low", "medium", "high"]] = None, + **kwargs: Any, + ) -> None: + """Init method.""" + additional_kwargs = additional_kwargs or {} + super().__init__( + model=model, + temperature=temperature, + max_tokens=max_tokens, + additional_kwargs=additional_kwargs, + strict=strict, + reasoning_effort=reasoning_effort, + **kwargs, + ) + + @property + def model_kwargs(self) -> Dict[str, Any]: + """Return chat model settings.""" + base_kwargs = {"model": self.model, "temperature": self.temperature} + if self.max_tokens is not None: + base_kwargs["max_tokens"] = self.max_tokens + if self.logprobs is not None and self.logprobs is True: + base_kwargs["logprobs"] = self.logprobs + base_kwargs["top_logprobs"] = self.top_logprobs + + all_kwargs = {**base_kwargs, **self.additional_kwargs} + return all_kwargs diff --git a/python/flink_agents/integrations/chat_models/openai/openai_utils.py b/python/flink_agents/integrations/chat_models/openai/openai_utils.py new file mode 100644 index 0000000..b896608 --- /dev/null +++ b/python/flink_agents/integrations/chat_models/openai/openai_utils.py @@ -0,0 +1,131 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# +import json +import os +from typing import List, Optional, Sequence, Tuple, cast + +import openai +from openai.types.chat import ChatCompletionMessage, ChatCompletionMessageParam + +from flink_agents.api.chat_message import ChatMessage, MessageRole + +DEFAULT_OPENAI_API_BASE_URL = "https://api.openai.com/v1" + + +def resolve_openai_credentials( + api_key: Optional[str] = None, + api_base_url: Optional[str] = None, +) -> Tuple[Optional[str], str]: + """Resolve OpenAI credentials. + + The order of precedence is: + 1. param + 2. env + 3. openai module + 4. default + """ + # resolve from param or env + api_key = _get_from_param_or_env("api_key", api_key, "OPENAI_API_KEY", "") + api_base_url = _get_from_param_or_env( + "api_base_url", api_base_url, "OPENAI_API_BASE_URL", "" + ) + + # resolve from openai module or default + final_api_key = api_key or openai.api_key or "" + final_api_base_url = api_base_url or openai.base_url or DEFAULT_OPENAI_API_BASE_URL + + return final_api_key, str(final_api_base_url) + + +def _get_from_param_or_env( + param_name: str, + value_from_args: Optional[str] = None, + env_var_name: Optional[str] = None, + default_value: Optional[str] = None, +) -> str: + """Get a value from a param or an environment variable. + + The order of precedence is: + 1. param + 2. env + 3. default + """ + if value_from_args is not None: + return value_from_args + elif env_var_name and env_var_name in os.environ and os.environ[env_var_name]: + return os.environ[env_var_name] + elif default_value is not None: + return default_value + else: + msg = ( + f"Did not find {param_name}, please add an environment variable" + f" `{env_var_name}` which contains it, or pass" + f" `{param_name}` as a named parameter." + ) + raise ValueError(msg) + + +def convert_to_openai_messages( + messages: Sequence[ChatMessage], +) -> List[ChatCompletionMessageParam]: + """Convert chat messages to OpenAI messages.""" + return [convert_to_openai_message(message) for message in messages] + + +def convert_to_openai_message(message: ChatMessage) -> ChatCompletionMessageParam: + """Convert a chat message to an OpenAI message.""" + context_txt = message.content + context_txt = ( + None + if context_txt == "" + and message.role == MessageRole.ASSISTANT + and len(message.tool_calls) > 0 + else context_txt + ) + if len(message.tool_calls) > 0: + openai_message = { + "role": message.role.value, + "content": context_txt, + "tool_calls": message.tool_calls, + } + else: + openai_message = {"role": message.role.value, "content": context_txt} + openai_message.update(message.extra_args) + return cast("ChatCompletionMessageParam", openai_message) + + +def convert_from_openai_message(message: ChatCompletionMessage) -> ChatMessage: + """Convert an OpenAI message to a chat message.""" + tool_calls = [] + if message.tool_calls: + tool_calls = [ + { + "id": tool_call.id, + "type": tool_call.type, + "function": { + "name": tool_call.function.name, + "arguments": json.loads(tool_call.function.arguments), + }, + } + for tool_call in message.tool_calls + ] + return ChatMessage( + role=MessageRole(message.role), + content=message.content or "", + tool_calls=tool_calls, + ) diff --git a/python/flink_agents/integrations/chat_models/openai/tests/__init__.py b/python/flink_agents/integrations/chat_models/openai/tests/__init__.py new file mode 100644 index 0000000..e154fad --- /dev/null +++ b/python/flink_agents/integrations/chat_models/openai/tests/__init__.py @@ -0,0 +1,17 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# diff --git a/python/flink_agents/integrations/chat_models/openai/tests/test_openai_chat_model.py b/python/flink_agents/integrations/chat_models/openai/tests/test_openai_chat_model.py new file mode 100644 index 0000000..e69700a --- /dev/null +++ b/python/flink_agents/integrations/chat_models/openai/tests/test_openai_chat_model.py @@ -0,0 +1,98 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# +import os + +import pytest + +from flink_agents.api.chat_message import ChatMessage, MessageRole +from flink_agents.api.resource import Resource, ResourceType +from flink_agents.integrations.chat_models.openai.openai_chat_model import ( + OpenAIChatModelConnection, + OpenAIChatModelSetup, +) +from flink_agents.plan.tools.function_tool import from_callable + +test_model = os.environ.get("TEST_MODEL") +api_key = os.environ.get("TEST_API_KEY") +api_base_url = os.environ.get("TEST_API_BASE_URL") + + +@pytest.mark.skipif(api_key is None, reason="TEST_API_KEY is not set") +def test_openai_chat_model() -> None: # noqa: D103 + connection = OpenAIChatModelConnection( + name="openai", api_key=api_key, api_base_url=api_base_url + ) + + def get_resource(name: str, type: ResourceType) -> Resource: + if type == ResourceType.CHAT_MODEL_CONNECTION: + return connection + else: + return get_resource(name, ResourceType.TOOL) + + chat_model = OpenAIChatModelSetup( + name="openai", model=test_model, connection="openai", get_resource=get_resource + ) + response = chat_model.chat([ChatMessage(role=MessageRole.USER, content="Hello!")]) + assert response is not None + assert str(response).strip() != "" + + +def add(a: int, b: int) -> int: + """Calculate the sum of a and b. + + Parameters + ---------- + a : int + The first operand + b : int + The second operand + + Returns: + ------- + int: + The sum of a and b + """ + return a + b + + +@pytest.mark.skipif(api_key is None, reason="TEST_API_KEY is not set") +def test_openai_chat_with_tools() -> None: # noqa : D103 + connection = OpenAIChatModelConnection( + name="openai", api_key=api_key, api_base_url=api_base_url + ) + + def get_resource(name: str, type: ResourceType) -> Resource: + if type == ResourceType.CHAT_MODEL_CONNECTION: + return connection + else: + return from_callable(name=name, func=add) + + chat_model = OpenAIChatModelSetup( + name="openai", + model=test_model, + connection="openai", + tools=["add"], + get_resource=get_resource, + ) + response = chat_model.chat( + [ChatMessage(role=MessageRole.USER, content="What is 377 + 688?")] + ) + tool_calls = response.tool_calls + assert len(tool_calls) == 1 + tool_call = tool_calls[0] + assert add(**tool_call["function"]["arguments"]) == 1065 diff --git a/python/pyproject.toml b/python/pyproject.toml index b6a7561..d8c7ffd 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -50,6 +50,7 @@ dependencies = [ #TODO: Seperate integration dependencies from project "ollama==0.4.8", "dashscope~=1.24.2", + "openai>=1.66.3" ] # Optional dependencies (dependency groups)