This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new daccc75b06 Make task.run work with DB isolation (#41154)
daccc75b06 is described below
commit daccc75b064c735636027b20f3edb82da69ffab1
Author: Jarek Potiuk <[email protected]>
AuthorDate: Thu Aug 1 09:22:19 2024 +0200
Make task.run work with DB isolation (#41154)
This PR adds possibility to handle tests that perform task.run
call in DB isolation mode. This call is treated specially - i.e.
it will be possible for the run() method to use the DB directly.
to set up the task but then when _run_raw_task below will use
the DB, there will be an error raised.
That allows to fix the `test_core.py` and likely a number of
other tests that rely on the run() method to run the tests.
This required to improve the _set_ti_attrs method to handle
dag_run inside the task instance pydantic to task instance
mapping - because we are actually using TaskInstance for those
tests not TaskInstancePydantic under the hood.
Related: #41067
---
airflow/models/taskinstance.py | 33 ++++++++++++++++++++++++++++++---
airflow/settings.py | 31 ++++++++++++++++++++++++++-----
tests/core/test_core.py | 3 +++
tests/datasets/test_manager.py | 3 +++
4 files changed, 62 insertions(+), 8 deletions(-)
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index fb3fb46f6d..9210659d4c 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -242,7 +242,6 @@ def _run_raw_task(
ti.test_mode = test_mode
ti.refresh_from_task(ti.task, pool_override=pool)
ti.refresh_from_db(session=session)
-
ti.job_id = job_id
ti.hostname = get_hostname()
ti.pid = os.getpid()
@@ -800,7 +799,7 @@ def _execute_task(task_instance: TaskInstance |
TaskInstancePydantic, context: C
return result
-def _set_ti_attrs(target, source):
+def _set_ti_attrs(target, source, include_dag_run=False):
# Fields ordered per model definition
target.start_date = source.start_date
target.end_date = source.end_date
@@ -827,6 +826,27 @@ def _set_ti_attrs(target, source):
target.next_method = source.next_method
target.next_kwargs = source.next_kwargs
+ if include_dag_run:
+ target.execution_date = source.execution_date
+ target.dag_run.id = source.dag_run.id
+ target.dag_run.dag_id = source.dag_run.dag_id
+ target.dag_run.queued_at = source.dag_run.queued_at
+ target.dag_run.execution_date = source.dag_run.execution_date
+ target.dag_run.start_date = source.dag_run.start_date
+ target.dag_run.end_date = source.dag_run.end_date
+ target.dag_run.state = source.dag_run.state
+ target.dag_run.run_id = source.dag_run.run_id
+ target.dag_run.creating_job_id = source.dag_run.creating_job_id
+ target.dag_run.external_trigger = source.dag_run.external_trigger
+ target.dag_run.run_type = source.dag_run.run_type
+ target.dag_run.conf = source.dag_run.conf
+ target.dag_run.data_interval_start = source.dag_run.data_interval_start
+ target.dag_run.data_interval_end = source.dag_run.data_interval_end
+ target.dag_run.last_scheduling_decision =
source.dag_run.last_scheduling_decision
+ target.dag_run.dag_hash = source.dag_run.dag_hash
+ target.dag_run.updated_at = source.dag_run.updated_at
+ target.dag_run.log_template_id = source.dag_run.log_template_id
+
def _refresh_from_db(
*,
@@ -859,7 +879,14 @@ def _refresh_from_db(
)
if ti:
- _set_ti_attrs(task_instance, ti)
+ from airflow.serialization.pydantic.taskinstance import
TaskInstancePydantic
+
+ include_dag_run = isinstance(ti, TaskInstancePydantic)
+ # in case of internal API, what we get is TaskInstancePydantic value,
and we are supposed
+ # to also update dag_run information as it might not be available. We
cannot always do it in
+ # case ti is TaskInstance, because it might be detached/ not loaded
yet and dag_run might
+ # not be available.
+ _set_ti_attrs(task_instance, ti, include_dag_run=include_dag_run)
else:
task_instance.state = None
diff --git a/airflow/settings.py b/airflow/settings.py
index d76f2c148e..f1a0826dba 100644
--- a/airflow/settings.py
+++ b/airflow/settings.py
@@ -303,6 +303,8 @@ class TracebackSession:
AIRFLOW_PATH = os.path.dirname(os.path.dirname(__file__))
AIRFLOW_TESTS_PATH = os.path.join(AIRFLOW_PATH, "tests")
AIRFLOW_SETTINGS_PATH = os.path.join(AIRFLOW_PATH, "airflow", "settings.py")
+AIRFLOW_UTILS_SESSION_PATH = os.path.join(AIRFLOW_PATH, "airflow", "utils",
"session.py")
+AIRFLOW_MODELS_BASEOPERATOR_PATH = os.path.join(AIRFLOW_PATH, "airflow",
"models", "baseoperator.py")
class TracebackSessionForTests:
@@ -352,14 +354,33 @@ class TracebackSessionForTests:
:return: True if the object was created from test code, False
otherwise.
"""
self.traceback = traceback.extract_stack()
- if any(filename.endswith("conftest.py") for filename, _, _, _ in
self.traceback):
+ airflow_frames = [
+ tb
+ for tb in self.traceback
+ if tb.filename.startswith(AIRFLOW_PATH)
+ and not tb.filename == AIRFLOW_SETTINGS_PATH
+ and not tb.filename == AIRFLOW_UTILS_SESSION_PATH
+ ]
+ if any(filename.endswith("conftest.py") for filename, _, _, _ in
airflow_frames):
+ # This is a fixture call
return True, None
- for tb in self.traceback[::-1]:
- # Skip first two settings.py file (will be always here - because
we call it from here
- if tb.filename == AIRFLOW_SETTINGS_PATH:
- continue
+ if (
+ len(airflow_frames) >= 2
+ and airflow_frames[-2].filename.startswith(AIRFLOW_TESTS_PATH)
+ and airflow_frames[-1].filename == AIRFLOW_MODELS_BASEOPERATOR_PATH
+ and airflow_frames[-1].name == "run"
+ ):
+ # This is baseoperator run method that is called directly from the
test code and this is
+ # usual pattern where we create a session in the test code to
create dag_runs for tests.
+ # If `run` code will be run inside a real "airflow" code the stack
trace would be longer
+ # and it would not be directly called from the test code. Also if
subsequently any of the
+ # run_task() method called later from the task code will attempt
to execute any DB
+ # method, the stack trace will be longer and we will catch it as
"illegal" call.
+ return True, None
+ for tb in airflow_frames[::-1]:
if tb.filename.startswith(AIRFLOW_PATH):
if tb.filename.startswith(AIRFLOW_TESTS_PATH):
+ # this is a session created directly in the test code
return True, None
else:
return False, tb
diff --git a/tests/core/test_core.py b/tests/core/test_core.py
index 17dbdee70c..27b5b33843 100644
--- a/tests/core/test_core.py
+++ b/tests/core/test_core.py
@@ -144,6 +144,7 @@ class TestCore:
task.run(start_date=execution_date, end_date=execution_date)
ti = TI(task=task, run_id=dr.run_id)
+ ti.refresh_from_db()
context = ti.get_template_context()
# next_ds should be the execution date for manually triggered runs
@@ -178,6 +179,8 @@ class TestCore:
task2.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
ti1 = TI(task=task1, run_id=dr.run_id)
ti2 = TI(task=task2, run_id=dr.run_id)
+ ti1.refresh_from_db()
+ ti2.refresh_from_db()
context1 = ti1.get_template_context()
context2 = ti2.get_template_context()
diff --git a/tests/datasets/test_manager.py b/tests/datasets/test_manager.py
index 3d3f4dca92..1e7b4fda40 100644
--- a/tests/datasets/test_manager.py
+++ b/tests/datasets/test_manager.py
@@ -35,6 +35,9 @@ from tests.listeners import dataset_listener
pytestmark = pytest.mark.db_test
+pytest.importorskip("pydantic", minversion="2.0.0")
+
+
@pytest.fixture
def mock_task_instance():
return TaskInstancePydantic(