This is an automated email from the ASF dual-hosted git repository.
ephraimanierobi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 2919abe5b3 Add more ways to connect to weaviate (#35864)
2919abe5b3 is described below
commit 2919abe5b3f2d186c896aebbc51acf98d554ef33
Author: Ephraim Anierobi <[email protected]>
AuthorDate: Tue Nov 28 20:30:06 2023 +0100
Add more ways to connect to weaviate (#35864)
* Add more ways to connect to weaviate
There are other options for connecting to weaviate. This commit adds
these other options and also improved the imports/typing
* fixup! Add more ways to connect to weaviate
* fixup! fixup! Add more ways to connect to weaviate
* add depreccation
* remove mark as dbtest
---
airflow/providers/weaviate/hooks/weaviate.py | 50 ++++---
.../connections.rst | 19 +++
tests/providers/weaviate/hooks/test_weaviate.py | 148 ++++++++++++++++++++-
3 files changed, 198 insertions(+), 19 deletions(-)
diff --git a/airflow/providers/weaviate/hooks/weaviate.py
b/airflow/providers/weaviate/hooks/weaviate.py
index c8b0ed05d4..151aaabea6 100644
--- a/airflow/providers/weaviate/hooks/weaviate.py
+++ b/airflow/providers/weaviate/hooks/weaviate.py
@@ -17,10 +17,13 @@
from __future__ import annotations
+import warnings
from typing import Any
-import weaviate
+from weaviate import Client as WeaviateClient
+from weaviate.auth import AuthApiKey, AuthBearerToken, AuthClientCredentials,
AuthClientPassword
+from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.hooks.base import BaseHook
@@ -40,19 +43,19 @@ class WeaviateHook(BaseHook):
super().__init__(*args, **kwargs)
self.conn_id = conn_id
- @staticmethod
- def get_connection_form_widgets() -> dict[str, Any]:
+ @classmethod
+ def get_connection_form_widgets(cls) -> dict[str, Any]:
"""Returns connection widgets to add to connection form."""
from flask_appbuilder.fieldwidgets import BS3PasswordFieldWidget
from flask_babel import lazy_gettext
from wtforms import PasswordField
return {
- "token": PasswordField(lazy_gettext("Weaviate API Token"),
widget=BS3PasswordFieldWidget()),
+ "token": PasswordField(lazy_gettext("Weaviate API Key"),
widget=BS3PasswordFieldWidget()),
}
- @staticmethod
- def get_ui_field_behaviour() -> dict[str, Any]:
+ @classmethod
+ def get_ui_field_behaviour(cls) -> dict[str, Any]:
"""Returns custom field behaviour."""
return {
"hidden_fields": ["port", "schema"],
@@ -62,28 +65,43 @@ class WeaviateHook(BaseHook):
},
}
- def get_client(self) -> weaviate.Client:
+ def get_conn(self) -> WeaviateClient:
conn = self.get_connection(self.conn_id)
url = conn.host
username = conn.login or ""
password = conn.password or ""
extras = conn.extra_dejson
- token = extras.pop("token", "")
+ access_token = extras.get("access_token", None)
+ refresh_token = extras.get("refresh_token", None)
+ expires_in = extras.get("expires_in", 60)
+ # previously token was used as api_key(backwards compatibility)
+ api_key = extras.get("api_key", None) or extras.get("token", None)
+ client_secret = extras.get("client_secret", None)
additional_headers = extras.pop("additional_headers", {})
- scope = conn.extra_dejson.get("oidc_scope", "offline_access")
-
- if token == "" and username != "":
- auth_client_secret = weaviate.AuthClientPassword(
- username=username, password=password, scope=scope
+ scope = extras.get("scope", None) or extras.get("oidc_scope", None)
+ if api_key:
+ auth_client_secret = AuthApiKey(api_key)
+ elif access_token:
+ auth_client_secret = AuthBearerToken(
+ access_token, expires_in=expires_in,
refresh_token=refresh_token
)
+ elif client_secret:
+ auth_client_secret =
AuthClientCredentials(client_secret=client_secret, scope=scope)
else:
- auth_client_secret = weaviate.AuthApiKey(token)
+ auth_client_secret = AuthClientPassword(username=username,
password=password, scope=scope)
- client = weaviate.Client(
+ return WeaviateClient(
url=url, auth_client_secret=auth_client_secret,
additional_headers=additional_headers
)
- return client
+ def get_client(self) -> WeaviateClient:
+ # Keeping this for backwards compatibility
+ warnings.warn(
+ "The `get_client` method has been renamed to `get_conn`",
+ AirflowProviderDeprecationWarning,
+ stacklevel=2,
+ )
+ return self.get_conn()
def test_connection(self) -> tuple[bool, str]:
try:
diff --git a/docs/apache-airflow-providers-weaviate/connections.rst
b/docs/apache-airflow-providers-weaviate/connections.rst
index 5e16164ff6..081fe14d92 100644
--- a/docs/apache-airflow-providers-weaviate/connections.rst
+++ b/docs/apache-airflow-providers-weaviate/connections.rst
@@ -42,6 +42,8 @@ OIDC Password (optional)
Extra (optional)
Specify the extra parameters (as json dictionary) that can be used in the
connection. All parameters are optional.
+ The extras are those parameters that are acceptable in the different
authentication methods
+ here: `Authentication
<https://weaviate-python-client.readthedocs.io/en/stable/weaviate.auth.html>`__
* If you'd like to use Vectorizers for your class, configure the API keys
to use the corresponding
embedding API. The extras accepts a key ``additional_headers``
containing the dictionary
@@ -50,3 +52,20 @@ Extra (optional)
Weaviate API Token (optional)
Specify your Weaviate API Key to connect when API Key option is to be used
for authentication.
+
+Supported Authentication Methods
+--------------------------------
+* API Key Authentication: This method uses the Weaviate API Key to
authenticate the connection. You can either have the
+ API key in the ``Weaviate API Token`` field or in the extra field as a
dictionary with key ``token`` or ``api_key`` and
+ value as the API key.
+
+* Bearer Token Authentication: This method uses the Access Token to
authenticate the connection. You need to
+ have the Access Token in the extra field as a dictionary with key
``access_token`` and value as the Access Token. Other
+ parameters such as ``expires_in`` and ``refresh_token`` are optional.
+
+* Client Credentials Authentication: This method uses the Client Credentials
to authenticate the connection. You need to
+ have the Client Credentials in the extra field as a dictionary with key
``client_secret`` and value as the Client Credentials.
+ The ``scope`` is optional.
+
+* Password Authentication: This method uses the username and password to
authenticate the connection. You can specify the
+ scope in the extra field as a dictionary with key ``scope`` and value as the
scope. The ``scope`` is optional.
diff --git a/tests/providers/weaviate/hooks/test_weaviate.py
b/tests/providers/weaviate/hooks/test_weaviate.py
index 56f57ebc9b..0274004fc0 100644
--- a/tests/providers/weaviate/hooks/test_weaviate.py
+++ b/tests/providers/weaviate/hooks/test_weaviate.py
@@ -16,10 +16,12 @@
# under the License.
from __future__ import annotations
+from unittest import mock
from unittest.mock import MagicMock, Mock, patch
import pytest
+from airflow.models import Connection
from airflow.providers.weaviate.hooks.weaviate import WeaviateHook
TEST_CONN_ID = "test_weaviate_conn"
@@ -38,13 +40,153 @@ def weaviate_hook():
return hook
[email protected]
+def mock_auth_api_key():
+ with mock.patch("airflow.providers.weaviate.hooks.weaviate.AuthApiKey") as
m:
+ yield m
+
+
[email protected]
+def mock_auth_bearer_token():
+ with
mock.patch("airflow.providers.weaviate.hooks.weaviate.AuthBearerToken") as m:
+ yield m
+
+
[email protected]
+def mock_auth_client_credentials():
+ with
mock.patch("airflow.providers.weaviate.hooks.weaviate.AuthClientCredentials")
as m:
+ yield m
+
+
[email protected]
+def mock_auth_client_password():
+ with
mock.patch("airflow.providers.weaviate.hooks.weaviate.AuthClientPassword") as m:
+ yield m
+
+
+class TestWeaviateHook:
+ """
+ Test the WeaviateHook Hook.
+ """
+
+ @pytest.fixture(autouse=True)
+ def setup_method(self, monkeypatch):
+ """Set up the test method."""
+ self.weaviate_api_key1 = "weaviate_api_key1"
+ self.weaviate_api_key2 = "weaviate_api_key2"
+ self.api_key = "api_key"
+ self.weaviate_client_credentials = "weaviate_client_credentials"
+ self.client_secret = "client_secret"
+ self.scope = "scope1 scope2"
+ self.client_password = "client_password"
+ self.client_bearer_token = "client_bearer_token"
+ self.host = "http://localhost:8080"
+ conns = (
+ Connection(
+ conn_id=self.weaviate_api_key1,
+ host=self.host,
+ conn_type="weaviate",
+ extra={"api_key": self.api_key},
+ ),
+ Connection(
+ conn_id=self.weaviate_api_key2,
+ host=self.host,
+ conn_type="weaviate",
+ extra={"token": self.api_key},
+ ),
+ Connection(
+ conn_id=self.weaviate_client_credentials,
+ host=self.host,
+ conn_type="weaviate",
+ extra={"client_secret": self.client_secret, "scope":
self.scope},
+ ),
+ Connection(
+ conn_id=self.client_password,
+ host=self.host,
+ conn_type="weaviate",
+ login="login",
+ password="password",
+ ),
+ Connection(
+ conn_id=self.client_bearer_token,
+ host=self.host,
+ conn_type="weaviate",
+ extra={
+ "access_token": self.client_bearer_token,
+ "expires_in": 30,
+ "refresh_token": "refresh_token",
+ },
+ ),
+ )
+ for conn in conns:
+ monkeypatch.setenv(f"AIRFLOW_CONN_{conn.conn_id.upper()}",
conn.get_uri())
+
+ @mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateClient")
+ def test_get_conn_with_api_key_in_extra(self, mock_client,
mock_auth_api_key):
+ hook = WeaviateHook(conn_id=self.weaviate_api_key1)
+ hook.get_conn()
+ mock_auth_api_key.assert_called_once_with(self.api_key)
+ mock_client.assert_called_once_with(
+ url=self.host,
auth_client_secret=mock_auth_api_key(api_key=self.api_key),
additional_headers={}
+ )
+
+ @mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateClient")
+ def test_get_conn_with_token_in_extra(self, mock_client,
mock_auth_api_key):
+ # when token is passed in extra
+ hook = WeaviateHook(conn_id=self.weaviate_api_key2)
+ hook.get_conn()
+ mock_auth_api_key.assert_called_once_with(self.api_key)
+ mock_client.assert_called_once_with(
+ url=self.host,
auth_client_secret=mock_auth_api_key(api_key=self.api_key),
additional_headers={}
+ )
+
+ @mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateClient")
+ def test_get_conn_with_access_token_in_extra(self, mock_client,
mock_auth_bearer_token):
+ hook = WeaviateHook(conn_id=self.client_bearer_token)
+ hook.get_conn()
+ mock_auth_bearer_token.assert_called_once_with(
+ self.client_bearer_token, expires_in=30,
refresh_token="refresh_token"
+ )
+ mock_client.assert_called_once_with(
+ url=self.host,
+ auth_client_secret=mock_auth_bearer_token(
+ access_token=self.client_bearer_token, expires_in=30,
refresh_token="refresh_token"
+ ),
+ additional_headers={},
+ )
+
+ @mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateClient")
+ def test_get_conn_with_client_secret_in_extra(self, mock_client,
mock_auth_client_credentials):
+ hook = WeaviateHook(conn_id=self.weaviate_client_credentials)
+ hook.get_conn()
+ mock_auth_client_credentials.assert_called_once_with(
+ client_secret=self.client_secret, scope=self.scope
+ )
+ mock_client.assert_called_once_with(
+ url=self.host,
+
auth_client_secret=mock_auth_client_credentials(api_key=self.client_secret,
scope=self.scope),
+ additional_headers={},
+ )
+
+ @mock.patch("airflow.providers.weaviate.hooks.weaviate.WeaviateClient")
+ def test_get_conn_with_client_password_in_extra(self, mock_client,
mock_auth_client_password):
+ hook = WeaviateHook(conn_id=self.client_password)
+ hook.get_conn()
+ mock_auth_client_password.assert_called_once_with(username="login",
password="password", scope=None)
+ mock_client.assert_called_once_with(
+ url=self.host,
+ auth_client_secret=mock_auth_client_password(username="login",
password="password", scope=None),
+ additional_headers={},
+ )
+
+
def test_create_class(weaviate_hook):
"""
Test the create_class method of WeaviateHook.
"""
# Mock the Weaviate Client
mock_client = MagicMock()
- weaviate_hook.get_client = MagicMock(return_value=mock_client)
+ weaviate_hook.get_conn = MagicMock(return_value=mock_client)
# Define test class JSON
test_class_json = {
@@ -65,7 +207,7 @@ def test_create_schema(weaviate_hook):
"""
# Mock the Weaviate Client
mock_client = MagicMock()
- weaviate_hook.get_client = MagicMock(return_value=mock_client)
+ weaviate_hook.get_conn = MagicMock(return_value=mock_client)
# Define test schema JSON
test_schema_json = {
@@ -90,7 +232,7 @@ def test_batch_data(weaviate_hook):
"""
# Mock the Weaviate Client
mock_client = MagicMock()
- weaviate_hook.get_client = MagicMock(return_value=mock_client)
+ weaviate_hook.get_conn = MagicMock(return_value=mock_client)
# Define test data
test_class_name = "TestClass"