This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new b4b8a31df0 Bugfix/41067 fix tests operators test python (#41257)
b4b8a31df0 is described below
commit b4b8a31df0b2f6438f8b1707662d661c5e0a7ebe
Author: Jens Scheffler <[email protected]>
AuthorDate: Mon Aug 5 08:16:12 2024 +0200
Bugfix/41067 fix tests operators test python (#41257)
* Ensure that internal API transports AirflowException (not HTTP 500) and
is correctly serialized
* Fix tests/operators/test_python.py for Database Isolation Tests
* Adjust variable in test decorators according to test_python
* Adjust variable in test sensors according to test_python
---
airflow/api_internal/endpoints/rpc_api_endpoint.py | 6 ++
airflow/api_internal/internal_api_call.py | 5 +-
airflow/exceptions.py | 18 ++++-
airflow/serialization/serialized_objects.py | 2 +-
tests/conftest.py | 22 ++++++
tests/decorators/test_python.py | 87 ++++++++++++----------
tests/operators/test_python.py | 86 +++++++++++++--------
tests/sensors/test_python.py | 2 +-
tests/serialization/test_serialized_objects.py | 22 +++++-
9 files changed, 174 insertions(+), 76 deletions(-)
diff --git a/airflow/api_internal/endpoints/rpc_api_endpoint.py
b/airflow/api_internal/endpoints/rpc_api_endpoint.py
index 4a5fdb0276..eb4ddbad52 100644
--- a/airflow/api_internal/endpoints/rpc_api_endpoint.py
+++ b/airflow/api_internal/endpoints/rpc_api_endpoint.py
@@ -35,6 +35,7 @@ from jwt import (
from airflow.api_connexion.exceptions import PermissionDenied
from airflow.configuration import conf
+from airflow.exceptions import AirflowException
from airflow.jobs.job import Job, most_recent_job
from airflow.models.dagcode import DagCode
from airflow.models.taskinstance import _record_task_map_for_downstreams
@@ -234,5 +235,10 @@ def internal_airflow_api(body: dict[str, Any]) ->
APIResponse:
response = json.dumps(output_json) if output_json is not None else
None
log.info("Sending response: %s", response)
return Response(response=response, headers={"Content-Type":
"application/json"})
+ except AirflowException as e: # In case of AirflowException transport the
exception class back to caller
+ exception_json = BaseSerialization.serialize(e,
use_pydantic_models=True)
+ response = json.dumps(exception_json)
+ log.info("Sending exception response: %s", response)
+ return Response(response=response, headers={"Content-Type":
"application/json"})
except Exception:
return log_and_build_error_response(message=f"Error executing method
'{method_name}'.", status=500)
diff --git a/airflow/api_internal/internal_api_call.py
b/airflow/api_internal/internal_api_call.py
index 7ad56c876f..fc0945b3c0 100644
--- a/airflow/api_internal/internal_api_call.py
+++ b/airflow/api_internal/internal_api_call.py
@@ -158,6 +158,9 @@ def internal_api_call(func: Callable[PS, RT]) ->
Callable[PS, RT]:
result = make_jsonrpc_request(method_name, args_dict)
if result is None or result == b"":
return None
- return BaseSerialization.deserialize(json.loads(result),
use_pydantic_models=True)
+ result = BaseSerialization.deserialize(json.loads(result),
use_pydantic_models=True)
+ if isinstance(result, AirflowException):
+ raise result
+ return result
return wrapper
diff --git a/airflow/exceptions.py b/airflow/exceptions.py
index dc59f91841..40a62ad208 100644
--- a/airflow/exceptions.py
+++ b/airflow/exceptions.py
@@ -43,6 +43,10 @@ class AirflowException(Exception):
status_code = HTTPStatus.INTERNAL_SERVER_ERROR
+ def serialize(self):
+ cls = self.__class__
+ return f"{cls.__module__}.{cls.__name__}", (str(self),), {}
+
class AirflowBadRequest(AirflowException):
"""Raise when the application or server cannot handle the request."""
@@ -76,7 +80,8 @@ class AirflowRescheduleException(AirflowException):
self.reschedule_date = reschedule_date
def serialize(self):
- return "AirflowRescheduleException", (), {"reschedule_date":
self.reschedule_date}
+ cls = self.__class__
+ return f"{cls.__module__}.{cls.__name__}", (), {"reschedule_date":
self.reschedule_date}
class InvalidStatsNameException(AirflowException):
@@ -132,6 +137,14 @@ class XComNotFound(AirflowException):
def __str__(self) -> str:
return f'XComArg result from {self.task_id} at {self.dag_id} with
key="{self.key}" is not found!'
+ def serialize(self):
+ cls = self.__class__
+ return (
+ f"{cls.__module__}.{cls.__name__}",
+ (),
+ {"dag_id": self.dag_id, "task_id": self.task_id, "key": self.key},
+ )
+
class UnmappableOperator(AirflowException):
"""Raise when an operator is not implemented to be mappable."""
@@ -396,8 +409,9 @@ class TaskDeferred(BaseException):
raise ValueError("Timeout value must be a timedelta")
def serialize(self):
+ cls = self.__class__
return (
- self.__class__.__name__,
+ f"{cls.__module__}.{cls.__name__}",
(),
{
"trigger": self.trigger,
diff --git a/airflow/serialization/serialized_objects.py
b/airflow/serialization/serialized_objects.py
index ecb5757632..94631c993c 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -840,7 +840,7 @@ class BaseSerialization:
args = deser["args"]
kwargs = deser["kwargs"]
del deser
- exc_cls = import_string(f"airflow.exceptions.{exc_cls_name}")
+ exc_cls = import_string(exc_cls_name)
return exc_cls(*args, **kwargs)
elif type_ == DAT.BASE_TRIGGER:
tr_cls_name, kwargs = cls.deserialize(var,
use_pydantic_models=use_pydantic_models)
diff --git a/tests/conftest.py b/tests/conftest.py
index 472a42cf28..11b47363b0 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1120,6 +1120,28 @@ def create_task_instance(dag_maker, create_dummy_dag):
return maker
[email protected]
+def create_serialized_task_instance_of_operator(dag_maker):
+ def _create_task_instance(
+ operator_class,
+ *,
+ dag_id,
+ execution_date=None,
+ session=None,
+ **operator_kwargs,
+ ) -> TaskInstance:
+ with dag_maker(dag_id=dag_id, serialized=True, session=session):
+ operator_class(**operator_kwargs)
+ if execution_date is None:
+ dagrun_kwargs = {}
+ else:
+ dagrun_kwargs = {"execution_date": execution_date}
+ (ti,) = dag_maker.create_dagrun(**dagrun_kwargs).task_instances
+ return ti
+
+ return _create_task_instance
+
+
@pytest.fixture
def create_task_instance_of_operator(dag_maker):
def _create_task_instance(
diff --git a/tests/decorators/test_python.py b/tests/decorators/test_python.py
index 1561b8224d..83e2d86cd7 100644
--- a/tests/decorators/test_python.py
+++ b/tests/decorators/test_python.py
@@ -212,7 +212,7 @@ class TestAirflowTaskDecorator(BasePythonTest):
def identity2(x: int, y: int) -> Tuple[int, int]:
return x, y
- with self.dag:
+ with self.dag_non_serialized:
res = identity2(8, 4)
dr = self.create_dag_run()
@@ -230,7 +230,7 @@ class TestAirflowTaskDecorator(BasePythonTest):
def identity_tuple(x: int, y: int) -> Tuple[int, int]:
return x, y
- with self.dag:
+ with self.dag_non_serialized:
ident = identity_tuple(35, 36)
dr = self.create_dag_run()
@@ -284,7 +284,7 @@ class TestAirflowTaskDecorator(BasePythonTest):
def add_number(num: int):
return {2: num}
- with self.dag:
+ with self.dag_non_serialized:
ret = add_number(2)
self.create_dag_run()
@@ -296,7 +296,7 @@ class TestAirflowTaskDecorator(BasePythonTest):
def add_number(num: int):
return num
- with self.dag:
+ with self.dag_non_serialized:
ret = add_number(2)
self.create_dag_run()
@@ -308,7 +308,7 @@ class TestAirflowTaskDecorator(BasePythonTest):
def empty_dict():
return {}
- with self.dag:
+ with self.dag_non_serialized:
ret = empty_dict()
dr = self.create_dag_run()
@@ -321,7 +321,7 @@ class TestAirflowTaskDecorator(BasePythonTest):
def test_func():
return
- with self.dag:
+ with self.dag_non_serialized:
ret = test_func()
dr = self.create_dag_run()
@@ -341,7 +341,7 @@ class TestAirflowTaskDecorator(BasePythonTest):
Named = namedtuple("Named", ["var1", "var2"])
named_tuple = Named("{{ ds }}", "unchanged")
- with self.dag:
+ with self.dag_non_serialized:
ret = arg_task(4, date(2019, 1, 1), "dag {{dag.dag_id}} ran on
{{ds}}.", named_tuple)
dr = self.create_dag_run()
@@ -360,7 +360,7 @@ class TestAirflowTaskDecorator(BasePythonTest):
def kwargs_task(an_int, a_date, a_templated_string):
raise RuntimeError("Should not executed")
- with self.dag:
+ with self.dag_non_serialized:
ret = kwargs_task(
an_int=4, a_date=date(2019, 1, 1), a_templated_string="dag
{{dag.dag_id}} ran on {{ds}}."
)
@@ -379,9 +379,9 @@ class TestAirflowTaskDecorator(BasePythonTest):
def do_run():
return 4
- with self.dag:
+ with self.dag_non_serialized:
do_run()
- assert ["some_name"] == self.dag.task_ids
+ assert ["some_name"] == self.dag_non_serialized.task_ids
def test_multiple_calls(self):
"""Test calling task multiple times in a DAG"""
@@ -390,12 +390,12 @@ class TestAirflowTaskDecorator(BasePythonTest):
def do_run():
return 4
- with self.dag:
+ with self.dag_non_serialized:
do_run()
- assert ["do_run"] == self.dag.task_ids
+ assert ["do_run"] == self.dag_non_serialized.task_ids
do_run_1 = do_run()
do_run_2 = do_run()
- assert ["do_run", "do_run__1", "do_run__2"] == self.dag.task_ids
+ assert ["do_run", "do_run__1", "do_run__2"] ==
self.dag_non_serialized.task_ids
assert do_run_1.operator.task_id == "do_run__1"
assert do_run_2.operator.task_id == "do_run__2"
@@ -408,14 +408,14 @@ class TestAirflowTaskDecorator(BasePythonTest):
return 4
group_id = "KnightsOfNii"
- with self.dag:
+ with self.dag_non_serialized:
with TaskGroup(group_id=group_id):
do_run()
- assert [f"{group_id}.do_run"] == self.dag.task_ids
+ assert [f"{group_id}.do_run"] ==
self.dag_non_serialized.task_ids
do_run()
- assert [f"{group_id}.do_run", f"{group_id}.do_run__1"] ==
self.dag.task_ids
+ assert [f"{group_id}.do_run", f"{group_id}.do_run__1"] ==
self.dag_non_serialized.task_ids
- assert len(self.dag.task_ids) == 2
+ assert len(self.dag_non_serialized.task_ids) == 2
def test_call_20(self):
"""Test calling decorated function 21 times in a DAG"""
@@ -424,12 +424,12 @@ class TestAirflowTaskDecorator(BasePythonTest):
def __do_run():
return 4
- with self.dag:
+ with self.dag_non_serialized:
__do_run()
for _ in range(20):
__do_run()
- assert self.dag.task_ids[-1] == "__do_run__20"
+ assert self.dag_non_serialized.task_ids[-1] == "__do_run__20"
def test_multiple_outputs(self):
"""Tests pushing multiple outputs as a dictionary"""
@@ -439,15 +439,17 @@ class TestAirflowTaskDecorator(BasePythonTest):
return {"number": number + 1, "43": 43}
test_number = 10
- with self.dag:
+ with self.dag_non_serialized:
ret = return_dict(test_number)
- dr = self.dag.create_dagrun(
+ dr = self.dag_non_serialized.create_dagrun(
run_id=DagRunType.MANUAL,
start_date=timezone.utcnow(),
execution_date=DEFAULT_DATE,
state=State.RUNNING,
-
data_interval=self.dag.timetable.infer_manual_data_interval(run_after=DEFAULT_DATE),
+
data_interval=self.dag_non_serialized.timetable.infer_manual_data_interval(
+ run_after=DEFAULT_DATE
+ ),
)
ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
@@ -464,8 +466,8 @@ class TestAirflowTaskDecorator(BasePythonTest):
def do_run():
return 4
- self.dag.default_args["owner"] = "airflow"
- with self.dag:
+ self.dag_non_serialized.default_args["owner"] = "airflow"
+ with self.dag_non_serialized:
ret = do_run()
assert ret.operator.owner == "airflow"
@@ -474,14 +476,14 @@ class TestAirflowTaskDecorator(BasePythonTest):
return unknown
with pytest.raises(TypeError):
- with self.dag:
+ with self.dag_non_serialized:
test_apply_default_raise()
@task_decorator
def test_apply_default(owner):
return owner
- with self.dag:
+ with self.dag_non_serialized:
ret = test_apply_default()
assert "owner" in ret.operator.op_kwargs
@@ -498,16 +500,18 @@ class TestAirflowTaskDecorator(BasePythonTest):
test_number = 10
- with self.dag:
+ with self.dag_non_serialized:
bigger_number = add_2(test_number)
ret = add_num(bigger_number, XComArg(bigger_number.operator))
- dr = self.dag.create_dagrun(
+ dr = self.dag_non_serialized.create_dagrun(
run_id=DagRunType.MANUAL,
start_date=timezone.utcnow(),
execution_date=DEFAULT_DATE,
state=State.RUNNING,
-
data_interval=self.dag.timetable.infer_manual_data_interval(run_after=DEFAULT_DATE),
+
data_interval=self.dag_non_serialized.timetable.infer_manual_data_interval(
+ run_after=DEFAULT_DATE
+ ),
)
bigger_number.operator.run(start_date=DEFAULT_DATE,
end_date=DEFAULT_DATE)
@@ -519,7 +523,7 @@ class TestAirflowTaskDecorator(BasePythonTest):
def test_dag_task(self):
"""Tests dag.task property to generate task"""
- @self.dag.task
+ @self.dag_non_serialized.task
def add_2(number: int):
return number + 2
@@ -527,12 +531,12 @@ class TestAirflowTaskDecorator(BasePythonTest):
res = add_2(test_number)
add_2(res)
- assert "add_2" in self.dag.task_ids
+ assert "add_2" in self.dag_non_serialized.task_ids
def test_dag_task_multiple_outputs(self):
"""Tests dag.task property to generate task with multiple outputs"""
- @self.dag.task(multiple_outputs=True)
+ @self.dag_non_serialized.task(multiple_outputs=True)
def add_2(number: int):
return {"1": number + 2, "2": 42}
@@ -540,7 +544,7 @@ class TestAirflowTaskDecorator(BasePythonTest):
add_2(test_number)
add_2(test_number)
- assert "add_2" in self.dag.task_ids
+ assert "add_2" in self.dag_non_serialized.task_ids
@pytest.mark.parametrize(
argnames=["op_doc_attr", "op_doc_value", "expected_doc_md"],
@@ -564,7 +568,7 @@ class TestAirflowTaskDecorator(BasePythonTest):
return number + 2
test_number = 10
- with self.dag:
+ with self.dag_non_serialized:
ret = add_2(test_number)
assert ret.operator.doc_md == expected_doc_md
@@ -579,12 +583,17 @@ class TestAirflowTaskDecorator(BasePythonTest):
"""
print("Hello world")
- with self.dag:
+ with self.dag_non_serialized:
for i in range(3):
hello.override(task_id=f"my_task_id_{i * 2}")()
hello() # This task would have hello_task as the task_id
- assert self.dag.task_ids == ["my_task_id_0", "my_task_id_2",
"my_task_id_4", "hello_task"]
+ assert self.dag_non_serialized.task_ids == [
+ "my_task_id_0",
+ "my_task_id_2",
+ "my_task_id_4",
+ "hello_task",
+ ]
def test_user_provided_pool_and_priority_weight_works(self):
"""Tests that when looping that user provided pool, priority_weight
etc is used"""
@@ -596,12 +605,12 @@ class TestAirflowTaskDecorator(BasePythonTest):
"""
print("Hello world")
- with self.dag:
+ with self.dag_non_serialized:
for i in range(3):
hello.override(pool="my_pool", priority_weight=i)()
weights = []
- for task in self.dag.tasks:
+ for task in self.dag_non_serialized.tasks:
assert task.pool == "my_pool"
weights.append(task.priority_weight)
assert weights == [0, 1, 2]
@@ -617,7 +626,7 @@ class TestAirflowTaskDecorator(BasePythonTest):
print("Hello world", x, y)
return x, y
- with self.dag:
+ with self.dag_non_serialized:
output = hello.override(task_id="mytask")(x=2, y=3)
output2 = hello.override()(2, 3) # nothing overridden but should
work
diff --git a/tests/operators/test_python.py b/tests/operators/test_python.py
index 66e3c9a823..893a215d9d 100644
--- a/tests/operators/test_python.py
+++ b/tests/operators/test_python.py
@@ -95,14 +95,14 @@ class BasePythonTest:
default_date: datetime = DEFAULT_DATE
@pytest.fixture(autouse=True)
- def base_tests_setup(self, request, create_task_instance_of_operator,
dag_maker):
+ def base_tests_setup(self, request,
create_serialized_task_instance_of_operator, dag_maker):
self.dag_id = f"dag_{slugify(request.cls.__name__)}"
self.task_id = f"task_{slugify(request.node.name, max_length=40)}"
self.run_id = f"run_{slugify(request.node.name, max_length=40)}"
self.ds_templated = self.default_date.date().isoformat()
- self.ti_maker = create_task_instance_of_operator
+ self.ti_maker = create_serialized_task_instance_of_operator
self.dag_maker = dag_maker
- self.dag = self.dag_maker(self.dag_id,
template_searchpath=TEMPLATE_SEARCHPATH).dag
+ self.dag_non_serialized = self.dag_maker(self.dag_id,
template_searchpath=TEMPLATE_SEARCHPATH).dag
clear_db_runs()
yield
clear_db_runs()
@@ -129,7 +129,7 @@ class BasePythonTest:
return kwargs
def create_dag_run(self) -> DagRun:
- return self.dag.create_dagrun(
+ return self.dag_maker.create_dagrun(
state=DagRunState.RUNNING,
start_date=self.dag_maker.start_date,
session=self.dag_maker.session,
@@ -151,10 +151,12 @@ class BasePythonTest:
def run_as_operator(self, fn, **kwargs):
"""Run task by direct call ``run`` method."""
- with self.dag:
+ clear_db_runs()
+ with self.dag_maker(self.dag_id,
template_searchpath=TEMPLATE_SEARCHPATH, serialized=True):
task = self.opcls(task_id=self.task_id, python_callable=fn,
**self.default_kwargs(**kwargs))
-
+ self.dag_maker.create_dagrun()
task.run(start_date=self.default_date, end_date=self.default_date)
+ clear_db_runs()
return task
def run_as_task(self, fn, return_ti=False, **kwargs):
@@ -324,13 +326,13 @@ class TestPythonOperator(BasePythonTest):
def func():
return "test_return_value"
- python_operator = PythonOperator(
- task_id="python_operator",
- python_callable=func,
- dag=self.dag,
- show_return_value_in_logs=False,
- templates_exts=["test_ext"],
- )
+ with self.dag_maker(self.dag_id,
template_searchpath=TEMPLATE_SEARCHPATH, serialized=True):
+ python_operator = PythonOperator(
+ task_id="python_operator",
+ python_callable=func,
+ show_return_value_in_logs=False,
+ templates_exts=["test_ext"],
+ )
assert python_operator.template_ext == ["test_ext"]
@@ -369,7 +371,8 @@ class TestBranchOperator(BasePythonTest):
self.branch_2 = EmptyOperator(task_id="branch_2")
def test_with_dag_run(self):
- with self.dag:
+ clear_db_runs()
+ with self.dag_maker(self.dag_id,
template_searchpath=TEMPLATE_SEARCHPATH, serialized=True):
def f():
return "branch_1"
@@ -384,7 +387,8 @@ class TestBranchOperator(BasePythonTest):
)
def test_with_skip_in_branch_downstream_dependencies(self):
- with self.dag:
+ clear_db_runs()
+ with self.dag_maker(self.dag_id,
template_searchpath=TEMPLATE_SEARCHPATH, serialized=True):
def f():
return "branch_1"
@@ -400,7 +404,8 @@ class TestBranchOperator(BasePythonTest):
)
def test_with_skip_in_branch_downstream_dependencies2(self):
- with self.dag:
+ clear_db_runs()
+ with self.dag_maker(self.dag_id,
template_searchpath=TEMPLATE_SEARCHPATH, serialized=True):
def f():
return "branch_2"
@@ -416,7 +421,8 @@ class TestBranchOperator(BasePythonTest):
)
def test_xcom_push(self):
- with self.dag:
+ clear_db_runs()
+ with self.dag_maker(self.dag_id,
template_searchpath=TEMPLATE_SEARCHPATH, serialized=True):
def f():
return "branch_1"
@@ -433,12 +439,13 @@ class TestBranchOperator(BasePythonTest):
else:
pytest.fail(f"{self.task_id!r} not found.")
+ @pytest.mark.skip_if_database_isolation_mode # tests logic with
clear_task_instances(), this needs DB access
def test_clear_skipped_downstream_task(self):
"""
After a downstream task is skipped by BranchPythonOperator, clearing
the skipped task
should not cause it to be executed.
"""
- with self.dag:
+ with self.dag_non_serialized:
def f():
return "branch_1"
@@ -492,6 +499,7 @@ class TestBranchOperator(BasePythonTest):
with pytest.raises(AirflowException, match="Invalid tasks found:
{'some_task_id'}"):
ti.run()
+ @pytest.mark.skip_if_database_isolation_mode # tests pure logic with
run() method, can not run in isolation mode
@pytest.mark.parametrize(
"choice,expected_states",
[
@@ -503,7 +511,7 @@ class TestBranchOperator(BasePythonTest):
"""
Tests that BranchPythonOperator handles empty branches properly.
"""
- with self.dag:
+ with self.dag_non_serialized:
def f():
return choice
@@ -521,7 +529,7 @@ class TestBranchOperator(BasePythonTest):
for task_id in task_ids: # Mimic the specific order the scheduling
would run the tests.
task_instance = tis[task_id]
- task_instance.refresh_from_task(self.dag.get_task(task_id))
+
task_instance.refresh_from_task(self.dag_non_serialized.get_task(task_id))
task_instance.run()
def get_state(ti):
@@ -547,6 +555,7 @@ class TestShortCircuitOperator(BasePythonTest):
}
all_success_states = {"short_circuit": State.SUCCESS, "op1":
State.SUCCESS, "op2": State.SUCCESS}
+ @pytest.mark.skip_if_database_isolation_mode # tests pure logic with
run() method, can not run in isolation mode
@pytest.mark.parametrize(
argnames=(
"callable_return, test_ignore_downstream_trigger_rules,
test_trigger_rule, expected_task_states"
@@ -645,7 +654,7 @@ class TestShortCircuitOperator(BasePythonTest):
Checking the behavior of the ShortCircuitOperator in several scenarios
enabling/disabling the skipping
of downstream tasks, both short-circuiting modes, and various trigger
rules of downstream tasks.
"""
- with self.dag:
+ with self.dag_non_serialized:
short_circuit = ShortCircuitOperator(
task_id="short_circuit",
python_callable=lambda: callable_return,
@@ -665,12 +674,13 @@ class TestShortCircuitOperator(BasePythonTest):
assert self.op2.trigger_rule == test_trigger_rule
self.assert_expected_task_states(dr, expected_task_states)
+ @pytest.mark.skip_if_database_isolation_mode # tests logic with
clear_task_instances(), this needs DB access
def test_clear_skipped_downstream_task(self):
"""
After a downstream task is skipped by ShortCircuitOperator, clearing
the skipped task
should not cause it to be executed.
"""
- with self.dag:
+ with self.dag_non_serialized:
short_circuit = ShortCircuitOperator(task_id="short_circuit",
python_callable=lambda: False)
short_circuit >> self.op1 >> self.op2
dr = self.create_dag_run()
@@ -700,7 +710,8 @@ class TestShortCircuitOperator(BasePythonTest):
self.assert_expected_task_states(dr, expected_states)
def test_xcom_push(self):
- with self.dag:
+ clear_db_runs()
+ with self.dag_maker(self.dag_id,
template_searchpath=TEMPLATE_SEARCHPATH, serialized=True):
short_op_push_xcom = ShortCircuitOperator(
task_id="push_xcom_from_shortcircuit", python_callable=lambda:
"signature"
)
@@ -716,8 +727,9 @@ class TestShortCircuitOperator(BasePythonTest):
assert tis[0].xcom_pull(task_ids=short_op_push_xcom.task_id,
key="return_value") == "signature"
assert tis[0].xcom_pull(task_ids=short_op_no_push_xcom.task_id,
key="return_value") is False
+ @pytest.mark.skip_if_database_isolation_mode # tests pure logic with
run() method, can not run in isolation mode
def test_xcom_push_skipped_tasks(self):
- with self.dag:
+ with self.dag_non_serialized:
short_op_push_xcom = ShortCircuitOperator(
task_id="push_xcom_from_shortcircuit", python_callable=lambda:
False
)
@@ -730,8 +742,9 @@ class TestShortCircuitOperator(BasePythonTest):
"skipped": ["empty_task"]
}
+ @pytest.mark.skip_if_database_isolation_mode # tests pure logic with
run() method, can not run in isolation mode
def test_mapped_xcom_push_skipped_tasks(self, session):
- with self.dag:
+ with self.dag_non_serialized:
@task_group
def group(x):
@@ -1394,7 +1407,7 @@ class
TestExternalPythonOperator(BaseTestPythonVirtualenvOperator):
python_callable=f,
task_id="task",
python=sys.executable,
- dag=self.dag,
+ dag=self.dag_non_serialized,
)
loads_mock.side_effect = DeserializingResultError
@@ -1491,7 +1504,8 @@ class
BaseTestBranchPythonVirtualenvOperator(BaseTestPythonVirtualenvOperator):
self.run_as_task(f, do_not_use_caching=True)
def test_with_dag_run(self):
- with self.dag:
+ clear_db_runs()
+ with self.dag_maker(self.dag_id,
template_searchpath=TEMPLATE_SEARCHPATH, serialized=True):
def f():
return "branch_1"
@@ -1506,7 +1520,8 @@ class
BaseTestBranchPythonVirtualenvOperator(BaseTestPythonVirtualenvOperator):
)
def test_with_skip_in_branch_downstream_dependencies(self):
- with self.dag:
+ clear_db_runs()
+ with self.dag_maker(self.dag_id,
template_searchpath=TEMPLATE_SEARCHPATH, serialized=True):
def f():
return "branch_1"
@@ -1522,7 +1537,8 @@ class
BaseTestBranchPythonVirtualenvOperator(BaseTestPythonVirtualenvOperator):
)
def test_with_skip_in_branch_downstream_dependencies2(self):
- with self.dag:
+ clear_db_runs()
+ with self.dag_maker(self.dag_id,
template_searchpath=TEMPLATE_SEARCHPATH, serialized=True):
def f():
return "branch_2"
@@ -1538,7 +1554,8 @@ class
BaseTestBranchPythonVirtualenvOperator(BaseTestPythonVirtualenvOperator):
)
def test_xcom_push(self):
- with self.dag:
+ clear_db_runs()
+ with self.dag_maker(self.dag_id,
template_searchpath=TEMPLATE_SEARCHPATH, serialized=True):
def f():
return "branch_1"
@@ -1555,12 +1572,14 @@ class
BaseTestBranchPythonVirtualenvOperator(BaseTestPythonVirtualenvOperator):
else:
pytest.fail(f"{self.task_id!r} not found.")
+ @pytest.mark.skip_if_database_isolation_mode # tests logic with
clear_task_instances(), this needs DB access
def test_clear_skipped_downstream_task(self):
"""
After a downstream task is skipped by BranchPythonOperator, clearing
the skipped task
should not cause it to be executed.
"""
- with self.dag:
+ clear_db_runs()
+ with self.dag_maker(self.dag_id,
template_searchpath=TEMPLATE_SEARCHPATH, serialized=True):
def f():
return "branch_1"
@@ -1721,6 +1740,7 @@ DEFAULT_ARGS = {
}
[email protected]_if_database_isolation_mode # tests pure logic with run()
method, can not run in isolation mode
@pytest.mark.usefixtures("clear_db")
class TestCurrentContextRuntime:
def test_context_in_task(self):
@@ -1736,6 +1756,7 @@ class TestCurrentContextRuntime:
@pytest.mark.need_serialized_dag(False)
class TestShortCircuitWithTeardown:
+ @pytest.mark.skip_if_database_isolation_mode # tests pure logic with
run() method, mix of pydantic and mock fails
@pytest.mark.parametrize(
"ignore_downstream_trigger_rules, with_teardown, should_skip,
expected",
[
@@ -1777,6 +1798,7 @@ class TestShortCircuitWithTeardown:
else:
op1.skip.assert_not_called()
+ @pytest.mark.skip_if_database_isolation_mode # tests pure logic with
run() method, mix of pydantic and mock fails
@pytest.mark.parametrize("config", ["sequence", "parallel"])
def test_short_circuit_with_teardowns_complicated(self, dag_maker, config):
with dag_maker():
@@ -1804,6 +1826,7 @@ class TestShortCircuitWithTeardown:
actual_skipped = set(op1.skip.call_args.kwargs["tasks"])
assert actual_skipped == {s2, op2}
+ @pytest.mark.skip_if_database_isolation_mode # tests pure logic with
run() method, mix of pydantic and mock fails
def test_short_circuit_with_teardowns_complicated_2(self, dag_maker):
with dag_maker():
s1 = PythonOperator(task_id="s1", python_callable=print).as_setup()
@@ -1833,6 +1856,7 @@ class TestShortCircuitWithTeardown:
assert actual_kwargs["execution_date"] == dagrun.logical_date
assert actual_skipped == {op3}
+ @pytest.mark.skip_if_database_isolation_mode # tests pure logic with
run() method, mix of pydantic and mock fails
@pytest.mark.parametrize("level", [logging.DEBUG, logging.INFO])
def test_short_circuit_with_teardowns_debug_level(self, dag_maker, level,
clear_db):
"""
diff --git a/tests/sensors/test_python.py b/tests/sensors/test_python.py
index a971b8a146..60814620b3 100644
--- a/tests/sensors/test_python.py
+++ b/tests/sensors/test_python.py
@@ -45,7 +45,7 @@ class TestPythonSensor(BasePythonTest):
self.run_as_task(lambda: 1 / 0)
def test_python_sensor_xcom(self):
- with self.dag:
+ with self.dag_non_serialized:
task = self.opcls(
task_id=self.task_id,
python_callable=lambda: PokeReturnValue(True, "xcom"),
diff --git a/tests/serialization/test_serialized_objects.py
b/tests/serialization/test_serialized_objects.py
index e06e7b253b..661ecbf5dc 100644
--- a/tests/serialization/test_serialized_objects.py
+++ b/tests/serialization/test_serialized_objects.py
@@ -31,7 +31,13 @@ from kubernetes.client import models as k8s
from pendulum.tz.timezone import Timezone
from airflow.datasets import Dataset, DatasetAlias, DatasetAliasEvent
-from airflow.exceptions import AirflowRescheduleException, SerializationError,
TaskDeferred
+from airflow.exceptions import (
+ AirflowException,
+ AirflowFailException,
+ AirflowRescheduleException,
+ SerializationError,
+ TaskDeferred,
+)
from airflow.jobs.job import Job
from airflow.models.connection import Connection
from airflow.models.dag import DAG, DagModel, DagTag
@@ -152,6 +158,10 @@ def equal_time(a: datetime, b: datetime) -> bool:
return a.strftime("%s") == b.strftime("%s")
+def equal_exception(a: AirflowException, b: AirflowException) -> bool:
+ return a.__class__ == b.__class__ and str(a) == str(b)
+
+
def equal_outlet_event_accessor(a: OutletEventAccessor, b:
OutletEventAccessor) -> bool:
return a.raw_key == b.raw_key and a.extra == b.extra and
a.dataset_alias_event == b.dataset_alias_event
@@ -252,6 +262,16 @@ class MockLazySelectSequence(LazySelectSequence):
DAT.DATASET_EVENT_ACCESSOR,
equal_outlet_event_accessor,
),
+ (
+ AirflowException("test123 wohoo!"),
+ DAT.AIRFLOW_EXC_SER,
+ equal_exception,
+ ),
+ (
+ AirflowFailException("uuups, failed :-("),
+ DAT.AIRFLOW_EXC_SER,
+ equal_exception,
+ ),
],
)
def test_serialize_deserialize(input, encoded_type, cmp_func):