This is an automated email from the ASF dual-hosted git repository.

weilee pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 42a94927145 Combine asset events fetching logic into one SQL query and 
clean up unnecessary asset-triggered dag data (#46721)
42a94927145 is described below

commit 42a94927145d052183b09637157359eb044e472d
Author: Wei Lee <[email protected]>
AuthorDate: Fri Feb 14 16:15:38 2025 +0800

    Combine asset events fetching logic into one SQL query and clean up 
unnecessary asset-triggered dag data (#46721)
    
    * refactor(dag): simplify asset_triggered_dag_info content
    
    it was {dag_id: (min_asset_event_date, max_asset_event_date)}
    min_asset_event_date is no longer needed as we won't have data interval for 
asset triggered event
    
    * refactor(scheduler_job_runner): merge asset event fetching logic
    
    * refactor(scheduler_job_runner): rename asset_triggered_dag_info as 
triggered_date_by_dag
---
 airflow/jobs/scheduler_job_runner.py | 34 +++++++++++++++-------------------
 airflow/models/dag.py                | 31 +++++++++++++++++--------------
 tests/models/test_dag.py             | 11 +++++------
 3 files changed, 37 insertions(+), 39 deletions(-)

diff --git a/airflow/jobs/scheduler_job_runner.py 
b/airflow/jobs/scheduler_job_runner.py
index 4d81e464e68..e35e8573a6d 100644
--- a/airflow/jobs/scheduler_job_runner.py
+++ b/airflow/jobs/scheduler_job_runner.py
@@ -27,7 +27,7 @@ import time
 from collections import Counter, defaultdict, deque
 from collections.abc import Collection, Iterable, Iterator
 from contextlib import ExitStack, suppress
-from datetime import timedelta
+from datetime import date, timedelta
 from functools import lru_cache, partial
 from itertools import groupby
 from typing import TYPE_CHECKING, Any, Callable
@@ -1198,17 +1198,17 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
     @retry_db_transaction
     def _create_dagruns_for_dags(self, guard: CommitProhibitorGuard, session: 
Session) -> None:
         """Find Dag Models needing DagRuns and Create Dag Runs with retries in 
case of OperationalError."""
-        query, asset_triggered_dag_info = 
DagModel.dags_needing_dagruns(session)
+        query, triggered_date_by_dag = DagModel.dags_needing_dagruns(session)
         all_dags_needing_dag_runs = set(query.all())
         asset_triggered_dags = [
-            dag for dag in all_dags_needing_dag_runs if dag.dag_id in 
asset_triggered_dag_info
+            dag for dag in all_dags_needing_dag_runs if dag.dag_id in 
triggered_date_by_dag
         ]
         non_asset_dags = 
all_dags_needing_dag_runs.difference(asset_triggered_dags)
         self._create_dag_runs(non_asset_dags, session)
         if asset_triggered_dags:
             self._create_dag_runs_asset_triggered(
                 dag_models=asset_triggered_dags,
-                asset_triggered_dag_info=asset_triggered_dag_info,
+                triggered_date_by_dag=triggered_date_by_dag,
                 session=session,
             )
 
@@ -1325,13 +1325,13 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
     def _create_dag_runs_asset_triggered(
         self,
         dag_models: Collection[DagModel],
-        asset_triggered_dag_info: dict[str, tuple[datetime, datetime]],
+        triggered_date_by_dag: dict[str, datetime],
         session: Session,
     ) -> None:
         """For DAGs that are triggered by assets, create dag runs."""
         triggered_dates: dict[str, DateTime] = {
             dag_id: timezone.coerce_datetime(last_asset_event_time)
-            for dag_id, (_, last_asset_event_time) in 
asset_triggered_dag_info.items()
+            for dag_id, last_asset_event_time in triggered_date_by_dag.items()
         }
 
         for dag_model in dag_models:
@@ -1350,30 +1350,26 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
             latest_dag_version = DagVersion.get_latest_version(dag.dag_id, 
session=session)
 
             triggered_date = triggered_dates[dag.dag_id]
-            previous_dag_run = session.scalar(
-                select(DagRun)
+            cte = (
+                
select(func.max(DagRun.run_after).label("previous_dag_run_run_after"))
                 .where(
                     DagRun.dag_id == dag.dag_id,
-                    DagRun.run_after < triggered_date,
                     DagRun.run_type == DagRunType.ASSET_TRIGGERED,
+                    DagRun.run_after < triggered_date,
                 )
-                .order_by(DagRun.run_after.desc())
-                .limit(1)
+                .cte()
             )
-            asset_event_filters = [
-                DagScheduleAssetReference.dag_id == dag.dag_id,
-                AssetEvent.timestamp <= triggered_date,
-            ]
-            if previous_dag_run:
-                asset_event_filters.append(AssetEvent.timestamp > 
previous_dag_run.run_after)
-
             asset_events = session.scalars(
                 select(AssetEvent)
                 .join(
                     DagScheduleAssetReference,
                     AssetEvent.asset_id == DagScheduleAssetReference.asset_id,
                 )
-                .where(*asset_event_filters)
+                .where(
+                    DagScheduleAssetReference.dag_id == dag.dag_id,
+                    AssetEvent.timestamp <= triggered_date,
+                    AssetEvent.timestamp > 
func.coalesce(cte.c.previous_dag_run_run_after, date.min),
+                )
             ).all()
 
             dag_run = dag.create_dagrun(
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 37ff854c8da..cff6c857c03 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -2317,7 +2317,7 @@ class DagModel(Base):
                 dm.is_active = False
 
     @classmethod
-    def dags_needing_dagruns(cls, session: Session) -> tuple[Query, dict[str, 
tuple[datetime, datetime]]]:
+    def dags_needing_dagruns(cls, session: Session) -> tuple[Query, dict[str, 
datetime]]:
         """
         Return (and lock) a list of Dag objects that are due to create a new 
DagRun.
 
@@ -2341,11 +2341,12 @@ class DagModel(Base):
         adrq_by_dag: dict[str, list[AssetDagRunQueue]] = defaultdict(list)
         for r in session.scalars(select(AssetDagRunQueue)):
             adrq_by_dag[r.target_dag_id].append(r)
-        dag_statuses: dict[str, dict[AssetUniqueKey, bool]] = {}
-        for dag_id, records in adrq_by_dag.items():
-            dag_statuses[dag_id] = {AssetUniqueKey.from_asset(x.asset): True 
for x in records}
-        ser_dags = 
SerializedDagModel.get_latest_serialized_dags(dag_ids=list(dag_statuses), 
session=session)
 
+        dag_statuses: dict[str, dict[AssetUniqueKey, bool]] = {
+            dag_id: {AssetUniqueKey.from_asset(adrq.asset): True for adrq in 
adrqs}
+            for dag_id, adrqs in adrq_by_dag.items()
+        }
+        ser_dags = 
SerializedDagModel.get_latest_serialized_dags(dag_ids=list(dag_statuses), 
session=session)
         for ser_dag in ser_dags:
             dag_id = ser_dag.dag_id
             statuses = dag_statuses[dag_id]
@@ -2353,14 +2354,16 @@ class DagModel(Base):
                 del adrq_by_dag[dag_id]
                 del dag_statuses[dag_id]
         del dag_statuses
-        # TODO: make it more readable (rename it or make it attrs, dataclass 
or etc.)
-        asset_triggered_dag_info: dict[str, tuple[datetime, datetime]] = {}
-        for dag_id, records in adrq_by_dag.items():
-            times = sorted(x.created_at for x in records)
-            asset_triggered_dag_info[dag_id] = (times[0], times[-1])
+
+        # triggered dates for asset triggered dags
+        triggered_date_by_dag: dict[str, datetime] = {
+            dag_id: max(adrq.created_at for adrq in adrqs) for dag_id, adrqs 
in adrq_by_dag.items()
+        }
         del adrq_by_dag
-        asset_triggered_dag_ids = set(asset_triggered_dag_info.keys())
+
+        asset_triggered_dag_ids = set(triggered_date_by_dag.keys())
         if asset_triggered_dag_ids:
+            # exclude as max active runs has been reached
             exclusion_list = set(
                 session.scalars(
                     select(DagModel.dag_id)
@@ -2373,8 +2376,8 @@ class DagModel(Base):
             )
             if exclusion_list:
                 asset_triggered_dag_ids -= exclusion_list
-                asset_triggered_dag_info = {
-                    k: v for k, v in asset_triggered_dag_info.items() if k not 
in exclusion_list
+                triggered_date_by_dag = {
+                    k: v for k, v in triggered_date_by_dag.items() if k not in 
exclusion_list
                 }
 
         # We limit so that _one_ scheduler doesn't try to do all the creation 
of dag runs
@@ -2395,7 +2398,7 @@ class DagModel(Base):
 
         return (
             session.scalars(with_row_locks(query, of=cls, session=session, 
skip_locked=True)),
-            asset_triggered_dag_info,
+            triggered_date_by_dag,
         )
 
     def calculate_dagrun_date_fields(
diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py
index aee56bee883..a5330b17403 100644
--- a/tests/models/test_dag.py
+++ b/tests/models/test_dag.py
@@ -2377,7 +2377,7 @@ class TestDagModel:
         assert sdm.dag._processor_dags_folder == settings.DAGS_FOLDER
 
     @pytest.mark.need_serialized_dag
-    def test_dags_needing_dagruns_asset_triggered_dag_info_queued_times(self, 
session, dag_maker):
+    def test_dags_needing_dagruns_triggered_date_by_dag_queued_times(self, 
session, dag_maker):
         asset1 = Asset(uri="test://asset1", group="test-group")
         asset2 = Asset(uri="test://asset2", name="test_asset_2", 
group="test-group")
 
@@ -2417,11 +2417,10 @@ class TestDagModel:
         )
         session.flush()
 
-        query, asset_triggered_dag_info = 
DagModel.dags_needing_dagruns(session)
-        assert len(asset_triggered_dag_info) == 1
-        assert dag.dag_id in asset_triggered_dag_info
-        first_queued_time, last_queued_time = 
asset_triggered_dag_info[dag.dag_id]
-        assert first_queued_time == DEFAULT_DATE
+        query, triggered_date_by_dag = DagModel.dags_needing_dagruns(session)
+        assert len(triggered_date_by_dag) == 1
+        assert dag.dag_id in triggered_date_by_dag
+        last_queued_time = triggered_date_by_dag[dag.dag_id]
         assert last_queued_time == DEFAULT_DATE + timedelta(hours=1)
 
     def test_asset_expression(self, testing_dag_bundle, session: Session) -> 
None:

Reply via email to