This is an automated email from the ASF dual-hosted git repository.
amoghdesai pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 28b3f6eb921 Decouple deadline reference types from core in task SDK
(#61461)
28b3f6eb921 is described below
commit 28b3f6eb9213b794a55a147c0ed70a1fe8f9ad26
Author: Amogh Desai <[email protected]>
AuthorDate: Wed Mar 11 14:03:57 2026 +0530
Decouple deadline reference types from core in task SDK (#61461)
Custom deadline references now serialize and deserialize using a wrapper
pattern.
---
airflow-core/docs/howto/deadline-alerts.rst | 7 +-
airflow-core/src/airflow/models/deadline.py | 6 +-
airflow-core/src/airflow/models/deadline_alert.py | 7 +-
airflow-core/src/airflow/serialization/decoders.py | 17 ++-
.../airflow/serialization/definitions/deadline.py | 91 ++++++++++-
airflow-core/src/airflow/serialization/encoders.py | 28 +++-
airflow-core/tests/unit/models/test_deadline.py | 103 ++++++++-----
.../tests/unit/models/test_deadline_alert.py | 58 ++++++-
task-sdk/src/airflow/sdk/definitions/deadline.py | 170 +++++++++++++++++----
9 files changed, 396 insertions(+), 91 deletions(-)
diff --git a/airflow-core/docs/howto/deadline-alerts.rst
b/airflow-core/docs/howto/deadline-alerts.rst
index 1ed9750bf4e..64f39c02440 100644
--- a/airflow-core/docs/howto/deadline-alerts.rst
+++ b/airflow-core/docs/howto/deadline-alerts.rst
@@ -425,17 +425,16 @@ implement an ``_evaluate_with()`` method.
.. code-block:: python
- from airflow.models.deadline import ReferenceModels
from sqlalchemy.orm import Session
from airflow.sdk import DeadlineReference
- from airflow.sdk.definitions.deadline import deadline_reference
+ from airflow.sdk.definitions.deadline import BaseDeadlineReference,
deadline_reference
from airflow.sdk.timezone import datetime
# By default, the evaluate_with method will be executed when the dagrun is
created.
@deadline_reference()
- class MyCustomDecoratedReference(ReferenceModels.BaseDeadlineReference):
+ class MyCustomDecoratedReference(BaseDeadlineReference):
"""A custom reference evaluated when Dag runs are created."""
def _evaluate_with(self, *, session: Session, **kwargs) -> datetime:
@@ -445,7 +444,7 @@ implement an ``_evaluate_with()`` method.
# You can specify when evaluate_with will be called by providing a
DeadlineReference.TYPES value.
@deadline_reference(DeadlineReference.TYPES.DAGRUN_QUEUED)
- class MyQueuedReference(ReferenceModels.BaseDeadlineReference):
+ class MyQueuedReference(BaseDeadlineReference):
"""A custom reference evaluated when Dag runs are queued."""
required_kwargs = {"custom_param"}
diff --git a/airflow-core/src/airflow/models/deadline.py
b/airflow-core/src/airflow/models/deadline.py
index debfe949b31..ec3ab5ad99c 100644
--- a/airflow-core/src/airflow/models/deadline.py
+++ b/airflow-core/src/airflow/models/deadline.py
@@ -314,7 +314,11 @@ class ReferenceModels:
)
if extra_kwargs := kwargs.keys() - filtered_kwargs.keys():
- self.log.debug("Ignoring unexpected parameters: %s", ",
".join(extra_kwargs))
+ self.log.debug(
+ "%s ignoring unexpected parameters: %s",
+ self.reference_name,
+ ", ".join(extra_kwargs),
+ )
base_time = self._evaluate_with(session=session, **filtered_kwargs)
return base_time + interval if base_time is not None else None
diff --git a/airflow-core/src/airflow/models/deadline_alert.py
b/airflow-core/src/airflow/models/deadline_alert.py
index 8afc35d7560..0b8a8eba9b1 100644
--- a/airflow-core/src/airflow/models/deadline_alert.py
+++ b/airflow-core/src/airflow/models/deadline_alert.py
@@ -86,9 +86,10 @@ class DeadlineAlert(Base):
@property
def reference_class(self) ->
type[SerializedReferenceModels.SerializedBaseDeadlineReference]:
"""Return the deserialized reference class."""
- return SerializedReferenceModels.get_reference_class(
- self.reference[SerializedReferenceModels.REFERENCE_TYPE_FIELD]
- )
+ ref_name =
self.reference.get(SerializedReferenceModels.REFERENCE_TYPE_FIELD)
+ if ref_name and
SerializedReferenceModels.is_builtin_reference(ref_name):
+ return SerializedReferenceModels.get_reference_class(ref_name)
+ return SerializedReferenceModels.SerializedCustomReference
@classmethod
@provide_session
diff --git a/airflow-core/src/airflow/serialization/decoders.py
b/airflow-core/src/airflow/serialization/decoders.py
index a438010424f..8d19a196e92 100644
--- a/airflow-core/src/airflow/serialization/decoders.py
+++ b/airflow-core/src/airflow/serialization/decoders.py
@@ -136,6 +136,18 @@ def decode_asset_like(var: dict[str, Any]) ->
SerializedAssetBase:
raise ValueError(f"deserialization not implemented for DAT
{data_type!r}")
+def decode_deadline_reference(reference_data: dict):
+ """Decode a previously serialized deadline reference."""
+ ref_name =
reference_data.get(SerializedReferenceModels.REFERENCE_TYPE_FIELD)
+
+ if ref_name and SerializedReferenceModels.is_builtin_reference(ref_name):
+ reference_class =
SerializedReferenceModels.get_reference_class(ref_name)
+ else:
+ reference_class = SerializedReferenceModels.SerializedCustomReference
+
+ return reference_class.deserialize_reference(reference_data)
+
+
def decode_deadline_alert(encoded_data: dict):
"""
Decode a previously serialized deadline alert.
@@ -147,10 +159,7 @@ def decode_deadline_alert(encoded_data: dict):
data = encoded_data.get(Encoding.VAR, encoded_data)
reference_data = data[DeadlineAlertFields.REFERENCE]
- reference_type =
reference_data[SerializedReferenceModels.REFERENCE_TYPE_FIELD]
-
- reference_class =
SerializedReferenceModels.get_reference_class(reference_type)
- reference = reference_class.deserialize_reference(reference_data)
+ reference = decode_deadline_reference(reference_data)
return SerializedDeadlineAlert(
reference=reference,
diff --git a/airflow-core/src/airflow/serialization/definitions/deadline.py
b/airflow-core/src/airflow/serialization/definitions/deadline.py
index 78adc6b9a76..93af9ef19e7 100644
--- a/airflow-core/src/airflow/serialization/definitions/deadline.py
+++ b/airflow-core/src/airflow/serialization/definitions/deadline.py
@@ -20,6 +20,7 @@ import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from datetime import datetime, timedelta
+from inspect import isclass
from typing import TYPE_CHECKING, Any
import attrs
@@ -62,6 +63,15 @@ class SerializedReferenceModels:
REFERENCE_TYPE_FIELD = "reference_type"
+ @classmethod
+ def is_builtin_reference(cls, ref_name: str) -> bool:
+ """Check if a reference type is a built-in reference."""
+ return any(
+ r.__name__ == ref_name
+ for r in vars(cls).values()
+ if isclass(r) and issubclass(r,
cls.SerializedBaseDeadlineReference)
+ )
+
@classmethod
def get_reference_class(cls, reference_name: str) ->
type[SerializedBaseDeadlineReference]:
"""
@@ -99,7 +109,11 @@ class SerializedReferenceModels:
)
if extra_kwargs := kwargs.keys() - filtered_kwargs.keys():
- self.log.debug("Ignoring unexpected parameters: %s", ",
".join(extra_kwargs))
+ self.log.debug(
+ "%s ignoring unexpected parameters: %s",
+ self.reference_name,
+ ", ".join(extra_kwargs),
+ )
base_time = self._evaluate_with(session=session, **filtered_kwargs)
return base_time + interval if base_time is not None else None
@@ -225,8 +239,19 @@ class SerializedReferenceModels:
)
return None
- avg_duration_seconds = sum(durations) / len(durations)
- return timezone.utcnow() + timedelta(seconds=avg_duration_seconds)
+ # Convert to float to handle Decimal types from MySQL while
preserving precision
+ # Use Decimal arithmetic for higher precision, then convert to
float
+ from decimal import Decimal
+
+ decimal_durations = [Decimal(str(d)) for d in durations]
+ avg_seconds = float(sum(decimal_durations) /
len(decimal_durations))
+ logger.info(
+ "Average runtime for dag_id %s (from %d runs): %.2f seconds",
+ dag_id,
+ len(durations),
+ avg_seconds,
+ )
+ return timezone.utcnow() + timedelta(seconds=avg_seconds)
def serialize_reference(self) -> dict:
return {
@@ -239,6 +264,62 @@ class SerializedReferenceModels:
def deserialize_reference(cls, reference_data: dict):
return cls(max_runs=reference_data["max_runs"],
min_runs=reference_data.get("min_runs"))
+ class SerializedCustomReference(SerializedBaseDeadlineReference):
+ """
+ Wrapper for custom deadline references.
+
+ This class dynamically delegates to the wrapped reference for
required_kwargs and evaluation logic.
+ """
+
+ def __init__(self, inner_ref):
+ self.inner_ref = inner_ref
+
+ @property
+ def reference_name(self) -> str:
+ return self.inner_ref.reference_name
+
+ def evaluate_with(self, *, session: Session, interval: timedelta,
**kwargs: Any) -> datetime | None:
+ """Validate the provided kwargs and evaluate this deadline with
the given conditions."""
+ required_kwargs: set[str] = getattr(self.inner_ref,
"required_kwargs", set())
+ filtered_kwargs = {k: v for k, v in kwargs.items() if k in
required_kwargs}
+
+ if missing_kwargs := required_kwargs - filtered_kwargs.keys():
+ raise ValueError(
+ f"{self.inner_ref.__class__.__name__} is missing required
parameters: {', '.join(missing_kwargs)}"
+ )
+
+ if extra_kwargs := kwargs.keys() - filtered_kwargs.keys():
+ self.log.debug(
+ "%s ignoring unexpected parameters: %s",
+ self.reference_name,
+ ", ".join(extra_kwargs),
+ )
+
+ deadline = self.inner_ref._evaluate_with(session=session,
**filtered_kwargs)
+ return deadline + interval if deadline is not None else None
+
+ def _evaluate_with(self, *, session: Session, **kwargs: Any) ->
datetime | None:
+ return self.inner_ref._evaluate_with(session=session, **kwargs)
+
+ def serialize_reference(self) -> dict:
+ return self.inner_ref.serialize_reference()
+
+ @classmethod
+ def deserialize_reference(cls, reference_data: dict):
+ from airflow._shared.module_loading import import_string
+
+ custom_class = import_string(reference_data["__class_path"])
+ inner_ref = custom_class.deserialize_reference(reference_data)
+ return cls(inner_ref)
+
+ def __eq__(self, other) -> bool:
+ if not isinstance(other,
SerializedReferenceModels.SerializedCustomReference):
+ return False
+ return self.inner_ref == other.inner_ref
+
+ def __hash__(self) -> int:
+ return hash(self.inner_ref)
+
class TYPES:
"""Collection of SerializedDeadlineReference types for type
checking."""
@@ -259,7 +340,9 @@ SerializedReferenceModels.TYPES.DAGRUN_CREATED = (
)
SerializedReferenceModels.TYPES.DAGRUN_QUEUED =
(SerializedReferenceModels.DagRunQueuedAtDeadline,)
SerializedReferenceModels.TYPES.DAGRUN = (
- SerializedReferenceModels.TYPES.DAGRUN_CREATED +
SerializedReferenceModels.TYPES.DAGRUN_QUEUED
+ *SerializedReferenceModels.TYPES.DAGRUN_CREATED,
+ *SerializedReferenceModels.TYPES.DAGRUN_QUEUED,
+ SerializedReferenceModels.SerializedCustomReference,
)
diff --git a/airflow-core/src/airflow/serialization/encoders.py
b/airflow-core/src/airflow/serialization/encoders.py
index 7db97b844f6..239b2b1b97b 100644
--- a/airflow-core/src/airflow/serialization/encoders.py
+++ b/airflow-core/src/airflow/serialization/encoders.py
@@ -202,19 +202,43 @@ def encode_deadline_alert(d: DeadlineAlert |
SerializedDeadlineAlert) -> dict[st
from airflow.sdk.serde import serialize
return {
- "reference": d.reference.serialize_reference(),
+ "reference": encode_deadline_reference(d.reference),
"interval": d.interval.total_seconds(),
"callback": serialize(d.callback),
}
+_BUILTIN_DEADLINE_MODULES = (
+ "airflow.sdk.definitions.deadline",
+ "airflow.serialization.definitions.deadline",
+ # Include airflow.models.deadline to treat core's deadline references as
builtins.
+ # This is to maintain backcompat with 3.1.x custom refs that inherit from
+ # airflow.models.deadline.ReferenceModels.BaseDeadlineReference.
+ "airflow.models.deadline",
+)
+
+
def encode_deadline_reference(ref) -> dict[str, Any]:
"""
Encode a deadline reference.
+ For custom (non-builtin) deadline references, includes the class path
+ so the decoder can import the user's class at runtime.
+
:meta private:
"""
- return ref.serialize_reference()
+ from airflow._shared.module_loading import qualname
+
+ serialized = ref.serialize_reference()
+
+ # Custom types (not built-in) need __class_path so the decoder can import
them.
+ # Unlike built-in types which are looked up in SerializedReferenceModels,
+ # custom types are discovered via import_string(__class_path) at
deserialization time.
+ module = type(ref).__module__
+ if module not in _BUILTIN_DEADLINE_MODULES:
+ serialized["__class_path"] = qualname(ref)
+
+ return serialized
def _get_serialized_timetable_import_path(var: BaseTimetable | CoreTimetable)
-> str:
diff --git a/airflow-core/tests/unit/models/test_deadline.py
b/airflow-core/tests/unit/models/test_deadline.py
index f4f435291ce..94c6977ae0c 100644
--- a/airflow-core/tests/unit/models/test_deadline.py
+++ b/airflow-core/tests/unit/models/test_deadline.py
@@ -28,11 +28,20 @@ from sqlalchemy.exc import SQLAlchemyError
from airflow.api_fastapi.core_api.datamodels.dag_run import DAGRunResponse
from airflow.models import DagRun
-from airflow.models.deadline import Deadline, ReferenceModels, _fetch_from_db
+from airflow.models.deadline import Deadline, _fetch_from_db
from airflow.providers.standard.operators.empty import EmptyOperator
from airflow.sdk import timezone
from airflow.sdk.definitions.callback import AsyncCallback, SyncCallback
-from airflow.sdk.definitions.deadline import DeadlineReference,
deadline_reference
+from airflow.sdk.definitions.deadline import (
+ AverageRuntimeDeadline,
+ BaseDeadlineReference,
+ DagRunLogicalDateDeadline,
+ DagRunQueuedAtDeadline,
+ DeadlineReference,
+ FixedDatetimeDeadline,
+ deadline_reference,
+)
+from airflow.serialization.definitions.deadline import
SerializedReferenceModels
from airflow.utils.state import DagRunState
from tests_common.test_utils import db
@@ -46,10 +55,12 @@ INVALID_DAG_ID = "invalid_dag_id"
INVALID_RUN_ID = -1
REFERENCE_TYPES = [
- pytest.param(DeadlineReference.DAGRUN_LOGICAL_DATE, id="logical_date"),
- pytest.param(DeadlineReference.DAGRUN_QUEUED_AT, id="queued_at"),
- pytest.param(DeadlineReference.FIXED_DATETIME(DEFAULT_DATE),
id="fixed_deadline"),
- pytest.param(DeadlineReference.AVERAGE_RUNTIME(), id="average_runtime"),
+ pytest.param(SerializedReferenceModels.DagRunLogicalDateDeadline(),
id="logical_date"),
+ pytest.param(SerializedReferenceModels.DagRunQueuedAtDeadline(),
id="queued_at"),
+
pytest.param(SerializedReferenceModels.FixedDatetimeDeadline(DEFAULT_DATE),
id="fixed_deadline"),
+ pytest.param(
+ SerializedReferenceModels.AverageRuntimeDeadline(max_runs=10,
min_runs=10), id="average_runtime"
+ ),
]
@@ -320,10 +331,20 @@ class TestCalculatedDeadlineDatabaseCalls:
@pytest.mark.parametrize(
("reference", "expected_column"),
[
- pytest.param(DeadlineReference.DAGRUN_LOGICAL_DATE,
DagRun.logical_date, id="logical_date"),
- pytest.param(DeadlineReference.DAGRUN_QUEUED_AT, DagRun.queued_at,
id="queued_at"),
- pytest.param(DeadlineReference.FIXED_DATETIME(DEFAULT_DATE), None,
id="fixed_deadline"),
- pytest.param(DeadlineReference.AVERAGE_RUNTIME(), None,
id="average_runtime"),
+ pytest.param(
+ SerializedReferenceModels.DagRunLogicalDateDeadline(),
DagRun.logical_date, id="logical_date"
+ ),
+ pytest.param(
+ SerializedReferenceModels.DagRunQueuedAtDeadline(),
DagRun.queued_at, id="queued_at"
+ ),
+ pytest.param(
+ SerializedReferenceModels.FixedDatetimeDeadline(DEFAULT_DATE),
None, id="fixed_deadline"
+ ),
+ pytest.param(
+ SerializedReferenceModels.AverageRuntimeDeadline(max_runs=10,
min_runs=10),
+ None,
+ id="average_runtime",
+ ),
],
)
def test_deadline_database_integration(self, reference, expected_column,
session):
@@ -337,13 +358,13 @@ class TestCalculatedDeadlineDatabaseCalls:
"""
conditions = {"dag_id": DAG_ID, "run_id": "dagrun_1"}
interval = timedelta(hours=1)
- with mock.patch("airflow.models.deadline._fetch_from_db") as
mock_fetch:
+ with
mock.patch("airflow.serialization.definitions.deadline._fetch_from_db") as
mock_fetch:
mock_fetch.return_value = DEFAULT_DATE
if expected_column is not None:
result = reference.evaluate_with(session=session,
interval=interval, **conditions)
mock_fetch.assert_called_once_with(expected_column,
session=session, **conditions)
- elif reference == DeadlineReference.AVERAGE_RUNTIME():
+ elif isinstance(reference,
SerializedReferenceModels.AverageRuntimeDeadline):
with mock.patch("airflow._shared.timezones.timezone.utcnow")
as mock_utcnow:
mock_utcnow.return_value = DEFAULT_DATE
# No DAG runs exist, so it should use 24-hour default
@@ -380,7 +401,7 @@ class TestCalculatedDeadlineDatabaseCalls:
session.commit()
# Test with default max_runs (10)
- reference = DeadlineReference.AVERAGE_RUNTIME()
+ reference =
SerializedReferenceModels.AverageRuntimeDeadline(max_runs=10, min_runs=10)
interval = timedelta(hours=1)
with mock.patch("airflow._shared.timezones.timezone.utcnow") as
mock_utcnow:
@@ -417,7 +438,7 @@ class TestCalculatedDeadlineDatabaseCalls:
session.commit()
- reference = DeadlineReference.AVERAGE_RUNTIME()
+ reference =
SerializedReferenceModels.AverageRuntimeDeadline(max_runs=10, min_runs=10)
interval = timedelta(hours=1)
with mock.patch("airflow._shared.timezones.timezone.utcnow") as
mock_utcnow:
@@ -451,7 +472,7 @@ class TestCalculatedDeadlineDatabaseCalls:
session.commit()
# Test with min_runs=2, should work with 3 runs
- reference = DeadlineReference.AVERAGE_RUNTIME(max_runs=10, min_runs=2)
+ reference =
SerializedReferenceModels.AverageRuntimeDeadline(max_runs=10, min_runs=2)
interval = timedelta(hours=1)
with mock.patch("airflow._shared.timezones.timezone.utcnow") as
mock_utcnow:
@@ -465,7 +486,7 @@ class TestCalculatedDeadlineDatabaseCalls:
assert result.replace(second=0, microsecond=0) ==
expected.replace(second=0, microsecond=0)
# Test with min_runs=5, should return None with only 3 runs
- reference = DeadlineReference.AVERAGE_RUNTIME(max_runs=10, min_runs=5)
+ reference =
SerializedReferenceModels.AverageRuntimeDeadline(max_runs=10, min_runs=5)
with mock.patch("airflow._shared.timezones.timezone.utcnow") as
mock_utcnow:
mock_utcnow.return_value = DEFAULT_DATE
@@ -535,17 +556,17 @@ class TestDeadlineReference:
def test_deadline_reference_creation(self):
"""Test that DeadlineReference provides consistent interface and
types."""
fixed_reference = DeadlineReference.FIXED_DATETIME(DEFAULT_DATE)
- assert isinstance(fixed_reference,
ReferenceModels.FixedDatetimeDeadline)
+ assert isinstance(fixed_reference, FixedDatetimeDeadline)
assert fixed_reference._datetime == DEFAULT_DATE
logical_date_reference = DeadlineReference.DAGRUN_LOGICAL_DATE
- assert isinstance(logical_date_reference,
ReferenceModels.DagRunLogicalDateDeadline)
+ assert isinstance(logical_date_reference, DagRunLogicalDateDeadline)
queued_reference = DeadlineReference.DAGRUN_QUEUED_AT
- assert isinstance(queued_reference,
ReferenceModels.DagRunQueuedAtDeadline)
+ assert isinstance(queued_reference, DagRunQueuedAtDeadline)
average_runtime_reference = DeadlineReference.AVERAGE_RUNTIME()
- assert isinstance(average_runtime_reference,
ReferenceModels.AverageRuntimeDeadline)
+ assert isinstance(average_runtime_reference, AverageRuntimeDeadline)
assert average_runtime_reference.max_runs == 10
assert average_runtime_reference.min_runs == 10
@@ -556,14 +577,14 @@ class TestDeadlineReference:
class TestCustomDeadlineReference:
- class MyCustomRef(ReferenceModels.BaseDeadlineReference):
+ class MyCustomRef(BaseDeadlineReference):
def _evaluate_with(self, *, session: Session, **kwargs) -> datetime:
return timezone.datetime(DEFAULT_DATE)
class MyInvalidCustomRef:
pass
- class MyCustomRefWithKwargs(ReferenceModels.BaseDeadlineReference):
+ class MyCustomRefWithKwargs(BaseDeadlineReference):
required_kwargs = {"custom_id"}
def _evaluate_with(self, *, session: Session, **kwargs) -> datetime:
@@ -573,7 +594,6 @@ class TestCustomDeadlineReference:
self.original_dagrun_created = DeadlineReference.TYPES.DAGRUN_CREATED
self.original_dagrun_queued = DeadlineReference.TYPES.DAGRUN_QUEUED
self.original_dagrun = DeadlineReference.TYPES.DAGRUN
- self.original_attrs = set(dir(ReferenceModels))
self.original_deadline_attrs = set(dir(DeadlineReference))
def teardown_method(self):
@@ -581,10 +601,6 @@ class TestCustomDeadlineReference:
DeadlineReference.TYPES.DAGRUN_QUEUED = self.original_dagrun_queued
DeadlineReference.TYPES.DAGRUN = self.original_dagrun
- for attr in set(dir(ReferenceModels)):
- if attr not in self.original_attrs:
- delattr(ReferenceModels, attr)
-
for attr in set(dir(DeadlineReference)):
if attr not in self.original_deadline_attrs:
delattr(DeadlineReference, attr)
@@ -613,7 +629,7 @@ class TestCustomDeadlineReference:
expected_timing = timing
assert result is reference
- assert getattr(ReferenceModels, reference.__name__) is reference
+ assert hasattr(DeadlineReference, reference.__name__)
assert getattr(DeadlineReference, reference.__name__).__class__ is
reference
assert_correct_timing(reference, expected_timing)
@@ -637,12 +653,15 @@ class TestCustomDeadlineReference:
):
DeadlineReference.register_custom_reference(self.MyCustomRef,
invalid_timing)
- def test_custom_reference_discoverable_by_get_reference_class(self):
+ def test_custom_reference_discoverable_on_deadline_reference(self):
+ # Custom references are only registered on DeadlineReference, not on
ReferenceModels.
+ # During deserialization, custom refs are discovered via __class_path
in the
+ # serialized data (using import_string), not through ReferenceModels
lookup.
DeadlineReference.register_custom_reference(self.MyCustomRef)
- found_class =
ReferenceModels.get_reference_class(self.MyCustomRef.__name__)
-
- assert found_class is self.MyCustomRef
+ assert hasattr(DeadlineReference, self.MyCustomRef.__name__)
+ found_instance = getattr(DeadlineReference, self.MyCustomRef.__name__)
+ assert isinstance(found_instance, self.MyCustomRef)
class TestDeadlineReferenceDecorator:
@@ -650,21 +669,21 @@ class TestDeadlineReferenceDecorator:
self.original_dagrun_created = DeadlineReference.TYPES.DAGRUN_CREATED
self.original_dagrun_queued = DeadlineReference.TYPES.DAGRUN_QUEUED
self.original_dagrun = DeadlineReference.TYPES.DAGRUN
- self.original_attrs = set(dir(ReferenceModels))
+ self.original_deadline_attrs = set(dir(DeadlineReference))
def teardown_method(self):
DeadlineReference.TYPES.DAGRUN_CREATED = self.original_dagrun_created
DeadlineReference.TYPES.DAGRUN_QUEUED = self.original_dagrun_queued
DeadlineReference.TYPES.DAGRUN = self.original_dagrun
- for attr in set(dir(ReferenceModels)):
- if attr not in self.original_attrs:
- delattr(ReferenceModels, attr)
+ for attr in set(dir(DeadlineReference)):
+ if attr not in self.original_deadline_attrs:
+ delattr(DeadlineReference, attr)
@staticmethod
def create_decorated_custom_ref():
@deadline_reference()
- class DecoratedCustomRef(ReferenceModels.BaseDeadlineReference):
+ class DecoratedCustomRef(BaseDeadlineReference):
def _evaluate_with(self, *, session: Session, **kwargs) ->
datetime:
return timezone.datetime(DEFAULT_DATE)
@@ -673,7 +692,7 @@ class TestDeadlineReferenceDecorator:
@staticmethod
def create_decorated_custom_ref_with_kwargs():
@deadline_reference()
- class
DecoratedCustomRefWithKwargs(ReferenceModels.BaseDeadlineReference):
+ class DecoratedCustomRefWithKwargs(BaseDeadlineReference):
required_kwargs = {"custom_id"}
def _evaluate_with(self, *, session: Session, **kwargs) ->
datetime:
@@ -684,7 +703,7 @@ class TestDeadlineReferenceDecorator:
@staticmethod
def create_decorated_custom_ref_queued():
@deadline_reference(DeadlineReference.TYPES.DAGRUN_QUEUED)
- class DecoratedCustomRefQueued(ReferenceModels.BaseDeadlineReference):
+ class DecoratedCustomRefQueued(BaseDeadlineReference):
def _evaluate_with(self, *, session: Session, **kwargs) ->
datetime:
return timezone.datetime(DEFAULT_DATE)
@@ -713,7 +732,7 @@ class TestDeadlineReferenceDecorator:
def test_deadline_reference_decorator(self, reference_factory,
expected_timing):
reference = reference_factory()
- assert getattr(ReferenceModels, reference.__name__) is reference
+ assert hasattr(DeadlineReference, reference.__name__)
assert getattr(DeadlineReference, reference.__name__).__class__ is
reference
assert_correct_timing(reference, expected_timing)
@@ -741,7 +760,7 @@ class TestDeadlineReferenceDecorator:
):
@deadline_reference(invalid_timing)
- class DecoratedCustomRef(ReferenceModels.BaseDeadlineReference):
+ class DecoratedCustomRef(BaseDeadlineReference):
def _evaluate_with(self, *, session: Session, **kwargs) ->
datetime:
return timezone.datetime(DEFAULT_DATE)
@@ -750,7 +769,7 @@ class TestDeadlineReferenceDecorator:
timing = DeadlineReference.TYPES.DAGRUN_QUEUED
@deadline_reference(timing)
- class DecoratedCustomRef(ReferenceModels.BaseDeadlineReference):
+ class DecoratedCustomRef(BaseDeadlineReference):
def _evaluate_with(self, *, session: Session, **kwargs) ->
datetime:
return timezone.datetime(DEFAULT_DATE)
diff --git a/airflow-core/tests/unit/models/test_deadline_alert.py
b/airflow-core/tests/unit/models/test_deadline_alert.py
index 879203814b3..9d69577a6d6 100644
--- a/airflow-core/tests/unit/models/test_deadline_alert.py
+++ b/airflow-core/tests/unit/models/test_deadline_alert.py
@@ -16,13 +16,17 @@
# under the License.
from __future__ import annotations
+from datetime import timedelta
+from unittest.mock import Mock
+
import pytest
import time_machine
from sqlalchemy import select
+from airflow._shared.timezones import timezone
from airflow.models.deadline_alert import DeadlineAlert
from airflow.models.serialized_dag import SerializedDagModel
-from airflow.sdk.definitions.deadline import DeadlineReference
+from airflow.sdk.definitions.deadline import BaseDeadlineReference,
DeadlineReference
from airflow.serialization.definitions.deadline import
SerializedReferenceModels
from tests_common.test_utils import db
@@ -172,3 +176,55 @@ class TestDeadlineAlert:
nonexistent_uuid = "00000000-0000-7000-8000-000000000000"
with pytest.raises(NoResultFound, match="No DeadlineAlert found"):
DeadlineAlert.get_by_id(nonexistent_uuid, session=session)
+
+ def test_serialized_custom_reference_kwargs_handling(self):
+ """Test that SerializedCustomReference properly filters and validates
kwargs."""
+
+ class StrictCustomRef(BaseDeadlineReference):
+ reference_name = "StrictCustomRef"
+ required_kwargs = {"dag_id", "run_id"}
+
+ def _evaluate_with(self, *, session, dag_id, run_id):
+ return timezone.utcnow()
+
+ inner_ref = StrictCustomRef()
+ inner_ref._evaluate_with = Mock(return_value=timezone.utcnow())
+
+ wrapper =
SerializedReferenceModels.SerializedCustomReference(inner_ref)
+
+ wrapper.evaluate_with(
+ session=None,
+ interval=timedelta(hours=1),
+ dag_id="test_dag",
+ run_id="test_run",
+ extra_param="should_be_filtered",
+ )
+
+ inner_ref._evaluate_with.assert_called_once_with(session=None,
dag_id="test_dag", run_id="test_run")
+
+ # try calling with missing required parameters
+ with pytest.raises(ValueError, match="missing required parameters:
run_id"):
+ wrapper.evaluate_with(
+ session=None,
+ interval=timedelta(hours=1),
+ dag_id="test_dag",
+ )
+
+ def test_core_deadline_reference_treated_as_builtins(self):
+ """Test that refs from airflow.models.deadline are still treated as
builtins."""
+ from airflow.models.deadline import ReferenceModels
+ from airflow.serialization.encoders import encode_deadline_reference
+
+ ref = ReferenceModels.DagRunLogicalDateDeadline()
+ serialized = encode_deadline_reference(ref)
+
+ assert "__class_path" not in serialized
+ assert serialized["reference_type"] == "DagRunLogicalDateDeadline"
+
+ def test_is_builtin_reference(self):
+ """Test that is_builtin_reference correctly identifies built-in vs
custom references."""
+ assert
SerializedReferenceModels.is_builtin_reference("DagRunLogicalDateDeadline") is
True
+ assert
SerializedReferenceModels.is_builtin_reference("DagRunQueuedAtDeadline") is True
+ assert
SerializedReferenceModels.is_builtin_reference("AverageRuntimeDeadline") is True
+
+ assert SerializedReferenceModels.is_builtin_reference("MyCustomRef")
is False
diff --git a/task-sdk/src/airflow/sdk/definitions/deadline.py
b/task-sdk/src/airflow/sdk/definitions/deadline.py
index 8c55e10d45c..2fe220e789d 100644
--- a/task-sdk/src/airflow/sdk/definitions/deadline.py
+++ b/task-sdk/src/airflow/sdk/definitions/deadline.py
@@ -17,10 +17,11 @@
from __future__ import annotations
import logging
+from abc import ABC
+from dataclasses import dataclass
from datetime import datetime, timedelta
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Any
-from airflow.models.deadline import DeadlineReferenceType, ReferenceModels
from airflow.sdk.definitions.callback import AsyncCallback, Callback,
SyncCallback
if TYPE_CHECKING:
@@ -29,7 +30,111 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-DeadlineReferenceTypes: TypeAlias =
tuple[type[ReferenceModels.BaseDeadlineReference], ...]
+# Field name used in serialization - must be in sync with
SerializedReferenceModels.REFERENCE_TYPE_FIELD
+REFERENCE_TYPE_FIELD = "reference_type"
+
+
+class BaseDeadlineReference(ABC):
+ """
+ Base class for all Deadline Reference implementations.
+
+ This is a lightweight SDK class for DAG authoring. It only handles
serialization.
+ The actual evaluation logic (_evaluate_with) is in Core's
SerializedReferenceModels.
+
+ For custom deadline references, users should inherit from this class and
implement
+ _evaluate_with() with deferred Core imports (imports inside the method
body).
+ """
+
+ @property
+ def reference_name(self) -> str:
+ """Return the class name as the reference identifier."""
+ return self.__class__.__name__
+
+ def serialize_reference(self) -> dict[str, Any]:
+ """
+ Serialize this reference type into a dictionary representation.
+
+ Override this method in subclasses if additional data is needed for
serialization.
+ """
+ return {REFERENCE_TYPE_FIELD: self.reference_name}
+
+ @classmethod
+ def deserialize_reference(cls, reference_data: dict[str, Any]) ->
BaseDeadlineReference:
+ """
+ Deserialize a reference type from its dictionary representation.
+
+ :param reference_data: Dictionary containing serialized reference data.
+ """
+ return cls()
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, BaseDeadlineReference):
+ return NotImplemented
+ return self.serialize_reference() == other.serialize_reference()
+
+ def __hash__(self) -> int:
+ return hash(frozenset(self.serialize_reference().items()))
+
+
+class DagRunLogicalDateDeadline(BaseDeadlineReference):
+ """A deadline that returns a DagRun's logical date."""
+
+
+class DagRunQueuedAtDeadline(BaseDeadlineReference):
+ """A deadline that returns when a DagRun was queued."""
+
+
+@dataclass
+class FixedDatetimeDeadline(BaseDeadlineReference):
+ """A deadline that always returns a fixed datetime."""
+
+ _datetime: datetime
+
+ def serialize_reference(self) -> dict[str, Any]:
+ return {
+ REFERENCE_TYPE_FIELD: self.reference_name,
+ "datetime": self._datetime.timestamp(),
+ }
+
+ @classmethod
+ def deserialize_reference(cls, reference_data: dict[str, Any]) ->
FixedDatetimeDeadline:
+ from airflow._shared.timezones import timezone
+
+ return
cls(_datetime=timezone.from_timestamp(reference_data["datetime"]))
+
+
+@dataclass
+class AverageRuntimeDeadline(BaseDeadlineReference):
+ """A deadline that calculates the average runtime from past DAG runs."""
+
+ DEFAULT_LIMIT = 10
+ max_runs: int
+ min_runs: int | None = None
+
+ def __post_init__(self):
+ if self.min_runs is None:
+ self.min_runs = self.max_runs
+ if self.min_runs < 1:
+ raise ValueError("min_runs must be at least 1")
+
+ def serialize_reference(self) -> dict[str, Any]:
+ return {
+ REFERENCE_TYPE_FIELD: self.reference_name,
+ "max_runs": self.max_runs,
+ "min_runs": self.min_runs,
+ }
+
+ @classmethod
+ def deserialize_reference(cls, reference_data: dict[str, Any]) ->
AverageRuntimeDeadline:
+ max_runs = reference_data.get("max_runs", cls.DEFAULT_LIMIT)
+ min_runs = reference_data.get("min_runs", max_runs)
+ if min_runs < 1:
+ raise ValueError("min_runs must be at least 1")
+ return cls(max_runs=max_runs, min_runs=min_runs)
+
+
+DeadlineReferenceType: TypeAlias = BaseDeadlineReference
+DeadlineReferenceTypes: TypeAlias = tuple[type[BaseDeadlineReference], ...]
class DeadlineAlert:
@@ -118,33 +223,31 @@ class DeadlineReference:
# Deadlines that should be created when the DagRun is created.
DAGRUN_CREATED: DeadlineReferenceTypes = (
- ReferenceModels.DagRunLogicalDateDeadline,
- ReferenceModels.FixedDatetimeDeadline,
- ReferenceModels.AverageRuntimeDeadline,
+ DagRunLogicalDateDeadline,
+ FixedDatetimeDeadline,
+ AverageRuntimeDeadline,
)
# Deadlines that should be created when the DagRun is queued.
- DAGRUN_QUEUED: DeadlineReferenceTypes =
(ReferenceModels.DagRunQueuedAtDeadline,)
+ DAGRUN_QUEUED: DeadlineReferenceTypes = (DagRunQueuedAtDeadline,)
# All DagRun-related deadline types.
DAGRUN: DeadlineReferenceTypes = DAGRUN_CREATED + DAGRUN_QUEUED
- from airflow.models.deadline import ReferenceModels
-
- DAGRUN_LOGICAL_DATE: DeadlineReferenceType =
ReferenceModels.DagRunLogicalDateDeadline()
- DAGRUN_QUEUED_AT: DeadlineReferenceType =
ReferenceModels.DagRunQueuedAtDeadline()
+ DAGRUN_LOGICAL_DATE: DeadlineReferenceType = DagRunLogicalDateDeadline()
+ DAGRUN_QUEUED_AT: DeadlineReferenceType = DagRunQueuedAtDeadline()
@classmethod
def AVERAGE_RUNTIME(cls, max_runs: int = 0, min_runs: int | None = None)
-> DeadlineReferenceType:
if max_runs == 0:
- max_runs = cls.ReferenceModels.AverageRuntimeDeadline.DEFAULT_LIMIT
+ max_runs = AverageRuntimeDeadline.DEFAULT_LIMIT
if min_runs is None:
min_runs = max_runs
- return cls.ReferenceModels.AverageRuntimeDeadline(max_runs, min_runs)
+ return AverageRuntimeDeadline(max_runs, min_runs)
@classmethod
- def FIXED_DATETIME(cls, datetime: datetime) -> DeadlineReferenceType:
- return cls.ReferenceModels.FixedDatetimeDeadline(datetime)
+ def FIXED_DATETIME(cls, dt: datetime) -> DeadlineReferenceType:
+ return FixedDatetimeDeadline(dt)
# TODO: Remove this once other deadline types exist.
# This is a temporary reference type used only in tests to verify that
@@ -152,16 +255,16 @@ class DeadlineReference:
# It should be replaced with a real non-dagrun deadline type when one is
available.
_TEMPORARY_TEST_REFERENCE = type(
"TemporaryTestDeadlineForTypeChecking",
- (DeadlineReferenceType,),
- {"_evaluate_with": lambda self, **kwargs: datetime.now()},
+ (BaseDeadlineReference,),
+ {"serialize_reference": lambda self: {REFERENCE_TYPE_FIELD:
"TemporaryTestDeadlineForTypeChecking"}},
)()
@classmethod
def register_custom_reference(
cls,
- reference_class: type[ReferenceModels.BaseDeadlineReference],
+ reference_class: type[BaseDeadlineReference],
deadline_reference_type: DeadlineReferenceTypes | None = None,
- ) -> type[ReferenceModels.BaseDeadlineReference]:
+ ) -> type[BaseDeadlineReference]:
"""
Register a custom deadline reference class.
@@ -169,18 +272,18 @@ class DeadlineReference:
:param deadline_reference_type: A DeadlineReference.TYPES for when the
deadline should be evaluated ("DAGRUN_CREATED",
"DAGRUN_QUEUED", etc.); defaults to
DeadlineReference.TYPES.DAGRUN_CREATED
"""
- from airflow.models.deadline import ReferenceModels
-
# Default to DAGRUN_CREATED if no deadline_reference_type specified
if deadline_reference_type is None:
deadline_reference_type = cls.TYPES.DAGRUN_CREATED
# Validate the reference class inherits from BaseDeadlineReference
- if not issubclass(reference_class,
ReferenceModels.BaseDeadlineReference):
+ # Accept both sdk and core base classes for backward compatibility for
now
+ from airflow.models.deadline import ReferenceModels
+
+ if not issubclass(reference_class, (BaseDeadlineReference,
ReferenceModels.BaseDeadlineReference)):
raise ValueError(f"{reference_class.__name__} must inherit from
BaseDeadlineReference")
- # Register the new reference with ReferenceModels and
DeadlineReference for discoverability
- setattr(ReferenceModels, reference_class.__name__, reference_class)
+ # Register the new reference with DeadlineReference for discoverability
setattr(cls, reference_class.__name__, reference_class())
logger.info("Registered DeadlineReference %s",
reference_class.__name__)
@@ -203,29 +306,36 @@ class DeadlineReference:
def deadline_reference(
deadline_reference_type: DeadlineReferenceTypes | None = None,
-) -> Callable[[type[ReferenceModels.BaseDeadlineReference]],
type[ReferenceModels.BaseDeadlineReference]]:
+) -> Callable[[type[BaseDeadlineReference]], type[BaseDeadlineReference]]:
"""
Decorate a class to register a custom deadline reference.
Usage:
@deadline_reference()
- class MyCustomReference(ReferenceModels.BaseDeadlineReference):
+ class MyCustomReference(BaseDeadlineReference):
# By default, evaluate_with will be called when a new dagrun is
created.
def _evaluate_with(self, *, session: Session, **kwargs) ->
datetime:
- # Put your business logic here
+ # Put your business logic here (use deferred imports for Core
types)
+ from airflow.models import DagRun
return some_datetime
+ def serialize_reference(self) -> dict:
+ return {"reference_type": self.reference_name}
+
@deadline_reference(DeadlineReference.TYPES.DAGRUN_QUEUED)
- class MyQueuedRef(ReferenceModels.BaseDeadlineReference):
+ class MyQueuedRef(BaseDeadlineReference):
# Optionally, you can specify when you want it calculated by
providing a DeadlineReference.TYPES
def _evaluate_with(self, *, session: Session, **kwargs) ->
datetime:
# Put your business logic here
return some_datetime
+
+ def serialize_reference(self) -> dict:
+ return {"reference_type": self.reference_name}
"""
def decorator(
- reference_class: type[ReferenceModels.BaseDeadlineReference],
- ) -> type[ReferenceModels.BaseDeadlineReference]:
+ reference_class: type[BaseDeadlineReference],
+ ) -> type[BaseDeadlineReference]:
DeadlineReference.register_custom_reference(reference_class,
deadline_reference_type)
return reference_class