This is an automated email from the ASF dual-hosted git repository. potiuk pushed a commit to branch v2-8-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 01087a7fe788acefed27b38011fb42ac8a2c054a Author: Pankaj Koti <[email protected]> AuthorDate: Wed Dec 6 15:57:05 2023 +0530 Bump up openai version to >=1.0 & use get_conn (#36014) (cherry picked from commit d2514b408cb98f792289a5d032aaf85fe605350d) --- airflow/providers/openai/hooks/openai.py | 42 ++++--- airflow/providers/openai/provider.yaml | 2 +- .../connections.rst | 17 +++ generated/provider_dependencies.json | 2 +- tests/providers/openai/hooks/test_openai.py | 127 +++++++++++++++++---- 5 files changed, 144 insertions(+), 46 deletions(-) diff --git a/airflow/providers/openai/hooks/openai.py b/airflow/providers/openai/hooks/openai.py index fac725b5be..f57c41c2b9 100644 --- a/airflow/providers/openai/hooks/openai.py +++ b/airflow/providers/openai/hooks/openai.py @@ -17,9 +17,10 @@ from __future__ import annotations +from functools import cached_property from typing import Any -import openai +from openai import OpenAI from airflow.hooks.base import BaseHook @@ -41,37 +42,40 @@ class OpenAIHook(BaseHook): def __init__(self, conn_id: str = default_conn_name, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.conn_id = conn_id - openai.api_key = self._get_api_key() - api_base = self._get_api_base() - if api_base: - openai.api_base = api_base - @staticmethod - def get_ui_field_behaviour() -> dict[str, Any]: + @classmethod + def get_ui_field_behaviour(cls) -> dict[str, Any]: """Return custom field behaviour.""" return { - "hidden_fields": ["schema", "port", "login", "extra"], + "hidden_fields": ["schema", "port", "login"], "relabeling": {"password": "API Key"}, "placeholders": {}, } def test_connection(self) -> tuple[bool, str]: try: - openai.Model.list() + self.conn.models.list() return True, "Connection established!" except Exception as e: return False, str(e) - def _get_api_key(self) -> str: - """Get the OpenAI API key from the connection.""" - conn = self.get_connection(self.conn_id) - if not conn.password: - raise ValueError("OpenAI API key not found in connection") - return str(conn.password) + @cached_property + def conn(self) -> OpenAI: + """Return an OpenAI connection object.""" + return self.get_conn() - def _get_api_base(self) -> None | str: + def get_conn(self) -> OpenAI: + """Return an OpenAI connection object.""" conn = self.get_connection(self.conn_id) - return conn.host + extras = conn.extra_dejson + openai_client_kwargs = extras.get("openai_client_kwargs", {}) + api_key = openai_client_kwargs.pop("api_key", None) or conn.password + base_url = openai_client_kwargs.pop("base_url", None) or conn.host or None + return OpenAI( + api_key=api_key, + base_url=base_url, + **openai_client_kwargs, + ) def create_embeddings( self, @@ -84,6 +88,6 @@ class OpenAIHook(BaseHook): :param text: The text to generate embeddings for. :param model: The model to use for generating embeddings. """ - response = openai.Embedding.create(model=model, input=text, **kwargs) - embeddings: list[float] = response["data"][0]["embedding"] + response = self.conn.embeddings.create(model=model, input=text, **kwargs) + embeddings: list[float] = response.data[0].embedding return embeddings diff --git a/airflow/providers/openai/provider.yaml b/airflow/providers/openai/provider.yaml index 86226aa3f0..0f9d830a61 100644 --- a/airflow/providers/openai/provider.yaml +++ b/airflow/providers/openai/provider.yaml @@ -39,7 +39,7 @@ integrations: dependencies: - apache-airflow>=2.5.0 - - openai[datalib]>=0.28.1,<1.0 + - openai[datalib]>=1.0 hooks: - integration-name: OpenAI diff --git a/docs/apache-airflow-providers-openai/connections.rst b/docs/apache-airflow-providers-openai/connections.rst index 88e79df59d..8ef7ee456b 100644 --- a/docs/apache-airflow-providers-openai/connections.rst +++ b/docs/apache-airflow-providers-openai/connections.rst @@ -35,3 +35,20 @@ API Key (required) Host (optional) The host address of the OpenAI instance. + +Extra (optional) + Specify the extra parameters (as json dictionary) that can be used in the + connection. All parameters are optional. + This ``extra`` field accepts a nested dictionary with key ``openai_client_kwargs`` as key-value pairs that + are passed to the `OpenAI client <https://github.com/openai/openai-python/blob/main/src/openai/_client.py>`__ + on instantiation. For example, to set the timeout for the client, you can pass the following dictionary + as the ``extra`` field: + + .. code-block:: json + + { + "openai_client_kwargs": { + "timeout": 10, + "api_key": "YOUR_API_KEY" + } + } diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index e6405cfe77..e53b5c2054 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -661,7 +661,7 @@ "openai": { "deps": [ "apache-airflow>=2.5.0", - "openai[datalib]>=0.28.1,<1.0" + "openai[datalib]>=1.0" ], "cross-providers-deps": [], "excluded-python-versions": [] diff --git a/tests/providers/openai/hooks/test_openai.py b/tests/providers/openai/hooks/test_openai.py index cd811107f7..a80be35dfb 100644 --- a/tests/providers/openai/hooks/test_openai.py +++ b/tests/providers/openai/hooks/test_openai.py @@ -16,48 +16,125 @@ # under the License. from __future__ import annotations -from unittest.mock import Mock, patch +import os +from unittest.mock import patch import pytest +from openai.types import CreateEmbeddingResponse, Embedding +from airflow.models import Connection from airflow.providers.openai.hooks.openai import OpenAIHook @pytest.fixture -def openai_hook(): - with patch("airflow.providers.openai.hooks.openai.OpenAIHook._get_api_key"), patch( - "airflow.providers.openai.hooks.openai.OpenAIHook._get_api_base" - ) as _: - yield OpenAIHook(conn_id="test_conn_id") +def mock_openai_connection(): + conn_id = "openai_conn" + conn = Connection( + conn_id=conn_id, + conn_type="openai", + ) + os.environ[f"AIRFLOW_CONN_{conn.conn_id.upper()}"] = conn.get_uri() + yield conn @pytest.fixture -def mock_embeddings_response(): - return {"data": [{"embedding": [0.1, 0.2, 0.3]}]} +def mock_openai_hook(mock_openai_connection): + with patch("airflow.providers.openai.hooks.openai.OpenAI"): + yield OpenAIHook(conn_id=mock_openai_connection.conn_id) @pytest.fixture -def mock_completions_response(): - return Mock( - id="completion-id", - object="completion", - created=1234567890, - model="text-davinci-002", - usage={"prompt_tokens": 15, "completion_tokens": 32, "total_tokens": 47}, - choices=[Mock(text="the quick brown fox", finish_reason="stop", index=0)], +def mock_embeddings_response(): + return CreateEmbeddingResponse( + data=[Embedding(embedding=[0.1, 0.2, 0.3], index=0, object="embedding")], + model="text-embedding-ada-002-v2", + object="list", + usage={"prompt_tokens": 4, "total_tokens": 4}, ) -def test_create_embeddings(openai_hook, mock_embeddings_response): +def test_create_embeddings(mock_openai_hook, mock_embeddings_response): text = "Sample text" - with patch("openai.Embedding.create", return_value=mock_embeddings_response): - embeddings = openai_hook.create_embeddings(text) + mock_openai_hook.conn.embeddings.create.return_value = mock_embeddings_response + embeddings = mock_openai_hook.create_embeddings(text) assert embeddings == [0.1, 0.2, 0.3] -def test_get_api_key(): - mock_connection = Mock() - mock_connection.password = "your_api_key" - OpenAIHook.get_connection = Mock(return_value=mock_connection) - api_key = OpenAIHook()._get_api_key() - assert api_key == "your_api_key" +def test_openai_hook_test_connection(mock_openai_hook): + result, message = mock_openai_hook.test_connection() + assert result is True + assert message == "Connection established!" + + +@patch("airflow.providers.openai.hooks.openai.OpenAI") +def test_get_conn_with_api_key_in_extra(mock_client): + conn_id = "api_key_in_extra" + conn = Connection( + conn_id=conn_id, + conn_type="openai", + extra={"openai_client_kwargs": {"api_key": "api_key_in_extra"}}, + ) + os.environ[f"AIRFLOW_CONN_{conn.conn_id.upper()}"] = conn.get_uri() + hook = OpenAIHook(conn_id=conn_id) + hook.get_conn() + mock_client.assert_called_once_with( + api_key="api_key_in_extra", + base_url=None, + ) + + +@patch("airflow.providers.openai.hooks.openai.OpenAI") +def test_get_conn_with_api_key_in_password(mock_client): + conn_id = "api_key_in_password" + conn = Connection( + conn_id=conn_id, + conn_type="openai", + password="api_key_in_password", + ) + os.environ[f"AIRFLOW_CONN_{conn.conn_id.upper()}"] = conn.get_uri() + hook = OpenAIHook(conn_id=conn_id) + hook.get_conn() + mock_client.assert_called_once_with( + api_key="api_key_in_password", + base_url=None, + ) + + +@patch("airflow.providers.openai.hooks.openai.OpenAI") +def test_get_conn_with_base_url_in_extra(mock_client): + conn_id = "base_url_in_extra" + conn = Connection( + conn_id=conn_id, + conn_type="openai", + extra={"openai_client_kwargs": {"base_url": "base_url_in_extra", "api_key": "api_key_in_extra"}}, + ) + os.environ[f"AIRFLOW_CONN_{conn.conn_id.upper()}"] = conn.get_uri() + hook = OpenAIHook(conn_id=conn_id) + hook.get_conn() + mock_client.assert_called_once_with( + api_key="api_key_in_extra", + base_url="base_url_in_extra", + ) + + +@patch("airflow.providers.openai.hooks.openai.OpenAI") +def test_get_conn_with_openai_client_kwargs(mock_client): + conn_id = "openai_client_kwargs" + conn = Connection( + conn_id=conn_id, + conn_type="openai", + extra={ + "openai_client_kwargs": { + "api_key": "api_key_in_extra", + "organization": "organization_in_extra", + } + }, + ) + os.environ[f"AIRFLOW_CONN_{conn.conn_id.upper()}"] = conn.get_uri() + hook = OpenAIHook(conn_id=conn_id) + hook.get_conn() + mock_client.assert_called_once_with( + api_key="api_key_in_extra", + base_url=None, + organization="organization_in_extra", + )
