This is an automated email from the ASF dual-hosted git repository.

uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 3ee844b6972 Consolidate asset orphanization and activation (#43254)
3ee844b6972 is described below

commit 3ee844b6972b9cd6dfd9f8ad4a0f85c49262988c
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Sat Oct 26 23:01:15 2024 +0800

    Consolidate asset orphanization and activation (#43254)
---
 airflow/dag_processing/collection.py |  40 +++++-------
 airflow/dag_processing/processor.py  |  11 +++-
 airflow/jobs/scheduler_job_runner.py | 118 +++++++++++++++++++++++++++--------
 airflow/models/asset.py              |   6 +-
 airflow/models/dag.py                |   1 -
 airflow/models/dagwarning.py         |   1 +
 tests/jobs/test_scheduler_job.py     | 108 +++++++++++++++++++-------------
 tests/models/test_dag.py             |  80 +++++++++++-------------
 tests_common/pytest_plugin.py        |  27 +++++++-
 9 files changed, 244 insertions(+), 148 deletions(-)

diff --git a/airflow/dag_processing/collection.py 
b/airflow/dag_processing/collection.py
index 034c9c05401..f27f45dda82 100644
--- a/airflow/dag_processing/collection.py
+++ b/airflow/dag_processing/collection.py
@@ -37,7 +37,6 @@ from sqlalchemy.orm import joinedload, load_only
 from airflow.assets import Asset, AssetAlias
 from airflow.assets.manager import asset_manager
 from airflow.models.asset import (
-    AssetActive,
     AssetAliasModel,
     AssetModel,
     DagScheduleAssetAliasReference,
@@ -277,7 +276,7 @@ class AssetModelOperation(NamedTuple):
     schedule_asset_references: dict[str, list[Asset]]
     schedule_asset_alias_references: dict[str, list[AssetAlias]]
     outlet_references: dict[str, list[tuple[str, Asset]]]
-    assets: dict[str, Asset]
+    assets: dict[tuple[str, str], Asset]
     asset_aliases: dict[str, AssetAlias]
 
     @classmethod
@@ -300,22 +299,25 @@ class AssetModelOperation(NamedTuple):
                 ]
                 for dag_id, dag in dags.items()
             },
-            assets={asset.uri: asset for asset in 
_find_all_assets(dags.values())},
+            assets={(asset.name, asset.uri): asset for asset in 
_find_all_assets(dags.values())},
             asset_aliases={alias.name: alias for alias in 
_find_all_asset_aliases(dags.values())},
         )
         return coll
 
-    def add_assets(self, *, session: Session) -> dict[str, AssetModel]:
+    def add_assets(self, *, session: Session) -> dict[tuple[str, str], 
AssetModel]:
         # Optimization: skip all database calls if no assets were collected.
         if not self.assets:
             return {}
-        orm_assets: dict[str, AssetModel] = {
-            am.uri: am for am in 
session.scalars(select(AssetModel).where(AssetModel.uri.in_(self.assets)))
+        orm_assets: dict[tuple[str, str], AssetModel] = {
+            (am.name, am.uri): am
+            for am in session.scalars(
+                select(AssetModel).where(tuple_(AssetModel.name, 
AssetModel.uri).in_(self.assets))
+            )
         }
         orm_assets.update(
-            (model.uri, model)
+            ((model.name, model.uri), model)
             for model in asset_manager.create_assets(
-                [asset for uri, asset in self.assets.items() if uri not in 
orm_assets],
+                [asset for name_uri, asset in self.assets.items() if name_uri 
not in orm_assets],
                 session=session,
             )
         )
@@ -340,24 +342,10 @@ class AssetModelOperation(NamedTuple):
         )
         return orm_aliases
 
-    def add_asset_active_references(self, assets: Collection[AssetModel], *, 
session: Session) -> None:
-        existing_entries = set(
-            session.execute(
-                select(AssetActive.name, AssetActive.uri).where(
-                    tuple_(AssetActive.name, AssetActive.uri).in_((asset.name, 
asset.uri) for asset in assets)
-                )
-            )
-        )
-        session.add_all(
-            AssetActive.for_asset(asset)
-            for asset in assets
-            if (asset.name, asset.uri) not in existing_entries
-        )
-
     def add_dag_asset_references(
         self,
         dags: dict[str, DagModel],
-        assets: dict[str, AssetModel],
+        assets: dict[tuple[str, str], AssetModel],
         *,
         session: Session,
     ) -> None:
@@ -369,7 +357,7 @@ class AssetModelOperation(NamedTuple):
             if not references:
                 dags[dag_id].schedule_asset_references = []
                 continue
-            referenced_asset_ids = {asset.id for asset in (assets[r.uri] for r 
in references)}
+            referenced_asset_ids = {asset.id for asset in (assets[r.name, 
r.uri] for r in references)}
             orm_refs = {r.asset_id: r for r in 
dags[dag_id].schedule_asset_references}
             for asset_id, ref in orm_refs.items():
                 if asset_id not in referenced_asset_ids:
@@ -409,7 +397,7 @@ class AssetModelOperation(NamedTuple):
     def add_task_asset_references(
         self,
         dags: dict[str, DagModel],
-        assets: dict[str, AssetModel],
+        assets: dict[tuple[str, str], AssetModel],
         *,
         session: Session,
     ) -> None:
@@ -423,7 +411,7 @@ class AssetModelOperation(NamedTuple):
                 continue
             referenced_outlets = {
                 (task_id, asset.id)
-                for task_id, asset in ((task_id, assets[d.uri]) for task_id, d 
in references)
+                for task_id, asset in ((task_id, assets[d.name, d.uri]) for 
task_id, d in references)
             }
             orm_refs = {(r.task_id, r.asset_id): r for r in 
dags[dag_id].task_outlet_asset_references}
             for key, ref in orm_refs.items():
diff --git a/airflow/dag_processing/processor.py 
b/airflow/dag_processing/processor.py
index f030cb75019..8694f5890cc 100644
--- a/airflow/dag_processing/processor.py
+++ b/airflow/dag_processing/processor.py
@@ -28,7 +28,7 @@ from dataclasses import dataclass
 from typing import TYPE_CHECKING, Generator, Iterable
 
 from setproctitle import setproctitle
-from sqlalchemy import delete, event
+from sqlalchemy import delete, event, select
 
 from airflow import settings
 from airflow.api_internal.internal_api_call import internal_api_call
@@ -533,7 +533,14 @@ class DagFileProcessor(LoggingMixin):
                     )
                 )
 
-        stored_warnings = 
set(session.query(DagWarning).filter(DagWarning.dag_id.in_(dag_ids)).all())
+        stored_warnings = set(
+            session.scalars(
+                select(DagWarning).where(
+                    DagWarning.dag_id.in_(dag_ids),
+                    DagWarning.warning_type == DagWarningType.NONEXISTENT_POOL,
+                )
+            )
+        )
 
         for warning_to_delete in stored_warnings - warnings:
             session.delete(warning_to_delete)
diff --git a/airflow/jobs/scheduler_job_runner.py 
b/airflow/jobs/scheduler_job_runner.py
index 04ea8c5e616..15042b0d3f1 100644
--- a/airflow/jobs/scheduler_job_runner.py
+++ b/airflow/jobs/scheduler_job_runner.py
@@ -19,6 +19,7 @@ from __future__ import annotations
 
 import itertools
 import multiprocessing
+import operator
 import os
 import signal
 import sys
@@ -55,6 +56,7 @@ from airflow.models.backfill import Backfill
 from airflow.models.dag import DAG, DagModel
 from airflow.models.dagbag import DagBag
 from airflow.models.dagrun import DagRun
+from airflow.models.dagwarning import DagWarning, DagWarningType
 from airflow.models.serialized_dag import SerializedDagModel
 from airflow.models.taskinstance import SimpleTaskInstance, TaskInstance
 from airflow.stats import Stats
@@ -1078,7 +1080,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
 
         timers.call_regular_interval(
             conf.getfloat("scheduler", "parsing_cleanup_interval"),
-            self._orphan_unreferenced_assets,
+            self._update_asset_orphanage,
         )
 
         if self._standalone_dag_processor:
@@ -2068,44 +2070,106 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
             SerializedDagModel.remove_dag(dag_id=dag.dag_id, session=session)
         session.flush()
 
-    def _get_orphaning_identifier(self, asset: AssetModel) -> tuple[str, str]:
-        self.log.info("Orphaning unreferenced %s", asset)
-        return asset.name, asset.uri
-
     @provide_session
-    def _orphan_unreferenced_assets(self, session: Session = NEW_SESSION) -> 
None:
+    def _update_asset_orphanage(self, session: Session = NEW_SESSION) -> None:
         """
-        Detect orphaned assets and remove their active entry.
+        Check assets orphanization and update their active entry.
 
-        An orphaned asset is no longer referenced in any DAG schedule 
parameters or task outlets.
+        An orphaned asset is no longer referenced in any DAG schedule 
parameters
+        or task outlets. Active assets (non-orphaned) have entries in 
AssetActive
+        and must have unique names and URIs.
         """
-        orphaned_asset_query = session.scalars(
-            select(AssetModel)
-            .join(
-                DagScheduleAssetReference,
-                isouter=True,
-            )
-            .join(
-                TaskOutletAssetReference,
-                isouter=True,
-            )
+        # Group assets into orphaned=True and orphaned=False groups.
+        orphaned = (
+            (func.count(DagScheduleAssetReference.dag_id) + 
func.count(TaskOutletAssetReference.dag_id)) == 0
+        ).label("orphaned")
+        asset_reference_query = session.execute(
+            select(orphaned, AssetModel)
+            .outerjoin(DagScheduleAssetReference)
+            .outerjoin(TaskOutletAssetReference)
             .group_by(AssetModel.id)
-            .where(AssetModel.active.has())
-            .having(
-                and_(
-                    func.count(DagScheduleAssetReference.dag_id) == 0,
-                    func.count(TaskOutletAssetReference.dag_id) == 0,
+            .order_by(orphaned)
+        )
+        asset_orphanation: dict[bool, Collection[AssetModel]] = {
+            orphaned: [asset for _, asset in group]
+            for orphaned, group in itertools.groupby(asset_reference_query, 
key=operator.itemgetter(0))
+        }
+        self._orphan_unreferenced_assets(asset_orphanation.get(True, ()), 
session=session)
+        self._activate_referenced_assets(asset_orphanation.get(False, ()), 
session=session)
+
+    @staticmethod
+    def _orphan_unreferenced_assets(assets: Collection[AssetModel], *, 
session: Session) -> None:
+        if assets:
+            session.execute(
+                delete(AssetActive).where(
+                    tuple_in_condition((AssetActive.name, AssetActive.uri), 
((a.name, a.uri) for a in assets))
+                )
+            )
+        Stats.gauge("asset.orphaned", len(assets))
+
+    @staticmethod
+    def _activate_referenced_assets(assets: Collection[AssetModel], *, 
session: Session) -> None:
+        if not assets:
+            return
+
+        active_assets = set(
+            session.execute(
+                select(AssetActive.name, AssetActive.uri).where(
+                    tuple_in_condition((AssetActive.name, AssetActive.uri), 
((a.name, a.uri) for a in assets))
                 )
             )
         )
 
-        orphaning_identifiers = [self._get_orphaning_identifier(asset) for 
asset in orphaned_asset_query]
+        active_name_to_uri: dict[str, str] = {name: uri for name, uri in 
active_assets}
+        active_uri_to_name: dict[str, str] = {uri: name for name, uri in 
active_assets}
+
+        def _generate_dag_warnings(offending: AssetModel, attr: str, value: 
str) -> Iterator[DagWarning]:
+            for ref in itertools.chain(offending.consuming_dags, 
offending.producing_tasks):
+                yield DagWarning(
+                    dag_id=ref.dag_id,
+                    error_type=DagWarningType.ASSET_CONFLICT,
+                    message=f"Cannot activate asset {offending}; {attr} is 
already associated to {value!r}",
+                )
+
+        def _activate_assets_generate_warnings() -> Iterator[DagWarning]:
+            incoming_name_to_uri: dict[str, str] = {}
+            incoming_uri_to_name: dict[str, str] = {}
+            for asset in assets:
+                if (asset.name, asset.uri) in active_assets:
+                    continue
+                existing_uri = active_name_to_uri.get(asset.name) or 
incoming_name_to_uri.get(asset.name)
+                if existing_uri is not None and existing_uri != asset.uri:
+                    yield from _generate_dag_warnings(asset, "name", 
existing_uri)
+                    continue
+                existing_name = active_uri_to_name.get(asset.uri) or 
incoming_uri_to_name.get(asset.uri)
+                if existing_name is not None and existing_name != asset.name:
+                    yield from _generate_dag_warnings(asset, "uri", 
existing_name)
+                    continue
+                incoming_name_to_uri[asset.name] = asset.uri
+                incoming_uri_to_name[asset.uri] = asset.name
+                session.add(AssetActive.for_asset(asset))
+
+        warnings_to_have = {w.dag_id: w for w in 
_activate_assets_generate_warnings()}
         session.execute(
-            delete(AssetActive).where(
-                tuple_in_condition((AssetActive.name, AssetActive.uri), 
orphaning_identifiers)
+            delete(DagWarning).where(
+                DagWarning.warning_type == DagWarningType.ASSET_CONFLICT,
+                DagWarning.dag_id.not_in(warnings_to_have),
+            )
+        )
+        existing_warned_dag_ids: set[str] = set(
+            session.scalars(
+                select(DagWarning.dag_id).where(
+                    DagWarning.warning_type == DagWarningType.ASSET_CONFLICT,
+                    DagWarning.dag_id.not_in(warnings_to_have),
+                )
             )
         )
-        Stats.gauge("asset.orphaned", len(orphaning_identifiers))
+        for dag_id, warning in warnings_to_have.items():
+            if dag_id in existing_warned_dag_ids:
+                session.merge(warning)
+                continue
+            session.add(warning)
+            existing_warned_dag_ids.add(warning.dag_id)
 
     def _executor_to_tis(self, tis: list[TaskInstance]) -> dict[BaseExecutor, 
list[TaskInstance]]:
         """Organize TIs into lists per their respective executor."""
diff --git a/airflow/models/asset.py b/airflow/models/asset.py
index fc77cb7a31d..d6092aaff1b 100644
--- a/airflow/models/asset.py
+++ b/airflow/models/asset.py
@@ -181,7 +181,7 @@ class AssetModel(Base):
     created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False)
     updated_at = Column(UtcDateTime, default=timezone.utcnow, 
onupdate=timezone.utcnow, nullable=False)
 
-    active = relationship("AssetActive", uselist=False, viewonly=True)
+    active = relationship("AssetActive", uselist=False, viewonly=True, 
back_populates="asset")
 
     consuming_dags = relationship("DagScheduleAssetReference", 
back_populates="asset")
     producing_tasks = relationship("TaskOutletAssetReference", 
back_populates="asset")
@@ -221,7 +221,7 @@ class AssetModel(Base):
         return hash((self.name, self.uri))
 
     def __repr__(self):
-        return f"{self.__class__.__name__}(uri={self.uri!r}, 
extra={self.extra!r})"
+        return f"{self.__class__.__name__}(name={self.name!r}, 
uri={self.uri!r}, extra={self.extra!r})"
 
     def to_public(self) -> Asset:
         return Asset(name=self.name, uri=self.uri, group=self.group, 
extra=self.extra)
@@ -264,6 +264,8 @@ class AssetActive(Base):
         nullable=False,
     )
 
+    asset = relationship("AssetModel", back_populates="active")
+
     __tablename__ = "asset_active"
     __table_args__ = (
         PrimaryKeyConstraint(name, uri, name="asset_active_pkey"),
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 00943ec2ee2..fd1c67debe2 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -2571,7 +2571,6 @@ class DAG(LoggingMixin):
         orm_asset_aliases = asset_op.add_asset_aliases(session=session)
         session.flush()  # This populates id so we can create fks in later 
calls.
 
-        asset_op.add_asset_active_references(orm_assets.values(), 
session=session)
         asset_op.add_dag_asset_references(orm_dags, orm_assets, 
session=session)
         asset_op.add_dag_asset_alias_references(orm_dags, orm_asset_aliases, 
session=session)
         asset_op.add_task_asset_references(orm_dags, orm_assets, 
session=session)
diff --git a/airflow/models/dagwarning.py b/airflow/models/dagwarning.py
index ffab515f854..e0c271c4c8e 100644
--- a/airflow/models/dagwarning.py
+++ b/airflow/models/dagwarning.py
@@ -104,4 +104,5 @@ class DagWarningType(str, Enum):
     in the DagWarning model.
     """
 
+    ASSET_CONFLICT = "asset conflict"
     NONEXISTENT_POOL = "non-existent pool"
diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py
index ac48344435a..3d71d598799 100644
--- a/tests/jobs/test_scheduler_job.py
+++ b/tests/jobs/test_scheduler_job.py
@@ -52,7 +52,7 @@ from airflow.executors.executor_loader import ExecutorLoader
 from airflow.jobs.job import Job, run_job
 from airflow.jobs.local_task_job_runner import LocalTaskJobRunner
 from airflow.jobs.scheduler_job_runner import SchedulerJobRunner
-from airflow.models.asset import AssetDagRunQueue, AssetEvent, AssetModel
+from airflow.models.asset import AssetActive, AssetDagRunQueue, AssetEvent, 
AssetModel
 from airflow.models.backfill import Backfill, _create_backfill
 from airflow.models.dag import DAG, DagModel
 from airflow.models.dagbag import DagBag
@@ -6160,84 +6160,102 @@ class TestSchedulerJob:
         (backfill_run,) = DagRun.find(dag_id=dag.dag_id, 
run_type=DagRunType.BACKFILL_JOB, session=session)
         assert backfill_run.state == State.SUCCESS
 
+    @staticmethod
+    def _find_assets_activation(session) -> tuple[list[AssetModel], 
list[AssetModel]]:
+        assets = session.execute(
+            select(AssetModel, AssetActive)
+            .outerjoin(
+                AssetActive,
+                (AssetModel.name == AssetActive.name) & (AssetModel.uri == 
AssetActive.uri),
+            )
+            .order_by(AssetModel.uri)
+        ).all()
+        return [a for a, v in assets if not v], [a for a, v in assets if v]
+
+    @pytest.mark.want_activate_assets(False)
     def test_asset_orphaning(self, dag_maker, session):
+        self.job_runner = SchedulerJobRunner(job=Job(), subdir=os.devnull)
+
         asset1 = Asset(uri="ds1")
         asset2 = Asset(uri="ds2")
         asset3 = Asset(uri="ds3")
         asset4 = Asset(uri="ds4")
+        asset5 = Asset(uri="ds5")
 
         with dag_maker(dag_id="assets-1", schedule=[asset1, asset2], 
session=session):
             BashOperator(task_id="task", bash_command="echo 1", 
outlets=[asset3, asset4])
 
-        non_orphaned_asset_count = 
session.query(AssetModel).filter(AssetModel.active.has()).count()
-        assert non_orphaned_asset_count == 4
-        orphaned_asset_count = 
session.query(AssetModel).filter(~AssetModel.active.has()).count()
-        assert orphaned_asset_count == 0
+        # Assets not activated yet; asset5 is not even registered (since it's 
not used anywhere).
+        orphaned, active = self._find_assets_activation(session)
+        assert active == []
+        assert orphaned == [asset1, asset2, asset3, asset4]
 
-        # now remove 2 asset references
+        self.job_runner._update_asset_orphanage(session=session)
+        session.flush()
+
+        # Assets are activated after scheduler loop.
+        orphaned, active = self._find_assets_activation(session)
+        assert active == [asset1, asset2, asset3, asset4]
+        assert orphaned == []
+
+        # Now remove 2 asset references and add asset5.
         with dag_maker(dag_id="assets-1", schedule=[asset1], session=session):
-            BashOperator(task_id="task", bash_command="echo 1", 
outlets=[asset3])
+            BashOperator(task_id="task", bash_command="echo 1", 
outlets=[asset3, asset5])
 
-        scheduler_job = Job()
-        self.job_runner = SchedulerJobRunner(job=scheduler_job, 
subdir=os.devnull)
+        # The DAG parser finds asset5, but it's not activated yet.
+        orphaned, active = self._find_assets_activation(session)
+        assert active == [asset1, asset2, asset3, asset4]
+        assert orphaned == [asset5]
 
-        self.job_runner._orphan_unreferenced_assets(session=session)
+        self.job_runner._update_asset_orphanage(session=session)
         session.flush()
 
-        # and find the orphans
-        non_orphaned_assets = [
-            asset.uri
-            for asset in session.query(AssetModel.uri)
-            .filter(AssetModel.active.has())
-            .order_by(AssetModel.uri)
-        ]
-        assert non_orphaned_assets == ["ds1", "ds3"]
-        orphaned_assets = session.scalars(
-            
select(AssetModel.uri).where(~AssetModel.active.has()).order_by(AssetModel.uri)
-        ).all()
-        assert orphaned_assets == ["ds2", "ds4"]
+        # Now we get the updated result.
+        orphaned, active = self._find_assets_activation(session)
+        assert active == [asset1, asset3, asset5]
+        assert orphaned == [asset2, asset4]
 
+    @pytest.mark.want_activate_assets(False)
     def test_asset_orphaning_ignore_orphaned_assets(self, dag_maker, session):
+        self.job_runner = SchedulerJobRunner(job=Job(), subdir=os.devnull)
+
         asset1 = Asset(uri="ds1")
 
         with dag_maker(dag_id="assets-1", schedule=[asset1], session=session):
             BashOperator(task_id="task", bash_command="echo 1")
 
-        non_orphaned_asset_count = 
session.query(AssetModel).filter(AssetModel.active.has()).count()
-        assert non_orphaned_asset_count == 1
-        orphaned_asset_count = 
session.query(AssetModel).filter(~AssetModel.active.has()).count()
-        assert orphaned_asset_count == 0
+        orphaned, active = self._find_assets_activation(session)
+        assert active == []
+        assert orphaned == [asset1]
+
+        self.job_runner._update_asset_orphanage(session=session)
+        session.flush()
+
+        orphaned, active = self._find_assets_activation(session)
+        assert active == [asset1]
+        assert orphaned == []
 
         # now remove asset1 reference
         with dag_maker(dag_id="assets-1", schedule=None, session=session):
             BashOperator(task_id="task", bash_command="echo 1")
 
-        scheduler_job = Job()
-        self.job_runner = SchedulerJobRunner(job=scheduler_job, 
subdir=os.devnull)
-
-        self.job_runner._orphan_unreferenced_assets(session=session)
+        self.job_runner._update_asset_orphanage(session=session)
         session.flush()
 
-        orphaned_assets_before_rerun = (
-            session.query(AssetModel.updated_at, AssetModel.uri)
-            .filter(~AssetModel.active.has())
-            .order_by(AssetModel.uri)
-        )
-        assert [asset.uri for asset in orphaned_assets_before_rerun] == ["ds1"]
-        updated_at_timestamps = [asset.updated_at for asset in 
orphaned_assets_before_rerun]
+        orphaned, active = self._find_assets_activation(session)
+        assert active == []
+        assert orphaned == [asset1]
+        updated_at_timestamps = [asset.updated_at for asset in orphaned]
 
         # when rerunning we should ignore the already orphaned assets and thus 
the updated_at timestamp
         # should remain the same
-        self.job_runner._orphan_unreferenced_assets(session=session)
+        self.job_runner._update_asset_orphanage(session=session)
         session.flush()
 
-        orphaned_assets_after_rerun = (
-            session.query(AssetModel.updated_at, AssetModel.uri)
-            .filter(~AssetModel.active.has())
-            .order_by(AssetModel.uri)
-        )
-        assert [asset.uri for asset in orphaned_assets_after_rerun] == ["ds1"]
-        assert updated_at_timestamps == [asset.updated_at for asset in 
orphaned_assets_after_rerun]
+        orphaned, active = self._find_assets_activation(session)
+        assert active == []
+        assert orphaned == [asset1]
+        assert [asset.updated_at for asset in orphaned] == 
updated_at_timestamps
 
     def test_misconfigured_dags_doesnt_crash_scheduler(self, session, 
dag_maker, caplog):
         """Test that if dagrun creation throws an exception, the scheduler 
doesn't crash"""
diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py
index 86e499edb3b..1d7a69ba843 100644
--- a/tests/models/test_dag.py
+++ b/tests/models/test_dag.py
@@ -47,6 +47,7 @@ from airflow.exceptions import (
     UnknownExecutorException,
 )
 from airflow.models.asset import (
+    AssetActive,
     AssetAliasModel,
     AssetDagRunQueue,
     AssetEvent,
@@ -1070,54 +1071,47 @@ class TestDag:
             .all()
         ) == {(task_id, dag_id1, asset2_orm.id)}
 
-    def test_bulk_write_to_db_unorphan_assets(self):
+    @staticmethod
+    def _find_assets_activation(session) -> tuple[list[AssetModel], 
list[AssetModel]]:
+        assets = session.execute(
+            select(AssetModel, AssetActive)
+            .outerjoin(
+                AssetActive,
+                (AssetModel.name == AssetActive.name) & (AssetModel.uri == 
AssetActive.uri),
+            )
+            .order_by(AssetModel.uri)
+        ).all()
+        return [a for a, v in assets if not v], [a for a, v in assets if v]
+
+    def test_bulk_write_to_db_does_not_activate(self, dag_maker, session):
         """
-        Assets can lose their last reference and be orphaned, but then if a 
reference to them reappears, we
-        need to un-orphan those assets
+        Assets are not activated on write, but later in the scheduler by the 
SchedulerJob.
         """
-        with create_session() as session:
-            # Create four assets - two that have references and two that are 
unreferenced and marked as
-            # orphans
-            asset1 = Asset(uri="ds1")
-            asset2 = Asset(uri="ds2")
-            session.add(AssetModel(uri=asset2.uri))
-            asset3 = Asset(uri="ds3")
-            asset4 = Asset(uri="ds4")
-            session.add(AssetModel(uri=asset4.uri))
-            session.flush()
-
-            dag1 = DAG(dag_id="assets-1", start_date=DEFAULT_DATE, 
schedule=[asset1])
-            BashOperator(dag=dag1, task_id="task", bash_command="echo 1", 
outlets=[asset3])
-
-            DAG.bulk_write_to_db([dag1], session=session)
-
-            # Double check
-            non_orphaned_assets = [
-                asset.uri
-                for asset in session.query(AssetModel.uri)
-                .filter(AssetModel.active.has())
-                .order_by(AssetModel.uri)
-            ]
-            assert non_orphaned_assets == ["ds1", "ds3"]
-            orphaned_assets = [
-                asset.uri
-                for asset in session.query(AssetModel.uri)
-                .filter(~AssetModel.active.has())
-                .order_by(AssetModel.uri)
-            ]
-            assert orphaned_assets == ["ds2", "ds4"]
+        # Create four assets - two that have references and two that are 
unreferenced and marked as
+        # orphans
+        asset1 = Asset(uri="ds1")
+        asset2 = Asset(uri="ds2")
+        asset3 = Asset(uri="ds3")
+        asset4 = Asset(uri="ds4")
 
-            # Now add references to the two unreferenced assets
-            dag1 = DAG(dag_id="assets-1", start_date=DEFAULT_DATE, 
schedule=[asset1, asset2])
-            BashOperator(dag=dag1, task_id="task", bash_command="echo 1", 
outlets=[asset3, asset4])
+        dag1 = DAG(dag_id="assets-1", start_date=DEFAULT_DATE, 
schedule=[asset1])
+        BashOperator(dag=dag1, task_id="task", bash_command="echo 1", 
outlets=[asset3])
+        DAG.bulk_write_to_db([dag1], session=session)
 
-            DAG.bulk_write_to_db([dag1], session=session)
+        assert 
session.scalars(select(AssetModel).order_by(AssetModel.uri)).all() == [asset1, 
asset3]
+        assert session.scalars(select(AssetActive)).all() == []
 
-            # and count the orphans and non-orphans
-            non_orphaned_asset_count = 
session.query(AssetModel).filter(AssetModel.active.has()).count()
-            assert non_orphaned_asset_count == 4
-            orphaned_asset_count = 
session.query(AssetModel).filter(~AssetModel.active.has()).count()
-            assert orphaned_asset_count == 0
+        dag1 = DAG(dag_id="assets-1", start_date=DEFAULT_DATE, 
schedule=[asset1, asset2])
+        BashOperator(dag=dag1, task_id="task", bash_command="echo 1", 
outlets=[asset3, asset4])
+        DAG.bulk_write_to_db([dag1], session=session)
+
+        assert 
session.scalars(select(AssetModel).order_by(AssetModel.uri)).all() == [
+            asset1,
+            asset2,
+            asset3,
+            asset4,
+        ]
+        assert session.scalars(select(AssetActive)).all() == []
 
     def test_bulk_write_to_db_asset_aliases(self):
         """
diff --git a/tests_common/pytest_plugin.py b/tests_common/pytest_plugin.py
index 0c7ed57fea6..18c3779d7fa 100644
--- a/tests_common/pytest_plugin.py
+++ b/tests_common/pytest_plugin.py
@@ -404,6 +404,7 @@ def pytest_configure(config: pytest.Config) -> None:
     config.addinivalue_line(
         "markers", "need_serialized_dag: mark tests that require dags in 
serialized form to be present"
     )
+    config.addinivalue_line("markers", "want_activate_assets: mark tests that 
require assets to be activated")
     config.addinivalue_line(
         "markers",
         "db_test: mark tests that require database to be present",
@@ -759,12 +760,14 @@ def dag_maker(request):
     # and "baked" in to various constants
 
     want_serialized = False
+    want_activate_assets = True  # Only has effect if want_serialized=True on 
Airflow 3.
 
     # Allow changing default serialized behaviour with 
`@pytest.mark.need_serialized_dag` or
     # `@pytest.mark.need_serialized_dag(False)`
-    serialized_marker = request.node.get_closest_marker("need_serialized_dag")
-    if serialized_marker:
+    if serialized_marker := 
request.node.get_closest_marker("need_serialized_dag"):
         (want_serialized,) = serialized_marker.args or (True,)
+    if serialized_marker := 
request.node.get_closest_marker("want_activate_assets"):
+        (want_activate_assets,) = serialized_marker.args or (True,)
 
     from airflow.utils.log.logging_mixin import LoggingMixin
 
@@ -802,10 +805,26 @@ def dag_maker(request):
                 return self.dagbag.bag_dag(dag, root_dag=dag)
             return self.dagbag.bag_dag(dag)
 
+        def _activate_assets(self):
+            from sqlalchemy import select
+
+            from airflow.jobs.scheduler_job_runner import SchedulerJobRunner
+            from airflow.models.asset import AssetModel, 
DagScheduleAssetReference, TaskOutletAssetReference
+
+            assets = self.session.scalars(
+                select(AssetModel).where(
+                    
AssetModel.consuming_dags.any(DagScheduleAssetReference.dag_id == 
self.dag.dag_id)
+                    | 
AssetModel.producing_tasks.any(TaskOutletAssetReference.dag_id == 
self.dag.dag_id)
+                )
+            ).all()
+            SchedulerJobRunner._activate_referenced_assets(assets, 
session=self.session)
+
         def __exit__(self, type, value, traceback):
             from airflow.models import DagModel
             from airflow.models.serialized_dag import SerializedDagModel
 
+            from tests_common.test_utils.compat import AIRFLOW_V_3_0_PLUS
+
             dag = self.dag
             dag.__exit__(type, value, traceback)
             if type is not None:
@@ -822,6 +841,8 @@ def dag_maker(request):
                 self.session.merge(self.serialized_model)
                 serialized_dag = self._serialized_dag()
                 self._bag_dag_compat(serialized_dag)
+                if AIRFLOW_V_3_0_PLUS and self.want_activate_assets:
+                    self._activate_assets()
                 self.session.flush()
             else:
                 self._bag_dag_compat(self.dag)
@@ -887,6 +908,7 @@ def dag_maker(request):
             dag_id="test_dag",
             schedule=timedelta(days=1),
             serialized=want_serialized,
+            activate_assets=want_activate_assets,
             fileloc=None,
             processor_subdir=None,
             session=None,
@@ -919,6 +941,7 @@ def dag_maker(request):
             self.dag = DAG(dag_id, schedule=schedule, **self.kwargs)
             self.dag.fileloc = fileloc or request.module.__file__
             self.want_serialized = serialized
+            self.want_activate_assets = activate_assets
             self.processor_subdir = processor_subdir
 
             return self

Reply via email to