This is an automated email from the ASF dual-hosted git repository. ephraimanierobi pushed a commit to branch v2-8-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 9be2ffc30411693d069074c6a51a243d1e61fdc0 Author: Hussein Awala <[email protected]> AuthorDate: Tue Nov 28 19:24:21 2023 +0200 Add a public interface for custom weight_rule implementation (#35210) * Add a public interface for custom weight_rule implementation * Remove _weight_strategy attribute * Move priority weight calculation to TI to support advanced strategies * Fix loading the var from mapped operators and simplify loading it from task * Update default value and deprecated the other one * Update task endpoint API spec * fix tests * Update docs and add dag example * Fix serialization test * revert change in spark provider * Update unit tests (cherry picked from commit 3385113e277f86b5f163a3509ba61590cfe7d8cc) --- airflow/api_connexion/openapi/v1.yaml | 7 ++ airflow/api_connexion/schemas/task_schema.py | 1 + airflow/config_templates/config.yml | 11 +++ .../example_priority_weight_strategy.py | 69 ++++++++++++++++ airflow/executors/base_executor.py | 2 +- airflow/executors/debug_executor.py | 2 +- ...2_8_0_add_priority_weight_strategy_to_task_.py} | 42 +++++----- airflow/models/abstractoperator.py | 20 ++++- airflow/models/baseoperator.py | 36 ++++++--- airflow/models/mappedoperator.py | 16 +++- airflow/models/taskinstance.py | 22 +++++- airflow/serialization/pydantic/taskinstance.py | 1 + airflow/task/priority_strategy.py | 91 ++++++++++++++++++++++ airflow/utils/db.py | 2 +- airflow/utils/weight_rule.py | 6 +- airflow/www/static/js/types/api-generated.ts | 10 ++- .../priority-weight.rst | 12 +-- docs/apache-airflow/img/airflow_erd.sha256 | 2 +- docs/apache-airflow/migrations-ref.rst | 4 +- .../api_connexion/endpoints/test_task_endpoint.py | 21 +++-- tests/api_connexion/schemas/test_task_schema.py | 6 +- tests/models/test_baseoperator.py | 12 ++- tests/models/test_dag.py | 20 +++++ tests/models/test_taskinstance.py | 1 + tests/serialization/test_dag_serialization.py | 3 +- tests/www/views/test_views_tasks.py | 7 ++ 26 files changed, 362 insertions(+), 64 deletions(-) diff --git a/airflow/api_connexion/openapi/v1.yaml b/airflow/api_connexion/openapi/v1.yaml index 5d0c58102a..1653470d91 100644 --- a/airflow/api_connexion/openapi/v1.yaml +++ b/airflow/api_connexion/openapi/v1.yaml @@ -3738,6 +3738,8 @@ components: readOnly: true weight_rule: $ref: "#/components/schemas/WeightRule" + priority_weight_strategy: + $ref: "#/components/schemas/PriorityWeightStrategy" ui_color: $ref: "#/components/schemas/Color" ui_fgcolor: @@ -4767,11 +4769,16 @@ components: WeightRule: description: Weight rule. type: string + nullable: true enum: - downstream - upstream - absolute + PriorityWeightStrategy: + description: Priority weight strategy. + type: string + HealthStatus: description: Health status type: string diff --git a/airflow/api_connexion/schemas/task_schema.py b/airflow/api_connexion/schemas/task_schema.py index ac1b465bb2..cd8ccdfd3b 100644 --- a/airflow/api_connexion/schemas/task_schema.py +++ b/airflow/api_connexion/schemas/task_schema.py @@ -57,6 +57,7 @@ class TaskSchema(Schema): retry_exponential_backoff = fields.Boolean(dump_only=True) priority_weight = fields.Number(dump_only=True) weight_rule = WeightRuleField(dump_only=True) + priority_weight_strategy = fields.String(dump_only=True) ui_color = ColorField(dump_only=True) ui_fgcolor = ColorField(dump_only=True) template_fields = fields.List(fields.String(), dump_only=True) diff --git a/airflow/config_templates/config.yml b/airflow/config_templates/config.yml index a25adc7206..072eaea86d 100644 --- a/airflow/config_templates/config.yml +++ b/airflow/config_templates/config.yml @@ -306,6 +306,17 @@ core: description: | The weighting method used for the effective total priority weight of the task version_added: 2.2.0 + version_deprecated: 2.8.0 + deprecation_reason: | + This option is deprecated and will be removed in Airflow 3.0. + Please use ``default_task_priority_weight_strategy`` instead. + type: string + example: ~ + default: ~ + default_task_priority_weight_strategy: + description: | + The strategy used for the effective total priority weight of the task + version_added: 2.8.0 type: string example: ~ default: "downstream" diff --git a/airflow/example_dags/example_priority_weight_strategy.py b/airflow/example_dags/example_priority_weight_strategy.py new file mode 100644 index 0000000000..5575d74a37 --- /dev/null +++ b/airflow/example_dags/example_priority_weight_strategy.py @@ -0,0 +1,69 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Example DAG demonstrating the usage of a custom PriorityWeightStrategy class.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pendulum + +from airflow.models.dag import DAG +from airflow.operators.python import PythonOperator +from airflow.task.priority_strategy import PriorityWeightStrategy + +if TYPE_CHECKING: + from airflow.models import TaskInstance + + +def success_on_third_attempt(ti: TaskInstance, **context): + if ti.try_number < 3: + raise Exception("Not yet") + + +class DecreasingPriorityStrategy(PriorityWeightStrategy): + """A priority weight strategy that decreases the priority weight with each attempt.""" + + def get_weight(self, ti: TaskInstance): + return max(3 - ti._try_number + 1, 1) + + +with DAG( + dag_id="example_priority_weight_strategy", + start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), + catchup=False, + schedule="@daily", + tags=["example"], + default_args={ + "retries": 3, + "retry_delay": pendulum.duration(seconds=10), + }, +) as dag: + fixed_weight_task = PythonOperator( + task_id="fixed_weight_task", + python_callable=success_on_third_attempt, + priority_weight_strategy="downstream", + ) + + decreasing_weight_task = PythonOperator( + task_id="decreasing_weight_task", + python_callable=success_on_third_attempt, + priority_weight_strategy=( + "airflow.example_dags.example_priority_weight_strategy.DecreasingPriorityStrategy" + ), + ) diff --git a/airflow/executors/base_executor.py b/airflow/executors/base_executor.py index 2791c938a4..babfe8e903 100644 --- a/airflow/executors/base_executor.py +++ b/airflow/executors/base_executor.py @@ -184,7 +184,7 @@ class BaseExecutor(LoggingMixin): self.queue_command( task_instance, command_list_to_run, - priority=task_instance.task.priority_weight_total, + priority=task_instance.priority_weight, queue=task_instance.task.queue, ) diff --git a/airflow/executors/debug_executor.py b/airflow/executors/debug_executor.py index be2b657b75..b601c2b7c9 100644 --- a/airflow/executors/debug_executor.py +++ b/airflow/executors/debug_executor.py @@ -109,7 +109,7 @@ class DebugExecutor(BaseExecutor): self.queue_command( task_instance, [str(task_instance)], # Just for better logging, it's not used anywhere - priority=task_instance.task.priority_weight_total, + priority=task_instance.priority_weight, queue=task_instance.task.queue, ) # Save params for TaskInstance._run_raw_task diff --git a/airflow/utils/weight_rule.py b/airflow/migrations/versions/0132_2_8_0_add_priority_weight_strategy_to_task_.py similarity index 50% copy from airflow/utils/weight_rule.py copy to airflow/migrations/versions/0132_2_8_0_add_priority_weight_strategy_to_task_.py index f65f2fa77e..8b3d30ba76 100644 --- a/airflow/utils/weight_rule.py +++ b/airflow/migrations/versions/0132_2_8_0_add_priority_weight_strategy_to_task_.py @@ -15,30 +15,34 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from __future__ import annotations -from enum import Enum +"""add priority_weight_strategy to task_instance -from airflow.compat.functools import cache +Revision ID: 624ecf3b6a5e +Revises: bd5dfbe21f88 +Create Date: 2023-10-29 02:01:34.774596 +""" -class WeightRule(str, Enum): - """Weight rules.""" +import sqlalchemy as sa +from alembic import op - DOWNSTREAM = "downstream" - UPSTREAM = "upstream" - ABSOLUTE = "absolute" - @classmethod - def is_valid(cls, weight_rule: str) -> bool: - """Check if weight rule is valid.""" - return weight_rule in cls.all_weight_rules() +# revision identifiers, used by Alembic. +revision = "624ecf3b6a5e" +down_revision = "bd5dfbe21f88" +branch_labels = None +depends_on = None +airflow_version = "2.8.0" - @classmethod - @cache - def all_weight_rules(cls) -> set[str]: - """Return all weight rules.""" - return set(cls.__members__.values()) - def __str__(self) -> str: - return self.value +def upgrade(): + """Apply add priority_weight_strategy to task_instance""" + with op.batch_alter_table("task_instance") as batch_op: + batch_op.add_column(sa.Column("priority_weight_strategy", sa.String(length=1000))) + + +def downgrade(): + """Unapply add priority_weight_strategy to task_instance""" + with op.batch_alter_table("task_instance") as batch_op: + batch_op.drop_column("priority_weight_strategy") diff --git a/airflow/models/abstractoperator.py b/airflow/models/abstractoperator.py index df0e6cb349..0145f7d149 100644 --- a/airflow/models/abstractoperator.py +++ b/airflow/models/abstractoperator.py @@ -19,6 +19,7 @@ from __future__ import annotations import datetime import inspect +import warnings from functools import cached_property from typing import TYPE_CHECKING, Any, Callable, ClassVar, Collection, Iterable, Iterator, Sequence @@ -70,8 +71,14 @@ DEFAULT_RETRY_DELAY: datetime.timedelta = datetime.timedelta( ) MAX_RETRY_DELAY: int = conf.getint("core", "max_task_retry_delay", fallback=24 * 60 * 60) -DEFAULT_WEIGHT_RULE: WeightRule = WeightRule( - conf.get("core", "default_task_weight_rule", fallback=WeightRule.DOWNSTREAM) +DEFAULT_WEIGHT_RULE: WeightRule | None = ( + WeightRule(conf.get("core", "default_task_weight_rule", fallback=None)) + if conf.get("core", "default_task_weight_rule", fallback=None) + else None +) + +DEFAULT_PRIORITY_WEIGHT_STRATEGY: str = conf.get( + "core", "default_task_priority_weight_strategy", fallback=WeightRule.DOWNSTREAM ) DEFAULT_TRIGGER_RULE: TriggerRule = TriggerRule.ALL_SUCCESS DEFAULT_TASK_EXECUTION_TIMEOUT: datetime.timedelta | None = conf.gettimedelta( @@ -98,7 +105,8 @@ class AbstractOperator(Templater, DAGNode): operator_class: type[BaseOperator] | dict[str, Any] - weight_rule: str + weight_rule: str | None + priority_weight_strategy: str priority_weight: int # Defines the operator level extra links. @@ -398,6 +406,12 @@ class AbstractOperator(Templater, DAGNode): - WeightRule.DOWNSTREAM - adds priority weight of all downstream tasks - WeightRule.UPSTREAM - adds priority weight of all upstream tasks """ + warnings.warn( + "Accessing `priority_weight_total` from AbstractOperator instance is deprecated." + " Please use `priority_weight` from task instance instead.", + DeprecationWarning, + stacklevel=2, + ) if self.weight_rule == WeightRule.ABSOLUTE: return self.priority_weight elif self.weight_rule == WeightRule.DOWNSTREAM: diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 2ba7ec8ad1..ca368555df 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -61,6 +61,7 @@ from airflow.models.abstractoperator import ( DEFAULT_OWNER, DEFAULT_POOL_SLOTS, DEFAULT_PRIORITY_WEIGHT, + DEFAULT_PRIORITY_WEIGHT_STRATEGY, DEFAULT_QUEUE, DEFAULT_RETRIES, DEFAULT_RETRY_DELAY, @@ -76,6 +77,7 @@ from airflow.models.pool import Pool from airflow.models.taskinstance import TaskInstance, clear_task_instances from airflow.models.taskmixin import DependencyMixin from airflow.serialization.enums import DagAttributeTypes +from airflow.task.priority_strategy import get_priority_weight_strategy from airflow.ti_deps.deps.not_in_retry_period_dep import NotInRetryPeriodDep from airflow.ti_deps.deps.not_previously_skipped_dep import NotPreviouslySkippedDep from airflow.ti_deps.deps.prev_dagrun_dep import PrevDagrunDep @@ -90,7 +92,6 @@ from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.setup_teardown import SetupTeardownContext from airflow.utils.trigger_rule import TriggerRule from airflow.utils.types import NOTSET -from airflow.utils.weight_rule import WeightRule from airflow.utils.xcom import XCOM_RETURN_KEY if TYPE_CHECKING: @@ -207,6 +208,7 @@ _PARTIAL_DEFAULTS = { "retry_exponential_backoff": False, "priority_weight": DEFAULT_PRIORITY_WEIGHT, "weight_rule": DEFAULT_WEIGHT_RULE, + "priority_weight_strategy": DEFAULT_PRIORITY_WEIGHT_STRATEGY, "inlets": [], "outlets": [], } @@ -240,6 +242,7 @@ def partial( retry_exponential_backoff: bool | ArgNotSet = NOTSET, priority_weight: int | ArgNotSet = NOTSET, weight_rule: str | ArgNotSet = NOTSET, + priority_weight_strategy: str | ArgNotSet = NOTSET, sla: timedelta | None | ArgNotSet = NOTSET, max_active_tis_per_dag: int | None | ArgNotSet = NOTSET, max_active_tis_per_dagrun: int | None | ArgNotSet = NOTSET, @@ -303,6 +306,7 @@ def partial( "retry_exponential_backoff": retry_exponential_backoff, "priority_weight": priority_weight, "weight_rule": weight_rule, + "priority_weight_strategy": priority_weight_strategy, "sla": sla, "max_active_tis_per_dag": max_active_tis_per_dag, "max_active_tis_per_dagrun": max_active_tis_per_dagrun, @@ -544,9 +548,9 @@ class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta): This allows the executor to trigger higher priority tasks before others when things get backed up. Set priority_weight as a higher number for more important tasks. - :param weight_rule: weighting method used for the effective total - priority weight of the task. Options are: - ``{ downstream | upstream | absolute }`` default is ``downstream`` + :param weight_rule: Deprecated field, please use ``priority_weight_strategy`` instead. + weighting method used for the effective total priority weight of the task. Options are: + ``{ downstream | upstream | absolute }`` default is ``None`` When set to ``downstream`` the effective weight of the task is the aggregate sum of all downstream descendants. As a result, upstream tasks will have higher weight and will be scheduled more aggressively @@ -566,6 +570,11 @@ class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta): significantly speeding up the task creation process as for very large DAGs. Options can be set as string or using the constants defined in the static class ``airflow.utils.WeightRule`` + :param priority_weight_strategy: weighting method used for the effective total priority weight + of the task. You can provide one of the following options: + ``{ downstream | upstream | absolute }`` or the path to a custom + strategy class that extends ``airflow.task.priority_strategy.PriorityWeightStrategy``. + Default is ``downstream``. :param queue: which queue to target when running this job. Not all executors implement queue management, the CeleryExecutor does support targeting specific queues. @@ -754,7 +763,8 @@ class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta): params: collections.abc.MutableMapping | None = None, default_args: dict | None = None, priority_weight: int = DEFAULT_PRIORITY_WEIGHT, - weight_rule: str = DEFAULT_WEIGHT_RULE, + weight_rule: str | None = DEFAULT_WEIGHT_RULE, + priority_weight_strategy: str = DEFAULT_PRIORITY_WEIGHT_STRATEGY, queue: str = DEFAULT_QUEUE, pool: str | None = None, pool_slots: int = DEFAULT_POOL_SLOTS, @@ -901,13 +911,17 @@ class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta): f"received '{type(priority_weight)}'." ) self.priority_weight = priority_weight - if not WeightRule.is_valid(weight_rule): - raise AirflowException( - f"The weight_rule must be one of " - f"{WeightRule.all_weight_rules},'{dag.dag_id if dag else ''}.{task_id}'; " - f"received '{weight_rule}'." - ) self.weight_rule = weight_rule + self.priority_weight_strategy = priority_weight_strategy + if weight_rule: + warnings.warn( + "weight_rule is deprecated. Please use `priority_weight_strategy` instead.", + DeprecationWarning, + stacklevel=2, + ) + self.priority_weight_strategy = weight_rule + # validate the priority weight strategy + get_priority_weight_strategy(self.priority_weight_strategy) self.resources = coerce_resources(resources) if task_concurrency and not max_active_tis_per_dag: # TODO: Remove in Airflow 3.0 diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py index 8174db145a..480c236758 100644 --- a/airflow/models/mappedoperator.py +++ b/airflow/models/mappedoperator.py @@ -32,6 +32,7 @@ from airflow.models.abstractoperator import ( DEFAULT_OWNER, DEFAULT_POOL_SLOTS, DEFAULT_PRIORITY_WEIGHT, + DEFAULT_PRIORITY_WEIGHT_STRATEGY, DEFAULT_QUEUE, DEFAULT_RETRIES, DEFAULT_RETRY_DELAY, @@ -48,6 +49,7 @@ from airflow.models.expandinput import ( ) from airflow.models.pool import Pool from airflow.serialization.enums import DagAttributeTypes +from airflow.task.priority_strategy import get_priority_weight_strategy from airflow.ti_deps.deps.mapped_task_expanded import MappedTaskIsExpanded from airflow.typing_compat import Literal from airflow.utils.context import context_update_for_unmapped @@ -329,6 +331,8 @@ class MappedOperator(AbstractOperator): f"SLAs are unsupported with mapped tasks. Please set `sla=None` for task " f"{self.task_id!r}." ) + # validate the priority weight strategy + get_priority_weight_strategy(self.priority_weight_strategy) @classmethod @cache @@ -471,8 +475,16 @@ class MappedOperator(AbstractOperator): return self.partial_kwargs.get("priority_weight", DEFAULT_PRIORITY_WEIGHT) @property - def weight_rule(self) -> str: # type: ignore[override] - return self.partial_kwargs.get("weight_rule", DEFAULT_WEIGHT_RULE) + def weight_rule(self) -> str | None: # type: ignore[override] + return self.partial_kwargs.get("weight_rule") or DEFAULT_WEIGHT_RULE + + @property + def priority_weight_strategy(self) -> str: # type: ignore[override] + return ( + self.weight_rule # for backward compatibility + or self.partial_kwargs.get("priority_weight_strategy") + or DEFAULT_PRIORITY_WEIGHT_STRATEGY + ) @property def sla(self) -> datetime.timedelta | None: diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index f041dcf208..7efc353d94 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -98,6 +98,7 @@ from airflow.models.xcom import LazyXComAccess, XCom from airflow.plugins_manager import integrate_macros_plugins from airflow.sentry import Sentry from airflow.stats import Stats +from airflow.task.priority_strategy import get_priority_weight_strategy from airflow.templates import SandboxedEnvironment from airflow.ti_deps.dep_context import DepContext from airflow.ti_deps.dependencies_deps import REQUEUEABLE_DEPS, RUNNING_DEPS @@ -130,7 +131,6 @@ TR = TaskReschedule _CURRENT_CONTEXT: list[Context] = [] log = logging.getLogger(__name__) - if TYPE_CHECKING: from datetime import datetime from pathlib import PurePath @@ -159,7 +159,6 @@ if TYPE_CHECKING: else: from sqlalchemy.ext.hybrid import hybrid_property - PAST_DEPENDS_MET = "past_depends_met" @@ -487,6 +486,7 @@ def _refresh_from_db( task_instance.pool_slots = ti.pool_slots or 1 task_instance.queue = ti.queue task_instance.priority_weight = ti.priority_weight + task_instance.priority_weight_strategy = ti.priority_weight_strategy task_instance.operator = ti.operator task_instance.custom_operator_name = ti.custom_operator_name task_instance.queued_dttm = ti.queued_dttm @@ -881,7 +881,13 @@ def _refresh_from_task( task_instance.queue = task.queue task_instance.pool = pool_override or task.pool task_instance.pool_slots = task.pool_slots - task_instance.priority_weight = task.priority_weight_total + with contextlib.suppress(Exception): + # This method is called from the different places, and sometimes the TI is not fully initialized + task_instance.priority_weight = get_priority_weight_strategy( + task.priority_weight_strategy + ).get_weight( + task_instance # type: ignore + ) task_instance.run_as_user = task.run_as_user # Do not set max_tries to task.retries here because max_tries is a cumulative # value that needs to be stored in the db. @@ -1216,6 +1222,7 @@ class TaskInstance(Base, LoggingMixin): pool_slots = Column(Integer, default=1, nullable=False) queue = Column(String(256)) priority_weight = Column(Integer) + priority_weight_strategy = Column(String(1000)) operator = Column(String(1000)) custom_operator_name = Column(String(1000)) queued_dttm = Column(UtcDateTime) @@ -1384,6 +1391,9 @@ class TaskInstance(Base, LoggingMixin): :meta private: """ + priority_weight = get_priority_weight_strategy(task.priority_weight_strategy).get_weight( + TaskInstance(task=task, run_id=run_id, map_index=map_index) + ) return { "dag_id": task.dag_id, "task_id": task.task_id, @@ -1394,7 +1404,8 @@ class TaskInstance(Base, LoggingMixin): "queue": task.queue, "pool": task.pool, "pool_slots": task.pool_slots, - "priority_weight": task.priority_weight_total, + "priority_weight": priority_weight, + "priority_weight_strategy": task.priority_weight_strategy, "run_as_user": task.run_as_user, "max_tries": task.retries, "executor_config": task.executor_config, @@ -3451,6 +3462,7 @@ class SimpleTaskInstance: key: TaskInstanceKey, run_as_user: str | None = None, priority_weight: int | None = None, + priority_weight_strategy: str | None = None, ): self.dag_id = dag_id self.task_id = task_id @@ -3464,6 +3476,7 @@ class SimpleTaskInstance: self.run_as_user = run_as_user self.pool = pool self.priority_weight = priority_weight + self.priority_weight_strategy = priority_weight_strategy self.queue = queue self.key = key @@ -3504,6 +3517,7 @@ class SimpleTaskInstance: key=ti.key, run_as_user=ti.run_as_user if hasattr(ti, "run_as_user") else None, priority_weight=ti.priority_weight if hasattr(ti, "priority_weight") else None, + priority_weight_strategy=ti.priority_weight_strategy, ) @classmethod diff --git a/airflow/serialization/pydantic/taskinstance.py b/airflow/serialization/pydantic/taskinstance.py index 106a31186e..2556027928 100644 --- a/airflow/serialization/pydantic/taskinstance.py +++ b/airflow/serialization/pydantic/taskinstance.py @@ -87,6 +87,7 @@ class TaskInstancePydantic(BaseModelPydantic, LoggingMixin): pool_slots: int queue: str priority_weight: Optional[int] + priority_weight_strategy: Optional[str] operator: str custom_operator_name: Optional[str] queued_dttm: Optional[str] diff --git a/airflow/task/priority_strategy.py b/airflow/task/priority_strategy.py new file mode 100644 index 0000000000..6e061ad706 --- /dev/null +++ b/airflow/task/priority_strategy.py @@ -0,0 +1,91 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Priority weight strategies for task scheduling.""" +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING + +from airflow.exceptions import AirflowException +from airflow.utils.module_loading import import_string + +if TYPE_CHECKING: + from airflow.models.taskinstance import TaskInstance + + +class PriorityWeightStrategy(ABC): + """Priority weight strategy interface.""" + + @abstractmethod + def get_weight(self, ti: TaskInstance): + """Get the priority weight of a task.""" + ... + + +class AbsolutePriorityWeightStrategy(PriorityWeightStrategy): + """Priority weight strategy that uses the task's priority weight directly.""" + + def get_weight(self, ti: TaskInstance): + return ti.task.priority_weight + + +class DownstreamPriorityWeightStrategy(PriorityWeightStrategy): + """Priority weight strategy that uses the sum of the priority weights of all downstream tasks.""" + + def get_weight(self, ti: TaskInstance): + dag = ti.task.get_dag() + if dag is None: + return ti.task.priority_weight + return ti.task.priority_weight + sum( + dag.task_dict[task_id].priority_weight + for task_id in ti.task.get_flat_relative_ids(upstream=False) + ) + + +class UpstreamPriorityWeightStrategy(PriorityWeightStrategy): + """Priority weight strategy that uses the sum of the priority weights of all upstream tasks.""" + + def get_weight(self, ti: TaskInstance): + dag = ti.task.get_dag() + if dag is None: + return ti.task.priority_weight + return ti.task.priority_weight + sum( + dag.task_dict[task_id].priority_weight for task_id in ti.task.get_flat_relative_ids(upstream=True) + ) + + +_airflow_priority_weight_strategies = { + "absolute": AbsolutePriorityWeightStrategy(), + "downstream": DownstreamPriorityWeightStrategy(), + "upstream": UpstreamPriorityWeightStrategy(), +} + + +def get_priority_weight_strategy(strategy_name: str) -> PriorityWeightStrategy: + """Get a priority weight strategy by name or class path.""" + if strategy_name not in _airflow_priority_weight_strategies: + try: + priority_strategy_class = import_string(strategy_name) + if not issubclass(priority_strategy_class, PriorityWeightStrategy): + raise AirflowException( + f"Priority strategy {priority_strategy_class} is not a subclass of PriorityWeightStrategy" + ) + _airflow_priority_weight_strategies[strategy_name] = priority_strategy_class() + except ImportError: + raise AirflowException(f"Unknown priority strategy {strategy_name}") + return _airflow_priority_weight_strategies[strategy_name] diff --git a/airflow/utils/db.py b/airflow/utils/db.py index 5b6bd4757e..b9509cf8d1 100644 --- a/airflow/utils/db.py +++ b/airflow/utils/db.py @@ -88,7 +88,7 @@ _REVISION_HEADS_MAP = { "2.6.0": "98ae134e6fff", "2.6.2": "c804e5c76e3e", "2.7.0": "405de8318b3a", - "2.8.0": "bd5dfbe21f88", + "2.8.0": "624ecf3b6a5e", } diff --git a/airflow/utils/weight_rule.py b/airflow/utils/weight_rule.py index f65f2fa77e..dd6c554c67 100644 --- a/airflow/utils/weight_rule.py +++ b/airflow/utils/weight_rule.py @@ -23,7 +23,11 @@ from airflow.compat.functools import cache class WeightRule(str, Enum): - """Weight rules.""" + """ + Weight rules. + + This class is deprecated and will be removed in Airflow 3 + """ DOWNSTREAM = "downstream" UPSTREAM = "upstream" diff --git a/airflow/www/static/js/types/api-generated.ts b/airflow/www/static/js/types/api-generated.ts index 55ade6179d..0771647159 100644 --- a/airflow/www/static/js/types/api-generated.ts +++ b/airflow/www/static/js/types/api-generated.ts @@ -1561,6 +1561,7 @@ export interface components { retry_exponential_backoff?: boolean; priority_weight?: number; weight_rule?: components["schemas"]["WeightRule"]; + priority_weight_strategy?: components["schemas"]["PriorityWeightStrategy"]; ui_color?: components["schemas"]["Color"]; ui_fgcolor?: components["schemas"]["Color"]; template_fields?: string[]; @@ -2234,9 +2235,11 @@ export interface components { | "always"; /** * @description Weight rule. - * @enum {string} + * @enum {string|null} */ - WeightRule: "downstream" | "upstream" | "absolute"; + WeightRule: ("downstream" | "upstream" | "absolute") | null; + /** @description Priority weight strategy. */ + PriorityWeightStrategy: string; /** * @description Health status * @enum {string|null} @@ -4952,6 +4955,9 @@ export type TriggerRule = CamelCasedPropertiesDeep< export type WeightRule = CamelCasedPropertiesDeep< components["schemas"]["WeightRule"] >; +export type PriorityWeightStrategy = CamelCasedPropertiesDeep< + components["schemas"]["PriorityWeightStrategy"] +>; export type HealthStatus = CamelCasedPropertiesDeep< components["schemas"]["HealthStatus"] >; diff --git a/docs/apache-airflow/administration-and-deployment/priority-weight.rst b/docs/apache-airflow/administration-and-deployment/priority-weight.rst index 87a9288ddc..3e064123af 100644 --- a/docs/apache-airflow/administration-and-deployment/priority-weight.rst +++ b/docs/apache-airflow/administration-and-deployment/priority-weight.rst @@ -22,12 +22,9 @@ Priority Weights ``priority_weight`` defines priorities in the executor queue. The default ``priority_weight`` is ``1``, and can be bumped to any integer. Moreover, each task has a true ``priority_weight`` that is calculated based on its -``weight_rule`` which defines weighting method used for the effective total priority weight of the task. +``priority_weight_strategy`` which defines weighting method used for the effective total priority weight of the task. -By default, Airflow's weighting method is ``downstream``. You can find other weighting methods in -:class:`airflow.utils.WeightRule`. - -There are three weighting methods. +Airflow has three weighting strategies: - downstream @@ -57,5 +54,10 @@ There are three weighting methods. significantly speeding up the task creation process as for very large DAGs +You can also implement your own weighting strategy by extending the class +:class:`~airflow.task.priority_strategy.PriorityWeightStrategy` and overriding the method +:meth:`~airflow.task.priority_strategy.PriorityWeightStrategy.get_weight`, the providing the path of your class +to the ``priority_weight_strategy`` parameter. + The ``priority_weight`` parameter can be used in conjunction with :ref:`concepts:pool`. diff --git a/docs/apache-airflow/img/airflow_erd.sha256 b/docs/apache-airflow/img/airflow_erd.sha256 index 301abf8a84..b2d9dbf5a2 100644 --- a/docs/apache-airflow/img/airflow_erd.sha256 +++ b/docs/apache-airflow/img/airflow_erd.sha256 @@ -1 +1 @@ -8229a936107bee851d6a39c791b842b11f295ffa308b18106e45298a50871493 \ No newline at end of file +4739d87664d779f93e39b09ca6e5e662d72f1fa88857d8b6e44d2f2557656753 \ No newline at end of file diff --git a/docs/apache-airflow/migrations-ref.rst b/docs/apache-airflow/migrations-ref.rst index 72f467f6fb..af23ce50ca 100644 --- a/docs/apache-airflow/migrations-ref.rst +++ b/docs/apache-airflow/migrations-ref.rst @@ -39,7 +39,9 @@ Here's the list of all the Database Migrations that are executed via when you ru +---------------------------------+-------------------+-------------------+--------------------------------------------------------------+ | Revision ID | Revises ID | Airflow Version | Description | +=================================+===================+===================+==============================================================+ -| ``bd5dfbe21f88`` (head) | ``f7bf2a57d0a6`` | ``2.8.0`` | Make connection login/password TEXT | +| ``624ecf3b6a5e`` (head) | ``bd5dfbe21f88`` | ``2.8.0`` | add priority_weight_strategy to task_instance | ++---------------------------------+-------------------+-------------------+--------------------------------------------------------------+ +| ``bd5dfbe21f88`` | ``f7bf2a57d0a6`` | ``2.8.0`` | Make connection login/password TEXT | +---------------------------------+-------------------+-------------------+--------------------------------------------------------------+ | ``f7bf2a57d0a6`` | ``375a816bbbf4`` | ``2.8.0`` | Add owner_display_name to (Audit) Log table | +---------------------------------+-------------------+-------------------+--------------------------------------------------------------+ diff --git a/tests/api_connexion/endpoints/test_task_endpoint.py b/tests/api_connexion/endpoints/test_task_endpoint.py index b8ef8dc0cf..d2b717bfc0 100644 --- a/tests/api_connexion/endpoints/test_task_endpoint.py +++ b/tests/api_connexion/endpoints/test_task_endpoint.py @@ -123,6 +123,7 @@ class TestGetTask(TestTaskEndpoint): "pool": "default_pool", "pool_slots": 1.0, "priority_weight": 1.0, + "priority_weight_strategy": "downstream", "queue": "default", "retries": 0.0, "retry_delay": {"__type": "TimeDelta", "days": 0, "seconds": 300, "microseconds": 0}, @@ -134,7 +135,7 @@ class TestGetTask(TestTaskEndpoint): "ui_color": "#e8f7e4", "ui_fgcolor": "#000", "wait_for_downstream": False, - "weight_rule": "downstream", + "weight_rule": None, "is_mapped": False, } response = self.client.get( @@ -158,6 +159,7 @@ class TestGetTask(TestTaskEndpoint): "pool": "default_pool", "pool_slots": 1.0, "priority_weight": 1.0, + "priority_weight_strategy": "downstream", "queue": "default", "retries": 0.0, "retry_delay": {"__type": "TimeDelta", "days": 0, "microseconds": 0, "seconds": 300}, @@ -169,7 +171,7 @@ class TestGetTask(TestTaskEndpoint): "ui_color": "#e8f7e4", "ui_fgcolor": "#000", "wait_for_downstream": False, - "weight_rule": "downstream", + "weight_rule": None, } response = self.client.get( f"/api/v1/dags/{self.mapped_dag_id}/tasks/{self.mapped_task_id}", @@ -209,6 +211,7 @@ class TestGetTask(TestTaskEndpoint): "pool": "default_pool", "pool_slots": 1.0, "priority_weight": 1.0, + "priority_weight_strategy": "downstream", "queue": "default", "retries": 0.0, "retry_delay": {"__type": "TimeDelta", "days": 0, "seconds": 300, "microseconds": 0}, @@ -220,7 +223,7 @@ class TestGetTask(TestTaskEndpoint): "ui_color": "#e8f7e4", "ui_fgcolor": "#000", "wait_for_downstream": False, - "weight_rule": "downstream", + "weight_rule": None, "is_mapped": False, } response = self.client.get( @@ -284,6 +287,7 @@ class TestGetTasks(TestTaskEndpoint): "pool": "default_pool", "pool_slots": 1.0, "priority_weight": 1.0, + "priority_weight_strategy": "downstream", "queue": "default", "retries": 0.0, "retry_delay": {"__type": "TimeDelta", "days": 0, "seconds": 300, "microseconds": 0}, @@ -295,7 +299,7 @@ class TestGetTasks(TestTaskEndpoint): "ui_color": "#e8f7e4", "ui_fgcolor": "#000", "wait_for_downstream": False, - "weight_rule": "downstream", + "weight_rule": None, "is_mapped": False, }, { @@ -314,6 +318,7 @@ class TestGetTasks(TestTaskEndpoint): "pool": "default_pool", "pool_slots": 1.0, "priority_weight": 1.0, + "priority_weight_strategy": "downstream", "queue": "default", "retries": 0.0, "retry_delay": {"__type": "TimeDelta", "days": 0, "seconds": 300, "microseconds": 0}, @@ -325,7 +330,7 @@ class TestGetTasks(TestTaskEndpoint): "ui_color": "#e8f7e4", "ui_fgcolor": "#000", "wait_for_downstream": False, - "weight_rule": "downstream", + "weight_rule": None, "is_mapped": False, }, ], @@ -354,6 +359,7 @@ class TestGetTasks(TestTaskEndpoint): "pool": "default_pool", "pool_slots": 1.0, "priority_weight": 1.0, + "priority_weight_strategy": "downstream", "queue": "default", "retries": 0.0, "retry_delay": {"__type": "TimeDelta", "days": 0, "microseconds": 0, "seconds": 300}, @@ -365,7 +371,7 @@ class TestGetTasks(TestTaskEndpoint): "ui_color": "#e8f7e4", "ui_fgcolor": "#000", "wait_for_downstream": False, - "weight_rule": "downstream", + "weight_rule": None, }, { "class_ref": { @@ -383,6 +389,7 @@ class TestGetTasks(TestTaskEndpoint): "pool": "default_pool", "pool_slots": 1.0, "priority_weight": 1.0, + "priority_weight_strategy": "downstream", "queue": "default", "retries": 0.0, "retry_delay": {"__type": "TimeDelta", "days": 0, "seconds": 300, "microseconds": 0}, @@ -394,7 +401,7 @@ class TestGetTasks(TestTaskEndpoint): "ui_color": "#e8f7e4", "ui_fgcolor": "#000", "wait_for_downstream": False, - "weight_rule": "downstream", + "weight_rule": None, "is_mapped": False, }, ], diff --git a/tests/api_connexion/schemas/test_task_schema.py b/tests/api_connexion/schemas/test_task_schema.py index 54403ebbf0..f76fa439e8 100644 --- a/tests/api_connexion/schemas/test_task_schema.py +++ b/tests/api_connexion/schemas/test_task_schema.py @@ -46,6 +46,7 @@ class TestTaskSchema: "pool": "default_pool", "pool_slots": 1.0, "priority_weight": 1.0, + "priority_weight_strategy": "downstream", "queue": "default", "retries": 0.0, "retry_delay": {"__type": "TimeDelta", "days": 0, "seconds": 300, "microseconds": 0}, @@ -57,7 +58,7 @@ class TestTaskSchema: "ui_color": "#e8f7e4", "ui_fgcolor": "#000", "wait_for_downstream": False, - "weight_rule": "downstream", + "weight_rule": None, "is_mapped": False, } assert expected == result @@ -93,6 +94,7 @@ class TestTaskCollectionSchema: "pool": "default_pool", "pool_slots": 1.0, "priority_weight": 1.0, + "priority_weight_strategy": "downstream", "queue": "default", "retries": 0.0, "retry_delay": {"__type": "TimeDelta", "days": 0, "seconds": 300, "microseconds": 0}, @@ -104,7 +106,7 @@ class TestTaskCollectionSchema: "ui_color": "#e8f7e4", "ui_fgcolor": "#000", "wait_for_downstream": False, - "weight_rule": "downstream", + "weight_rule": None, "is_mapped": False, } ], diff --git a/tests/models/test_baseoperator.py b/tests/models/test_baseoperator.py index fb46fd39c7..28b76f3684 100644 --- a/tests/models/test_baseoperator.py +++ b/tests/models/test_baseoperator.py @@ -784,12 +784,20 @@ class TestBaseOperator: def test_weight_rule_default(self): op = BaseOperator(task_id="test_task") - assert WeightRule.DOWNSTREAM == op.weight_rule + assert op.weight_rule is None - def test_weight_rule_override(self): + def test_priority_weight_strategy_default(self): + op = BaseOperator(task_id="test_task") + assert op.priority_weight_strategy == "downstream" + + def test_deprecated_weight_rule_override(self): op = BaseOperator(task_id="test_task", weight_rule="upstream") assert WeightRule.UPSTREAM == op.weight_rule + def test_priority_weight_strategy_override(self): + op = BaseOperator(task_id="test_task", priority_weight_strategy="upstream") + assert op.priority_weight_strategy == "upstream" + # ensure the default logging config is used for this test, no matter what ran before @pytest.mark.usefixtures("reset_logging_config") def test_logging_propogated_by_default(self, caplog): diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py index f7bf1ad6d0..ba5a047f56 100644 --- a/tests/models/test_dag.py +++ b/tests/models/test_dag.py @@ -29,6 +29,7 @@ from contextlib import redirect_stdout from datetime import timedelta from io import StringIO from pathlib import Path +from typing import TYPE_CHECKING from unittest import mock from unittest.mock import patch @@ -69,6 +70,7 @@ from airflow.operators.empty import EmptyOperator from airflow.operators.python import PythonOperator from airflow.operators.subdag import SubDagOperator from airflow.security import permissions +from airflow.task.priority_strategy import PriorityWeightStrategy from airflow.templates import NativeEnvironment, SandboxedEnvironment from airflow.timetables.base import DagRunInfo, DataInterval, TimeRestriction, Timetable from airflow.timetables.simple import ( @@ -93,6 +95,9 @@ from tests.test_utils.db import clear_db_dags, clear_db_datasets, clear_db_runs, from tests.test_utils.mapping import expand_mapped_task from tests.test_utils.timetables import cron_timetable, delta_timetable +if TYPE_CHECKING: + from airflow.models.taskinstance import TaskInstance + pytestmark = pytest.mark.db_test TEST_DATE = datetime_tz(2015, 1, 2, 0, 0) @@ -116,6 +121,11 @@ def clear_datasets(): clear_db_datasets() +class TestPriorityWeightStrategy(PriorityWeightStrategy): + def get_weight(self, ti: TaskInstance): + return 99 + + class TestDag: def setup_method(self) -> None: clear_db_runs() @@ -430,6 +440,16 @@ class TestDag: with pytest.raises(AirflowException): EmptyOperator(task_id="should_fail", weight_rule="no rule") + def test_dag_task_custom_weight_strategy(self): + with DAG("dag", start_date=DEFAULT_DATE, default_args={"owner": "owner1"}) as dag: + task = EmptyOperator( + task_id="empty_task", + priority_weight_strategy="tests.models.test_dag.TestPriorityWeightStrategy", + ) + dr = dag.create_dagrun(state=None, run_id="test", execution_date=DEFAULT_DATE) + ti = dr.get_task_instance(task.task_id) + assert ti.priority_weight == 99 + def test_get_num_task_instances(self): test_dag_id = "test_get_num_task_instances_dag" test_task_id = "task_1" diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index 27ce80df1a..a1c4281285 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -3093,6 +3093,7 @@ class TestTaskInstance: "pool_slots": 25, "queue": "some_queue_id", "priority_weight": 123, + "priority_weight_strategy": "downstream", "operator": "some_custom_operator", "custom_operator_name": "some_custom_operator", "queued_dttm": run_date + datetime.timedelta(hours=1), diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py index 30407eb945..3c0ce045ee 100644 --- a/tests/serialization/test_dag_serialization.py +++ b/tests/serialization/test_dag_serialization.py @@ -1243,6 +1243,7 @@ class TestStringifiedDAGs: "pool": "default_pool", "pool_slots": 1, "priority_weight": 1, + "priority_weight_strategy": "downstream", "queue": "default", "resources": None, "retries": 0, @@ -1254,7 +1255,7 @@ class TestStringifiedDAGs: "trigger_rule": "all_success", "wait_for_downstream": False, "wait_for_past_depends_before_skipping": False, - "weight_rule": "downstream", + "weight_rule": None, }, """ !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! diff --git a/tests/www/views/test_views_tasks.py b/tests/www/views/test_views_tasks.py index 55568d4d8f..c432daab4c 100644 --- a/tests/www/views/test_views_tasks.py +++ b/tests/www/views/test_views_tasks.py @@ -1136,6 +1136,7 @@ def test_task_instances(admin_client): "pool": "default_pool", "pool_slots": 1, "priority_weight": 2, + "priority_weight_strategy": "downstream", "queue": "default", "queued_by_job_id": None, "queued_dttm": None, @@ -1168,6 +1169,7 @@ def test_task_instances(admin_client): "pool": "default_pool", "pool_slots": 1, "priority_weight": 2, + "priority_weight_strategy": "downstream", "queue": "default", "queued_by_job_id": None, "queued_dttm": None, @@ -1200,6 +1202,7 @@ def test_task_instances(admin_client): "pool": "default_pool", "pool_slots": 1, "priority_weight": 1, + "priority_weight_strategy": "downstream", "queue": "default", "queued_by_job_id": None, "queued_dttm": None, @@ -1232,6 +1235,7 @@ def test_task_instances(admin_client): "pool": "default_pool", "pool_slots": 1, "priority_weight": 3, + "priority_weight_strategy": "downstream", "queue": "default", "queued_by_job_id": None, "queued_dttm": None, @@ -1264,6 +1268,7 @@ def test_task_instances(admin_client): "pool": "default_pool", "pool_slots": 1, "priority_weight": 3, + "priority_weight_strategy": "downstream", "queue": "default", "queued_by_job_id": None, "queued_dttm": None, @@ -1296,6 +1301,7 @@ def test_task_instances(admin_client): "pool": "default_pool", "pool_slots": 1, "priority_weight": 3, + "priority_weight_strategy": "downstream", "queue": "default", "queued_by_job_id": None, "queued_dttm": None, @@ -1328,6 +1334,7 @@ def test_task_instances(admin_client): "pool": "default_pool", "pool_slots": 1, "priority_weight": 2, + "priority_weight_strategy": "downstream", "queue": "default", "queued_by_job_id": None, "queued_dttm": None,
