This is an automated email from the ASF dual-hosted git repository.
dstandish pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 10c04a4efc Optimize max_execution_date query in single dag case
(#33242)
10c04a4efc is described below
commit 10c04a4efcee6b92f8057a5a7dbd21d9ca2de710
Author: Josh Owen <[email protected]>
AuthorDate: Mon Jan 22 11:08:25 2024 -0500
Optimize max_execution_date query in single dag case (#33242)
We can make better use of an index when we're only dealing with one dag,
which is a common case.
---------
Co-authored-by: Elad Kalif <[email protected]>
Co-authored-by: Daniel Standish
<[email protected]>
---
airflow/models/dag.py | 69 +++++++++++++++++++++++++++-----------
tests/models/test_dag.py | 86 ++++++++++++++++++++++++++++++++++++++++++++++++
2 files changed, 136 insertions(+), 19 deletions(-)
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 9ee3409c0d..dac7be010a 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -73,7 +73,7 @@ from sqlalchemy import (
update,
)
from sqlalchemy.ext.associationproxy import association_proxy
-from sqlalchemy.orm import backref, joinedload, relationship
+from sqlalchemy.orm import backref, joinedload, load_only, relationship
from sqlalchemy.sql import Select, expression
import airflow.templates
@@ -3062,27 +3062,13 @@ class DAG(LoggingMixin):
session.add(orm_dag)
orm_dags.append(orm_dag)
- dag_id_to_last_automated_run: dict[str, DagRun] = {}
+ latest_runs: dict[str, DagRun] = {}
num_active_runs: dict[str, int] = {}
# Skip these queries entirely if no DAGs can be scheduled to save time.
if any(dag.timetable.can_be_scheduled for dag in dags):
# Get the latest automated dag run for each existing dag as a
single query (avoid n+1 query)
- last_automated_runs_subq = (
- select(DagRun.dag_id,
func.max(DagRun.execution_date).label("max_execution_date"))
- .where(
- DagRun.dag_id.in_(existing_dags),
- or_(DagRun.run_type == DagRunType.BACKFILL_JOB,
DagRun.run_type == DagRunType.SCHEDULED),
- )
- .group_by(DagRun.dag_id)
- .subquery()
- )
- last_automated_runs = session.scalars(
- select(DagRun).where(
- DagRun.dag_id == last_automated_runs_subq.c.dag_id,
- DagRun.execution_date ==
last_automated_runs_subq.c.max_execution_date,
- )
- )
- dag_id_to_last_automated_run = {run.dag_id: run for run in
last_automated_runs}
+ query = cls._get_latest_runs_query(existing_dags, session)
+ latest_runs = {run.dag_id: run for run in session.scalars(query)}
# Get number of active dagruns for all dags we are processing as a
single query.
num_active_runs =
DagRun.active_runs_of_dags(dag_ids=existing_dags, session=session)
@@ -3116,7 +3102,7 @@ class DAG(LoggingMixin):
orm_dag.timetable_description = dag.timetable.description
orm_dag.processor_subdir = processor_subdir
- last_automated_run: DagRun | None =
dag_id_to_last_automated_run.get(dag.dag_id)
+ last_automated_run: DagRun | None = latest_runs.get(dag.dag_id)
if last_automated_run is None:
last_automated_data_interval = None
else:
@@ -3253,6 +3239,51 @@ class DAG(LoggingMixin):
for dag in dags:
cls.bulk_write_to_db(dag.subdags,
processor_subdir=processor_subdir, session=session)
+ @classmethod
+ def _get_latest_runs_query(cls, dags, session) -> Query:
+ """
+ Query the database to retrieve the last automated run for each dag.
+
+ :param dags: dags to query
+ :param session: sqlalchemy session object
+ """
+ if len(dags) == 1:
+ # Index optimized fast path to avoid more complicated & slower
groupby queryplan
+ existing_dag_id = list(dags)[0].dag_id
+ last_automated_runs_subq = (
+
select(func.max(DagRun.execution_date).label("max_execution_date"))
+ .where(
+ DagRun.dag_id == existing_dag_id,
+ DagRun.run_type.in_((DagRunType.BACKFILL_JOB,
DagRunType.SCHEDULED)),
+ )
+ .subquery()
+ )
+ query = select(DagRun).where(
+ DagRun.dag_id == existing_dag_id, DagRun.execution_date ==
last_automated_runs_subq
+ )
+ else:
+ last_automated_runs_subq = (
+ select(DagRun.dag_id,
func.max(DagRun.execution_date).label("max_execution_date"))
+ .where(
+ DagRun.dag_id.in_(dags),
+ DagRun.run_type.in_((DagRunType.BACKFILL_JOB,
DagRunType.SCHEDULED)),
+ )
+ .group_by(DagRun.dag_id)
+ .subquery()
+ )
+ query = select(DagRun).where(
+ DagRun.dag_id == last_automated_runs_subq.c.dag_id,
+ DagRun.execution_date ==
last_automated_runs_subq.c.max_execution_date,
+ )
+ return query.options(
+ load_only(
+ DagRun.dag_id,
+ DagRun.execution_date,
+ DagRun.data_interval_start,
+ DagRun.data_interval_end,
+ )
+ )
+
@provide_session
def sync_to_db(self, processor_subdir: str | None = None,
session=NEW_SESSION):
"""
diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py
index f367b00abe..7c337ed965 100644
--- a/tests/models/test_dag.py
+++ b/tests/models/test_dag.py
@@ -952,6 +952,59 @@ class TestDag:
for row in session.query(DagModel.last_parsed_time).all():
assert row[0] is not None
+ def test_bulk_write_to_db_single_dag(self):
+ """
+ Test bulk_write_to_db for a single dag using the index optimized query
+ """
+ clear_db_dags()
+ dags = [DAG(f"dag-bulk-sync-{i}", start_date=DEFAULT_DATE,
tags=["test-dag"]) for i in range(1)]
+
+ with assert_queries_count(5):
+ DAG.bulk_write_to_db(dags)
+ with create_session() as session:
+ assert {"dag-bulk-sync-0"} == {row[0] for row in
session.query(DagModel.dag_id).all()}
+ assert {
+ ("dag-bulk-sync-0", "test-dag"),
+ } == set(session.query(DagTag.dag_id, DagTag.name).all())
+
+ for row in session.query(DagModel.last_parsed_time).all():
+ assert row[0] is not None
+
+ # Re-sync should do fewer queries
+ with assert_queries_count(8):
+ DAG.bulk_write_to_db(dags)
+ with assert_queries_count(8):
+ DAG.bulk_write_to_db(dags)
+
+ def test_bulk_write_to_db_multiple_dags(self):
+ """
+ Test bulk_write_to_db for multiple dags which does not use the index
optimized query
+ """
+ clear_db_dags()
+ dags = [DAG(f"dag-bulk-sync-{i}", start_date=DEFAULT_DATE,
tags=["test-dag"]) for i in range(4)]
+
+ with assert_queries_count(5):
+ DAG.bulk_write_to_db(dags)
+ with create_session() as session:
+ assert {"dag-bulk-sync-0", "dag-bulk-sync-1", "dag-bulk-sync-2",
"dag-bulk-sync-3"} == {
+ row[0] for row in session.query(DagModel.dag_id).all()
+ }
+ assert {
+ ("dag-bulk-sync-0", "test-dag"),
+ ("dag-bulk-sync-1", "test-dag"),
+ ("dag-bulk-sync-2", "test-dag"),
+ ("dag-bulk-sync-3", "test-dag"),
+ } == set(session.query(DagTag.dag_id, DagTag.name).all())
+
+ for row in session.query(DagModel.last_parsed_time).all():
+ assert row[0] is not None
+
+ # Re-sync should do fewer queries
+ with assert_queries_count(8):
+ DAG.bulk_write_to_db(dags)
+ with assert_queries_count(8):
+ DAG.bulk_write_to_db(dags)
+
@pytest.mark.parametrize("interval", [None, "@daily"])
def test_bulk_write_to_db_interval_save_runtime(self, interval):
mock_active_runs_of_dags =
mock.MagicMock(side_effect=DagRun.active_runs_of_dags)
@@ -4082,3 +4135,36 @@ class TestTaskClearingSetupTeardownBehavior:
Exception, match="Setup tasks must be followed with trigger
rule ALL_SUCCESS."
):
dag.validate_setup_teardown()
+
+
+def test_get_latest_runs_query_one_dag(dag_maker, session):
+ with dag_maker(dag_id="dag1") as dag1:
+ ...
+ query = DAG._get_latest_runs_query(dags=[dag1], session=session)
+ actual = [x.strip() for x in str(query.compile()).splitlines()]
+ expected = [
+ "SELECT dag_run.id, dag_run.dag_id, dag_run.execution_date,
dag_run.data_interval_start, dag_run.data_interval_end",
+ "FROM dag_run",
+ "WHERE dag_run.dag_id = :dag_id_1 AND dag_run.execution_date = (SELECT
max(dag_run.execution_date) AS max_execution_date",
+ "FROM dag_run",
+ "WHERE dag_run.dag_id = :dag_id_2 AND dag_run.run_type IN
(__[POSTCOMPILE_run_type_1]))",
+ ]
+ assert actual == expected
+
+
+def test_get_latest_runs_query_two_dags(dag_maker, session):
+ with dag_maker(dag_id="dag1") as dag1:
+ ...
+ with dag_maker(dag_id="dag2") as dag2:
+ ...
+ query = DAG._get_latest_runs_query(dags=[dag1, dag2], session=session)
+ actual = [x.strip() for x in str(query.compile()).splitlines()]
+ print("\n".join(actual))
+ expected = [
+ "SELECT dag_run.id, dag_run.dag_id, dag_run.execution_date,
dag_run.data_interval_start, dag_run.data_interval_end",
+ "FROM dag_run, (SELECT dag_run.dag_id AS dag_id,
max(dag_run.execution_date) AS max_execution_date",
+ "FROM dag_run",
+ "WHERE dag_run.dag_id IN (__[POSTCOMPILE_dag_id_1]) AND
dag_run.run_type IN (__[POSTCOMPILE_run_type_1]) GROUP BY dag_run.dag_id) AS
anon_1",
+ "WHERE dag_run.dag_id = anon_1.dag_id AND dag_run.execution_date =
anon_1.max_execution_date",
+ ]
+ assert actual == expected