This is an automated email from the ASF dual-hosted git repository.
eladkal pushed a commit to branch v3-1-test
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/v3-1-test by this push:
new a81eb25eef0 [v3-1-test] remove large in clause in assets with cte and
join (#62114) (#63511)
a81eb25eef0 is described below
commit a81eb25eef0f631c49d7947a0f98f336f6f8ea23
Author: Nataneljpwd <[email protected]>
AuthorDate: Fri Mar 13 21:18:36 2026 +0200
[v3-1-test] remove large in clause in assets with cte and join (#62114)
(#63511)
* [v3-1-test] remove large in clause in assets with cte and join
(#62114)\n\n* fixed large in clause\n\n* fixed tests\n\n* changed to delete
using\n\n* added compatibility for tests\n\n* Change asset selection to use
CTE\n\n* Fix asset selection for Airflow versions\n\n* fixed some tests and
optimized the query\n\n* fixed all tests\n\n* fixed mypy\n\n* fixup! address cr
comments\n\n* address CR comments, added tests and fixed plugin\n\n* adress CR
comment\n\n---------\n(cherry picked [...]
* fix CTE import
* set execution options
* lint
* fixed ruff
---------
Co-authored-by: Natanel Rudyuklakir <[email protected]>
---
.../src/airflow/jobs/scheduler_job_runner.py | 78 +++++++++++++---------
airflow-core/tests/unit/jobs/test_scheduler_job.py | 40 +++++++++--
devel-common/src/tests_common/pytest_plugin.py | 3 +-
3 files changed, 86 insertions(+), 35 deletions(-)
diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py
b/airflow-core/src/airflow/jobs/scheduler_job_runner.py
index c66be676137..fc223d7ca80 100644
--- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py
+++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py
@@ -96,6 +96,7 @@ if TYPE_CHECKING:
from pendulum.datetime import DateTime
from sqlalchemy.orm import Load, Query, Session
+ from sqlalchemy.sql.selectable import CTE
from airflow._shared.logging.types import Logger
from airflow.executors.base_executor import BaseExecutor
@@ -592,7 +593,11 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
task_instance,
)
starved_tasks_task_dagrun_concurrency.add(
- (task_instance.dag_id, task_instance.run_id,
task_instance.task_id)
+ (
+ task_instance.dag_id,
+ task_instance.run_id,
+ task_instance.task_id,
+ )
)
continue
@@ -2646,44 +2651,42 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
)
== 0
).label("orphaned")
- asset_reference_query = session.execute(
- select(orphaned, AssetModel)
+ asset_reference_query = (
+ select(AssetModel)
.outerjoin(DagScheduleAssetReference)
.outerjoin(TaskOutletAssetReference)
.outerjoin(TaskInletAssetReference)
.group_by(AssetModel.id)
- .order_by(orphaned)
)
- asset_orphanation: dict[bool, Collection[AssetModel]] = {
- orphaned: [asset for _, asset in group]
- for orphaned, group in itertools.groupby(asset_reference_query,
key=operator.itemgetter(0))
- }
- self._orphan_unreferenced_assets(asset_orphanation.get(True, ()),
session=session)
- self._activate_referenced_assets(asset_orphanation.get(False, ()),
session=session)
+
+ orphan_query = asset_reference_query.having(orphaned).cte()
+ activate_query = asset_reference_query.having(~orphaned).cte()
+
+ self._orphan_unreferenced_assets(orphan_query, session=session)
+ self._activate_referenced_assets(activate_query, session=session)
@staticmethod
- def _orphan_unreferenced_assets(assets: Collection[AssetModel], *,
session: Session) -> None:
- if assets:
- session.execute(
- delete(AssetActive).where(
- tuple_(AssetActive.name, AssetActive.uri).in_((a.name,
a.uri) for a in assets)
+ def _orphan_unreferenced_assets(assets_query: CTE, *, session: Session) ->
None:
+ deleted_orphaned_assets = session.execute(
+ delete(AssetActive).where(
+ exists().where(
+ and_(AssetActive.name == assets_query.c.name,
AssetActive.uri == assets_query.c.uri)
)
- )
- Stats.gauge("asset.orphaned", len(assets))
+ ),
+ execution_options={"synchronize_session": "fetch"},
+ )
- @staticmethod
- def _activate_referenced_assets(assets: Collection[AssetModel], *,
session: Session) -> None:
- if not assets:
- return
+ Stats.gauge("asset.orphaned", max(getattr(deleted_orphaned_assets,
"rowcount", 0), 0))
- active_assets = set(
- session.execute(
- select(AssetActive.name, AssetActive.uri).where(
- tuple_(AssetActive.name, AssetActive.uri).in_((a.name,
a.uri) for a in assets)
- )
- )
+ @staticmethod
+ def _activate_referenced_assets(assets_query: CTE, *, session: Session) ->
None:
+ active_assets_query = select(AssetActive.name, AssetActive.uri).join(
+ assets_query,
+ and_(AssetActive.name == assets_query.c.name, AssetActive.uri ==
assets_query.c.uri),
)
+ active_assets = session.execute(active_assets_query).all()
+
active_name_to_uri: dict[str, str] = {name: uri for name, uri in
active_assets}
active_uri_to_name: dict[str, str] = {uri: name for name, uri in
active_assets}
@@ -2708,9 +2711,24 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
def _activate_assets_generate_warnings() -> Iterator[tuple[str, str]]:
incoming_name_to_uri: dict[str, str] = {}
incoming_uri_to_name: dict[str, str] = {}
- for asset in assets:
- if (asset.name, asset.uri) in active_assets:
- continue
+
+ inactive_assets_query = (
+ select(AssetModel)
+ .join(
+ assets_query,
+ and_(
+ assets_query.c.name == AssetModel.name,
+ assets_query.c.uri == AssetModel.uri,
+ ),
+ )
+ .where(
+ ~active_assets_query.where(
+ and_(AssetActive.name == AssetModel.name,
AssetActive.uri == AssetModel.uri)
+ ).exists()
+ )
+ )
+
+ for asset in session.scalars(inactive_assets_query):
existing_uri = active_name_to_uri.get(asset.name) or
incoming_name_to_uri.get(asset.name)
if existing_uri is not None and existing_uri != asset.uri:
yield from _generate_warning_message(asset, "name",
existing_uri)
diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py
b/airflow-core/tests/unit/jobs/test_scheduler_job.py
index d9cc4e0a4e5..0d6d2ddf360 100644
--- a/airflow-core/tests/unit/jobs/test_scheduler_job.py
+++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py
@@ -23,6 +23,7 @@ import logging
import os
from collections import Counter, deque
from collections.abc import Generator
+from contextlib import contextmanager
from datetime import timedelta
from pathlib import Path
from typing import TYPE_CHECKING
@@ -7076,6 +7077,37 @@ class TestSchedulerJob:
# Check if the second dagrun was created
assert DagRun.find(dag_id="testdag2", session=session)
+ def test_activate_referenced_assets_no_in_check_inside_query(self,
session, testing_dag_bundle):
+ dag_id1 = "test_asset_dag1"
+ asset1_name = "asset1"
+ asset_extra = {"foo": "bar"}
+
+ asset1 = Asset(name=asset1_name, uri="s3://bucket/key/1",
extra=asset_extra)
+ dag1 = DAG(dag_id=dag_id1, start_date=DEFAULT_DATE, schedule=[asset1])
+ sync_dag_to_db(dag1, session=session)
+
+ @contextmanager
+ def assert_no_in_clause(session):
+ from sqlalchemy import event
+
+ def fail_on_in_clause_found(execute_statement):
+ if " IN " in str(execute_statement).upper():
+ execute_statement = str(execute_statement).upper()
+ pytest.fail(
+ f"Query contains IN clause which was removed in PR
#62114, query: {execute_statement}"
+ )
+
+ event.listen(session, "do_orm_execute", fail_on_in_clause_found)
+ try:
+ yield
+ finally:
+ event.remove(session, "do_orm_execute",
fail_on_in_clause_found)
+
+ asset_models = select(AssetModel).cte()
+
+ with assert_no_in_clause(session):
+ SchedulerJobRunner._activate_referenced_assets(asset_models,
session=session)
+
def test_activate_referenced_assets_with_no_existing_warning(self,
session, testing_dag_bundle):
dag_warnings = session.query(DagWarning).all()
assert dag_warnings == []
@@ -7090,8 +7122,8 @@ class TestSchedulerJob:
dag1 = DAG(dag_id=dag_id1, start_date=DEFAULT_DATE, schedule=[asset1,
asset1_1, asset1_2])
sync_dag_to_db(dag1, session=session)
- asset_models = session.scalars(select(AssetModel)).all()
- assert len(asset_models) == 3
+ asset_models = select(AssetModel).cte()
+ assert len(session.execute(select(asset_models)).all()) == 3
SchedulerJobRunner._activate_referenced_assets(asset_models,
session=session)
session.flush()
@@ -7128,7 +7160,7 @@ class TestSchedulerJob:
)
session.flush()
- asset_models = session.scalars(select(AssetModel)).all()
+ asset_models = select(AssetModel).cte()
SchedulerJobRunner._activate_referenced_assets(asset_models,
session=session)
session.flush()
@@ -7177,7 +7209,7 @@ class TestSchedulerJob:
session.add(DagWarning(dag_id=dag_id, warning_type="asset conflict",
message="will not exist"))
session.flush()
- asset_models = session.scalars(select(AssetModel)).all()
+ asset_models = select(AssetModel).cte()
SchedulerJobRunner._activate_referenced_assets(asset_models,
session=session)
session.flush()
diff --git a/devel-common/src/tests_common/pytest_plugin.py
b/devel-common/src/tests_common/pytest_plugin.py
index 69314717376..264e6e08a2e 100644
--- a/devel-common/src/tests_common/pytest_plugin.py
+++ b/devel-common/src/tests_common/pytest_plugin.py
@@ -937,7 +937,8 @@ def dag_maker(request) -> Generator[DagMaker, None, None]:
AssetModel.producing_tasks.any(TaskOutletAssetReference.dag_id ==
self.dag.dag_id),
)
- assets =
self.session.scalars(select(AssetModel).where(assets_select_condition)).all()
+ assets = select(AssetModel).where(assets_select_condition).cte()
+
SchedulerJobRunner._activate_referenced_assets(assets,
session=self.session)
def __exit__(self, type, value, traceback):