This is an automated email from the ASF dual-hosted git repository.
eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new a89dcaa46dc remove large in clause in assets with cte and join (#62114)
a89dcaa46dc is described below
commit a89dcaa46dcf7d1cc1a4e2291da87c90f5541bf6
Author: Nataneljpwd <[email protected]>
AuthorDate: Fri Mar 13 08:08:12 2026 +0200
remove large in clause in assets with cte and join (#62114)
* fixed large in clause
* fixed tests
* changed to delete using
* added compatibility for tests
* Change asset selection to use CTE
* Fix asset selection for Airflow versions
* fixed some tests and optimized the query
* fixed all tests
* fixed mypy
* fixup! address cr comments
* address CR comments, added tests and fixed plugin
* adress CR comment
---------
Co-authored-by: Natanel Rudyuklakir <[email protected]>
Co-authored-by: Elad Kalif <[email protected]>
---
.../src/airflow/jobs/scheduler_job_runner.py | 87 ++++++++++++----------
airflow-core/tests/unit/jobs/test_scheduler_job.py | 41 ++++++++--
devel-common/src/tests_common/pytest_plugin.py | 8 +-
3 files changed, 89 insertions(+), 47 deletions(-)
diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py
b/airflow-core/src/airflow/jobs/scheduler_job_runner.py
index c667631756e..b078cb183bc 100644
--- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py
+++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py
@@ -33,18 +33,7 @@ from functools import lru_cache, partial
from itertools import groupby
from typing import TYPE_CHECKING, Any
-from sqlalchemy import (
- and_,
- delete,
- exists,
- func,
- inspect,
- or_,
- select,
- text,
- tuple_,
- update,
-)
+from sqlalchemy import CTE, and_, delete, exists, func, inspect, or_, select,
text, tuple_, update
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import joinedload, lazyload, load_only, make_transient,
selectinload
from sqlalchemy.sql import expression
@@ -822,7 +811,11 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
task_instance,
)
starved_tasks_task_dagrun_concurrency.add(
- (task_instance.dag_id, task_instance.run_id,
task_instance.task_id)
+ (
+ task_instance.dag_id,
+ task_instance.run_id,
+ task_instance.task_id,
+ )
)
continue
@@ -2940,44 +2933,41 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
)
== 0
).label("orphaned")
- asset_reference_query = session.execute(
- select(orphaned, AssetModel)
+ asset_reference_query = (
+ select(AssetModel)
.outerjoin(DagScheduleAssetReference)
.outerjoin(TaskOutletAssetReference)
.outerjoin(TaskInletAssetReference)
.group_by(AssetModel.id)
- .order_by(orphaned)
)
- asset_orphanation: dict[bool, Collection[AssetModel]] = {
- orphaned: [asset for _, asset in group]
- for orphaned, group in itertools.groupby(asset_reference_query,
key=operator.itemgetter(0))
- }
- self._orphan_unreferenced_assets(asset_orphanation.get(True, ()),
session=session)
- self._activate_referenced_assets(asset_orphanation.get(False, ()),
session=session)
+
+ orphan_query = asset_reference_query.having(orphaned).cte()
+ activate_query = asset_reference_query.having(~orphaned).cte()
+
+ self._orphan_unreferenced_assets(orphan_query, session=session)
+ self._activate_referenced_assets(activate_query, session=session)
@staticmethod
- def _orphan_unreferenced_assets(assets: Collection[AssetModel], *,
session: Session) -> None:
- if assets:
- session.execute(
- delete(AssetActive).where(
- tuple_(AssetActive.name, AssetActive.uri).in_((a.name,
a.uri) for a in assets)
+ def _orphan_unreferenced_assets(assets_query: CTE, *, session: Session) ->
None:
+ deleted_orphaned_assets = session.execute(
+ delete(AssetActive).where(
+ exists().where(
+ and_(AssetActive.name == assets_query.c.name,
AssetActive.uri == assets_query.c.uri)
)
)
- Stats.gauge("asset.orphaned", len(assets))
+ )
- @staticmethod
- def _activate_referenced_assets(assets: Collection[AssetModel], *,
session: Session) -> None:
- if not assets:
- return
+ Stats.gauge("asset.orphaned", max(getattr(deleted_orphaned_assets,
"rowcount", 0), 0))
- active_assets = set(
- session.execute(
- select(AssetActive.name, AssetActive.uri).where(
- tuple_(AssetActive.name, AssetActive.uri).in_((a.name,
a.uri) for a in assets)
- )
- )
+ @staticmethod
+ def _activate_referenced_assets(assets_query: CTE, *, session: Session) ->
None:
+ active_assets_query = select(AssetActive.name, AssetActive.uri).join(
+ assets_query,
+ and_(AssetActive.name == assets_query.c.name, AssetActive.uri ==
assets_query.c.uri),
)
+ active_assets = session.execute(active_assets_query).all()
+
active_name_to_uri: dict[str, str] = {name: uri for name, uri in
active_assets}
active_uri_to_name: dict[str, str] = {uri: name for name, uri in
active_assets}
@@ -3002,9 +2992,24 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
def _activate_assets_generate_warnings() -> Iterator[tuple[str, str]]:
incoming_name_to_uri: dict[str, str] = {}
incoming_uri_to_name: dict[str, str] = {}
- for asset in assets:
- if (asset.name, asset.uri) in active_assets:
- continue
+
+ inactive_assets_query = (
+ select(AssetModel)
+ .join(
+ assets_query,
+ and_(
+ assets_query.c.name == AssetModel.name,
+ assets_query.c.uri == AssetModel.uri,
+ ),
+ )
+ .where(
+ ~active_assets_query.where(
+ and_(AssetActive.name == AssetModel.name,
AssetActive.uri == AssetModel.uri)
+ ).exists()
+ )
+ )
+
+ for asset in session.scalars(inactive_assets_query):
existing_uri = active_name_to_uri.get(asset.name) or
incoming_name_to_uri.get(asset.name)
if existing_uri is not None and existing_uri != asset.uri:
yield from _generate_warning_message(asset, "name",
existing_uri)
diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py
b/airflow-core/tests/unit/jobs/test_scheduler_job.py
index c23f180f3a4..7c8d2ff4ffb 100644
--- a/airflow-core/tests/unit/jobs/test_scheduler_job.py
+++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py
@@ -24,7 +24,7 @@ import os
import re
from collections import Counter, deque
from collections.abc import Callable, Generator, Iterator
-from contextlib import ExitStack
+from contextlib import ExitStack, contextmanager
from datetime import timedelta
from pathlib import Path
from typing import TYPE_CHECKING
@@ -7362,6 +7362,37 @@ class TestSchedulerJob:
job_runner._create_dag_runs([dm1], session)
assert "Failed creating DagRun" in caplog.text
+ def test_activate_referenced_assets_no_in_check_inside_query(self,
session, testing_dag_bundle):
+ dag_id1 = "test_asset_dag1"
+ asset1_name = "asset1"
+ asset_extra = {"foo": "bar"}
+
+ asset1 = Asset(name=asset1_name, uri="s3://bucket/key/1",
extra=asset_extra)
+ dag1 = DAG(dag_id=dag_id1, start_date=DEFAULT_DATE, schedule=[asset1])
+ sync_dag_to_db(dag1, session=session)
+
+ @contextmanager
+ def assert_no_in_clause(session):
+ from sqlalchemy import event
+
+ def fail_on_in_clause_found(execute_statement):
+ if " IN " in str(execute_statement).upper():
+ execute_statement = str(execute_statement).upper()
+ pytest.fail(
+ f"Query contains IN clause which was removed in PR
#62114, query: {execute_statement}"
+ )
+
+ event.listen(session, "do_orm_execute", fail_on_in_clause_found)
+ try:
+ yield
+ finally:
+ event.remove(session, "do_orm_execute",
fail_on_in_clause_found)
+
+ asset_models = select(AssetModel).cte()
+
+ with assert_no_in_clause(session):
+ SchedulerJobRunner._activate_referenced_assets(asset_models,
session=session)
+
def test_activate_referenced_assets_with_no_existing_warning(self,
session, testing_dag_bundle):
dag_warnings = session.scalars(select(DagWarning)).all()
assert dag_warnings == []
@@ -7376,8 +7407,8 @@ class TestSchedulerJob:
dag1 = DAG(dag_id=dag_id1, start_date=DEFAULT_DATE, schedule=[asset1,
asset1_1, asset1_2])
sync_dag_to_db(dag1, session=session)
- asset_models = session.scalars(select(AssetModel)).all()
- assert len(asset_models) == 3
+ asset_models = select(AssetModel).cte()
+ assert len(session.execute(select(asset_models)).all()) == 3
SchedulerJobRunner._activate_referenced_assets(asset_models,
session=session)
session.flush()
@@ -7414,7 +7445,7 @@ class TestSchedulerJob:
)
session.flush()
- asset_models = session.scalars(select(AssetModel)).all()
+ asset_models = select(AssetModel).cte()
SchedulerJobRunner._activate_referenced_assets(asset_models,
session=session)
session.flush()
@@ -7463,7 +7494,7 @@ class TestSchedulerJob:
session.add(DagWarning(dag_id=dag_id, warning_type="asset conflict",
message="will not exist"))
session.flush()
- asset_models = session.scalars(select(AssetModel)).all()
+ asset_models = select(AssetModel).cte()
SchedulerJobRunner._activate_referenced_assets(asset_models,
session=session)
session.flush()
diff --git a/devel-common/src/tests_common/pytest_plugin.py
b/devel-common/src/tests_common/pytest_plugin.py
index 5d473bdfea0..9f3aa0a7bc4 100644
--- a/devel-common/src/tests_common/pytest_plugin.py
+++ b/devel-common/src/tests_common/pytest_plugin.py
@@ -972,7 +972,13 @@ def dag_maker(request) -> Generator[DagMaker, None, None]:
AssetModel.producing_tasks.any(TaskOutletAssetReference.dag_id ==
self.dag.dag_id),
)
- assets =
self.session.scalars(select(AssetModel).where(assets_select_condition)).all()
+ assets = select(AssetModel).where(assets_select_condition).cte()
+
+ if not AIRFLOW_V_3_2_PLUS:
+ assets = self.session.scalars(
+ select(AssetModel).join(assets, AssetModel.id ==
AssetModel.id)
+ ).all()
+
SchedulerJobRunner._activate_referenced_assets(assets,
session=self.session)
def __exit__(self, type, value, traceback):