This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 5439b494b0 Add helper function for CRUD operations on weaviate's 
schema and class objects (#35919)
5439b494b0 is described below

commit 5439b494b00daf0bb62d8f1f8a0f4d71c39f4923
Author: Utkarsh Sharma <[email protected]>
AuthorDate: Tue Dec 12 02:01:58 2023 +0530

    Add helper function for CRUD operations on weaviate's schema and class 
objects (#35919)
    
    * Add CRUD operations arounf schema and class objects
    
    * Change assert methods
    
    * Add contains_schema to check for subset of schema
    
    * Handle casees when the properties are not in same order
    
    * Change the methods name and docstring
    
    * Resolve conflicts
    
    * Make sure the retrying is working as expected
    
    * Address PR comments
    
    * Remove retry logic from
    
    * Remove vector_col params and dataframe support
    
    * Remove unwanted retry logic
    
    * Address PR comments
    
    * Resolve ruff-lint issues
    
    * Remove unwanted changes
    
    * Remove unwanted changes
    
    * Change the exception to rety on
---
 airflow/providers/weaviate/hooks/weaviate.py    | 241 ++++++++++++++++++++++--
 tests/providers/weaviate/hooks/test_weaviate.py | 230 +++++++++++++++++++++-
 2 files changed, 453 insertions(+), 18 deletions(-)

diff --git a/airflow/providers/weaviate/hooks/weaviate.py 
b/airflow/providers/weaviate/hooks/weaviate.py
index 8a68b1c3f0..6c9b1b6787 100644
--- a/airflow/providers/weaviate/hooks/weaviate.py
+++ b/airflow/providers/weaviate/hooks/weaviate.py
@@ -24,7 +24,7 @@ from functools import cached_property
 from typing import TYPE_CHECKING, Any, Dict, List, cast
 
 import requests
-from tenacity import Retrying, retry_if_exception, stop_after_attempt
+from tenacity import Retrying, retry, retry_if_exception, 
retry_if_exception_type, stop_after_attempt
 from weaviate import Client as WeaviateClient
 from weaviate.auth import AuthApiKey, AuthBearerToken, AuthClientCredentials, 
AuthClientPassword
 from weaviate.exceptions import ObjectAlreadyExistsException
@@ -34,12 +34,31 @@ from airflow.exceptions import 
AirflowProviderDeprecationWarning
 from airflow.hooks.base import BaseHook
 
 if TYPE_CHECKING:
-    from typing import Sequence
+    from typing import Literal, Sequence
 
     import pandas as pd
     from weaviate import ConsistencyLevel
     from weaviate.types import UUID
 
+    ExitingSchemaOptions = Literal["replace", "fail", "ignore"]
+
+HTTP_RETRY_STATUS_CODE = [429, 500, 503, 504]
+REQUESTS_EXCEPTIONS_TYPES = (
+    requests.RequestException,
+    requests.exceptions.ConnectionError,
+    requests.exceptions.HTTPError,
+    requests.exceptions.ConnectTimeout,
+)
+
+
+def check_http_error_is_retryable(exc: BaseException):
+    return (
+        isinstance(exc, requests.exceptions.RequestException)
+        and exc.response
+        and exc.response.status_code
+        and exc.response.status_code in HTTP_RETRY_STATUS_CODE
+    )
+
 
 class WeaviateHook(BaseHook):
     """
@@ -53,7 +72,13 @@ class WeaviateHook(BaseHook):
     conn_type = "weaviate"
     hook_name = "Weaviate"
 
-    def __init__(self, conn_id: str = default_conn_name, *args: Any, **kwargs: 
Any) -> None:
+    def __init__(
+        self,
+        conn_id: str = default_conn_name,
+        retry_status_codes: list[int] | None = None,
+        *args: Any,
+        **kwargs: Any,
+    ) -> None:
         super().__init__(*args, **kwargs)
         self.conn_id = conn_id
 
@@ -137,21 +162,25 @@ class WeaviateHook(BaseHook):
         client = self.conn
         client.schema.create_class(class_json)
 
-    def create_schema(self, schema_json: dict[str, Any]) -> None:
+    @retry(
+        reraise=True,
+        stop=stop_after_attempt(3),
+        retry=(
+            retry_if_exception(lambda exc: check_http_error_is_retryable(exc))
+            | retry_if_exception_type(REQUESTS_EXCEPTIONS_TYPES)
+        ),
+    )
+    def create_schema(self, schema_json: dict[str, Any] | str) -> None:
         """
         Create a new Schema.
 
         Instead of adding classes one by one , you can upload a full schema in 
JSON format at once.
 
-        :param schema_json: The schema to create
+        :param schema_json: Schema as a Python dict or the path to a JSON 
file, or the URL of a JSON file.
         """
         client = self.conn
         client.schema.create(schema_json)
 
-    @staticmethod
-    def check_http_error_should_retry(exc: BaseException):
-        return isinstance(exc, requests.HTTPError) and not exc.response.ok
-
     @staticmethod
     def _convert_dataframe_to_list(data: list[dict[str, Any]] | pd.DataFrame) 
-> list[dict[str, Any]]:
         """Helper function to convert dataframe to list of dicts.
@@ -166,6 +195,190 @@ class WeaviateHook(BaseHook):
                 data = json.loads(data.to_json(orient="records"))
         return cast(List[Dict[str, Any]], data)
 
+    @retry(
+        reraise=True,
+        stop=stop_after_attempt(3),
+        retry=(
+            retry_if_exception(lambda exc: check_http_error_is_retryable(exc))
+            | retry_if_exception_type(REQUESTS_EXCEPTIONS_TYPES)
+        ),
+    )
+    def get_schema(self, class_name: str | None = None):
+        """Get the schema from Weaviate.
+
+        :param class_name: The class for which to return the schema. If NOT 
provided the whole schema is
+            returned, otherwise only the schema of this class is returned. By 
default None.
+        """
+        client = self.get_client()
+        return client.schema.get(class_name)
+
+    def delete_classes(self, class_names: list[str] | str, if_error: str = 
"stop") -> list[str] | None:
+        """Deletes all or specific classes if class_names are provided.
+
+        :param class_names: list of class names to be deleted.
+        :param if_error: define the actions to be taken if there is an error 
while deleting a class, possible
+         options are `stop` and `continue`
+        :return: if `if_error=continue` return list of classes which we failed 
to delete.
+            if `if_error=stop` returns None.
+        """
+        client = self.get_client()
+        class_names = [class_names] if class_names and isinstance(class_names, 
str) else class_names
+
+        failed_class_list = []
+        for class_name in class_names:
+            try:
+                for attempt in Retrying(
+                    stop=stop_after_attempt(3),
+                    retry=(
+                        retry_if_exception(lambda exc: 
check_http_error_is_retryable(exc))
+                        | retry_if_exception_type(REQUESTS_EXCEPTIONS_TYPES)
+                    ),
+                ):
+                    with attempt:
+                        print(attempt)
+                        client.schema.delete_class(class_name)
+            except Exception as e:
+                if if_error == "continue":
+                    self.log.error(e)
+                    failed_class_list.append(class_name)
+                elif if_error == "stop":
+                    raise e
+
+        if if_error == "continue":
+            return failed_class_list
+        return None
+
+    def delete_all_schema(self):
+        """Remove the entire schema from the Weaviate instance and all data 
associated with it."""
+        client = self.get_client()
+        return client.schema.delete_all()
+
+    def update_config(self, class_name: str, config: dict):
+        """Update a schema configuration for a specific class."""
+        client = self.get_client()
+        client.schema.update_config(class_name=class_name, config=config)
+
+    def create_or_replace_classes(
+        self, schema_json: dict[str, Any] | str, existing: 
ExitingSchemaOptions = "ignore"
+    ):
+        """
+        Create or replace the classes in schema of Weaviate database.
+
+        :param schema_json: Json containing the schema. Format {"class_name": 
"class_dict"}
+            .. seealso:: `example of class_dict 
<https://weaviate-python-client.readthedocs.io/en/v3.25.2/weaviate.schema.html#weaviate.schema.Schema.create>`_.
+        :param existing: Options to handle the case when the classes exist, 
possible options
+            'replace', 'fail', 'ignore'.
+        """
+        existing_schema_options = ["replace", "fail", "ignore"]
+        if existing not in existing_schema_options:
+            raise ValueError(f"Param 'existing' should be one of the 
{existing_schema_options} values.")
+        if isinstance(schema_json, str):
+            schema_json = cast(dict, json.load(open(schema_json)))
+        set__exiting_classes = {class_object["class"] for class_object in 
self.get_schema()["classes"]}
+        set__to_be_added_classes = {key for key, _ in schema_json.items()}
+        intersection_classes = 
set__exiting_classes.intersection(set__to_be_added_classes)
+        classes_to_create = set()
+        if existing == "fail" and intersection_classes:
+            raise ValueError(
+                f"Trying to create class {intersection_classes}" f" but this 
class already exists."
+            )
+        elif existing == "ignore":
+            classes_to_create = set__to_be_added_classes - set__exiting_classes
+        elif existing == "replace":
+            error_list = 
self.delete_classes(class_names=list(intersection_classes))
+            if error_list:
+                raise ValueError(error_list)
+            classes_to_create = 
intersection_classes.union(set__to_be_added_classes)
+        classes_to_create_list = [schema_json[item] for item in 
sorted(list(classes_to_create))]
+        self.create_schema({"classes": classes_to_create_list})
+
+    def _compare_schema_subset(self, subset_object: Any, superset_object: Any) 
-> bool:
+        """
+        Recursively check if requested subset_object is a subset of the 
superset_object.
+
+        Example 1:
+        superset_object = {"a": {"b": [1, 2, 3], "c": "d"}}
+        subset_object = {"a": {"c": "d"}}
+        _compare_schema_subset(subset_object, superset_object) # will result 
in True
+
+        superset_object = {"a": {"b": [1, 2, 3], "c": "d"}}
+        subset_object = {"a": {"d": "e"}}
+        _compare_schema_subset(subset_object, superset_object) # will result 
in False
+
+        :param subset_object: The object to be checked
+        :param superset_object: The object to check against
+        """
+        # Direct equality check
+        if subset_object == superset_object:
+            return True
+
+        # Type mismatch early return
+        if type(subset_object) != type(superset_object):
+            return False
+
+        # Dictionary comparison
+        if isinstance(subset_object, dict):
+            for k, v in subset_object.items():
+                if (k not in superset_object) or (not 
self._compare_schema_subset(v, superset_object[k])):
+                    return False
+            return True
+
+        # List or Tuple comparison
+        if isinstance(subset_object, (list, tuple)):
+            for sub, sup in zip(subset_object, superset_object):
+                if len(subset_object) > len(superset_object) or not 
self._compare_schema_subset(sub, sup):
+                    return False
+            return True
+
+        # Default case for non-matching types or unsupported types
+        return False
+
+    @staticmethod
+    def _convert_properties_to_dict(classes_objects, key_property: str = 
"name"):
+        """
+        Helper function to convert list of class properties into dict by using 
a `key_property` as key.
+
+        This is done to avoid class properties comparison as list of 
properties.
+
+        Case 1:
+        A = [1, 2, 3]
+        B = [1, 2]
+        When comparing list we check for the length, but it's not suitable for 
subset check.
+
+        Case 2:
+        A = [1, 2, 3]
+        B = [1, 3, 2]
+        When we compare two lists, we compare item 1 of list A with item 1 of 
list B and
+         pass if the two are same, but there can be scenarios when the 
properties are not in same order.
+        """
+        for cls in classes_objects:
+            cls["properties"] = {p[key_property]: p for p in cls["properties"]}
+        return classes_objects
+
+    def check_subset_of_schema(self, classes_objects: list) -> bool:
+        """Check if the class_objects is a subset of existing schema.
+
+        Note - weaviate client's `contains()` don't handle the class 
properties mismatch, if you want to
+         compare `Class A` with `Class B` they must have exactly same 
properties. If `Class A` has fewer
+          numbers of properties than Class B, `contains()` will result in 
False.
+
+        .. seealso:: `contains 
<https://weaviate-python-client.readthedocs.io/en/v3.25.3/weaviate.schema.html#weaviate.schema.Schema.contains>`_.
+        """
+        # When the class properties are not in same order or not the same 
length. We convert them to dicts
+        # with property `name` as the key. This way we ensure, the subset is 
checked.
+        classes_objects = self._convert_properties_to_dict(classes_objects)
+        exiting_classes_list = 
self._convert_properties_to_dict(self.get_schema()["classes"])
+
+        exiting_classes = {cls["class"]: cls for cls in exiting_classes_list}
+        exiting_classes_set = set(exiting_classes.keys())
+        input_classes_set = {cls["class"] for cls in classes_objects}
+        if not input_classes_set.issubset(exiting_classes_set):
+            return False
+        for cls in classes_objects:
+            if not self._compare_schema_subset(cls, 
exiting_classes[cls["class"]]):
+                return False
+        return True
+
     def batch_data(
         self,
         class_name: str,
@@ -194,7 +407,10 @@ class WeaviateHook(BaseHook):
             for index, data_obj in enumerate(data):
                 for attempt in Retrying(
                     stop=stop_after_attempt(retry_attempts_per_object),
-                    
retry=retry_if_exception(self.check_http_error_should_retry),
+                    retry=(
+                        retry_if_exception(lambda exc: 
check_http_error_is_retryable(exc))
+                        | retry_if_exception_type(REQUESTS_EXCEPTIONS_TYPES)
+                    ),
                 ):
                     with attempt:
                         self.log.debug(
@@ -203,11 +419,6 @@ class WeaviateHook(BaseHook):
                         vector = data_obj.pop(vector_col, None)
                         batch.add_data_object(data_obj, class_name, 
vector=vector)
 
-    def delete_class(self, class_name: str) -> None:
-        """Delete an existing class."""
-        client = self.conn
-        client.schema.delete_class(class_name)
-
     def query_with_vector(
         self,
         embeddings: list[float],
diff --git a/tests/providers/weaviate/hooks/test_weaviate.py 
b/tests/providers/weaviate/hooks/test_weaviate.py
index acda7e9c2e..fc6d7db6ae 100644
--- a/tests/providers/weaviate/hooks/test_weaviate.py
+++ b/tests/providers/weaviate/hooks/test_weaviate.py
@@ -16,12 +16,14 @@
 # under the License.
 from __future__ import annotations
 
+from contextlib import ExitStack
 from unittest import mock
-from unittest.mock import MagicMock, Mock, patch
+from unittest.mock import MagicMock, Mock
 
 import pandas as pd
 import pytest
 import requests
+import weaviate
 from weaviate import ObjectAlreadyExistsException
 
 from airflow.models import Connection
@@ -38,7 +40,7 @@ def weaviate_hook():
     mock_conn = Mock()
 
     # Patch the WeaviateHook get_connection method to return the mock 
connection
-    with patch.object(WeaviateHook, "get_connection", return_value=mock_conn):
+    with mock.patch.object(WeaviateHook, "get_connection", 
return_value=mock_conn):
         hook = WeaviateHook(conn_id=TEST_CONN_ID)
     return hook
 
@@ -434,7 +436,7 @@ def test_batch_data(data, expected_length, weaviate_hook):
     assert mock_batch_context.add_data_object.call_count == expected_length
 
 
-@patch("airflow.providers.weaviate.hooks.weaviate.WeaviateHook.get_conn")
[email protected]("airflow.providers.weaviate.hooks.weaviate.WeaviateHook.get_conn")
 def test_batch_data_retry(get_conn, weaviate_hook):
     """Test to ensure retrying working as expected"""
     data = [{"name": "chandler"}, {"name": "joey"}, {"name": "ross"}]
@@ -446,3 +448,225 @@ def test_batch_data_retry(get_conn, weaviate_hook):
     
get_conn.return_value.batch.__enter__.return_value.add_data_object.side_effect 
= side_effect
     weaviate_hook.batch_data("TestClass", data)
     assert 
get_conn.return_value.batch.__enter__.return_value.add_data_object.call_count 
== len(side_effect)
+
+
[email protected](
+    argnames=["get_schema_value", "existing", "expected_value"],
+    argvalues=[
+        ({"classes": [{"class": "A"}, {"class": "B"}]}, "ignore", [{"class": 
"C"}]),
+        ({"classes": [{"class": "A"}, {"class": "B"}]}, "replace", [{"class": 
"B"}, {"class": "C"}]),
+        ({"classes": [{"class": "A"}, {"class": "B"}]}, "fail", {}),
+        ({"classes": [{"class": "A"}, {"class": "B"}]}, "invalid_option", {}),
+    ],
+    ids=["ignore", "replace", "fail", "invalid_option"],
+)
[email protected]("airflow.providers.weaviate.hooks.weaviate.WeaviateHook.delete_classes")
[email protected]("airflow.providers.weaviate.hooks.weaviate.WeaviateHook.create_schema")
[email protected]("airflow.providers.weaviate.hooks.weaviate.WeaviateHook.get_schema")
+def test_upsert_schema_scenarios(
+    get_schema, create_schema, delete_classes, get_schema_value, existing, 
expected_value, weaviate_hook
+):
+    schema_json = {
+        "B": {"class": "B"},
+        "C": {"class": "C"},
+    }
+    with ExitStack() as stack:
+        delete_classes.return_value = None
+        if existing in ["fail", "invalid_option"]:
+            stack.enter_context(pytest.raises(ValueError))
+        get_schema.return_value = get_schema_value
+        weaviate_hook.create_or_replace_classes(schema_json=schema_json, 
existing=existing)
+        create_schema.assert_called_once_with({"classes": expected_value})
+        if existing == "replace":
+            delete_classes.assert_called_once_with(class_names=["B"])
+
+
[email protected]("builtins.open")
[email protected]("json.load")
[email protected]("airflow.providers.weaviate.hooks.weaviate.WeaviateHook.create_schema")
[email protected]("airflow.providers.weaviate.hooks.weaviate.WeaviateHook.get_schema")
+def test_upsert_schema_json_file_param(get_schema, create_schema, load, open, 
weaviate_hook):
+    """Test if schema_json is path to a json file"""
+    get_schema.return_value = {"classes": [{"class": "A"}, {"class": "B"}]}
+    load.return_value = {
+        "B": {"class": "B"},
+        "C": {"class": "C"},
+    }
+    
weaviate_hook.create_or_replace_classes(schema_json="/tmp/some_temp_file.json", 
existing="ignore")
+    create_schema.assert_called_once_with({"classes": [{"class": "C"}]})
+
+
[email protected]("airflow.providers.weaviate.hooks.weaviate.WeaviateHook.get_client")
+def test_delete_classes(get_client, weaviate_hook):
+    class_names = ["class_a", "class_b"]
+    get_client.return_value.schema.delete_class.side_effect = [
+        weaviate.UnexpectedStatusCodeException("something failed", 
requests.Response()),
+        None,
+    ]
+    error_list = weaviate_hook.delete_classes(class_names, if_error="continue")
+    assert error_list == ["class_a"]
+
+    get_client.return_value.schema.delete_class.side_effect = 
weaviate.UnexpectedStatusCodeException(
+        "something failed", requests.Response()
+    )
+    with pytest.raises(weaviate.UnexpectedStatusCodeException):
+        weaviate_hook.delete_classes("class_a", if_error="stop")
+
+
[email protected]("airflow.providers.weaviate.hooks.weaviate.WeaviateHook.get_client")
+def test_http_errors_of_delete_classes(get_client, weaviate_hook):
+    class_names = ["class_a", "class_b"]
+    resp = requests.Response()
+    resp.status_code = 429
+    get_client.return_value.schema.delete_class.side_effect = [
+        requests.exceptions.HTTPError(response=resp),
+        None,
+        requests.exceptions.ConnectionError,
+        None,
+    ]
+    error_list = weaviate_hook.delete_classes(class_names, if_error="continue")
+    assert error_list == []
+    assert get_client.return_value.schema.delete_class.call_count == 4
+
+
[email protected](
+    argnames=["classes_to_test", "expected_result"],
+    argvalues=[
+        (
+            [
+                {
+                    "class": "Author",
+                    "description": "Authors info",
+                    "properties": [
+                        {
+                            "name": "last_name",
+                            "description": "Last name of the author",
+                            "dataType": ["text"],
+                        },
+                    ],
+                },
+            ],
+            True,
+        ),
+        (
+            [
+                {
+                    "class": "Author",
+                    "description": "Authors info",
+                    "properties": [
+                        {
+                            "name": "last_name",
+                            "description": "Last name of the author",
+                            "dataType": ["text"],
+                        },
+                    ],
+                },
+            ],
+            True,
+        ),
+        (
+            [
+                {
+                    "class": "Author",
+                    "description": "Authors info",
+                    "properties": [
+                        {
+                            "name": "invalid_property",
+                            "description": "Last name of the author",
+                            "dataType": ["text"],
+                        }
+                    ],
+                },
+            ],
+            False,
+        ),
+        (
+            [
+                {
+                    "class": "invalid_class",
+                    "description": "Authors info",
+                    "properties": [
+                        {
+                            "name": "last_name",
+                            "description": "Last name of the author",
+                            "dataType": ["text"],
+                        },
+                    ],
+                },
+            ],
+            False,
+        ),
+        (
+            [
+                {
+                    "class": "Author",
+                    "description": "Authors info",
+                    "properties": [
+                        {
+                            "name": "last_name",
+                            "description": "Last name of the author",
+                            "dataType": ["text"],
+                        },
+                        {
+                            "name": "name",
+                            "description": "Name of the author",
+                            "dataType": ["text"],
+                            "extra_key": "some_value",
+                        },
+                    ],
+                },
+            ],
+            True,
+        ),
+    ],
+    ids=(
+        "property_level_check",
+        "class_level_check",
+        "invalid_property",
+        "invalid_class",
+        "swapped_properties",
+    ),
+)
[email protected]("airflow.providers.weaviate.hooks.weaviate.WeaviateHook.get_schema")
+def test_contains_schema(get_schema, classes_to_test, expected_result, 
weaviate_hook):
+    get_schema.return_value = {
+        "classes": [
+            {
+                "class": "Author",
+                "description": "Authors info",
+                "properties": [
+                    {
+                        "name": "name",
+                        "description": "Name of the author",
+                        "dataType": ["text"],
+                        "extra_key": "some_value",
+                    },
+                    {
+                        "name": "last_name",
+                        "description": "Last name of the author",
+                        "dataType": ["text"],
+                        "extra_key": "some_value",
+                    },
+                ],
+            },
+            {
+                "class": "Article",
+                "description": "An article written by an Author",
+                "properties": [
+                    {
+                        "name": "name",
+                        "description": "Name of the author",
+                        "dataType": ["text"],
+                        "extra_key": "some_value",
+                    },
+                    {
+                        "name": "last_name",
+                        "description": "Last name of the author",
+                        "dataType": ["text"],
+                        "extra_key": "some_value",
+                    },
+                ],
+            },
+        ]
+    }
+    assert weaviate_hook.check_subset_of_schema(classes_to_test) == 
expected_result

Reply via email to