This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 615e1eceff Apply task instance mutation hook consistently (#38440)
615e1eceff is described below

commit 615e1eceffcb5c3f30b7f137d4f9d2b482fffcbc
Author: Jens Scheffler <[email protected]>
AuthorDate: Tue Mar 26 01:04:07 2024 +0100

    Apply task instance mutation hook consistently (#38440)
    
    * Apply task instance mutation hook consistently
    
    * Add test for cluster policy applied in pytest
---
 airflow/models/taskinstance.py    |  3 +++
 tests/models/test_taskinstance.py | 16 +++++++++++++---
 2 files changed, 16 insertions(+), 3 deletions(-)

diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 7619d06989..9968da5898 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -98,6 +98,7 @@ from airflow.models.taskreschedule import TaskReschedule
 from airflow.models.xcom import LazyXComAccess, XCom
 from airflow.plugins_manager import integrate_macros_plugins
 from airflow.sentry import Sentry
+from airflow.settings import task_instance_mutation_hook
 from airflow.stats import Stats
 from airflow.templates import SandboxedEnvironment
 from airflow.ti_deps.dep_context import DepContext
@@ -943,6 +944,8 @@ def _refresh_from_task(
     task_instance.executor_config = task.executor_config
     task_instance.operator = task.task_type
     task_instance.custom_operator_name = getattr(task, "custom_operator_name", 
None)
+    # Re-apply cluster policy here so that task default do not overload 
previous data
+    task_instance_mutation_hook(task_instance)
 
 
 def _record_task_map_for_downstreams(
diff --git a/tests/models/test_taskinstance.py 
b/tests/models/test_taskinstance.py
index 74a803f941..02069d382c 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -3348,10 +3348,20 @@ class TestTaskInstance:
 
 
 @pytest.mark.parametrize("pool_override", [None, "test_pool2"])
-def test_refresh_from_task(pool_override):
[email protected]("queue_by_policy", [None, "forced_queue"])
+def test_refresh_from_task(pool_override, queue_by_policy, monkeypatch):
+    default_queue = "test_queue"
+    expected_queue = queue_by_policy or default_queue
+    if queue_by_policy:
+        # Apply a dummy cluster policy to check if it is always applied
+        def mock_policy(task_instance: TaskInstance):
+            task_instance.queue = queue_by_policy
+
+        
monkeypatch.setattr("airflow.models.taskinstance.task_instance_mutation_hook", 
mock_policy)
+
     task = EmptyOperator(
         task_id="empty",
-        queue="test_queue",
+        queue=default_queue,
         pool="test_pool1",
         pool_slots=3,
         priority_weight=10,
@@ -3362,7 +3372,7 @@ def test_refresh_from_task(pool_override):
     ti = TI(task, run_id=None)
     ti.refresh_from_task(task, pool_override=pool_override)
 
-    assert ti.queue == task.queue
+    assert ti.queue == expected_queue
 
     if pool_override:
         assert ti.pool == pool_override

Reply via email to