This is an automated email from the ASF dual-hosted git repository.
taragolis pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 51bd26b56f Fix 'implicitly coercing SELECT object to scalar subquery'
in latest dag run statement (#37505)
51bd26b56f is described below
commit 51bd26b56f5f930add92764d8eade490f11f6fea
Author: Andrey Anshin <[email protected]>
AuthorDate: Mon Feb 19 11:45:35 2024 +0400
Fix 'implicitly coercing SELECT object to scalar subquery' in latest dag
run statement (#37505)
* Fix 'implicitly coercing SELECT object to scalar subquery' in latest dag
run statement
* Remove redundant print
* remove redundant dag_maker and session in tests
* Beautify test output
---
airflow/models/dag.py | 8 +++---
tests/models/test_dag.py | 65 +++++++++++++++++++++++++-----------------------
2 files changed, 38 insertions(+), 35 deletions(-)
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index d2366c0e9e..164e83a3f5 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -3081,7 +3081,7 @@ class DAG(LoggingMixin):
# Skip these queries entirely if no DAGs can be scheduled to save time.
if any(dag.timetable.can_be_scheduled for dag in dags):
# Get the latest automated dag run for each existing dag as a
single query (avoid n+1 query)
- query = cls._get_latest_runs_query(dags=list(existing_dags.keys()))
+ query = cls._get_latest_runs_stmt(dags=list(existing_dags.keys()))
latest_runs = {run.dag_id: run for run in session.scalars(query)}
# Get number of active dagruns for all dags we are processing as a
single query.
@@ -3254,9 +3254,9 @@ class DAG(LoggingMixin):
cls.bulk_write_to_db(dag.subdags,
processor_subdir=processor_subdir, session=session)
@classmethod
- def _get_latest_runs_query(cls, dags: list[str]) -> Query:
+ def _get_latest_runs_stmt(cls, dags: list[str]) -> Select:
"""
- Query the database to retrieve the last automated run for each dag.
+ Build a select statement for retrieve the last automated run for each
dag.
:param dags: dags to query
"""
@@ -3269,7 +3269,7 @@ class DAG(LoggingMixin):
DagRun.dag_id == existing_dag_id,
DagRun.run_type.in_((DagRunType.BACKFILL_JOB,
DagRunType.SCHEDULED)),
)
- .subquery()
+ .scalar_subquery()
)
query = select(DagRun).where(
DagRun.dag_id == existing_dag_id, DagRun.execution_date ==
last_automated_runs_subq
diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py
index 05681cfe88..b46f2b2870 100644
--- a/tests/models/test_dag.py
+++ b/tests/models/test_dag.py
@@ -24,6 +24,7 @@ import os
import pickle
import re
import sys
+import warnings
import weakref
from contextlib import redirect_stdout
from datetime import timedelta
@@ -39,6 +40,7 @@ import time_machine
from dateutil.relativedelta import relativedelta
from pendulum.tz.timezone import Timezone
from sqlalchemy import inspect
+from sqlalchemy.exc import SAWarning
from airflow import settings
from airflow.configuration import conf
@@ -4148,34 +4150,35 @@ class TestTaskClearingSetupTeardownBehavior:
dag.validate_setup_teardown()
-def test_get_latest_runs_query_one_dag(dag_maker, session):
- with dag_maker(dag_id="dag1") as dag1:
- ...
- query = DAG._get_latest_runs_query(dags=[dag1.dag_id])
- actual = [x.strip() for x in str(query.compile()).splitlines()]
- expected = [
- "SELECT dag_run.id, dag_run.dag_id, dag_run.execution_date,
dag_run.data_interval_start, dag_run.data_interval_end",
- "FROM dag_run",
- "WHERE dag_run.dag_id = :dag_id_1 AND dag_run.execution_date = (SELECT
max(dag_run.execution_date) AS max_execution_date",
- "FROM dag_run",
- "WHERE dag_run.dag_id = :dag_id_2 AND dag_run.run_type IN
(__[POSTCOMPILE_run_type_1]))",
- ]
- assert actual == expected
-
-
-def test_get_latest_runs_query_two_dags(dag_maker, session):
- with dag_maker(dag_id="dag1") as dag1:
- ...
- with dag_maker(dag_id="dag2") as dag2:
- ...
- query = DAG._get_latest_runs_query(dags=[dag1.dag_id, dag2.dag_id])
- actual = [x.strip() for x in str(query.compile()).splitlines()]
- print("\n".join(actual))
- expected = [
- "SELECT dag_run.id, dag_run.dag_id, dag_run.execution_date,
dag_run.data_interval_start, dag_run.data_interval_end",
- "FROM dag_run, (SELECT dag_run.dag_id AS dag_id,
max(dag_run.execution_date) AS max_execution_date",
- "FROM dag_run",
- "WHERE dag_run.dag_id IN (__[POSTCOMPILE_dag_id_1]) AND
dag_run.run_type IN (__[POSTCOMPILE_run_type_1]) GROUP BY dag_run.dag_id) AS
anon_1",
- "WHERE dag_run.dag_id = anon_1.dag_id AND dag_run.execution_date =
anon_1.max_execution_date",
- ]
- assert actual == expected
+def test_statement_latest_runs_one_dag():
+ with warnings.catch_warnings():
+ warnings.simplefilter("error", category=SAWarning)
+
+ stmt = DAG._get_latest_runs_stmt(dags=["fake-dag"])
+ compiled_stmt = str(stmt.compile())
+ actual = [x.strip() for x in compiled_stmt.splitlines()]
+ expected = [
+ "SELECT dag_run.id, dag_run.dag_id, dag_run.execution_date,
dag_run.data_interval_start, dag_run.data_interval_end",
+ "FROM dag_run",
+ "WHERE dag_run.dag_id = :dag_id_1 AND dag_run.execution_date =
(SELECT max(dag_run.execution_date) AS max_execution_date",
+ "FROM dag_run",
+ "WHERE dag_run.dag_id = :dag_id_2 AND dag_run.run_type IN
(__[POSTCOMPILE_run_type_1]))",
+ ]
+ assert actual == expected, compiled_stmt
+
+
+def test_statement_latest_runs_many_dag():
+ with warnings.catch_warnings():
+ warnings.simplefilter("error", category=SAWarning)
+
+ stmt = DAG._get_latest_runs_stmt(dags=["fake-dag-1", "fake-dag-2"])
+ compiled_stmt = str(stmt.compile())
+ actual = [x.strip() for x in compiled_stmt.splitlines()]
+ expected = [
+ "SELECT dag_run.id, dag_run.dag_id, dag_run.execution_date,
dag_run.data_interval_start, dag_run.data_interval_end",
+ "FROM dag_run, (SELECT dag_run.dag_id AS dag_id,
max(dag_run.execution_date) AS max_execution_date",
+ "FROM dag_run",
+ "WHERE dag_run.dag_id IN (__[POSTCOMPILE_dag_id_1]) AND
dag_run.run_type IN (__[POSTCOMPILE_run_type_1]) GROUP BY dag_run.dag_id) AS
anon_1",
+ "WHERE dag_run.dag_id = anon_1.dag_id AND dag_run.execution_date =
anon_1.max_execution_date",
+ ]
+ assert actual == expected, compiled_stmt