This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch v2-10-test
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/v2-10-test by this push:
     new 6af7256318 Fix tests/models/test_variable.py for database isolation 
mode (#41414) (#41952)
6af7256318 is described below

commit 6af72563187749622943dde5bce5723807ed2bc7
Author: Ephraim Anierobi <[email protected]>
AuthorDate: Mon Sep 2 15:25:29 2024 +0100

    Fix tests/models/test_variable.py for database isolation mode (#41414) 
(#41952)
    
    * Fix tests/models/test_variable.py for database isolation mode
    
    * Review feedback
    
    (cherry picked from commit 736ebfe3fe2bd67406d5a50dacbfa1e43767d4ce)
    
    Co-authored-by: Jens Scheffler <[email protected]>
---
 airflow/api_internal/endpoints/rpc_api_endpoint.py |  9 +--
 airflow/api_internal/internal_api_call.py          |  2 +-
 airflow/models/variable.py                         | 66 +++++++++++++++++++++-
 airflow/serialization/enums.py                     |  1 +
 airflow/serialization/serialized_objects.py        | 16 +++++-
 tests/models/test_variable.py                      |  8 ++-
 6 files changed, 90 insertions(+), 12 deletions(-)

diff --git a/airflow/api_internal/endpoints/rpc_api_endpoint.py 
b/airflow/api_internal/endpoints/rpc_api_endpoint.py
index a85964af4f..c3d8b671fb 100644
--- a/airflow/api_internal/endpoints/rpc_api_endpoint.py
+++ b/airflow/api_internal/endpoints/rpc_api_endpoint.py
@@ -126,9 +126,9 @@ def initialize_method_map() -> dict[str, Callable]:
         # XCom.get_many, # Not supported because it returns query
         XCom.clear,
         XCom.set,
-        Variable.set,
-        Variable.update,
-        Variable.delete,
+        Variable._set,
+        Variable._update,
+        Variable._delete,
         DAG.fetch_callback,
         DAG.fetch_dagrun,
         DagRun.fetch_task_instances,
@@ -237,7 +237,8 @@ def internal_airflow_api(body: dict[str, Any]) -> 
APIResponse:
             response = json.dumps(output_json) if output_json is not None else 
None
             log.debug("Sending response: %s", response)
             return Response(response=response, headers={"Content-Type": 
"application/json"})
-    except AirflowException as e:  # In case of AirflowException transport the 
exception class back to caller
+    # In case of AirflowException or other selective known types, transport 
the exception class back to caller
+    except (KeyError, AttributeError, AirflowException) as e:
         exception_json = BaseSerialization.serialize(e, 
use_pydantic_models=True)
         response = json.dumps(exception_json)
         log.debug("Sending exception response: %s", response)
diff --git a/airflow/api_internal/internal_api_call.py 
b/airflow/api_internal/internal_api_call.py
index fc0945b3c0..8838377877 100644
--- a/airflow/api_internal/internal_api_call.py
+++ b/airflow/api_internal/internal_api_call.py
@@ -159,7 +159,7 @@ def internal_api_call(func: Callable[PS, RT]) -> 
Callable[PS, RT]:
         if result is None or result == b"":
             return None
         result = BaseSerialization.deserialize(json.loads(result), 
use_pydantic_models=True)
-        if isinstance(result, AirflowException):
+        if isinstance(result, (KeyError, AttributeError, AirflowException)):
             raise result
         return result
 
diff --git a/airflow/models/variable.py b/airflow/models/variable.py
index 63b71303bc..563cac46e8 100644
--- a/airflow/models/variable.py
+++ b/airflow/models/variable.py
@@ -154,7 +154,6 @@ class Variable(Base, LoggingMixin):
 
     @staticmethod
     @provide_session
-    @internal_api_call
     def set(
         key: str,
         value: Any,
@@ -167,6 +166,35 @@ class Variable(Base, LoggingMixin):
 
         This operation overwrites an existing variable.
 
+        :param key: Variable Key
+        :param value: Value to set for the Variable
+        :param description: Description of the Variable
+        :param serialize_json: Serialize the value to a JSON string
+        :param session: Session
+        """
+        Variable._set(
+            key=key, value=value, description=description, 
serialize_json=serialize_json, session=session
+        )
+        # invalidate key in cache for faster propagation
+        # we cannot save the value set because it's possible that it's 
shadowed by a custom backend
+        # (see call to check_for_write_conflict above)
+        SecretCache.invalidate_variable(key)
+
+    @staticmethod
+    @provide_session
+    @internal_api_call
+    def _set(
+        key: str,
+        value: Any,
+        description: str | None = None,
+        serialize_json: bool = False,
+        session: Session = None,
+    ) -> None:
+        """
+        Set a value for an Airflow Variable with a given Key.
+
+        This operation overwrites an existing variable.
+
         :param key: Variable Key
         :param value: Value to set for the Variable
         :param description: Description of the Variable
@@ -190,7 +218,6 @@ class Variable(Base, LoggingMixin):
 
     @staticmethod
     @provide_session
-    @internal_api_call
     def update(
         key: str,
         value: Any,
@@ -200,6 +227,27 @@ class Variable(Base, LoggingMixin):
         """
         Update a given Airflow Variable with the Provided value.
 
+        :param key: Variable Key
+        :param value: Value to set for the Variable
+        :param serialize_json: Serialize the value to a JSON string
+        :param session: Session
+        """
+        Variable._update(key=key, value=value, serialize_json=serialize_json, 
session=session)
+        # We need to invalidate the cache for internal API cases on the client 
side
+        SecretCache.invalidate_variable(key)
+
+    @staticmethod
+    @provide_session
+    @internal_api_call
+    def _update(
+        key: str,
+        value: Any,
+        serialize_json: bool = False,
+        session: Session = None,
+    ) -> None:
+        """
+        Update a given Airflow Variable with the Provided value.
+
         :param key: Variable Key
         :param value: Value to set for the Variable
         :param serialize_json: Serialize the value to a JSON string
@@ -219,11 +267,23 @@ class Variable(Base, LoggingMixin):
 
     @staticmethod
     @provide_session
-    @internal_api_call
     def delete(key: str, session: Session = None) -> int:
         """
         Delete an Airflow Variable for a given key.
 
+        :param key: Variable Keys
+        """
+        rows = Variable._delete(key=key, session=session)
+        SecretCache.invalidate_variable(key)
+        return rows
+
+    @staticmethod
+    @provide_session
+    @internal_api_call
+    def _delete(key: str, session: Session = None) -> int:
+        """
+        Delete an Airflow Variable for a given key.
+
         :param key: Variable Keys
         """
         rows = session.execute(delete(Variable).where(Variable.key == 
key)).rowcount
diff --git a/airflow/serialization/enums.py b/airflow/serialization/enums.py
index a5bd5e3646..f216ce7316 100644
--- a/airflow/serialization/enums.py
+++ b/airflow/serialization/enums.py
@@ -46,6 +46,7 @@ class DagAttributeTypes(str, Enum):
     RELATIVEDELTA = "relativedelta"
     BASE_TRIGGER = "base_trigger"
     AIRFLOW_EXC_SER = "airflow_exc_ser"
+    BASE_EXC_SER = "base_exc_ser"
     DICT = "dict"
     SET = "set"
     TUPLE = "tuple"
diff --git a/airflow/serialization/serialized_objects.py 
b/airflow/serialization/serialized_objects.py
index 6d0bbd4e23..84ad567918 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -692,6 +692,15 @@ class BaseSerialization:
                 ),
                 type_=DAT.AIRFLOW_EXC_SER,
             )
+        elif isinstance(var, (KeyError, AttributeError)):
+            return cls._encode(
+                cls.serialize(
+                    {"exc_cls_name": var.__class__.__name__, "args": 
[var.args], "kwargs": {}},
+                    use_pydantic_models=use_pydantic_models,
+                    strict=strict,
+                ),
+                type_=DAT.BASE_EXC_SER,
+            )
         elif isinstance(var, BaseTrigger):
             return cls._encode(
                 cls.serialize(var.serialize(), 
use_pydantic_models=use_pydantic_models, strict=strict),
@@ -834,13 +843,16 @@ class BaseSerialization:
             return decode_timezone(var)
         elif type_ == DAT.RELATIVEDELTA:
             return decode_relativedelta(var)
-        elif type_ == DAT.AIRFLOW_EXC_SER:
+        elif type_ == DAT.AIRFLOW_EXC_SER or type_ == DAT.BASE_EXC_SER:
             deser = cls.deserialize(var, 
use_pydantic_models=use_pydantic_models)
             exc_cls_name = deser["exc_cls_name"]
             args = deser["args"]
             kwargs = deser["kwargs"]
             del deser
-            exc_cls = import_string(exc_cls_name)
+            if type_ == DAT.AIRFLOW_EXC_SER:
+                exc_cls = import_string(exc_cls_name)
+            else:
+                exc_cls = import_string(f"builtins.{exc_cls_name}")
             return exc_cls(*args, **kwargs)
         elif type_ == DAT.BASE_TRIGGER:
             tr_cls_name, kwargs = cls.deserialize(var, 
use_pydantic_models=use_pydantic_models)
diff --git a/tests/models/test_variable.py b/tests/models/test_variable.py
index e3d5c023a2..6fb6fa15f2 100644
--- a/tests/models/test_variable.py
+++ b/tests/models/test_variable.py
@@ -47,6 +47,7 @@ class TestVariable:
         db.clear_db_variables()
         crypto._fernet = None
 
+    @pytest.mark.skip_if_database_isolation_mode  # Does not work in db 
isolation mode, internal API has other fernet
     @conf_vars({("core", "fernet_key"): "", ("core", "unit_test_mode"): 
"True"})
     def test_variable_no_encryption(self, session):
         """
@@ -60,6 +61,7 @@ class TestVariable:
         # should mask anything. That logic is tested in test_secrets_masker.py
         self.mask_secret.assert_called_once_with("value", "key")
 
+    @pytest.mark.skip_if_database_isolation_mode  # Does not work in db 
isolation mode, internal API has other fernet
     @conf_vars({("core", "fernet_key"): Fernet.generate_key().decode()})
     def test_variable_with_encryption(self, session):
         """
@@ -70,6 +72,7 @@ class TestVariable:
         assert test_var.is_encrypted
         assert test_var.val == "value"
 
+    @pytest.mark.skip_if_database_isolation_mode  # Does not work in db 
isolation mode, internal API has other fernet
     @pytest.mark.parametrize("test_value", ["value", ""])
     def test_var_with_encryption_rotate_fernet_key(self, test_value, session):
         """
@@ -152,6 +155,7 @@ class TestVariable:
         Variable.update(key="test_key", value="value2", session=session)
         assert "value2" == Variable.get("test_key")
 
+    @pytest.mark.skip_if_database_isolation_mode  # Does not work in db 
isolation mode, API server has other ENV
     def test_variable_update_fails_on_non_metastore_variable(self, session):
         with mock.patch.dict("os.environ", AIRFLOW_VAR_KEY="env-value"):
             with pytest.raises(AttributeError):
@@ -281,6 +285,7 @@ class TestVariable:
         mock_backend.get_variable.assert_called_once()  # second call was not 
made because of cache
         assert first == second
 
+    @pytest.mark.skip_if_database_isolation_mode  # Does not work in db 
isolation mode, internal API has other env
     def test_cache_invalidation_on_set(self, session):
         with mock.patch.dict("os.environ", AIRFLOW_VAR_KEY="from_env"):
             a = Variable.get("key")  # value is saved in cache
@@ -316,7 +321,7 @@ def test_masking_only_secret_values(variable_value, 
deserialize_json, expected_m
             val=variable_value,
         )
         session.add(var)
-        session.flush()
+        session.commit()
         # Make sure we re-load it, not just get the cached object back
         session.expunge(var)
         _secrets_masker().patterns = set()
@@ -326,5 +331,4 @@ def test_masking_only_secret_values(variable_value, 
deserialize_json, expected_m
         for expected_masked_value in expected_masked_values:
             assert expected_masked_value in _secrets_masker().patterns
     finally:
-        session.rollback()
         db.clear_db_variables()

Reply via email to