This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch v3-1-test
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/v3-1-test by this push:
     new a81eb25eef0 [v3-1-test] remove large in clause in assets with cte and 
join (#62114) (#63511)
a81eb25eef0 is described below

commit a81eb25eef0f631c49d7947a0f98f336f6f8ea23
Author: Nataneljpwd <[email protected]>
AuthorDate: Fri Mar 13 21:18:36 2026 +0200

    [v3-1-test] remove large in clause in assets with cte and join (#62114) 
(#63511)
    
    * [v3-1-test] remove large in clause in assets with cte and join 
(#62114)\n\n* fixed large in clause\n\n* fixed tests\n\n* changed to delete 
using\n\n* added compatibility for tests\n\n* Change asset selection to use 
CTE\n\n* Fix asset selection for Airflow versions\n\n* fixed some tests and 
optimized the query\n\n* fixed all tests\n\n* fixed mypy\n\n* fixup! address cr 
comments\n\n* address CR comments, added tests and fixed plugin\n\n* adress CR 
comment\n\n---------\n(cherry picked  [...]
    
    * fix CTE import
    
    * set execution options
    
    * lint
    
    * fixed ruff
    
    ---------
    
    Co-authored-by: Natanel Rudyuklakir <[email protected]>
---
 .../src/airflow/jobs/scheduler_job_runner.py       | 78 +++++++++++++---------
 airflow-core/tests/unit/jobs/test_scheduler_job.py | 40 +++++++++--
 devel-common/src/tests_common/pytest_plugin.py     |  3 +-
 3 files changed, 86 insertions(+), 35 deletions(-)

diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py 
b/airflow-core/src/airflow/jobs/scheduler_job_runner.py
index c66be676137..fc223d7ca80 100644
--- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py
+++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py
@@ -96,6 +96,7 @@ if TYPE_CHECKING:
 
     from pendulum.datetime import DateTime
     from sqlalchemy.orm import Load, Query, Session
+    from sqlalchemy.sql.selectable import CTE
 
     from airflow._shared.logging.types import Logger
     from airflow.executors.base_executor import BaseExecutor
@@ -592,7 +593,11 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
                                 task_instance,
                             )
                             starved_tasks_task_dagrun_concurrency.add(
-                                (task_instance.dag_id, task_instance.run_id, 
task_instance.task_id)
+                                (
+                                    task_instance.dag_id,
+                                    task_instance.run_id,
+                                    task_instance.task_id,
+                                )
                             )
                             continue
 
@@ -2646,44 +2651,42 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
             )
             == 0
         ).label("orphaned")
-        asset_reference_query = session.execute(
-            select(orphaned, AssetModel)
+        asset_reference_query = (
+            select(AssetModel)
             .outerjoin(DagScheduleAssetReference)
             .outerjoin(TaskOutletAssetReference)
             .outerjoin(TaskInletAssetReference)
             .group_by(AssetModel.id)
-            .order_by(orphaned)
         )
-        asset_orphanation: dict[bool, Collection[AssetModel]] = {
-            orphaned: [asset for _, asset in group]
-            for orphaned, group in itertools.groupby(asset_reference_query, 
key=operator.itemgetter(0))
-        }
-        self._orphan_unreferenced_assets(asset_orphanation.get(True, ()), 
session=session)
-        self._activate_referenced_assets(asset_orphanation.get(False, ()), 
session=session)
+
+        orphan_query = asset_reference_query.having(orphaned).cte()
+        activate_query = asset_reference_query.having(~orphaned).cte()
+
+        self._orphan_unreferenced_assets(orphan_query, session=session)
+        self._activate_referenced_assets(activate_query, session=session)
 
     @staticmethod
-    def _orphan_unreferenced_assets(assets: Collection[AssetModel], *, 
session: Session) -> None:
-        if assets:
-            session.execute(
-                delete(AssetActive).where(
-                    tuple_(AssetActive.name, AssetActive.uri).in_((a.name, 
a.uri) for a in assets)
+    def _orphan_unreferenced_assets(assets_query: CTE, *, session: Session) -> 
None:
+        deleted_orphaned_assets = session.execute(
+            delete(AssetActive).where(
+                exists().where(
+                    and_(AssetActive.name == assets_query.c.name, 
AssetActive.uri == assets_query.c.uri)
                 )
-            )
-        Stats.gauge("asset.orphaned", len(assets))
+            ),
+            execution_options={"synchronize_session": "fetch"},
+        )
 
-    @staticmethod
-    def _activate_referenced_assets(assets: Collection[AssetModel], *, 
session: Session) -> None:
-        if not assets:
-            return
+        Stats.gauge("asset.orphaned", max(getattr(deleted_orphaned_assets, 
"rowcount", 0), 0))
 
-        active_assets = set(
-            session.execute(
-                select(AssetActive.name, AssetActive.uri).where(
-                    tuple_(AssetActive.name, AssetActive.uri).in_((a.name, 
a.uri) for a in assets)
-                )
-            )
+    @staticmethod
+    def _activate_referenced_assets(assets_query: CTE, *, session: Session) -> 
None:
+        active_assets_query = select(AssetActive.name, AssetActive.uri).join(
+            assets_query,
+            and_(AssetActive.name == assets_query.c.name, AssetActive.uri == 
assets_query.c.uri),
         )
 
+        active_assets = session.execute(active_assets_query).all()
+
         active_name_to_uri: dict[str, str] = {name: uri for name, uri in 
active_assets}
         active_uri_to_name: dict[str, str] = {uri: name for name, uri in 
active_assets}
 
@@ -2708,9 +2711,24 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
         def _activate_assets_generate_warnings() -> Iterator[tuple[str, str]]:
             incoming_name_to_uri: dict[str, str] = {}
             incoming_uri_to_name: dict[str, str] = {}
-            for asset in assets:
-                if (asset.name, asset.uri) in active_assets:
-                    continue
+
+            inactive_assets_query = (
+                select(AssetModel)
+                .join(
+                    assets_query,
+                    and_(
+                        assets_query.c.name == AssetModel.name,
+                        assets_query.c.uri == AssetModel.uri,
+                    ),
+                )
+                .where(
+                    ~active_assets_query.where(
+                        and_(AssetActive.name == AssetModel.name, 
AssetActive.uri == AssetModel.uri)
+                    ).exists()
+                )
+            )
+
+            for asset in session.scalars(inactive_assets_query):
                 existing_uri = active_name_to_uri.get(asset.name) or 
incoming_name_to_uri.get(asset.name)
                 if existing_uri is not None and existing_uri != asset.uri:
                     yield from _generate_warning_message(asset, "name", 
existing_uri)
diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py 
b/airflow-core/tests/unit/jobs/test_scheduler_job.py
index d9cc4e0a4e5..0d6d2ddf360 100644
--- a/airflow-core/tests/unit/jobs/test_scheduler_job.py
+++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py
@@ -23,6 +23,7 @@ import logging
 import os
 from collections import Counter, deque
 from collections.abc import Generator
+from contextlib import contextmanager
 from datetime import timedelta
 from pathlib import Path
 from typing import TYPE_CHECKING
@@ -7076,6 +7077,37 @@ class TestSchedulerJob:
         # Check if the second dagrun was created
         assert DagRun.find(dag_id="testdag2", session=session)
 
+    def test_activate_referenced_assets_no_in_check_inside_query(self, 
session, testing_dag_bundle):
+        dag_id1 = "test_asset_dag1"
+        asset1_name = "asset1"
+        asset_extra = {"foo": "bar"}
+
+        asset1 = Asset(name=asset1_name, uri="s3://bucket/key/1", 
extra=asset_extra)
+        dag1 = DAG(dag_id=dag_id1, start_date=DEFAULT_DATE, schedule=[asset1])
+        sync_dag_to_db(dag1, session=session)
+
+        @contextmanager
+        def assert_no_in_clause(session):
+            from sqlalchemy import event
+
+            def fail_on_in_clause_found(execute_statement):
+                if " IN " in str(execute_statement).upper():
+                    execute_statement = str(execute_statement).upper()
+                    pytest.fail(
+                        f"Query contains IN clause which was removed in PR 
#62114, query: {execute_statement}"
+                    )
+
+            event.listen(session, "do_orm_execute", fail_on_in_clause_found)
+            try:
+                yield
+            finally:
+                event.remove(session, "do_orm_execute", 
fail_on_in_clause_found)
+
+        asset_models = select(AssetModel).cte()
+
+        with assert_no_in_clause(session):
+            SchedulerJobRunner._activate_referenced_assets(asset_models, 
session=session)
+
     def test_activate_referenced_assets_with_no_existing_warning(self, 
session, testing_dag_bundle):
         dag_warnings = session.query(DagWarning).all()
         assert dag_warnings == []
@@ -7090,8 +7122,8 @@ class TestSchedulerJob:
         dag1 = DAG(dag_id=dag_id1, start_date=DEFAULT_DATE, schedule=[asset1, 
asset1_1, asset1_2])
         sync_dag_to_db(dag1, session=session)
 
-        asset_models = session.scalars(select(AssetModel)).all()
-        assert len(asset_models) == 3
+        asset_models = select(AssetModel).cte()
+        assert len(session.execute(select(asset_models)).all()) == 3
 
         SchedulerJobRunner._activate_referenced_assets(asset_models, 
session=session)
         session.flush()
@@ -7128,7 +7160,7 @@ class TestSchedulerJob:
         )
         session.flush()
 
-        asset_models = session.scalars(select(AssetModel)).all()
+        asset_models = select(AssetModel).cte()
 
         SchedulerJobRunner._activate_referenced_assets(asset_models, 
session=session)
         session.flush()
@@ -7177,7 +7209,7 @@ class TestSchedulerJob:
         session.add(DagWarning(dag_id=dag_id, warning_type="asset conflict", 
message="will not exist"))
         session.flush()
 
-        asset_models = session.scalars(select(AssetModel)).all()
+        asset_models = select(AssetModel).cte()
 
         SchedulerJobRunner._activate_referenced_assets(asset_models, 
session=session)
         session.flush()
diff --git a/devel-common/src/tests_common/pytest_plugin.py 
b/devel-common/src/tests_common/pytest_plugin.py
index 69314717376..264e6e08a2e 100644
--- a/devel-common/src/tests_common/pytest_plugin.py
+++ b/devel-common/src/tests_common/pytest_plugin.py
@@ -937,7 +937,8 @@ def dag_maker(request) -> Generator[DagMaker, None, None]:
                     
AssetModel.producing_tasks.any(TaskOutletAssetReference.dag_id == 
self.dag.dag_id),
                 )
 
-            assets = 
self.session.scalars(select(AssetModel).where(assets_select_condition)).all()
+            assets = select(AssetModel).where(assets_select_condition).cte()
+
             SchedulerJobRunner._activate_referenced_assets(assets, 
session=self.session)
 
         def __exit__(self, type, value, traceback):

Reply via email to