This is an automated email from the ASF dual-hosted git repository.

jscheffl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new e9c2a6f5761 Revert "Remove PriorityWeightStrategy reference in SDK" 
(#59828)
e9c2a6f5761 is described below

commit e9c2a6f576159241b8022f42027fbfe49421d734
Author: Jens Scheffler <[email protected]>
AuthorDate: Fri Dec 26 22:48:40 2025 +0100

    Revert "Remove PriorityWeightStrategy reference in SDK" (#59828)
    
    * Revert "Remove PriorityWeightStrategy reference in SDK (#59780)"
    
    This reverts commit 60b4ed48e1a2a4b16d4de1ff4b04f61a9d7253c1.
    
    * Tip by TP
---
 airflow-core/newsfragments/59780.significant.rst   |  4 -
 .../airflow/serialization/serialized_objects.py    |  5 +-
 airflow-core/src/airflow/task/priority_strategy.py | 49 +++++++++----
 airflow-core/tests/unit/jobs/test_triggerer_job.py | 21 ++----
 .../tests/unit/models/test_mappedoperator.py       |  3 +-
 .../tests/unit/models/test_taskinstance.py         | 25 +++----
 .../unit/serialization/test_dag_serialization.py   | 85 ++++++----------------
 .../integration/celery/test_celery_executor.py     | 46 ++++--------
 .../cncf/kubernetes/cli/kubernetes_command.py      | 14 +---
 .../providers/cncf/kubernetes/version_compat.py    |  2 -
 task-sdk/src/airflow/sdk/bases/operator.py         | 11 ++-
 .../src/airflow/sdk/definitions/mappedoperator.py  |  8 +-
 task-sdk/tests/task_sdk/bases/test_operator.py     | 31 ++++++--
 13 files changed, 136 insertions(+), 168 deletions(-)

diff --git a/airflow-core/newsfragments/59780.significant.rst 
b/airflow-core/newsfragments/59780.significant.rst
deleted file mode 100644
index df4c64895c7..00000000000
--- a/airflow-core/newsfragments/59780.significant.rst
+++ /dev/null
@@ -1,4 +0,0 @@
-Usused methods removed from (experimental) PriorityWeightStrategy
-
-Functions ``serialize`` and ``deserialize`` were never used anywhere, and have
-been removed from the class. They should not be relied in in user code.
diff --git a/airflow-core/src/airflow/serialization/serialized_objects.py 
b/airflow-core/src/airflow/serialization/serialized_objects.py
index dad7e4834f8..170879c5339 100644
--- a/airflow-core/src/airflow/serialization/serialized_objects.py
+++ b/airflow-core/src/airflow/serialization/serialized_objects.py
@@ -96,7 +96,6 @@ from airflow.task.priority_strategy import (
     PriorityWeightStrategy,
     airflow_priority_weight_strategies,
     airflow_priority_weight_strategies_classes,
-    validate_and_load_priority_weight_strategy,
 )
 from airflow.timetables.base import DagRunInfo, Timetable
 from airflow.triggers.base import BaseTrigger, StartTriggerArgs
@@ -250,7 +249,7 @@ def decode_partition_mapper(var: dict[str, Any]) -> 
PartitionMapper:
     return partition_mapper_class.deserialize(var[Encoding.VAR])
 
 
-def encode_priority_weight_strategy(var: PriorityWeightStrategy | str) -> str:
+def encode_priority_weight_strategy(var: PriorityWeightStrategy) -> str:
     """
     Encode a priority weight strategy instance.
 
@@ -258,7 +257,7 @@ def encode_priority_weight_strategy(var: 
PriorityWeightStrategy | str) -> str:
     for any parameters to be passed to it. If you need to store the 
parameters, you
     should store them in the class itself.
     """
-    priority_weight_strategy_class = 
type(validate_and_load_priority_weight_strategy(var))
+    priority_weight_strategy_class = type(var)
     if priority_weight_strategy_class in 
airflow_priority_weight_strategies_classes:
         return 
airflow_priority_weight_strategies_classes[priority_weight_strategy_class]
     importable_string = qualname(priority_weight_strategy_class)
diff --git a/airflow-core/src/airflow/task/priority_strategy.py 
b/airflow-core/src/airflow/task/priority_strategy.py
index 65b6e3ff34e..a330ca9198b 100644
--- a/airflow-core/src/airflow/task/priority_strategy.py
+++ b/airflow-core/src/airflow/task/priority_strategy.py
@@ -20,9 +20,8 @@
 from __future__ import annotations
 
 from abc import ABC, abstractmethod
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Any
 
-from airflow._shared.module_loading import qualname
 from airflow.task.weight_rule import WeightRule
 
 if TYPE_CHECKING:
@@ -42,22 +41,46 @@ class PriorityWeightStrategy(ABC):
     """
 
     @abstractmethod
-    def get_weight(self, ti: TaskInstance) -> int:
+    def get_weight(self, ti: TaskInstance):
         """Get the priority weight of a task."""
-        raise NotImplementedError("must be implemented by a subclass")
+        ...
+
+    @classmethod
+    def deserialize(cls, data: dict[str, Any]) -> PriorityWeightStrategy:
+        """
+        Deserialize a priority weight strategy from data.
+
+        This is called when a serialized DAG is deserialized. ``data`` will be 
whatever
+        was returned by ``serialize`` during DAG serialization. The default
+        implementation constructs the priority weight strategy without any 
arguments.
+        """
+        return cls(**data)
+
+    def serialize(self) -> dict[str, Any]:
+        """
+        Serialize the priority weight strategy for JSON encoding.
+
+        This is called during DAG serialization to store priority weight 
strategy information
+        in the database. This should return a JSON-serializable dict that will 
be fed into
+        ``deserialize`` when the DAG is deserialized. The default 
implementation returns
+        an empty dict.
+        """
+        return {}
 
     def __eq__(self, other: object) -> bool:
         """Equality comparison."""
-        return isinstance(other, type(self))
+        if not isinstance(other, type(self)):
+            return False
+        return self.serialize() == other.serialize()
 
-    def __hash__(self) -> int:
-        return hash(None)
+    def __hash__(self):
+        return hash(self.serialize())
 
 
 class _AbsolutePriorityWeightStrategy(PriorityWeightStrategy):
     """Priority weight strategy that uses the task's priority weight 
directly."""
 
-    def get_weight(self, ti: TaskInstance) -> int:
+    def get_weight(self, ti: TaskInstance):
         if TYPE_CHECKING:
             assert ti.task
         return ti.task.priority_weight
@@ -81,7 +104,7 @@ class 
_DownstreamPriorityWeightStrategy(PriorityWeightStrategy):
 class _UpstreamPriorityWeightStrategy(PriorityWeightStrategy):
     """Priority weight strategy that uses the sum of the priority weights of 
all upstream tasks."""
 
-    def get_weight(self, ti: TaskInstance) -> int:
+    def get_weight(self, ti: TaskInstance):
         if TYPE_CHECKING:
             assert ti.task
         dag = ti.task.get_dag()
@@ -93,9 +116,6 @@ class 
_UpstreamPriorityWeightStrategy(PriorityWeightStrategy):
 
 
 airflow_priority_weight_strategies: dict[str, type[PriorityWeightStrategy]] = {
-    qualname(_AbsolutePriorityWeightStrategy): _AbsolutePriorityWeightStrategy,
-    qualname(_DownstreamPriorityWeightStrategy): 
_DownstreamPriorityWeightStrategy,
-    qualname(_UpstreamPriorityWeightStrategy): _UpstreamPriorityWeightStrategy,
     WeightRule.ABSOLUTE: _AbsolutePriorityWeightStrategy,
     WeightRule.DOWNSTREAM: _DownstreamPriorityWeightStrategy,
     WeightRule.UPSTREAM: _UpstreamPriorityWeightStrategy,
@@ -103,9 +123,7 @@ airflow_priority_weight_strategies: dict[str, 
type[PriorityWeightStrategy]] = {
 
 
 airflow_priority_weight_strategies_classes = {
-    _AbsolutePriorityWeightStrategy: WeightRule.ABSOLUTE,
-    _DownstreamPriorityWeightStrategy: WeightRule.DOWNSTREAM,
-    _UpstreamPriorityWeightStrategy: WeightRule.UPSTREAM,
+    cls: name for name, cls in airflow_priority_weight_strategies.items()
 }
 
 
@@ -121,6 +139,7 @@ def validate_and_load_priority_weight_strategy(
 
     :meta private:
     """
+    from airflow._shared.module_loading import qualname
     from airflow.serialization.serialized_objects import 
_get_registered_priority_weight_strategy
 
     if priority_weight_strategy is None:
diff --git a/airflow-core/tests/unit/jobs/test_triggerer_job.py 
b/airflow-core/tests/unit/jobs/test_triggerer_job.py
index 0e3b1647482..b9df47ec768 100644
--- a/airflow-core/tests/unit/jobs/test_triggerer_job.py
+++ b/airflow-core/tests/unit/jobs/test_triggerer_job.py
@@ -130,17 +130,12 @@ def create_trigger_in_db(session, trigger, operator=None):
         operator = BaseOperator(task_id="test_ti", dag=dag)
     session.add(dag_model)
 
-    lazy_serdag = LazyDeserializedDAG.from_dag(dag)
-    SerializedDagModel.write_dag(lazy_serdag, bundle_name=bundle_name)
+    SerializedDagModel.write_dag(LazyDeserializedDAG.from_dag(dag), 
bundle_name=bundle_name)
     session.add(run)
     session.add(trigger_orm)
     session.flush()
     dag_version = DagVersion.get_latest_version(dag.dag_id)
-    task_instance = TaskInstance(
-        lazy_serdag._real_dag.get_task(operator.task_id),
-        run_id=run.run_id,
-        dag_version_id=dag_version.id,
-    )
+    task_instance = TaskInstance(operator, run_id=run.run_id, 
dag_version_id=dag_version.id)
     task_instance.trigger_id = trigger_orm.id
     session.add(task_instance)
     session.commit()
@@ -445,8 +440,7 @@ class TestTriggerRunner:
 
 
 @pytest.mark.asyncio
[email protected]("testing_dag_bundle")
-async def test_trigger_create_race_condition_38599(session, 
supervisor_builder):
+async def test_trigger_create_race_condition_38599(session, 
supervisor_builder, testing_dag_bundle):
     """
     This verifies the resolution of race condition documented in github issue 
#38599.
     More details in the issue description.
@@ -471,17 +465,14 @@ async def 
test_trigger_create_race_condition_38599(session, supervisor_builder):
     session.flush()
 
     bundle_name = "testing"
-    with DAG(dag_id="test-dag") as dag:
-        task = PythonOperator(task_id="dummy-task", python_callable=print)
+    dag = DAG(dag_id="test-dag")
     dm = DagModel(dag_id="test-dag", bundle_name=bundle_name)
     session.add(dm)
-
-    lazy_serdag = LazyDeserializedDAG.from_dag(dag)
-    SerializedDagModel.write_dag(lazy_serdag, bundle_name=bundle_name)
+    SerializedDagModel.write_dag(LazyDeserializedDAG.from_dag(dag), 
bundle_name=bundle_name)
     dag_run = DagRun(dag.dag_id, run_id="abc", run_type="none", 
run_after=timezone.utcnow())
     dag_version = DagVersion.get_latest_version(dag.dag_id)
     ti = TaskInstance(
-        lazy_serdag._real_dag.get_task(task.task_id),
+        PythonOperator(task_id="dummy-task", python_callable=print),
         run_id=dag_run.run_id,
         state=TaskInstanceState.DEFERRED,
         dag_version_id=dag_version.id,
diff --git a/airflow-core/tests/unit/models/test_mappedoperator.py 
b/airflow-core/tests/unit/models/test_mappedoperator.py
index 408e76b22a6..e5bfe78a059 100644
--- a/airflow-core/tests/unit/models/test_mappedoperator.py
+++ b/airflow-core/tests/unit/models/test_mappedoperator.py
@@ -34,6 +34,7 @@ from airflow.models.taskmap import TaskMap
 from airflow.providers.standard.operators.python import PythonOperator
 from airflow.sdk import DAG, BaseOperator, TaskGroup, setup, task, task_group, 
teardown
 from airflow.serialization.definitions.baseoperator import 
SerializedBaseOperator
+from airflow.task.priority_strategy import PriorityWeightStrategy
 from airflow.task.trigger_rule import TriggerRule
 from airflow.utils.state import TaskInstanceState
 
@@ -1523,7 +1524,7 @@ class TestMappedSetupTeardown:
         assert op.pool == SerializedBaseOperator.pool
         assert op.pool_slots == SerializedBaseOperator.pool_slots
         assert op.priority_weight == SerializedBaseOperator.priority_weight
-        assert op.weight_rule == "downstream"
+        assert isinstance(op.weight_rule, PriorityWeightStrategy)
         assert op.email == email
         assert op.execution_timeout == execution_timeout
         assert op.retry_delay == retry_delay
diff --git a/airflow-core/tests/unit/models/test_taskinstance.py 
b/airflow-core/tests/unit/models/test_taskinstance.py
index da49fd81b01..e83f906fa04 100644
--- a/airflow-core/tests/unit/models/test_taskinstance.py
+++ b/airflow-core/tests/unit/models/test_taskinstance.py
@@ -82,7 +82,7 @@ from airflow.sdk.execution_time.comms import AssetEventsResult
 from airflow.serialization.definitions.assets import SerializedAsset
 from airflow.serialization.definitions.dag import SerializedDAG
 from airflow.serialization.encoders import ensure_serialized_asset
-from airflow.serialization.serialized_objects import OperatorSerialization, 
create_scheduler_operator
+from airflow.serialization.serialized_objects import OperatorSerialization
 from airflow.ti_deps.dep_context import DepContext
 from airflow.ti_deps.dependencies_deps import REQUEUEABLE_DEPS, RUNNING_DEPS
 from airflow.ti_deps.dependencies_states import RUNNABLE_STATES
@@ -2812,7 +2812,7 @@ def test_refresh_from_task(pool_override, 
queue_by_policy, monkeypatch):
 
         
monkeypatch.setattr("airflow.models.taskinstance.task_instance_mutation_hook", 
mock_policy)
 
-    sdk_task = EmptyOperator(
+    task = EmptyOperator(
         task_id="empty",
         queue=default_queue,
         pool="test_pool1",
@@ -2822,28 +2822,27 @@ def test_refresh_from_task(pool_override, 
queue_by_policy, monkeypatch):
         retries=30,
         executor_config={"KubernetesExecutor": {"image": 
"myCustomDockerImage"}},
     )
-    ser_task = create_scheduler_operator(sdk_task)
-    ti = TI(ser_task, run_id=None, dag_version_id=mock.MagicMock())
-    ti.refresh_from_task(ser_task, pool_override=pool_override)
+    ti = TI(task, run_id=None, dag_version_id=mock.MagicMock())
+    ti.refresh_from_task(task, pool_override=pool_override)
 
     assert ti.queue == expected_queue
 
     if pool_override:
         assert ti.pool == pool_override
     else:
-        assert ti.pool == sdk_task.pool
+        assert ti.pool == task.pool
 
-    assert ti.pool_slots == sdk_task.pool_slots
-    assert ti.priority_weight == ser_task.weight_rule.get_weight(ti)
-    assert ti.run_as_user == sdk_task.run_as_user
-    assert ti.max_tries == sdk_task.retries
-    assert ti.executor_config == sdk_task.executor_config
+    assert ti.pool_slots == task.pool_slots
+    assert ti.priority_weight == task.weight_rule.get_weight(ti)
+    assert ti.run_as_user == task.run_as_user
+    assert ti.max_tries == task.retries
+    assert ti.executor_config == task.executor_config
     assert ti.operator == EmptyOperator.__name__
 
     # Test that refresh_from_task does not reset ti.max_tries
-    expected_max_tries = sdk_task.retries + 10
+    expected_max_tries = task.retries + 10
     ti.max_tries = expected_max_tries
-    ti.refresh_from_task(ser_task)
+    ti.refresh_from_task(task)
     assert ti.max_tries == expected_max_tries
 
 
diff --git a/airflow-core/tests/unit/serialization/test_dag_serialization.py 
b/airflow-core/tests/unit/serialization/test_dag_serialization.py
index 0cbdedb1b6d..7b87c6e00a8 100644
--- a/airflow-core/tests/unit/serialization/test_dag_serialization.py
+++ b/airflow-core/tests/unit/serialization/test_dag_serialization.py
@@ -84,12 +84,7 @@ from airflow.serialization.serialized_objects import (
     OperatorSerialization,
     _XComRef,
 )
-from airflow.task.priority_strategy import (
-    PriorityWeightStrategy,
-    _DownstreamPriorityWeightStrategy,
-    airflow_priority_weight_strategies,
-    validate_and_load_priority_weight_strategy,
-)
+from airflow.task.priority_strategy import _AbsolutePriorityWeightStrategy, 
_DownstreamPriorityWeightStrategy
 from airflow.ti_deps.deps.ready_to_reschedule import ReadyToRescheduleDep
 from airflow.timetables.simple import NullTimetable, OnceTimetable
 from airflow.triggers.base import StartTriggerArgs
@@ -784,7 +779,6 @@ class TestStringifiedDAGs:
                 "inlets",
                 "outlets",
                 "task_type",
-                "weight_rule",
                 "_operator_name",
                 # Type is excluded, so don't check it
                 "_log",
@@ -818,7 +812,6 @@ class TestStringifiedDAGs:
                 "operator_class",
                 "partial_kwargs",
                 "expand_input",
-                "weight_rule",
             }
 
         assert serialized_task.task_type == task.task_type
@@ -853,12 +846,6 @@ class TestStringifiedDAGs:
         if isinstance(task.params, ParamsDict) and 
isinstance(serialized_task.params, ParamsDict):
             assert serialized_task.params.dump() == task.params.dump()
 
-        if isinstance(task.weight_rule, PriorityWeightStrategy):
-            assert task.weight_rule == serialized_task.weight_rule
-        else:
-            task_weight_strat = 
validate_and_load_priority_weight_strategy(task.weight_rule)
-            assert task_weight_strat == serialized_task.weight_rule
-
         if isinstance(task, MappedOperator):
             # MappedOperator.operator_class now stores only minimal type 
information
             # for memory efficiency (task_type and _operator_name).
@@ -1585,7 +1572,7 @@ class TestStringifiedDAGs:
             "ui_fgcolor": "#000",
             "wait_for_downstream": False,
             "wait_for_past_depends_before_skipping": False,
-            "weight_rule": WeightRule.DOWNSTREAM,
+            "weight_rule": _DownstreamPriorityWeightStrategy(),
             "multiple_outputs": False,
         }, """
 
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
@@ -3999,6 +3986,27 @@ def 
test_task_callback_backward_compatibility(old_callback_name, new_callback_na
     assert getattr(deserialized_task_empty, new_callback_name) is False
 
 
+def test_weight_rule_absolute_serialization_deserialization():
+    """Test that weight_rule can be serialized and deserialized correctly."""
+    from airflow.sdk import task
+
+    with DAG("test_weight_rule_dag") as dag:
+
+        @task(weight_rule=WeightRule.ABSOLUTE)
+        def test_task():
+            return "test"
+
+        test_task()
+
+    serialized_dag = DagSerialization.to_dict(dag)
+    assert serialized_dag["dag"]["tasks"][0]["__var"]["weight_rule"] == 
"absolute"
+
+    deserialized_dag = DagSerialization.from_dict(serialized_dag)
+
+    deserialized_task = deserialized_dag.task_dict["test_task"]
+    assert isinstance(deserialized_task.weight_rule, 
_AbsolutePriorityWeightStrategy)
+
+
 class TestClientDefaultsGeneration:
     """Test client defaults generation functionality."""
 
@@ -4471,50 +4479,3 @@ def 
test_dag_default_args_callbacks_serialization(callbacks, expected_has_flags,
 
     deserialized_dag = DagSerialization.deserialize_dag(serialized_dag_dict)
     assert deserialized_dag.dag_id == "test_default_args_callbacks"
-
-
-class RegisteredPriorityWeightStrategy(PriorityWeightStrategy):
-    def get_weight(self, ti):
-        return 99
-
-
-class TestWeightRule:
-    def test_default(self):
-        sdkop = BaseOperator(task_id="should_fail")
-        serop = 
OperatorSerialization.deserialize(OperatorSerialization.serialize(sdkop))
-        assert serop.weight_rule == _DownstreamPriorityWeightStrategy()
-
-    @pytest.mark.parametrize(("value", "expected"), 
list(airflow_priority_weight_strategies.items()))
-    def test_builtin(self, value, expected):
-        sdkop = BaseOperator(task_id="should_fail", weight_rule=value)
-        serop = 
OperatorSerialization.deserialize(OperatorSerialization.serialize(sdkop))
-        assert serop.weight_rule == expected()
-
-    def test_custom(self):
-        sdkop = BaseOperator(task_id="should_fail", 
weight_rule=RegisteredPriorityWeightStrategy())
-        with mock.patch(
-            
"airflow.serialization.serialized_objects._get_registered_priority_weight_strategy",
-            return_value=RegisteredPriorityWeightStrategy,
-        ) as mock_get_registered_priority_weight_strategy:
-            serop = 
OperatorSerialization.deserialize(OperatorSerialization.serialize(sdkop))
-
-        assert serop.weight_rule == RegisteredPriorityWeightStrategy()
-        assert mock_get_registered_priority_weight_strategy.mock_calls == [
-            
mock.call("unit.serialization.test_dag_serialization.RegisteredPriorityWeightStrategy"),
-            
mock.call("unit.serialization.test_dag_serialization.RegisteredPriorityWeightStrategy"),
-            
mock.call("unit.serialization.test_dag_serialization.RegisteredPriorityWeightStrategy"),
-        ]
-
-    def test_invalid(self):
-        op = BaseOperator(task_id="should_fail", weight_rule="no rule")
-        with pytest.raises(ValueError, match="Unknown priority strategy"):
-            OperatorSerialization.serialize(op)
-
-    def test_not_registered_custom(self):
-        class NotRegisteredPriorityWeightStrategy(PriorityWeightStrategy):
-            def get_weight(self, ti):
-                return 99
-
-        op = BaseOperator(task_id="empty_task", 
weight_rule=NotRegisteredPriorityWeightStrategy())
-        with pytest.raises(ValueError, match="Unknown priority strategy"):
-            OperatorSerialization.serialize(op)
diff --git a/providers/celery/tests/integration/celery/test_celery_executor.py 
b/providers/celery/tests/integration/celery/test_celery_executor.py
index 18c0f10f42a..4d60d6546c7 100644
--- a/providers/celery/tests/integration/celery/test_celery_executor.py
+++ b/providers/celery/tests/integration/celery/test_celery_executor.py
@@ -49,7 +49,7 @@ from airflow.providers.standard.operators.bash import 
BashOperator
 from airflow.utils.state import State
 
 from tests_common.test_utils import db
-from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS, 
AIRFLOW_V_3_1_PLUS
+from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
 
 logger = logging.getLogger(__name__)
 
@@ -215,21 +215,13 @@ class TestCeleryExecutor:
             # fake_execute_command takes no arguments while execute_workload 
takes 1,
             # which will cause TypeError when calling task.apply_async()
             executor = celery_executor.CeleryExecutor()
-            with DAG(dag_id="dag_id") as dag:
-                task = BashOperator(
-                    task_id="test",
-                    bash_command="true",
-                    start_date=datetime.now(),
-                )
-            if AIRFLOW_V_3_1_PLUS:
-                from tests_common.test_utils.dag import create_scheduler_dag
-
-                ti = TaskInstance(
-                    task=create_scheduler_dag(dag).get_task(task.task_id),
-                    run_id="abc",
-                    dag_version_id=uuid6.uuid7(),
-                )
-            elif AIRFLOW_V_3_0_PLUS:
+            task = BashOperator(
+                task_id="test",
+                bash_command="true",
+                dag=DAG(dag_id="dag_id"),
+                start_date=datetime.now(),
+            )
+            if AIRFLOW_V_3_0_PLUS:
                 ti = TaskInstance(task=task, run_id="abc", 
dag_version_id=uuid6.uuid7())
             else:
                 ti = TaskInstance(task=task, run_id="abc")
@@ -262,21 +254,13 @@ class TestCeleryExecutor:
             assert executor.task_publish_retries == {}
             assert executor.task_publish_max_retries == 3, "Assert Default Max 
Retries is 3"
 
-            with DAG(dag_id="id") as dag:
-                task = BashOperator(
-                    task_id="test",
-                    bash_command="true",
-                    start_date=datetime.now(),
-                )
-            if AIRFLOW_V_3_1_PLUS:
-                from tests_common.test_utils.dag import create_scheduler_dag
-
-                ti = TaskInstance(
-                    task=create_scheduler_dag(dag).get_task(task.task_id),
-                    run_id="abc",
-                    dag_version_id=uuid6.uuid7(),
-                )
-            elif AIRFLOW_V_3_0_PLUS:
+            task = BashOperator(
+                task_id="test",
+                bash_command="true",
+                dag=DAG(dag_id="id"),
+                start_date=datetime.now(),
+            )
+            if AIRFLOW_V_3_0_PLUS:
                 ti = TaskInstance(task=task, run_id="abc", 
dag_version_id=uuid6.uuid7())
             else:
                 ti = TaskInstance(task=task, run_id="abc")
diff --git 
a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/cli/kubernetes_command.py
 
b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/cli/kubernetes_command.py
index 3b298a17791..e2b72589351 100644
--- 
a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/cli/kubernetes_command.py
+++ 
b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/cli/kubernetes_command.py
@@ -32,11 +32,7 @@ from 
airflow.providers.cncf.kubernetes.executors.kubernetes_executor import Kube
 from airflow.providers.cncf.kubernetes.kube_client import get_kube_client
 from airflow.providers.cncf.kubernetes.kubernetes_helper_functions import 
create_unique_id
 from airflow.providers.cncf.kubernetes.pod_generator import PodGenerator, 
generate_pod_command_args
-from airflow.providers.cncf.kubernetes.version_compat import (
-    AIRFLOW_V_3_0_PLUS,
-    AIRFLOW_V_3_1_PLUS,
-    AIRFLOW_V_3_2_PLUS,
-)
+from airflow.providers.cncf.kubernetes.version_compat import 
AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_1_PLUS
 from airflow.utils import cli as cli_utils, yaml
 from airflow.utils.providers_configuration_loader import 
providers_configuration_loaded
 from airflow.utils.types import DagRunType
@@ -74,13 +70,7 @@ def generate_pod_yaml(args):
     kube_config = KubeConfig()
 
     for task in dag.tasks:
-        if AIRFLOW_V_3_2_PLUS:
-            from uuid6 import uuid7
-
-            from airflow.serialization.serialized_objects import 
create_scheduler_operator
-
-            ti = TaskInstance(create_scheduler_operator(task), 
run_id=dr.run_id, dag_version_id=uuid7())
-        elif AIRFLOW_V_3_0_PLUS:
+        if AIRFLOW_V_3_0_PLUS:
             from uuid6 import uuid7
 
             ti = TaskInstance(task, run_id=dr.run_id, dag_version_id=uuid7())
diff --git 
a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/version_compat.py
 
b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/version_compat.py
index 7751da07081..2fb2ac93a12 100644
--- 
a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/version_compat.py
+++ 
b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/version_compat.py
@@ -34,11 +34,9 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]:
 
 AIRFLOW_V_3_0_PLUS = get_base_airflow_version_tuple() >= (3, 0, 0)
 AIRFLOW_V_3_1_PLUS = get_base_airflow_version_tuple() >= (3, 1, 0)
-AIRFLOW_V_3_2_PLUS = get_base_airflow_version_tuple() >= (3, 2, 0)
 
 
 __all__ = [
     "AIRFLOW_V_3_0_PLUS",
     "AIRFLOW_V_3_1_PLUS",
-    "AIRFLOW_V_3_2_PLUS",
 ]
diff --git a/task-sdk/src/airflow/sdk/bases/operator.py 
b/task-sdk/src/airflow/sdk/bases/operator.py
index f8163a98f83..7f6d6cf297a 100644
--- a/task-sdk/src/airflow/sdk/bases/operator.py
+++ b/task-sdk/src/airflow/sdk/bases/operator.py
@@ -62,6 +62,11 @@ from airflow.sdk.definitions.edges import EdgeModifier
 from airflow.sdk.definitions.mappedoperator import OperatorPartial, 
validate_mapping_kwargs
 from airflow.sdk.definitions.param import ParamsDict
 from airflow.sdk.exceptions import RemovedInAirflow4Warning
+from airflow.task.priority_strategy import (
+    PriorityWeightStrategy,
+    airflow_priority_weight_strategies,
+    validate_and_load_priority_weight_strategy,
+)
 
 # Databases do not support arbitrary precision integers, so we need to limit 
the range of priority weights.
 # postgres: -2147483648 to +2147483647 (see 
https://www.postgresql.org/docs/current/datatype-numeric.html)
@@ -838,7 +843,9 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
     params: ParamsDict | dict = field(default_factory=ParamsDict)
     default_args: dict | None = None
     priority_weight: int = DEFAULT_PRIORITY_WEIGHT
-    weight_rule: PriorityWeightStrategy | str = 
field(default=DEFAULT_WEIGHT_RULE)
+    weight_rule: PriorityWeightStrategy = field(
+        default_factory=airflow_priority_weight_strategies[DEFAULT_WEIGHT_RULE]
+    )
     queue: str = DEFAULT_QUEUE
     pool: str = DEFAULT_POOL_NAME
     pool_slots: int = DEFAULT_POOL_SLOTS
@@ -1136,7 +1143,7 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
         self.params = ParamsDict(params)
 
         self.priority_weight = priority_weight
-        self.weight_rule = weight_rule
+        self.weight_rule = 
validate_and_load_priority_weight_strategy(weight_rule)
 
         self.max_active_tis_per_dag: int | None = max_active_tis_per_dag
         self.max_active_tis_per_dagrun: int | None = max_active_tis_per_dagrun
diff --git a/task-sdk/src/airflow/sdk/definitions/mappedoperator.py 
b/task-sdk/src/airflow/sdk/definitions/mappedoperator.py
index dd3ea15ccc6..6aff9ca3c68 100644
--- a/task-sdk/src/airflow/sdk/definitions/mappedoperator.py
+++ b/task-sdk/src/airflow/sdk/definitions/mappedoperator.py
@@ -51,6 +51,7 @@ from airflow.sdk.definitions._internal.expandinput import (
 )
 from airflow.sdk.definitions._internal.types import NOTSET
 from airflow.serialization.enums import DagAttributeTypes
+from airflow.task.priority_strategy import PriorityWeightStrategy, 
validate_and_load_priority_weight_strategy
 
 if TYPE_CHECKING:
     import datetime
@@ -66,7 +67,6 @@ if TYPE_CHECKING:
     )
     from airflow.sdk.definitions.operator_resources import Resources
     from airflow.sdk.definitions.param import ParamsDict
-    from airflow.task.priority_strategy import PriorityWeightStrategy
     from airflow.triggers.base import StartTriggerArgs
 
 ValidationSource = Literal["expand"] | Literal["partial"]
@@ -556,11 +556,13 @@ class MappedOperator(AbstractOperator):
 
     @property
     def weight_rule(self) -> PriorityWeightStrategy:
-        return self.partial_kwargs.get("weight_rule", DEFAULT_WEIGHT_RULE)
+        return validate_and_load_priority_weight_strategy(
+            self.partial_kwargs.get("weight_rule", DEFAULT_WEIGHT_RULE)
+        )
 
     @weight_rule.setter
     def weight_rule(self, value: str | PriorityWeightStrategy) -> None:
-        self.partial_kwargs["weight_rule"] = value
+        self.partial_kwargs["weight_rule"] = 
validate_and_load_priority_weight_strategy(value)
 
     @property
     def max_active_tis_per_dag(self) -> int | None:
diff --git a/task-sdk/tests/task_sdk/bases/test_operator.py 
b/task-sdk/tests/task_sdk/bases/test_operator.py
index 85b127dfce6..5db27774e23 100644
--- a/task-sdk/tests/task_sdk/bases/test_operator.py
+++ b/task-sdk/tests/task_sdk/bases/test_operator.py
@@ -29,7 +29,7 @@ import jinja2
 import pytest
 import structlog
 
-from airflow.sdk import DAG, Label, TaskGroup, task as task_decorator
+from airflow.sdk import task as task_decorator
 from airflow.sdk._shared.secrets_masker import _secrets_masker, mask_secret
 from airflow.sdk.bases.operator import (
     BaseOperator,
@@ -39,8 +39,12 @@ from airflow.sdk.bases.operator import (
     chain_linear,
     cross_downstream,
 )
+from airflow.sdk.definitions.dag import DAG
+from airflow.sdk.definitions.edges import Label
 from airflow.sdk.definitions.param import ParamsDict
+from airflow.sdk.definitions.taskgroup import TaskGroup
 from airflow.sdk.definitions.template import literal
+from airflow.task.priority_strategy import _DownstreamPriorityWeightStrategy, 
_UpstreamPriorityWeightStrategy
 
 DEFAULT_DATE = datetime(2016, 1, 1, tzinfo=timezone.utc)
 
@@ -267,12 +271,29 @@ class TestBaseOperator:
 
     def test_weight_rule_default(self):
         op = BaseOperator(task_id="test_task")
-        assert op.weight_rule == "downstream"
+        assert _DownstreamPriorityWeightStrategy() == op.weight_rule
 
     def test_weight_rule_override(self):
-        whatever_value = object()
-        op = BaseOperator(task_id="test_task", weight_rule=whatever_value)
-        assert op.weight_rule is whatever_value
+        op = BaseOperator(task_id="test_task", weight_rule="upstream")
+        assert _UpstreamPriorityWeightStrategy() == op.weight_rule
+
+    def test_dag_task_invalid_weight_rule(self):
+        # Test if we enter an invalid weight rule
+        with pytest.raises(ValueError, match="Unknown priority strategy"):
+            BaseOperator(task_id="should_fail", weight_rule="no rule")
+
+    def test_dag_task_not_registered_weight_strategy(self):
+        from airflow.task.priority_strategy import PriorityWeightStrategy
+
+        class NotRegisteredPriorityWeightStrategy(PriorityWeightStrategy):
+            def get_weight(self, ti):
+                return 99
+
+        with pytest.raises(ValueError, match="Unknown priority strategy"):
+            BaseOperator(
+                task_id="empty_task",
+                weight_rule=NotRegisteredPriorityWeightStrategy(),
+            )
 
     def test_db_safe_priority(self):
         """Test the db_safe_priority function."""

Reply via email to