This is an automated email from the ASF dual-hosted git repository.
ferruzzi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 7b562ba381f Use average runtime as deadline reference (#55088)
7b562ba381f is described below
commit 7b562ba381f88c7a42b50aab7b0e234adf7941e4
Author: Sean Ghaeli <[email protected]>
AuthorDate: Mon Sep 22 15:26:06 2025 -0700
Use average runtime as deadline reference (#55088)
Co-authored-by: Ramit Kataria <[email protected]>
---
airflow-core/docs/howto/deadline-alerts.rst | 53 ++++++++
airflow-core/src/airflow/models/deadline.py | 98 +++++++++++++-
.../airflow/serialization/serialized_objects.py | 24 ++--
airflow-core/tests/unit/models/test_deadline.py | 147 ++++++++++++++++++++-
task-sdk/src/airflow/sdk/definitions/deadline.py | 9 ++
.../tests/task_sdk/definitions/test_deadline.py | 1 +
6 files changed, 315 insertions(+), 17 deletions(-)
diff --git a/airflow-core/docs/howto/deadline-alerts.rst
b/airflow-core/docs/howto/deadline-alerts.rst
index e52fdd6aeb1..74f10f1b636 100644
--- a/airflow-core/docs/howto/deadline-alerts.rst
+++ b/airflow-core/docs/howto/deadline-alerts.rst
@@ -104,6 +104,58 @@ Airflow provides several built-in reference points that
you can use with Deadlin
``DeadlineReference.FIXED_DATETIME``
Specifies a fixed point in time. Useful when Dags must complete by a
specific time.
+``DeadlineReference.AVERAGE_RUNTIME``
+ Calculates deadlines based on the average runtime of previous DAG runs.
This reference
+ analyzes historical execution data to predict when the current run should
complete.
+ The deadline is set to the current time plus the calculated average
runtime plus the interval.
+ If insufficient historical data exists, no deadline is created.
+
+ Parameters:
+ * ``max_runs`` (int, optional): Maximum number of recent DAG runs to
analyze. Defaults to 10.
+ * ``min_runs`` (int, optional): Minimum number of completed runs
required to calculate average. Defaults to same value as ``max_runs``.
+
+ Example usage:
+
+ .. code-block:: python
+
+ # Use default settings (analyze up to 10 runs, require 10 runs)
+ DeadlineReference.AVERAGE_RUNTIME()
+
+ # Analyze up to 20 runs but calculate with minimum 5 runs
+ DeadlineReference.AVERAGE_RUNTIME(max_runs=20, min_runs=5)
+
+ # Strict: require exactly 15 runs to calculate
+ DeadlineReference.AVERAGE_RUNTIME(max_runs=15, min_runs=15)
+
+Here's an example using average runtime:
+
+.. code-block:: python
+
+ with DAG(
+ dag_id="average_runtime_deadline",
+ deadline=DeadlineAlert(
+ reference=DeadlineReference.AVERAGE_RUNTIME(max_runs=15,
min_runs=5),
+ interval=timedelta(minutes=30), # Alert if 30 minutes past
average runtime
+ callback=AsyncCallback(
+ SlackWebhookNotifier,
+ kwargs={"text": "🚨 DAG {{ dag_run.dag_id }} is running longer
than expected!"},
+ ),
+ ),
+ ):
+ EmptyOperator(task_id="data_processing")
+
+If the calculated historical average was 30 minutes, the timeline for this
example would look like this:
+
+::
+
+ |------|----------|--------------|--------------|--------|
+ Queued Start | Deadline
+ 09:00 09:05 09:35 10:05
+ | | |
+ |--- Average --|-- Interval --|
+ (30 min) (30 min)
+
+
Here's an example using a fixed datetime:
.. code-block:: python
@@ -166,6 +218,7 @@ Here's an example using the Slack Notifier if the Dag run
has not finished withi
):
EmptyOperator(task_id="example_task")
+
Creating Custom Callbacks
^^^^^^^^^^^^^^^^^^^^^^^^^
diff --git a/airflow-core/src/airflow/models/deadline.py
b/airflow-core/src/airflow/models/deadline.py
index cfb99160a11..4b3560fc99b 100644
--- a/airflow-core/src/airflow/models/deadline.py
+++ b/airflow-core/src/airflow/models/deadline.py
@@ -26,7 +26,7 @@ from typing import TYPE_CHECKING, Any, cast
import sqlalchemy_jsonfield
import uuid6
-from sqlalchemy import Column, ForeignKey, Index, Integer, String, and_, select
+from sqlalchemy import Column, ForeignKey, Index, Integer, String, and_, func,
select, text
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import relationship
from sqlalchemy_utils import UUIDType
@@ -283,7 +283,7 @@ class ReferenceModels:
def reference_name(cls: Any) -> str:
return cls.__name__
- def evaluate_with(self, *, session: Session, interval: timedelta,
**kwargs: Any) -> datetime:
+ def evaluate_with(self, *, session: Session, interval: timedelta,
**kwargs: Any) -> datetime | None:
"""Validate the provided kwargs and evaluate this deadline with
the given conditions."""
filtered_kwargs = {k: v for k, v in kwargs.items() if k in
self.required_kwargs}
@@ -295,10 +295,11 @@ class ReferenceModels:
if extra_kwargs := kwargs.keys() - filtered_kwargs.keys():
self.log.debug("Ignoring unexpected parameters: %s", ",
".join(extra_kwargs))
- return self._evaluate_with(session=session, **filtered_kwargs) +
interval
+ base_time = self._evaluate_with(session=session, **filtered_kwargs)
+ return base_time + interval if base_time is not None else None
@abstractmethod
- def _evaluate_with(self, *, session: Session, **kwargs: Any) ->
datetime:
+ def _evaluate_with(self, *, session: Session, **kwargs: Any) ->
datetime | None:
"""Must be implemented by subclasses to perform the actual
evaluation."""
raise NotImplementedError
@@ -366,6 +367,95 @@ class ReferenceModels:
return _fetch_from_db(DagRun.queued_at, session=session, **kwargs)
+ @dataclass
+ class AverageRuntimeDeadline(BaseDeadlineReference):
+ """A deadline that calculates the average runtime from past DAG
runs."""
+
+ DEFAULT_LIMIT = 10
+ max_runs: int
+ min_runs: int | None = None
+ required_kwargs = {"dag_id"}
+
+ def __post_init__(self):
+ if self.min_runs is None:
+ self.min_runs = self.max_runs
+ if self.min_runs < 1:
+ raise ValueError("min_runs must be at least 1")
+
+ @provide_session
+ def _evaluate_with(self, *, session: Session, **kwargs: Any) ->
datetime | None:
+ from airflow.models import DagRun
+
+ dag_id = kwargs["dag_id"]
+
+ # Get database dialect to use appropriate time difference
calculation
+ dialect = session.bind.dialect.name
+
+ # Create database-specific expression for calculating duration in
seconds
+ if dialect == "postgresql":
+ duration_expr = func.extract("epoch", DagRun.end_date -
DagRun.start_date)
+ elif dialect == "mysql":
+ # Use TIMESTAMPDIFF to get exact seconds like PostgreSQL
EXTRACT(epoch FROM ...)
+ duration_expr = func.timestampdiff(text("SECOND"),
DagRun.start_date, DagRun.end_date)
+ elif dialect == "sqlite":
+ duration_expr = (func.julianday(DagRun.end_date) -
func.julianday(DagRun.start_date)) * 86400
+ else:
+ raise ValueError(f"Unsupported database dialect: {dialect}")
+
+ # Query for completed DAG runs with both start and end dates
+ # Order by logical_date descending to get most recent runs first
+ query = (
+ select(duration_expr)
+ .filter(DagRun.dag_id == dag_id,
DagRun.start_date.isnot(None), DagRun.end_date.isnot(None))
+ .order_by(DagRun.logical_date.desc())
+ )
+
+ # Apply max_runs
+ query = query.limit(self.max_runs)
+
+ # Get all durations and calculate average
+ durations = session.execute(query).scalars().all()
+
+ if len(durations) < cast("int", self.min_runs):
+ logger.info(
+ "Only %d completed DAG runs found for dag_id: %s (need
%d), skipping deadline creation",
+ len(durations),
+ dag_id,
+ self.min_runs,
+ )
+ return None
+ # Convert to float to handle Decimal types from MySQL while
preserving precision
+ # Use Decimal arithmetic for higher precision, then convert to
float
+ from decimal import Decimal
+
+ decimal_durations = [Decimal(str(d)) for d in durations]
+ avg_seconds = float(sum(decimal_durations) /
len(decimal_durations))
+ logger.info(
+ "Average runtime for dag_id %s (from %d runs): %.2f seconds",
+ dag_id,
+ len(durations),
+ avg_seconds,
+ )
+ return timezone.utcnow() + timedelta(seconds=avg_seconds)
+
+ def serialize_reference(self) -> dict:
+ return {
+ ReferenceModels.REFERENCE_TYPE_FIELD: self.reference_name,
+ "max_runs": self.max_runs,
+ "min_runs": self.min_runs,
+ }
+
+ @classmethod
+ def deserialize_reference(cls, reference_data: dict):
+ max_runs = reference_data.get("max_runs", cls.DEFAULT_LIMIT)
+ min_runs = reference_data.get("min_runs", max_runs)
+ if min_runs < 1:
+ raise ValueError("min_runs must be at least 1")
+ return cls(
+ max_runs=max_runs,
+ min_runs=min_runs,
+ )
+
DeadlineReferenceType = ReferenceModels.BaseDeadlineReference
diff --git a/airflow-core/src/airflow/serialization/serialized_objects.py
b/airflow-core/src/airflow/serialization/serialized_objects.py
index 9dbe19df085..4b99bca45d0 100644
--- a/airflow-core/src/airflow/serialization/serialized_objects.py
+++ b/airflow-core/src/airflow/serialization/serialized_objects.py
@@ -3270,18 +3270,20 @@ class SerializedDAG(BaseSerialization):
if self.deadline:
for deadline in cast("list", self.deadline):
if isinstance(deadline.reference,
DeadlineReference.TYPES.DAGRUN):
- session.add(
- Deadline(
- deadline_time=deadline.reference.evaluate_with(
- session=session,
- interval=deadline.interval,
- dag_id=self.dag_id,
- run_id=run_id,
- ),
- callback=deadline.callback,
- dagrun_id=orm_dagrun.id,
- )
+ deadline_time = deadline.reference.evaluate_with(
+ session=session,
+ interval=deadline.interval,
+ dag_id=self.dag_id,
+ run_id=run_id,
)
+ if deadline_time is not None:
+ session.add(
+ Deadline(
+ deadline_time=deadline_time,
+ callback=deadline.callback,
+ dagrun_id=orm_dagrun.id,
+ )
+ )
return orm_dagrun
diff --git a/airflow-core/tests/unit/models/test_deadline.py
b/airflow-core/tests/unit/models/test_deadline.py
index 5e152935707..05229488fc3 100644
--- a/airflow-core/tests/unit/models/test_deadline.py
+++ b/airflow-core/tests/unit/models/test_deadline.py
@@ -43,6 +43,7 @@ REFERENCE_TYPES = [
pytest.param(DeadlineReference.DAGRUN_LOGICAL_DATE, id="logical_date"),
pytest.param(DeadlineReference.DAGRUN_QUEUED_AT, id="queued_at"),
pytest.param(DeadlineReference.FIXED_DATETIME(DEFAULT_DATE),
id="fixed_deadline"),
+ pytest.param(DeadlineReference.AVERAGE_RUNTIME(), id="average_runtime"),
]
@@ -356,6 +357,7 @@ class TestCalculatedDeadlineDatabaseCalls:
pytest.param(DeadlineReference.DAGRUN_LOGICAL_DATE,
DagRun.logical_date, id="logical_date"),
pytest.param(DeadlineReference.DAGRUN_QUEUED_AT, DagRun.queued_at,
id="queued_at"),
pytest.param(DeadlineReference.FIXED_DATETIME(DEFAULT_DATE), None,
id="fixed_deadline"),
+ pytest.param(DeadlineReference.AVERAGE_RUNTIME(), None,
id="average_runtime"),
],
)
def test_deadline_database_integration(self, reference, expected_column,
session):
@@ -375,11 +377,142 @@ class TestCalculatedDeadlineDatabaseCalls:
if expected_column is not None:
result = reference.evaluate_with(session=session,
interval=interval, **conditions)
mock_fetch.assert_called_once_with(expected_column,
session=session, **conditions)
+ elif reference == DeadlineReference.AVERAGE_RUNTIME():
+ with mock.patch("airflow._shared.timezones.timezone.utcnow")
as mock_utcnow:
+ mock_utcnow.return_value = DEFAULT_DATE
+ # No DAG runs exist, so it should use 24-hour default
+ result = reference.evaluate_with(session=session,
interval=interval, dag_id=DAG_ID)
+ mock_fetch.assert_not_called()
+ # Should return None when no DAG runs exist
+ assert result is None
else:
result = reference.evaluate_with(session=session,
interval=interval)
mock_fetch.assert_not_called()
+ assert result == DEFAULT_DATE + interval
- assert result == DEFAULT_DATE + interval
+ def test_average_runtime_with_sufficient_history(self, session, dag_maker):
+ """Test AverageRuntimeDeadline when enough historical data exists."""
+ with dag_maker(DAG_ID):
+ EmptyOperator(task_id="test_task")
+
+ # Create 10 completed DAG runs with known durations
+ base_time = DEFAULT_DATE
+ durations = [3600, 7200, 1800, 5400, 2700, 4500, 3300, 6000, 2400,
4200]
+
+ for i, duration in enumerate(durations):
+ logical_date = base_time + timedelta(days=i)
+ start_time = logical_date + timedelta(minutes=5)
+ end_time = start_time + timedelta(seconds=duration)
+
+ dagrun = dag_maker.create_dagrun(
+ logical_date=logical_date, run_id=f"test_run_{i}",
state=DagRunState.SUCCESS
+ )
+ # Manually set start and end times
+ dagrun.start_date = start_time
+ dagrun.end_date = end_time
+
+ session.commit()
+
+ # Test with default max_runs (10)
+ reference = DeadlineReference.AVERAGE_RUNTIME()
+ interval = timedelta(hours=1)
+
+ with mock.patch("airflow._shared.timezones.timezone.utcnow") as
mock_utcnow:
+ mock_utcnow.return_value = DEFAULT_DATE
+ result = reference.evaluate_with(session=session,
interval=interval, dag_id=DAG_ID)
+
+ # Calculate expected average: sum(durations) / len(durations)
+ expected_avg_seconds = sum(durations) / len(durations)
+ expected = DEFAULT_DATE + timedelta(seconds=expected_avg_seconds)
+ interval
+
+ # Compare only up to minutes to avoid sub-second timing issues in
CI
+ assert result.replace(second=0, microsecond=0) ==
expected.replace(second=0, microsecond=0)
+
+ def test_average_runtime_with_insufficient_history(self, session,
dag_maker):
+ """Test AverageRuntimeDeadline when insufficient historical data
exists."""
+ with dag_maker(DAG_ID):
+ EmptyOperator(task_id="test_task")
+
+ # Create only 5 completed DAG runs (less than default max_runs of 10)
+ base_time = DEFAULT_DATE
+ durations = [3600, 7200, 1800, 5400, 2700]
+
+ for i, duration in enumerate(durations):
+ logical_date = base_time + timedelta(days=i)
+ start_time = logical_date + timedelta(minutes=5)
+ end_time = start_time + timedelta(seconds=duration)
+
+ dagrun = dag_maker.create_dagrun(
+ logical_date=logical_date, run_id=f"insufficient_run_{i}",
state=DagRunState.SUCCESS
+ )
+ # Manually set start and end times
+ dagrun.start_date = start_time
+ dagrun.end_date = end_time
+
+ session.commit()
+
+ reference = DeadlineReference.AVERAGE_RUNTIME()
+ interval = timedelta(hours=1)
+
+ with mock.patch("airflow._shared.timezones.timezone.utcnow") as
mock_utcnow:
+ mock_utcnow.return_value = DEFAULT_DATE
+ result = reference.evaluate_with(session=session,
interval=interval, dag_id=DAG_ID)
+
+ # Should return None since insufficient runs
+ assert result is None
+
+ def test_average_runtime_with_min_runs(self, session, dag_maker):
+ """Test AverageRuntimeDeadline with min_runs parameter allowing
calculation with fewer runs."""
+ with dag_maker(DAG_ID):
+ EmptyOperator(task_id="test_task")
+
+ # Create only 3 completed DAG runs
+ base_time = DEFAULT_DATE
+ durations = [3600, 7200, 1800] # 1h, 2h, 30min
+
+ for i, duration in enumerate(durations):
+ logical_date = base_time + timedelta(days=i)
+ start_time = logical_date + timedelta(minutes=5)
+ end_time = start_time + timedelta(seconds=duration)
+
+ dagrun = dag_maker.create_dagrun(
+ logical_date=logical_date, run_id=f"min_runs_test_{i}",
state=DagRunState.SUCCESS
+ )
+ # Manually set start and end times
+ dagrun.start_date = start_time
+ dagrun.end_date = end_time
+
+ session.commit()
+
+ # Test with min_runs=2, should work with 3 runs
+ reference = DeadlineReference.AVERAGE_RUNTIME(max_runs=10, min_runs=2)
+ interval = timedelta(hours=1)
+
+ with mock.patch("airflow._shared.timezones.timezone.utcnow") as
mock_utcnow:
+ mock_utcnow.return_value = DEFAULT_DATE
+ result = reference.evaluate_with(session=session,
interval=interval, dag_id=DAG_ID)
+
+ # Should calculate average from 3 runs
+ expected_avg_seconds = sum(durations) / len(durations) # 4200
seconds
+ expected = DEFAULT_DATE + timedelta(seconds=expected_avg_seconds)
+ interval
+ # Compare only up to minutes to avoid sub-second timing issues in
CI
+ assert result.replace(second=0, microsecond=0) ==
expected.replace(second=0, microsecond=0)
+
+ # Test with min_runs=5, should return None with only 3 runs
+ reference = DeadlineReference.AVERAGE_RUNTIME(max_runs=10, min_runs=5)
+
+ with mock.patch("airflow._shared.timezones.timezone.utcnow") as
mock_utcnow:
+ mock_utcnow.return_value = DEFAULT_DATE
+ result = reference.evaluate_with(session=session,
interval=interval, dag_id=DAG_ID)
+ assert result is None
+
+ def test_average_runtime_min_runs_validation(self):
+ """Test that min_runs must be at least 1."""
+ with pytest.raises(ValueError, match="min_runs must be at least 1"):
+ DeadlineReference.AVERAGE_RUNTIME(max_runs=10, min_runs=0)
+
+ with pytest.raises(ValueError, match="min_runs must be at least 1"):
+ DeadlineReference.AVERAGE_RUNTIME(max_runs=10, min_runs=-1)
class TestDeadlineReference:
@@ -424,7 +557,7 @@ class TestDeadlineReference:
f"{reference.__class__.__name__} is missing required
parameters: ",
*reference.required_kwargs,
}
- assert [substring in str(raised_exception) for substring in
expected_substrings]
+ assert all(substring in str(raised_exception.value) for substring
in expected_substrings)
else:
# Let the lack of an exception here effectively assert that no
exception is raised.
reference.evaluate_with(session=session, **self.DEFAULT_ARGS)
@@ -440,3 +573,13 @@ class TestDeadlineReference:
queued_reference = DeadlineReference.DAGRUN_QUEUED_AT
assert isinstance(queued_reference,
ReferenceModels.DagRunQueuedAtDeadline)
+
+ average_runtime_reference = DeadlineReference.AVERAGE_RUNTIME()
+ assert isinstance(average_runtime_reference,
ReferenceModels.AverageRuntimeDeadline)
+ assert average_runtime_reference.max_runs == 10
+ assert average_runtime_reference.min_runs == 10
+
+ # Test with custom parameters
+ custom_reference = DeadlineReference.AVERAGE_RUNTIME(max_runs=5,
min_runs=3)
+ assert custom_reference.max_runs == 5
+ assert custom_reference.min_runs == 3
diff --git a/task-sdk/src/airflow/sdk/definitions/deadline.py
b/task-sdk/src/airflow/sdk/definitions/deadline.py
index beb32d17382..1f31dd7ace5 100644
--- a/task-sdk/src/airflow/sdk/definitions/deadline.py
+++ b/task-sdk/src/airflow/sdk/definitions/deadline.py
@@ -286,6 +286,7 @@ class DeadlineReference:
DAGRUN_CREATED = (
ReferenceModels.DagRunLogicalDateDeadline,
ReferenceModels.FixedDatetimeDeadline,
+ ReferenceModels.AverageRuntimeDeadline,
)
# Deadlines that should be created when the DagRun is queued.
@@ -299,6 +300,14 @@ class DeadlineReference:
DAGRUN_LOGICAL_DATE: DeadlineReferenceType =
ReferenceModels.DagRunLogicalDateDeadline()
DAGRUN_QUEUED_AT: DeadlineReferenceType =
ReferenceModels.DagRunQueuedAtDeadline()
+ @classmethod
+ def AVERAGE_RUNTIME(cls, max_runs: int = 0, min_runs: int | None = None)
-> DeadlineReferenceType:
+ if max_runs == 0:
+ max_runs = cls.ReferenceModels.AverageRuntimeDeadline.DEFAULT_LIMIT
+ if min_runs is None:
+ min_runs = max_runs
+ return cls.ReferenceModels.AverageRuntimeDeadline(max_runs, min_runs)
+
@classmethod
def FIXED_DATETIME(cls, datetime: datetime) -> DeadlineReferenceType:
return cls.ReferenceModels.FixedDatetimeDeadline(datetime)
diff --git a/task-sdk/tests/task_sdk/definitions/test_deadline.py
b/task-sdk/tests/task_sdk/definitions/test_deadline.py
index e13bc435e8a..ffbcaa4aa85 100644
--- a/task-sdk/tests/task_sdk/definitions/test_deadline.py
+++ b/task-sdk/tests/task_sdk/definitions/test_deadline.py
@@ -41,6 +41,7 @@ REFERENCE_TYPES = [
pytest.param(DeadlineReference.DAGRUN_LOGICAL_DATE, id="logical_date"),
pytest.param(DeadlineReference.DAGRUN_QUEUED_AT, id="queued_at"),
pytest.param(DeadlineReference.FIXED_DATETIME(DEFAULT_DATE),
id="fixed_deadline"),
+ pytest.param(DeadlineReference.AVERAGE_RUNTIME, id="average_runtime"),
]