This is an automated email from the ASF dual-hosted git repository.
eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new eda4329d65 Fix almost 100 tests for taskinstance for DB isolation mode
(#41296)
eda4329d65 is described below
commit eda4329d65a466786522488f159782803f9eeb29
Author: Jarek Potiuk <[email protected]>
AuthorDate: Wed Aug 7 08:43:43 2024 +0200
Fix almost 100 tests for taskinstance for DB isolation mode (#41296)
Related: #41067
---
tests/models/test_taskinstance.py | 22 +++++++++++++++-------
1 file changed, 15 insertions(+), 7 deletions(-)
diff --git a/tests/models/test_taskinstance.py
b/tests/models/test_taskinstance.py
index 32158dc00d..7be4eda365 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -82,7 +82,7 @@ from airflow.operators.python import PythonOperator
from airflow.sensors.base import BaseSensorOperator
from airflow.sensors.python import PythonSensor
from airflow.serialization.serialized_objects import SerializedBaseOperator,
SerializedDAG
-from airflow.settings import TIMEZONE
+from airflow.settings import TIMEZONE, TracebackSessionForTests
from airflow.stats import Stats
from airflow.ti_deps.dep_context import DepContext
from airflow.ti_deps.dependencies_deps import REQUEUEABLE_DEPS, RUNNING_DEPS
@@ -104,7 +104,7 @@ from tests.test_utils.config import conf_vars
from tests.test_utils.db import clear_db_connections, clear_db_runs
from tests.test_utils.mock_operators import MockOperator
-pytestmark = pytest.mark.db_test
+pytestmark = [pytest.mark.db_test]
@pytest.fixture
@@ -288,7 +288,7 @@ class TestTaskInstance:
assert not ti.test_mode
@patch.object(DAG, "get_concurrency_reached")
- def test_requeue_over_dag_concurrency(self, mock_concurrency_reached,
create_task_instance):
+ def test_requeue_over_dag_concurrency(self, mock_concurrency_reached,
create_task_instance, dag_maker):
mock_concurrency_reached.return_value = True
ti = create_task_instance(
@@ -1403,7 +1403,9 @@ class TestTaskInstance:
assert task.start_date is not None
run_date = task.start_date + datetime.timedelta(days=5)
- ti =
dag_maker.create_dagrun(execution_date=run_date).get_task_instance(downstream.task_id)
+ dr = dag_maker.create_dagrun(execution_date=run_date)
+ dag_maker.session.commit()
+ ti = dr.get_task_instance(downstream.task_id)
ti.task = downstream
dep_results = TriggerRuleDep()._evaluate_trigger_rule(
@@ -1413,6 +1415,8 @@ class TestTaskInstance:
)
completed = all(dep.passed for dep in dep_results)
+ ti = dr.get_task_instance(downstream.task_id)
+
assert completed == expect_passed
assert ti.state == expect_state
@@ -1511,15 +1515,16 @@ class TestTaskInstance:
do_something_else.expand(i=nums)
dr = dag_maker.create_dagrun()
-
+ dag_maker.session.commit()
monkeypatch.setattr(_UpstreamTIStates, "calculate", lambda *_:
upstream_states)
ti = dr.get_task_instance("do_something_else", session=session)
ti.map_index = 0
for map_index in range(1, 5):
- ti = TaskInstance(ti.task, run_id=dr.run_id, map_index=map_index)
+ ti = TaskInstance(dr.task_instances[-1].task, run_id=dr.run_id,
map_index=map_index)
session.add(ti)
ti.dag_run = dr
session.flush()
+ session.commit()
downstream = ti.task
ti = dr.get_task_instance(task_id="do_something_else", map_index=3,
session=session)
ti.task = downstream
@@ -1528,7 +1533,10 @@ class TestTaskInstance:
dep_context=DepContext(flag_upstream_failed=flag_upstream_failed),
session=session,
)
+ TracebackSessionForTests.set_allow_db_access(session, True)
completed = all(dep.passed for dep in dep_results)
+ TracebackSessionForTests.set_allow_db_access(session, False)
+ ti = dr.get_task_instance(task_id="do_something_else", map_index=3,
session=session)
assert completed == expect_completed
assert ti.state == expect_state
@@ -2946,7 +2954,7 @@ class TestTaskInstance:
dag_maker,
) -> list:
dag_id = "test_previous_dates"
- with dag_maker(dag_id=dag_id, schedule=schedule_interval,
catchup=catchup):
+ with dag_maker(dag_id=dag_id, schedule=schedule_interval,
catchup=catchup, serialized=True):
task = EmptyOperator(task_id="task")
def get_test_ti(execution_date: pendulum.DateTime, state: str) -> TI: