This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch v2-10-test
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/v2-10-test by this push:
new 6af7256318 Fix tests/models/test_variable.py for database isolation
mode (#41414) (#41952)
6af7256318 is described below
commit 6af72563187749622943dde5bce5723807ed2bc7
Author: Ephraim Anierobi <[email protected]>
AuthorDate: Mon Sep 2 15:25:29 2024 +0100
Fix tests/models/test_variable.py for database isolation mode (#41414)
(#41952)
* Fix tests/models/test_variable.py for database isolation mode
* Review feedback
(cherry picked from commit 736ebfe3fe2bd67406d5a50dacbfa1e43767d4ce)
Co-authored-by: Jens Scheffler <[email protected]>
---
airflow/api_internal/endpoints/rpc_api_endpoint.py | 9 +--
airflow/api_internal/internal_api_call.py | 2 +-
airflow/models/variable.py | 66 +++++++++++++++++++++-
airflow/serialization/enums.py | 1 +
airflow/serialization/serialized_objects.py | 16 +++++-
tests/models/test_variable.py | 8 ++-
6 files changed, 90 insertions(+), 12 deletions(-)
diff --git a/airflow/api_internal/endpoints/rpc_api_endpoint.py
b/airflow/api_internal/endpoints/rpc_api_endpoint.py
index a85964af4f..c3d8b671fb 100644
--- a/airflow/api_internal/endpoints/rpc_api_endpoint.py
+++ b/airflow/api_internal/endpoints/rpc_api_endpoint.py
@@ -126,9 +126,9 @@ def initialize_method_map() -> dict[str, Callable]:
# XCom.get_many, # Not supported because it returns query
XCom.clear,
XCom.set,
- Variable.set,
- Variable.update,
- Variable.delete,
+ Variable._set,
+ Variable._update,
+ Variable._delete,
DAG.fetch_callback,
DAG.fetch_dagrun,
DagRun.fetch_task_instances,
@@ -237,7 +237,8 @@ def internal_airflow_api(body: dict[str, Any]) ->
APIResponse:
response = json.dumps(output_json) if output_json is not None else
None
log.debug("Sending response: %s", response)
return Response(response=response, headers={"Content-Type":
"application/json"})
- except AirflowException as e: # In case of AirflowException transport the
exception class back to caller
+ # In case of AirflowException or other selective known types, transport
the exception class back to caller
+ except (KeyError, AttributeError, AirflowException) as e:
exception_json = BaseSerialization.serialize(e,
use_pydantic_models=True)
response = json.dumps(exception_json)
log.debug("Sending exception response: %s", response)
diff --git a/airflow/api_internal/internal_api_call.py
b/airflow/api_internal/internal_api_call.py
index fc0945b3c0..8838377877 100644
--- a/airflow/api_internal/internal_api_call.py
+++ b/airflow/api_internal/internal_api_call.py
@@ -159,7 +159,7 @@ def internal_api_call(func: Callable[PS, RT]) ->
Callable[PS, RT]:
if result is None or result == b"":
return None
result = BaseSerialization.deserialize(json.loads(result),
use_pydantic_models=True)
- if isinstance(result, AirflowException):
+ if isinstance(result, (KeyError, AttributeError, AirflowException)):
raise result
return result
diff --git a/airflow/models/variable.py b/airflow/models/variable.py
index 63b71303bc..563cac46e8 100644
--- a/airflow/models/variable.py
+++ b/airflow/models/variable.py
@@ -154,7 +154,6 @@ class Variable(Base, LoggingMixin):
@staticmethod
@provide_session
- @internal_api_call
def set(
key: str,
value: Any,
@@ -167,6 +166,35 @@ class Variable(Base, LoggingMixin):
This operation overwrites an existing variable.
+ :param key: Variable Key
+ :param value: Value to set for the Variable
+ :param description: Description of the Variable
+ :param serialize_json: Serialize the value to a JSON string
+ :param session: Session
+ """
+ Variable._set(
+ key=key, value=value, description=description,
serialize_json=serialize_json, session=session
+ )
+ # invalidate key in cache for faster propagation
+ # we cannot save the value set because it's possible that it's
shadowed by a custom backend
+ # (see call to check_for_write_conflict above)
+ SecretCache.invalidate_variable(key)
+
+ @staticmethod
+ @provide_session
+ @internal_api_call
+ def _set(
+ key: str,
+ value: Any,
+ description: str | None = None,
+ serialize_json: bool = False,
+ session: Session = None,
+ ) -> None:
+ """
+ Set a value for an Airflow Variable with a given Key.
+
+ This operation overwrites an existing variable.
+
:param key: Variable Key
:param value: Value to set for the Variable
:param description: Description of the Variable
@@ -190,7 +218,6 @@ class Variable(Base, LoggingMixin):
@staticmethod
@provide_session
- @internal_api_call
def update(
key: str,
value: Any,
@@ -200,6 +227,27 @@ class Variable(Base, LoggingMixin):
"""
Update a given Airflow Variable with the Provided value.
+ :param key: Variable Key
+ :param value: Value to set for the Variable
+ :param serialize_json: Serialize the value to a JSON string
+ :param session: Session
+ """
+ Variable._update(key=key, value=value, serialize_json=serialize_json,
session=session)
+ # We need to invalidate the cache for internal API cases on the client
side
+ SecretCache.invalidate_variable(key)
+
+ @staticmethod
+ @provide_session
+ @internal_api_call
+ def _update(
+ key: str,
+ value: Any,
+ serialize_json: bool = False,
+ session: Session = None,
+ ) -> None:
+ """
+ Update a given Airflow Variable with the Provided value.
+
:param key: Variable Key
:param value: Value to set for the Variable
:param serialize_json: Serialize the value to a JSON string
@@ -219,11 +267,23 @@ class Variable(Base, LoggingMixin):
@staticmethod
@provide_session
- @internal_api_call
def delete(key: str, session: Session = None) -> int:
"""
Delete an Airflow Variable for a given key.
+ :param key: Variable Keys
+ """
+ rows = Variable._delete(key=key, session=session)
+ SecretCache.invalidate_variable(key)
+ return rows
+
+ @staticmethod
+ @provide_session
+ @internal_api_call
+ def _delete(key: str, session: Session = None) -> int:
+ """
+ Delete an Airflow Variable for a given key.
+
:param key: Variable Keys
"""
rows = session.execute(delete(Variable).where(Variable.key ==
key)).rowcount
diff --git a/airflow/serialization/enums.py b/airflow/serialization/enums.py
index a5bd5e3646..f216ce7316 100644
--- a/airflow/serialization/enums.py
+++ b/airflow/serialization/enums.py
@@ -46,6 +46,7 @@ class DagAttributeTypes(str, Enum):
RELATIVEDELTA = "relativedelta"
BASE_TRIGGER = "base_trigger"
AIRFLOW_EXC_SER = "airflow_exc_ser"
+ BASE_EXC_SER = "base_exc_ser"
DICT = "dict"
SET = "set"
TUPLE = "tuple"
diff --git a/airflow/serialization/serialized_objects.py
b/airflow/serialization/serialized_objects.py
index 6d0bbd4e23..84ad567918 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -692,6 +692,15 @@ class BaseSerialization:
),
type_=DAT.AIRFLOW_EXC_SER,
)
+ elif isinstance(var, (KeyError, AttributeError)):
+ return cls._encode(
+ cls.serialize(
+ {"exc_cls_name": var.__class__.__name__, "args":
[var.args], "kwargs": {}},
+ use_pydantic_models=use_pydantic_models,
+ strict=strict,
+ ),
+ type_=DAT.BASE_EXC_SER,
+ )
elif isinstance(var, BaseTrigger):
return cls._encode(
cls.serialize(var.serialize(),
use_pydantic_models=use_pydantic_models, strict=strict),
@@ -834,13 +843,16 @@ class BaseSerialization:
return decode_timezone(var)
elif type_ == DAT.RELATIVEDELTA:
return decode_relativedelta(var)
- elif type_ == DAT.AIRFLOW_EXC_SER:
+ elif type_ == DAT.AIRFLOW_EXC_SER or type_ == DAT.BASE_EXC_SER:
deser = cls.deserialize(var,
use_pydantic_models=use_pydantic_models)
exc_cls_name = deser["exc_cls_name"]
args = deser["args"]
kwargs = deser["kwargs"]
del deser
- exc_cls = import_string(exc_cls_name)
+ if type_ == DAT.AIRFLOW_EXC_SER:
+ exc_cls = import_string(exc_cls_name)
+ else:
+ exc_cls = import_string(f"builtins.{exc_cls_name}")
return exc_cls(*args, **kwargs)
elif type_ == DAT.BASE_TRIGGER:
tr_cls_name, kwargs = cls.deserialize(var,
use_pydantic_models=use_pydantic_models)
diff --git a/tests/models/test_variable.py b/tests/models/test_variable.py
index e3d5c023a2..6fb6fa15f2 100644
--- a/tests/models/test_variable.py
+++ b/tests/models/test_variable.py
@@ -47,6 +47,7 @@ class TestVariable:
db.clear_db_variables()
crypto._fernet = None
+ @pytest.mark.skip_if_database_isolation_mode # Does not work in db
isolation mode, internal API has other fernet
@conf_vars({("core", "fernet_key"): "", ("core", "unit_test_mode"):
"True"})
def test_variable_no_encryption(self, session):
"""
@@ -60,6 +61,7 @@ class TestVariable:
# should mask anything. That logic is tested in test_secrets_masker.py
self.mask_secret.assert_called_once_with("value", "key")
+ @pytest.mark.skip_if_database_isolation_mode # Does not work in db
isolation mode, internal API has other fernet
@conf_vars({("core", "fernet_key"): Fernet.generate_key().decode()})
def test_variable_with_encryption(self, session):
"""
@@ -70,6 +72,7 @@ class TestVariable:
assert test_var.is_encrypted
assert test_var.val == "value"
+ @pytest.mark.skip_if_database_isolation_mode # Does not work in db
isolation mode, internal API has other fernet
@pytest.mark.parametrize("test_value", ["value", ""])
def test_var_with_encryption_rotate_fernet_key(self, test_value, session):
"""
@@ -152,6 +155,7 @@ class TestVariable:
Variable.update(key="test_key", value="value2", session=session)
assert "value2" == Variable.get("test_key")
+ @pytest.mark.skip_if_database_isolation_mode # Does not work in db
isolation mode, API server has other ENV
def test_variable_update_fails_on_non_metastore_variable(self, session):
with mock.patch.dict("os.environ", AIRFLOW_VAR_KEY="env-value"):
with pytest.raises(AttributeError):
@@ -281,6 +285,7 @@ class TestVariable:
mock_backend.get_variable.assert_called_once() # second call was not
made because of cache
assert first == second
+ @pytest.mark.skip_if_database_isolation_mode # Does not work in db
isolation mode, internal API has other env
def test_cache_invalidation_on_set(self, session):
with mock.patch.dict("os.environ", AIRFLOW_VAR_KEY="from_env"):
a = Variable.get("key") # value is saved in cache
@@ -316,7 +321,7 @@ def test_masking_only_secret_values(variable_value,
deserialize_json, expected_m
val=variable_value,
)
session.add(var)
- session.flush()
+ session.commit()
# Make sure we re-load it, not just get the cached object back
session.expunge(var)
_secrets_masker().patterns = set()
@@ -326,5 +331,4 @@ def test_masking_only_secret_values(variable_value,
deserialize_json, expected_m
for expected_masked_value in expected_masked_values:
assert expected_masked_value in _secrets_masker().patterns
finally:
- session.rollback()
db.clear_db_variables()