This is an automated email from the ASF dual-hosted git repository.

vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new ffcc60fc4a9 AIP-82 Send asset change event when trigger fires (#44369)
ffcc60fc4a9 is described below

commit ffcc60fc4a900d922de4abe40896e1e3a579942f
Author: Vincent <[email protected]>
AuthorDate: Wed Dec 4 10:58:22 2024 -0500

    AIP-82 Send asset change event when trigger fires (#44369)
---
 airflow/jobs/triggerer_job_runner.py |  20 ++++---
 airflow/models/trigger.py            |  49 +++++++++++++---
 tests/models/test_trigger.py         | 105 ++++++++++++++++++++++++++++++-----
 3 files changed, 144 insertions(+), 30 deletions(-)

diff --git a/airflow/jobs/triggerer_job_runner.py 
b/airflow/jobs/triggerer_job_runner.py
index c52a7514346..e44c6709d4b 100644
--- a/airflow/jobs/triggerer_job_runner.py
+++ b/airflow/jobs/triggerer_job_runner.py
@@ -530,11 +530,15 @@ class TriggerRunner(threading.Thread, LoggingMixin):
         while self.to_create:
             trigger_id, trigger_instance = self.to_create.popleft()
             if trigger_id not in self.triggers:
-                ti: TaskInstance = trigger_instance.task_instance
+                ti: TaskInstance | None = trigger_instance.task_instance
+                trigger_name = (
+                    
f"{ti.dag_id}/{ti.run_id}/{ti.task_id}/{ti.map_index}/{ti.try_number} (ID 
{trigger_id})"
+                    if ti
+                    else f"ID {trigger_id}"
+                )
                 self.triggers[trigger_id] = {
                     "task": asyncio.create_task(self.run_trigger(trigger_id, 
trigger_instance)),
-                    "name": 
f"{ti.dag_id}/{ti.run_id}/{ti.task_id}/{ti.map_index}/{ti.try_number} "
-                    f"(ID {trigger_id})",
+                    "name": trigger_name,
                     "events": 0,
                 }
             else:
@@ -636,13 +640,14 @@ class TriggerRunner(threading.Thread, LoggingMixin):
         name = self.triggers[trigger_id]["name"]
         self.log.info("trigger %s starting", name)
         try:
-            self.set_individual_trigger_logging(trigger)
+            if trigger.task_instance:
+                self.set_individual_trigger_logging(trigger)
             async for event in trigger.run():
                 self.log.info("Trigger %s fired: %s", 
self.triggers[trigger_id]["name"], event)
                 self.triggers[trigger_id]["events"] += 1
                 self.events.append((trigger_id, event))
         except asyncio.CancelledError:
-            if timeout := trigger.task_instance.trigger_timeout:
+            if timeout := trigger.task_instance and 
trigger.task_instance.trigger_timeout:
                 timeout = timeout.replace(tzinfo=timezone.utc) if not 
timeout.tzinfo else timeout
                 if timeout < timezone.utcnow():
                     self.log.error("Trigger cancelled due to timeout")
@@ -696,6 +701,7 @@ class TriggerRunner(threading.Thread, LoggingMixin):
         cancel_trigger_ids = running_trigger_ids - requested_trigger_ids
         # Bulk-fetch new trigger records
         new_triggers = Trigger.bulk_fetch(new_trigger_ids)
+        triggers_with_assets = Trigger.fetch_trigger_ids_with_asset()
         # Add in new triggers
         for new_id in new_trigger_ids:
             # Check it didn't vanish in the meantime
@@ -711,11 +717,11 @@ class TriggerRunner(threading.Thread, LoggingMixin):
                 self.failed_triggers.append((new_id, e))
                 continue
 
-            # If new_trigger_orm.task_instance is None, this means the 
TaskInstance
+            # If the trigger is not associated to a task or an asset, this 
means the TaskInstance
             # row was updated by either Trigger.submit_event or 
Trigger.submit_failure
             # and can happen when a single trigger Job is being run on 
multiple TriggerRunners
             # in a High-Availability setup.
-            if new_trigger_orm.task_instance is None:
+            if new_trigger_orm.task_instance is None and new_id not in 
triggers_with_assets:
                 self.log.info(
                     (
                         "TaskInstance for Trigger ID %s is None. It was likely 
updated by another trigger job. "
diff --git a/airflow/models/trigger.py b/airflow/models/trigger.py
index f56512cdbc1..5a46fcbda27 100644
--- a/airflow/models/trigger.py
+++ b/airflow/models/trigger.py
@@ -26,6 +26,7 @@ from sqlalchemy import Column, Integer, String, Text, delete, 
func, or_, select,
 from sqlalchemy.orm import relationship, selectinload
 from sqlalchemy.sql.functions import coalesce
 
+from airflow.assets.manager import AssetManager
 from airflow.models.asset import asset_trigger_association_table
 from airflow.models.base import Base
 from airflow.models.taskinstance import TaskInstance
@@ -180,15 +181,21 @@ class Trigger(Base):
         )
         return {obj.id: obj for obj in session.scalars(stmt)}
 
+    @classmethod
+    @provide_session
+    def fetch_trigger_ids_with_asset(cls, session: Session = NEW_SESSION) -> 
set[str]:
+        """Fetch all the trigger IDs associated with at least one asset."""
+        query = select(asset_trigger_association_table.columns.trigger_id)
+        return {trigger_id for trigger_id in session.scalars(query)}
+
     @classmethod
     @provide_session
     def clean_unused(cls, session: Session = NEW_SESSION) -> None:
         """
-        Delete all triggers that have no tasks dependent on them.
+        Delete all triggers that have no tasks dependent on them and are not 
associated to an asset.
 
-        Triggers have a one-to-many relationship to task instances, so we need
-        to clean those up first. Afterwards we can drop the triggers not
-        referenced by anyone.
+        Triggers have a one-to-many relationship to task instances, so we need 
to clean those up first.
+        Afterward we can drop the triggers not referenced by anyone.
         """
         # Update all task instances with trigger IDs that are not DEFERRED to 
remove them
         for attempt in run_with_db_retries():
@@ -201,9 +208,10 @@ class Trigger(Base):
                     .values(trigger_id=None)
                 )
 
-        # Get all triggers that have no task instances depending on them and 
delete them
+        # Get all triggers that have no task instances and assets depending on 
them and delete them
         ids = (
             select(cls.id)
+            .where(~cls.assets.any())
             .join(TaskInstance, cls.id == TaskInstance.trigger_id, 
isouter=True)
             .group_by(cls.id)
             .having(func.count(TaskInstance.trigger_id) == 0)
@@ -218,7 +226,13 @@ class Trigger(Base):
     @classmethod
     @provide_session
     def submit_event(cls, trigger_id, event, session: Session = NEW_SESSION) 
-> None:
-        """Take an event from an instance of itself, and trigger all dependent 
tasks to resume."""
+        """
+        Fire an event.
+
+        Resume all tasks that were in deferred state.
+        Send an event to all assets associated to the trigger.
+        """
+        # Resume deferred tasks
         for task_instance in session.scalars(
             select(TaskInstance).where(
                 TaskInstance.trigger_id == trigger_id, TaskInstance.state == 
TaskInstanceState.DEFERRED
@@ -226,6 +240,14 @@ class Trigger(Base):
         ):
             event.handle_submit(task_instance=task_instance)
 
+        # Send an event to assets
+        trigger = session.scalars(select(cls).where(cls.id == 
trigger_id)).one()
+        for asset in trigger.assets:
+            AssetManager.register_asset_change(
+                asset=asset.to_public(),
+                session=session,
+            )
+
     @classmethod
     @provide_session
     def submit_failure(cls, trigger_id, exc=None, session: Session = 
NEW_SESSION) -> None:
@@ -264,7 +286,7 @@ class Trigger(Base):
     @classmethod
     @provide_session
     def ids_for_triggerer(cls, triggerer_id, session: Session = NEW_SESSION) 
-> list[int]:
-        """Retrieve a list of triggerer_ids."""
+        """Retrieve a list of trigger ids."""
         return session.scalars(select(cls.id).where(cls.triggerer_id == 
triggerer_id)).all()
 
     @classmethod
@@ -326,4 +348,15 @@ class Trigger(Base):
             session,
             skip_locked=True,
         )
-        return session.execute(query).all()
+        ti_triggers = session.execute(query).all()
+
+        query = with_row_locks(
+            
select(cls.id).where(cls.assets.any()).order_by(cls.created_date).limit(capacity),
+            session,
+            skip_locked=True,
+        )
+        asset_triggers = session.execute(query).all()
+
+        # Add triggers associated to assets after triggers associated to tasks
+        # It prioritizes DAGs over event driven scheduling which is fair
+        return ti_triggers + asset_triggers
diff --git a/tests/models/test_trigger.py b/tests/models/test_trigger.py
index 235c6585798..97c2b102082 100644
--- a/tests/models/test_trigger.py
+++ b/tests/models/test_trigger.py
@@ -30,6 +30,7 @@ from cryptography.fernet import Fernet
 from airflow.jobs.job import Job
 from airflow.jobs.triggerer_job_runner import TriggererJobRunner
 from airflow.models import TaskInstance, Trigger, XCom
+from airflow.models.asset import AssetEvent, AssetModel, 
asset_trigger_association_table
 from airflow.operators.empty import EmptyOperator
 from airflow.serialization.serialized_objects import BaseSerialization
 from airflow.triggers.base import (
@@ -59,48 +60,92 @@ def session():
 @pytest.fixture(autouse=True)
 def clear_db(session):
     session.query(TaskInstance).delete()
+    session.query(asset_trigger_association_table).delete()
     session.query(Trigger).delete()
+    session.query(AssetModel).delete()
+    session.query(AssetEvent).delete()
     session.query(Job).delete()
     yield session
     session.query(TaskInstance).delete()
+    session.query(asset_trigger_association_table).delete()
     session.query(Trigger).delete()
+    session.query(AssetModel).delete()
+    session.query(AssetEvent).delete()
     session.query(Job).delete()
     session.commit()
 
 
+def test_fetch_trigger_ids_with_asset(session):
+    # Create triggers
+    trigger1 = Trigger(classpath="airflow.triggers.testing.SuccessTrigger1", 
kwargs={})
+    trigger1.id = 1
+    trigger2 = Trigger(classpath="airflow.triggers.testing.SuccessTrigger2", 
kwargs={})
+    trigger2.id = 2
+    session.add(trigger1)
+    session.add(trigger2)
+    # Create assets
+    asset = AssetModel("test")
+    asset.triggers.extend([trigger1])
+    session.add(asset)
+    session.commit()
+
+    results = Trigger.fetch_trigger_ids_with_asset()
+    assert results == {1}
+
+
 def test_clean_unused(session, create_task_instance):
     """
     Tests that unused triggers (those with no task instances referencing them)
     are cleaned out automatically.
     """
-    # Make three triggers
-    trigger1 = Trigger(classpath="airflow.triggers.testing.SuccessTrigger", 
kwargs={})
+    # Create triggers
+    trigger1 = Trigger(classpath="airflow.triggers.testing.SuccessTrigger1", 
kwargs={})
     trigger1.id = 1
-    trigger2 = Trigger(classpath="airflow.triggers.testing.SuccessTrigger", 
kwargs={})
+    trigger2 = Trigger(classpath="airflow.triggers.testing.SuccessTrigger2", 
kwargs={})
     trigger2.id = 2
-    trigger3 = Trigger(classpath="airflow.triggers.testing.SuccessTrigger", 
kwargs={})
+    trigger3 = Trigger(classpath="airflow.triggers.testing.SuccessTrigger3", 
kwargs={})
     trigger3.id = 3
+    trigger4 = Trigger(classpath="airflow.triggers.testing.SuccessTrigger4", 
kwargs={})
+    trigger4.id = 4
+    trigger5 = Trigger(classpath="airflow.triggers.testing.SuccessTrigger5", 
kwargs={})
+    trigger5.id = 5
     session.add(trigger1)
     session.add(trigger2)
     session.add(trigger3)
+    session.add(trigger4)
+    session.add(trigger5)
     session.commit()
-    assert session.query(Trigger).count() == 3
+    assert session.query(Trigger).count() == 5
     # Tie one to a fake TaskInstance that is not deferred, and one to one that 
is
     task_instance = create_task_instance(
         session=session, task_id="fake", state=State.DEFERRED, 
logical_date=timezone.utcnow()
     )
     task_instance.trigger_id = trigger1.id
     session.add(task_instance)
-    fake_task = EmptyOperator(task_id="fake2", dag=task_instance.task.dag)
-    task_instance = TaskInstance(task=fake_task, run_id=task_instance.run_id)
-    task_instance.state = State.SUCCESS
-    task_instance.trigger_id = trigger2.id
-    session.add(task_instance)
+    fake_task1 = EmptyOperator(task_id="fake2", dag=task_instance.task.dag)
+    task_instance1 = TaskInstance(task=fake_task1, run_id=task_instance.run_id)
+    task_instance1.state = State.SUCCESS
+    task_instance1.trigger_id = trigger2.id
+    session.add(task_instance1)
+    fake_task2 = EmptyOperator(task_id="fake3", dag=task_instance.task.dag)
+    task_instance2 = TaskInstance(task=fake_task2, run_id=task_instance.run_id)
+    task_instance2.state = State.SUCCESS
+    task_instance2.trigger_id = trigger4.id
+    session.add(task_instance2)
+    session.commit()
+
+    # Create assets
+    asset = AssetModel("test")
+    asset.triggers.extend([trigger4, trigger5])
+    session.add(asset)
     session.commit()
+    assert session.query(AssetModel).count() == 1
+
     # Run clear operation
     Trigger.clean_unused()
-    # Verify that one trigger is gone, and the right one is left
-    assert session.query(Trigger).one().id == trigger1.id
+    results = session.query(Trigger).all()
+    assert len(results) == 3
+    assert {result.id for result in results} == {1, 4, 5}
 
 
 def test_submit_event(session, create_task_instance):
@@ -120,6 +165,15 @@ def test_submit_event(session, create_task_instance):
     task_instance.trigger_id = trigger.id
     task_instance.next_kwargs = {"cheesecake": True}
     session.commit()
+    # Create assets
+    asset = AssetModel("test")
+    asset.id = 1
+    asset.triggers.extend([trigger])
+    session.add(asset)
+    session.commit()
+
+    # Check that the asset has 0 event prior to sending an event to the trigger
+    assert session.query(AssetEvent).filter_by(asset_id=asset.id).count() == 0
     # Call submit_event
     Trigger.submit_event(trigger.id, TriggerEvent(42), session=session)
     # commit changes made by submit event and expire all cache to read from db.
@@ -128,6 +182,8 @@ def test_submit_event(session, create_task_instance):
     updated_task_instance = session.query(TaskInstance).one()
     assert updated_task_instance.state == State.SCHEDULED
     assert updated_task_instance.next_kwargs == {"event": 42, "cheesecake": 
True}
+    # Check that the asset has received an event
+    assert session.query(AssetEvent).filter_by(asset_id=asset.id).count() == 1
 
 
 def test_submit_failure(session, create_task_instance):
@@ -349,13 +405,32 @@ def 
test_get_sorted_triggers_same_priority_weight(session, create_task_instance)
     TI_new.priority_weight = 1
     TI_new.trigger_id = trigger_new.id
     session.add(TI_new)
-
+    trigger_orphan = Trigger(
+        classpath="airflow.triggers.testing.TriggerOrphan",
+        kwargs={},
+        created_date=new_logical_date,
+    )
+    trigger_orphan.id = 3
+    session.add(trigger_orphan)
+    trigger_asset = Trigger(
+        classpath="airflow.triggers.testing.TriggerAsset",
+        kwargs={},
+        created_date=new_logical_date,
+    )
+    trigger_asset.id = 4
+    session.add(trigger_asset)
+    session.commit()
+    assert session.query(Trigger).count() == 4
+    # Create assets
+    asset = AssetModel("test")
+    asset.id = 1
+    asset.triggers.extend([trigger_asset])
+    session.add(asset)
     session.commit()
-    assert session.query(Trigger).count() == 2
 
     trigger_ids_query = Trigger.get_sorted_triggers(capacity=100, 
alive_triggerer_ids=[], session=session)
 
-    assert trigger_ids_query == [(1,), (2,)]
+    assert trigger_ids_query == [(1,), (2,), (4,)]
 
 
 def test_get_sorted_triggers_different_priority_weights(session, 
create_task_instance):

Reply via email to