This is an automated email from the ASF dual-hosted git repository.
weilee pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 42a94927145 Combine asset events fetching logic into one SQL query and
clean up unnecessary asset-triggered dag data (#46721)
42a94927145 is described below
commit 42a94927145d052183b09637157359eb044e472d
Author: Wei Lee <[email protected]>
AuthorDate: Fri Feb 14 16:15:38 2025 +0800
Combine asset events fetching logic into one SQL query and clean up
unnecessary asset-triggered dag data (#46721)
* refactor(dag): simplify asset_triggered_dag_info content
it was {dag_id: (min_asset_event_date, max_asset_event_date)}
min_asset_event_date is no longer needed as we won't have data interval for
asset triggered event
* refactor(scheduler_job_runner): merge asset event fetching logic
* refactor(scheduler_job_runner): rename asset_triggered_dag_info as
triggered_date_by_dag
---
airflow/jobs/scheduler_job_runner.py | 34 +++++++++++++++-------------------
airflow/models/dag.py | 31 +++++++++++++++++--------------
tests/models/test_dag.py | 11 +++++------
3 files changed, 37 insertions(+), 39 deletions(-)
diff --git a/airflow/jobs/scheduler_job_runner.py
b/airflow/jobs/scheduler_job_runner.py
index 4d81e464e68..e35e8573a6d 100644
--- a/airflow/jobs/scheduler_job_runner.py
+++ b/airflow/jobs/scheduler_job_runner.py
@@ -27,7 +27,7 @@ import time
from collections import Counter, defaultdict, deque
from collections.abc import Collection, Iterable, Iterator
from contextlib import ExitStack, suppress
-from datetime import timedelta
+from datetime import date, timedelta
from functools import lru_cache, partial
from itertools import groupby
from typing import TYPE_CHECKING, Any, Callable
@@ -1198,17 +1198,17 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
@retry_db_transaction
def _create_dagruns_for_dags(self, guard: CommitProhibitorGuard, session:
Session) -> None:
"""Find Dag Models needing DagRuns and Create Dag Runs with retries in
case of OperationalError."""
- query, asset_triggered_dag_info =
DagModel.dags_needing_dagruns(session)
+ query, triggered_date_by_dag = DagModel.dags_needing_dagruns(session)
all_dags_needing_dag_runs = set(query.all())
asset_triggered_dags = [
- dag for dag in all_dags_needing_dag_runs if dag.dag_id in
asset_triggered_dag_info
+ dag for dag in all_dags_needing_dag_runs if dag.dag_id in
triggered_date_by_dag
]
non_asset_dags =
all_dags_needing_dag_runs.difference(asset_triggered_dags)
self._create_dag_runs(non_asset_dags, session)
if asset_triggered_dags:
self._create_dag_runs_asset_triggered(
dag_models=asset_triggered_dags,
- asset_triggered_dag_info=asset_triggered_dag_info,
+ triggered_date_by_dag=triggered_date_by_dag,
session=session,
)
@@ -1325,13 +1325,13 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
def _create_dag_runs_asset_triggered(
self,
dag_models: Collection[DagModel],
- asset_triggered_dag_info: dict[str, tuple[datetime, datetime]],
+ triggered_date_by_dag: dict[str, datetime],
session: Session,
) -> None:
"""For DAGs that are triggered by assets, create dag runs."""
triggered_dates: dict[str, DateTime] = {
dag_id: timezone.coerce_datetime(last_asset_event_time)
- for dag_id, (_, last_asset_event_time) in
asset_triggered_dag_info.items()
+ for dag_id, last_asset_event_time in triggered_date_by_dag.items()
}
for dag_model in dag_models:
@@ -1350,30 +1350,26 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
latest_dag_version = DagVersion.get_latest_version(dag.dag_id,
session=session)
triggered_date = triggered_dates[dag.dag_id]
- previous_dag_run = session.scalar(
- select(DagRun)
+ cte = (
+
select(func.max(DagRun.run_after).label("previous_dag_run_run_after"))
.where(
DagRun.dag_id == dag.dag_id,
- DagRun.run_after < triggered_date,
DagRun.run_type == DagRunType.ASSET_TRIGGERED,
+ DagRun.run_after < triggered_date,
)
- .order_by(DagRun.run_after.desc())
- .limit(1)
+ .cte()
)
- asset_event_filters = [
- DagScheduleAssetReference.dag_id == dag.dag_id,
- AssetEvent.timestamp <= triggered_date,
- ]
- if previous_dag_run:
- asset_event_filters.append(AssetEvent.timestamp >
previous_dag_run.run_after)
-
asset_events = session.scalars(
select(AssetEvent)
.join(
DagScheduleAssetReference,
AssetEvent.asset_id == DagScheduleAssetReference.asset_id,
)
- .where(*asset_event_filters)
+ .where(
+ DagScheduleAssetReference.dag_id == dag.dag_id,
+ AssetEvent.timestamp <= triggered_date,
+ AssetEvent.timestamp >
func.coalesce(cte.c.previous_dag_run_run_after, date.min),
+ )
).all()
dag_run = dag.create_dagrun(
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 37ff854c8da..cff6c857c03 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -2317,7 +2317,7 @@ class DagModel(Base):
dm.is_active = False
@classmethod
- def dags_needing_dagruns(cls, session: Session) -> tuple[Query, dict[str,
tuple[datetime, datetime]]]:
+ def dags_needing_dagruns(cls, session: Session) -> tuple[Query, dict[str,
datetime]]:
"""
Return (and lock) a list of Dag objects that are due to create a new
DagRun.
@@ -2341,11 +2341,12 @@ class DagModel(Base):
adrq_by_dag: dict[str, list[AssetDagRunQueue]] = defaultdict(list)
for r in session.scalars(select(AssetDagRunQueue)):
adrq_by_dag[r.target_dag_id].append(r)
- dag_statuses: dict[str, dict[AssetUniqueKey, bool]] = {}
- for dag_id, records in adrq_by_dag.items():
- dag_statuses[dag_id] = {AssetUniqueKey.from_asset(x.asset): True
for x in records}
- ser_dags =
SerializedDagModel.get_latest_serialized_dags(dag_ids=list(dag_statuses),
session=session)
+ dag_statuses: dict[str, dict[AssetUniqueKey, bool]] = {
+ dag_id: {AssetUniqueKey.from_asset(adrq.asset): True for adrq in
adrqs}
+ for dag_id, adrqs in adrq_by_dag.items()
+ }
+ ser_dags =
SerializedDagModel.get_latest_serialized_dags(dag_ids=list(dag_statuses),
session=session)
for ser_dag in ser_dags:
dag_id = ser_dag.dag_id
statuses = dag_statuses[dag_id]
@@ -2353,14 +2354,16 @@ class DagModel(Base):
del adrq_by_dag[dag_id]
del dag_statuses[dag_id]
del dag_statuses
- # TODO: make it more readable (rename it or make it attrs, dataclass
or etc.)
- asset_triggered_dag_info: dict[str, tuple[datetime, datetime]] = {}
- for dag_id, records in adrq_by_dag.items():
- times = sorted(x.created_at for x in records)
- asset_triggered_dag_info[dag_id] = (times[0], times[-1])
+
+ # triggered dates for asset triggered dags
+ triggered_date_by_dag: dict[str, datetime] = {
+ dag_id: max(adrq.created_at for adrq in adrqs) for dag_id, adrqs
in adrq_by_dag.items()
+ }
del adrq_by_dag
- asset_triggered_dag_ids = set(asset_triggered_dag_info.keys())
+
+ asset_triggered_dag_ids = set(triggered_date_by_dag.keys())
if asset_triggered_dag_ids:
+ # exclude as max active runs has been reached
exclusion_list = set(
session.scalars(
select(DagModel.dag_id)
@@ -2373,8 +2376,8 @@ class DagModel(Base):
)
if exclusion_list:
asset_triggered_dag_ids -= exclusion_list
- asset_triggered_dag_info = {
- k: v for k, v in asset_triggered_dag_info.items() if k not
in exclusion_list
+ triggered_date_by_dag = {
+ k: v for k, v in triggered_date_by_dag.items() if k not in
exclusion_list
}
# We limit so that _one_ scheduler doesn't try to do all the creation
of dag runs
@@ -2395,7 +2398,7 @@ class DagModel(Base):
return (
session.scalars(with_row_locks(query, of=cls, session=session,
skip_locked=True)),
- asset_triggered_dag_info,
+ triggered_date_by_dag,
)
def calculate_dagrun_date_fields(
diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py
index aee56bee883..a5330b17403 100644
--- a/tests/models/test_dag.py
+++ b/tests/models/test_dag.py
@@ -2377,7 +2377,7 @@ class TestDagModel:
assert sdm.dag._processor_dags_folder == settings.DAGS_FOLDER
@pytest.mark.need_serialized_dag
- def test_dags_needing_dagruns_asset_triggered_dag_info_queued_times(self,
session, dag_maker):
+ def test_dags_needing_dagruns_triggered_date_by_dag_queued_times(self,
session, dag_maker):
asset1 = Asset(uri="test://asset1", group="test-group")
asset2 = Asset(uri="test://asset2", name="test_asset_2",
group="test-group")
@@ -2417,11 +2417,10 @@ class TestDagModel:
)
session.flush()
- query, asset_triggered_dag_info =
DagModel.dags_needing_dagruns(session)
- assert len(asset_triggered_dag_info) == 1
- assert dag.dag_id in asset_triggered_dag_info
- first_queued_time, last_queued_time =
asset_triggered_dag_info[dag.dag_id]
- assert first_queued_time == DEFAULT_DATE
+ query, triggered_date_by_dag = DagModel.dags_needing_dagruns(session)
+ assert len(triggered_date_by_dag) == 1
+ assert dag.dag_id in triggered_date_by_dag
+ last_queued_time = triggered_date_by_dag[dag.dag_id]
assert last_queued_time == DEFAULT_DATE + timedelta(hours=1)
def test_asset_expression(self, testing_dag_bundle, session: Session) ->
None: