This is an automated email from the ASF dual-hosted git repository.

pankajkoti pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new d2514b408c Bump up openai version to >=1.0 & use get_conn (#36014)
d2514b408c is described below

commit d2514b408cb98f792289a5d032aaf85fe605350d
Author: Pankaj Koti <[email protected]>
AuthorDate: Wed Dec 6 15:57:05 2023 +0530

    Bump up openai version to >=1.0 & use get_conn (#36014)
---
 airflow/providers/openai/hooks/openai.py           |  42 ++++---
 airflow/providers/openai/provider.yaml             |   2 +-
 .../connections.rst                                |  17 +++
 generated/provider_dependencies.json               |   2 +-
 tests/providers/openai/hooks/test_openai.py        | 127 +++++++++++++++++----
 5 files changed, 144 insertions(+), 46 deletions(-)

diff --git a/airflow/providers/openai/hooks/openai.py 
b/airflow/providers/openai/hooks/openai.py
index fac725b5be..f57c41c2b9 100644
--- a/airflow/providers/openai/hooks/openai.py
+++ b/airflow/providers/openai/hooks/openai.py
@@ -17,9 +17,10 @@
 
 from __future__ import annotations
 
+from functools import cached_property
 from typing import Any
 
-import openai
+from openai import OpenAI
 
 from airflow.hooks.base import BaseHook
 
@@ -41,37 +42,40 @@ class OpenAIHook(BaseHook):
     def __init__(self, conn_id: str = default_conn_name, *args: Any, **kwargs: 
Any) -> None:
         super().__init__(*args, **kwargs)
         self.conn_id = conn_id
-        openai.api_key = self._get_api_key()
-        api_base = self._get_api_base()
-        if api_base:
-            openai.api_base = api_base
 
-    @staticmethod
-    def get_ui_field_behaviour() -> dict[str, Any]:
+    @classmethod
+    def get_ui_field_behaviour(cls) -> dict[str, Any]:
         """Return custom field behaviour."""
         return {
-            "hidden_fields": ["schema", "port", "login", "extra"],
+            "hidden_fields": ["schema", "port", "login"],
             "relabeling": {"password": "API Key"},
             "placeholders": {},
         }
 
     def test_connection(self) -> tuple[bool, str]:
         try:
-            openai.Model.list()
+            self.conn.models.list()
             return True, "Connection established!"
         except Exception as e:
             return False, str(e)
 
-    def _get_api_key(self) -> str:
-        """Get the OpenAI API key from the connection."""
-        conn = self.get_connection(self.conn_id)
-        if not conn.password:
-            raise ValueError("OpenAI API key not found in connection")
-        return str(conn.password)
+    @cached_property
+    def conn(self) -> OpenAI:
+        """Return an OpenAI connection object."""
+        return self.get_conn()
 
-    def _get_api_base(self) -> None | str:
+    def get_conn(self) -> OpenAI:
+        """Return an OpenAI connection object."""
         conn = self.get_connection(self.conn_id)
-        return conn.host
+        extras = conn.extra_dejson
+        openai_client_kwargs = extras.get("openai_client_kwargs", {})
+        api_key = openai_client_kwargs.pop("api_key", None) or conn.password
+        base_url = openai_client_kwargs.pop("base_url", None) or conn.host or 
None
+        return OpenAI(
+            api_key=api_key,
+            base_url=base_url,
+            **openai_client_kwargs,
+        )
 
     def create_embeddings(
         self,
@@ -84,6 +88,6 @@ class OpenAIHook(BaseHook):
         :param text: The text to generate embeddings for.
         :param model: The model to use for generating embeddings.
         """
-        response = openai.Embedding.create(model=model, input=text, **kwargs)
-        embeddings: list[float] = response["data"][0]["embedding"]
+        response = self.conn.embeddings.create(model=model, input=text, 
**kwargs)
+        embeddings: list[float] = response.data[0].embedding
         return embeddings
diff --git a/airflow/providers/openai/provider.yaml 
b/airflow/providers/openai/provider.yaml
index 86226aa3f0..0f9d830a61 100644
--- a/airflow/providers/openai/provider.yaml
+++ b/airflow/providers/openai/provider.yaml
@@ -39,7 +39,7 @@ integrations:
 
 dependencies:
   - apache-airflow>=2.5.0
-  - openai[datalib]>=0.28.1,<1.0
+  - openai[datalib]>=1.0
 
 hooks:
   - integration-name: OpenAI
diff --git a/docs/apache-airflow-providers-openai/connections.rst 
b/docs/apache-airflow-providers-openai/connections.rst
index 88e79df59d..8ef7ee456b 100644
--- a/docs/apache-airflow-providers-openai/connections.rst
+++ b/docs/apache-airflow-providers-openai/connections.rst
@@ -35,3 +35,20 @@ API Key (required)
 
 Host (optional)
     The host address of the OpenAI instance.
+
+Extra (optional)
+    Specify the extra parameters (as json dictionary) that can be used in the
+    connection. All parameters are optional.
+    This ``extra`` field accepts a nested dictionary with key 
``openai_client_kwargs`` as key-value pairs that
+    are passed to the `OpenAI client 
<https://github.com/openai/openai-python/blob/main/src/openai/_client.py>`__
+    on instantiation. For example, to set the timeout for the client, you can 
pass the following dictionary
+    as the ``extra`` field:
+
+    .. code-block:: json
+
+        {
+          "openai_client_kwargs": {
+            "timeout": 10,
+            "api_key": "YOUR_API_KEY"
+          }
+        }
diff --git a/generated/provider_dependencies.json 
b/generated/provider_dependencies.json
index e6405cfe77..e53b5c2054 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -661,7 +661,7 @@
   "openai": {
     "deps": [
       "apache-airflow>=2.5.0",
-      "openai[datalib]>=0.28.1,<1.0"
+      "openai[datalib]>=1.0"
     ],
     "cross-providers-deps": [],
     "excluded-python-versions": []
diff --git a/tests/providers/openai/hooks/test_openai.py 
b/tests/providers/openai/hooks/test_openai.py
index cd811107f7..a80be35dfb 100644
--- a/tests/providers/openai/hooks/test_openai.py
+++ b/tests/providers/openai/hooks/test_openai.py
@@ -16,48 +16,125 @@
 # under the License.
 from __future__ import annotations
 
-from unittest.mock import Mock, patch
+import os
+from unittest.mock import patch
 
 import pytest
+from openai.types import CreateEmbeddingResponse, Embedding
 
+from airflow.models import Connection
 from airflow.providers.openai.hooks.openai import OpenAIHook
 
 
 @pytest.fixture
-def openai_hook():
-    with 
patch("airflow.providers.openai.hooks.openai.OpenAIHook._get_api_key"), patch(
-        "airflow.providers.openai.hooks.openai.OpenAIHook._get_api_base"
-    ) as _:
-        yield OpenAIHook(conn_id="test_conn_id")
+def mock_openai_connection():
+    conn_id = "openai_conn"
+    conn = Connection(
+        conn_id=conn_id,
+        conn_type="openai",
+    )
+    os.environ[f"AIRFLOW_CONN_{conn.conn_id.upper()}"] = conn.get_uri()
+    yield conn
 
 
 @pytest.fixture
-def mock_embeddings_response():
-    return {"data": [{"embedding": [0.1, 0.2, 0.3]}]}
+def mock_openai_hook(mock_openai_connection):
+    with patch("airflow.providers.openai.hooks.openai.OpenAI"):
+        yield OpenAIHook(conn_id=mock_openai_connection.conn_id)
 
 
 @pytest.fixture
-def mock_completions_response():
-    return Mock(
-        id="completion-id",
-        object="completion",
-        created=1234567890,
-        model="text-davinci-002",
-        usage={"prompt_tokens": 15, "completion_tokens": 32, "total_tokens": 
47},
-        choices=[Mock(text="the quick brown fox", finish_reason="stop", 
index=0)],
+def mock_embeddings_response():
+    return CreateEmbeddingResponse(
+        data=[Embedding(embedding=[0.1, 0.2, 0.3], index=0, 
object="embedding")],
+        model="text-embedding-ada-002-v2",
+        object="list",
+        usage={"prompt_tokens": 4, "total_tokens": 4},
     )
 
 
-def test_create_embeddings(openai_hook, mock_embeddings_response):
+def test_create_embeddings(mock_openai_hook, mock_embeddings_response):
     text = "Sample text"
-    with patch("openai.Embedding.create", 
return_value=mock_embeddings_response):
-        embeddings = openai_hook.create_embeddings(text)
+    mock_openai_hook.conn.embeddings.create.return_value = 
mock_embeddings_response
+    embeddings = mock_openai_hook.create_embeddings(text)
     assert embeddings == [0.1, 0.2, 0.3]
 
 
-def test_get_api_key():
-    mock_connection = Mock()
-    mock_connection.password = "your_api_key"
-    OpenAIHook.get_connection = Mock(return_value=mock_connection)
-    api_key = OpenAIHook()._get_api_key()
-    assert api_key == "your_api_key"
+def test_openai_hook_test_connection(mock_openai_hook):
+    result, message = mock_openai_hook.test_connection()
+    assert result is True
+    assert message == "Connection established!"
+
+
+@patch("airflow.providers.openai.hooks.openai.OpenAI")
+def test_get_conn_with_api_key_in_extra(mock_client):
+    conn_id = "api_key_in_extra"
+    conn = Connection(
+        conn_id=conn_id,
+        conn_type="openai",
+        extra={"openai_client_kwargs": {"api_key": "api_key_in_extra"}},
+    )
+    os.environ[f"AIRFLOW_CONN_{conn.conn_id.upper()}"] = conn.get_uri()
+    hook = OpenAIHook(conn_id=conn_id)
+    hook.get_conn()
+    mock_client.assert_called_once_with(
+        api_key="api_key_in_extra",
+        base_url=None,
+    )
+
+
+@patch("airflow.providers.openai.hooks.openai.OpenAI")
+def test_get_conn_with_api_key_in_password(mock_client):
+    conn_id = "api_key_in_password"
+    conn = Connection(
+        conn_id=conn_id,
+        conn_type="openai",
+        password="api_key_in_password",
+    )
+    os.environ[f"AIRFLOW_CONN_{conn.conn_id.upper()}"] = conn.get_uri()
+    hook = OpenAIHook(conn_id=conn_id)
+    hook.get_conn()
+    mock_client.assert_called_once_with(
+        api_key="api_key_in_password",
+        base_url=None,
+    )
+
+
+@patch("airflow.providers.openai.hooks.openai.OpenAI")
+def test_get_conn_with_base_url_in_extra(mock_client):
+    conn_id = "base_url_in_extra"
+    conn = Connection(
+        conn_id=conn_id,
+        conn_type="openai",
+        extra={"openai_client_kwargs": {"base_url": "base_url_in_extra", 
"api_key": "api_key_in_extra"}},
+    )
+    os.environ[f"AIRFLOW_CONN_{conn.conn_id.upper()}"] = conn.get_uri()
+    hook = OpenAIHook(conn_id=conn_id)
+    hook.get_conn()
+    mock_client.assert_called_once_with(
+        api_key="api_key_in_extra",
+        base_url="base_url_in_extra",
+    )
+
+
+@patch("airflow.providers.openai.hooks.openai.OpenAI")
+def test_get_conn_with_openai_client_kwargs(mock_client):
+    conn_id = "openai_client_kwargs"
+    conn = Connection(
+        conn_id=conn_id,
+        conn_type="openai",
+        extra={
+            "openai_client_kwargs": {
+                "api_key": "api_key_in_extra",
+                "organization": "organization_in_extra",
+            }
+        },
+    )
+    os.environ[f"AIRFLOW_CONN_{conn.conn_id.upper()}"] = conn.get_uri()
+    hook = OpenAIHook(conn_id=conn_id)
+    hook.get_conn()
+    mock_client.assert_called_once_with(
+        api_key="api_key_in_extra",
+        base_url=None,
+        organization="organization_in_extra",
+    )

Reply via email to