This is an automated email from the ASF dual-hosted git repository.

uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 60b4ed48e1a Remove PriorityWeightStrategy reference in SDK (#59780)
60b4ed48e1a is described below

commit 60b4ed48e1a2a4b16d4de1ff4b04f61a9d7253c1
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Fri Dec 26 16:57:35 2025 +0800

    Remove PriorityWeightStrategy reference in SDK (#59780)
---
 airflow-core/newsfragments/59780.significant.rst   |  4 +
 .../airflow/serialization/serialized_objects.py    |  5 +-
 airflow-core/src/airflow/task/priority_strategy.py | 49 ++++---------
 airflow-core/tests/unit/jobs/test_triggerer_job.py | 21 ++++--
 airflow-core/tests/unit/models/test_dagrun.py      |  2 +-
 .../tests/unit/models/test_mappedoperator.py       |  3 +-
 .../tests/unit/models/test_taskinstance.py         | 25 ++++---
 .../unit/serialization/test_dag_serialization.py   | 85 ++++++++++++++++------
 .../integration/celery/test_celery_executor.py     | 46 ++++++++----
 .../cncf/kubernetes/cli/kubernetes_command.py      | 14 +++-
 .../providers/cncf/kubernetes/version_compat.py    |  2 +
 task-sdk/src/airflow/sdk/bases/operator.py         | 11 +--
 .../src/airflow/sdk/definitions/mappedoperator.py  |  8 +-
 task-sdk/tests/task_sdk/bases/test_operator.py     | 31 ++------
 14 files changed, 169 insertions(+), 137 deletions(-)

diff --git a/airflow-core/newsfragments/59780.significant.rst 
b/airflow-core/newsfragments/59780.significant.rst
new file mode 100644
index 00000000000..df4c64895c7
--- /dev/null
+++ b/airflow-core/newsfragments/59780.significant.rst
@@ -0,0 +1,4 @@
+Usused methods removed from (experimental) PriorityWeightStrategy
+
+Functions ``serialize`` and ``deserialize`` were never used anywhere, and have
+been removed from the class. They should not be relied in in user code.
diff --git a/airflow-core/src/airflow/serialization/serialized_objects.py 
b/airflow-core/src/airflow/serialization/serialized_objects.py
index 5db5b9d8c7a..63e46b1bb2a 100644
--- a/airflow-core/src/airflow/serialization/serialized_objects.py
+++ b/airflow-core/src/airflow/serialization/serialized_objects.py
@@ -94,6 +94,7 @@ from airflow.task.priority_strategy import (
     PriorityWeightStrategy,
     airflow_priority_weight_strategies,
     airflow_priority_weight_strategies_classes,
+    validate_and_load_priority_weight_strategy,
 )
 from airflow.timetables.base import DagRunInfo, Timetable
 from airflow.triggers.base import BaseTrigger, StartTriggerArgs
@@ -247,7 +248,7 @@ def decode_partition_mapper(var: dict[str, Any]) -> 
PartitionMapper:
     return partition_mapper_class.deserialize(var[Encoding.VAR])
 
 
-def encode_priority_weight_strategy(var: PriorityWeightStrategy) -> str:
+def encode_priority_weight_strategy(var: PriorityWeightStrategy | str) -> str:
     """
     Encode a priority weight strategy instance.
 
@@ -255,7 +256,7 @@ def encode_priority_weight_strategy(var: 
PriorityWeightStrategy) -> str:
     for any parameters to be passed to it. If you need to store the 
parameters, you
     should store them in the class itself.
     """
-    priority_weight_strategy_class = type(var)
+    priority_weight_strategy_class = 
type(validate_and_load_priority_weight_strategy(var))
     if priority_weight_strategy_class in 
airflow_priority_weight_strategies_classes:
         return 
airflow_priority_weight_strategies_classes[priority_weight_strategy_class]
     importable_string = qualname(priority_weight_strategy_class)
diff --git a/airflow-core/src/airflow/task/priority_strategy.py 
b/airflow-core/src/airflow/task/priority_strategy.py
index a330ca9198b..65b6e3ff34e 100644
--- a/airflow-core/src/airflow/task/priority_strategy.py
+++ b/airflow-core/src/airflow/task/priority_strategy.py
@@ -20,8 +20,9 @@
 from __future__ import annotations
 
 from abc import ABC, abstractmethod
-from typing import TYPE_CHECKING, Any
+from typing import TYPE_CHECKING
 
+from airflow._shared.module_loading import qualname
 from airflow.task.weight_rule import WeightRule
 
 if TYPE_CHECKING:
@@ -41,46 +42,22 @@ class PriorityWeightStrategy(ABC):
     """
 
     @abstractmethod
-    def get_weight(self, ti: TaskInstance):
+    def get_weight(self, ti: TaskInstance) -> int:
         """Get the priority weight of a task."""
-        ...
-
-    @classmethod
-    def deserialize(cls, data: dict[str, Any]) -> PriorityWeightStrategy:
-        """
-        Deserialize a priority weight strategy from data.
-
-        This is called when a serialized DAG is deserialized. ``data`` will be 
whatever
-        was returned by ``serialize`` during DAG serialization. The default
-        implementation constructs the priority weight strategy without any 
arguments.
-        """
-        return cls(**data)
-
-    def serialize(self) -> dict[str, Any]:
-        """
-        Serialize the priority weight strategy for JSON encoding.
-
-        This is called during DAG serialization to store priority weight 
strategy information
-        in the database. This should return a JSON-serializable dict that will 
be fed into
-        ``deserialize`` when the DAG is deserialized. The default 
implementation returns
-        an empty dict.
-        """
-        return {}
+        raise NotImplementedError("must be implemented by a subclass")
 
     def __eq__(self, other: object) -> bool:
         """Equality comparison."""
-        if not isinstance(other, type(self)):
-            return False
-        return self.serialize() == other.serialize()
+        return isinstance(other, type(self))
 
-    def __hash__(self):
-        return hash(self.serialize())
+    def __hash__(self) -> int:
+        return hash(None)
 
 
 class _AbsolutePriorityWeightStrategy(PriorityWeightStrategy):
     """Priority weight strategy that uses the task's priority weight 
directly."""
 
-    def get_weight(self, ti: TaskInstance):
+    def get_weight(self, ti: TaskInstance) -> int:
         if TYPE_CHECKING:
             assert ti.task
         return ti.task.priority_weight
@@ -104,7 +81,7 @@ class 
_DownstreamPriorityWeightStrategy(PriorityWeightStrategy):
 class _UpstreamPriorityWeightStrategy(PriorityWeightStrategy):
     """Priority weight strategy that uses the sum of the priority weights of 
all upstream tasks."""
 
-    def get_weight(self, ti: TaskInstance):
+    def get_weight(self, ti: TaskInstance) -> int:
         if TYPE_CHECKING:
             assert ti.task
         dag = ti.task.get_dag()
@@ -116,6 +93,9 @@ class 
_UpstreamPriorityWeightStrategy(PriorityWeightStrategy):
 
 
 airflow_priority_weight_strategies: dict[str, type[PriorityWeightStrategy]] = {
+    qualname(_AbsolutePriorityWeightStrategy): _AbsolutePriorityWeightStrategy,
+    qualname(_DownstreamPriorityWeightStrategy): 
_DownstreamPriorityWeightStrategy,
+    qualname(_UpstreamPriorityWeightStrategy): _UpstreamPriorityWeightStrategy,
     WeightRule.ABSOLUTE: _AbsolutePriorityWeightStrategy,
     WeightRule.DOWNSTREAM: _DownstreamPriorityWeightStrategy,
     WeightRule.UPSTREAM: _UpstreamPriorityWeightStrategy,
@@ -123,7 +103,9 @@ airflow_priority_weight_strategies: dict[str, 
type[PriorityWeightStrategy]] = {
 
 
 airflow_priority_weight_strategies_classes = {
-    cls: name for name, cls in airflow_priority_weight_strategies.items()
+    _AbsolutePriorityWeightStrategy: WeightRule.ABSOLUTE,
+    _DownstreamPriorityWeightStrategy: WeightRule.DOWNSTREAM,
+    _UpstreamPriorityWeightStrategy: WeightRule.UPSTREAM,
 }
 
 
@@ -139,7 +121,6 @@ def validate_and_load_priority_weight_strategy(
 
     :meta private:
     """
-    from airflow._shared.module_loading import qualname
     from airflow.serialization.serialized_objects import 
_get_registered_priority_weight_strategy
 
     if priority_weight_strategy is None:
diff --git a/airflow-core/tests/unit/jobs/test_triggerer_job.py 
b/airflow-core/tests/unit/jobs/test_triggerer_job.py
index b9df47ec768..0e3b1647482 100644
--- a/airflow-core/tests/unit/jobs/test_triggerer_job.py
+++ b/airflow-core/tests/unit/jobs/test_triggerer_job.py
@@ -130,12 +130,17 @@ def create_trigger_in_db(session, trigger, operator=None):
         operator = BaseOperator(task_id="test_ti", dag=dag)
     session.add(dag_model)
 
-    SerializedDagModel.write_dag(LazyDeserializedDAG.from_dag(dag), 
bundle_name=bundle_name)
+    lazy_serdag = LazyDeserializedDAG.from_dag(dag)
+    SerializedDagModel.write_dag(lazy_serdag, bundle_name=bundle_name)
     session.add(run)
     session.add(trigger_orm)
     session.flush()
     dag_version = DagVersion.get_latest_version(dag.dag_id)
-    task_instance = TaskInstance(operator, run_id=run.run_id, 
dag_version_id=dag_version.id)
+    task_instance = TaskInstance(
+        lazy_serdag._real_dag.get_task(operator.task_id),
+        run_id=run.run_id,
+        dag_version_id=dag_version.id,
+    )
     task_instance.trigger_id = trigger_orm.id
     session.add(task_instance)
     session.commit()
@@ -440,7 +445,8 @@ class TestTriggerRunner:
 
 
 @pytest.mark.asyncio
-async def test_trigger_create_race_condition_38599(session, 
supervisor_builder, testing_dag_bundle):
[email protected]("testing_dag_bundle")
+async def test_trigger_create_race_condition_38599(session, 
supervisor_builder):
     """
     This verifies the resolution of race condition documented in github issue 
#38599.
     More details in the issue description.
@@ -465,14 +471,17 @@ async def 
test_trigger_create_race_condition_38599(session, supervisor_builder,
     session.flush()
 
     bundle_name = "testing"
-    dag = DAG(dag_id="test-dag")
+    with DAG(dag_id="test-dag") as dag:
+        task = PythonOperator(task_id="dummy-task", python_callable=print)
     dm = DagModel(dag_id="test-dag", bundle_name=bundle_name)
     session.add(dm)
-    SerializedDagModel.write_dag(LazyDeserializedDAG.from_dag(dag), 
bundle_name=bundle_name)
+
+    lazy_serdag = LazyDeserializedDAG.from_dag(dag)
+    SerializedDagModel.write_dag(lazy_serdag, bundle_name=bundle_name)
     dag_run = DagRun(dag.dag_id, run_id="abc", run_type="none", 
run_after=timezone.utcnow())
     dag_version = DagVersion.get_latest_version(dag.dag_id)
     ti = TaskInstance(
-        PythonOperator(task_id="dummy-task", python_callable=print),
+        lazy_serdag._real_dag.get_task(task.task_id),
         run_id=dag_run.run_id,
         state=TaskInstanceState.DEFERRED,
         dag_version_id=dag_version.id,
diff --git a/airflow-core/tests/unit/models/test_dagrun.py 
b/airflow-core/tests/unit/models/test_dagrun.py
index 1c85d1762f8..9c154976891 100644
--- a/airflow-core/tests/unit/models/test_dagrun.py
+++ b/airflow-core/tests/unit/models/test_dagrun.py
@@ -1479,7 +1479,7 @@ def 
test_mapped_literal_to_xcom_arg_verify_integrity(dag_maker, session):
         t1 = BaseOperator(task_id="task_1")
         task_2.expand(arg2=t1.output)
 
-    dr.dag = dag_maker.dag
+    dr.dag = dag_maker.serialized_model.dag
     dag_version_id = DagVersion.get_latest_version(dag_id=dr.dag_id, 
session=session).id
     dr.verify_integrity(dag_version_id=dag_version_id, session=session)
 
diff --git a/airflow-core/tests/unit/models/test_mappedoperator.py 
b/airflow-core/tests/unit/models/test_mappedoperator.py
index e5bfe78a059..408e76b22a6 100644
--- a/airflow-core/tests/unit/models/test_mappedoperator.py
+++ b/airflow-core/tests/unit/models/test_mappedoperator.py
@@ -34,7 +34,6 @@ from airflow.models.taskmap import TaskMap
 from airflow.providers.standard.operators.python import PythonOperator
 from airflow.sdk import DAG, BaseOperator, TaskGroup, setup, task, task_group, 
teardown
 from airflow.serialization.definitions.baseoperator import 
SerializedBaseOperator
-from airflow.task.priority_strategy import PriorityWeightStrategy
 from airflow.task.trigger_rule import TriggerRule
 from airflow.utils.state import TaskInstanceState
 
@@ -1524,7 +1523,7 @@ class TestMappedSetupTeardown:
         assert op.pool == SerializedBaseOperator.pool
         assert op.pool_slots == SerializedBaseOperator.pool_slots
         assert op.priority_weight == SerializedBaseOperator.priority_weight
-        assert isinstance(op.weight_rule, PriorityWeightStrategy)
+        assert op.weight_rule == "downstream"
         assert op.email == email
         assert op.execution_timeout == execution_timeout
         assert op.retry_delay == retry_delay
diff --git a/airflow-core/tests/unit/models/test_taskinstance.py 
b/airflow-core/tests/unit/models/test_taskinstance.py
index e83f906fa04..da49fd81b01 100644
--- a/airflow-core/tests/unit/models/test_taskinstance.py
+++ b/airflow-core/tests/unit/models/test_taskinstance.py
@@ -82,7 +82,7 @@ from airflow.sdk.execution_time.comms import AssetEventsResult
 from airflow.serialization.definitions.assets import SerializedAsset
 from airflow.serialization.definitions.dag import SerializedDAG
 from airflow.serialization.encoders import ensure_serialized_asset
-from airflow.serialization.serialized_objects import OperatorSerialization
+from airflow.serialization.serialized_objects import OperatorSerialization, 
create_scheduler_operator
 from airflow.ti_deps.dep_context import DepContext
 from airflow.ti_deps.dependencies_deps import REQUEUEABLE_DEPS, RUNNING_DEPS
 from airflow.ti_deps.dependencies_states import RUNNABLE_STATES
@@ -2812,7 +2812,7 @@ def test_refresh_from_task(pool_override, 
queue_by_policy, monkeypatch):
 
         
monkeypatch.setattr("airflow.models.taskinstance.task_instance_mutation_hook", 
mock_policy)
 
-    task = EmptyOperator(
+    sdk_task = EmptyOperator(
         task_id="empty",
         queue=default_queue,
         pool="test_pool1",
@@ -2822,27 +2822,28 @@ def test_refresh_from_task(pool_override, 
queue_by_policy, monkeypatch):
         retries=30,
         executor_config={"KubernetesExecutor": {"image": 
"myCustomDockerImage"}},
     )
-    ti = TI(task, run_id=None, dag_version_id=mock.MagicMock())
-    ti.refresh_from_task(task, pool_override=pool_override)
+    ser_task = create_scheduler_operator(sdk_task)
+    ti = TI(ser_task, run_id=None, dag_version_id=mock.MagicMock())
+    ti.refresh_from_task(ser_task, pool_override=pool_override)
 
     assert ti.queue == expected_queue
 
     if pool_override:
         assert ti.pool == pool_override
     else:
-        assert ti.pool == task.pool
+        assert ti.pool == sdk_task.pool
 
-    assert ti.pool_slots == task.pool_slots
-    assert ti.priority_weight == task.weight_rule.get_weight(ti)
-    assert ti.run_as_user == task.run_as_user
-    assert ti.max_tries == task.retries
-    assert ti.executor_config == task.executor_config
+    assert ti.pool_slots == sdk_task.pool_slots
+    assert ti.priority_weight == ser_task.weight_rule.get_weight(ti)
+    assert ti.run_as_user == sdk_task.run_as_user
+    assert ti.max_tries == sdk_task.retries
+    assert ti.executor_config == sdk_task.executor_config
     assert ti.operator == EmptyOperator.__name__
 
     # Test that refresh_from_task does not reset ti.max_tries
-    expected_max_tries = task.retries + 10
+    expected_max_tries = sdk_task.retries + 10
     ti.max_tries = expected_max_tries
-    ti.refresh_from_task(task)
+    ti.refresh_from_task(ser_task)
     assert ti.max_tries == expected_max_tries
 
 
diff --git a/airflow-core/tests/unit/serialization/test_dag_serialization.py 
b/airflow-core/tests/unit/serialization/test_dag_serialization.py
index 7b87c6e00a8..0cbdedb1b6d 100644
--- a/airflow-core/tests/unit/serialization/test_dag_serialization.py
+++ b/airflow-core/tests/unit/serialization/test_dag_serialization.py
@@ -84,7 +84,12 @@ from airflow.serialization.serialized_objects import (
     OperatorSerialization,
     _XComRef,
 )
-from airflow.task.priority_strategy import _AbsolutePriorityWeightStrategy, 
_DownstreamPriorityWeightStrategy
+from airflow.task.priority_strategy import (
+    PriorityWeightStrategy,
+    _DownstreamPriorityWeightStrategy,
+    airflow_priority_weight_strategies,
+    validate_and_load_priority_weight_strategy,
+)
 from airflow.ti_deps.deps.ready_to_reschedule import ReadyToRescheduleDep
 from airflow.timetables.simple import NullTimetable, OnceTimetable
 from airflow.triggers.base import StartTriggerArgs
@@ -779,6 +784,7 @@ class TestStringifiedDAGs:
                 "inlets",
                 "outlets",
                 "task_type",
+                "weight_rule",
                 "_operator_name",
                 # Type is excluded, so don't check it
                 "_log",
@@ -812,6 +818,7 @@ class TestStringifiedDAGs:
                 "operator_class",
                 "partial_kwargs",
                 "expand_input",
+                "weight_rule",
             }
 
         assert serialized_task.task_type == task.task_type
@@ -846,6 +853,12 @@ class TestStringifiedDAGs:
         if isinstance(task.params, ParamsDict) and 
isinstance(serialized_task.params, ParamsDict):
             assert serialized_task.params.dump() == task.params.dump()
 
+        if isinstance(task.weight_rule, PriorityWeightStrategy):
+            assert task.weight_rule == serialized_task.weight_rule
+        else:
+            task_weight_strat = 
validate_and_load_priority_weight_strategy(task.weight_rule)
+            assert task_weight_strat == serialized_task.weight_rule
+
         if isinstance(task, MappedOperator):
             # MappedOperator.operator_class now stores only minimal type 
information
             # for memory efficiency (task_type and _operator_name).
@@ -1572,7 +1585,7 @@ class TestStringifiedDAGs:
             "ui_fgcolor": "#000",
             "wait_for_downstream": False,
             "wait_for_past_depends_before_skipping": False,
-            "weight_rule": _DownstreamPriorityWeightStrategy(),
+            "weight_rule": WeightRule.DOWNSTREAM,
             "multiple_outputs": False,
         }, """
 
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
@@ -3986,27 +3999,6 @@ def 
test_task_callback_backward_compatibility(old_callback_name, new_callback_na
     assert getattr(deserialized_task_empty, new_callback_name) is False
 
 
-def test_weight_rule_absolute_serialization_deserialization():
-    """Test that weight_rule can be serialized and deserialized correctly."""
-    from airflow.sdk import task
-
-    with DAG("test_weight_rule_dag") as dag:
-
-        @task(weight_rule=WeightRule.ABSOLUTE)
-        def test_task():
-            return "test"
-
-        test_task()
-
-    serialized_dag = DagSerialization.to_dict(dag)
-    assert serialized_dag["dag"]["tasks"][0]["__var"]["weight_rule"] == 
"absolute"
-
-    deserialized_dag = DagSerialization.from_dict(serialized_dag)
-
-    deserialized_task = deserialized_dag.task_dict["test_task"]
-    assert isinstance(deserialized_task.weight_rule, 
_AbsolutePriorityWeightStrategy)
-
-
 class TestClientDefaultsGeneration:
     """Test client defaults generation functionality."""
 
@@ -4479,3 +4471,50 @@ def 
test_dag_default_args_callbacks_serialization(callbacks, expected_has_flags,
 
     deserialized_dag = DagSerialization.deserialize_dag(serialized_dag_dict)
     assert deserialized_dag.dag_id == "test_default_args_callbacks"
+
+
+class RegisteredPriorityWeightStrategy(PriorityWeightStrategy):
+    def get_weight(self, ti):
+        return 99
+
+
+class TestWeightRule:
+    def test_default(self):
+        sdkop = BaseOperator(task_id="should_fail")
+        serop = 
OperatorSerialization.deserialize(OperatorSerialization.serialize(sdkop))
+        assert serop.weight_rule == _DownstreamPriorityWeightStrategy()
+
+    @pytest.mark.parametrize(("value", "expected"), 
list(airflow_priority_weight_strategies.items()))
+    def test_builtin(self, value, expected):
+        sdkop = BaseOperator(task_id="should_fail", weight_rule=value)
+        serop = 
OperatorSerialization.deserialize(OperatorSerialization.serialize(sdkop))
+        assert serop.weight_rule == expected()
+
+    def test_custom(self):
+        sdkop = BaseOperator(task_id="should_fail", 
weight_rule=RegisteredPriorityWeightStrategy())
+        with mock.patch(
+            
"airflow.serialization.serialized_objects._get_registered_priority_weight_strategy",
+            return_value=RegisteredPriorityWeightStrategy,
+        ) as mock_get_registered_priority_weight_strategy:
+            serop = 
OperatorSerialization.deserialize(OperatorSerialization.serialize(sdkop))
+
+        assert serop.weight_rule == RegisteredPriorityWeightStrategy()
+        assert mock_get_registered_priority_weight_strategy.mock_calls == [
+            
mock.call("unit.serialization.test_dag_serialization.RegisteredPriorityWeightStrategy"),
+            
mock.call("unit.serialization.test_dag_serialization.RegisteredPriorityWeightStrategy"),
+            
mock.call("unit.serialization.test_dag_serialization.RegisteredPriorityWeightStrategy"),
+        ]
+
+    def test_invalid(self):
+        op = BaseOperator(task_id="should_fail", weight_rule="no rule")
+        with pytest.raises(ValueError, match="Unknown priority strategy"):
+            OperatorSerialization.serialize(op)
+
+    def test_not_registered_custom(self):
+        class NotRegisteredPriorityWeightStrategy(PriorityWeightStrategy):
+            def get_weight(self, ti):
+                return 99
+
+        op = BaseOperator(task_id="empty_task", 
weight_rule=NotRegisteredPriorityWeightStrategy())
+        with pytest.raises(ValueError, match="Unknown priority strategy"):
+            OperatorSerialization.serialize(op)
diff --git a/providers/celery/tests/integration/celery/test_celery_executor.py 
b/providers/celery/tests/integration/celery/test_celery_executor.py
index 4d60d6546c7..18c0f10f42a 100644
--- a/providers/celery/tests/integration/celery/test_celery_executor.py
+++ b/providers/celery/tests/integration/celery/test_celery_executor.py
@@ -49,7 +49,7 @@ from airflow.providers.standard.operators.bash import 
BashOperator
 from airflow.utils.state import State
 
 from tests_common.test_utils import db
-from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
+from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS, 
AIRFLOW_V_3_1_PLUS
 
 logger = logging.getLogger(__name__)
 
@@ -215,13 +215,21 @@ class TestCeleryExecutor:
             # fake_execute_command takes no arguments while execute_workload 
takes 1,
             # which will cause TypeError when calling task.apply_async()
             executor = celery_executor.CeleryExecutor()
-            task = BashOperator(
-                task_id="test",
-                bash_command="true",
-                dag=DAG(dag_id="dag_id"),
-                start_date=datetime.now(),
-            )
-            if AIRFLOW_V_3_0_PLUS:
+            with DAG(dag_id="dag_id") as dag:
+                task = BashOperator(
+                    task_id="test",
+                    bash_command="true",
+                    start_date=datetime.now(),
+                )
+            if AIRFLOW_V_3_1_PLUS:
+                from tests_common.test_utils.dag import create_scheduler_dag
+
+                ti = TaskInstance(
+                    task=create_scheduler_dag(dag).get_task(task.task_id),
+                    run_id="abc",
+                    dag_version_id=uuid6.uuid7(),
+                )
+            elif AIRFLOW_V_3_0_PLUS:
                 ti = TaskInstance(task=task, run_id="abc", 
dag_version_id=uuid6.uuid7())
             else:
                 ti = TaskInstance(task=task, run_id="abc")
@@ -254,13 +262,21 @@ class TestCeleryExecutor:
             assert executor.task_publish_retries == {}
             assert executor.task_publish_max_retries == 3, "Assert Default Max 
Retries is 3"
 
-            task = BashOperator(
-                task_id="test",
-                bash_command="true",
-                dag=DAG(dag_id="id"),
-                start_date=datetime.now(),
-            )
-            if AIRFLOW_V_3_0_PLUS:
+            with DAG(dag_id="id") as dag:
+                task = BashOperator(
+                    task_id="test",
+                    bash_command="true",
+                    start_date=datetime.now(),
+                )
+            if AIRFLOW_V_3_1_PLUS:
+                from tests_common.test_utils.dag import create_scheduler_dag
+
+                ti = TaskInstance(
+                    task=create_scheduler_dag(dag).get_task(task.task_id),
+                    run_id="abc",
+                    dag_version_id=uuid6.uuid7(),
+                )
+            elif AIRFLOW_V_3_0_PLUS:
                 ti = TaskInstance(task=task, run_id="abc", 
dag_version_id=uuid6.uuid7())
             else:
                 ti = TaskInstance(task=task, run_id="abc")
diff --git 
a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/cli/kubernetes_command.py
 
b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/cli/kubernetes_command.py
index e2b72589351..3b298a17791 100644
--- 
a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/cli/kubernetes_command.py
+++ 
b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/cli/kubernetes_command.py
@@ -32,7 +32,11 @@ from 
airflow.providers.cncf.kubernetes.executors.kubernetes_executor import Kube
 from airflow.providers.cncf.kubernetes.kube_client import get_kube_client
 from airflow.providers.cncf.kubernetes.kubernetes_helper_functions import 
create_unique_id
 from airflow.providers.cncf.kubernetes.pod_generator import PodGenerator, 
generate_pod_command_args
-from airflow.providers.cncf.kubernetes.version_compat import 
AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_1_PLUS
+from airflow.providers.cncf.kubernetes.version_compat import (
+    AIRFLOW_V_3_0_PLUS,
+    AIRFLOW_V_3_1_PLUS,
+    AIRFLOW_V_3_2_PLUS,
+)
 from airflow.utils import cli as cli_utils, yaml
 from airflow.utils.providers_configuration_loader import 
providers_configuration_loaded
 from airflow.utils.types import DagRunType
@@ -70,7 +74,13 @@ def generate_pod_yaml(args):
     kube_config = KubeConfig()
 
     for task in dag.tasks:
-        if AIRFLOW_V_3_0_PLUS:
+        if AIRFLOW_V_3_2_PLUS:
+            from uuid6 import uuid7
+
+            from airflow.serialization.serialized_objects import 
create_scheduler_operator
+
+            ti = TaskInstance(create_scheduler_operator(task), 
run_id=dr.run_id, dag_version_id=uuid7())
+        elif AIRFLOW_V_3_0_PLUS:
             from uuid6 import uuid7
 
             ti = TaskInstance(task, run_id=dr.run_id, dag_version_id=uuid7())
diff --git 
a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/version_compat.py
 
b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/version_compat.py
index 2fb2ac93a12..7751da07081 100644
--- 
a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/version_compat.py
+++ 
b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/version_compat.py
@@ -34,9 +34,11 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]:
 
 AIRFLOW_V_3_0_PLUS = get_base_airflow_version_tuple() >= (3, 0, 0)
 AIRFLOW_V_3_1_PLUS = get_base_airflow_version_tuple() >= (3, 1, 0)
+AIRFLOW_V_3_2_PLUS = get_base_airflow_version_tuple() >= (3, 2, 0)
 
 
 __all__ = [
     "AIRFLOW_V_3_0_PLUS",
     "AIRFLOW_V_3_1_PLUS",
+    "AIRFLOW_V_3_2_PLUS",
 ]
diff --git a/task-sdk/src/airflow/sdk/bases/operator.py 
b/task-sdk/src/airflow/sdk/bases/operator.py
index 7f6d6cf297a..f8163a98f83 100644
--- a/task-sdk/src/airflow/sdk/bases/operator.py
+++ b/task-sdk/src/airflow/sdk/bases/operator.py
@@ -62,11 +62,6 @@ from airflow.sdk.definitions.edges import EdgeModifier
 from airflow.sdk.definitions.mappedoperator import OperatorPartial, 
validate_mapping_kwargs
 from airflow.sdk.definitions.param import ParamsDict
 from airflow.sdk.exceptions import RemovedInAirflow4Warning
-from airflow.task.priority_strategy import (
-    PriorityWeightStrategy,
-    airflow_priority_weight_strategies,
-    validate_and_load_priority_weight_strategy,
-)
 
 # Databases do not support arbitrary precision integers, so we need to limit 
the range of priority weights.
 # postgres: -2147483648 to +2147483647 (see 
https://www.postgresql.org/docs/current/datatype-numeric.html)
@@ -843,9 +838,7 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
     params: ParamsDict | dict = field(default_factory=ParamsDict)
     default_args: dict | None = None
     priority_weight: int = DEFAULT_PRIORITY_WEIGHT
-    weight_rule: PriorityWeightStrategy = field(
-        default_factory=airflow_priority_weight_strategies[DEFAULT_WEIGHT_RULE]
-    )
+    weight_rule: PriorityWeightStrategy | str = 
field(default=DEFAULT_WEIGHT_RULE)
     queue: str = DEFAULT_QUEUE
     pool: str = DEFAULT_POOL_NAME
     pool_slots: int = DEFAULT_POOL_SLOTS
@@ -1143,7 +1136,7 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
         self.params = ParamsDict(params)
 
         self.priority_weight = priority_weight
-        self.weight_rule = 
validate_and_load_priority_weight_strategy(weight_rule)
+        self.weight_rule = weight_rule
 
         self.max_active_tis_per_dag: int | None = max_active_tis_per_dag
         self.max_active_tis_per_dagrun: int | None = max_active_tis_per_dagrun
diff --git a/task-sdk/src/airflow/sdk/definitions/mappedoperator.py 
b/task-sdk/src/airflow/sdk/definitions/mappedoperator.py
index 8e2f3f2a990..29d0de286ec 100644
--- a/task-sdk/src/airflow/sdk/definitions/mappedoperator.py
+++ b/task-sdk/src/airflow/sdk/definitions/mappedoperator.py
@@ -52,7 +52,6 @@ from airflow.sdk.definitions._internal.expandinput import (
 )
 from airflow.sdk.definitions._internal.types import NOTSET
 from airflow.serialization.enums import DagAttributeTypes
-from airflow.task.priority_strategy import PriorityWeightStrategy, 
validate_and_load_priority_weight_strategy
 
 if TYPE_CHECKING:
     import datetime
@@ -68,6 +67,7 @@ if TYPE_CHECKING:
     from airflow.sdk.definitions._internal.expandinput import ExpandInput
     from airflow.sdk.definitions.operator_resources import Resources
     from airflow.sdk.definitions.param import ParamsDict
+    from airflow.task.priority_strategy import PriorityWeightStrategy
     from airflow.triggers.base import StartTriggerArgs
 
 ValidationSource = Literal["expand"] | Literal["partial"]
@@ -557,13 +557,11 @@ class MappedOperator(AbstractOperator):
 
     @property
     def weight_rule(self) -> PriorityWeightStrategy:
-        return validate_and_load_priority_weight_strategy(
-            self.partial_kwargs.get("weight_rule", DEFAULT_WEIGHT_RULE)
-        )
+        return self.partial_kwargs.get("weight_rule", DEFAULT_WEIGHT_RULE)
 
     @weight_rule.setter
     def weight_rule(self, value: str | PriorityWeightStrategy) -> None:
-        self.partial_kwargs["weight_rule"] = 
validate_and_load_priority_weight_strategy(value)
+        self.partial_kwargs["weight_rule"] = value
 
     @property
     def max_active_tis_per_dag(self) -> int | None:
diff --git a/task-sdk/tests/task_sdk/bases/test_operator.py 
b/task-sdk/tests/task_sdk/bases/test_operator.py
index 5db27774e23..85b127dfce6 100644
--- a/task-sdk/tests/task_sdk/bases/test_operator.py
+++ b/task-sdk/tests/task_sdk/bases/test_operator.py
@@ -29,7 +29,7 @@ import jinja2
 import pytest
 import structlog
 
-from airflow.sdk import task as task_decorator
+from airflow.sdk import DAG, Label, TaskGroup, task as task_decorator
 from airflow.sdk._shared.secrets_masker import _secrets_masker, mask_secret
 from airflow.sdk.bases.operator import (
     BaseOperator,
@@ -39,12 +39,8 @@ from airflow.sdk.bases.operator import (
     chain_linear,
     cross_downstream,
 )
-from airflow.sdk.definitions.dag import DAG
-from airflow.sdk.definitions.edges import Label
 from airflow.sdk.definitions.param import ParamsDict
-from airflow.sdk.definitions.taskgroup import TaskGroup
 from airflow.sdk.definitions.template import literal
-from airflow.task.priority_strategy import _DownstreamPriorityWeightStrategy, 
_UpstreamPriorityWeightStrategy
 
 DEFAULT_DATE = datetime(2016, 1, 1, tzinfo=timezone.utc)
 
@@ -271,29 +267,12 @@ class TestBaseOperator:
 
     def test_weight_rule_default(self):
         op = BaseOperator(task_id="test_task")
-        assert _DownstreamPriorityWeightStrategy() == op.weight_rule
+        assert op.weight_rule == "downstream"
 
     def test_weight_rule_override(self):
-        op = BaseOperator(task_id="test_task", weight_rule="upstream")
-        assert _UpstreamPriorityWeightStrategy() == op.weight_rule
-
-    def test_dag_task_invalid_weight_rule(self):
-        # Test if we enter an invalid weight rule
-        with pytest.raises(ValueError, match="Unknown priority strategy"):
-            BaseOperator(task_id="should_fail", weight_rule="no rule")
-
-    def test_dag_task_not_registered_weight_strategy(self):
-        from airflow.task.priority_strategy import PriorityWeightStrategy
-
-        class NotRegisteredPriorityWeightStrategy(PriorityWeightStrategy):
-            def get_weight(self, ti):
-                return 99
-
-        with pytest.raises(ValueError, match="Unknown priority strategy"):
-            BaseOperator(
-                task_id="empty_task",
-                weight_rule=NotRegisteredPriorityWeightStrategy(),
-            )
+        whatever_value = object()
+        op = BaseOperator(task_id="test_task", weight_rule=whatever_value)
+        assert op.weight_rule is whatever_value
 
     def test_db_safe_priority(self):
         """Test the db_safe_priority function."""

Reply via email to