This is an automated email from the ASF dual-hosted git repository.
ephraimanierobi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 5439b494b0 Add helper function for CRUD operations on weaviate's
schema and class objects (#35919)
5439b494b0 is described below
commit 5439b494b00daf0bb62d8f1f8a0f4d71c39f4923
Author: Utkarsh Sharma <[email protected]>
AuthorDate: Tue Dec 12 02:01:58 2023 +0530
Add helper function for CRUD operations on weaviate's schema and class
objects (#35919)
* Add CRUD operations arounf schema and class objects
* Change assert methods
* Add contains_schema to check for subset of schema
* Handle casees when the properties are not in same order
* Change the methods name and docstring
* Resolve conflicts
* Make sure the retrying is working as expected
* Address PR comments
* Remove retry logic from
* Remove vector_col params and dataframe support
* Remove unwanted retry logic
* Address PR comments
* Resolve ruff-lint issues
* Remove unwanted changes
* Remove unwanted changes
* Change the exception to rety on
---
airflow/providers/weaviate/hooks/weaviate.py | 241 ++++++++++++++++++++++--
tests/providers/weaviate/hooks/test_weaviate.py | 230 +++++++++++++++++++++-
2 files changed, 453 insertions(+), 18 deletions(-)
diff --git a/airflow/providers/weaviate/hooks/weaviate.py
b/airflow/providers/weaviate/hooks/weaviate.py
index 8a68b1c3f0..6c9b1b6787 100644
--- a/airflow/providers/weaviate/hooks/weaviate.py
+++ b/airflow/providers/weaviate/hooks/weaviate.py
@@ -24,7 +24,7 @@ from functools import cached_property
from typing import TYPE_CHECKING, Any, Dict, List, cast
import requests
-from tenacity import Retrying, retry_if_exception, stop_after_attempt
+from tenacity import Retrying, retry, retry_if_exception,
retry_if_exception_type, stop_after_attempt
from weaviate import Client as WeaviateClient
from weaviate.auth import AuthApiKey, AuthBearerToken, AuthClientCredentials,
AuthClientPassword
from weaviate.exceptions import ObjectAlreadyExistsException
@@ -34,12 +34,31 @@ from airflow.exceptions import
AirflowProviderDeprecationWarning
from airflow.hooks.base import BaseHook
if TYPE_CHECKING:
- from typing import Sequence
+ from typing import Literal, Sequence
import pandas as pd
from weaviate import ConsistencyLevel
from weaviate.types import UUID
+ ExitingSchemaOptions = Literal["replace", "fail", "ignore"]
+
+HTTP_RETRY_STATUS_CODE = [429, 500, 503, 504]
+REQUESTS_EXCEPTIONS_TYPES = (
+ requests.RequestException,
+ requests.exceptions.ConnectionError,
+ requests.exceptions.HTTPError,
+ requests.exceptions.ConnectTimeout,
+)
+
+
+def check_http_error_is_retryable(exc: BaseException):
+ return (
+ isinstance(exc, requests.exceptions.RequestException)
+ and exc.response
+ and exc.response.status_code
+ and exc.response.status_code in HTTP_RETRY_STATUS_CODE
+ )
+
class WeaviateHook(BaseHook):
"""
@@ -53,7 +72,13 @@ class WeaviateHook(BaseHook):
conn_type = "weaviate"
hook_name = "Weaviate"
- def __init__(self, conn_id: str = default_conn_name, *args: Any, **kwargs:
Any) -> None:
+ def __init__(
+ self,
+ conn_id: str = default_conn_name,
+ retry_status_codes: list[int] | None = None,
+ *args: Any,
+ **kwargs: Any,
+ ) -> None:
super().__init__(*args, **kwargs)
self.conn_id = conn_id
@@ -137,21 +162,25 @@ class WeaviateHook(BaseHook):
client = self.conn
client.schema.create_class(class_json)
- def create_schema(self, schema_json: dict[str, Any]) -> None:
+ @retry(
+ reraise=True,
+ stop=stop_after_attempt(3),
+ retry=(
+ retry_if_exception(lambda exc: check_http_error_is_retryable(exc))
+ | retry_if_exception_type(REQUESTS_EXCEPTIONS_TYPES)
+ ),
+ )
+ def create_schema(self, schema_json: dict[str, Any] | str) -> None:
"""
Create a new Schema.
Instead of adding classes one by one , you can upload a full schema in
JSON format at once.
- :param schema_json: The schema to create
+ :param schema_json: Schema as a Python dict or the path to a JSON
file, or the URL of a JSON file.
"""
client = self.conn
client.schema.create(schema_json)
- @staticmethod
- def check_http_error_should_retry(exc: BaseException):
- return isinstance(exc, requests.HTTPError) and not exc.response.ok
-
@staticmethod
def _convert_dataframe_to_list(data: list[dict[str, Any]] | pd.DataFrame)
-> list[dict[str, Any]]:
"""Helper function to convert dataframe to list of dicts.
@@ -166,6 +195,190 @@ class WeaviateHook(BaseHook):
data = json.loads(data.to_json(orient="records"))
return cast(List[Dict[str, Any]], data)
+ @retry(
+ reraise=True,
+ stop=stop_after_attempt(3),
+ retry=(
+ retry_if_exception(lambda exc: check_http_error_is_retryable(exc))
+ | retry_if_exception_type(REQUESTS_EXCEPTIONS_TYPES)
+ ),
+ )
+ def get_schema(self, class_name: str | None = None):
+ """Get the schema from Weaviate.
+
+ :param class_name: The class for which to return the schema. If NOT
provided the whole schema is
+ returned, otherwise only the schema of this class is returned. By
default None.
+ """
+ client = self.get_client()
+ return client.schema.get(class_name)
+
+ def delete_classes(self, class_names: list[str] | str, if_error: str =
"stop") -> list[str] | None:
+ """Deletes all or specific classes if class_names are provided.
+
+ :param class_names: list of class names to be deleted.
+ :param if_error: define the actions to be taken if there is an error
while deleting a class, possible
+ options are `stop` and `continue`
+ :return: if `if_error=continue` return list of classes which we failed
to delete.
+ if `if_error=stop` returns None.
+ """
+ client = self.get_client()
+ class_names = [class_names] if class_names and isinstance(class_names,
str) else class_names
+
+ failed_class_list = []
+ for class_name in class_names:
+ try:
+ for attempt in Retrying(
+ stop=stop_after_attempt(3),
+ retry=(
+ retry_if_exception(lambda exc:
check_http_error_is_retryable(exc))
+ | retry_if_exception_type(REQUESTS_EXCEPTIONS_TYPES)
+ ),
+ ):
+ with attempt:
+ print(attempt)
+ client.schema.delete_class(class_name)
+ except Exception as e:
+ if if_error == "continue":
+ self.log.error(e)
+ failed_class_list.append(class_name)
+ elif if_error == "stop":
+ raise e
+
+ if if_error == "continue":
+ return failed_class_list
+ return None
+
+ def delete_all_schema(self):
+ """Remove the entire schema from the Weaviate instance and all data
associated with it."""
+ client = self.get_client()
+ return client.schema.delete_all()
+
+ def update_config(self, class_name: str, config: dict):
+ """Update a schema configuration for a specific class."""
+ client = self.get_client()
+ client.schema.update_config(class_name=class_name, config=config)
+
+ def create_or_replace_classes(
+ self, schema_json: dict[str, Any] | str, existing:
ExitingSchemaOptions = "ignore"
+ ):
+ """
+ Create or replace the classes in schema of Weaviate database.
+
+ :param schema_json: Json containing the schema. Format {"class_name":
"class_dict"}
+ .. seealso:: `example of class_dict
<https://weaviate-python-client.readthedocs.io/en/v3.25.2/weaviate.schema.html#weaviate.schema.Schema.create>`_.
+ :param existing: Options to handle the case when the classes exist,
possible options
+ 'replace', 'fail', 'ignore'.
+ """
+ existing_schema_options = ["replace", "fail", "ignore"]
+ if existing not in existing_schema_options:
+ raise ValueError(f"Param 'existing' should be one of the
{existing_schema_options} values.")
+ if isinstance(schema_json, str):
+ schema_json = cast(dict, json.load(open(schema_json)))
+ set__exiting_classes = {class_object["class"] for class_object in
self.get_schema()["classes"]}
+ set__to_be_added_classes = {key for key, _ in schema_json.items()}
+ intersection_classes =
set__exiting_classes.intersection(set__to_be_added_classes)
+ classes_to_create = set()
+ if existing == "fail" and intersection_classes:
+ raise ValueError(
+ f"Trying to create class {intersection_classes}" f" but this
class already exists."
+ )
+ elif existing == "ignore":
+ classes_to_create = set__to_be_added_classes - set__exiting_classes
+ elif existing == "replace":
+ error_list =
self.delete_classes(class_names=list(intersection_classes))
+ if error_list:
+ raise ValueError(error_list)
+ classes_to_create =
intersection_classes.union(set__to_be_added_classes)
+ classes_to_create_list = [schema_json[item] for item in
sorted(list(classes_to_create))]
+ self.create_schema({"classes": classes_to_create_list})
+
+ def _compare_schema_subset(self, subset_object: Any, superset_object: Any)
-> bool:
+ """
+ Recursively check if requested subset_object is a subset of the
superset_object.
+
+ Example 1:
+ superset_object = {"a": {"b": [1, 2, 3], "c": "d"}}
+ subset_object = {"a": {"c": "d"}}
+ _compare_schema_subset(subset_object, superset_object) # will result
in True
+
+ superset_object = {"a": {"b": [1, 2, 3], "c": "d"}}
+ subset_object = {"a": {"d": "e"}}
+ _compare_schema_subset(subset_object, superset_object) # will result
in False
+
+ :param subset_object: The object to be checked
+ :param superset_object: The object to check against
+ """
+ # Direct equality check
+ if subset_object == superset_object:
+ return True
+
+ # Type mismatch early return
+ if type(subset_object) != type(superset_object):
+ return False
+
+ # Dictionary comparison
+ if isinstance(subset_object, dict):
+ for k, v in subset_object.items():
+ if (k not in superset_object) or (not
self._compare_schema_subset(v, superset_object[k])):
+ return False
+ return True
+
+ # List or Tuple comparison
+ if isinstance(subset_object, (list, tuple)):
+ for sub, sup in zip(subset_object, superset_object):
+ if len(subset_object) > len(superset_object) or not
self._compare_schema_subset(sub, sup):
+ return False
+ return True
+
+ # Default case for non-matching types or unsupported types
+ return False
+
+ @staticmethod
+ def _convert_properties_to_dict(classes_objects, key_property: str =
"name"):
+ """
+ Helper function to convert list of class properties into dict by using
a `key_property` as key.
+
+ This is done to avoid class properties comparison as list of
properties.
+
+ Case 1:
+ A = [1, 2, 3]
+ B = [1, 2]
+ When comparing list we check for the length, but it's not suitable for
subset check.
+
+ Case 2:
+ A = [1, 2, 3]
+ B = [1, 3, 2]
+ When we compare two lists, we compare item 1 of list A with item 1 of
list B and
+ pass if the two are same, but there can be scenarios when the
properties are not in same order.
+ """
+ for cls in classes_objects:
+ cls["properties"] = {p[key_property]: p for p in cls["properties"]}
+ return classes_objects
+
+ def check_subset_of_schema(self, classes_objects: list) -> bool:
+ """Check if the class_objects is a subset of existing schema.
+
+ Note - weaviate client's `contains()` don't handle the class
properties mismatch, if you want to
+ compare `Class A` with `Class B` they must have exactly same
properties. If `Class A` has fewer
+ numbers of properties than Class B, `contains()` will result in
False.
+
+ .. seealso:: `contains
<https://weaviate-python-client.readthedocs.io/en/v3.25.3/weaviate.schema.html#weaviate.schema.Schema.contains>`_.
+ """
+ # When the class properties are not in same order or not the same
length. We convert them to dicts
+ # with property `name` as the key. This way we ensure, the subset is
checked.
+ classes_objects = self._convert_properties_to_dict(classes_objects)
+ exiting_classes_list =
self._convert_properties_to_dict(self.get_schema()["classes"])
+
+ exiting_classes = {cls["class"]: cls for cls in exiting_classes_list}
+ exiting_classes_set = set(exiting_classes.keys())
+ input_classes_set = {cls["class"] for cls in classes_objects}
+ if not input_classes_set.issubset(exiting_classes_set):
+ return False
+ for cls in classes_objects:
+ if not self._compare_schema_subset(cls,
exiting_classes[cls["class"]]):
+ return False
+ return True
+
def batch_data(
self,
class_name: str,
@@ -194,7 +407,10 @@ class WeaviateHook(BaseHook):
for index, data_obj in enumerate(data):
for attempt in Retrying(
stop=stop_after_attempt(retry_attempts_per_object),
-
retry=retry_if_exception(self.check_http_error_should_retry),
+ retry=(
+ retry_if_exception(lambda exc:
check_http_error_is_retryable(exc))
+ | retry_if_exception_type(REQUESTS_EXCEPTIONS_TYPES)
+ ),
):
with attempt:
self.log.debug(
@@ -203,11 +419,6 @@ class WeaviateHook(BaseHook):
vector = data_obj.pop(vector_col, None)
batch.add_data_object(data_obj, class_name,
vector=vector)
- def delete_class(self, class_name: str) -> None:
- """Delete an existing class."""
- client = self.conn
- client.schema.delete_class(class_name)
-
def query_with_vector(
self,
embeddings: list[float],
diff --git a/tests/providers/weaviate/hooks/test_weaviate.py
b/tests/providers/weaviate/hooks/test_weaviate.py
index acda7e9c2e..fc6d7db6ae 100644
--- a/tests/providers/weaviate/hooks/test_weaviate.py
+++ b/tests/providers/weaviate/hooks/test_weaviate.py
@@ -16,12 +16,14 @@
# under the License.
from __future__ import annotations
+from contextlib import ExitStack
from unittest import mock
-from unittest.mock import MagicMock, Mock, patch
+from unittest.mock import MagicMock, Mock
import pandas as pd
import pytest
import requests
+import weaviate
from weaviate import ObjectAlreadyExistsException
from airflow.models import Connection
@@ -38,7 +40,7 @@ def weaviate_hook():
mock_conn = Mock()
# Patch the WeaviateHook get_connection method to return the mock
connection
- with patch.object(WeaviateHook, "get_connection", return_value=mock_conn):
+ with mock.patch.object(WeaviateHook, "get_connection",
return_value=mock_conn):
hook = WeaviateHook(conn_id=TEST_CONN_ID)
return hook
@@ -434,7 +436,7 @@ def test_batch_data(data, expected_length, weaviate_hook):
assert mock_batch_context.add_data_object.call_count == expected_length
-@patch("airflow.providers.weaviate.hooks.weaviate.WeaviateHook.get_conn")
[email protected]("airflow.providers.weaviate.hooks.weaviate.WeaviateHook.get_conn")
def test_batch_data_retry(get_conn, weaviate_hook):
"""Test to ensure retrying working as expected"""
data = [{"name": "chandler"}, {"name": "joey"}, {"name": "ross"}]
@@ -446,3 +448,225 @@ def test_batch_data_retry(get_conn, weaviate_hook):
get_conn.return_value.batch.__enter__.return_value.add_data_object.side_effect
= side_effect
weaviate_hook.batch_data("TestClass", data)
assert
get_conn.return_value.batch.__enter__.return_value.add_data_object.call_count
== len(side_effect)
+
+
[email protected](
+ argnames=["get_schema_value", "existing", "expected_value"],
+ argvalues=[
+ ({"classes": [{"class": "A"}, {"class": "B"}]}, "ignore", [{"class":
"C"}]),
+ ({"classes": [{"class": "A"}, {"class": "B"}]}, "replace", [{"class":
"B"}, {"class": "C"}]),
+ ({"classes": [{"class": "A"}, {"class": "B"}]}, "fail", {}),
+ ({"classes": [{"class": "A"}, {"class": "B"}]}, "invalid_option", {}),
+ ],
+ ids=["ignore", "replace", "fail", "invalid_option"],
+)
[email protected]("airflow.providers.weaviate.hooks.weaviate.WeaviateHook.delete_classes")
[email protected]("airflow.providers.weaviate.hooks.weaviate.WeaviateHook.create_schema")
[email protected]("airflow.providers.weaviate.hooks.weaviate.WeaviateHook.get_schema")
+def test_upsert_schema_scenarios(
+ get_schema, create_schema, delete_classes, get_schema_value, existing,
expected_value, weaviate_hook
+):
+ schema_json = {
+ "B": {"class": "B"},
+ "C": {"class": "C"},
+ }
+ with ExitStack() as stack:
+ delete_classes.return_value = None
+ if existing in ["fail", "invalid_option"]:
+ stack.enter_context(pytest.raises(ValueError))
+ get_schema.return_value = get_schema_value
+ weaviate_hook.create_or_replace_classes(schema_json=schema_json,
existing=existing)
+ create_schema.assert_called_once_with({"classes": expected_value})
+ if existing == "replace":
+ delete_classes.assert_called_once_with(class_names=["B"])
+
+
[email protected]("builtins.open")
[email protected]("json.load")
[email protected]("airflow.providers.weaviate.hooks.weaviate.WeaviateHook.create_schema")
[email protected]("airflow.providers.weaviate.hooks.weaviate.WeaviateHook.get_schema")
+def test_upsert_schema_json_file_param(get_schema, create_schema, load, open,
weaviate_hook):
+ """Test if schema_json is path to a json file"""
+ get_schema.return_value = {"classes": [{"class": "A"}, {"class": "B"}]}
+ load.return_value = {
+ "B": {"class": "B"},
+ "C": {"class": "C"},
+ }
+
weaviate_hook.create_or_replace_classes(schema_json="/tmp/some_temp_file.json",
existing="ignore")
+ create_schema.assert_called_once_with({"classes": [{"class": "C"}]})
+
+
[email protected]("airflow.providers.weaviate.hooks.weaviate.WeaviateHook.get_client")
+def test_delete_classes(get_client, weaviate_hook):
+ class_names = ["class_a", "class_b"]
+ get_client.return_value.schema.delete_class.side_effect = [
+ weaviate.UnexpectedStatusCodeException("something failed",
requests.Response()),
+ None,
+ ]
+ error_list = weaviate_hook.delete_classes(class_names, if_error="continue")
+ assert error_list == ["class_a"]
+
+ get_client.return_value.schema.delete_class.side_effect =
weaviate.UnexpectedStatusCodeException(
+ "something failed", requests.Response()
+ )
+ with pytest.raises(weaviate.UnexpectedStatusCodeException):
+ weaviate_hook.delete_classes("class_a", if_error="stop")
+
+
[email protected]("airflow.providers.weaviate.hooks.weaviate.WeaviateHook.get_client")
+def test_http_errors_of_delete_classes(get_client, weaviate_hook):
+ class_names = ["class_a", "class_b"]
+ resp = requests.Response()
+ resp.status_code = 429
+ get_client.return_value.schema.delete_class.side_effect = [
+ requests.exceptions.HTTPError(response=resp),
+ None,
+ requests.exceptions.ConnectionError,
+ None,
+ ]
+ error_list = weaviate_hook.delete_classes(class_names, if_error="continue")
+ assert error_list == []
+ assert get_client.return_value.schema.delete_class.call_count == 4
+
+
[email protected](
+ argnames=["classes_to_test", "expected_result"],
+ argvalues=[
+ (
+ [
+ {
+ "class": "Author",
+ "description": "Authors info",
+ "properties": [
+ {
+ "name": "last_name",
+ "description": "Last name of the author",
+ "dataType": ["text"],
+ },
+ ],
+ },
+ ],
+ True,
+ ),
+ (
+ [
+ {
+ "class": "Author",
+ "description": "Authors info",
+ "properties": [
+ {
+ "name": "last_name",
+ "description": "Last name of the author",
+ "dataType": ["text"],
+ },
+ ],
+ },
+ ],
+ True,
+ ),
+ (
+ [
+ {
+ "class": "Author",
+ "description": "Authors info",
+ "properties": [
+ {
+ "name": "invalid_property",
+ "description": "Last name of the author",
+ "dataType": ["text"],
+ }
+ ],
+ },
+ ],
+ False,
+ ),
+ (
+ [
+ {
+ "class": "invalid_class",
+ "description": "Authors info",
+ "properties": [
+ {
+ "name": "last_name",
+ "description": "Last name of the author",
+ "dataType": ["text"],
+ },
+ ],
+ },
+ ],
+ False,
+ ),
+ (
+ [
+ {
+ "class": "Author",
+ "description": "Authors info",
+ "properties": [
+ {
+ "name": "last_name",
+ "description": "Last name of the author",
+ "dataType": ["text"],
+ },
+ {
+ "name": "name",
+ "description": "Name of the author",
+ "dataType": ["text"],
+ "extra_key": "some_value",
+ },
+ ],
+ },
+ ],
+ True,
+ ),
+ ],
+ ids=(
+ "property_level_check",
+ "class_level_check",
+ "invalid_property",
+ "invalid_class",
+ "swapped_properties",
+ ),
+)
[email protected]("airflow.providers.weaviate.hooks.weaviate.WeaviateHook.get_schema")
+def test_contains_schema(get_schema, classes_to_test, expected_result,
weaviate_hook):
+ get_schema.return_value = {
+ "classes": [
+ {
+ "class": "Author",
+ "description": "Authors info",
+ "properties": [
+ {
+ "name": "name",
+ "description": "Name of the author",
+ "dataType": ["text"],
+ "extra_key": "some_value",
+ },
+ {
+ "name": "last_name",
+ "description": "Last name of the author",
+ "dataType": ["text"],
+ "extra_key": "some_value",
+ },
+ ],
+ },
+ {
+ "class": "Article",
+ "description": "An article written by an Author",
+ "properties": [
+ {
+ "name": "name",
+ "description": "Name of the author",
+ "dataType": ["text"],
+ "extra_key": "some_value",
+ },
+ {
+ "name": "last_name",
+ "description": "Last name of the author",
+ "dataType": ["text"],
+ "extra_key": "some_value",
+ },
+ ],
+ },
+ ]
+ }
+ assert weaviate_hook.check_subset_of_schema(classes_to_test) ==
expected_result