This is an automated email from the ASF dual-hosted git repository.
jscheffl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new e9c2a6f5761 Revert "Remove PriorityWeightStrategy reference in SDK"
(#59828)
e9c2a6f5761 is described below
commit e9c2a6f576159241b8022f42027fbfe49421d734
Author: Jens Scheffler <[email protected]>
AuthorDate: Fri Dec 26 22:48:40 2025 +0100
Revert "Remove PriorityWeightStrategy reference in SDK" (#59828)
* Revert "Remove PriorityWeightStrategy reference in SDK (#59780)"
This reverts commit 60b4ed48e1a2a4b16d4de1ff4b04f61a9d7253c1.
* Tip by TP
---
airflow-core/newsfragments/59780.significant.rst | 4 -
.../airflow/serialization/serialized_objects.py | 5 +-
airflow-core/src/airflow/task/priority_strategy.py | 49 +++++++++----
airflow-core/tests/unit/jobs/test_triggerer_job.py | 21 ++----
.../tests/unit/models/test_mappedoperator.py | 3 +-
.../tests/unit/models/test_taskinstance.py | 25 +++----
.../unit/serialization/test_dag_serialization.py | 85 ++++++----------------
.../integration/celery/test_celery_executor.py | 46 ++++--------
.../cncf/kubernetes/cli/kubernetes_command.py | 14 +---
.../providers/cncf/kubernetes/version_compat.py | 2 -
task-sdk/src/airflow/sdk/bases/operator.py | 11 ++-
.../src/airflow/sdk/definitions/mappedoperator.py | 8 +-
task-sdk/tests/task_sdk/bases/test_operator.py | 31 ++++++--
13 files changed, 136 insertions(+), 168 deletions(-)
diff --git a/airflow-core/newsfragments/59780.significant.rst
b/airflow-core/newsfragments/59780.significant.rst
deleted file mode 100644
index df4c64895c7..00000000000
--- a/airflow-core/newsfragments/59780.significant.rst
+++ /dev/null
@@ -1,4 +0,0 @@
-Usused methods removed from (experimental) PriorityWeightStrategy
-
-Functions ``serialize`` and ``deserialize`` were never used anywhere, and have
-been removed from the class. They should not be relied in in user code.
diff --git a/airflow-core/src/airflow/serialization/serialized_objects.py
b/airflow-core/src/airflow/serialization/serialized_objects.py
index dad7e4834f8..170879c5339 100644
--- a/airflow-core/src/airflow/serialization/serialized_objects.py
+++ b/airflow-core/src/airflow/serialization/serialized_objects.py
@@ -96,7 +96,6 @@ from airflow.task.priority_strategy import (
PriorityWeightStrategy,
airflow_priority_weight_strategies,
airflow_priority_weight_strategies_classes,
- validate_and_load_priority_weight_strategy,
)
from airflow.timetables.base import DagRunInfo, Timetable
from airflow.triggers.base import BaseTrigger, StartTriggerArgs
@@ -250,7 +249,7 @@ def decode_partition_mapper(var: dict[str, Any]) ->
PartitionMapper:
return partition_mapper_class.deserialize(var[Encoding.VAR])
-def encode_priority_weight_strategy(var: PriorityWeightStrategy | str) -> str:
+def encode_priority_weight_strategy(var: PriorityWeightStrategy) -> str:
"""
Encode a priority weight strategy instance.
@@ -258,7 +257,7 @@ def encode_priority_weight_strategy(var:
PriorityWeightStrategy | str) -> str:
for any parameters to be passed to it. If you need to store the
parameters, you
should store them in the class itself.
"""
- priority_weight_strategy_class =
type(validate_and_load_priority_weight_strategy(var))
+ priority_weight_strategy_class = type(var)
if priority_weight_strategy_class in
airflow_priority_weight_strategies_classes:
return
airflow_priority_weight_strategies_classes[priority_weight_strategy_class]
importable_string = qualname(priority_weight_strategy_class)
diff --git a/airflow-core/src/airflow/task/priority_strategy.py
b/airflow-core/src/airflow/task/priority_strategy.py
index 65b6e3ff34e..a330ca9198b 100644
--- a/airflow-core/src/airflow/task/priority_strategy.py
+++ b/airflow-core/src/airflow/task/priority_strategy.py
@@ -20,9 +20,8 @@
from __future__ import annotations
from abc import ABC, abstractmethod
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Any
-from airflow._shared.module_loading import qualname
from airflow.task.weight_rule import WeightRule
if TYPE_CHECKING:
@@ -42,22 +41,46 @@ class PriorityWeightStrategy(ABC):
"""
@abstractmethod
- def get_weight(self, ti: TaskInstance) -> int:
+ def get_weight(self, ti: TaskInstance):
"""Get the priority weight of a task."""
- raise NotImplementedError("must be implemented by a subclass")
+ ...
+
+ @classmethod
+ def deserialize(cls, data: dict[str, Any]) -> PriorityWeightStrategy:
+ """
+ Deserialize a priority weight strategy from data.
+
+ This is called when a serialized DAG is deserialized. ``data`` will be
whatever
+ was returned by ``serialize`` during DAG serialization. The default
+ implementation constructs the priority weight strategy without any
arguments.
+ """
+ return cls(**data)
+
+ def serialize(self) -> dict[str, Any]:
+ """
+ Serialize the priority weight strategy for JSON encoding.
+
+ This is called during DAG serialization to store priority weight
strategy information
+ in the database. This should return a JSON-serializable dict that will
be fed into
+ ``deserialize`` when the DAG is deserialized. The default
implementation returns
+ an empty dict.
+ """
+ return {}
def __eq__(self, other: object) -> bool:
"""Equality comparison."""
- return isinstance(other, type(self))
+ if not isinstance(other, type(self)):
+ return False
+ return self.serialize() == other.serialize()
- def __hash__(self) -> int:
- return hash(None)
+ def __hash__(self):
+ return hash(self.serialize())
class _AbsolutePriorityWeightStrategy(PriorityWeightStrategy):
"""Priority weight strategy that uses the task's priority weight
directly."""
- def get_weight(self, ti: TaskInstance) -> int:
+ def get_weight(self, ti: TaskInstance):
if TYPE_CHECKING:
assert ti.task
return ti.task.priority_weight
@@ -81,7 +104,7 @@ class
_DownstreamPriorityWeightStrategy(PriorityWeightStrategy):
class _UpstreamPriorityWeightStrategy(PriorityWeightStrategy):
"""Priority weight strategy that uses the sum of the priority weights of
all upstream tasks."""
- def get_weight(self, ti: TaskInstance) -> int:
+ def get_weight(self, ti: TaskInstance):
if TYPE_CHECKING:
assert ti.task
dag = ti.task.get_dag()
@@ -93,9 +116,6 @@ class
_UpstreamPriorityWeightStrategy(PriorityWeightStrategy):
airflow_priority_weight_strategies: dict[str, type[PriorityWeightStrategy]] = {
- qualname(_AbsolutePriorityWeightStrategy): _AbsolutePriorityWeightStrategy,
- qualname(_DownstreamPriorityWeightStrategy):
_DownstreamPriorityWeightStrategy,
- qualname(_UpstreamPriorityWeightStrategy): _UpstreamPriorityWeightStrategy,
WeightRule.ABSOLUTE: _AbsolutePriorityWeightStrategy,
WeightRule.DOWNSTREAM: _DownstreamPriorityWeightStrategy,
WeightRule.UPSTREAM: _UpstreamPriorityWeightStrategy,
@@ -103,9 +123,7 @@ airflow_priority_weight_strategies: dict[str,
type[PriorityWeightStrategy]] = {
airflow_priority_weight_strategies_classes = {
- _AbsolutePriorityWeightStrategy: WeightRule.ABSOLUTE,
- _DownstreamPriorityWeightStrategy: WeightRule.DOWNSTREAM,
- _UpstreamPriorityWeightStrategy: WeightRule.UPSTREAM,
+ cls: name for name, cls in airflow_priority_weight_strategies.items()
}
@@ -121,6 +139,7 @@ def validate_and_load_priority_weight_strategy(
:meta private:
"""
+ from airflow._shared.module_loading import qualname
from airflow.serialization.serialized_objects import
_get_registered_priority_weight_strategy
if priority_weight_strategy is None:
diff --git a/airflow-core/tests/unit/jobs/test_triggerer_job.py
b/airflow-core/tests/unit/jobs/test_triggerer_job.py
index 0e3b1647482..b9df47ec768 100644
--- a/airflow-core/tests/unit/jobs/test_triggerer_job.py
+++ b/airflow-core/tests/unit/jobs/test_triggerer_job.py
@@ -130,17 +130,12 @@ def create_trigger_in_db(session, trigger, operator=None):
operator = BaseOperator(task_id="test_ti", dag=dag)
session.add(dag_model)
- lazy_serdag = LazyDeserializedDAG.from_dag(dag)
- SerializedDagModel.write_dag(lazy_serdag, bundle_name=bundle_name)
+ SerializedDagModel.write_dag(LazyDeserializedDAG.from_dag(dag),
bundle_name=bundle_name)
session.add(run)
session.add(trigger_orm)
session.flush()
dag_version = DagVersion.get_latest_version(dag.dag_id)
- task_instance = TaskInstance(
- lazy_serdag._real_dag.get_task(operator.task_id),
- run_id=run.run_id,
- dag_version_id=dag_version.id,
- )
+ task_instance = TaskInstance(operator, run_id=run.run_id,
dag_version_id=dag_version.id)
task_instance.trigger_id = trigger_orm.id
session.add(task_instance)
session.commit()
@@ -445,8 +440,7 @@ class TestTriggerRunner:
@pytest.mark.asyncio
[email protected]("testing_dag_bundle")
-async def test_trigger_create_race_condition_38599(session,
supervisor_builder):
+async def test_trigger_create_race_condition_38599(session,
supervisor_builder, testing_dag_bundle):
"""
This verifies the resolution of race condition documented in github issue
#38599.
More details in the issue description.
@@ -471,17 +465,14 @@ async def
test_trigger_create_race_condition_38599(session, supervisor_builder):
session.flush()
bundle_name = "testing"
- with DAG(dag_id="test-dag") as dag:
- task = PythonOperator(task_id="dummy-task", python_callable=print)
+ dag = DAG(dag_id="test-dag")
dm = DagModel(dag_id="test-dag", bundle_name=bundle_name)
session.add(dm)
-
- lazy_serdag = LazyDeserializedDAG.from_dag(dag)
- SerializedDagModel.write_dag(lazy_serdag, bundle_name=bundle_name)
+ SerializedDagModel.write_dag(LazyDeserializedDAG.from_dag(dag),
bundle_name=bundle_name)
dag_run = DagRun(dag.dag_id, run_id="abc", run_type="none",
run_after=timezone.utcnow())
dag_version = DagVersion.get_latest_version(dag.dag_id)
ti = TaskInstance(
- lazy_serdag._real_dag.get_task(task.task_id),
+ PythonOperator(task_id="dummy-task", python_callable=print),
run_id=dag_run.run_id,
state=TaskInstanceState.DEFERRED,
dag_version_id=dag_version.id,
diff --git a/airflow-core/tests/unit/models/test_mappedoperator.py
b/airflow-core/tests/unit/models/test_mappedoperator.py
index 408e76b22a6..e5bfe78a059 100644
--- a/airflow-core/tests/unit/models/test_mappedoperator.py
+++ b/airflow-core/tests/unit/models/test_mappedoperator.py
@@ -34,6 +34,7 @@ from airflow.models.taskmap import TaskMap
from airflow.providers.standard.operators.python import PythonOperator
from airflow.sdk import DAG, BaseOperator, TaskGroup, setup, task, task_group,
teardown
from airflow.serialization.definitions.baseoperator import
SerializedBaseOperator
+from airflow.task.priority_strategy import PriorityWeightStrategy
from airflow.task.trigger_rule import TriggerRule
from airflow.utils.state import TaskInstanceState
@@ -1523,7 +1524,7 @@ class TestMappedSetupTeardown:
assert op.pool == SerializedBaseOperator.pool
assert op.pool_slots == SerializedBaseOperator.pool_slots
assert op.priority_weight == SerializedBaseOperator.priority_weight
- assert op.weight_rule == "downstream"
+ assert isinstance(op.weight_rule, PriorityWeightStrategy)
assert op.email == email
assert op.execution_timeout == execution_timeout
assert op.retry_delay == retry_delay
diff --git a/airflow-core/tests/unit/models/test_taskinstance.py
b/airflow-core/tests/unit/models/test_taskinstance.py
index da49fd81b01..e83f906fa04 100644
--- a/airflow-core/tests/unit/models/test_taskinstance.py
+++ b/airflow-core/tests/unit/models/test_taskinstance.py
@@ -82,7 +82,7 @@ from airflow.sdk.execution_time.comms import AssetEventsResult
from airflow.serialization.definitions.assets import SerializedAsset
from airflow.serialization.definitions.dag import SerializedDAG
from airflow.serialization.encoders import ensure_serialized_asset
-from airflow.serialization.serialized_objects import OperatorSerialization,
create_scheduler_operator
+from airflow.serialization.serialized_objects import OperatorSerialization
from airflow.ti_deps.dep_context import DepContext
from airflow.ti_deps.dependencies_deps import REQUEUEABLE_DEPS, RUNNING_DEPS
from airflow.ti_deps.dependencies_states import RUNNABLE_STATES
@@ -2812,7 +2812,7 @@ def test_refresh_from_task(pool_override,
queue_by_policy, monkeypatch):
monkeypatch.setattr("airflow.models.taskinstance.task_instance_mutation_hook",
mock_policy)
- sdk_task = EmptyOperator(
+ task = EmptyOperator(
task_id="empty",
queue=default_queue,
pool="test_pool1",
@@ -2822,28 +2822,27 @@ def test_refresh_from_task(pool_override,
queue_by_policy, monkeypatch):
retries=30,
executor_config={"KubernetesExecutor": {"image":
"myCustomDockerImage"}},
)
- ser_task = create_scheduler_operator(sdk_task)
- ti = TI(ser_task, run_id=None, dag_version_id=mock.MagicMock())
- ti.refresh_from_task(ser_task, pool_override=pool_override)
+ ti = TI(task, run_id=None, dag_version_id=mock.MagicMock())
+ ti.refresh_from_task(task, pool_override=pool_override)
assert ti.queue == expected_queue
if pool_override:
assert ti.pool == pool_override
else:
- assert ti.pool == sdk_task.pool
+ assert ti.pool == task.pool
- assert ti.pool_slots == sdk_task.pool_slots
- assert ti.priority_weight == ser_task.weight_rule.get_weight(ti)
- assert ti.run_as_user == sdk_task.run_as_user
- assert ti.max_tries == sdk_task.retries
- assert ti.executor_config == sdk_task.executor_config
+ assert ti.pool_slots == task.pool_slots
+ assert ti.priority_weight == task.weight_rule.get_weight(ti)
+ assert ti.run_as_user == task.run_as_user
+ assert ti.max_tries == task.retries
+ assert ti.executor_config == task.executor_config
assert ti.operator == EmptyOperator.__name__
# Test that refresh_from_task does not reset ti.max_tries
- expected_max_tries = sdk_task.retries + 10
+ expected_max_tries = task.retries + 10
ti.max_tries = expected_max_tries
- ti.refresh_from_task(ser_task)
+ ti.refresh_from_task(task)
assert ti.max_tries == expected_max_tries
diff --git a/airflow-core/tests/unit/serialization/test_dag_serialization.py
b/airflow-core/tests/unit/serialization/test_dag_serialization.py
index 0cbdedb1b6d..7b87c6e00a8 100644
--- a/airflow-core/tests/unit/serialization/test_dag_serialization.py
+++ b/airflow-core/tests/unit/serialization/test_dag_serialization.py
@@ -84,12 +84,7 @@ from airflow.serialization.serialized_objects import (
OperatorSerialization,
_XComRef,
)
-from airflow.task.priority_strategy import (
- PriorityWeightStrategy,
- _DownstreamPriorityWeightStrategy,
- airflow_priority_weight_strategies,
- validate_and_load_priority_weight_strategy,
-)
+from airflow.task.priority_strategy import _AbsolutePriorityWeightStrategy,
_DownstreamPriorityWeightStrategy
from airflow.ti_deps.deps.ready_to_reschedule import ReadyToRescheduleDep
from airflow.timetables.simple import NullTimetable, OnceTimetable
from airflow.triggers.base import StartTriggerArgs
@@ -784,7 +779,6 @@ class TestStringifiedDAGs:
"inlets",
"outlets",
"task_type",
- "weight_rule",
"_operator_name",
# Type is excluded, so don't check it
"_log",
@@ -818,7 +812,6 @@ class TestStringifiedDAGs:
"operator_class",
"partial_kwargs",
"expand_input",
- "weight_rule",
}
assert serialized_task.task_type == task.task_type
@@ -853,12 +846,6 @@ class TestStringifiedDAGs:
if isinstance(task.params, ParamsDict) and
isinstance(serialized_task.params, ParamsDict):
assert serialized_task.params.dump() == task.params.dump()
- if isinstance(task.weight_rule, PriorityWeightStrategy):
- assert task.weight_rule == serialized_task.weight_rule
- else:
- task_weight_strat =
validate_and_load_priority_weight_strategy(task.weight_rule)
- assert task_weight_strat == serialized_task.weight_rule
-
if isinstance(task, MappedOperator):
# MappedOperator.operator_class now stores only minimal type
information
# for memory efficiency (task_type and _operator_name).
@@ -1585,7 +1572,7 @@ class TestStringifiedDAGs:
"ui_fgcolor": "#000",
"wait_for_downstream": False,
"wait_for_past_depends_before_skipping": False,
- "weight_rule": WeightRule.DOWNSTREAM,
+ "weight_rule": _DownstreamPriorityWeightStrategy(),
"multiple_outputs": False,
}, """
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
@@ -3999,6 +3986,27 @@ def
test_task_callback_backward_compatibility(old_callback_name, new_callback_na
assert getattr(deserialized_task_empty, new_callback_name) is False
+def test_weight_rule_absolute_serialization_deserialization():
+ """Test that weight_rule can be serialized and deserialized correctly."""
+ from airflow.sdk import task
+
+ with DAG("test_weight_rule_dag") as dag:
+
+ @task(weight_rule=WeightRule.ABSOLUTE)
+ def test_task():
+ return "test"
+
+ test_task()
+
+ serialized_dag = DagSerialization.to_dict(dag)
+ assert serialized_dag["dag"]["tasks"][0]["__var"]["weight_rule"] ==
"absolute"
+
+ deserialized_dag = DagSerialization.from_dict(serialized_dag)
+
+ deserialized_task = deserialized_dag.task_dict["test_task"]
+ assert isinstance(deserialized_task.weight_rule,
_AbsolutePriorityWeightStrategy)
+
+
class TestClientDefaultsGeneration:
"""Test client defaults generation functionality."""
@@ -4471,50 +4479,3 @@ def
test_dag_default_args_callbacks_serialization(callbacks, expected_has_flags,
deserialized_dag = DagSerialization.deserialize_dag(serialized_dag_dict)
assert deserialized_dag.dag_id == "test_default_args_callbacks"
-
-
-class RegisteredPriorityWeightStrategy(PriorityWeightStrategy):
- def get_weight(self, ti):
- return 99
-
-
-class TestWeightRule:
- def test_default(self):
- sdkop = BaseOperator(task_id="should_fail")
- serop =
OperatorSerialization.deserialize(OperatorSerialization.serialize(sdkop))
- assert serop.weight_rule == _DownstreamPriorityWeightStrategy()
-
- @pytest.mark.parametrize(("value", "expected"),
list(airflow_priority_weight_strategies.items()))
- def test_builtin(self, value, expected):
- sdkop = BaseOperator(task_id="should_fail", weight_rule=value)
- serop =
OperatorSerialization.deserialize(OperatorSerialization.serialize(sdkop))
- assert serop.weight_rule == expected()
-
- def test_custom(self):
- sdkop = BaseOperator(task_id="should_fail",
weight_rule=RegisteredPriorityWeightStrategy())
- with mock.patch(
-
"airflow.serialization.serialized_objects._get_registered_priority_weight_strategy",
- return_value=RegisteredPriorityWeightStrategy,
- ) as mock_get_registered_priority_weight_strategy:
- serop =
OperatorSerialization.deserialize(OperatorSerialization.serialize(sdkop))
-
- assert serop.weight_rule == RegisteredPriorityWeightStrategy()
- assert mock_get_registered_priority_weight_strategy.mock_calls == [
-
mock.call("unit.serialization.test_dag_serialization.RegisteredPriorityWeightStrategy"),
-
mock.call("unit.serialization.test_dag_serialization.RegisteredPriorityWeightStrategy"),
-
mock.call("unit.serialization.test_dag_serialization.RegisteredPriorityWeightStrategy"),
- ]
-
- def test_invalid(self):
- op = BaseOperator(task_id="should_fail", weight_rule="no rule")
- with pytest.raises(ValueError, match="Unknown priority strategy"):
- OperatorSerialization.serialize(op)
-
- def test_not_registered_custom(self):
- class NotRegisteredPriorityWeightStrategy(PriorityWeightStrategy):
- def get_weight(self, ti):
- return 99
-
- op = BaseOperator(task_id="empty_task",
weight_rule=NotRegisteredPriorityWeightStrategy())
- with pytest.raises(ValueError, match="Unknown priority strategy"):
- OperatorSerialization.serialize(op)
diff --git a/providers/celery/tests/integration/celery/test_celery_executor.py
b/providers/celery/tests/integration/celery/test_celery_executor.py
index 18c0f10f42a..4d60d6546c7 100644
--- a/providers/celery/tests/integration/celery/test_celery_executor.py
+++ b/providers/celery/tests/integration/celery/test_celery_executor.py
@@ -49,7 +49,7 @@ from airflow.providers.standard.operators.bash import
BashOperator
from airflow.utils.state import State
from tests_common.test_utils import db
-from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS,
AIRFLOW_V_3_1_PLUS
+from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
logger = logging.getLogger(__name__)
@@ -215,21 +215,13 @@ class TestCeleryExecutor:
# fake_execute_command takes no arguments while execute_workload
takes 1,
# which will cause TypeError when calling task.apply_async()
executor = celery_executor.CeleryExecutor()
- with DAG(dag_id="dag_id") as dag:
- task = BashOperator(
- task_id="test",
- bash_command="true",
- start_date=datetime.now(),
- )
- if AIRFLOW_V_3_1_PLUS:
- from tests_common.test_utils.dag import create_scheduler_dag
-
- ti = TaskInstance(
- task=create_scheduler_dag(dag).get_task(task.task_id),
- run_id="abc",
- dag_version_id=uuid6.uuid7(),
- )
- elif AIRFLOW_V_3_0_PLUS:
+ task = BashOperator(
+ task_id="test",
+ bash_command="true",
+ dag=DAG(dag_id="dag_id"),
+ start_date=datetime.now(),
+ )
+ if AIRFLOW_V_3_0_PLUS:
ti = TaskInstance(task=task, run_id="abc",
dag_version_id=uuid6.uuid7())
else:
ti = TaskInstance(task=task, run_id="abc")
@@ -262,21 +254,13 @@ class TestCeleryExecutor:
assert executor.task_publish_retries == {}
assert executor.task_publish_max_retries == 3, "Assert Default Max
Retries is 3"
- with DAG(dag_id="id") as dag:
- task = BashOperator(
- task_id="test",
- bash_command="true",
- start_date=datetime.now(),
- )
- if AIRFLOW_V_3_1_PLUS:
- from tests_common.test_utils.dag import create_scheduler_dag
-
- ti = TaskInstance(
- task=create_scheduler_dag(dag).get_task(task.task_id),
- run_id="abc",
- dag_version_id=uuid6.uuid7(),
- )
- elif AIRFLOW_V_3_0_PLUS:
+ task = BashOperator(
+ task_id="test",
+ bash_command="true",
+ dag=DAG(dag_id="id"),
+ start_date=datetime.now(),
+ )
+ if AIRFLOW_V_3_0_PLUS:
ti = TaskInstance(task=task, run_id="abc",
dag_version_id=uuid6.uuid7())
else:
ti = TaskInstance(task=task, run_id="abc")
diff --git
a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/cli/kubernetes_command.py
b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/cli/kubernetes_command.py
index 3b298a17791..e2b72589351 100644
---
a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/cli/kubernetes_command.py
+++
b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/cli/kubernetes_command.py
@@ -32,11 +32,7 @@ from
airflow.providers.cncf.kubernetes.executors.kubernetes_executor import Kube
from airflow.providers.cncf.kubernetes.kube_client import get_kube_client
from airflow.providers.cncf.kubernetes.kubernetes_helper_functions import
create_unique_id
from airflow.providers.cncf.kubernetes.pod_generator import PodGenerator,
generate_pod_command_args
-from airflow.providers.cncf.kubernetes.version_compat import (
- AIRFLOW_V_3_0_PLUS,
- AIRFLOW_V_3_1_PLUS,
- AIRFLOW_V_3_2_PLUS,
-)
+from airflow.providers.cncf.kubernetes.version_compat import
AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_1_PLUS
from airflow.utils import cli as cli_utils, yaml
from airflow.utils.providers_configuration_loader import
providers_configuration_loaded
from airflow.utils.types import DagRunType
@@ -74,13 +70,7 @@ def generate_pod_yaml(args):
kube_config = KubeConfig()
for task in dag.tasks:
- if AIRFLOW_V_3_2_PLUS:
- from uuid6 import uuid7
-
- from airflow.serialization.serialized_objects import
create_scheduler_operator
-
- ti = TaskInstance(create_scheduler_operator(task),
run_id=dr.run_id, dag_version_id=uuid7())
- elif AIRFLOW_V_3_0_PLUS:
+ if AIRFLOW_V_3_0_PLUS:
from uuid6 import uuid7
ti = TaskInstance(task, run_id=dr.run_id, dag_version_id=uuid7())
diff --git
a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/version_compat.py
b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/version_compat.py
index 7751da07081..2fb2ac93a12 100644
---
a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/version_compat.py
+++
b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/version_compat.py
@@ -34,11 +34,9 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]:
AIRFLOW_V_3_0_PLUS = get_base_airflow_version_tuple() >= (3, 0, 0)
AIRFLOW_V_3_1_PLUS = get_base_airflow_version_tuple() >= (3, 1, 0)
-AIRFLOW_V_3_2_PLUS = get_base_airflow_version_tuple() >= (3, 2, 0)
__all__ = [
"AIRFLOW_V_3_0_PLUS",
"AIRFLOW_V_3_1_PLUS",
- "AIRFLOW_V_3_2_PLUS",
]
diff --git a/task-sdk/src/airflow/sdk/bases/operator.py
b/task-sdk/src/airflow/sdk/bases/operator.py
index f8163a98f83..7f6d6cf297a 100644
--- a/task-sdk/src/airflow/sdk/bases/operator.py
+++ b/task-sdk/src/airflow/sdk/bases/operator.py
@@ -62,6 +62,11 @@ from airflow.sdk.definitions.edges import EdgeModifier
from airflow.sdk.definitions.mappedoperator import OperatorPartial,
validate_mapping_kwargs
from airflow.sdk.definitions.param import ParamsDict
from airflow.sdk.exceptions import RemovedInAirflow4Warning
+from airflow.task.priority_strategy import (
+ PriorityWeightStrategy,
+ airflow_priority_weight_strategies,
+ validate_and_load_priority_weight_strategy,
+)
# Databases do not support arbitrary precision integers, so we need to limit
the range of priority weights.
# postgres: -2147483648 to +2147483647 (see
https://www.postgresql.org/docs/current/datatype-numeric.html)
@@ -838,7 +843,9 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
params: ParamsDict | dict = field(default_factory=ParamsDict)
default_args: dict | None = None
priority_weight: int = DEFAULT_PRIORITY_WEIGHT
- weight_rule: PriorityWeightStrategy | str =
field(default=DEFAULT_WEIGHT_RULE)
+ weight_rule: PriorityWeightStrategy = field(
+ default_factory=airflow_priority_weight_strategies[DEFAULT_WEIGHT_RULE]
+ )
queue: str = DEFAULT_QUEUE
pool: str = DEFAULT_POOL_NAME
pool_slots: int = DEFAULT_POOL_SLOTS
@@ -1136,7 +1143,7 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
self.params = ParamsDict(params)
self.priority_weight = priority_weight
- self.weight_rule = weight_rule
+ self.weight_rule =
validate_and_load_priority_weight_strategy(weight_rule)
self.max_active_tis_per_dag: int | None = max_active_tis_per_dag
self.max_active_tis_per_dagrun: int | None = max_active_tis_per_dagrun
diff --git a/task-sdk/src/airflow/sdk/definitions/mappedoperator.py
b/task-sdk/src/airflow/sdk/definitions/mappedoperator.py
index dd3ea15ccc6..6aff9ca3c68 100644
--- a/task-sdk/src/airflow/sdk/definitions/mappedoperator.py
+++ b/task-sdk/src/airflow/sdk/definitions/mappedoperator.py
@@ -51,6 +51,7 @@ from airflow.sdk.definitions._internal.expandinput import (
)
from airflow.sdk.definitions._internal.types import NOTSET
from airflow.serialization.enums import DagAttributeTypes
+from airflow.task.priority_strategy import PriorityWeightStrategy,
validate_and_load_priority_weight_strategy
if TYPE_CHECKING:
import datetime
@@ -66,7 +67,6 @@ if TYPE_CHECKING:
)
from airflow.sdk.definitions.operator_resources import Resources
from airflow.sdk.definitions.param import ParamsDict
- from airflow.task.priority_strategy import PriorityWeightStrategy
from airflow.triggers.base import StartTriggerArgs
ValidationSource = Literal["expand"] | Literal["partial"]
@@ -556,11 +556,13 @@ class MappedOperator(AbstractOperator):
@property
def weight_rule(self) -> PriorityWeightStrategy:
- return self.partial_kwargs.get("weight_rule", DEFAULT_WEIGHT_RULE)
+ return validate_and_load_priority_weight_strategy(
+ self.partial_kwargs.get("weight_rule", DEFAULT_WEIGHT_RULE)
+ )
@weight_rule.setter
def weight_rule(self, value: str | PriorityWeightStrategy) -> None:
- self.partial_kwargs["weight_rule"] = value
+ self.partial_kwargs["weight_rule"] =
validate_and_load_priority_weight_strategy(value)
@property
def max_active_tis_per_dag(self) -> int | None:
diff --git a/task-sdk/tests/task_sdk/bases/test_operator.py
b/task-sdk/tests/task_sdk/bases/test_operator.py
index 85b127dfce6..5db27774e23 100644
--- a/task-sdk/tests/task_sdk/bases/test_operator.py
+++ b/task-sdk/tests/task_sdk/bases/test_operator.py
@@ -29,7 +29,7 @@ import jinja2
import pytest
import structlog
-from airflow.sdk import DAG, Label, TaskGroup, task as task_decorator
+from airflow.sdk import task as task_decorator
from airflow.sdk._shared.secrets_masker import _secrets_masker, mask_secret
from airflow.sdk.bases.operator import (
BaseOperator,
@@ -39,8 +39,12 @@ from airflow.sdk.bases.operator import (
chain_linear,
cross_downstream,
)
+from airflow.sdk.definitions.dag import DAG
+from airflow.sdk.definitions.edges import Label
from airflow.sdk.definitions.param import ParamsDict
+from airflow.sdk.definitions.taskgroup import TaskGroup
from airflow.sdk.definitions.template import literal
+from airflow.task.priority_strategy import _DownstreamPriorityWeightStrategy,
_UpstreamPriorityWeightStrategy
DEFAULT_DATE = datetime(2016, 1, 1, tzinfo=timezone.utc)
@@ -267,12 +271,29 @@ class TestBaseOperator:
def test_weight_rule_default(self):
op = BaseOperator(task_id="test_task")
- assert op.weight_rule == "downstream"
+ assert _DownstreamPriorityWeightStrategy() == op.weight_rule
def test_weight_rule_override(self):
- whatever_value = object()
- op = BaseOperator(task_id="test_task", weight_rule=whatever_value)
- assert op.weight_rule is whatever_value
+ op = BaseOperator(task_id="test_task", weight_rule="upstream")
+ assert _UpstreamPriorityWeightStrategy() == op.weight_rule
+
+ def test_dag_task_invalid_weight_rule(self):
+ # Test if we enter an invalid weight rule
+ with pytest.raises(ValueError, match="Unknown priority strategy"):
+ BaseOperator(task_id="should_fail", weight_rule="no rule")
+
+ def test_dag_task_not_registered_weight_strategy(self):
+ from airflow.task.priority_strategy import PriorityWeightStrategy
+
+ class NotRegisteredPriorityWeightStrategy(PriorityWeightStrategy):
+ def get_weight(self, ti):
+ return 99
+
+ with pytest.raises(ValueError, match="Unknown priority strategy"):
+ BaseOperator(
+ task_id="empty_task",
+ weight_rule=NotRegisteredPriorityWeightStrategy(),
+ )
def test_db_safe_priority(self):
"""Test the db_safe_priority function."""