This is an automated email from the ASF dual-hosted git repository. jscheffl pushed a commit to branch revert-59780-sdk-weight-rule in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 4e43a718275241ceb7695019eb3407e58b687a56 Author: Jens Scheffler <[email protected]> AuthorDate: Fri Dec 26 17:04:46 2025 +0100 Revert "Remove PriorityWeightStrategy reference in SDK (#59780)" This reverts commit 60b4ed48e1a2a4b16d4de1ff4b04f61a9d7253c1. --- airflow-core/newsfragments/59780.significant.rst | 4 - .../airflow/serialization/serialized_objects.py | 5 +- airflow-core/src/airflow/task/priority_strategy.py | 49 +++++++++---- airflow-core/tests/unit/jobs/test_triggerer_job.py | 21 ++---- airflow-core/tests/unit/models/test_dagrun.py | 2 +- .../tests/unit/models/test_mappedoperator.py | 3 +- .../tests/unit/models/test_taskinstance.py | 25 +++---- .../unit/serialization/test_dag_serialization.py | 85 ++++++---------------- .../integration/celery/test_celery_executor.py | 46 ++++-------- .../cncf/kubernetes/cli/kubernetes_command.py | 14 +--- .../providers/cncf/kubernetes/version_compat.py | 2 - task-sdk/src/airflow/sdk/bases/operator.py | 11 ++- .../src/airflow/sdk/definitions/mappedoperator.py | 8 +- task-sdk/tests/task_sdk/bases/test_operator.py | 31 ++++++-- 14 files changed, 137 insertions(+), 169 deletions(-) diff --git a/airflow-core/newsfragments/59780.significant.rst b/airflow-core/newsfragments/59780.significant.rst deleted file mode 100644 index df4c64895c7..00000000000 --- a/airflow-core/newsfragments/59780.significant.rst +++ /dev/null @@ -1,4 +0,0 @@ -Usused methods removed from (experimental) PriorityWeightStrategy - -Functions ``serialize`` and ``deserialize`` were never used anywhere, and have -been removed from the class. They should not be relied in in user code. diff --git a/airflow-core/src/airflow/serialization/serialized_objects.py b/airflow-core/src/airflow/serialization/serialized_objects.py index dad7e4834f8..170879c5339 100644 --- a/airflow-core/src/airflow/serialization/serialized_objects.py +++ b/airflow-core/src/airflow/serialization/serialized_objects.py @@ -96,7 +96,6 @@ from airflow.task.priority_strategy import ( PriorityWeightStrategy, airflow_priority_weight_strategies, airflow_priority_weight_strategies_classes, - validate_and_load_priority_weight_strategy, ) from airflow.timetables.base import DagRunInfo, Timetable from airflow.triggers.base import BaseTrigger, StartTriggerArgs @@ -250,7 +249,7 @@ def decode_partition_mapper(var: dict[str, Any]) -> PartitionMapper: return partition_mapper_class.deserialize(var[Encoding.VAR]) -def encode_priority_weight_strategy(var: PriorityWeightStrategy | str) -> str: +def encode_priority_weight_strategy(var: PriorityWeightStrategy) -> str: """ Encode a priority weight strategy instance. @@ -258,7 +257,7 @@ def encode_priority_weight_strategy(var: PriorityWeightStrategy | str) -> str: for any parameters to be passed to it. If you need to store the parameters, you should store them in the class itself. """ - priority_weight_strategy_class = type(validate_and_load_priority_weight_strategy(var)) + priority_weight_strategy_class = type(var) if priority_weight_strategy_class in airflow_priority_weight_strategies_classes: return airflow_priority_weight_strategies_classes[priority_weight_strategy_class] importable_string = qualname(priority_weight_strategy_class) diff --git a/airflow-core/src/airflow/task/priority_strategy.py b/airflow-core/src/airflow/task/priority_strategy.py index 65b6e3ff34e..a330ca9198b 100644 --- a/airflow-core/src/airflow/task/priority_strategy.py +++ b/airflow-core/src/airflow/task/priority_strategy.py @@ -20,9 +20,8 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any -from airflow._shared.module_loading import qualname from airflow.task.weight_rule import WeightRule if TYPE_CHECKING: @@ -42,22 +41,46 @@ class PriorityWeightStrategy(ABC): """ @abstractmethod - def get_weight(self, ti: TaskInstance) -> int: + def get_weight(self, ti: TaskInstance): """Get the priority weight of a task.""" - raise NotImplementedError("must be implemented by a subclass") + ... + + @classmethod + def deserialize(cls, data: dict[str, Any]) -> PriorityWeightStrategy: + """ + Deserialize a priority weight strategy from data. + + This is called when a serialized DAG is deserialized. ``data`` will be whatever + was returned by ``serialize`` during DAG serialization. The default + implementation constructs the priority weight strategy without any arguments. + """ + return cls(**data) + + def serialize(self) -> dict[str, Any]: + """ + Serialize the priority weight strategy for JSON encoding. + + This is called during DAG serialization to store priority weight strategy information + in the database. This should return a JSON-serializable dict that will be fed into + ``deserialize`` when the DAG is deserialized. The default implementation returns + an empty dict. + """ + return {} def __eq__(self, other: object) -> bool: """Equality comparison.""" - return isinstance(other, type(self)) + if not isinstance(other, type(self)): + return False + return self.serialize() == other.serialize() - def __hash__(self) -> int: - return hash(None) + def __hash__(self): + return hash(self.serialize()) class _AbsolutePriorityWeightStrategy(PriorityWeightStrategy): """Priority weight strategy that uses the task's priority weight directly.""" - def get_weight(self, ti: TaskInstance) -> int: + def get_weight(self, ti: TaskInstance): if TYPE_CHECKING: assert ti.task return ti.task.priority_weight @@ -81,7 +104,7 @@ class _DownstreamPriorityWeightStrategy(PriorityWeightStrategy): class _UpstreamPriorityWeightStrategy(PriorityWeightStrategy): """Priority weight strategy that uses the sum of the priority weights of all upstream tasks.""" - def get_weight(self, ti: TaskInstance) -> int: + def get_weight(self, ti: TaskInstance): if TYPE_CHECKING: assert ti.task dag = ti.task.get_dag() @@ -93,9 +116,6 @@ class _UpstreamPriorityWeightStrategy(PriorityWeightStrategy): airflow_priority_weight_strategies: dict[str, type[PriorityWeightStrategy]] = { - qualname(_AbsolutePriorityWeightStrategy): _AbsolutePriorityWeightStrategy, - qualname(_DownstreamPriorityWeightStrategy): _DownstreamPriorityWeightStrategy, - qualname(_UpstreamPriorityWeightStrategy): _UpstreamPriorityWeightStrategy, WeightRule.ABSOLUTE: _AbsolutePriorityWeightStrategy, WeightRule.DOWNSTREAM: _DownstreamPriorityWeightStrategy, WeightRule.UPSTREAM: _UpstreamPriorityWeightStrategy, @@ -103,9 +123,7 @@ airflow_priority_weight_strategies: dict[str, type[PriorityWeightStrategy]] = { airflow_priority_weight_strategies_classes = { - _AbsolutePriorityWeightStrategy: WeightRule.ABSOLUTE, - _DownstreamPriorityWeightStrategy: WeightRule.DOWNSTREAM, - _UpstreamPriorityWeightStrategy: WeightRule.UPSTREAM, + cls: name for name, cls in airflow_priority_weight_strategies.items() } @@ -121,6 +139,7 @@ def validate_and_load_priority_weight_strategy( :meta private: """ + from airflow._shared.module_loading import qualname from airflow.serialization.serialized_objects import _get_registered_priority_weight_strategy if priority_weight_strategy is None: diff --git a/airflow-core/tests/unit/jobs/test_triggerer_job.py b/airflow-core/tests/unit/jobs/test_triggerer_job.py index 0e3b1647482..b9df47ec768 100644 --- a/airflow-core/tests/unit/jobs/test_triggerer_job.py +++ b/airflow-core/tests/unit/jobs/test_triggerer_job.py @@ -130,17 +130,12 @@ def create_trigger_in_db(session, trigger, operator=None): operator = BaseOperator(task_id="test_ti", dag=dag) session.add(dag_model) - lazy_serdag = LazyDeserializedDAG.from_dag(dag) - SerializedDagModel.write_dag(lazy_serdag, bundle_name=bundle_name) + SerializedDagModel.write_dag(LazyDeserializedDAG.from_dag(dag), bundle_name=bundle_name) session.add(run) session.add(trigger_orm) session.flush() dag_version = DagVersion.get_latest_version(dag.dag_id) - task_instance = TaskInstance( - lazy_serdag._real_dag.get_task(operator.task_id), - run_id=run.run_id, - dag_version_id=dag_version.id, - ) + task_instance = TaskInstance(operator, run_id=run.run_id, dag_version_id=dag_version.id) task_instance.trigger_id = trigger_orm.id session.add(task_instance) session.commit() @@ -445,8 +440,7 @@ class TestTriggerRunner: @pytest.mark.asyncio [email protected]("testing_dag_bundle") -async def test_trigger_create_race_condition_38599(session, supervisor_builder): +async def test_trigger_create_race_condition_38599(session, supervisor_builder, testing_dag_bundle): """ This verifies the resolution of race condition documented in github issue #38599. More details in the issue description. @@ -471,17 +465,14 @@ async def test_trigger_create_race_condition_38599(session, supervisor_builder): session.flush() bundle_name = "testing" - with DAG(dag_id="test-dag") as dag: - task = PythonOperator(task_id="dummy-task", python_callable=print) + dag = DAG(dag_id="test-dag") dm = DagModel(dag_id="test-dag", bundle_name=bundle_name) session.add(dm) - - lazy_serdag = LazyDeserializedDAG.from_dag(dag) - SerializedDagModel.write_dag(lazy_serdag, bundle_name=bundle_name) + SerializedDagModel.write_dag(LazyDeserializedDAG.from_dag(dag), bundle_name=bundle_name) dag_run = DagRun(dag.dag_id, run_id="abc", run_type="none", run_after=timezone.utcnow()) dag_version = DagVersion.get_latest_version(dag.dag_id) ti = TaskInstance( - lazy_serdag._real_dag.get_task(task.task_id), + PythonOperator(task_id="dummy-task", python_callable=print), run_id=dag_run.run_id, state=TaskInstanceState.DEFERRED, dag_version_id=dag_version.id, diff --git a/airflow-core/tests/unit/models/test_dagrun.py b/airflow-core/tests/unit/models/test_dagrun.py index bd904d761b7..5d2e3d7f9e4 100644 --- a/airflow-core/tests/unit/models/test_dagrun.py +++ b/airflow-core/tests/unit/models/test_dagrun.py @@ -1479,7 +1479,7 @@ def test_mapped_literal_to_xcom_arg_verify_integrity(dag_maker, session): t1 = BaseOperator(task_id="task_1") task_2.expand(arg2=t1.output) - dr.dag = dag_maker.serialized_model.dag + dr.dag = dag_maker.dag dag_version_id = DagVersion.get_latest_version(dag_id=dr.dag_id, session=session).id dr.verify_integrity(dag_version_id=dag_version_id, session=session) diff --git a/airflow-core/tests/unit/models/test_mappedoperator.py b/airflow-core/tests/unit/models/test_mappedoperator.py index 408e76b22a6..e5bfe78a059 100644 --- a/airflow-core/tests/unit/models/test_mappedoperator.py +++ b/airflow-core/tests/unit/models/test_mappedoperator.py @@ -34,6 +34,7 @@ from airflow.models.taskmap import TaskMap from airflow.providers.standard.operators.python import PythonOperator from airflow.sdk import DAG, BaseOperator, TaskGroup, setup, task, task_group, teardown from airflow.serialization.definitions.baseoperator import SerializedBaseOperator +from airflow.task.priority_strategy import PriorityWeightStrategy from airflow.task.trigger_rule import TriggerRule from airflow.utils.state import TaskInstanceState @@ -1523,7 +1524,7 @@ class TestMappedSetupTeardown: assert op.pool == SerializedBaseOperator.pool assert op.pool_slots == SerializedBaseOperator.pool_slots assert op.priority_weight == SerializedBaseOperator.priority_weight - assert op.weight_rule == "downstream" + assert isinstance(op.weight_rule, PriorityWeightStrategy) assert op.email == email assert op.execution_timeout == execution_timeout assert op.retry_delay == retry_delay diff --git a/airflow-core/tests/unit/models/test_taskinstance.py b/airflow-core/tests/unit/models/test_taskinstance.py index da49fd81b01..e83f906fa04 100644 --- a/airflow-core/tests/unit/models/test_taskinstance.py +++ b/airflow-core/tests/unit/models/test_taskinstance.py @@ -82,7 +82,7 @@ from airflow.sdk.execution_time.comms import AssetEventsResult from airflow.serialization.definitions.assets import SerializedAsset from airflow.serialization.definitions.dag import SerializedDAG from airflow.serialization.encoders import ensure_serialized_asset -from airflow.serialization.serialized_objects import OperatorSerialization, create_scheduler_operator +from airflow.serialization.serialized_objects import OperatorSerialization from airflow.ti_deps.dep_context import DepContext from airflow.ti_deps.dependencies_deps import REQUEUEABLE_DEPS, RUNNING_DEPS from airflow.ti_deps.dependencies_states import RUNNABLE_STATES @@ -2812,7 +2812,7 @@ def test_refresh_from_task(pool_override, queue_by_policy, monkeypatch): monkeypatch.setattr("airflow.models.taskinstance.task_instance_mutation_hook", mock_policy) - sdk_task = EmptyOperator( + task = EmptyOperator( task_id="empty", queue=default_queue, pool="test_pool1", @@ -2822,28 +2822,27 @@ def test_refresh_from_task(pool_override, queue_by_policy, monkeypatch): retries=30, executor_config={"KubernetesExecutor": {"image": "myCustomDockerImage"}}, ) - ser_task = create_scheduler_operator(sdk_task) - ti = TI(ser_task, run_id=None, dag_version_id=mock.MagicMock()) - ti.refresh_from_task(ser_task, pool_override=pool_override) + ti = TI(task, run_id=None, dag_version_id=mock.MagicMock()) + ti.refresh_from_task(task, pool_override=pool_override) assert ti.queue == expected_queue if pool_override: assert ti.pool == pool_override else: - assert ti.pool == sdk_task.pool + assert ti.pool == task.pool - assert ti.pool_slots == sdk_task.pool_slots - assert ti.priority_weight == ser_task.weight_rule.get_weight(ti) - assert ti.run_as_user == sdk_task.run_as_user - assert ti.max_tries == sdk_task.retries - assert ti.executor_config == sdk_task.executor_config + assert ti.pool_slots == task.pool_slots + assert ti.priority_weight == task.weight_rule.get_weight(ti) + assert ti.run_as_user == task.run_as_user + assert ti.max_tries == task.retries + assert ti.executor_config == task.executor_config assert ti.operator == EmptyOperator.__name__ # Test that refresh_from_task does not reset ti.max_tries - expected_max_tries = sdk_task.retries + 10 + expected_max_tries = task.retries + 10 ti.max_tries = expected_max_tries - ti.refresh_from_task(ser_task) + ti.refresh_from_task(task) assert ti.max_tries == expected_max_tries diff --git a/airflow-core/tests/unit/serialization/test_dag_serialization.py b/airflow-core/tests/unit/serialization/test_dag_serialization.py index 0cbdedb1b6d..7b87c6e00a8 100644 --- a/airflow-core/tests/unit/serialization/test_dag_serialization.py +++ b/airflow-core/tests/unit/serialization/test_dag_serialization.py @@ -84,12 +84,7 @@ from airflow.serialization.serialized_objects import ( OperatorSerialization, _XComRef, ) -from airflow.task.priority_strategy import ( - PriorityWeightStrategy, - _DownstreamPriorityWeightStrategy, - airflow_priority_weight_strategies, - validate_and_load_priority_weight_strategy, -) +from airflow.task.priority_strategy import _AbsolutePriorityWeightStrategy, _DownstreamPriorityWeightStrategy from airflow.ti_deps.deps.ready_to_reschedule import ReadyToRescheduleDep from airflow.timetables.simple import NullTimetable, OnceTimetable from airflow.triggers.base import StartTriggerArgs @@ -784,7 +779,6 @@ class TestStringifiedDAGs: "inlets", "outlets", "task_type", - "weight_rule", "_operator_name", # Type is excluded, so don't check it "_log", @@ -818,7 +812,6 @@ class TestStringifiedDAGs: "operator_class", "partial_kwargs", "expand_input", - "weight_rule", } assert serialized_task.task_type == task.task_type @@ -853,12 +846,6 @@ class TestStringifiedDAGs: if isinstance(task.params, ParamsDict) and isinstance(serialized_task.params, ParamsDict): assert serialized_task.params.dump() == task.params.dump() - if isinstance(task.weight_rule, PriorityWeightStrategy): - assert task.weight_rule == serialized_task.weight_rule - else: - task_weight_strat = validate_and_load_priority_weight_strategy(task.weight_rule) - assert task_weight_strat == serialized_task.weight_rule - if isinstance(task, MappedOperator): # MappedOperator.operator_class now stores only minimal type information # for memory efficiency (task_type and _operator_name). @@ -1585,7 +1572,7 @@ class TestStringifiedDAGs: "ui_fgcolor": "#000", "wait_for_downstream": False, "wait_for_past_depends_before_skipping": False, - "weight_rule": WeightRule.DOWNSTREAM, + "weight_rule": _DownstreamPriorityWeightStrategy(), "multiple_outputs": False, }, """ !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! @@ -3999,6 +3986,27 @@ def test_task_callback_backward_compatibility(old_callback_name, new_callback_na assert getattr(deserialized_task_empty, new_callback_name) is False +def test_weight_rule_absolute_serialization_deserialization(): + """Test that weight_rule can be serialized and deserialized correctly.""" + from airflow.sdk import task + + with DAG("test_weight_rule_dag") as dag: + + @task(weight_rule=WeightRule.ABSOLUTE) + def test_task(): + return "test" + + test_task() + + serialized_dag = DagSerialization.to_dict(dag) + assert serialized_dag["dag"]["tasks"][0]["__var"]["weight_rule"] == "absolute" + + deserialized_dag = DagSerialization.from_dict(serialized_dag) + + deserialized_task = deserialized_dag.task_dict["test_task"] + assert isinstance(deserialized_task.weight_rule, _AbsolutePriorityWeightStrategy) + + class TestClientDefaultsGeneration: """Test client defaults generation functionality.""" @@ -4471,50 +4479,3 @@ def test_dag_default_args_callbacks_serialization(callbacks, expected_has_flags, deserialized_dag = DagSerialization.deserialize_dag(serialized_dag_dict) assert deserialized_dag.dag_id == "test_default_args_callbacks" - - -class RegisteredPriorityWeightStrategy(PriorityWeightStrategy): - def get_weight(self, ti): - return 99 - - -class TestWeightRule: - def test_default(self): - sdkop = BaseOperator(task_id="should_fail") - serop = OperatorSerialization.deserialize(OperatorSerialization.serialize(sdkop)) - assert serop.weight_rule == _DownstreamPriorityWeightStrategy() - - @pytest.mark.parametrize(("value", "expected"), list(airflow_priority_weight_strategies.items())) - def test_builtin(self, value, expected): - sdkop = BaseOperator(task_id="should_fail", weight_rule=value) - serop = OperatorSerialization.deserialize(OperatorSerialization.serialize(sdkop)) - assert serop.weight_rule == expected() - - def test_custom(self): - sdkop = BaseOperator(task_id="should_fail", weight_rule=RegisteredPriorityWeightStrategy()) - with mock.patch( - "airflow.serialization.serialized_objects._get_registered_priority_weight_strategy", - return_value=RegisteredPriorityWeightStrategy, - ) as mock_get_registered_priority_weight_strategy: - serop = OperatorSerialization.deserialize(OperatorSerialization.serialize(sdkop)) - - assert serop.weight_rule == RegisteredPriorityWeightStrategy() - assert mock_get_registered_priority_weight_strategy.mock_calls == [ - mock.call("unit.serialization.test_dag_serialization.RegisteredPriorityWeightStrategy"), - mock.call("unit.serialization.test_dag_serialization.RegisteredPriorityWeightStrategy"), - mock.call("unit.serialization.test_dag_serialization.RegisteredPriorityWeightStrategy"), - ] - - def test_invalid(self): - op = BaseOperator(task_id="should_fail", weight_rule="no rule") - with pytest.raises(ValueError, match="Unknown priority strategy"): - OperatorSerialization.serialize(op) - - def test_not_registered_custom(self): - class NotRegisteredPriorityWeightStrategy(PriorityWeightStrategy): - def get_weight(self, ti): - return 99 - - op = BaseOperator(task_id="empty_task", weight_rule=NotRegisteredPriorityWeightStrategy()) - with pytest.raises(ValueError, match="Unknown priority strategy"): - OperatorSerialization.serialize(op) diff --git a/providers/celery/tests/integration/celery/test_celery_executor.py b/providers/celery/tests/integration/celery/test_celery_executor.py index 18c0f10f42a..4d60d6546c7 100644 --- a/providers/celery/tests/integration/celery/test_celery_executor.py +++ b/providers/celery/tests/integration/celery/test_celery_executor.py @@ -49,7 +49,7 @@ from airflow.providers.standard.operators.bash import BashOperator from airflow.utils.state import State from tests_common.test_utils import db -from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_1_PLUS +from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS logger = logging.getLogger(__name__) @@ -215,21 +215,13 @@ class TestCeleryExecutor: # fake_execute_command takes no arguments while execute_workload takes 1, # which will cause TypeError when calling task.apply_async() executor = celery_executor.CeleryExecutor() - with DAG(dag_id="dag_id") as dag: - task = BashOperator( - task_id="test", - bash_command="true", - start_date=datetime.now(), - ) - if AIRFLOW_V_3_1_PLUS: - from tests_common.test_utils.dag import create_scheduler_dag - - ti = TaskInstance( - task=create_scheduler_dag(dag).get_task(task.task_id), - run_id="abc", - dag_version_id=uuid6.uuid7(), - ) - elif AIRFLOW_V_3_0_PLUS: + task = BashOperator( + task_id="test", + bash_command="true", + dag=DAG(dag_id="dag_id"), + start_date=datetime.now(), + ) + if AIRFLOW_V_3_0_PLUS: ti = TaskInstance(task=task, run_id="abc", dag_version_id=uuid6.uuid7()) else: ti = TaskInstance(task=task, run_id="abc") @@ -262,21 +254,13 @@ class TestCeleryExecutor: assert executor.task_publish_retries == {} assert executor.task_publish_max_retries == 3, "Assert Default Max Retries is 3" - with DAG(dag_id="id") as dag: - task = BashOperator( - task_id="test", - bash_command="true", - start_date=datetime.now(), - ) - if AIRFLOW_V_3_1_PLUS: - from tests_common.test_utils.dag import create_scheduler_dag - - ti = TaskInstance( - task=create_scheduler_dag(dag).get_task(task.task_id), - run_id="abc", - dag_version_id=uuid6.uuid7(), - ) - elif AIRFLOW_V_3_0_PLUS: + task = BashOperator( + task_id="test", + bash_command="true", + dag=DAG(dag_id="id"), + start_date=datetime.now(), + ) + if AIRFLOW_V_3_0_PLUS: ti = TaskInstance(task=task, run_id="abc", dag_version_id=uuid6.uuid7()) else: ti = TaskInstance(task=task, run_id="abc") diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/cli/kubernetes_command.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/cli/kubernetes_command.py index 3b298a17791..e2b72589351 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/cli/kubernetes_command.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/cli/kubernetes_command.py @@ -32,11 +32,7 @@ from airflow.providers.cncf.kubernetes.executors.kubernetes_executor import Kube from airflow.providers.cncf.kubernetes.kube_client import get_kube_client from airflow.providers.cncf.kubernetes.kubernetes_helper_functions import create_unique_id from airflow.providers.cncf.kubernetes.pod_generator import PodGenerator, generate_pod_command_args -from airflow.providers.cncf.kubernetes.version_compat import ( - AIRFLOW_V_3_0_PLUS, - AIRFLOW_V_3_1_PLUS, - AIRFLOW_V_3_2_PLUS, -) +from airflow.providers.cncf.kubernetes.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_1_PLUS from airflow.utils import cli as cli_utils, yaml from airflow.utils.providers_configuration_loader import providers_configuration_loaded from airflow.utils.types import DagRunType @@ -74,13 +70,7 @@ def generate_pod_yaml(args): kube_config = KubeConfig() for task in dag.tasks: - if AIRFLOW_V_3_2_PLUS: - from uuid6 import uuid7 - - from airflow.serialization.serialized_objects import create_scheduler_operator - - ti = TaskInstance(create_scheduler_operator(task), run_id=dr.run_id, dag_version_id=uuid7()) - elif AIRFLOW_V_3_0_PLUS: + if AIRFLOW_V_3_0_PLUS: from uuid6 import uuid7 ti = TaskInstance(task, run_id=dr.run_id, dag_version_id=uuid7()) diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/version_compat.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/version_compat.py index 7751da07081..2fb2ac93a12 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/version_compat.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/version_compat.py @@ -34,11 +34,9 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]: AIRFLOW_V_3_0_PLUS = get_base_airflow_version_tuple() >= (3, 0, 0) AIRFLOW_V_3_1_PLUS = get_base_airflow_version_tuple() >= (3, 1, 0) -AIRFLOW_V_3_2_PLUS = get_base_airflow_version_tuple() >= (3, 2, 0) __all__ = [ "AIRFLOW_V_3_0_PLUS", "AIRFLOW_V_3_1_PLUS", - "AIRFLOW_V_3_2_PLUS", ] diff --git a/task-sdk/src/airflow/sdk/bases/operator.py b/task-sdk/src/airflow/sdk/bases/operator.py index f8163a98f83..7f6d6cf297a 100644 --- a/task-sdk/src/airflow/sdk/bases/operator.py +++ b/task-sdk/src/airflow/sdk/bases/operator.py @@ -62,6 +62,11 @@ from airflow.sdk.definitions.edges import EdgeModifier from airflow.sdk.definitions.mappedoperator import OperatorPartial, validate_mapping_kwargs from airflow.sdk.definitions.param import ParamsDict from airflow.sdk.exceptions import RemovedInAirflow4Warning +from airflow.task.priority_strategy import ( + PriorityWeightStrategy, + airflow_priority_weight_strategies, + validate_and_load_priority_weight_strategy, +) # Databases do not support arbitrary precision integers, so we need to limit the range of priority weights. # postgres: -2147483648 to +2147483647 (see https://www.postgresql.org/docs/current/datatype-numeric.html) @@ -838,7 +843,9 @@ class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta): params: ParamsDict | dict = field(default_factory=ParamsDict) default_args: dict | None = None priority_weight: int = DEFAULT_PRIORITY_WEIGHT - weight_rule: PriorityWeightStrategy | str = field(default=DEFAULT_WEIGHT_RULE) + weight_rule: PriorityWeightStrategy = field( + default_factory=airflow_priority_weight_strategies[DEFAULT_WEIGHT_RULE] + ) queue: str = DEFAULT_QUEUE pool: str = DEFAULT_POOL_NAME pool_slots: int = DEFAULT_POOL_SLOTS @@ -1136,7 +1143,7 @@ class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta): self.params = ParamsDict(params) self.priority_weight = priority_weight - self.weight_rule = weight_rule + self.weight_rule = validate_and_load_priority_weight_strategy(weight_rule) self.max_active_tis_per_dag: int | None = max_active_tis_per_dag self.max_active_tis_per_dagrun: int | None = max_active_tis_per_dagrun diff --git a/task-sdk/src/airflow/sdk/definitions/mappedoperator.py b/task-sdk/src/airflow/sdk/definitions/mappedoperator.py index dd3ea15ccc6..6aff9ca3c68 100644 --- a/task-sdk/src/airflow/sdk/definitions/mappedoperator.py +++ b/task-sdk/src/airflow/sdk/definitions/mappedoperator.py @@ -51,6 +51,7 @@ from airflow.sdk.definitions._internal.expandinput import ( ) from airflow.sdk.definitions._internal.types import NOTSET from airflow.serialization.enums import DagAttributeTypes +from airflow.task.priority_strategy import PriorityWeightStrategy, validate_and_load_priority_weight_strategy if TYPE_CHECKING: import datetime @@ -66,7 +67,6 @@ if TYPE_CHECKING: ) from airflow.sdk.definitions.operator_resources import Resources from airflow.sdk.definitions.param import ParamsDict - from airflow.task.priority_strategy import PriorityWeightStrategy from airflow.triggers.base import StartTriggerArgs ValidationSource = Literal["expand"] | Literal["partial"] @@ -556,11 +556,13 @@ class MappedOperator(AbstractOperator): @property def weight_rule(self) -> PriorityWeightStrategy: - return self.partial_kwargs.get("weight_rule", DEFAULT_WEIGHT_RULE) + return validate_and_load_priority_weight_strategy( + self.partial_kwargs.get("weight_rule", DEFAULT_WEIGHT_RULE) + ) @weight_rule.setter def weight_rule(self, value: str | PriorityWeightStrategy) -> None: - self.partial_kwargs["weight_rule"] = value + self.partial_kwargs["weight_rule"] = validate_and_load_priority_weight_strategy(value) @property def max_active_tis_per_dag(self) -> int | None: diff --git a/task-sdk/tests/task_sdk/bases/test_operator.py b/task-sdk/tests/task_sdk/bases/test_operator.py index 85b127dfce6..5db27774e23 100644 --- a/task-sdk/tests/task_sdk/bases/test_operator.py +++ b/task-sdk/tests/task_sdk/bases/test_operator.py @@ -29,7 +29,7 @@ import jinja2 import pytest import structlog -from airflow.sdk import DAG, Label, TaskGroup, task as task_decorator +from airflow.sdk import task as task_decorator from airflow.sdk._shared.secrets_masker import _secrets_masker, mask_secret from airflow.sdk.bases.operator import ( BaseOperator, @@ -39,8 +39,12 @@ from airflow.sdk.bases.operator import ( chain_linear, cross_downstream, ) +from airflow.sdk.definitions.dag import DAG +from airflow.sdk.definitions.edges import Label from airflow.sdk.definitions.param import ParamsDict +from airflow.sdk.definitions.taskgroup import TaskGroup from airflow.sdk.definitions.template import literal +from airflow.task.priority_strategy import _DownstreamPriorityWeightStrategy, _UpstreamPriorityWeightStrategy DEFAULT_DATE = datetime(2016, 1, 1, tzinfo=timezone.utc) @@ -267,12 +271,29 @@ class TestBaseOperator: def test_weight_rule_default(self): op = BaseOperator(task_id="test_task") - assert op.weight_rule == "downstream" + assert _DownstreamPriorityWeightStrategy() == op.weight_rule def test_weight_rule_override(self): - whatever_value = object() - op = BaseOperator(task_id="test_task", weight_rule=whatever_value) - assert op.weight_rule is whatever_value + op = BaseOperator(task_id="test_task", weight_rule="upstream") + assert _UpstreamPriorityWeightStrategy() == op.weight_rule + + def test_dag_task_invalid_weight_rule(self): + # Test if we enter an invalid weight rule + with pytest.raises(ValueError, match="Unknown priority strategy"): + BaseOperator(task_id="should_fail", weight_rule="no rule") + + def test_dag_task_not_registered_weight_strategy(self): + from airflow.task.priority_strategy import PriorityWeightStrategy + + class NotRegisteredPriorityWeightStrategy(PriorityWeightStrategy): + def get_weight(self, ti): + return 99 + + with pytest.raises(ValueError, match="Unknown priority strategy"): + BaseOperator( + task_id="empty_task", + weight_rule=NotRegisteredPriorityWeightStrategy(), + ) def test_db_safe_priority(self): """Test the db_safe_priority function."""
