This is an automated email from the ASF dual-hosted git repository.

taragolis pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 51bd26b56f Fix 'implicitly coercing SELECT object to scalar subquery' 
in latest dag run statement (#37505)
51bd26b56f is described below

commit 51bd26b56f5f930add92764d8eade490f11f6fea
Author: Andrey Anshin <[email protected]>
AuthorDate: Mon Feb 19 11:45:35 2024 +0400

    Fix 'implicitly coercing SELECT object to scalar subquery' in latest dag 
run statement (#37505)
    
    * Fix 'implicitly coercing SELECT object to scalar subquery' in latest dag 
run statement
    
    * Remove redundant print
    
    * remove redundant dag_maker and session in tests
    
    * Beautify test output
---
 airflow/models/dag.py    |  8 +++---
 tests/models/test_dag.py | 65 +++++++++++++++++++++++++-----------------------
 2 files changed, 38 insertions(+), 35 deletions(-)

diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index d2366c0e9e..164e83a3f5 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -3081,7 +3081,7 @@ class DAG(LoggingMixin):
         # Skip these queries entirely if no DAGs can be scheduled to save time.
         if any(dag.timetable.can_be_scheduled for dag in dags):
             # Get the latest automated dag run for each existing dag as a 
single query (avoid n+1 query)
-            query = cls._get_latest_runs_query(dags=list(existing_dags.keys()))
+            query = cls._get_latest_runs_stmt(dags=list(existing_dags.keys()))
             latest_runs = {run.dag_id: run for run in session.scalars(query)}
 
             # Get number of active dagruns for all dags we are processing as a 
single query.
@@ -3254,9 +3254,9 @@ class DAG(LoggingMixin):
             cls.bulk_write_to_db(dag.subdags, 
processor_subdir=processor_subdir, session=session)
 
     @classmethod
-    def _get_latest_runs_query(cls, dags: list[str]) -> Query:
+    def _get_latest_runs_stmt(cls, dags: list[str]) -> Select:
         """
-        Query the database to retrieve the last automated run for each dag.
+        Build a select statement for retrieve the last automated run for each 
dag.
 
         :param dags: dags to query
         """
@@ -3269,7 +3269,7 @@ class DAG(LoggingMixin):
                     DagRun.dag_id == existing_dag_id,
                     DagRun.run_type.in_((DagRunType.BACKFILL_JOB, 
DagRunType.SCHEDULED)),
                 )
-                .subquery()
+                .scalar_subquery()
             )
             query = select(DagRun).where(
                 DagRun.dag_id == existing_dag_id, DagRun.execution_date == 
last_automated_runs_subq
diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py
index 05681cfe88..b46f2b2870 100644
--- a/tests/models/test_dag.py
+++ b/tests/models/test_dag.py
@@ -24,6 +24,7 @@ import os
 import pickle
 import re
 import sys
+import warnings
 import weakref
 from contextlib import redirect_stdout
 from datetime import timedelta
@@ -39,6 +40,7 @@ import time_machine
 from dateutil.relativedelta import relativedelta
 from pendulum.tz.timezone import Timezone
 from sqlalchemy import inspect
+from sqlalchemy.exc import SAWarning
 
 from airflow import settings
 from airflow.configuration import conf
@@ -4148,34 +4150,35 @@ class TestTaskClearingSetupTeardownBehavior:
                 dag.validate_setup_teardown()
 
 
-def test_get_latest_runs_query_one_dag(dag_maker, session):
-    with dag_maker(dag_id="dag1") as dag1:
-        ...
-    query = DAG._get_latest_runs_query(dags=[dag1.dag_id])
-    actual = [x.strip() for x in str(query.compile()).splitlines()]
-    expected = [
-        "SELECT dag_run.id, dag_run.dag_id, dag_run.execution_date, 
dag_run.data_interval_start, dag_run.data_interval_end",
-        "FROM dag_run",
-        "WHERE dag_run.dag_id = :dag_id_1 AND dag_run.execution_date = (SELECT 
max(dag_run.execution_date) AS max_execution_date",
-        "FROM dag_run",
-        "WHERE dag_run.dag_id = :dag_id_2 AND dag_run.run_type IN 
(__[POSTCOMPILE_run_type_1]))",
-    ]
-    assert actual == expected
-
-
-def test_get_latest_runs_query_two_dags(dag_maker, session):
-    with dag_maker(dag_id="dag1") as dag1:
-        ...
-    with dag_maker(dag_id="dag2") as dag2:
-        ...
-    query = DAG._get_latest_runs_query(dags=[dag1.dag_id, dag2.dag_id])
-    actual = [x.strip() for x in str(query.compile()).splitlines()]
-    print("\n".join(actual))
-    expected = [
-        "SELECT dag_run.id, dag_run.dag_id, dag_run.execution_date, 
dag_run.data_interval_start, dag_run.data_interval_end",
-        "FROM dag_run, (SELECT dag_run.dag_id AS dag_id, 
max(dag_run.execution_date) AS max_execution_date",
-        "FROM dag_run",
-        "WHERE dag_run.dag_id IN (__[POSTCOMPILE_dag_id_1]) AND 
dag_run.run_type IN (__[POSTCOMPILE_run_type_1]) GROUP BY dag_run.dag_id) AS 
anon_1",
-        "WHERE dag_run.dag_id = anon_1.dag_id AND dag_run.execution_date = 
anon_1.max_execution_date",
-    ]
-    assert actual == expected
+def test_statement_latest_runs_one_dag():
+    with warnings.catch_warnings():
+        warnings.simplefilter("error", category=SAWarning)
+
+        stmt = DAG._get_latest_runs_stmt(dags=["fake-dag"])
+        compiled_stmt = str(stmt.compile())
+        actual = [x.strip() for x in compiled_stmt.splitlines()]
+        expected = [
+            "SELECT dag_run.id, dag_run.dag_id, dag_run.execution_date, 
dag_run.data_interval_start, dag_run.data_interval_end",
+            "FROM dag_run",
+            "WHERE dag_run.dag_id = :dag_id_1 AND dag_run.execution_date = 
(SELECT max(dag_run.execution_date) AS max_execution_date",
+            "FROM dag_run",
+            "WHERE dag_run.dag_id = :dag_id_2 AND dag_run.run_type IN 
(__[POSTCOMPILE_run_type_1]))",
+        ]
+        assert actual == expected, compiled_stmt
+
+
+def test_statement_latest_runs_many_dag():
+    with warnings.catch_warnings():
+        warnings.simplefilter("error", category=SAWarning)
+
+        stmt = DAG._get_latest_runs_stmt(dags=["fake-dag-1", "fake-dag-2"])
+        compiled_stmt = str(stmt.compile())
+        actual = [x.strip() for x in compiled_stmt.splitlines()]
+        expected = [
+            "SELECT dag_run.id, dag_run.dag_id, dag_run.execution_date, 
dag_run.data_interval_start, dag_run.data_interval_end",
+            "FROM dag_run, (SELECT dag_run.dag_id AS dag_id, 
max(dag_run.execution_date) AS max_execution_date",
+            "FROM dag_run",
+            "WHERE dag_run.dag_id IN (__[POSTCOMPILE_dag_id_1]) AND 
dag_run.run_type IN (__[POSTCOMPILE_run_type_1]) GROUP BY dag_run.dag_id) AS 
anon_1",
+            "WHERE dag_run.dag_id = anon_1.dag_id AND dag_run.execution_date = 
anon_1.max_execution_date",
+        ]
+        assert actual == expected, compiled_stmt

Reply via email to