This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new a89dcaa46dc remove large in clause in assets with cte and join (#62114)
a89dcaa46dc is described below

commit a89dcaa46dcf7d1cc1a4e2291da87c90f5541bf6
Author: Nataneljpwd <[email protected]>
AuthorDate: Fri Mar 13 08:08:12 2026 +0200

    remove large in clause in assets with cte and join (#62114)
    
    * fixed large in clause
    
    * fixed tests
    
    * changed to delete using
    
    * added compatibility for tests
    
    * Change asset selection to use CTE
    
    * Fix asset selection for Airflow versions
    
    * fixed some tests and optimized the query
    
    * fixed all tests
    
    * fixed mypy
    
    * fixup! address cr comments
    
    * address CR comments, added tests and fixed plugin
    
    * adress CR comment
    
    ---------
    
    Co-authored-by: Natanel Rudyuklakir <[email protected]>
    Co-authored-by: Elad Kalif <[email protected]>
---
 .../src/airflow/jobs/scheduler_job_runner.py       | 87 ++++++++++++----------
 airflow-core/tests/unit/jobs/test_scheduler_job.py | 41 ++++++++--
 devel-common/src/tests_common/pytest_plugin.py     |  8 +-
 3 files changed, 89 insertions(+), 47 deletions(-)

diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py 
b/airflow-core/src/airflow/jobs/scheduler_job_runner.py
index c667631756e..b078cb183bc 100644
--- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py
+++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py
@@ -33,18 +33,7 @@ from functools import lru_cache, partial
 from itertools import groupby
 from typing import TYPE_CHECKING, Any
 
-from sqlalchemy import (
-    and_,
-    delete,
-    exists,
-    func,
-    inspect,
-    or_,
-    select,
-    text,
-    tuple_,
-    update,
-)
+from sqlalchemy import CTE, and_, delete, exists, func, inspect, or_, select, 
text, tuple_, update
 from sqlalchemy.exc import OperationalError
 from sqlalchemy.orm import joinedload, lazyload, load_only, make_transient, 
selectinload
 from sqlalchemy.sql import expression
@@ -822,7 +811,11 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
                                 task_instance,
                             )
                             starved_tasks_task_dagrun_concurrency.add(
-                                (task_instance.dag_id, task_instance.run_id, 
task_instance.task_id)
+                                (
+                                    task_instance.dag_id,
+                                    task_instance.run_id,
+                                    task_instance.task_id,
+                                )
                             )
                             continue
 
@@ -2940,44 +2933,41 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
             )
             == 0
         ).label("orphaned")
-        asset_reference_query = session.execute(
-            select(orphaned, AssetModel)
+        asset_reference_query = (
+            select(AssetModel)
             .outerjoin(DagScheduleAssetReference)
             .outerjoin(TaskOutletAssetReference)
             .outerjoin(TaskInletAssetReference)
             .group_by(AssetModel.id)
-            .order_by(orphaned)
         )
-        asset_orphanation: dict[bool, Collection[AssetModel]] = {
-            orphaned: [asset for _, asset in group]
-            for orphaned, group in itertools.groupby(asset_reference_query, 
key=operator.itemgetter(0))
-        }
-        self._orphan_unreferenced_assets(asset_orphanation.get(True, ()), 
session=session)
-        self._activate_referenced_assets(asset_orphanation.get(False, ()), 
session=session)
+
+        orphan_query = asset_reference_query.having(orphaned).cte()
+        activate_query = asset_reference_query.having(~orphaned).cte()
+
+        self._orphan_unreferenced_assets(orphan_query, session=session)
+        self._activate_referenced_assets(activate_query, session=session)
 
     @staticmethod
-    def _orphan_unreferenced_assets(assets: Collection[AssetModel], *, 
session: Session) -> None:
-        if assets:
-            session.execute(
-                delete(AssetActive).where(
-                    tuple_(AssetActive.name, AssetActive.uri).in_((a.name, 
a.uri) for a in assets)
+    def _orphan_unreferenced_assets(assets_query: CTE, *, session: Session) -> 
None:
+        deleted_orphaned_assets = session.execute(
+            delete(AssetActive).where(
+                exists().where(
+                    and_(AssetActive.name == assets_query.c.name, 
AssetActive.uri == assets_query.c.uri)
                 )
             )
-        Stats.gauge("asset.orphaned", len(assets))
+        )
 
-    @staticmethod
-    def _activate_referenced_assets(assets: Collection[AssetModel], *, 
session: Session) -> None:
-        if not assets:
-            return
+        Stats.gauge("asset.orphaned", max(getattr(deleted_orphaned_assets, 
"rowcount", 0), 0))
 
-        active_assets = set(
-            session.execute(
-                select(AssetActive.name, AssetActive.uri).where(
-                    tuple_(AssetActive.name, AssetActive.uri).in_((a.name, 
a.uri) for a in assets)
-                )
-            )
+    @staticmethod
+    def _activate_referenced_assets(assets_query: CTE, *, session: Session) -> 
None:
+        active_assets_query = select(AssetActive.name, AssetActive.uri).join(
+            assets_query,
+            and_(AssetActive.name == assets_query.c.name, AssetActive.uri == 
assets_query.c.uri),
         )
 
+        active_assets = session.execute(active_assets_query).all()
+
         active_name_to_uri: dict[str, str] = {name: uri for name, uri in 
active_assets}
         active_uri_to_name: dict[str, str] = {uri: name for name, uri in 
active_assets}
 
@@ -3002,9 +2992,24 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
         def _activate_assets_generate_warnings() -> Iterator[tuple[str, str]]:
             incoming_name_to_uri: dict[str, str] = {}
             incoming_uri_to_name: dict[str, str] = {}
-            for asset in assets:
-                if (asset.name, asset.uri) in active_assets:
-                    continue
+
+            inactive_assets_query = (
+                select(AssetModel)
+                .join(
+                    assets_query,
+                    and_(
+                        assets_query.c.name == AssetModel.name,
+                        assets_query.c.uri == AssetModel.uri,
+                    ),
+                )
+                .where(
+                    ~active_assets_query.where(
+                        and_(AssetActive.name == AssetModel.name, 
AssetActive.uri == AssetModel.uri)
+                    ).exists()
+                )
+            )
+
+            for asset in session.scalars(inactive_assets_query):
                 existing_uri = active_name_to_uri.get(asset.name) or 
incoming_name_to_uri.get(asset.name)
                 if existing_uri is not None and existing_uri != asset.uri:
                     yield from _generate_warning_message(asset, "name", 
existing_uri)
diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py 
b/airflow-core/tests/unit/jobs/test_scheduler_job.py
index c23f180f3a4..7c8d2ff4ffb 100644
--- a/airflow-core/tests/unit/jobs/test_scheduler_job.py
+++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py
@@ -24,7 +24,7 @@ import os
 import re
 from collections import Counter, deque
 from collections.abc import Callable, Generator, Iterator
-from contextlib import ExitStack
+from contextlib import ExitStack, contextmanager
 from datetime import timedelta
 from pathlib import Path
 from typing import TYPE_CHECKING
@@ -7362,6 +7362,37 @@ class TestSchedulerJob:
         job_runner._create_dag_runs([dm1], session)
         assert "Failed creating DagRun" in caplog.text
 
+    def test_activate_referenced_assets_no_in_check_inside_query(self, 
session, testing_dag_bundle):
+        dag_id1 = "test_asset_dag1"
+        asset1_name = "asset1"
+        asset_extra = {"foo": "bar"}
+
+        asset1 = Asset(name=asset1_name, uri="s3://bucket/key/1", 
extra=asset_extra)
+        dag1 = DAG(dag_id=dag_id1, start_date=DEFAULT_DATE, schedule=[asset1])
+        sync_dag_to_db(dag1, session=session)
+
+        @contextmanager
+        def assert_no_in_clause(session):
+            from sqlalchemy import event
+
+            def fail_on_in_clause_found(execute_statement):
+                if " IN " in str(execute_statement).upper():
+                    execute_statement = str(execute_statement).upper()
+                    pytest.fail(
+                        f"Query contains IN clause which was removed in PR 
#62114, query: {execute_statement}"
+                    )
+
+            event.listen(session, "do_orm_execute", fail_on_in_clause_found)
+            try:
+                yield
+            finally:
+                event.remove(session, "do_orm_execute", 
fail_on_in_clause_found)
+
+        asset_models = select(AssetModel).cte()
+
+        with assert_no_in_clause(session):
+            SchedulerJobRunner._activate_referenced_assets(asset_models, 
session=session)
+
     def test_activate_referenced_assets_with_no_existing_warning(self, 
session, testing_dag_bundle):
         dag_warnings = session.scalars(select(DagWarning)).all()
         assert dag_warnings == []
@@ -7376,8 +7407,8 @@ class TestSchedulerJob:
         dag1 = DAG(dag_id=dag_id1, start_date=DEFAULT_DATE, schedule=[asset1, 
asset1_1, asset1_2])
         sync_dag_to_db(dag1, session=session)
 
-        asset_models = session.scalars(select(AssetModel)).all()
-        assert len(asset_models) == 3
+        asset_models = select(AssetModel).cte()
+        assert len(session.execute(select(asset_models)).all()) == 3
 
         SchedulerJobRunner._activate_referenced_assets(asset_models, 
session=session)
         session.flush()
@@ -7414,7 +7445,7 @@ class TestSchedulerJob:
         )
         session.flush()
 
-        asset_models = session.scalars(select(AssetModel)).all()
+        asset_models = select(AssetModel).cte()
 
         SchedulerJobRunner._activate_referenced_assets(asset_models, 
session=session)
         session.flush()
@@ -7463,7 +7494,7 @@ class TestSchedulerJob:
         session.add(DagWarning(dag_id=dag_id, warning_type="asset conflict", 
message="will not exist"))
         session.flush()
 
-        asset_models = session.scalars(select(AssetModel)).all()
+        asset_models = select(AssetModel).cte()
 
         SchedulerJobRunner._activate_referenced_assets(asset_models, 
session=session)
         session.flush()
diff --git a/devel-common/src/tests_common/pytest_plugin.py 
b/devel-common/src/tests_common/pytest_plugin.py
index 5d473bdfea0..9f3aa0a7bc4 100644
--- a/devel-common/src/tests_common/pytest_plugin.py
+++ b/devel-common/src/tests_common/pytest_plugin.py
@@ -972,7 +972,13 @@ def dag_maker(request) -> Generator[DagMaker, None, None]:
                     
AssetModel.producing_tasks.any(TaskOutletAssetReference.dag_id == 
self.dag.dag_id),
                 )
 
-            assets = 
self.session.scalars(select(AssetModel).where(assets_select_condition)).all()
+            assets = select(AssetModel).where(assets_select_condition).cte()
+
+            if not AIRFLOW_V_3_2_PLUS:
+                assets = self.session.scalars(
+                    select(AssetModel).join(assets, AssetModel.id == 
AssetModel.id)
+                ).all()
+
             SchedulerJobRunner._activate_referenced_assets(assets, 
session=self.session)
 
         def __exit__(self, type, value, traceback):

Reply via email to