This is an automated email from the ASF dual-hosted git repository.

amoghdesai pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new a7b8c7f0123 Make weight_rule independent of airflow-core 
priority_strategy (#62210)
a7b8c7f0123 is described below

commit a7b8c7f0123dc402ef581d872d2315a68fc591fe
Author: Amogh Desai <[email protected]>
AuthorDate: Mon Mar 2 14:11:54 2026 +0530

    Make weight_rule independent of airflow-core priority_strategy (#62210)
    
    * Make weight_rule independent of airflow-core priority_strategy
    
    * better typing for protocol
    
    ---------
    
    Co-authored-by: Tzu-ping Chung <[email protected]>
---
 task-sdk/.pre-commit-config.yaml                       | 13 ++++++++-----
 task-sdk/src/airflow/sdk/bases/operator.py             |  8 ++++----
 task-sdk/src/airflow/sdk/definitions/mappedoperator.py |  6 +++---
 task-sdk/src/airflow/sdk/types.py                      | 17 +++++++++++++++++
 4 files changed, 32 insertions(+), 12 deletions(-)

diff --git a/task-sdk/.pre-commit-config.yaml b/task-sdk/.pre-commit-config.yaml
index e9d83fddfbd..e31e46faf35 100644
--- a/task-sdk/.pre-commit-config.yaml
+++ b/task-sdk/.pre-commit-config.yaml
@@ -33,6 +33,7 @@ repos:
         exclude: |
           (?x)
           # TODO: These files need to be refactored to remove core coupling
+          ^src/airflow/sdk/bases/operator\.py$|
           ^src/airflow/sdk/definitions/decorators/__init__\.pyi$|
           ^src/airflow/sdk/definitions/decorators/setup_teardown\.py$|
           ^src/airflow/sdk/definitions/asset/__init__\.py$|
@@ -41,13 +42,15 @@ repos:
           ^src/airflow/sdk/definitions/mappedoperator\.py$|
           ^src/airflow/sdk/definitions/deadline\.py$|
           ^src/airflow/sdk/definitions/dag\.py$|
-          ^src/airflow/sdk/execution_time/execute_workload\.py$|
           ^src/airflow/sdk/definitions/_internal/types\.py$|
-          ^src/airflow/sdk/serde/serializers/kubernetes\.py$|
-          ^src/airflow/sdk/execution_time/task_runner\.py$|
-          ^src/airflow/sdk/execution_time/supervisor\.py$|
+          ^src/airflow/sdk/execution_time/execute_workload\.py$|
           ^src/airflow/sdk/execution_time/secrets_masker\.py$|
-          ^src/airflow/sdk/bases/operator\.py$
+          ^src/airflow/sdk/execution_time/supervisor\.py$|
+          ^src/airflow/sdk/execution_time/task_runner\.py$|
+          ^src/airflow/sdk/io/path.py$|
+          ^src/airflow/sdk/log.py$|
+          ^src/airflow/sdk/serde/serializers/kubernetes\.py$|
+          ^src/airflow/sdk/types.py$
       - id: check-init-decorator-arguments
         name: Sync model __init__ and decorator arguments
         language: python
diff --git a/task-sdk/src/airflow/sdk/bases/operator.py 
b/task-sdk/src/airflow/sdk/bases/operator.py
index 4cabe358d0b..6e88f0a94ad 100644
--- a/task-sdk/src/airflow/sdk/bases/operator.py
+++ b/task-sdk/src/airflow/sdk/bases/operator.py
@@ -96,7 +96,7 @@ if TYPE_CHECKING:
     from airflow.sdk.definitions.operator_resources import Resources
     from airflow.sdk.definitions.taskgroup import TaskGroup
     from airflow.sdk.definitions.xcom_arg import XComArg
-    from airflow.task.priority_strategy import PriorityWeightStrategy
+    from airflow.sdk.types import WeightRuleParam
     from airflow.triggers.base import BaseTrigger, StartTriggerArgs
 
     TaskPreExecuteHook = Callable[[Context], None]
@@ -298,7 +298,7 @@ if TYPE_CHECKING:
         retry_delay: timedelta | float = ...,
         retry_exponential_backoff: float = ...,
         priority_weight: int = ...,
-        weight_rule: str | PriorityWeightStrategy = ...,
+        weight_rule: WeightRuleParam = ...,
         sla: timedelta | None = ...,
         map_index_template: str | None = ...,
         max_active_tis_per_dag: int | None = ...,
@@ -868,7 +868,7 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
     params: ParamsDict | dict = field(default_factory=ParamsDict)
     default_args: dict | None = None
     priority_weight: int = DEFAULT_PRIORITY_WEIGHT
-    weight_rule: PriorityWeightStrategy | str = 
field(default=DEFAULT_WEIGHT_RULE)
+    weight_rule: WeightRuleParam = field(default=DEFAULT_WEIGHT_RULE)
     queue: str = DEFAULT_QUEUE
     pool: str = DEFAULT_POOL_NAME
     pool_slots: int = DEFAULT_POOL_SLOTS
@@ -1024,7 +1024,7 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
         params: collections.abc.MutableMapping[str, Any] | None = None,
         default_args: dict | None = None,
         priority_weight: int = DEFAULT_PRIORITY_WEIGHT,
-        weight_rule: str | PriorityWeightStrategy = DEFAULT_WEIGHT_RULE,
+        weight_rule: WeightRuleParam = DEFAULT_WEIGHT_RULE,
         queue: str = DEFAULT_QUEUE,
         pool: str | None = None,
         pool_slots: int = DEFAULT_POOL_SLOTS,
diff --git a/task-sdk/src/airflow/sdk/definitions/mappedoperator.py 
b/task-sdk/src/airflow/sdk/definitions/mappedoperator.py
index 345973e11af..f217306f4da 100644
--- a/task-sdk/src/airflow/sdk/definitions/mappedoperator.py
+++ b/task-sdk/src/airflow/sdk/definitions/mappedoperator.py
@@ -66,7 +66,7 @@ if TYPE_CHECKING:
     )
     from airflow.sdk.definitions.operator_resources import Resources
     from airflow.sdk.definitions.param import ParamsDict
-    from airflow.task.priority_strategy import PriorityWeightStrategy
+    from airflow.sdk.types import WeightRuleParam
     from airflow.triggers.base import StartTriggerArgs
 
 ValidationSource = Literal["expand"] | Literal["partial"]
@@ -555,11 +555,11 @@ class MappedOperator(AbstractOperator):
         self.partial_kwargs["priority_weight"] = value
 
     @property
-    def weight_rule(self) -> PriorityWeightStrategy:
+    def weight_rule(self) -> WeightRuleParam:
         return self.partial_kwargs.get("weight_rule", DEFAULT_WEIGHT_RULE)
 
     @weight_rule.setter
-    def weight_rule(self, value: str | PriorityWeightStrategy) -> None:
+    def weight_rule(self, value: WeightRuleParam) -> None:
         self.partial_kwargs["weight_rule"] = value
 
     @property
diff --git a/task-sdk/src/airflow/sdk/types.py 
b/task-sdk/src/airflow/sdk/types.py
index 2f191a6e080..3477c05b491 100644
--- a/task-sdk/src/airflow/sdk/types.py
+++ b/task-sdk/src/airflow/sdk/types.py
@@ -21,6 +21,7 @@ import uuid
 from collections.abc import Iterable
 from typing import TYPE_CHECKING, Any, NamedTuple, Protocol, TypeAlias
 
+from airflow.sdk.api.datamodels._generated import WeightRule
 from airflow.sdk.bases.xcom import BaseXCom
 from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet
 
@@ -29,6 +30,7 @@ if TYPE_CHECKING:
 
     from pydantic import AwareDatetime, JsonValue
 
+    from airflow.models.taskinstance import TaskInstance as 
SchedulerTaskInstance
     from airflow.sdk._shared.logging.types import Logger as Logger
     from airflow.sdk.api.datamodels._generated import PreviousTIResponse, 
TaskInstanceState
     from airflow.sdk.bases.operator import BaseOperator
@@ -39,6 +41,21 @@ if TYPE_CHECKING:
     Operator: TypeAlias = BaseOperator | MappedOperator
 
 
+class WeightRuleProtocol(Protocol):
+    """
+    Protocol for custom weight strategy instances.
+
+    Matches objects that implement get_weight(ti).
+    """
+
+    def get_weight(self, ti: SchedulerTaskInstance) -> int:
+        """Return the priority weight for the task instance."""
+        ...
+
+
+WeightRuleParam: TypeAlias = str | WeightRule | WeightRuleProtocol
+
+
 class TaskInstanceKey(NamedTuple):
     """Key used to identify task instance."""
 

Reply via email to