This is an automated email from the ASF dual-hosted git repository. kaxilnaik pushed a commit to branch v1-10-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 4d15311003bd7d8b4e9a9637dd5b56e25dc9dbfb Author: Chao-Han Tsai <[email protected]> AuthorDate: Fri Jun 12 21:03:17 2020 -0700 Add task instance mutation hook (#8852) (cherry picked from commit bacb05df38532f81a9480f3c3439c6a75e580567) --- airflow/models/dagrun.py | 7 +++++-- airflow/settings.py | 23 +++++++++++++-------- docs/concepts.rst | 50 ++++++++++++++++++++++++++++++++++++--------- tests/models/test_dagrun.py | 29 ++++++++++++++++++++++++++ 4 files changed, 89 insertions(+), 20 deletions(-) diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py index c5fcde5..ec9ecc8 100644 --- a/airflow/models/dagrun.py +++ b/airflow/models/dagrun.py @@ -27,8 +27,8 @@ from sqlalchemy.ext.declarative import declared_attr from sqlalchemy.orm import synonym from sqlalchemy.orm.session import Session from airflow.exceptions import AirflowException -from airflow.models.base import Base, ID_LEN -from airflow.settings import Stats +from airflow.models.base import ID_LEN, Base +from airflow.settings import Stats, task_instance_mutation_hook from airflow.ti_deps.dep_context import DepContext from airflow.utils import timezone from airflow.utils.db import provide_session @@ -364,6 +364,7 @@ class DagRun(Base, LoggingMixin): # check for removed or restored tasks task_ids = [] for ti in tis: + task_instance_mutation_hook(ti) task_ids.append(ti.task_id) task = None try: @@ -385,6 +386,7 @@ class DagRun(Base, LoggingMixin): "removed from DAG '{}'".format(ti, dag)) Stats.incr("task_restored_to_dag.{}".format(dag.dag_id), 1, 1) ti.state = State.NONE + session.merge(ti) # check for missing tasks for task in six.itervalues(dag.task_dict): @@ -396,6 +398,7 @@ class DagRun(Base, LoggingMixin): "task_instance_created-{}".format(task.__class__.__name__), 1, 1) ti = TaskInstance(task, self.execution_date) + task_instance_mutation_hook(ti) session.add(ti) session.commit() diff --git a/airflow/settings.py b/airflow/settings.py index 8b33de7..2195617 100644 --- a/airflow/settings.py +++ b/airflow/settings.py @@ -170,16 +170,11 @@ STATE_COLORS = { def policy(task): """ - This policy setting allows altering tasks right before they - are executed. It allows administrator to rewire some task parameters. - - Note that the ``Task`` object has a reference to the DAG - object. So you can use the attributes of all of these to define your - policy. + This policy setting allows altering tasks after they are loaded in + the DagBag. It allows administrator to rewire some task parameters. To define policy, add a ``airflow_local_settings`` module - to your PYTHONPATH that defines this ``policy`` function. It receives - a ``Task`` object and can alter it where needed. + to your PYTHONPATH that defines this ``policy`` function. Here are a few examples of how this can be useful: @@ -192,6 +187,18 @@ def policy(task): """ +def task_instance_mutation_hook(task_instance): + """ + This setting allows altering task instances before they are queued by + the Airflow scheduler. + + To define task_instance_mutation_hook, add a ``airflow_local_settings`` module + to your PYTHONPATH that defines this ``task_instance_mutation_hook`` function. + + This could be used, for instance, to modify the task instance during retries. + """ + + def pod_mutation_hook(pod): """ This setting allows altering ``kubernetes.client.models.V1Pod`` object diff --git a/docs/concepts.rst b/docs/concepts.rst index 89479d4..3c15ba9 100644 --- a/docs/concepts.rst +++ b/docs/concepts.rst @@ -998,10 +998,18 @@ state. Cluster Policy ============== -Your local Airflow settings file can define a ``policy`` function that -has the ability to mutate task attributes based on other task or DAG -attributes. It receives a single argument as a reference to task objects, -and is expected to alter its attributes. +In case you want to apply cluster-wide mutations to the Airflow tasks, +you can either mutate the task right after the DAG is loaded or +mutate the task instance before task execution. + +Mutate tasks after DAG loaded +----------------------------- + +To mutate the task right after the DAG is parsed, you can define +a ``policy`` function in ``airflow_local_settings.py`` that mutates the +task based on other task or DAG attributes (through ``task.dag``). +It receives a single argument as a reference to the task object and you can alter +its attributes. For example, this function could apply a specific queue property when using a specific operator, or enforce a task timeout policy, making sure @@ -1017,13 +1025,35 @@ may look like inside your ``airflow_local_settings.py``: if task.timeout > timedelta(hours=48): task.timeout = timedelta(hours=48) -To define policy, add a ``airflow_local_settings`` module to your PYTHONPATH -or to AIRFLOW_HOME/config folder that defines this ``policy`` function. It receives a ``TaskInstance`` -object and can alter it where needed. -Please note, cluster policy currently applies to task only though you can access DAG via ``task.dag`` property. -Also, cluster policy will have precedence over task attributes defined in DAG -meaning if ``task.sla`` is defined in dag and also mutated via cluster policy then later will have precedence. +Please note, cluster policy will have precedence over task +attributes defined in DAG meaning if ``task.sla`` is defined +in dag and also mutated via cluster policy then later will have precedence. + + +Mutate task instances before task execution +------------------------------------------- + +To mutate the task instance before the task execution, you can define a +``task_instance_mutation_hook`` function in ``airflow_local_settings.py`` +that mutates the task instance. + +For example, this function re-routes the task to execute in a different +queue during retries: + +.. code:: python + + def task_instance_mutation_hook(ti): + if ti.try_number >= 1: + ti.queue = 'retry_queue' + + +Where to put ``airflow_local_settings.py``? +------------------------------------------- + +Add a ``airflow_local_settings.py`` file to your ``$PYTHONPATH`` +or to ``$AIRFLOW_HOME/config`` folder. + Documentation & Notes ===================== diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py index c90563d..c149c00 100644 --- a/tests/models/test_dagrun.py +++ b/tests/models/test_dagrun.py @@ -20,6 +20,8 @@ import datetime import unittest +from parameterized import parameterized + from airflow import settings, models from airflow.jobs import BackfillJob from airflow.models import DAG, DagRun, clear_task_instances @@ -29,6 +31,7 @@ from airflow.operators.python_operator import ShortCircuitOperator from airflow.utils import timezone from airflow.utils.state import State from airflow.utils.trigger_rule import TriggerRule +from tests.compat import mock from tests.models import DEFAULT_DATE @@ -560,3 +563,29 @@ class DagRunTest(unittest.TestCase): dagrun.verify_integrity() flaky_ti.refresh_from_db() self.assertEqual(State.NONE, flaky_ti.state) + + @parameterized.expand([(state,) for state in State.task_states]) + @mock.patch('airflow.models.dagrun.task_instance_mutation_hook') + def test_task_instance_mutation_hook(self, state, mock_hook): + def mutate_task_instance(task_instance): + if task_instance.queue == 'queue1': + task_instance.queue = 'queue2' + else: + task_instance.queue = 'queue1' + + mock_hook.side_effect = mutate_task_instance + + dag = DAG('test_task_instance_mutation_hook', start_date=DEFAULT_DATE) + dag.add_task(DummyOperator(task_id='task_to_mutate', owner='test', queue='queue1')) + + dagrun = self.create_dag_run(dag) + task = dagrun.get_task_instances()[0] + session = settings.Session() + task.state = state + session.merge(task) + session.commit() + assert task.queue == 'queue2' + + dagrun.verify_integrity() + task = dagrun.get_task_instances()[0] + assert task.queue == 'queue1'
