This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new daccc75b06 Make task.run work with DB isolation (#41154)
daccc75b06 is described below

commit daccc75b064c735636027b20f3edb82da69ffab1
Author: Jarek Potiuk <[email protected]>
AuthorDate: Thu Aug 1 09:22:19 2024 +0200

    Make task.run work with DB isolation (#41154)
    
    This PR adds possibility to handle tests that perform task.run
    call in DB isolation mode. This call is treated specially - i.e.
    it will be possible for the run() method to use the DB directly.
    to set up the task but then when _run_raw_task below will use
    the DB, there will be an error raised.
    
    That allows to fix the `test_core.py` and likely a number of
    other tests that rely on the run() method to run the tests.
    
    This required to improve the _set_ti_attrs method to handle
    dag_run inside the task instance pydantic to task instance
    mapping - because we are actually using TaskInstance for those
    tests not TaskInstancePydantic under the hood.
    
    Related: #41067
---
 airflow/models/taskinstance.py | 33 ++++++++++++++++++++++++++++++---
 airflow/settings.py            | 31 ++++++++++++++++++++++++++-----
 tests/core/test_core.py        |  3 +++
 tests/datasets/test_manager.py |  3 +++
 4 files changed, 62 insertions(+), 8 deletions(-)

diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index fb3fb46f6d..9210659d4c 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -242,7 +242,6 @@ def _run_raw_task(
     ti.test_mode = test_mode
     ti.refresh_from_task(ti.task, pool_override=pool)
     ti.refresh_from_db(session=session)
-
     ti.job_id = job_id
     ti.hostname = get_hostname()
     ti.pid = os.getpid()
@@ -800,7 +799,7 @@ def _execute_task(task_instance: TaskInstance | 
TaskInstancePydantic, context: C
     return result
 
 
-def _set_ti_attrs(target, source):
+def _set_ti_attrs(target, source, include_dag_run=False):
     # Fields ordered per model definition
     target.start_date = source.start_date
     target.end_date = source.end_date
@@ -827,6 +826,27 @@ def _set_ti_attrs(target, source):
     target.next_method = source.next_method
     target.next_kwargs = source.next_kwargs
 
+    if include_dag_run:
+        target.execution_date = source.execution_date
+        target.dag_run.id = source.dag_run.id
+        target.dag_run.dag_id = source.dag_run.dag_id
+        target.dag_run.queued_at = source.dag_run.queued_at
+        target.dag_run.execution_date = source.dag_run.execution_date
+        target.dag_run.start_date = source.dag_run.start_date
+        target.dag_run.end_date = source.dag_run.end_date
+        target.dag_run.state = source.dag_run.state
+        target.dag_run.run_id = source.dag_run.run_id
+        target.dag_run.creating_job_id = source.dag_run.creating_job_id
+        target.dag_run.external_trigger = source.dag_run.external_trigger
+        target.dag_run.run_type = source.dag_run.run_type
+        target.dag_run.conf = source.dag_run.conf
+        target.dag_run.data_interval_start = source.dag_run.data_interval_start
+        target.dag_run.data_interval_end = source.dag_run.data_interval_end
+        target.dag_run.last_scheduling_decision = 
source.dag_run.last_scheduling_decision
+        target.dag_run.dag_hash = source.dag_run.dag_hash
+        target.dag_run.updated_at = source.dag_run.updated_at
+        target.dag_run.log_template_id = source.dag_run.log_template_id
+
 
 def _refresh_from_db(
     *,
@@ -859,7 +879,14 @@ def _refresh_from_db(
     )
 
     if ti:
-        _set_ti_attrs(task_instance, ti)
+        from airflow.serialization.pydantic.taskinstance import 
TaskInstancePydantic
+
+        include_dag_run = isinstance(ti, TaskInstancePydantic)
+        # in case of internal API, what we get is TaskInstancePydantic value, 
and we are supposed
+        # to also update dag_run information as it might not be available. We 
cannot always do it in
+        # case ti is TaskInstance, because it might be detached/ not loaded 
yet and dag_run might
+        # not be available.
+        _set_ti_attrs(task_instance, ti, include_dag_run=include_dag_run)
     else:
         task_instance.state = None
 
diff --git a/airflow/settings.py b/airflow/settings.py
index d76f2c148e..f1a0826dba 100644
--- a/airflow/settings.py
+++ b/airflow/settings.py
@@ -303,6 +303,8 @@ class TracebackSession:
 AIRFLOW_PATH = os.path.dirname(os.path.dirname(__file__))
 AIRFLOW_TESTS_PATH = os.path.join(AIRFLOW_PATH, "tests")
 AIRFLOW_SETTINGS_PATH = os.path.join(AIRFLOW_PATH, "airflow", "settings.py")
+AIRFLOW_UTILS_SESSION_PATH = os.path.join(AIRFLOW_PATH, "airflow", "utils", 
"session.py")
+AIRFLOW_MODELS_BASEOPERATOR_PATH = os.path.join(AIRFLOW_PATH, "airflow", 
"models", "baseoperator.py")
 
 
 class TracebackSessionForTests:
@@ -352,14 +354,33 @@ class TracebackSessionForTests:
         :return: True if the object was created from test code, False 
otherwise.
         """
         self.traceback = traceback.extract_stack()
-        if any(filename.endswith("conftest.py") for filename, _, _, _ in 
self.traceback):
+        airflow_frames = [
+            tb
+            for tb in self.traceback
+            if tb.filename.startswith(AIRFLOW_PATH)
+            and not tb.filename == AIRFLOW_SETTINGS_PATH
+            and not tb.filename == AIRFLOW_UTILS_SESSION_PATH
+        ]
+        if any(filename.endswith("conftest.py") for filename, _, _, _ in 
airflow_frames):
+            # This is a fixture call
             return True, None
-        for tb in self.traceback[::-1]:
-            # Skip first two settings.py file (will be always here - because 
we call it from here
-            if tb.filename == AIRFLOW_SETTINGS_PATH:
-                continue
+        if (
+            len(airflow_frames) >= 2
+            and airflow_frames[-2].filename.startswith(AIRFLOW_TESTS_PATH)
+            and airflow_frames[-1].filename == AIRFLOW_MODELS_BASEOPERATOR_PATH
+            and airflow_frames[-1].name == "run"
+        ):
+            # This is baseoperator run method that is called directly from the 
test code and this is
+            # usual pattern where we create a session in the test code to 
create dag_runs for tests.
+            # If `run` code will be run inside a real "airflow" code the stack 
trace would be longer
+            # and it would not be directly called from the test code. Also if 
subsequently any of the
+            # run_task() method called later from the task code will attempt 
to execute any DB
+            # method, the stack trace will be longer and we will catch it as 
"illegal" call.
+            return True, None
+        for tb in airflow_frames[::-1]:
             if tb.filename.startswith(AIRFLOW_PATH):
                 if tb.filename.startswith(AIRFLOW_TESTS_PATH):
+                    # this is a session created directly in the test code
                     return True, None
                 else:
                     return False, tb
diff --git a/tests/core/test_core.py b/tests/core/test_core.py
index 17dbdee70c..27b5b33843 100644
--- a/tests/core/test_core.py
+++ b/tests/core/test_core.py
@@ -144,6 +144,7 @@ class TestCore:
         task.run(start_date=execution_date, end_date=execution_date)
 
         ti = TI(task=task, run_id=dr.run_id)
+        ti.refresh_from_db()
         context = ti.get_template_context()
 
         # next_ds should be the execution date for manually triggered runs
@@ -178,6 +179,8 @@ class TestCore:
         task2.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
         ti1 = TI(task=task1, run_id=dr.run_id)
         ti2 = TI(task=task2, run_id=dr.run_id)
+        ti1.refresh_from_db()
+        ti2.refresh_from_db()
         context1 = ti1.get_template_context()
         context2 = ti2.get_template_context()
 
diff --git a/tests/datasets/test_manager.py b/tests/datasets/test_manager.py
index 3d3f4dca92..1e7b4fda40 100644
--- a/tests/datasets/test_manager.py
+++ b/tests/datasets/test_manager.py
@@ -35,6 +35,9 @@ from tests.listeners import dataset_listener
 pytestmark = pytest.mark.db_test
 
 
+pytest.importorskip("pydantic", minversion="2.0.0")
+
+
 @pytest.fixture
 def mock_task_instance():
     return TaskInstancePydantic(

Reply via email to