This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new eda4329d65 Fix almost 100 tests for taskinstance for DB isolation mode 
(#41296)
eda4329d65 is described below

commit eda4329d65a466786522488f159782803f9eeb29
Author: Jarek Potiuk <[email protected]>
AuthorDate: Wed Aug 7 08:43:43 2024 +0200

    Fix almost 100 tests for taskinstance for DB isolation mode (#41296)
    
    Related: #41067
---
 tests/models/test_taskinstance.py | 22 +++++++++++++++-------
 1 file changed, 15 insertions(+), 7 deletions(-)

diff --git a/tests/models/test_taskinstance.py 
b/tests/models/test_taskinstance.py
index 32158dc00d..7be4eda365 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -82,7 +82,7 @@ from airflow.operators.python import PythonOperator
 from airflow.sensors.base import BaseSensorOperator
 from airflow.sensors.python import PythonSensor
 from airflow.serialization.serialized_objects import SerializedBaseOperator, 
SerializedDAG
-from airflow.settings import TIMEZONE
+from airflow.settings import TIMEZONE, TracebackSessionForTests
 from airflow.stats import Stats
 from airflow.ti_deps.dep_context import DepContext
 from airflow.ti_deps.dependencies_deps import REQUEUEABLE_DEPS, RUNNING_DEPS
@@ -104,7 +104,7 @@ from tests.test_utils.config import conf_vars
 from tests.test_utils.db import clear_db_connections, clear_db_runs
 from tests.test_utils.mock_operators import MockOperator
 
-pytestmark = pytest.mark.db_test
+pytestmark = [pytest.mark.db_test]
 
 
 @pytest.fixture
@@ -288,7 +288,7 @@ class TestTaskInstance:
         assert not ti.test_mode
 
     @patch.object(DAG, "get_concurrency_reached")
-    def test_requeue_over_dag_concurrency(self, mock_concurrency_reached, 
create_task_instance):
+    def test_requeue_over_dag_concurrency(self, mock_concurrency_reached, 
create_task_instance, dag_maker):
         mock_concurrency_reached.return_value = True
 
         ti = create_task_instance(
@@ -1403,7 +1403,9 @@ class TestTaskInstance:
             assert task.start_date is not None
             run_date = task.start_date + datetime.timedelta(days=5)
 
-        ti = 
dag_maker.create_dagrun(execution_date=run_date).get_task_instance(downstream.task_id)
+        dr = dag_maker.create_dagrun(execution_date=run_date)
+        dag_maker.session.commit()
+        ti = dr.get_task_instance(downstream.task_id)
         ti.task = downstream
 
         dep_results = TriggerRuleDep()._evaluate_trigger_rule(
@@ -1413,6 +1415,8 @@ class TestTaskInstance:
         )
         completed = all(dep.passed for dep in dep_results)
 
+        ti = dr.get_task_instance(downstream.task_id)
+
         assert completed == expect_passed
         assert ti.state == expect_state
 
@@ -1511,15 +1515,16 @@ class TestTaskInstance:
             do_something_else.expand(i=nums)
 
         dr = dag_maker.create_dagrun()
-
+        dag_maker.session.commit()
         monkeypatch.setattr(_UpstreamTIStates, "calculate", lambda *_: 
upstream_states)
         ti = dr.get_task_instance("do_something_else", session=session)
         ti.map_index = 0
         for map_index in range(1, 5):
-            ti = TaskInstance(ti.task, run_id=dr.run_id, map_index=map_index)
+            ti = TaskInstance(dr.task_instances[-1].task, run_id=dr.run_id, 
map_index=map_index)
             session.add(ti)
             ti.dag_run = dr
         session.flush()
+        session.commit()
         downstream = ti.task
         ti = dr.get_task_instance(task_id="do_something_else", map_index=3, 
session=session)
         ti.task = downstream
@@ -1528,7 +1533,10 @@ class TestTaskInstance:
             dep_context=DepContext(flag_upstream_failed=flag_upstream_failed),
             session=session,
         )
+        TracebackSessionForTests.set_allow_db_access(session, True)
         completed = all(dep.passed for dep in dep_results)
+        TracebackSessionForTests.set_allow_db_access(session, False)
+        ti = dr.get_task_instance(task_id="do_something_else", map_index=3, 
session=session)
 
         assert completed == expect_completed
         assert ti.state == expect_state
@@ -2946,7 +2954,7 @@ class TestTaskInstance:
         dag_maker,
     ) -> list:
         dag_id = "test_previous_dates"
-        with dag_maker(dag_id=dag_id, schedule=schedule_interval, 
catchup=catchup):
+        with dag_maker(dag_id=dag_id, schedule=schedule_interval, 
catchup=catchup, serialized=True):
             task = EmptyOperator(task_id="task")
 
         def get_test_ti(execution_date: pendulum.DateTime, state: str) -> TI:

Reply via email to