This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 615e1eceff Apply task instance mutation hook consistently (#38440)
615e1eceff is described below
commit 615e1eceffcb5c3f30b7f137d4f9d2b482fffcbc
Author: Jens Scheffler <[email protected]>
AuthorDate: Tue Mar 26 01:04:07 2024 +0100
Apply task instance mutation hook consistently (#38440)
* Apply task instance mutation hook consistently
* Add test for cluster policy applied in pytest
---
airflow/models/taskinstance.py | 3 +++
tests/models/test_taskinstance.py | 16 +++++++++++++---
2 files changed, 16 insertions(+), 3 deletions(-)
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 7619d06989..9968da5898 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -98,6 +98,7 @@ from airflow.models.taskreschedule import TaskReschedule
from airflow.models.xcom import LazyXComAccess, XCom
from airflow.plugins_manager import integrate_macros_plugins
from airflow.sentry import Sentry
+from airflow.settings import task_instance_mutation_hook
from airflow.stats import Stats
from airflow.templates import SandboxedEnvironment
from airflow.ti_deps.dep_context import DepContext
@@ -943,6 +944,8 @@ def _refresh_from_task(
task_instance.executor_config = task.executor_config
task_instance.operator = task.task_type
task_instance.custom_operator_name = getattr(task, "custom_operator_name",
None)
+ # Re-apply cluster policy here so that task default do not overload
previous data
+ task_instance_mutation_hook(task_instance)
def _record_task_map_for_downstreams(
diff --git a/tests/models/test_taskinstance.py
b/tests/models/test_taskinstance.py
index 74a803f941..02069d382c 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -3348,10 +3348,20 @@ class TestTaskInstance:
@pytest.mark.parametrize("pool_override", [None, "test_pool2"])
-def test_refresh_from_task(pool_override):
[email protected]("queue_by_policy", [None, "forced_queue"])
+def test_refresh_from_task(pool_override, queue_by_policy, monkeypatch):
+ default_queue = "test_queue"
+ expected_queue = queue_by_policy or default_queue
+ if queue_by_policy:
+ # Apply a dummy cluster policy to check if it is always applied
+ def mock_policy(task_instance: TaskInstance):
+ task_instance.queue = queue_by_policy
+
+
monkeypatch.setattr("airflow.models.taskinstance.task_instance_mutation_hook",
mock_policy)
+
task = EmptyOperator(
task_id="empty",
- queue="test_queue",
+ queue=default_queue,
pool="test_pool1",
pool_slots=3,
priority_weight=10,
@@ -3362,7 +3372,7 @@ def test_refresh_from_task(pool_override):
ti = TI(task, run_id=None)
ti.refresh_from_task(task, pool_override=pool_override)
- assert ti.queue == task.queue
+ assert ti.queue == expected_queue
if pool_override:
assert ti.pool == pool_override