This is an automated email from the ASF dual-hosted git repository.
pankajkoti pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new d2514b408c Bump up openai version to >=1.0 & use get_conn (#36014)
d2514b408c is described below
commit d2514b408cb98f792289a5d032aaf85fe605350d
Author: Pankaj Koti <[email protected]>
AuthorDate: Wed Dec 6 15:57:05 2023 +0530
Bump up openai version to >=1.0 & use get_conn (#36014)
---
airflow/providers/openai/hooks/openai.py | 42 ++++---
airflow/providers/openai/provider.yaml | 2 +-
.../connections.rst | 17 +++
generated/provider_dependencies.json | 2 +-
tests/providers/openai/hooks/test_openai.py | 127 +++++++++++++++++----
5 files changed, 144 insertions(+), 46 deletions(-)
diff --git a/airflow/providers/openai/hooks/openai.py
b/airflow/providers/openai/hooks/openai.py
index fac725b5be..f57c41c2b9 100644
--- a/airflow/providers/openai/hooks/openai.py
+++ b/airflow/providers/openai/hooks/openai.py
@@ -17,9 +17,10 @@
from __future__ import annotations
+from functools import cached_property
from typing import Any
-import openai
+from openai import OpenAI
from airflow.hooks.base import BaseHook
@@ -41,37 +42,40 @@ class OpenAIHook(BaseHook):
def __init__(self, conn_id: str = default_conn_name, *args: Any, **kwargs:
Any) -> None:
super().__init__(*args, **kwargs)
self.conn_id = conn_id
- openai.api_key = self._get_api_key()
- api_base = self._get_api_base()
- if api_base:
- openai.api_base = api_base
- @staticmethod
- def get_ui_field_behaviour() -> dict[str, Any]:
+ @classmethod
+ def get_ui_field_behaviour(cls) -> dict[str, Any]:
"""Return custom field behaviour."""
return {
- "hidden_fields": ["schema", "port", "login", "extra"],
+ "hidden_fields": ["schema", "port", "login"],
"relabeling": {"password": "API Key"},
"placeholders": {},
}
def test_connection(self) -> tuple[bool, str]:
try:
- openai.Model.list()
+ self.conn.models.list()
return True, "Connection established!"
except Exception as e:
return False, str(e)
- def _get_api_key(self) -> str:
- """Get the OpenAI API key from the connection."""
- conn = self.get_connection(self.conn_id)
- if not conn.password:
- raise ValueError("OpenAI API key not found in connection")
- return str(conn.password)
+ @cached_property
+ def conn(self) -> OpenAI:
+ """Return an OpenAI connection object."""
+ return self.get_conn()
- def _get_api_base(self) -> None | str:
+ def get_conn(self) -> OpenAI:
+ """Return an OpenAI connection object."""
conn = self.get_connection(self.conn_id)
- return conn.host
+ extras = conn.extra_dejson
+ openai_client_kwargs = extras.get("openai_client_kwargs", {})
+ api_key = openai_client_kwargs.pop("api_key", None) or conn.password
+ base_url = openai_client_kwargs.pop("base_url", None) or conn.host or
None
+ return OpenAI(
+ api_key=api_key,
+ base_url=base_url,
+ **openai_client_kwargs,
+ )
def create_embeddings(
self,
@@ -84,6 +88,6 @@ class OpenAIHook(BaseHook):
:param text: The text to generate embeddings for.
:param model: The model to use for generating embeddings.
"""
- response = openai.Embedding.create(model=model, input=text, **kwargs)
- embeddings: list[float] = response["data"][0]["embedding"]
+ response = self.conn.embeddings.create(model=model, input=text,
**kwargs)
+ embeddings: list[float] = response.data[0].embedding
return embeddings
diff --git a/airflow/providers/openai/provider.yaml
b/airflow/providers/openai/provider.yaml
index 86226aa3f0..0f9d830a61 100644
--- a/airflow/providers/openai/provider.yaml
+++ b/airflow/providers/openai/provider.yaml
@@ -39,7 +39,7 @@ integrations:
dependencies:
- apache-airflow>=2.5.0
- - openai[datalib]>=0.28.1,<1.0
+ - openai[datalib]>=1.0
hooks:
- integration-name: OpenAI
diff --git a/docs/apache-airflow-providers-openai/connections.rst
b/docs/apache-airflow-providers-openai/connections.rst
index 88e79df59d..8ef7ee456b 100644
--- a/docs/apache-airflow-providers-openai/connections.rst
+++ b/docs/apache-airflow-providers-openai/connections.rst
@@ -35,3 +35,20 @@ API Key (required)
Host (optional)
The host address of the OpenAI instance.
+
+Extra (optional)
+ Specify the extra parameters (as json dictionary) that can be used in the
+ connection. All parameters are optional.
+ This ``extra`` field accepts a nested dictionary with key
``openai_client_kwargs`` as key-value pairs that
+ are passed to the `OpenAI client
<https://github.com/openai/openai-python/blob/main/src/openai/_client.py>`__
+ on instantiation. For example, to set the timeout for the client, you can
pass the following dictionary
+ as the ``extra`` field:
+
+ .. code-block:: json
+
+ {
+ "openai_client_kwargs": {
+ "timeout": 10,
+ "api_key": "YOUR_API_KEY"
+ }
+ }
diff --git a/generated/provider_dependencies.json
b/generated/provider_dependencies.json
index e6405cfe77..e53b5c2054 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -661,7 +661,7 @@
"openai": {
"deps": [
"apache-airflow>=2.5.0",
- "openai[datalib]>=0.28.1,<1.0"
+ "openai[datalib]>=1.0"
],
"cross-providers-deps": [],
"excluded-python-versions": []
diff --git a/tests/providers/openai/hooks/test_openai.py
b/tests/providers/openai/hooks/test_openai.py
index cd811107f7..a80be35dfb 100644
--- a/tests/providers/openai/hooks/test_openai.py
+++ b/tests/providers/openai/hooks/test_openai.py
@@ -16,48 +16,125 @@
# under the License.
from __future__ import annotations
-from unittest.mock import Mock, patch
+import os
+from unittest.mock import patch
import pytest
+from openai.types import CreateEmbeddingResponse, Embedding
+from airflow.models import Connection
from airflow.providers.openai.hooks.openai import OpenAIHook
@pytest.fixture
-def openai_hook():
- with
patch("airflow.providers.openai.hooks.openai.OpenAIHook._get_api_key"), patch(
- "airflow.providers.openai.hooks.openai.OpenAIHook._get_api_base"
- ) as _:
- yield OpenAIHook(conn_id="test_conn_id")
+def mock_openai_connection():
+ conn_id = "openai_conn"
+ conn = Connection(
+ conn_id=conn_id,
+ conn_type="openai",
+ )
+ os.environ[f"AIRFLOW_CONN_{conn.conn_id.upper()}"] = conn.get_uri()
+ yield conn
@pytest.fixture
-def mock_embeddings_response():
- return {"data": [{"embedding": [0.1, 0.2, 0.3]}]}
+def mock_openai_hook(mock_openai_connection):
+ with patch("airflow.providers.openai.hooks.openai.OpenAI"):
+ yield OpenAIHook(conn_id=mock_openai_connection.conn_id)
@pytest.fixture
-def mock_completions_response():
- return Mock(
- id="completion-id",
- object="completion",
- created=1234567890,
- model="text-davinci-002",
- usage={"prompt_tokens": 15, "completion_tokens": 32, "total_tokens":
47},
- choices=[Mock(text="the quick brown fox", finish_reason="stop",
index=0)],
+def mock_embeddings_response():
+ return CreateEmbeddingResponse(
+ data=[Embedding(embedding=[0.1, 0.2, 0.3], index=0,
object="embedding")],
+ model="text-embedding-ada-002-v2",
+ object="list",
+ usage={"prompt_tokens": 4, "total_tokens": 4},
)
-def test_create_embeddings(openai_hook, mock_embeddings_response):
+def test_create_embeddings(mock_openai_hook, mock_embeddings_response):
text = "Sample text"
- with patch("openai.Embedding.create",
return_value=mock_embeddings_response):
- embeddings = openai_hook.create_embeddings(text)
+ mock_openai_hook.conn.embeddings.create.return_value =
mock_embeddings_response
+ embeddings = mock_openai_hook.create_embeddings(text)
assert embeddings == [0.1, 0.2, 0.3]
-def test_get_api_key():
- mock_connection = Mock()
- mock_connection.password = "your_api_key"
- OpenAIHook.get_connection = Mock(return_value=mock_connection)
- api_key = OpenAIHook()._get_api_key()
- assert api_key == "your_api_key"
+def test_openai_hook_test_connection(mock_openai_hook):
+ result, message = mock_openai_hook.test_connection()
+ assert result is True
+ assert message == "Connection established!"
+
+
+@patch("airflow.providers.openai.hooks.openai.OpenAI")
+def test_get_conn_with_api_key_in_extra(mock_client):
+ conn_id = "api_key_in_extra"
+ conn = Connection(
+ conn_id=conn_id,
+ conn_type="openai",
+ extra={"openai_client_kwargs": {"api_key": "api_key_in_extra"}},
+ )
+ os.environ[f"AIRFLOW_CONN_{conn.conn_id.upper()}"] = conn.get_uri()
+ hook = OpenAIHook(conn_id=conn_id)
+ hook.get_conn()
+ mock_client.assert_called_once_with(
+ api_key="api_key_in_extra",
+ base_url=None,
+ )
+
+
+@patch("airflow.providers.openai.hooks.openai.OpenAI")
+def test_get_conn_with_api_key_in_password(mock_client):
+ conn_id = "api_key_in_password"
+ conn = Connection(
+ conn_id=conn_id,
+ conn_type="openai",
+ password="api_key_in_password",
+ )
+ os.environ[f"AIRFLOW_CONN_{conn.conn_id.upper()}"] = conn.get_uri()
+ hook = OpenAIHook(conn_id=conn_id)
+ hook.get_conn()
+ mock_client.assert_called_once_with(
+ api_key="api_key_in_password",
+ base_url=None,
+ )
+
+
+@patch("airflow.providers.openai.hooks.openai.OpenAI")
+def test_get_conn_with_base_url_in_extra(mock_client):
+ conn_id = "base_url_in_extra"
+ conn = Connection(
+ conn_id=conn_id,
+ conn_type="openai",
+ extra={"openai_client_kwargs": {"base_url": "base_url_in_extra",
"api_key": "api_key_in_extra"}},
+ )
+ os.environ[f"AIRFLOW_CONN_{conn.conn_id.upper()}"] = conn.get_uri()
+ hook = OpenAIHook(conn_id=conn_id)
+ hook.get_conn()
+ mock_client.assert_called_once_with(
+ api_key="api_key_in_extra",
+ base_url="base_url_in_extra",
+ )
+
+
+@patch("airflow.providers.openai.hooks.openai.OpenAI")
+def test_get_conn_with_openai_client_kwargs(mock_client):
+ conn_id = "openai_client_kwargs"
+ conn = Connection(
+ conn_id=conn_id,
+ conn_type="openai",
+ extra={
+ "openai_client_kwargs": {
+ "api_key": "api_key_in_extra",
+ "organization": "organization_in_extra",
+ }
+ },
+ )
+ os.environ[f"AIRFLOW_CONN_{conn.conn_id.upper()}"] = conn.get_uri()
+ hook = OpenAIHook(conn_id=conn_id)
+ hook.get_conn()
+ mock_client.assert_called_once_with(
+ api_key="api_key_in_extra",
+ base_url=None,
+ organization="organization_in_extra",
+ )