This is an automated email from the ASF dual-hosted git repository.

dstandish pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new b698b7b5c6b Update scheduler to create dag runs for 
partitioned-asset-driven dags (#59006)
b698b7b5c6b is described below

commit b698b7b5c6ba4b5b71a73c035888f6ae776bd16c
Author: Daniel Standish <[email protected]>
AuthorDate: Fri Dec 5 08:08:08 2025 -0800

    Update scheduler to create dag runs for partitioned-asset-driven dags 
(#59006)
    
    Now that the events and log records needed for partition-driven scheduling 
are being recorded, we update the scheduler to create dag runs based on that.
    
    ---------
    
    Co-authored-by: Wei Lee <[email protected]>
---
 .../src/airflow/jobs/scheduler_job_runner.py       |  50 ++++++++-
 airflow-core/src/airflow/timetables/simple.py      |   2 +-
 airflow-core/tests/unit/jobs/test_scheduler_job.py | 119 +++++++++++++++------
 .../tests/unit/models/test_taskinstance.py         |   2 +-
 devel-common/src/tests_common/test_utils/db.py     |   2 +
 5 files changed, 136 insertions(+), 39 deletions(-)

diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py 
b/airflow-core/src/airflow/jobs/scheduler_job_runner.py
index 7a43f6416a6..2ae008026c7 100644
--- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py
+++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py
@@ -59,6 +59,7 @@ from airflow.models.asset import (
     AssetDagRunQueue,
     AssetEvent,
     AssetModel,
+    AssetPartitionDagRun,
     AssetWatcherModel,
     DagScheduleAssetAliasReference,
     DagScheduleAssetReference,
@@ -1672,19 +1673,58 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
 
         return num_queued_tis
 
+    def _create_dagruns_for_partitioned_asset_dags(self, session: Session) -> 
set[str]:
+        partition_dag_ids: set[str] = set()
+        apdrs: Iterable[AssetPartitionDagRun] = session.scalars(
+            
select(AssetPartitionDagRun).where(AssetPartitionDagRun.created_dag_run_id.is_(None))
+        )
+        for apdr in apdrs:
+            partition_dag_ids.add(apdr.target_dag_id)
+            dag = _get_current_dag(dag_id=apdr.target_dag_id, session=session)
+            if not dag:
+                self.log.error("Dag '%s' not found in serialized_dag table", 
apdr.target_dag_id)
+                continue
+
+            run_after = timezone.utcnow()
+            dag_run = dag.create_dagrun(
+                run_id=DagRun.generate_run_id(
+                    run_type=DagRunType.ASSET_TRIGGERED, logical_date=None, 
run_after=run_after
+                ),
+                logical_date=None,
+                data_interval=None,
+                partition_key=apdr.partition_key,
+                run_after=run_after,
+                run_type=DagRunType.ASSET_TRIGGERED,
+                triggered_by=DagRunTriggeredByType.ASSET,
+                state=DagRunState.QUEUED,
+                creating_job_id=self.job.id,
+                session=session,
+            )
+            session.flush()
+            apdr.created_dag_run_id = dag_run.id
+            session.flush()
+
+        return partition_dag_ids
+
     @retry_db_transaction
     def _create_dagruns_for_dags(self, guard: CommitProhibitorGuard, session: 
Session) -> None:
         """Find Dag Models needing DagRuns and Create Dag Runs with retries in 
case of OperationalError."""
+        partition_dag_ids: set[str] = 
self._create_dagruns_for_partitioned_asset_dags(session)
+
         query, triggered_date_by_dag = DagModel.dags_needing_dagruns(session)
         all_dags_needing_dag_runs = set(query.all())
-        asset_triggered_dags = [
-            dag for dag in all_dags_needing_dag_runs if dag.dag_id in 
triggered_date_by_dag
-        ]
-        non_asset_dags = 
all_dags_needing_dag_runs.difference(asset_triggered_dags)
+        asset_triggered_dags = [d for d in all_dags_needing_dag_runs if 
d.dag_id in triggered_date_by_dag]
+        non_asset_dags = {
+            d
+            # filter asset-triggered Dags
+            for d in all_dags_needing_dag_runs.difference(asset_triggered_dags)
+            # filter asset partition triggered Dags
+            if d.dag_id not in partition_dag_ids
+        }
         self._create_dag_runs(non_asset_dags, session)
         if asset_triggered_dags:
             self._create_dag_runs_asset_triggered(
-                dag_models=asset_triggered_dags,
+                dag_models=[d for d in asset_triggered_dags if d.dag_id not in 
partition_dag_ids],
                 triggered_date_by_dag=triggered_date_by_dag,
                 session=session,
             )
diff --git a/airflow-core/src/airflow/timetables/simple.py 
b/airflow-core/src/airflow/timetables/simple.py
index a0108029458..76af6a844e4 100644
--- a/airflow-core/src/airflow/timetables/simple.py
+++ b/airflow-core/src/airflow/timetables/simple.py
@@ -260,7 +260,7 @@ class PartitionedAssetTimetable(AssetTriggeredTimetable):
     def summary(self) -> str:
         return "Partitioned Asset"
 
-    def __init__(self, assets: BaseAsset, partition_mapper: PartitionMapper) 
-> None:
+    def __init__(self, *, assets: BaseAsset, partition_mapper: 
PartitionMapper) -> None:
         super().__init__(assets=assets)
         self.asset_condition = assets
         self.partition_mapper = partition_mapper
diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py 
b/airflow-core/tests/unit/jobs/test_scheduler_job.py
index 5f03ab7aa4b..767159bde59 100644
--- a/airflow-core/tests/unit/jobs/test_scheduler_job.py
+++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py
@@ -52,7 +52,15 @@ from airflow.executors.executor_loader import ExecutorLoader
 from airflow.executors.executor_utils import ExecutorName
 from airflow.jobs.job import Job, run_job
 from airflow.jobs.scheduler_job_runner import SchedulerJobRunner
-from airflow.models.asset import AssetActive, AssetAliasModel, 
AssetDagRunQueue, AssetEvent, AssetModel
+from airflow.models.asset import (
+    AssetActive,
+    AssetAliasModel,
+    AssetDagRunQueue,
+    AssetEvent,
+    AssetModel,
+    AssetPartitionDagRun,
+    PartitionedAssetKeyLog,
+)
 from airflow.models.backfill import Backfill, _create_backfill
 from airflow.models.dag import DagModel, get_last_dagrun, 
infer_automated_data_interval
 from airflow.models.dag_version import DagVersion
@@ -75,6 +83,7 @@ from airflow.sdk import DAG, Asset, AssetAlias, AssetWatcher, 
task
 from airflow.sdk.definitions.callback import AsyncCallback, SyncCallback
 from airflow.serialization.serialized_objects import LazyDeserializedDAG, 
SerializedDAG
 from airflow.timetables.base import DataInterval
+from airflow.timetables.simple import IdentityMapper, PartitionedAssetTimetable
 from airflow.utils.session import create_session, provide_session
 from airflow.utils.span_status import SpanStatus
 from airflow.utils.state import DagRunState, State, TaskInstanceState
@@ -96,7 +105,6 @@ from tests_common.test_utils.db import (
     clear_db_jobs,
     clear_db_pools,
     clear_db_runs,
-    clear_db_serialized_dags,
     clear_db_teams,
     clear_db_triggers,
     set_default_pool_slots,
@@ -181,33 +189,33 @@ def create_dagrun(session):
     return _create_dagrun
 
 
+def _clean_db():
+    clear_db_dags()
+    clear_db_runs()
+    clear_db_backfills()
+    clear_db_pools()
+    clear_db_import_errors()
+    clear_db_jobs()
+    clear_db_assets()
+    clear_db_deadline()
+    clear_db_callbacks()
+    clear_db_triggers()
+
+
 @patch.dict(
     ExecutorLoader.executors, {MOCK_EXECUTOR: 
f"{MockExecutor.__module__}.{MockExecutor.__qualname__}"}
 )
 @pytest.mark.usefixtures("disable_load_example")
 @pytest.mark.need_serialized_dag
 class TestSchedulerJob:
-    @staticmethod
-    def clean_db():
-        clear_db_dags()
-        clear_db_runs()
-        clear_db_backfills()
-        clear_db_pools()
-        clear_db_import_errors()
-        clear_db_jobs()
-        clear_db_assets()
-        clear_db_deadline()
-        clear_db_callbacks()
-        clear_db_triggers()
-
     @pytest.fixture(autouse=True)
     def per_test(self) -> Generator:
-        self.clean_db()
+        _clean_db()
         self.job_runner: SchedulerJobRunner | None = None
 
         yield
 
-        self.clean_db()
+        _clean_db()
 
     @pytest.fixture(autouse=True)
     def set_instance_attrs(self) -> Generator:
@@ -5024,7 +5032,7 @@ class TestSchedulerJob:
                 ti.state = State.SUCCESS
                 session.flush()
 
-        self.clean_db()
+        _clean_db()
 
         # Explicitly set catchup=True as test specifically expects runs to be 
created in date order
         with dag_maker(max_active_runs=3, session=session, catchup=True) as 
dag:
@@ -6371,7 +6379,7 @@ class TestSchedulerJob:
                 assert ti1.next_method == "__fail__"
                 assert ti2.state == State.DEFERRED
             finally:
-                self.clean_db()
+                _clean_db()
 
         # Positive case, will retry until success before reach max retry times
         check_if_trigger_timeout(retry_times)
@@ -7944,24 +7952,13 @@ class TestSchedulerJobQueriesCount:
 
     scheduler_job: Job | None
 
-    @staticmethod
-    def clean_db():
-        clear_db_runs()
-        clear_db_pools()
-        clear_db_backfills()
-        clear_db_dags()
-        clear_db_dag_bundles()
-        clear_db_import_errors()
-        clear_db_jobs()
-        clear_db_serialized_dags()
-
     @pytest.fixture(autouse=True)
     def per_test(self) -> Generator:
-        self.clean_db()
+        _clean_db()
 
         yield
 
-        self.clean_db()
+        _clean_db()
 
     @pytest.mark.parametrize(
         ("expected_query_count", "dag_count", "task_count"),
@@ -8140,3 +8137,61 @@ def test_mark_backfills_completed(dag_maker, session):
     runner._mark_backfills_complete()
     b = session.get(Backfill, b.id)
     assert b.completed_at.timestamp() > 0
+
+
+def 
test_when_dag_run_has_partition_and_downstreams_listening_then_tables_populated(
+    dag_maker,
+    session,
+):
+    asset = Asset(name="hello")
+    with dag_maker(dag_id="asset_event_tester", schedule=None, 
session=session) as dag:
+        EmptyOperator(task_id="hi", outlets=[asset])
+    dag1_id = dag.dag_id
+    dr = dag_maker.create_dagrun(partition_key="abc123", session=session)
+    assert dr.partition_key == "abc123"
+    [ti] = dr.get_task_instances(session=session)
+    session.commit()
+
+    with dag_maker(
+        dag_id="asset_event_listener",
+        schedule=PartitionedAssetTimetable(assets=asset, 
partition_mapper=IdentityMapper()),
+        session=session,
+    ):
+        EmptyOperator(task_id="hi")
+    session.commit()
+
+    TaskInstance.register_asset_changes_in_db(
+        ti=ti,
+        task_outlets=[asset.asprofile()],
+        outlet_events=[],
+        session=session,
+    )
+    session.commit()
+    event = session.scalar(
+        select(AssetEvent).where(
+            AssetEvent.source_dag_id == dag1_id,
+            AssetEvent.source_run_id == dr.run_id,
+        )
+    )
+    assert event.partition_key == "abc123"
+    pakl = session.scalar(
+        select(PartitionedAssetKeyLog).where(
+            PartitionedAssetKeyLog.asset_event_id == event.id,
+        )
+    )
+    apdr = session.scalar(
+        select(AssetPartitionDagRun).where(AssetPartitionDagRun.id == 
pakl.asset_partition_dag_run_id)
+    )
+    assert apdr is not None
+    assert apdr.created_dag_run_id is None
+    # ok, now we have established that the needed rows are there.
+    # let's see what the scheduler does
+
+    runner = SchedulerJobRunner(
+        job=Job(job_type=SchedulerJobRunner.job_type, 
executor=MockExecutor(do_update=False))
+    )
+    partition_dags = 
runner._create_dagruns_for_partitioned_asset_dags(session=session)
+    session.refresh(apdr)
+    assert apdr.created_dag_run_id is not None
+    assert len(partition_dags) == 1
+    assert partition_dags == {"asset_event_listener"}
diff --git a/airflow-core/tests/unit/models/test_taskinstance.py 
b/airflow-core/tests/unit/models/test_taskinstance.py
index 2e50fee1eac..8b70e9abecd 100644
--- a/airflow-core/tests/unit/models/test_taskinstance.py
+++ b/airflow-core/tests/unit/models/test_taskinstance.py
@@ -3304,7 +3304,7 @@ def 
test_when_dag_run_has_partition_and_downstreams_listening_then_tables_popula
 
     with dag_maker(
         dag_id="asset_event_listener",
-        schedule=PartitionedAssetTimetable(asset, IdentityMapper()),
+        schedule=PartitionedAssetTimetable(assets=asset, 
partition_mapper=IdentityMapper()),
         session=session,
     ):
         EmptyOperator(task_id="hi")
diff --git a/devel-common/src/tests_common/test_utils/db.py 
b/devel-common/src/tests_common/test_utils/db.py
index 89a79f91216..a643a1ad432 100644
--- a/devel-common/src/tests_common/test_utils/db.py
+++ b/devel-common/src/tests_common/test_utils/db.py
@@ -915,6 +915,8 @@ def create_default_connections_for_tests():
 def clear_all():
     clear_db_runs()
     clear_db_assets()
+    clear_db_apdr()
+    clear_db_pakl()
     clear_db_triggers()
     clear_db_dags()
     clear_db_serialized_dags()

Reply via email to