This is an automated email from the ASF dual-hosted git repository.
uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 1aaa87bc624 Make private dag-processing functions look private (#61661)
1aaa87bc624 is described below
commit 1aaa87bc624f30cdd487874537d8039a97c2f95a
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Tue Feb 10 11:51:45 2026 +0800
Make private dag-processing functions look private (#61661)
---
.../api_fastapi/core_api/datamodels/tasks.py | 14 ++++--
.../airflow/serialization/serialized_objects.py | 52 +++++++++-------------
2 files changed, 33 insertions(+), 33 deletions(-)
diff --git a/airflow-core/src/airflow/api_fastapi/core_api/datamodels/tasks.py
b/airflow-core/src/airflow/api_fastapi/core_api/datamodels/tasks.py
index 1c05e47d845..777f97f9b72 100644
--- a/airflow-core/src/airflow/api_fastapi/core_api/datamodels/tasks.py
+++ b/airflow-core/src/airflow/api_fastapi/core_api/datamodels/tasks.py
@@ -24,13 +24,17 @@ from typing import TYPE_CHECKING, Any
from pydantic import computed_field, field_validator, model_validator
+from airflow._shared.module_loading import qualname
from airflow.api_fastapi.common.types import TimeDeltaWithValidation
from airflow.api_fastapi.core_api.base import BaseModel
-from airflow.serialization.serialized_objects import
encode_priority_weight_strategy
-from airflow.task.priority_strategy import PriorityWeightStrategy
+from airflow.task.priority_strategy import (
+ get_weight_rule_from_priority_weight_strategy,
+ validate_and_load_priority_weight_strategy,
+)
if TYPE_CHECKING:
from airflow.serialization.definitions.param import SerializedParamsDict
+ from airflow.task.priority_strategy import PriorityWeightStrategy
def _get_class_ref(obj) -> dict[str, str | None]:
@@ -92,7 +96,11 @@ class TaskResponse(BaseModel):
return None
if isinstance(wr, str):
return wr
- return encode_priority_weight_strategy(wr)
+ strat_type = type(validate_and_load_priority_weight_strategy(wr))
+ try:
+ return get_weight_rule_from_priority_weight_strategy(strat_type)
+ except KeyError:
+ return qualname(strat_type)
@field_validator("params", mode="before")
@classmethod
diff --git a/airflow-core/src/airflow/serialization/serialized_objects.py
b/airflow-core/src/airflow/serialization/serialized_objects.py
index db79e799443..f1e5813b189 100644
--- a/airflow-core/src/airflow/serialization/serialized_objects.py
+++ b/airflow-core/src/airflow/serialization/serialized_objects.py
@@ -152,7 +152,7 @@ class
_PriorityWeightStrategyNotRegistered(AirflowException):
)
-def encode_outlet_event_accessor(var: OutletEventAccessor) -> dict[str, Any]:
+def _encode_outlet_event_accessor(var: OutletEventAccessor) -> dict[str, Any]:
key = var.key
return {
"key": BaseSerialization.serialize(key),
@@ -161,7 +161,7 @@ def encode_outlet_event_accessor(var: OutletEventAccessor)
-> dict[str, Any]:
}
-def decode_outlet_event_accessor(var: dict[str, Any]) -> OutletEventAccessor:
+def _decode_outlet_event_accessor(var: dict[str, Any]) -> OutletEventAccessor:
asset_alias_events = var.get("asset_alias_events", [])
outlet_event_accessor = OutletEventAccessor(
key=BaseSerialization.deserialize(var["key"]),
@@ -182,26 +182,26 @@ def decode_outlet_event_accessor(var: dict[str, Any]) ->
OutletEventAccessor:
return outlet_event_accessor
-def encode_outlet_event_accessors(var: OutletEventAccessors) -> dict[str, Any]:
+def _encode_outlet_event_accessors(var: OutletEventAccessors) -> dict[str,
Any]:
return {
"__type": DAT.ASSET_EVENT_ACCESSORS,
"_dict": [
- {"key": BaseSerialization.serialize(k), "value":
encode_outlet_event_accessor(v)}
+ {"key": BaseSerialization.serialize(k), "value":
_encode_outlet_event_accessor(v)}
for k, v in var._dict.items()
],
}
-def decode_outlet_event_accessors(var: dict[str, Any]) -> OutletEventAccessors:
+def _decode_outlet_event_accessors(var: dict[str, Any]) ->
OutletEventAccessors:
d = OutletEventAccessors()
d._dict = {
- BaseSerialization.deserialize(row["key"]):
decode_outlet_event_accessor(row["value"])
+ BaseSerialization.deserialize(row["key"]):
_decode_outlet_event_accessor(row["value"])
for row in var["_dict"]
}
return d
-def encode_priority_weight_strategy(var: PriorityWeightStrategy | str) -> str:
+def _encode_priority_weight_strategy(var: PriorityWeightStrategy | str) -> str:
"""
Encode a priority weight strategy instance.
@@ -218,7 +218,7 @@ def encode_priority_weight_strategy(var:
PriorityWeightStrategy | str) -> str:
return importable_string
-def decode_priority_weight_strategy(var: str) -> PriorityWeightStrategy:
+def _decode_priority_weight_strategy(var: str) -> PriorityWeightStrategy:
"""
Decode a previously serialized priority weight strategy.
@@ -231,12 +231,8 @@ def decode_priority_weight_strategy(var: str) ->
PriorityWeightStrategy:
return priority_weight_strategy_class()
-def encode_start_trigger_args(var: StartTriggerArgs) -> dict[str, Any]:
- """
- Encode a StartTriggerArgs.
-
- :meta private:
- """
+def _encode_start_trigger_args(var: StartTriggerArgs) -> dict[str, Any]:
+ """Encode a StartTriggerArgs."""
def serialize_kwargs(key: str) -> Any:
if (val := getattr(var, key)) is None:
@@ -253,12 +249,8 @@ def encode_start_trigger_args(var: StartTriggerArgs) ->
dict[str, Any]:
}
-def decode_start_trigger_args(var: dict[str, Any]) -> StartTriggerArgs:
- """
- Decode a StartTriggerArgs.
-
- :meta private:
- """
+def _decode_start_trigger_args(var: dict[str, Any]) -> StartTriggerArgs:
+ """Decode a StartTriggerArgs."""
def deserialize_kwargs(key: str) -> Any:
if (val := var[key]) is None:
@@ -448,7 +440,7 @@ class BaseSerialization:
elif key == "timetable" and value is not None:
serialized_object[key] = encode_timetable(value)
elif key == "weight_rule" and value is not None:
- encoded_priority_weight_strategy =
encode_priority_weight_strategy(value)
+ encoded_priority_weight_strategy =
_encode_priority_weight_strategy(value)
# Exclude if it is just default
default_pri_weight_stra =
cls.get_schema_defaults("operator").get(key, None)
@@ -504,7 +496,7 @@ class BaseSerialization:
return cls._encode(json_pod, type_=DAT.POD)
elif isinstance(var, OutletEventAccessors):
return cls._encode(
- encode_outlet_event_accessors(var),
+ _encode_outlet_event_accessors(var),
type_=DAT.ASSET_EVENT_ACCESSORS,
)
elif isinstance(var, AssetUniqueKey):
@@ -639,7 +631,7 @@ class BaseSerialization:
if type_ == DAT.DICT:
return {k: cls.deserialize(v) for k, v in var.items()}
elif type_ == DAT.ASSET_EVENT_ACCESSORS:
- return decode_outlet_event_accessors(var)
+ return _decode_outlet_event_accessors(var)
elif type_ == DAT.ASSET_UNIQUE_KEY:
return AssetUniqueKey(name=var["name"], uri=var["uri"])
elif type_ == DAT.ASSET_ALIAS_UNIQUE_KEY:
@@ -859,7 +851,7 @@ class BaseSerialization:
return defaults
-class DependencyDetector:
+class _DependencyDetector:
"""
Detects dependencies between DAGs.
@@ -1038,7 +1030,7 @@ class OperatorSerialization(DAGNode, BaseSerialization):
serialize_op["_can_skip_downstream"] = True
if op.start_trigger_args:
- serialize_op["start_trigger_args"] =
encode_start_trigger_args(op.start_trigger_args)
+ serialize_op["start_trigger_args"] =
_encode_start_trigger_args(op.start_trigger_args)
if op.operator_extra_links:
serialize_op["_operator_extra_links"] =
cls._serialize_operator_extra_links(
@@ -1163,7 +1155,7 @@ class OperatorSerialization(DAGNode, BaseSerialization):
k = "on_failure_fail_dagrun"
elif k == "weight_rule":
k = "_weight_rule"
- v = decode_priority_weight_strategy(v)
+ v = _decode_priority_weight_strategy(v)
elif k == "retry_exponential_backoff":
if isinstance(v, bool):
v = 2.0 if v else 0
@@ -1215,7 +1207,7 @@ class OperatorSerialization(DAGNode, BaseSerialization):
encoded_start_trigger_args = encoded_op.get("start_trigger_args", None)
if encoded_start_trigger_args:
encoded_start_trigger_args = cast("dict",
encoded_start_trigger_args)
- start_trigger_args =
decode_start_trigger_args(encoded_start_trigger_args)
+ start_trigger_args =
_decode_start_trigger_args(encoded_start_trigger_args)
setattr(op, "start_trigger_args", start_trigger_args)
setattr(op, "start_from_trigger",
bool(encoded_op.get("start_from_trigger", False)))
@@ -1371,7 +1363,7 @@ class OperatorSerialization(DAGNode, BaseSerialization):
@classmethod
def detect_dependencies(cls, op: SdkOperator) -> set[DagDependency]:
"""Detect between DAG dependencies for the operator."""
- dependency_detector = DependencyDetector()
+ dependency_detector = _DependencyDetector()
deps = set(dependency_detector.detect_task_dependencies(op))
return deps
@@ -1711,7 +1703,7 @@ class DagSerialization(BaseSerialization):
for task in dag.task_dict.values()
for dep in OperatorSerialization.detect_dependencies(task)
]
- dag_deps.extend(DependencyDetector.detect_dag_dependencies(dag))
+ dag_deps.extend(_DependencyDetector.detect_dag_dependencies(dag))
serialized_dag["dag_dependencies"] = [x.__dict__ for x in
sorted(dag_deps)]
serialized_dag["task_group"] =
TaskGroupSerialization.serialize_task_group(dag.task_group)
@@ -1811,7 +1803,7 @@ class DagSerialization(BaseSerialization):
elif k == "timetable":
v = decode_timetable(v)
elif k == "weight_rule":
- v = decode_priority_weight_strategy(v)
+ v = _decode_priority_weight_strategy(v)
elif k in cls._decorated_fields:
v = cls.deserialize(v)
elif k == "params":