This is an automated email from the ASF dual-hosted git repository.
amoghdesai pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new a7b8c7f0123 Make weight_rule independent of airflow-core
priority_strategy (#62210)
a7b8c7f0123 is described below
commit a7b8c7f0123dc402ef581d872d2315a68fc591fe
Author: Amogh Desai <[email protected]>
AuthorDate: Mon Mar 2 14:11:54 2026 +0530
Make weight_rule independent of airflow-core priority_strategy (#62210)
* Make weight_rule independent of airflow-core priority_strategy
* better typing for protocol
---------
Co-authored-by: Tzu-ping Chung <[email protected]>
---
task-sdk/.pre-commit-config.yaml | 13 ++++++++-----
task-sdk/src/airflow/sdk/bases/operator.py | 8 ++++----
task-sdk/src/airflow/sdk/definitions/mappedoperator.py | 6 +++---
task-sdk/src/airflow/sdk/types.py | 17 +++++++++++++++++
4 files changed, 32 insertions(+), 12 deletions(-)
diff --git a/task-sdk/.pre-commit-config.yaml b/task-sdk/.pre-commit-config.yaml
index e9d83fddfbd..e31e46faf35 100644
--- a/task-sdk/.pre-commit-config.yaml
+++ b/task-sdk/.pre-commit-config.yaml
@@ -33,6 +33,7 @@ repos:
exclude: |
(?x)
# TODO: These files need to be refactored to remove core coupling
+ ^src/airflow/sdk/bases/operator\.py$|
^src/airflow/sdk/definitions/decorators/__init__\.pyi$|
^src/airflow/sdk/definitions/decorators/setup_teardown\.py$|
^src/airflow/sdk/definitions/asset/__init__\.py$|
@@ -41,13 +42,15 @@ repos:
^src/airflow/sdk/definitions/mappedoperator\.py$|
^src/airflow/sdk/definitions/deadline\.py$|
^src/airflow/sdk/definitions/dag\.py$|
- ^src/airflow/sdk/execution_time/execute_workload\.py$|
^src/airflow/sdk/definitions/_internal/types\.py$|
- ^src/airflow/sdk/serde/serializers/kubernetes\.py$|
- ^src/airflow/sdk/execution_time/task_runner\.py$|
- ^src/airflow/sdk/execution_time/supervisor\.py$|
+ ^src/airflow/sdk/execution_time/execute_workload\.py$|
^src/airflow/sdk/execution_time/secrets_masker\.py$|
- ^src/airflow/sdk/bases/operator\.py$
+ ^src/airflow/sdk/execution_time/supervisor\.py$|
+ ^src/airflow/sdk/execution_time/task_runner\.py$|
+ ^src/airflow/sdk/io/path.py$|
+ ^src/airflow/sdk/log.py$|
+ ^src/airflow/sdk/serde/serializers/kubernetes\.py$|
+ ^src/airflow/sdk/types.py$
- id: check-init-decorator-arguments
name: Sync model __init__ and decorator arguments
language: python
diff --git a/task-sdk/src/airflow/sdk/bases/operator.py
b/task-sdk/src/airflow/sdk/bases/operator.py
index 4cabe358d0b..6e88f0a94ad 100644
--- a/task-sdk/src/airflow/sdk/bases/operator.py
+++ b/task-sdk/src/airflow/sdk/bases/operator.py
@@ -96,7 +96,7 @@ if TYPE_CHECKING:
from airflow.sdk.definitions.operator_resources import Resources
from airflow.sdk.definitions.taskgroup import TaskGroup
from airflow.sdk.definitions.xcom_arg import XComArg
- from airflow.task.priority_strategy import PriorityWeightStrategy
+ from airflow.sdk.types import WeightRuleParam
from airflow.triggers.base import BaseTrigger, StartTriggerArgs
TaskPreExecuteHook = Callable[[Context], None]
@@ -298,7 +298,7 @@ if TYPE_CHECKING:
retry_delay: timedelta | float = ...,
retry_exponential_backoff: float = ...,
priority_weight: int = ...,
- weight_rule: str | PriorityWeightStrategy = ...,
+ weight_rule: WeightRuleParam = ...,
sla: timedelta | None = ...,
map_index_template: str | None = ...,
max_active_tis_per_dag: int | None = ...,
@@ -868,7 +868,7 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
params: ParamsDict | dict = field(default_factory=ParamsDict)
default_args: dict | None = None
priority_weight: int = DEFAULT_PRIORITY_WEIGHT
- weight_rule: PriorityWeightStrategy | str =
field(default=DEFAULT_WEIGHT_RULE)
+ weight_rule: WeightRuleParam = field(default=DEFAULT_WEIGHT_RULE)
queue: str = DEFAULT_QUEUE
pool: str = DEFAULT_POOL_NAME
pool_slots: int = DEFAULT_POOL_SLOTS
@@ -1024,7 +1024,7 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
params: collections.abc.MutableMapping[str, Any] | None = None,
default_args: dict | None = None,
priority_weight: int = DEFAULT_PRIORITY_WEIGHT,
- weight_rule: str | PriorityWeightStrategy = DEFAULT_WEIGHT_RULE,
+ weight_rule: WeightRuleParam = DEFAULT_WEIGHT_RULE,
queue: str = DEFAULT_QUEUE,
pool: str | None = None,
pool_slots: int = DEFAULT_POOL_SLOTS,
diff --git a/task-sdk/src/airflow/sdk/definitions/mappedoperator.py
b/task-sdk/src/airflow/sdk/definitions/mappedoperator.py
index 345973e11af..f217306f4da 100644
--- a/task-sdk/src/airflow/sdk/definitions/mappedoperator.py
+++ b/task-sdk/src/airflow/sdk/definitions/mappedoperator.py
@@ -66,7 +66,7 @@ if TYPE_CHECKING:
)
from airflow.sdk.definitions.operator_resources import Resources
from airflow.sdk.definitions.param import ParamsDict
- from airflow.task.priority_strategy import PriorityWeightStrategy
+ from airflow.sdk.types import WeightRuleParam
from airflow.triggers.base import StartTriggerArgs
ValidationSource = Literal["expand"] | Literal["partial"]
@@ -555,11 +555,11 @@ class MappedOperator(AbstractOperator):
self.partial_kwargs["priority_weight"] = value
@property
- def weight_rule(self) -> PriorityWeightStrategy:
+ def weight_rule(self) -> WeightRuleParam:
return self.partial_kwargs.get("weight_rule", DEFAULT_WEIGHT_RULE)
@weight_rule.setter
- def weight_rule(self, value: str | PriorityWeightStrategy) -> None:
+ def weight_rule(self, value: WeightRuleParam) -> None:
self.partial_kwargs["weight_rule"] = value
@property
diff --git a/task-sdk/src/airflow/sdk/types.py
b/task-sdk/src/airflow/sdk/types.py
index 2f191a6e080..3477c05b491 100644
--- a/task-sdk/src/airflow/sdk/types.py
+++ b/task-sdk/src/airflow/sdk/types.py
@@ -21,6 +21,7 @@ import uuid
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, NamedTuple, Protocol, TypeAlias
+from airflow.sdk.api.datamodels._generated import WeightRule
from airflow.sdk.bases.xcom import BaseXCom
from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet
@@ -29,6 +30,7 @@ if TYPE_CHECKING:
from pydantic import AwareDatetime, JsonValue
+ from airflow.models.taskinstance import TaskInstance as
SchedulerTaskInstance
from airflow.sdk._shared.logging.types import Logger as Logger
from airflow.sdk.api.datamodels._generated import PreviousTIResponse,
TaskInstanceState
from airflow.sdk.bases.operator import BaseOperator
@@ -39,6 +41,21 @@ if TYPE_CHECKING:
Operator: TypeAlias = BaseOperator | MappedOperator
+class WeightRuleProtocol(Protocol):
+ """
+ Protocol for custom weight strategy instances.
+
+ Matches objects that implement get_weight(ti).
+ """
+
+ def get_weight(self, ti: SchedulerTaskInstance) -> int:
+ """Return the priority weight for the task instance."""
+ ...
+
+
+WeightRuleParam: TypeAlias = str | WeightRule | WeightRuleProtocol
+
+
class TaskInstanceKey(NamedTuple):
"""Key used to identify task instance."""