This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new b4b8a31df0 Bugfix/41067 fix tests operators test python (#41257)
b4b8a31df0 is described below

commit b4b8a31df0b2f6438f8b1707662d661c5e0a7ebe
Author: Jens Scheffler <[email protected]>
AuthorDate: Mon Aug 5 08:16:12 2024 +0200

    Bugfix/41067 fix tests operators test python (#41257)
    
    * Ensure that internal API transports AirflowException (not HTTP 500) and 
is correctly serialized
    
    * Fix tests/operators/test_python.py for Database Isolation Tests
    
    * Adjust variable in test decorators according to test_python
    
    * Adjust variable in test sensors according to test_python
---
 airflow/api_internal/endpoints/rpc_api_endpoint.py |  6 ++
 airflow/api_internal/internal_api_call.py          |  5 +-
 airflow/exceptions.py                              | 18 ++++-
 airflow/serialization/serialized_objects.py        |  2 +-
 tests/conftest.py                                  | 22 ++++++
 tests/decorators/test_python.py                    | 87 ++++++++++++----------
 tests/operators/test_python.py                     | 86 +++++++++++++--------
 tests/sensors/test_python.py                       |  2 +-
 tests/serialization/test_serialized_objects.py     | 22 +++++-
 9 files changed, 174 insertions(+), 76 deletions(-)

diff --git a/airflow/api_internal/endpoints/rpc_api_endpoint.py 
b/airflow/api_internal/endpoints/rpc_api_endpoint.py
index 4a5fdb0276..eb4ddbad52 100644
--- a/airflow/api_internal/endpoints/rpc_api_endpoint.py
+++ b/airflow/api_internal/endpoints/rpc_api_endpoint.py
@@ -35,6 +35,7 @@ from jwt import (
 
 from airflow.api_connexion.exceptions import PermissionDenied
 from airflow.configuration import conf
+from airflow.exceptions import AirflowException
 from airflow.jobs.job import Job, most_recent_job
 from airflow.models.dagcode import DagCode
 from airflow.models.taskinstance import _record_task_map_for_downstreams
@@ -234,5 +235,10 @@ def internal_airflow_api(body: dict[str, Any]) -> 
APIResponse:
             response = json.dumps(output_json) if output_json is not None else 
None
             log.info("Sending response: %s", response)
             return Response(response=response, headers={"Content-Type": 
"application/json"})
+    except AirflowException as e:  # In case of AirflowException transport the 
exception class back to caller
+        exception_json = BaseSerialization.serialize(e, 
use_pydantic_models=True)
+        response = json.dumps(exception_json)
+        log.info("Sending exception response: %s", response)
+        return Response(response=response, headers={"Content-Type": 
"application/json"})
     except Exception:
         return log_and_build_error_response(message=f"Error executing method 
'{method_name}'.", status=500)
diff --git a/airflow/api_internal/internal_api_call.py 
b/airflow/api_internal/internal_api_call.py
index 7ad56c876f..fc0945b3c0 100644
--- a/airflow/api_internal/internal_api_call.py
+++ b/airflow/api_internal/internal_api_call.py
@@ -158,6 +158,9 @@ def internal_api_call(func: Callable[PS, RT]) -> 
Callable[PS, RT]:
         result = make_jsonrpc_request(method_name, args_dict)
         if result is None or result == b"":
             return None
-        return BaseSerialization.deserialize(json.loads(result), 
use_pydantic_models=True)
+        result = BaseSerialization.deserialize(json.loads(result), 
use_pydantic_models=True)
+        if isinstance(result, AirflowException):
+            raise result
+        return result
 
     return wrapper
diff --git a/airflow/exceptions.py b/airflow/exceptions.py
index dc59f91841..40a62ad208 100644
--- a/airflow/exceptions.py
+++ b/airflow/exceptions.py
@@ -43,6 +43,10 @@ class AirflowException(Exception):
 
     status_code = HTTPStatus.INTERNAL_SERVER_ERROR
 
+    def serialize(self):
+        cls = self.__class__
+        return f"{cls.__module__}.{cls.__name__}", (str(self),), {}
+
 
 class AirflowBadRequest(AirflowException):
     """Raise when the application or server cannot handle the request."""
@@ -76,7 +80,8 @@ class AirflowRescheduleException(AirflowException):
         self.reschedule_date = reschedule_date
 
     def serialize(self):
-        return "AirflowRescheduleException", (), {"reschedule_date": 
self.reschedule_date}
+        cls = self.__class__
+        return f"{cls.__module__}.{cls.__name__}", (), {"reschedule_date": 
self.reschedule_date}
 
 
 class InvalidStatsNameException(AirflowException):
@@ -132,6 +137,14 @@ class XComNotFound(AirflowException):
     def __str__(self) -> str:
         return f'XComArg result from {self.task_id} at {self.dag_id} with 
key="{self.key}" is not found!'
 
+    def serialize(self):
+        cls = self.__class__
+        return (
+            f"{cls.__module__}.{cls.__name__}",
+            (),
+            {"dag_id": self.dag_id, "task_id": self.task_id, "key": self.key},
+        )
+
 
 class UnmappableOperator(AirflowException):
     """Raise when an operator is not implemented to be mappable."""
@@ -396,8 +409,9 @@ class TaskDeferred(BaseException):
             raise ValueError("Timeout value must be a timedelta")
 
     def serialize(self):
+        cls = self.__class__
         return (
-            self.__class__.__name__,
+            f"{cls.__module__}.{cls.__name__}",
             (),
             {
                 "trigger": self.trigger,
diff --git a/airflow/serialization/serialized_objects.py 
b/airflow/serialization/serialized_objects.py
index ecb5757632..94631c993c 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -840,7 +840,7 @@ class BaseSerialization:
             args = deser["args"]
             kwargs = deser["kwargs"]
             del deser
-            exc_cls = import_string(f"airflow.exceptions.{exc_cls_name}")
+            exc_cls = import_string(exc_cls_name)
             return exc_cls(*args, **kwargs)
         elif type_ == DAT.BASE_TRIGGER:
             tr_cls_name, kwargs = cls.deserialize(var, 
use_pydantic_models=use_pydantic_models)
diff --git a/tests/conftest.py b/tests/conftest.py
index 472a42cf28..11b47363b0 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1120,6 +1120,28 @@ def create_task_instance(dag_maker, create_dummy_dag):
     return maker
 
 
[email protected]
+def create_serialized_task_instance_of_operator(dag_maker):
+    def _create_task_instance(
+        operator_class,
+        *,
+        dag_id,
+        execution_date=None,
+        session=None,
+        **operator_kwargs,
+    ) -> TaskInstance:
+        with dag_maker(dag_id=dag_id, serialized=True, session=session):
+            operator_class(**operator_kwargs)
+        if execution_date is None:
+            dagrun_kwargs = {}
+        else:
+            dagrun_kwargs = {"execution_date": execution_date}
+        (ti,) = dag_maker.create_dagrun(**dagrun_kwargs).task_instances
+        return ti
+
+    return _create_task_instance
+
+
 @pytest.fixture
 def create_task_instance_of_operator(dag_maker):
     def _create_task_instance(
diff --git a/tests/decorators/test_python.py b/tests/decorators/test_python.py
index 1561b8224d..83e2d86cd7 100644
--- a/tests/decorators/test_python.py
+++ b/tests/decorators/test_python.py
@@ -212,7 +212,7 @@ class TestAirflowTaskDecorator(BasePythonTest):
         def identity2(x: int, y: int) -> Tuple[int, int]:
             return x, y
 
-        with self.dag:
+        with self.dag_non_serialized:
             res = identity2(8, 4)
 
         dr = self.create_dag_run()
@@ -230,7 +230,7 @@ class TestAirflowTaskDecorator(BasePythonTest):
         def identity_tuple(x: int, y: int) -> Tuple[int, int]:
             return x, y
 
-        with self.dag:
+        with self.dag_non_serialized:
             ident = identity_tuple(35, 36)
 
         dr = self.create_dag_run()
@@ -284,7 +284,7 @@ class TestAirflowTaskDecorator(BasePythonTest):
         def add_number(num: int):
             return {2: num}
 
-        with self.dag:
+        with self.dag_non_serialized:
             ret = add_number(2)
 
         self.create_dag_run()
@@ -296,7 +296,7 @@ class TestAirflowTaskDecorator(BasePythonTest):
         def add_number(num: int):
             return num
 
-        with self.dag:
+        with self.dag_non_serialized:
             ret = add_number(2)
 
         self.create_dag_run()
@@ -308,7 +308,7 @@ class TestAirflowTaskDecorator(BasePythonTest):
         def empty_dict():
             return {}
 
-        with self.dag:
+        with self.dag_non_serialized:
             ret = empty_dict()
 
         dr = self.create_dag_run()
@@ -321,7 +321,7 @@ class TestAirflowTaskDecorator(BasePythonTest):
         def test_func():
             return
 
-        with self.dag:
+        with self.dag_non_serialized:
             ret = test_func()
 
         dr = self.create_dag_run()
@@ -341,7 +341,7 @@ class TestAirflowTaskDecorator(BasePythonTest):
         Named = namedtuple("Named", ["var1", "var2"])
         named_tuple = Named("{{ ds }}", "unchanged")
 
-        with self.dag:
+        with self.dag_non_serialized:
             ret = arg_task(4, date(2019, 1, 1), "dag {{dag.dag_id}} ran on 
{{ds}}.", named_tuple)
 
         dr = self.create_dag_run()
@@ -360,7 +360,7 @@ class TestAirflowTaskDecorator(BasePythonTest):
         def kwargs_task(an_int, a_date, a_templated_string):
             raise RuntimeError("Should not executed")
 
-        with self.dag:
+        with self.dag_non_serialized:
             ret = kwargs_task(
                 an_int=4, a_date=date(2019, 1, 1), a_templated_string="dag 
{{dag.dag_id}} ran on {{ds}}."
             )
@@ -379,9 +379,9 @@ class TestAirflowTaskDecorator(BasePythonTest):
         def do_run():
             return 4
 
-        with self.dag:
+        with self.dag_non_serialized:
             do_run()
-            assert ["some_name"] == self.dag.task_ids
+            assert ["some_name"] == self.dag_non_serialized.task_ids
 
     def test_multiple_calls(self):
         """Test calling task multiple times in a DAG"""
@@ -390,12 +390,12 @@ class TestAirflowTaskDecorator(BasePythonTest):
         def do_run():
             return 4
 
-        with self.dag:
+        with self.dag_non_serialized:
             do_run()
-            assert ["do_run"] == self.dag.task_ids
+            assert ["do_run"] == self.dag_non_serialized.task_ids
             do_run_1 = do_run()
             do_run_2 = do_run()
-            assert ["do_run", "do_run__1", "do_run__2"] == self.dag.task_ids
+            assert ["do_run", "do_run__1", "do_run__2"] == 
self.dag_non_serialized.task_ids
 
         assert do_run_1.operator.task_id == "do_run__1"
         assert do_run_2.operator.task_id == "do_run__2"
@@ -408,14 +408,14 @@ class TestAirflowTaskDecorator(BasePythonTest):
             return 4
 
         group_id = "KnightsOfNii"
-        with self.dag:
+        with self.dag_non_serialized:
             with TaskGroup(group_id=group_id):
                 do_run()
-                assert [f"{group_id}.do_run"] == self.dag.task_ids
+                assert [f"{group_id}.do_run"] == 
self.dag_non_serialized.task_ids
                 do_run()
-                assert [f"{group_id}.do_run", f"{group_id}.do_run__1"] == 
self.dag.task_ids
+                assert [f"{group_id}.do_run", f"{group_id}.do_run__1"] == 
self.dag_non_serialized.task_ids
 
-        assert len(self.dag.task_ids) == 2
+        assert len(self.dag_non_serialized.task_ids) == 2
 
     def test_call_20(self):
         """Test calling decorated function 21 times in a DAG"""
@@ -424,12 +424,12 @@ class TestAirflowTaskDecorator(BasePythonTest):
         def __do_run():
             return 4
 
-        with self.dag:
+        with self.dag_non_serialized:
             __do_run()
             for _ in range(20):
                 __do_run()
 
-        assert self.dag.task_ids[-1] == "__do_run__20"
+        assert self.dag_non_serialized.task_ids[-1] == "__do_run__20"
 
     def test_multiple_outputs(self):
         """Tests pushing multiple outputs as a dictionary"""
@@ -439,15 +439,17 @@ class TestAirflowTaskDecorator(BasePythonTest):
             return {"number": number + 1, "43": 43}
 
         test_number = 10
-        with self.dag:
+        with self.dag_non_serialized:
             ret = return_dict(test_number)
 
-        dr = self.dag.create_dagrun(
+        dr = self.dag_non_serialized.create_dagrun(
             run_id=DagRunType.MANUAL,
             start_date=timezone.utcnow(),
             execution_date=DEFAULT_DATE,
             state=State.RUNNING,
-            
data_interval=self.dag.timetable.infer_manual_data_interval(run_after=DEFAULT_DATE),
+            
data_interval=self.dag_non_serialized.timetable.infer_manual_data_interval(
+                run_after=DEFAULT_DATE
+            ),
         )
 
         ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
@@ -464,8 +466,8 @@ class TestAirflowTaskDecorator(BasePythonTest):
         def do_run():
             return 4
 
-        self.dag.default_args["owner"] = "airflow"
-        with self.dag:
+        self.dag_non_serialized.default_args["owner"] = "airflow"
+        with self.dag_non_serialized:
             ret = do_run()
         assert ret.operator.owner == "airflow"
 
@@ -474,14 +476,14 @@ class TestAirflowTaskDecorator(BasePythonTest):
             return unknown
 
         with pytest.raises(TypeError):
-            with self.dag:
+            with self.dag_non_serialized:
                 test_apply_default_raise()
 
         @task_decorator
         def test_apply_default(owner):
             return owner
 
-        with self.dag:
+        with self.dag_non_serialized:
             ret = test_apply_default()
         assert "owner" in ret.operator.op_kwargs
 
@@ -498,16 +500,18 @@ class TestAirflowTaskDecorator(BasePythonTest):
 
         test_number = 10
 
-        with self.dag:
+        with self.dag_non_serialized:
             bigger_number = add_2(test_number)
             ret = add_num(bigger_number, XComArg(bigger_number.operator))
 
-        dr = self.dag.create_dagrun(
+        dr = self.dag_non_serialized.create_dagrun(
             run_id=DagRunType.MANUAL,
             start_date=timezone.utcnow(),
             execution_date=DEFAULT_DATE,
             state=State.RUNNING,
-            
data_interval=self.dag.timetable.infer_manual_data_interval(run_after=DEFAULT_DATE),
+            
data_interval=self.dag_non_serialized.timetable.infer_manual_data_interval(
+                run_after=DEFAULT_DATE
+            ),
         )
 
         bigger_number.operator.run(start_date=DEFAULT_DATE, 
end_date=DEFAULT_DATE)
@@ -519,7 +523,7 @@ class TestAirflowTaskDecorator(BasePythonTest):
     def test_dag_task(self):
         """Tests dag.task property to generate task"""
 
-        @self.dag.task
+        @self.dag_non_serialized.task
         def add_2(number: int):
             return number + 2
 
@@ -527,12 +531,12 @@ class TestAirflowTaskDecorator(BasePythonTest):
         res = add_2(test_number)
         add_2(res)
 
-        assert "add_2" in self.dag.task_ids
+        assert "add_2" in self.dag_non_serialized.task_ids
 
     def test_dag_task_multiple_outputs(self):
         """Tests dag.task property to generate task with multiple outputs"""
 
-        @self.dag.task(multiple_outputs=True)
+        @self.dag_non_serialized.task(multiple_outputs=True)
         def add_2(number: int):
             return {"1": number + 2, "2": 42}
 
@@ -540,7 +544,7 @@ class TestAirflowTaskDecorator(BasePythonTest):
         add_2(test_number)
         add_2(test_number)
 
-        assert "add_2" in self.dag.task_ids
+        assert "add_2" in self.dag_non_serialized.task_ids
 
     @pytest.mark.parametrize(
         argnames=["op_doc_attr", "op_doc_value", "expected_doc_md"],
@@ -564,7 +568,7 @@ class TestAirflowTaskDecorator(BasePythonTest):
             return number + 2
 
         test_number = 10
-        with self.dag:
+        with self.dag_non_serialized:
             ret = add_2(test_number)
 
         assert ret.operator.doc_md == expected_doc_md
@@ -579,12 +583,17 @@ class TestAirflowTaskDecorator(BasePythonTest):
             """
             print("Hello world")
 
-        with self.dag:
+        with self.dag_non_serialized:
             for i in range(3):
                 hello.override(task_id=f"my_task_id_{i * 2}")()
             hello()  # This task would have hello_task as the task_id
 
-        assert self.dag.task_ids == ["my_task_id_0", "my_task_id_2", 
"my_task_id_4", "hello_task"]
+        assert self.dag_non_serialized.task_ids == [
+            "my_task_id_0",
+            "my_task_id_2",
+            "my_task_id_4",
+            "hello_task",
+        ]
 
     def test_user_provided_pool_and_priority_weight_works(self):
         """Tests that when looping that user provided pool, priority_weight 
etc is used"""
@@ -596,12 +605,12 @@ class TestAirflowTaskDecorator(BasePythonTest):
             """
             print("Hello world")
 
-        with self.dag:
+        with self.dag_non_serialized:
             for i in range(3):
                 hello.override(pool="my_pool", priority_weight=i)()
 
         weights = []
-        for task in self.dag.tasks:
+        for task in self.dag_non_serialized.tasks:
             assert task.pool == "my_pool"
             weights.append(task.priority_weight)
         assert weights == [0, 1, 2]
@@ -617,7 +626,7 @@ class TestAirflowTaskDecorator(BasePythonTest):
             print("Hello world", x, y)
             return x, y
 
-        with self.dag:
+        with self.dag_non_serialized:
             output = hello.override(task_id="mytask")(x=2, y=3)
             output2 = hello.override()(2, 3)  # nothing overridden but should 
work
 
diff --git a/tests/operators/test_python.py b/tests/operators/test_python.py
index 66e3c9a823..893a215d9d 100644
--- a/tests/operators/test_python.py
+++ b/tests/operators/test_python.py
@@ -95,14 +95,14 @@ class BasePythonTest:
     default_date: datetime = DEFAULT_DATE
 
     @pytest.fixture(autouse=True)
-    def base_tests_setup(self, request, create_task_instance_of_operator, 
dag_maker):
+    def base_tests_setup(self, request, 
create_serialized_task_instance_of_operator, dag_maker):
         self.dag_id = f"dag_{slugify(request.cls.__name__)}"
         self.task_id = f"task_{slugify(request.node.name, max_length=40)}"
         self.run_id = f"run_{slugify(request.node.name, max_length=40)}"
         self.ds_templated = self.default_date.date().isoformat()
-        self.ti_maker = create_task_instance_of_operator
+        self.ti_maker = create_serialized_task_instance_of_operator
         self.dag_maker = dag_maker
-        self.dag = self.dag_maker(self.dag_id, 
template_searchpath=TEMPLATE_SEARCHPATH).dag
+        self.dag_non_serialized = self.dag_maker(self.dag_id, 
template_searchpath=TEMPLATE_SEARCHPATH).dag
         clear_db_runs()
         yield
         clear_db_runs()
@@ -129,7 +129,7 @@ class BasePythonTest:
         return kwargs
 
     def create_dag_run(self) -> DagRun:
-        return self.dag.create_dagrun(
+        return self.dag_maker.create_dagrun(
             state=DagRunState.RUNNING,
             start_date=self.dag_maker.start_date,
             session=self.dag_maker.session,
@@ -151,10 +151,12 @@ class BasePythonTest:
 
     def run_as_operator(self, fn, **kwargs):
         """Run task by direct call ``run`` method."""
-        with self.dag:
+        clear_db_runs()
+        with self.dag_maker(self.dag_id, 
template_searchpath=TEMPLATE_SEARCHPATH, serialized=True):
             task = self.opcls(task_id=self.task_id, python_callable=fn, 
**self.default_kwargs(**kwargs))
-
+        self.dag_maker.create_dagrun()
         task.run(start_date=self.default_date, end_date=self.default_date)
+        clear_db_runs()
         return task
 
     def run_as_task(self, fn, return_ti=False, **kwargs):
@@ -324,13 +326,13 @@ class TestPythonOperator(BasePythonTest):
         def func():
             return "test_return_value"
 
-        python_operator = PythonOperator(
-            task_id="python_operator",
-            python_callable=func,
-            dag=self.dag,
-            show_return_value_in_logs=False,
-            templates_exts=["test_ext"],
-        )
+        with self.dag_maker(self.dag_id, 
template_searchpath=TEMPLATE_SEARCHPATH, serialized=True):
+            python_operator = PythonOperator(
+                task_id="python_operator",
+                python_callable=func,
+                show_return_value_in_logs=False,
+                templates_exts=["test_ext"],
+            )
 
         assert python_operator.template_ext == ["test_ext"]
 
@@ -369,7 +371,8 @@ class TestBranchOperator(BasePythonTest):
         self.branch_2 = EmptyOperator(task_id="branch_2")
 
     def test_with_dag_run(self):
-        with self.dag:
+        clear_db_runs()
+        with self.dag_maker(self.dag_id, 
template_searchpath=TEMPLATE_SEARCHPATH, serialized=True):
 
             def f():
                 return "branch_1"
@@ -384,7 +387,8 @@ class TestBranchOperator(BasePythonTest):
         )
 
     def test_with_skip_in_branch_downstream_dependencies(self):
-        with self.dag:
+        clear_db_runs()
+        with self.dag_maker(self.dag_id, 
template_searchpath=TEMPLATE_SEARCHPATH, serialized=True):
 
             def f():
                 return "branch_1"
@@ -400,7 +404,8 @@ class TestBranchOperator(BasePythonTest):
         )
 
     def test_with_skip_in_branch_downstream_dependencies2(self):
-        with self.dag:
+        clear_db_runs()
+        with self.dag_maker(self.dag_id, 
template_searchpath=TEMPLATE_SEARCHPATH, serialized=True):
 
             def f():
                 return "branch_2"
@@ -416,7 +421,8 @@ class TestBranchOperator(BasePythonTest):
         )
 
     def test_xcom_push(self):
-        with self.dag:
+        clear_db_runs()
+        with self.dag_maker(self.dag_id, 
template_searchpath=TEMPLATE_SEARCHPATH, serialized=True):
 
             def f():
                 return "branch_1"
@@ -433,12 +439,13 @@ class TestBranchOperator(BasePythonTest):
         else:
             pytest.fail(f"{self.task_id!r} not found.")
 
+    @pytest.mark.skip_if_database_isolation_mode  # tests logic with 
clear_task_instances(), this needs DB access
     def test_clear_skipped_downstream_task(self):
         """
         After a downstream task is skipped by BranchPythonOperator, clearing 
the skipped task
         should not cause it to be executed.
         """
-        with self.dag:
+        with self.dag_non_serialized:
 
             def f():
                 return "branch_1"
@@ -492,6 +499,7 @@ class TestBranchOperator(BasePythonTest):
         with pytest.raises(AirflowException, match="Invalid tasks found: 
{'some_task_id'}"):
             ti.run()
 
+    @pytest.mark.skip_if_database_isolation_mode  # tests pure logic with 
run() method, can not run in isolation mode
     @pytest.mark.parametrize(
         "choice,expected_states",
         [
@@ -503,7 +511,7 @@ class TestBranchOperator(BasePythonTest):
         """
         Tests that BranchPythonOperator handles empty branches properly.
         """
-        with self.dag:
+        with self.dag_non_serialized:
 
             def f():
                 return choice
@@ -521,7 +529,7 @@ class TestBranchOperator(BasePythonTest):
 
         for task_id in task_ids:  # Mimic the specific order the scheduling 
would run the tests.
             task_instance = tis[task_id]
-            task_instance.refresh_from_task(self.dag.get_task(task_id))
+            
task_instance.refresh_from_task(self.dag_non_serialized.get_task(task_id))
             task_instance.run()
 
         def get_state(ti):
@@ -547,6 +555,7 @@ class TestShortCircuitOperator(BasePythonTest):
     }
     all_success_states = {"short_circuit": State.SUCCESS, "op1": 
State.SUCCESS, "op2": State.SUCCESS}
 
+    @pytest.mark.skip_if_database_isolation_mode  # tests pure logic with 
run() method, can not run in isolation mode
     @pytest.mark.parametrize(
         argnames=(
             "callable_return, test_ignore_downstream_trigger_rules, 
test_trigger_rule, expected_task_states"
@@ -645,7 +654,7 @@ class TestShortCircuitOperator(BasePythonTest):
         Checking the behavior of the ShortCircuitOperator in several scenarios 
enabling/disabling the skipping
         of downstream tasks, both short-circuiting modes, and various trigger 
rules of downstream tasks.
         """
-        with self.dag:
+        with self.dag_non_serialized:
             short_circuit = ShortCircuitOperator(
                 task_id="short_circuit",
                 python_callable=lambda: callable_return,
@@ -665,12 +674,13 @@ class TestShortCircuitOperator(BasePythonTest):
         assert self.op2.trigger_rule == test_trigger_rule
         self.assert_expected_task_states(dr, expected_task_states)
 
+    @pytest.mark.skip_if_database_isolation_mode  # tests logic with 
clear_task_instances(), this needs DB access
     def test_clear_skipped_downstream_task(self):
         """
         After a downstream task is skipped by ShortCircuitOperator, clearing 
the skipped task
         should not cause it to be executed.
         """
-        with self.dag:
+        with self.dag_non_serialized:
             short_circuit = ShortCircuitOperator(task_id="short_circuit", 
python_callable=lambda: False)
             short_circuit >> self.op1 >> self.op2
         dr = self.create_dag_run()
@@ -700,7 +710,8 @@ class TestShortCircuitOperator(BasePythonTest):
         self.assert_expected_task_states(dr, expected_states)
 
     def test_xcom_push(self):
-        with self.dag:
+        clear_db_runs()
+        with self.dag_maker(self.dag_id, 
template_searchpath=TEMPLATE_SEARCHPATH, serialized=True):
             short_op_push_xcom = ShortCircuitOperator(
                 task_id="push_xcom_from_shortcircuit", python_callable=lambda: 
"signature"
             )
@@ -716,8 +727,9 @@ class TestShortCircuitOperator(BasePythonTest):
         assert tis[0].xcom_pull(task_ids=short_op_push_xcom.task_id, 
key="return_value") == "signature"
         assert tis[0].xcom_pull(task_ids=short_op_no_push_xcom.task_id, 
key="return_value") is False
 
+    @pytest.mark.skip_if_database_isolation_mode  # tests pure logic with 
run() method, can not run in isolation mode
     def test_xcom_push_skipped_tasks(self):
-        with self.dag:
+        with self.dag_non_serialized:
             short_op_push_xcom = ShortCircuitOperator(
                 task_id="push_xcom_from_shortcircuit", python_callable=lambda: 
False
             )
@@ -730,8 +742,9 @@ class TestShortCircuitOperator(BasePythonTest):
             "skipped": ["empty_task"]
         }
 
+    @pytest.mark.skip_if_database_isolation_mode  # tests pure logic with 
run() method, can not run in isolation mode
     def test_mapped_xcom_push_skipped_tasks(self, session):
-        with self.dag:
+        with self.dag_non_serialized:
 
             @task_group
             def group(x):
@@ -1394,7 +1407,7 @@ class 
TestExternalPythonOperator(BaseTestPythonVirtualenvOperator):
             python_callable=f,
             task_id="task",
             python=sys.executable,
-            dag=self.dag,
+            dag=self.dag_non_serialized,
         )
 
         loads_mock.side_effect = DeserializingResultError
@@ -1491,7 +1504,8 @@ class 
BaseTestBranchPythonVirtualenvOperator(BaseTestPythonVirtualenvOperator):
             self.run_as_task(f, do_not_use_caching=True)
 
     def test_with_dag_run(self):
-        with self.dag:
+        clear_db_runs()
+        with self.dag_maker(self.dag_id, 
template_searchpath=TEMPLATE_SEARCHPATH, serialized=True):
 
             def f():
                 return "branch_1"
@@ -1506,7 +1520,8 @@ class 
BaseTestBranchPythonVirtualenvOperator(BaseTestPythonVirtualenvOperator):
         )
 
     def test_with_skip_in_branch_downstream_dependencies(self):
-        with self.dag:
+        clear_db_runs()
+        with self.dag_maker(self.dag_id, 
template_searchpath=TEMPLATE_SEARCHPATH, serialized=True):
 
             def f():
                 return "branch_1"
@@ -1522,7 +1537,8 @@ class 
BaseTestBranchPythonVirtualenvOperator(BaseTestPythonVirtualenvOperator):
         )
 
     def test_with_skip_in_branch_downstream_dependencies2(self):
-        with self.dag:
+        clear_db_runs()
+        with self.dag_maker(self.dag_id, 
template_searchpath=TEMPLATE_SEARCHPATH, serialized=True):
 
             def f():
                 return "branch_2"
@@ -1538,7 +1554,8 @@ class 
BaseTestBranchPythonVirtualenvOperator(BaseTestPythonVirtualenvOperator):
         )
 
     def test_xcom_push(self):
-        with self.dag:
+        clear_db_runs()
+        with self.dag_maker(self.dag_id, 
template_searchpath=TEMPLATE_SEARCHPATH, serialized=True):
 
             def f():
                 return "branch_1"
@@ -1555,12 +1572,14 @@ class 
BaseTestBranchPythonVirtualenvOperator(BaseTestPythonVirtualenvOperator):
         else:
             pytest.fail(f"{self.task_id!r} not found.")
 
+    @pytest.mark.skip_if_database_isolation_mode  # tests logic with 
clear_task_instances(), this needs DB access
     def test_clear_skipped_downstream_task(self):
         """
         After a downstream task is skipped by BranchPythonOperator, clearing 
the skipped task
         should not cause it to be executed.
         """
-        with self.dag:
+        clear_db_runs()
+        with self.dag_maker(self.dag_id, 
template_searchpath=TEMPLATE_SEARCHPATH, serialized=True):
 
             def f():
                 return "branch_1"
@@ -1721,6 +1740,7 @@ DEFAULT_ARGS = {
 }
 
 
[email protected]_if_database_isolation_mode  # tests pure logic with run() 
method, can not run in isolation mode
 @pytest.mark.usefixtures("clear_db")
 class TestCurrentContextRuntime:
     def test_context_in_task(self):
@@ -1736,6 +1756,7 @@ class TestCurrentContextRuntime:
 
 @pytest.mark.need_serialized_dag(False)
 class TestShortCircuitWithTeardown:
+    @pytest.mark.skip_if_database_isolation_mode  # tests pure logic with 
run() method, mix of pydantic and mock fails
     @pytest.mark.parametrize(
         "ignore_downstream_trigger_rules, with_teardown, should_skip, 
expected",
         [
@@ -1777,6 +1798,7 @@ class TestShortCircuitWithTeardown:
         else:
             op1.skip.assert_not_called()
 
+    @pytest.mark.skip_if_database_isolation_mode  # tests pure logic with 
run() method, mix of pydantic and mock fails
     @pytest.mark.parametrize("config", ["sequence", "parallel"])
     def test_short_circuit_with_teardowns_complicated(self, dag_maker, config):
         with dag_maker():
@@ -1804,6 +1826,7 @@ class TestShortCircuitWithTeardown:
             actual_skipped = set(op1.skip.call_args.kwargs["tasks"])
             assert actual_skipped == {s2, op2}
 
+    @pytest.mark.skip_if_database_isolation_mode  # tests pure logic with 
run() method, mix of pydantic and mock fails
     def test_short_circuit_with_teardowns_complicated_2(self, dag_maker):
         with dag_maker():
             s1 = PythonOperator(task_id="s1", python_callable=print).as_setup()
@@ -1833,6 +1856,7 @@ class TestShortCircuitWithTeardown:
             assert actual_kwargs["execution_date"] == dagrun.logical_date
             assert actual_skipped == {op3}
 
+    @pytest.mark.skip_if_database_isolation_mode  # tests pure logic with 
run() method, mix of pydantic and mock fails
     @pytest.mark.parametrize("level", [logging.DEBUG, logging.INFO])
     def test_short_circuit_with_teardowns_debug_level(self, dag_maker, level, 
clear_db):
         """
diff --git a/tests/sensors/test_python.py b/tests/sensors/test_python.py
index a971b8a146..60814620b3 100644
--- a/tests/sensors/test_python.py
+++ b/tests/sensors/test_python.py
@@ -45,7 +45,7 @@ class TestPythonSensor(BasePythonTest):
             self.run_as_task(lambda: 1 / 0)
 
     def test_python_sensor_xcom(self):
-        with self.dag:
+        with self.dag_non_serialized:
             task = self.opcls(
                 task_id=self.task_id,
                 python_callable=lambda: PokeReturnValue(True, "xcom"),
diff --git a/tests/serialization/test_serialized_objects.py 
b/tests/serialization/test_serialized_objects.py
index e06e7b253b..661ecbf5dc 100644
--- a/tests/serialization/test_serialized_objects.py
+++ b/tests/serialization/test_serialized_objects.py
@@ -31,7 +31,13 @@ from kubernetes.client import models as k8s
 from pendulum.tz.timezone import Timezone
 
 from airflow.datasets import Dataset, DatasetAlias, DatasetAliasEvent
-from airflow.exceptions import AirflowRescheduleException, SerializationError, 
TaskDeferred
+from airflow.exceptions import (
+    AirflowException,
+    AirflowFailException,
+    AirflowRescheduleException,
+    SerializationError,
+    TaskDeferred,
+)
 from airflow.jobs.job import Job
 from airflow.models.connection import Connection
 from airflow.models.dag import DAG, DagModel, DagTag
@@ -152,6 +158,10 @@ def equal_time(a: datetime, b: datetime) -> bool:
     return a.strftime("%s") == b.strftime("%s")
 
 
+def equal_exception(a: AirflowException, b: AirflowException) -> bool:
+    return a.__class__ == b.__class__ and str(a) == str(b)
+
+
 def equal_outlet_event_accessor(a: OutletEventAccessor, b: 
OutletEventAccessor) -> bool:
     return a.raw_key == b.raw_key and a.extra == b.extra and 
a.dataset_alias_event == b.dataset_alias_event
 
@@ -252,6 +262,16 @@ class MockLazySelectSequence(LazySelectSequence):
             DAT.DATASET_EVENT_ACCESSOR,
             equal_outlet_event_accessor,
         ),
+        (
+            AirflowException("test123 wohoo!"),
+            DAT.AIRFLOW_EXC_SER,
+            equal_exception,
+        ),
+        (
+            AirflowFailException("uuups, failed :-("),
+            DAT.AIRFLOW_EXC_SER,
+            equal_exception,
+        ),
     ],
 )
 def test_serialize_deserialize(input, encoded_type, cmp_func):

Reply via email to