Repository: incubator-airflow Updated Branches: refs/heads/master fbba5ef7c -> dd2bc8cb9
[AIRFLOW-192] Add weight_rule param to BaseOperator Improved task generation performance significantly by using sets of task_ids and dag_ids instead of lists when calculating total priority weight. Closes #2941 from wongwill86/performance-latest Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/dd2bc8cb Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/dd2bc8cb Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/dd2bc8cb Branch: refs/heads/master Commit: dd2bc8cb971d25087a35db16d12592f759ecbc6a Parents: fbba5ef Author: wongwill86 <[email protected]> Authored: Thu Jan 18 16:09:40 2018 +0100 Committer: Bolke de Bruin <[email protected]> Committed: Thu Jan 18 16:09:46 2018 +0100 ---------------------------------------------------------------------- airflow/models.py | 102 +++++++++++++++++++++++++++++++------- airflow/utils/weight_rule.py | 33 ++++++++++++ tests/models.py | 93 ++++++++++++++++++++++++++++++++++ 3 files changed, 210 insertions(+), 18 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/dd2bc8cb/airflow/models.py ---------------------------------------------------------------------- diff --git a/airflow/models.py b/airflow/models.py index 08c4b52..c5233ec 100755 --- a/airflow/models.py +++ b/airflow/models.py @@ -19,7 +19,6 @@ from __future__ import unicode_literals from future.standard_library import install_aliases -install_aliases() from builtins import str from builtins import object, bytes import copy @@ -84,8 +83,11 @@ from airflow.utils.operator_resources import Resources from airflow.utils.state import State from airflow.utils.timeout import timeout from airflow.utils.trigger_rule import TriggerRule +from airflow.utils.weight_rule import WeightRule from airflow.utils.log.logging_mixin import LoggingMixin +install_aliases() + Base = declarative_base() ID_LEN = 250 XCOM_RETURN_KEY = 'return_value' @@ -2073,6 +2075,29 @@ class BaseOperator(LoggingMixin): This allows the executor to trigger higher priority tasks before others when things get backed up. :type priority_weight: int + :param weight_rule: weighting method used for the effective total + priority weight of the task. Options are: + ``{ downstream | upstream | absolute }`` default is ``downstream`` + When set to ``downstream`` the effective weight of the task is the + aggregate sum of all downstream descendants. As a result, upstream + tasks will have higher weight and will be scheduled more aggressively + when using positive weight values. This is useful when you have + multiple dag run instances and desire to have all upstream tasks to + complete for all runs before each dag can continue processing + downstream tasks. When set to ``upstream`` the effective weight is the + aggregate sum of all upstream ancestors. This is the opposite where + downtream tasks have higher weight and will be scheduled more + aggressively when using positive weight values. This is useful when you + have multiple dag run instances and prefer to have each dag complete + before starting upstream tasks of other dags. When set to + ``absolute``, the effective weight is the exact ``priority_weight`` + specified without additional weighting. You may want to do this when + you know exactly what priority weight each task should have. + Additionally, when set to ``absolute``, there is bonus effect of + significantly speeding up the task creation process as for very large + DAGS. Options can be set as string or using the constants defined in + the static class ``airflow.utils.WeightRule`` + :type weight_rule: str :param pool: the slot pool this task should run in, slot pools are a way to limit concurrency for certain tasks :type pool: str @@ -2150,6 +2175,7 @@ class BaseOperator(LoggingMixin): default_args=None, adhoc=False, priority_weight=1, + weight_rule=WeightRule.DOWNSTREAM, queue=configuration.get('celery', 'default_queue'), pool=None, sla=None, @@ -2190,7 +2216,7 @@ class BaseOperator(LoggingMixin): "The trigger_rule must be one of {all_triggers}," "'{d}.{t}'; received '{tr}'." .format(all_triggers=TriggerRule.all_triggers, - d=dag.dag_id, t=task_id, tr=trigger_rule)) + d=dag.dag_id if dag else "", t=task_id, tr=trigger_rule)) self.trigger_rule = trigger_rule self.depends_on_past = depends_on_past @@ -2224,6 +2250,14 @@ class BaseOperator(LoggingMixin): self.params = params or {} # Available in templates! self.adhoc = adhoc self.priority_weight = priority_weight + if not WeightRule.is_valid(weight_rule): + raise AirflowException( + "The weight_rule must be one of {all_weight_rules}," + "'{d}.{t}'; received '{tr}'." + .format(all_weight_rules=WeightRule.all_weight_rules, + d=dag.dag_id if dag else "", t=task_id, tr=weight_rule)) + self.weight_rule = weight_rule + self.resources = Resources(**(resources or {})) self.run_as_user = run_as_user self.task_concurrency = task_concurrency @@ -2402,10 +2436,19 @@ class BaseOperator(LoggingMixin): @property def priority_weight_total(self): - return sum([ - t.priority_weight - for t in self.get_flat_relatives(upstream=False) - ]) + self.priority_weight + if self.weight_rule == WeightRule.ABSOLUTE: + return self.priority_weight + elif self.weight_rule == WeightRule.DOWNSTREAM: + upstream = False + elif self.weight_rule == WeightRule.UPSTREAM: + upstream = True + else: + upstream = False + + return self.priority_weight + sum( + map(lambda task_id: self._dag.task_dict[task_id].priority_weight, + self.get_flat_relative_ids(upstream=upstream)) + ) def pre_execute(self, context): """ @@ -2608,17 +2651,30 @@ class BaseOperator(LoggingMixin): TI.execution_date <= end_date, ).order_by(TI.execution_date).all() - def get_flat_relatives(self, upstream=False, l=None): + def get_flat_relative_ids(self, upstream=False, found_descendants=None): + """ + Get a flat list of relatives' ids, either upstream or downstream. + """ + + if not found_descendants: + found_descendants = set() + relative_ids = self.get_direct_relative_ids(upstream) + + for relative_id in relative_ids: + if relative_id not in found_descendants: + found_descendants.add(relative_id) + relative_task = self._dag.task_dict[relative_id] + relative_task.get_flat_relative_ids(upstream, + found_descendants) + + return found_descendants + + def get_flat_relatives(self, upstream=False): """ Get a flat list of relatives, either upstream or downstream. """ - if not l: - l = [] - for t in self.get_direct_relatives(upstream): - if not is_in(t, l): - l.append(t) - t.get_flat_relatives(upstream, l) - return l + return list(map(lambda task_id: self._dag.task_dict[task_id], + self.get_flat_relative_ids(upstream))) def detect_downstream_cycle(self, task=None): """ @@ -2664,6 +2720,16 @@ class BaseOperator(LoggingMixin): self.log.info('Rendering template for %s', attr) self.log.info(content) + def get_direct_relative_ids(self, upstream=False): + """ + Get the direct relative ids to the current task, upstream or + downstream. + """ + if upstream: + return self._upstream_task_ids + else: + return self._downstream_task_ids + def get_direct_relatives(self, upstream=False): """ Get the direct relatives to the current task, upstream or @@ -2704,14 +2770,14 @@ class BaseOperator(LoggingMixin): # relationships can only be set if the tasks share a single DAG. Tasks # without a DAG are assigned to that DAG. - dags = set(t.dag for t in [self] + task_list if t.has_dag()) + dags = {t._dag.dag_id: t.dag for t in [self] + task_list if t.has_dag()} if len(dags) > 1: raise AirflowException( 'Tried to set relationships between tasks in ' - 'more than one DAG: {}'.format(dags)) + 'more than one DAG: {}'.format(dags.values())) elif len(dags) == 1: - dag = list(dags)[0] + dag = dags.popitem()[1] else: raise AirflowException( "Tried to create relationships between tasks that don't have " @@ -4739,7 +4805,7 @@ class DagRun(Base, LoggingMixin): ti.state = State.REMOVED # check for missing tasks - for task in dag.tasks: + for task in six.itervalues(dag.task_dict): if task.adhoc: continue http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/dd2bc8cb/airflow/utils/weight_rule.py ---------------------------------------------------------------------- diff --git a/airflow/utils/weight_rule.py b/airflow/utils/weight_rule.py new file mode 100644 index 0000000..fde0d90 --- /dev/null +++ b/airflow/utils/weight_rule.py @@ -0,0 +1,33 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import unicode_literals + +from builtins import object + + +class WeightRule(object): + DOWNSTREAM = 'downstream' + UPSTREAM = 'upstream' + ABSOLUTE = 'absolute' + + @classmethod + def is_valid(cls, weight_rule): + return weight_rule in cls.all_weight_rules() + + @classmethod + def all_weight_rules(cls): + return [getattr(cls, attr) + for attr in dir(cls) + if not attr.startswith("__") and not callable(getattr(cls, attr))] http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/dd2bc8cb/tests/models.py ---------------------------------------------------------------------- diff --git a/tests/models.py b/tests/models.py index 3bab3cf..f0879eb 100644 --- a/tests/models.py +++ b/tests/models.py @@ -23,6 +23,8 @@ import os import pendulum import unittest import time +import six +import re from airflow import configuration, models, settings, AirflowException from airflow.exceptions import AirflowSkipException @@ -39,6 +41,7 @@ from airflow.operators.python_operator import PythonOperator from airflow.operators.python_operator import ShortCircuitOperator from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep from airflow.utils import timezone +from airflow.utils.weight_rule import WeightRule from airflow.utils.state import State from airflow.utils.trigger_rule import TriggerRule from mock import patch @@ -201,6 +204,96 @@ class DagTest(unittest.TestCase): self.assertEquals(tuple(), dag.topological_sort()) + def test_dag_task_priority_weight_total(self): + width = 5 + depth = 5 + weight = 5 + pattern = re.compile('stage(\\d*).(\\d*)') + # Fully connected parallel tasks. i.e. every task at each parallel + # stage is dependent on every task in the previous stage. + # Default weight should be calculated using downstream descendants + with DAG('dag', start_date=DEFAULT_DATE, + default_args={'owner': 'owner1'}) as dag: + pipeline = [ + [DummyOperator( + task_id='stage{}.{}'.format(i, j), priority_weight=weight) + for j in range(0, width)] for i in range(0, depth) + ] + for d, stage in enumerate(pipeline): + if d == 0: + continue + for current_task in stage: + for prev_task in pipeline[d - 1]: + current_task.set_upstream(prev_task) + + for task in six.itervalues(dag.task_dict): + match = pattern.match(task.task_id) + task_depth = int(match.group(1)) + # the sum of each stages after this task + itself + correct_weight = ((depth - (task_depth + 1)) * width + 1) * weight + + calculated_weight = task.priority_weight_total + self.assertEquals(calculated_weight, correct_weight) + + # Same test as above except use 'upstream' for weight calculation + weight = 3 + with DAG('dag', start_date=DEFAULT_DATE, + default_args={'owner': 'owner1'}) as dag: + pipeline = [ + [DummyOperator( + task_id='stage{}.{}'.format(i, j), priority_weight=weight, + weight_rule=WeightRule.UPSTREAM) + for j in range(0, width)] for i in range(0, depth) + ] + for d, stage in enumerate(pipeline): + if d == 0: + continue + for current_task in stage: + for prev_task in pipeline[d - 1]: + current_task.set_upstream(prev_task) + + for task in six.itervalues(dag.task_dict): + match = pattern.match(task.task_id) + task_depth = int(match.group(1)) + # the sum of each stages after this task + itself + correct_weight = ((task_depth) * width + 1) * weight + + calculated_weight = task.priority_weight_total + self.assertEquals(calculated_weight, correct_weight) + + # Same test as above except use 'absolute' for weight calculation + weight = 10 + with DAG('dag', start_date=DEFAULT_DATE, + default_args={'owner': 'owner1'}) as dag: + pipeline = [ + [DummyOperator( + task_id='stage{}.{}'.format(i, j), priority_weight=weight, + weight_rule=WeightRule.ABSOLUTE) + for j in range(0, width)] for i in range(0, depth) + ] + for d, stage in enumerate(pipeline): + if d == 0: + continue + for current_task in stage: + for prev_task in pipeline[d - 1]: + current_task.set_upstream(prev_task) + + for task in six.itervalues(dag.task_dict): + match = pattern.match(task.task_id) + task_depth = int(match.group(1)) + # the sum of each stages after this task + itself + correct_weight = weight + + calculated_weight = task.priority_weight_total + self.assertEquals(calculated_weight, correct_weight) + + # Test if we enter an invalid weight rule + with DAG('dag', start_date=DEFAULT_DATE, + default_args={'owner': 'owner1'}) as dag: + with self.assertRaises(AirflowException): + DummyOperator(task_id='should_fail', weight_rule='no rule') + + def test_get_num_task_instances(self): test_dag_id = 'test_get_num_task_instances_dag' test_task_id = 'task_1'
