This is an automated email from the ASF dual-hosted git repository.
vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new ffcc60fc4a9 AIP-82 Send asset change event when trigger fires (#44369)
ffcc60fc4a9 is described below
commit ffcc60fc4a900d922de4abe40896e1e3a579942f
Author: Vincent <[email protected]>
AuthorDate: Wed Dec 4 10:58:22 2024 -0500
AIP-82 Send asset change event when trigger fires (#44369)
---
airflow/jobs/triggerer_job_runner.py | 20 ++++---
airflow/models/trigger.py | 49 +++++++++++++---
tests/models/test_trigger.py | 105 ++++++++++++++++++++++++++++++-----
3 files changed, 144 insertions(+), 30 deletions(-)
diff --git a/airflow/jobs/triggerer_job_runner.py
b/airflow/jobs/triggerer_job_runner.py
index c52a7514346..e44c6709d4b 100644
--- a/airflow/jobs/triggerer_job_runner.py
+++ b/airflow/jobs/triggerer_job_runner.py
@@ -530,11 +530,15 @@ class TriggerRunner(threading.Thread, LoggingMixin):
while self.to_create:
trigger_id, trigger_instance = self.to_create.popleft()
if trigger_id not in self.triggers:
- ti: TaskInstance = trigger_instance.task_instance
+ ti: TaskInstance | None = trigger_instance.task_instance
+ trigger_name = (
+
f"{ti.dag_id}/{ti.run_id}/{ti.task_id}/{ti.map_index}/{ti.try_number} (ID
{trigger_id})"
+ if ti
+ else f"ID {trigger_id}"
+ )
self.triggers[trigger_id] = {
"task": asyncio.create_task(self.run_trigger(trigger_id,
trigger_instance)),
- "name":
f"{ti.dag_id}/{ti.run_id}/{ti.task_id}/{ti.map_index}/{ti.try_number} "
- f"(ID {trigger_id})",
+ "name": trigger_name,
"events": 0,
}
else:
@@ -636,13 +640,14 @@ class TriggerRunner(threading.Thread, LoggingMixin):
name = self.triggers[trigger_id]["name"]
self.log.info("trigger %s starting", name)
try:
- self.set_individual_trigger_logging(trigger)
+ if trigger.task_instance:
+ self.set_individual_trigger_logging(trigger)
async for event in trigger.run():
self.log.info("Trigger %s fired: %s",
self.triggers[trigger_id]["name"], event)
self.triggers[trigger_id]["events"] += 1
self.events.append((trigger_id, event))
except asyncio.CancelledError:
- if timeout := trigger.task_instance.trigger_timeout:
+ if timeout := trigger.task_instance and
trigger.task_instance.trigger_timeout:
timeout = timeout.replace(tzinfo=timezone.utc) if not
timeout.tzinfo else timeout
if timeout < timezone.utcnow():
self.log.error("Trigger cancelled due to timeout")
@@ -696,6 +701,7 @@ class TriggerRunner(threading.Thread, LoggingMixin):
cancel_trigger_ids = running_trigger_ids - requested_trigger_ids
# Bulk-fetch new trigger records
new_triggers = Trigger.bulk_fetch(new_trigger_ids)
+ triggers_with_assets = Trigger.fetch_trigger_ids_with_asset()
# Add in new triggers
for new_id in new_trigger_ids:
# Check it didn't vanish in the meantime
@@ -711,11 +717,11 @@ class TriggerRunner(threading.Thread, LoggingMixin):
self.failed_triggers.append((new_id, e))
continue
- # If new_trigger_orm.task_instance is None, this means the
TaskInstance
+ # If the trigger is not associated to a task or an asset, this
means the TaskInstance
# row was updated by either Trigger.submit_event or
Trigger.submit_failure
# and can happen when a single trigger Job is being run on
multiple TriggerRunners
# in a High-Availability setup.
- if new_trigger_orm.task_instance is None:
+ if new_trigger_orm.task_instance is None and new_id not in
triggers_with_assets:
self.log.info(
(
"TaskInstance for Trigger ID %s is None. It was likely
updated by another trigger job. "
diff --git a/airflow/models/trigger.py b/airflow/models/trigger.py
index f56512cdbc1..5a46fcbda27 100644
--- a/airflow/models/trigger.py
+++ b/airflow/models/trigger.py
@@ -26,6 +26,7 @@ from sqlalchemy import Column, Integer, String, Text, delete,
func, or_, select,
from sqlalchemy.orm import relationship, selectinload
from sqlalchemy.sql.functions import coalesce
+from airflow.assets.manager import AssetManager
from airflow.models.asset import asset_trigger_association_table
from airflow.models.base import Base
from airflow.models.taskinstance import TaskInstance
@@ -180,15 +181,21 @@ class Trigger(Base):
)
return {obj.id: obj for obj in session.scalars(stmt)}
+ @classmethod
+ @provide_session
+ def fetch_trigger_ids_with_asset(cls, session: Session = NEW_SESSION) ->
set[str]:
+ """Fetch all the trigger IDs associated with at least one asset."""
+ query = select(asset_trigger_association_table.columns.trigger_id)
+ return {trigger_id for trigger_id in session.scalars(query)}
+
@classmethod
@provide_session
def clean_unused(cls, session: Session = NEW_SESSION) -> None:
"""
- Delete all triggers that have no tasks dependent on them.
+ Delete all triggers that have no tasks dependent on them and are not
associated to an asset.
- Triggers have a one-to-many relationship to task instances, so we need
- to clean those up first. Afterwards we can drop the triggers not
- referenced by anyone.
+ Triggers have a one-to-many relationship to task instances, so we need
to clean those up first.
+ Afterward we can drop the triggers not referenced by anyone.
"""
# Update all task instances with trigger IDs that are not DEFERRED to
remove them
for attempt in run_with_db_retries():
@@ -201,9 +208,10 @@ class Trigger(Base):
.values(trigger_id=None)
)
- # Get all triggers that have no task instances depending on them and
delete them
+ # Get all triggers that have no task instances and assets depending on
them and delete them
ids = (
select(cls.id)
+ .where(~cls.assets.any())
.join(TaskInstance, cls.id == TaskInstance.trigger_id,
isouter=True)
.group_by(cls.id)
.having(func.count(TaskInstance.trigger_id) == 0)
@@ -218,7 +226,13 @@ class Trigger(Base):
@classmethod
@provide_session
def submit_event(cls, trigger_id, event, session: Session = NEW_SESSION)
-> None:
- """Take an event from an instance of itself, and trigger all dependent
tasks to resume."""
+ """
+ Fire an event.
+
+ Resume all tasks that were in deferred state.
+ Send an event to all assets associated to the trigger.
+ """
+ # Resume deferred tasks
for task_instance in session.scalars(
select(TaskInstance).where(
TaskInstance.trigger_id == trigger_id, TaskInstance.state ==
TaskInstanceState.DEFERRED
@@ -226,6 +240,14 @@ class Trigger(Base):
):
event.handle_submit(task_instance=task_instance)
+ # Send an event to assets
+ trigger = session.scalars(select(cls).where(cls.id ==
trigger_id)).one()
+ for asset in trigger.assets:
+ AssetManager.register_asset_change(
+ asset=asset.to_public(),
+ session=session,
+ )
+
@classmethod
@provide_session
def submit_failure(cls, trigger_id, exc=None, session: Session =
NEW_SESSION) -> None:
@@ -264,7 +286,7 @@ class Trigger(Base):
@classmethod
@provide_session
def ids_for_triggerer(cls, triggerer_id, session: Session = NEW_SESSION)
-> list[int]:
- """Retrieve a list of triggerer_ids."""
+ """Retrieve a list of trigger ids."""
return session.scalars(select(cls.id).where(cls.triggerer_id ==
triggerer_id)).all()
@classmethod
@@ -326,4 +348,15 @@ class Trigger(Base):
session,
skip_locked=True,
)
- return session.execute(query).all()
+ ti_triggers = session.execute(query).all()
+
+ query = with_row_locks(
+
select(cls.id).where(cls.assets.any()).order_by(cls.created_date).limit(capacity),
+ session,
+ skip_locked=True,
+ )
+ asset_triggers = session.execute(query).all()
+
+ # Add triggers associated to assets after triggers associated to tasks
+ # It prioritizes DAGs over event driven scheduling which is fair
+ return ti_triggers + asset_triggers
diff --git a/tests/models/test_trigger.py b/tests/models/test_trigger.py
index 235c6585798..97c2b102082 100644
--- a/tests/models/test_trigger.py
+++ b/tests/models/test_trigger.py
@@ -30,6 +30,7 @@ from cryptography.fernet import Fernet
from airflow.jobs.job import Job
from airflow.jobs.triggerer_job_runner import TriggererJobRunner
from airflow.models import TaskInstance, Trigger, XCom
+from airflow.models.asset import AssetEvent, AssetModel,
asset_trigger_association_table
from airflow.operators.empty import EmptyOperator
from airflow.serialization.serialized_objects import BaseSerialization
from airflow.triggers.base import (
@@ -59,48 +60,92 @@ def session():
@pytest.fixture(autouse=True)
def clear_db(session):
session.query(TaskInstance).delete()
+ session.query(asset_trigger_association_table).delete()
session.query(Trigger).delete()
+ session.query(AssetModel).delete()
+ session.query(AssetEvent).delete()
session.query(Job).delete()
yield session
session.query(TaskInstance).delete()
+ session.query(asset_trigger_association_table).delete()
session.query(Trigger).delete()
+ session.query(AssetModel).delete()
+ session.query(AssetEvent).delete()
session.query(Job).delete()
session.commit()
+def test_fetch_trigger_ids_with_asset(session):
+ # Create triggers
+ trigger1 = Trigger(classpath="airflow.triggers.testing.SuccessTrigger1",
kwargs={})
+ trigger1.id = 1
+ trigger2 = Trigger(classpath="airflow.triggers.testing.SuccessTrigger2",
kwargs={})
+ trigger2.id = 2
+ session.add(trigger1)
+ session.add(trigger2)
+ # Create assets
+ asset = AssetModel("test")
+ asset.triggers.extend([trigger1])
+ session.add(asset)
+ session.commit()
+
+ results = Trigger.fetch_trigger_ids_with_asset()
+ assert results == {1}
+
+
def test_clean_unused(session, create_task_instance):
"""
Tests that unused triggers (those with no task instances referencing them)
are cleaned out automatically.
"""
- # Make three triggers
- trigger1 = Trigger(classpath="airflow.triggers.testing.SuccessTrigger",
kwargs={})
+ # Create triggers
+ trigger1 = Trigger(classpath="airflow.triggers.testing.SuccessTrigger1",
kwargs={})
trigger1.id = 1
- trigger2 = Trigger(classpath="airflow.triggers.testing.SuccessTrigger",
kwargs={})
+ trigger2 = Trigger(classpath="airflow.triggers.testing.SuccessTrigger2",
kwargs={})
trigger2.id = 2
- trigger3 = Trigger(classpath="airflow.triggers.testing.SuccessTrigger",
kwargs={})
+ trigger3 = Trigger(classpath="airflow.triggers.testing.SuccessTrigger3",
kwargs={})
trigger3.id = 3
+ trigger4 = Trigger(classpath="airflow.triggers.testing.SuccessTrigger4",
kwargs={})
+ trigger4.id = 4
+ trigger5 = Trigger(classpath="airflow.triggers.testing.SuccessTrigger5",
kwargs={})
+ trigger5.id = 5
session.add(trigger1)
session.add(trigger2)
session.add(trigger3)
+ session.add(trigger4)
+ session.add(trigger5)
session.commit()
- assert session.query(Trigger).count() == 3
+ assert session.query(Trigger).count() == 5
# Tie one to a fake TaskInstance that is not deferred, and one to one that
is
task_instance = create_task_instance(
session=session, task_id="fake", state=State.DEFERRED,
logical_date=timezone.utcnow()
)
task_instance.trigger_id = trigger1.id
session.add(task_instance)
- fake_task = EmptyOperator(task_id="fake2", dag=task_instance.task.dag)
- task_instance = TaskInstance(task=fake_task, run_id=task_instance.run_id)
- task_instance.state = State.SUCCESS
- task_instance.trigger_id = trigger2.id
- session.add(task_instance)
+ fake_task1 = EmptyOperator(task_id="fake2", dag=task_instance.task.dag)
+ task_instance1 = TaskInstance(task=fake_task1, run_id=task_instance.run_id)
+ task_instance1.state = State.SUCCESS
+ task_instance1.trigger_id = trigger2.id
+ session.add(task_instance1)
+ fake_task2 = EmptyOperator(task_id="fake3", dag=task_instance.task.dag)
+ task_instance2 = TaskInstance(task=fake_task2, run_id=task_instance.run_id)
+ task_instance2.state = State.SUCCESS
+ task_instance2.trigger_id = trigger4.id
+ session.add(task_instance2)
+ session.commit()
+
+ # Create assets
+ asset = AssetModel("test")
+ asset.triggers.extend([trigger4, trigger5])
+ session.add(asset)
session.commit()
+ assert session.query(AssetModel).count() == 1
+
# Run clear operation
Trigger.clean_unused()
- # Verify that one trigger is gone, and the right one is left
- assert session.query(Trigger).one().id == trigger1.id
+ results = session.query(Trigger).all()
+ assert len(results) == 3
+ assert {result.id for result in results} == {1, 4, 5}
def test_submit_event(session, create_task_instance):
@@ -120,6 +165,15 @@ def test_submit_event(session, create_task_instance):
task_instance.trigger_id = trigger.id
task_instance.next_kwargs = {"cheesecake": True}
session.commit()
+ # Create assets
+ asset = AssetModel("test")
+ asset.id = 1
+ asset.triggers.extend([trigger])
+ session.add(asset)
+ session.commit()
+
+ # Check that the asset has 0 event prior to sending an event to the trigger
+ assert session.query(AssetEvent).filter_by(asset_id=asset.id).count() == 0
# Call submit_event
Trigger.submit_event(trigger.id, TriggerEvent(42), session=session)
# commit changes made by submit event and expire all cache to read from db.
@@ -128,6 +182,8 @@ def test_submit_event(session, create_task_instance):
updated_task_instance = session.query(TaskInstance).one()
assert updated_task_instance.state == State.SCHEDULED
assert updated_task_instance.next_kwargs == {"event": 42, "cheesecake":
True}
+ # Check that the asset has received an event
+ assert session.query(AssetEvent).filter_by(asset_id=asset.id).count() == 1
def test_submit_failure(session, create_task_instance):
@@ -349,13 +405,32 @@ def
test_get_sorted_triggers_same_priority_weight(session, create_task_instance)
TI_new.priority_weight = 1
TI_new.trigger_id = trigger_new.id
session.add(TI_new)
-
+ trigger_orphan = Trigger(
+ classpath="airflow.triggers.testing.TriggerOrphan",
+ kwargs={},
+ created_date=new_logical_date,
+ )
+ trigger_orphan.id = 3
+ session.add(trigger_orphan)
+ trigger_asset = Trigger(
+ classpath="airflow.triggers.testing.TriggerAsset",
+ kwargs={},
+ created_date=new_logical_date,
+ )
+ trigger_asset.id = 4
+ session.add(trigger_asset)
+ session.commit()
+ assert session.query(Trigger).count() == 4
+ # Create assets
+ asset = AssetModel("test")
+ asset.id = 1
+ asset.triggers.extend([trigger_asset])
+ session.add(asset)
session.commit()
- assert session.query(Trigger).count() == 2
trigger_ids_query = Trigger.get_sorted_triggers(capacity=100,
alive_triggerer_ids=[], session=session)
- assert trigger_ids_query == [(1,), (2,)]
+ assert trigger_ids_query == [(1,), (2,), (4,)]
def test_get_sorted_triggers_different_priority_weights(session,
create_task_instance):