This is an automated email from the ASF dual-hosted git repository.
dstandish pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new b698b7b5c6b Update scheduler to create dag runs for
partitioned-asset-driven dags (#59006)
b698b7b5c6b is described below
commit b698b7b5c6ba4b5b71a73c035888f6ae776bd16c
Author: Daniel Standish <[email protected]>
AuthorDate: Fri Dec 5 08:08:08 2025 -0800
Update scheduler to create dag runs for partitioned-asset-driven dags
(#59006)
Now that the events and log records needed for partition-driven scheduling
are being recorded, we update the scheduler to create dag runs based on that.
---------
Co-authored-by: Wei Lee <[email protected]>
---
.../src/airflow/jobs/scheduler_job_runner.py | 50 ++++++++-
airflow-core/src/airflow/timetables/simple.py | 2 +-
airflow-core/tests/unit/jobs/test_scheduler_job.py | 119 +++++++++++++++------
.../tests/unit/models/test_taskinstance.py | 2 +-
devel-common/src/tests_common/test_utils/db.py | 2 +
5 files changed, 136 insertions(+), 39 deletions(-)
diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py
b/airflow-core/src/airflow/jobs/scheduler_job_runner.py
index 7a43f6416a6..2ae008026c7 100644
--- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py
+++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py
@@ -59,6 +59,7 @@ from airflow.models.asset import (
AssetDagRunQueue,
AssetEvent,
AssetModel,
+ AssetPartitionDagRun,
AssetWatcherModel,
DagScheduleAssetAliasReference,
DagScheduleAssetReference,
@@ -1672,19 +1673,58 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
return num_queued_tis
+ def _create_dagruns_for_partitioned_asset_dags(self, session: Session) ->
set[str]:
+ partition_dag_ids: set[str] = set()
+ apdrs: Iterable[AssetPartitionDagRun] = session.scalars(
+
select(AssetPartitionDagRun).where(AssetPartitionDagRun.created_dag_run_id.is_(None))
+ )
+ for apdr in apdrs:
+ partition_dag_ids.add(apdr.target_dag_id)
+ dag = _get_current_dag(dag_id=apdr.target_dag_id, session=session)
+ if not dag:
+ self.log.error("Dag '%s' not found in serialized_dag table",
apdr.target_dag_id)
+ continue
+
+ run_after = timezone.utcnow()
+ dag_run = dag.create_dagrun(
+ run_id=DagRun.generate_run_id(
+ run_type=DagRunType.ASSET_TRIGGERED, logical_date=None,
run_after=run_after
+ ),
+ logical_date=None,
+ data_interval=None,
+ partition_key=apdr.partition_key,
+ run_after=run_after,
+ run_type=DagRunType.ASSET_TRIGGERED,
+ triggered_by=DagRunTriggeredByType.ASSET,
+ state=DagRunState.QUEUED,
+ creating_job_id=self.job.id,
+ session=session,
+ )
+ session.flush()
+ apdr.created_dag_run_id = dag_run.id
+ session.flush()
+
+ return partition_dag_ids
+
@retry_db_transaction
def _create_dagruns_for_dags(self, guard: CommitProhibitorGuard, session:
Session) -> None:
"""Find Dag Models needing DagRuns and Create Dag Runs with retries in
case of OperationalError."""
+ partition_dag_ids: set[str] =
self._create_dagruns_for_partitioned_asset_dags(session)
+
query, triggered_date_by_dag = DagModel.dags_needing_dagruns(session)
all_dags_needing_dag_runs = set(query.all())
- asset_triggered_dags = [
- dag for dag in all_dags_needing_dag_runs if dag.dag_id in
triggered_date_by_dag
- ]
- non_asset_dags =
all_dags_needing_dag_runs.difference(asset_triggered_dags)
+ asset_triggered_dags = [d for d in all_dags_needing_dag_runs if
d.dag_id in triggered_date_by_dag]
+ non_asset_dags = {
+ d
+ # filter asset-triggered Dags
+ for d in all_dags_needing_dag_runs.difference(asset_triggered_dags)
+ # filter asset partition triggered Dags
+ if d.dag_id not in partition_dag_ids
+ }
self._create_dag_runs(non_asset_dags, session)
if asset_triggered_dags:
self._create_dag_runs_asset_triggered(
- dag_models=asset_triggered_dags,
+ dag_models=[d for d in asset_triggered_dags if d.dag_id not in
partition_dag_ids],
triggered_date_by_dag=triggered_date_by_dag,
session=session,
)
diff --git a/airflow-core/src/airflow/timetables/simple.py
b/airflow-core/src/airflow/timetables/simple.py
index a0108029458..76af6a844e4 100644
--- a/airflow-core/src/airflow/timetables/simple.py
+++ b/airflow-core/src/airflow/timetables/simple.py
@@ -260,7 +260,7 @@ class PartitionedAssetTimetable(AssetTriggeredTimetable):
def summary(self) -> str:
return "Partitioned Asset"
- def __init__(self, assets: BaseAsset, partition_mapper: PartitionMapper)
-> None:
+ def __init__(self, *, assets: BaseAsset, partition_mapper:
PartitionMapper) -> None:
super().__init__(assets=assets)
self.asset_condition = assets
self.partition_mapper = partition_mapper
diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py
b/airflow-core/tests/unit/jobs/test_scheduler_job.py
index 5f03ab7aa4b..767159bde59 100644
--- a/airflow-core/tests/unit/jobs/test_scheduler_job.py
+++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py
@@ -52,7 +52,15 @@ from airflow.executors.executor_loader import ExecutorLoader
from airflow.executors.executor_utils import ExecutorName
from airflow.jobs.job import Job, run_job
from airflow.jobs.scheduler_job_runner import SchedulerJobRunner
-from airflow.models.asset import AssetActive, AssetAliasModel,
AssetDagRunQueue, AssetEvent, AssetModel
+from airflow.models.asset import (
+ AssetActive,
+ AssetAliasModel,
+ AssetDagRunQueue,
+ AssetEvent,
+ AssetModel,
+ AssetPartitionDagRun,
+ PartitionedAssetKeyLog,
+)
from airflow.models.backfill import Backfill, _create_backfill
from airflow.models.dag import DagModel, get_last_dagrun,
infer_automated_data_interval
from airflow.models.dag_version import DagVersion
@@ -75,6 +83,7 @@ from airflow.sdk import DAG, Asset, AssetAlias, AssetWatcher,
task
from airflow.sdk.definitions.callback import AsyncCallback, SyncCallback
from airflow.serialization.serialized_objects import LazyDeserializedDAG,
SerializedDAG
from airflow.timetables.base import DataInterval
+from airflow.timetables.simple import IdentityMapper, PartitionedAssetTimetable
from airflow.utils.session import create_session, provide_session
from airflow.utils.span_status import SpanStatus
from airflow.utils.state import DagRunState, State, TaskInstanceState
@@ -96,7 +105,6 @@ from tests_common.test_utils.db import (
clear_db_jobs,
clear_db_pools,
clear_db_runs,
- clear_db_serialized_dags,
clear_db_teams,
clear_db_triggers,
set_default_pool_slots,
@@ -181,33 +189,33 @@ def create_dagrun(session):
return _create_dagrun
+def _clean_db():
+ clear_db_dags()
+ clear_db_runs()
+ clear_db_backfills()
+ clear_db_pools()
+ clear_db_import_errors()
+ clear_db_jobs()
+ clear_db_assets()
+ clear_db_deadline()
+ clear_db_callbacks()
+ clear_db_triggers()
+
+
@patch.dict(
ExecutorLoader.executors, {MOCK_EXECUTOR:
f"{MockExecutor.__module__}.{MockExecutor.__qualname__}"}
)
@pytest.mark.usefixtures("disable_load_example")
@pytest.mark.need_serialized_dag
class TestSchedulerJob:
- @staticmethod
- def clean_db():
- clear_db_dags()
- clear_db_runs()
- clear_db_backfills()
- clear_db_pools()
- clear_db_import_errors()
- clear_db_jobs()
- clear_db_assets()
- clear_db_deadline()
- clear_db_callbacks()
- clear_db_triggers()
-
@pytest.fixture(autouse=True)
def per_test(self) -> Generator:
- self.clean_db()
+ _clean_db()
self.job_runner: SchedulerJobRunner | None = None
yield
- self.clean_db()
+ _clean_db()
@pytest.fixture(autouse=True)
def set_instance_attrs(self) -> Generator:
@@ -5024,7 +5032,7 @@ class TestSchedulerJob:
ti.state = State.SUCCESS
session.flush()
- self.clean_db()
+ _clean_db()
# Explicitly set catchup=True as test specifically expects runs to be
created in date order
with dag_maker(max_active_runs=3, session=session, catchup=True) as
dag:
@@ -6371,7 +6379,7 @@ class TestSchedulerJob:
assert ti1.next_method == "__fail__"
assert ti2.state == State.DEFERRED
finally:
- self.clean_db()
+ _clean_db()
# Positive case, will retry until success before reach max retry times
check_if_trigger_timeout(retry_times)
@@ -7944,24 +7952,13 @@ class TestSchedulerJobQueriesCount:
scheduler_job: Job | None
- @staticmethod
- def clean_db():
- clear_db_runs()
- clear_db_pools()
- clear_db_backfills()
- clear_db_dags()
- clear_db_dag_bundles()
- clear_db_import_errors()
- clear_db_jobs()
- clear_db_serialized_dags()
-
@pytest.fixture(autouse=True)
def per_test(self) -> Generator:
- self.clean_db()
+ _clean_db()
yield
- self.clean_db()
+ _clean_db()
@pytest.mark.parametrize(
("expected_query_count", "dag_count", "task_count"),
@@ -8140,3 +8137,61 @@ def test_mark_backfills_completed(dag_maker, session):
runner._mark_backfills_complete()
b = session.get(Backfill, b.id)
assert b.completed_at.timestamp() > 0
+
+
+def
test_when_dag_run_has_partition_and_downstreams_listening_then_tables_populated(
+ dag_maker,
+ session,
+):
+ asset = Asset(name="hello")
+ with dag_maker(dag_id="asset_event_tester", schedule=None,
session=session) as dag:
+ EmptyOperator(task_id="hi", outlets=[asset])
+ dag1_id = dag.dag_id
+ dr = dag_maker.create_dagrun(partition_key="abc123", session=session)
+ assert dr.partition_key == "abc123"
+ [ti] = dr.get_task_instances(session=session)
+ session.commit()
+
+ with dag_maker(
+ dag_id="asset_event_listener",
+ schedule=PartitionedAssetTimetable(assets=asset,
partition_mapper=IdentityMapper()),
+ session=session,
+ ):
+ EmptyOperator(task_id="hi")
+ session.commit()
+
+ TaskInstance.register_asset_changes_in_db(
+ ti=ti,
+ task_outlets=[asset.asprofile()],
+ outlet_events=[],
+ session=session,
+ )
+ session.commit()
+ event = session.scalar(
+ select(AssetEvent).where(
+ AssetEvent.source_dag_id == dag1_id,
+ AssetEvent.source_run_id == dr.run_id,
+ )
+ )
+ assert event.partition_key == "abc123"
+ pakl = session.scalar(
+ select(PartitionedAssetKeyLog).where(
+ PartitionedAssetKeyLog.asset_event_id == event.id,
+ )
+ )
+ apdr = session.scalar(
+ select(AssetPartitionDagRun).where(AssetPartitionDagRun.id ==
pakl.asset_partition_dag_run_id)
+ )
+ assert apdr is not None
+ assert apdr.created_dag_run_id is None
+ # ok, now we have established that the needed rows are there.
+ # let's see what the scheduler does
+
+ runner = SchedulerJobRunner(
+ job=Job(job_type=SchedulerJobRunner.job_type,
executor=MockExecutor(do_update=False))
+ )
+ partition_dags =
runner._create_dagruns_for_partitioned_asset_dags(session=session)
+ session.refresh(apdr)
+ assert apdr.created_dag_run_id is not None
+ assert len(partition_dags) == 1
+ assert partition_dags == {"asset_event_listener"}
diff --git a/airflow-core/tests/unit/models/test_taskinstance.py
b/airflow-core/tests/unit/models/test_taskinstance.py
index 2e50fee1eac..8b70e9abecd 100644
--- a/airflow-core/tests/unit/models/test_taskinstance.py
+++ b/airflow-core/tests/unit/models/test_taskinstance.py
@@ -3304,7 +3304,7 @@ def
test_when_dag_run_has_partition_and_downstreams_listening_then_tables_popula
with dag_maker(
dag_id="asset_event_listener",
- schedule=PartitionedAssetTimetable(asset, IdentityMapper()),
+ schedule=PartitionedAssetTimetable(assets=asset,
partition_mapper=IdentityMapper()),
session=session,
):
EmptyOperator(task_id="hi")
diff --git a/devel-common/src/tests_common/test_utils/db.py
b/devel-common/src/tests_common/test_utils/db.py
index 89a79f91216..a643a1ad432 100644
--- a/devel-common/src/tests_common/test_utils/db.py
+++ b/devel-common/src/tests_common/test_utils/db.py
@@ -915,6 +915,8 @@ def create_default_connections_for_tests():
def clear_all():
clear_db_runs()
clear_db_assets()
+ clear_db_apdr()
+ clear_db_pakl()
clear_db_triggers()
clear_db_dags()
clear_db_serialized_dags()