This is an automated email from the ASF dual-hosted git repository.
uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 3ee844b6972 Consolidate asset orphanization and activation (#43254)
3ee844b6972 is described below
commit 3ee844b6972b9cd6dfd9f8ad4a0f85c49262988c
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Sat Oct 26 23:01:15 2024 +0800
Consolidate asset orphanization and activation (#43254)
---
airflow/dag_processing/collection.py | 40 +++++-------
airflow/dag_processing/processor.py | 11 +++-
airflow/jobs/scheduler_job_runner.py | 118 +++++++++++++++++++++++++++--------
airflow/models/asset.py | 6 +-
airflow/models/dag.py | 1 -
airflow/models/dagwarning.py | 1 +
tests/jobs/test_scheduler_job.py | 108 +++++++++++++++++++-------------
tests/models/test_dag.py | 80 +++++++++++-------------
tests_common/pytest_plugin.py | 27 +++++++-
9 files changed, 244 insertions(+), 148 deletions(-)
diff --git a/airflow/dag_processing/collection.py
b/airflow/dag_processing/collection.py
index 034c9c05401..f27f45dda82 100644
--- a/airflow/dag_processing/collection.py
+++ b/airflow/dag_processing/collection.py
@@ -37,7 +37,6 @@ from sqlalchemy.orm import joinedload, load_only
from airflow.assets import Asset, AssetAlias
from airflow.assets.manager import asset_manager
from airflow.models.asset import (
- AssetActive,
AssetAliasModel,
AssetModel,
DagScheduleAssetAliasReference,
@@ -277,7 +276,7 @@ class AssetModelOperation(NamedTuple):
schedule_asset_references: dict[str, list[Asset]]
schedule_asset_alias_references: dict[str, list[AssetAlias]]
outlet_references: dict[str, list[tuple[str, Asset]]]
- assets: dict[str, Asset]
+ assets: dict[tuple[str, str], Asset]
asset_aliases: dict[str, AssetAlias]
@classmethod
@@ -300,22 +299,25 @@ class AssetModelOperation(NamedTuple):
]
for dag_id, dag in dags.items()
},
- assets={asset.uri: asset for asset in
_find_all_assets(dags.values())},
+ assets={(asset.name, asset.uri): asset for asset in
_find_all_assets(dags.values())},
asset_aliases={alias.name: alias for alias in
_find_all_asset_aliases(dags.values())},
)
return coll
- def add_assets(self, *, session: Session) -> dict[str, AssetModel]:
+ def add_assets(self, *, session: Session) -> dict[tuple[str, str],
AssetModel]:
# Optimization: skip all database calls if no assets were collected.
if not self.assets:
return {}
- orm_assets: dict[str, AssetModel] = {
- am.uri: am for am in
session.scalars(select(AssetModel).where(AssetModel.uri.in_(self.assets)))
+ orm_assets: dict[tuple[str, str], AssetModel] = {
+ (am.name, am.uri): am
+ for am in session.scalars(
+ select(AssetModel).where(tuple_(AssetModel.name,
AssetModel.uri).in_(self.assets))
+ )
}
orm_assets.update(
- (model.uri, model)
+ ((model.name, model.uri), model)
for model in asset_manager.create_assets(
- [asset for uri, asset in self.assets.items() if uri not in
orm_assets],
+ [asset for name_uri, asset in self.assets.items() if name_uri
not in orm_assets],
session=session,
)
)
@@ -340,24 +342,10 @@ class AssetModelOperation(NamedTuple):
)
return orm_aliases
- def add_asset_active_references(self, assets: Collection[AssetModel], *,
session: Session) -> None:
- existing_entries = set(
- session.execute(
- select(AssetActive.name, AssetActive.uri).where(
- tuple_(AssetActive.name, AssetActive.uri).in_((asset.name,
asset.uri) for asset in assets)
- )
- )
- )
- session.add_all(
- AssetActive.for_asset(asset)
- for asset in assets
- if (asset.name, asset.uri) not in existing_entries
- )
-
def add_dag_asset_references(
self,
dags: dict[str, DagModel],
- assets: dict[str, AssetModel],
+ assets: dict[tuple[str, str], AssetModel],
*,
session: Session,
) -> None:
@@ -369,7 +357,7 @@ class AssetModelOperation(NamedTuple):
if not references:
dags[dag_id].schedule_asset_references = []
continue
- referenced_asset_ids = {asset.id for asset in (assets[r.uri] for r
in references)}
+ referenced_asset_ids = {asset.id for asset in (assets[r.name,
r.uri] for r in references)}
orm_refs = {r.asset_id: r for r in
dags[dag_id].schedule_asset_references}
for asset_id, ref in orm_refs.items():
if asset_id not in referenced_asset_ids:
@@ -409,7 +397,7 @@ class AssetModelOperation(NamedTuple):
def add_task_asset_references(
self,
dags: dict[str, DagModel],
- assets: dict[str, AssetModel],
+ assets: dict[tuple[str, str], AssetModel],
*,
session: Session,
) -> None:
@@ -423,7 +411,7 @@ class AssetModelOperation(NamedTuple):
continue
referenced_outlets = {
(task_id, asset.id)
- for task_id, asset in ((task_id, assets[d.uri]) for task_id, d
in references)
+ for task_id, asset in ((task_id, assets[d.name, d.uri]) for
task_id, d in references)
}
orm_refs = {(r.task_id, r.asset_id): r for r in
dags[dag_id].task_outlet_asset_references}
for key, ref in orm_refs.items():
diff --git a/airflow/dag_processing/processor.py
b/airflow/dag_processing/processor.py
index f030cb75019..8694f5890cc 100644
--- a/airflow/dag_processing/processor.py
+++ b/airflow/dag_processing/processor.py
@@ -28,7 +28,7 @@ from dataclasses import dataclass
from typing import TYPE_CHECKING, Generator, Iterable
from setproctitle import setproctitle
-from sqlalchemy import delete, event
+from sqlalchemy import delete, event, select
from airflow import settings
from airflow.api_internal.internal_api_call import internal_api_call
@@ -533,7 +533,14 @@ class DagFileProcessor(LoggingMixin):
)
)
- stored_warnings =
set(session.query(DagWarning).filter(DagWarning.dag_id.in_(dag_ids)).all())
+ stored_warnings = set(
+ session.scalars(
+ select(DagWarning).where(
+ DagWarning.dag_id.in_(dag_ids),
+ DagWarning.warning_type == DagWarningType.NONEXISTENT_POOL,
+ )
+ )
+ )
for warning_to_delete in stored_warnings - warnings:
session.delete(warning_to_delete)
diff --git a/airflow/jobs/scheduler_job_runner.py
b/airflow/jobs/scheduler_job_runner.py
index 04ea8c5e616..15042b0d3f1 100644
--- a/airflow/jobs/scheduler_job_runner.py
+++ b/airflow/jobs/scheduler_job_runner.py
@@ -19,6 +19,7 @@ from __future__ import annotations
import itertools
import multiprocessing
+import operator
import os
import signal
import sys
@@ -55,6 +56,7 @@ from airflow.models.backfill import Backfill
from airflow.models.dag import DAG, DagModel
from airflow.models.dagbag import DagBag
from airflow.models.dagrun import DagRun
+from airflow.models.dagwarning import DagWarning, DagWarningType
from airflow.models.serialized_dag import SerializedDagModel
from airflow.models.taskinstance import SimpleTaskInstance, TaskInstance
from airflow.stats import Stats
@@ -1078,7 +1080,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
timers.call_regular_interval(
conf.getfloat("scheduler", "parsing_cleanup_interval"),
- self._orphan_unreferenced_assets,
+ self._update_asset_orphanage,
)
if self._standalone_dag_processor:
@@ -2068,44 +2070,106 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
SerializedDagModel.remove_dag(dag_id=dag.dag_id, session=session)
session.flush()
- def _get_orphaning_identifier(self, asset: AssetModel) -> tuple[str, str]:
- self.log.info("Orphaning unreferenced %s", asset)
- return asset.name, asset.uri
-
@provide_session
- def _orphan_unreferenced_assets(self, session: Session = NEW_SESSION) ->
None:
+ def _update_asset_orphanage(self, session: Session = NEW_SESSION) -> None:
"""
- Detect orphaned assets and remove their active entry.
+ Check assets orphanization and update their active entry.
- An orphaned asset is no longer referenced in any DAG schedule
parameters or task outlets.
+ An orphaned asset is no longer referenced in any DAG schedule
parameters
+ or task outlets. Active assets (non-orphaned) have entries in
AssetActive
+ and must have unique names and URIs.
"""
- orphaned_asset_query = session.scalars(
- select(AssetModel)
- .join(
- DagScheduleAssetReference,
- isouter=True,
- )
- .join(
- TaskOutletAssetReference,
- isouter=True,
- )
+ # Group assets into orphaned=True and orphaned=False groups.
+ orphaned = (
+ (func.count(DagScheduleAssetReference.dag_id) +
func.count(TaskOutletAssetReference.dag_id)) == 0
+ ).label("orphaned")
+ asset_reference_query = session.execute(
+ select(orphaned, AssetModel)
+ .outerjoin(DagScheduleAssetReference)
+ .outerjoin(TaskOutletAssetReference)
.group_by(AssetModel.id)
- .where(AssetModel.active.has())
- .having(
- and_(
- func.count(DagScheduleAssetReference.dag_id) == 0,
- func.count(TaskOutletAssetReference.dag_id) == 0,
+ .order_by(orphaned)
+ )
+ asset_orphanation: dict[bool, Collection[AssetModel]] = {
+ orphaned: [asset for _, asset in group]
+ for orphaned, group in itertools.groupby(asset_reference_query,
key=operator.itemgetter(0))
+ }
+ self._orphan_unreferenced_assets(asset_orphanation.get(True, ()),
session=session)
+ self._activate_referenced_assets(asset_orphanation.get(False, ()),
session=session)
+
+ @staticmethod
+ def _orphan_unreferenced_assets(assets: Collection[AssetModel], *,
session: Session) -> None:
+ if assets:
+ session.execute(
+ delete(AssetActive).where(
+ tuple_in_condition((AssetActive.name, AssetActive.uri),
((a.name, a.uri) for a in assets))
+ )
+ )
+ Stats.gauge("asset.orphaned", len(assets))
+
+ @staticmethod
+ def _activate_referenced_assets(assets: Collection[AssetModel], *,
session: Session) -> None:
+ if not assets:
+ return
+
+ active_assets = set(
+ session.execute(
+ select(AssetActive.name, AssetActive.uri).where(
+ tuple_in_condition((AssetActive.name, AssetActive.uri),
((a.name, a.uri) for a in assets))
)
)
)
- orphaning_identifiers = [self._get_orphaning_identifier(asset) for
asset in orphaned_asset_query]
+ active_name_to_uri: dict[str, str] = {name: uri for name, uri in
active_assets}
+ active_uri_to_name: dict[str, str] = {uri: name for name, uri in
active_assets}
+
+ def _generate_dag_warnings(offending: AssetModel, attr: str, value:
str) -> Iterator[DagWarning]:
+ for ref in itertools.chain(offending.consuming_dags,
offending.producing_tasks):
+ yield DagWarning(
+ dag_id=ref.dag_id,
+ error_type=DagWarningType.ASSET_CONFLICT,
+ message=f"Cannot activate asset {offending}; {attr} is
already associated to {value!r}",
+ )
+
+ def _activate_assets_generate_warnings() -> Iterator[DagWarning]:
+ incoming_name_to_uri: dict[str, str] = {}
+ incoming_uri_to_name: dict[str, str] = {}
+ for asset in assets:
+ if (asset.name, asset.uri) in active_assets:
+ continue
+ existing_uri = active_name_to_uri.get(asset.name) or
incoming_name_to_uri.get(asset.name)
+ if existing_uri is not None and existing_uri != asset.uri:
+ yield from _generate_dag_warnings(asset, "name",
existing_uri)
+ continue
+ existing_name = active_uri_to_name.get(asset.uri) or
incoming_uri_to_name.get(asset.uri)
+ if existing_name is not None and existing_name != asset.name:
+ yield from _generate_dag_warnings(asset, "uri",
existing_name)
+ continue
+ incoming_name_to_uri[asset.name] = asset.uri
+ incoming_uri_to_name[asset.uri] = asset.name
+ session.add(AssetActive.for_asset(asset))
+
+ warnings_to_have = {w.dag_id: w for w in
_activate_assets_generate_warnings()}
session.execute(
- delete(AssetActive).where(
- tuple_in_condition((AssetActive.name, AssetActive.uri),
orphaning_identifiers)
+ delete(DagWarning).where(
+ DagWarning.warning_type == DagWarningType.ASSET_CONFLICT,
+ DagWarning.dag_id.not_in(warnings_to_have),
+ )
+ )
+ existing_warned_dag_ids: set[str] = set(
+ session.scalars(
+ select(DagWarning.dag_id).where(
+ DagWarning.warning_type == DagWarningType.ASSET_CONFLICT,
+ DagWarning.dag_id.not_in(warnings_to_have),
+ )
)
)
- Stats.gauge("asset.orphaned", len(orphaning_identifiers))
+ for dag_id, warning in warnings_to_have.items():
+ if dag_id in existing_warned_dag_ids:
+ session.merge(warning)
+ continue
+ session.add(warning)
+ existing_warned_dag_ids.add(warning.dag_id)
def _executor_to_tis(self, tis: list[TaskInstance]) -> dict[BaseExecutor,
list[TaskInstance]]:
"""Organize TIs into lists per their respective executor."""
diff --git a/airflow/models/asset.py b/airflow/models/asset.py
index fc77cb7a31d..d6092aaff1b 100644
--- a/airflow/models/asset.py
+++ b/airflow/models/asset.py
@@ -181,7 +181,7 @@ class AssetModel(Base):
created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False)
updated_at = Column(UtcDateTime, default=timezone.utcnow,
onupdate=timezone.utcnow, nullable=False)
- active = relationship("AssetActive", uselist=False, viewonly=True)
+ active = relationship("AssetActive", uselist=False, viewonly=True,
back_populates="asset")
consuming_dags = relationship("DagScheduleAssetReference",
back_populates="asset")
producing_tasks = relationship("TaskOutletAssetReference",
back_populates="asset")
@@ -221,7 +221,7 @@ class AssetModel(Base):
return hash((self.name, self.uri))
def __repr__(self):
- return f"{self.__class__.__name__}(uri={self.uri!r},
extra={self.extra!r})"
+ return f"{self.__class__.__name__}(name={self.name!r},
uri={self.uri!r}, extra={self.extra!r})"
def to_public(self) -> Asset:
return Asset(name=self.name, uri=self.uri, group=self.group,
extra=self.extra)
@@ -264,6 +264,8 @@ class AssetActive(Base):
nullable=False,
)
+ asset = relationship("AssetModel", back_populates="active")
+
__tablename__ = "asset_active"
__table_args__ = (
PrimaryKeyConstraint(name, uri, name="asset_active_pkey"),
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 00943ec2ee2..fd1c67debe2 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -2571,7 +2571,6 @@ class DAG(LoggingMixin):
orm_asset_aliases = asset_op.add_asset_aliases(session=session)
session.flush() # This populates id so we can create fks in later
calls.
- asset_op.add_asset_active_references(orm_assets.values(),
session=session)
asset_op.add_dag_asset_references(orm_dags, orm_assets,
session=session)
asset_op.add_dag_asset_alias_references(orm_dags, orm_asset_aliases,
session=session)
asset_op.add_task_asset_references(orm_dags, orm_assets,
session=session)
diff --git a/airflow/models/dagwarning.py b/airflow/models/dagwarning.py
index ffab515f854..e0c271c4c8e 100644
--- a/airflow/models/dagwarning.py
+++ b/airflow/models/dagwarning.py
@@ -104,4 +104,5 @@ class DagWarningType(str, Enum):
in the DagWarning model.
"""
+ ASSET_CONFLICT = "asset conflict"
NONEXISTENT_POOL = "non-existent pool"
diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py
index ac48344435a..3d71d598799 100644
--- a/tests/jobs/test_scheduler_job.py
+++ b/tests/jobs/test_scheduler_job.py
@@ -52,7 +52,7 @@ from airflow.executors.executor_loader import ExecutorLoader
from airflow.jobs.job import Job, run_job
from airflow.jobs.local_task_job_runner import LocalTaskJobRunner
from airflow.jobs.scheduler_job_runner import SchedulerJobRunner
-from airflow.models.asset import AssetDagRunQueue, AssetEvent, AssetModel
+from airflow.models.asset import AssetActive, AssetDagRunQueue, AssetEvent,
AssetModel
from airflow.models.backfill import Backfill, _create_backfill
from airflow.models.dag import DAG, DagModel
from airflow.models.dagbag import DagBag
@@ -6160,84 +6160,102 @@ class TestSchedulerJob:
(backfill_run,) = DagRun.find(dag_id=dag.dag_id,
run_type=DagRunType.BACKFILL_JOB, session=session)
assert backfill_run.state == State.SUCCESS
+ @staticmethod
+ def _find_assets_activation(session) -> tuple[list[AssetModel],
list[AssetModel]]:
+ assets = session.execute(
+ select(AssetModel, AssetActive)
+ .outerjoin(
+ AssetActive,
+ (AssetModel.name == AssetActive.name) & (AssetModel.uri ==
AssetActive.uri),
+ )
+ .order_by(AssetModel.uri)
+ ).all()
+ return [a for a, v in assets if not v], [a for a, v in assets if v]
+
+ @pytest.mark.want_activate_assets(False)
def test_asset_orphaning(self, dag_maker, session):
+ self.job_runner = SchedulerJobRunner(job=Job(), subdir=os.devnull)
+
asset1 = Asset(uri="ds1")
asset2 = Asset(uri="ds2")
asset3 = Asset(uri="ds3")
asset4 = Asset(uri="ds4")
+ asset5 = Asset(uri="ds5")
with dag_maker(dag_id="assets-1", schedule=[asset1, asset2],
session=session):
BashOperator(task_id="task", bash_command="echo 1",
outlets=[asset3, asset4])
- non_orphaned_asset_count =
session.query(AssetModel).filter(AssetModel.active.has()).count()
- assert non_orphaned_asset_count == 4
- orphaned_asset_count =
session.query(AssetModel).filter(~AssetModel.active.has()).count()
- assert orphaned_asset_count == 0
+ # Assets not activated yet; asset5 is not even registered (since it's
not used anywhere).
+ orphaned, active = self._find_assets_activation(session)
+ assert active == []
+ assert orphaned == [asset1, asset2, asset3, asset4]
- # now remove 2 asset references
+ self.job_runner._update_asset_orphanage(session=session)
+ session.flush()
+
+ # Assets are activated after scheduler loop.
+ orphaned, active = self._find_assets_activation(session)
+ assert active == [asset1, asset2, asset3, asset4]
+ assert orphaned == []
+
+ # Now remove 2 asset references and add asset5.
with dag_maker(dag_id="assets-1", schedule=[asset1], session=session):
- BashOperator(task_id="task", bash_command="echo 1",
outlets=[asset3])
+ BashOperator(task_id="task", bash_command="echo 1",
outlets=[asset3, asset5])
- scheduler_job = Job()
- self.job_runner = SchedulerJobRunner(job=scheduler_job,
subdir=os.devnull)
+ # The DAG parser finds asset5, but it's not activated yet.
+ orphaned, active = self._find_assets_activation(session)
+ assert active == [asset1, asset2, asset3, asset4]
+ assert orphaned == [asset5]
- self.job_runner._orphan_unreferenced_assets(session=session)
+ self.job_runner._update_asset_orphanage(session=session)
session.flush()
- # and find the orphans
- non_orphaned_assets = [
- asset.uri
- for asset in session.query(AssetModel.uri)
- .filter(AssetModel.active.has())
- .order_by(AssetModel.uri)
- ]
- assert non_orphaned_assets == ["ds1", "ds3"]
- orphaned_assets = session.scalars(
-
select(AssetModel.uri).where(~AssetModel.active.has()).order_by(AssetModel.uri)
- ).all()
- assert orphaned_assets == ["ds2", "ds4"]
+ # Now we get the updated result.
+ orphaned, active = self._find_assets_activation(session)
+ assert active == [asset1, asset3, asset5]
+ assert orphaned == [asset2, asset4]
+ @pytest.mark.want_activate_assets(False)
def test_asset_orphaning_ignore_orphaned_assets(self, dag_maker, session):
+ self.job_runner = SchedulerJobRunner(job=Job(), subdir=os.devnull)
+
asset1 = Asset(uri="ds1")
with dag_maker(dag_id="assets-1", schedule=[asset1], session=session):
BashOperator(task_id="task", bash_command="echo 1")
- non_orphaned_asset_count =
session.query(AssetModel).filter(AssetModel.active.has()).count()
- assert non_orphaned_asset_count == 1
- orphaned_asset_count =
session.query(AssetModel).filter(~AssetModel.active.has()).count()
- assert orphaned_asset_count == 0
+ orphaned, active = self._find_assets_activation(session)
+ assert active == []
+ assert orphaned == [asset1]
+
+ self.job_runner._update_asset_orphanage(session=session)
+ session.flush()
+
+ orphaned, active = self._find_assets_activation(session)
+ assert active == [asset1]
+ assert orphaned == []
# now remove asset1 reference
with dag_maker(dag_id="assets-1", schedule=None, session=session):
BashOperator(task_id="task", bash_command="echo 1")
- scheduler_job = Job()
- self.job_runner = SchedulerJobRunner(job=scheduler_job,
subdir=os.devnull)
-
- self.job_runner._orphan_unreferenced_assets(session=session)
+ self.job_runner._update_asset_orphanage(session=session)
session.flush()
- orphaned_assets_before_rerun = (
- session.query(AssetModel.updated_at, AssetModel.uri)
- .filter(~AssetModel.active.has())
- .order_by(AssetModel.uri)
- )
- assert [asset.uri for asset in orphaned_assets_before_rerun] == ["ds1"]
- updated_at_timestamps = [asset.updated_at for asset in
orphaned_assets_before_rerun]
+ orphaned, active = self._find_assets_activation(session)
+ assert active == []
+ assert orphaned == [asset1]
+ updated_at_timestamps = [asset.updated_at for asset in orphaned]
# when rerunning we should ignore the already orphaned assets and thus
the updated_at timestamp
# should remain the same
- self.job_runner._orphan_unreferenced_assets(session=session)
+ self.job_runner._update_asset_orphanage(session=session)
session.flush()
- orphaned_assets_after_rerun = (
- session.query(AssetModel.updated_at, AssetModel.uri)
- .filter(~AssetModel.active.has())
- .order_by(AssetModel.uri)
- )
- assert [asset.uri for asset in orphaned_assets_after_rerun] == ["ds1"]
- assert updated_at_timestamps == [asset.updated_at for asset in
orphaned_assets_after_rerun]
+ orphaned, active = self._find_assets_activation(session)
+ assert active == []
+ assert orphaned == [asset1]
+ assert [asset.updated_at for asset in orphaned] ==
updated_at_timestamps
def test_misconfigured_dags_doesnt_crash_scheduler(self, session,
dag_maker, caplog):
"""Test that if dagrun creation throws an exception, the scheduler
doesn't crash"""
diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py
index 86e499edb3b..1d7a69ba843 100644
--- a/tests/models/test_dag.py
+++ b/tests/models/test_dag.py
@@ -47,6 +47,7 @@ from airflow.exceptions import (
UnknownExecutorException,
)
from airflow.models.asset import (
+ AssetActive,
AssetAliasModel,
AssetDagRunQueue,
AssetEvent,
@@ -1070,54 +1071,47 @@ class TestDag:
.all()
) == {(task_id, dag_id1, asset2_orm.id)}
- def test_bulk_write_to_db_unorphan_assets(self):
+ @staticmethod
+ def _find_assets_activation(session) -> tuple[list[AssetModel],
list[AssetModel]]:
+ assets = session.execute(
+ select(AssetModel, AssetActive)
+ .outerjoin(
+ AssetActive,
+ (AssetModel.name == AssetActive.name) & (AssetModel.uri ==
AssetActive.uri),
+ )
+ .order_by(AssetModel.uri)
+ ).all()
+ return [a for a, v in assets if not v], [a for a, v in assets if v]
+
+ def test_bulk_write_to_db_does_not_activate(self, dag_maker, session):
"""
- Assets can lose their last reference and be orphaned, but then if a
reference to them reappears, we
- need to un-orphan those assets
+ Assets are not activated on write, but later in the scheduler by the
SchedulerJob.
"""
- with create_session() as session:
- # Create four assets - two that have references and two that are
unreferenced and marked as
- # orphans
- asset1 = Asset(uri="ds1")
- asset2 = Asset(uri="ds2")
- session.add(AssetModel(uri=asset2.uri))
- asset3 = Asset(uri="ds3")
- asset4 = Asset(uri="ds4")
- session.add(AssetModel(uri=asset4.uri))
- session.flush()
-
- dag1 = DAG(dag_id="assets-1", start_date=DEFAULT_DATE,
schedule=[asset1])
- BashOperator(dag=dag1, task_id="task", bash_command="echo 1",
outlets=[asset3])
-
- DAG.bulk_write_to_db([dag1], session=session)
-
- # Double check
- non_orphaned_assets = [
- asset.uri
- for asset in session.query(AssetModel.uri)
- .filter(AssetModel.active.has())
- .order_by(AssetModel.uri)
- ]
- assert non_orphaned_assets == ["ds1", "ds3"]
- orphaned_assets = [
- asset.uri
- for asset in session.query(AssetModel.uri)
- .filter(~AssetModel.active.has())
- .order_by(AssetModel.uri)
- ]
- assert orphaned_assets == ["ds2", "ds4"]
+ # Create four assets - two that have references and two that are
unreferenced and marked as
+ # orphans
+ asset1 = Asset(uri="ds1")
+ asset2 = Asset(uri="ds2")
+ asset3 = Asset(uri="ds3")
+ asset4 = Asset(uri="ds4")
- # Now add references to the two unreferenced assets
- dag1 = DAG(dag_id="assets-1", start_date=DEFAULT_DATE,
schedule=[asset1, asset2])
- BashOperator(dag=dag1, task_id="task", bash_command="echo 1",
outlets=[asset3, asset4])
+ dag1 = DAG(dag_id="assets-1", start_date=DEFAULT_DATE,
schedule=[asset1])
+ BashOperator(dag=dag1, task_id="task", bash_command="echo 1",
outlets=[asset3])
+ DAG.bulk_write_to_db([dag1], session=session)
- DAG.bulk_write_to_db([dag1], session=session)
+ assert
session.scalars(select(AssetModel).order_by(AssetModel.uri)).all() == [asset1,
asset3]
+ assert session.scalars(select(AssetActive)).all() == []
- # and count the orphans and non-orphans
- non_orphaned_asset_count =
session.query(AssetModel).filter(AssetModel.active.has()).count()
- assert non_orphaned_asset_count == 4
- orphaned_asset_count =
session.query(AssetModel).filter(~AssetModel.active.has()).count()
- assert orphaned_asset_count == 0
+ dag1 = DAG(dag_id="assets-1", start_date=DEFAULT_DATE,
schedule=[asset1, asset2])
+ BashOperator(dag=dag1, task_id="task", bash_command="echo 1",
outlets=[asset3, asset4])
+ DAG.bulk_write_to_db([dag1], session=session)
+
+ assert
session.scalars(select(AssetModel).order_by(AssetModel.uri)).all() == [
+ asset1,
+ asset2,
+ asset3,
+ asset4,
+ ]
+ assert session.scalars(select(AssetActive)).all() == []
def test_bulk_write_to_db_asset_aliases(self):
"""
diff --git a/tests_common/pytest_plugin.py b/tests_common/pytest_plugin.py
index 0c7ed57fea6..18c3779d7fa 100644
--- a/tests_common/pytest_plugin.py
+++ b/tests_common/pytest_plugin.py
@@ -404,6 +404,7 @@ def pytest_configure(config: pytest.Config) -> None:
config.addinivalue_line(
"markers", "need_serialized_dag: mark tests that require dags in
serialized form to be present"
)
+ config.addinivalue_line("markers", "want_activate_assets: mark tests that
require assets to be activated")
config.addinivalue_line(
"markers",
"db_test: mark tests that require database to be present",
@@ -759,12 +760,14 @@ def dag_maker(request):
# and "baked" in to various constants
want_serialized = False
+ want_activate_assets = True # Only has effect if want_serialized=True on
Airflow 3.
# Allow changing default serialized behaviour with
`@pytest.mark.need_serialized_dag` or
# `@pytest.mark.need_serialized_dag(False)`
- serialized_marker = request.node.get_closest_marker("need_serialized_dag")
- if serialized_marker:
+ if serialized_marker :=
request.node.get_closest_marker("need_serialized_dag"):
(want_serialized,) = serialized_marker.args or (True,)
+ if serialized_marker :=
request.node.get_closest_marker("want_activate_assets"):
+ (want_activate_assets,) = serialized_marker.args or (True,)
from airflow.utils.log.logging_mixin import LoggingMixin
@@ -802,10 +805,26 @@ def dag_maker(request):
return self.dagbag.bag_dag(dag, root_dag=dag)
return self.dagbag.bag_dag(dag)
+ def _activate_assets(self):
+ from sqlalchemy import select
+
+ from airflow.jobs.scheduler_job_runner import SchedulerJobRunner
+ from airflow.models.asset import AssetModel,
DagScheduleAssetReference, TaskOutletAssetReference
+
+ assets = self.session.scalars(
+ select(AssetModel).where(
+
AssetModel.consuming_dags.any(DagScheduleAssetReference.dag_id ==
self.dag.dag_id)
+ |
AssetModel.producing_tasks.any(TaskOutletAssetReference.dag_id ==
self.dag.dag_id)
+ )
+ ).all()
+ SchedulerJobRunner._activate_referenced_assets(assets,
session=self.session)
+
def __exit__(self, type, value, traceback):
from airflow.models import DagModel
from airflow.models.serialized_dag import SerializedDagModel
+ from tests_common.test_utils.compat import AIRFLOW_V_3_0_PLUS
+
dag = self.dag
dag.__exit__(type, value, traceback)
if type is not None:
@@ -822,6 +841,8 @@ def dag_maker(request):
self.session.merge(self.serialized_model)
serialized_dag = self._serialized_dag()
self._bag_dag_compat(serialized_dag)
+ if AIRFLOW_V_3_0_PLUS and self.want_activate_assets:
+ self._activate_assets()
self.session.flush()
else:
self._bag_dag_compat(self.dag)
@@ -887,6 +908,7 @@ def dag_maker(request):
dag_id="test_dag",
schedule=timedelta(days=1),
serialized=want_serialized,
+ activate_assets=want_activate_assets,
fileloc=None,
processor_subdir=None,
session=None,
@@ -919,6 +941,7 @@ def dag_maker(request):
self.dag = DAG(dag_id, schedule=schedule, **self.kwargs)
self.dag.fileloc = fileloc or request.module.__file__
self.want_serialized = serialized
+ self.want_activate_assets = activate_assets
self.processor_subdir = processor_subdir
return self