This is an automated email from the ASF dual-hosted git repository.

dstandish pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 10c04a4efc Optimize max_execution_date query in single dag case 
(#33242)
10c04a4efc is described below

commit 10c04a4efcee6b92f8057a5a7dbd21d9ca2de710
Author: Josh Owen <[email protected]>
AuthorDate: Mon Jan 22 11:08:25 2024 -0500

    Optimize max_execution_date query in single dag case (#33242)
    
    We can make better use of an index when we're only dealing with one dag, 
which is a common case.
    
    ---------
    
    Co-authored-by: Elad Kalif <[email protected]>
    Co-authored-by: Daniel Standish 
<[email protected]>
---
 airflow/models/dag.py    | 69 +++++++++++++++++++++++++++-----------
 tests/models/test_dag.py | 86 ++++++++++++++++++++++++++++++++++++++++++++++++
 2 files changed, 136 insertions(+), 19 deletions(-)

diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 9ee3409c0d..dac7be010a 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -73,7 +73,7 @@ from sqlalchemy import (
     update,
 )
 from sqlalchemy.ext.associationproxy import association_proxy
-from sqlalchemy.orm import backref, joinedload, relationship
+from sqlalchemy.orm import backref, joinedload, load_only, relationship
 from sqlalchemy.sql import Select, expression
 
 import airflow.templates
@@ -3062,27 +3062,13 @@ class DAG(LoggingMixin):
             session.add(orm_dag)
             orm_dags.append(orm_dag)
 
-        dag_id_to_last_automated_run: dict[str, DagRun] = {}
+        latest_runs: dict[str, DagRun] = {}
         num_active_runs: dict[str, int] = {}
         # Skip these queries entirely if no DAGs can be scheduled to save time.
         if any(dag.timetable.can_be_scheduled for dag in dags):
             # Get the latest automated dag run for each existing dag as a 
single query (avoid n+1 query)
-            last_automated_runs_subq = (
-                select(DagRun.dag_id, 
func.max(DagRun.execution_date).label("max_execution_date"))
-                .where(
-                    DagRun.dag_id.in_(existing_dags),
-                    or_(DagRun.run_type == DagRunType.BACKFILL_JOB, 
DagRun.run_type == DagRunType.SCHEDULED),
-                )
-                .group_by(DagRun.dag_id)
-                .subquery()
-            )
-            last_automated_runs = session.scalars(
-                select(DagRun).where(
-                    DagRun.dag_id == last_automated_runs_subq.c.dag_id,
-                    DagRun.execution_date == 
last_automated_runs_subq.c.max_execution_date,
-                )
-            )
-            dag_id_to_last_automated_run = {run.dag_id: run for run in 
last_automated_runs}
+            query = cls._get_latest_runs_query(existing_dags, session)
+            latest_runs = {run.dag_id: run for run in session.scalars(query)}
 
             # Get number of active dagruns for all dags we are processing as a 
single query.
             num_active_runs = 
DagRun.active_runs_of_dags(dag_ids=existing_dags, session=session)
@@ -3116,7 +3102,7 @@ class DAG(LoggingMixin):
             orm_dag.timetable_description = dag.timetable.description
             orm_dag.processor_subdir = processor_subdir
 
-            last_automated_run: DagRun | None = 
dag_id_to_last_automated_run.get(dag.dag_id)
+            last_automated_run: DagRun | None = latest_runs.get(dag.dag_id)
             if last_automated_run is None:
                 last_automated_data_interval = None
             else:
@@ -3253,6 +3239,51 @@ class DAG(LoggingMixin):
         for dag in dags:
             cls.bulk_write_to_db(dag.subdags, 
processor_subdir=processor_subdir, session=session)
 
+    @classmethod
+    def _get_latest_runs_query(cls, dags, session) -> Query:
+        """
+        Query the database to retrieve the last automated run for each dag.
+
+        :param dags: dags to query
+        :param session: sqlalchemy session object
+        """
+        if len(dags) == 1:
+            # Index optimized fast path to avoid more complicated & slower 
groupby queryplan
+            existing_dag_id = list(dags)[0].dag_id
+            last_automated_runs_subq = (
+                
select(func.max(DagRun.execution_date).label("max_execution_date"))
+                .where(
+                    DagRun.dag_id == existing_dag_id,
+                    DagRun.run_type.in_((DagRunType.BACKFILL_JOB, 
DagRunType.SCHEDULED)),
+                )
+                .subquery()
+            )
+            query = select(DagRun).where(
+                DagRun.dag_id == existing_dag_id, DagRun.execution_date == 
last_automated_runs_subq
+            )
+        else:
+            last_automated_runs_subq = (
+                select(DagRun.dag_id, 
func.max(DagRun.execution_date).label("max_execution_date"))
+                .where(
+                    DagRun.dag_id.in_(dags),
+                    DagRun.run_type.in_((DagRunType.BACKFILL_JOB, 
DagRunType.SCHEDULED)),
+                )
+                .group_by(DagRun.dag_id)
+                .subquery()
+            )
+            query = select(DagRun).where(
+                DagRun.dag_id == last_automated_runs_subq.c.dag_id,
+                DagRun.execution_date == 
last_automated_runs_subq.c.max_execution_date,
+            )
+        return query.options(
+            load_only(
+                DagRun.dag_id,
+                DagRun.execution_date,
+                DagRun.data_interval_start,
+                DagRun.data_interval_end,
+            )
+        )
+
     @provide_session
     def sync_to_db(self, processor_subdir: str | None = None, 
session=NEW_SESSION):
         """
diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py
index f367b00abe..7c337ed965 100644
--- a/tests/models/test_dag.py
+++ b/tests/models/test_dag.py
@@ -952,6 +952,59 @@ class TestDag:
             for row in session.query(DagModel.last_parsed_time).all():
                 assert row[0] is not None
 
+    def test_bulk_write_to_db_single_dag(self):
+        """
+        Test bulk_write_to_db for a single dag using the index optimized query
+        """
+        clear_db_dags()
+        dags = [DAG(f"dag-bulk-sync-{i}", start_date=DEFAULT_DATE, 
tags=["test-dag"]) for i in range(1)]
+
+        with assert_queries_count(5):
+            DAG.bulk_write_to_db(dags)
+        with create_session() as session:
+            assert {"dag-bulk-sync-0"} == {row[0] for row in 
session.query(DagModel.dag_id).all()}
+            assert {
+                ("dag-bulk-sync-0", "test-dag"),
+            } == set(session.query(DagTag.dag_id, DagTag.name).all())
+
+            for row in session.query(DagModel.last_parsed_time).all():
+                assert row[0] is not None
+
+        # Re-sync should do fewer queries
+        with assert_queries_count(8):
+            DAG.bulk_write_to_db(dags)
+        with assert_queries_count(8):
+            DAG.bulk_write_to_db(dags)
+
+    def test_bulk_write_to_db_multiple_dags(self):
+        """
+        Test bulk_write_to_db for multiple dags which does not use the index 
optimized query
+        """
+        clear_db_dags()
+        dags = [DAG(f"dag-bulk-sync-{i}", start_date=DEFAULT_DATE, 
tags=["test-dag"]) for i in range(4)]
+
+        with assert_queries_count(5):
+            DAG.bulk_write_to_db(dags)
+        with create_session() as session:
+            assert {"dag-bulk-sync-0", "dag-bulk-sync-1", "dag-bulk-sync-2", 
"dag-bulk-sync-3"} == {
+                row[0] for row in session.query(DagModel.dag_id).all()
+            }
+            assert {
+                ("dag-bulk-sync-0", "test-dag"),
+                ("dag-bulk-sync-1", "test-dag"),
+                ("dag-bulk-sync-2", "test-dag"),
+                ("dag-bulk-sync-3", "test-dag"),
+            } == set(session.query(DagTag.dag_id, DagTag.name).all())
+
+            for row in session.query(DagModel.last_parsed_time).all():
+                assert row[0] is not None
+
+        # Re-sync should do fewer queries
+        with assert_queries_count(8):
+            DAG.bulk_write_to_db(dags)
+        with assert_queries_count(8):
+            DAG.bulk_write_to_db(dags)
+
     @pytest.mark.parametrize("interval", [None, "@daily"])
     def test_bulk_write_to_db_interval_save_runtime(self, interval):
         mock_active_runs_of_dags = 
mock.MagicMock(side_effect=DagRun.active_runs_of_dags)
@@ -4082,3 +4135,36 @@ class TestTaskClearingSetupTeardownBehavior:
                 Exception, match="Setup tasks must be followed with trigger 
rule ALL_SUCCESS."
             ):
                 dag.validate_setup_teardown()
+
+
+def test_get_latest_runs_query_one_dag(dag_maker, session):
+    with dag_maker(dag_id="dag1") as dag1:
+        ...
+    query = DAG._get_latest_runs_query(dags=[dag1], session=session)
+    actual = [x.strip() for x in str(query.compile()).splitlines()]
+    expected = [
+        "SELECT dag_run.id, dag_run.dag_id, dag_run.execution_date, 
dag_run.data_interval_start, dag_run.data_interval_end",
+        "FROM dag_run",
+        "WHERE dag_run.dag_id = :dag_id_1 AND dag_run.execution_date = (SELECT 
max(dag_run.execution_date) AS max_execution_date",
+        "FROM dag_run",
+        "WHERE dag_run.dag_id = :dag_id_2 AND dag_run.run_type IN 
(__[POSTCOMPILE_run_type_1]))",
+    ]
+    assert actual == expected
+
+
+def test_get_latest_runs_query_two_dags(dag_maker, session):
+    with dag_maker(dag_id="dag1") as dag1:
+        ...
+    with dag_maker(dag_id="dag2") as dag2:
+        ...
+    query = DAG._get_latest_runs_query(dags=[dag1, dag2], session=session)
+    actual = [x.strip() for x in str(query.compile()).splitlines()]
+    print("\n".join(actual))
+    expected = [
+        "SELECT dag_run.id, dag_run.dag_id, dag_run.execution_date, 
dag_run.data_interval_start, dag_run.data_interval_end",
+        "FROM dag_run, (SELECT dag_run.dag_id AS dag_id, 
max(dag_run.execution_date) AS max_execution_date",
+        "FROM dag_run",
+        "WHERE dag_run.dag_id IN (__[POSTCOMPILE_dag_id_1]) AND 
dag_run.run_type IN (__[POSTCOMPILE_run_type_1]) GROUP BY dag_run.dag_id) AS 
anon_1",
+        "WHERE dag_run.dag_id = anon_1.dag_id AND dag_run.execution_date = 
anon_1.max_execution_date",
+    ]
+    assert actual == expected

Reply via email to