This is an automated email from the ASF dual-hosted git repository.

dstandish pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 66a7428e75 Get dataset-driven scheduling working (#24743)
66a7428e75 is described below

commit 66a7428e754d0059b7bf735c7a83a4b6675c66ef
Author: Daniel Standish <[email protected]>
AuthorDate: Thu Jul 7 14:49:55 2022 -0700

    Get dataset-driven scheduling working (#24743)
    
    AIP-48 
(https://cwiki.apache.org/confluence/display/AIRFLOW/AIP-48+Data+Dependency+Management+and+Data+Driven+Scheduling)
 allows DAGs to reference upstream "datasets", and tasks to reference 
downstream datasets through outlets.  And when a DAG's upstream datasets have 
been updated, a dagrun will be triggered.
    
    This constitutes essentially the "initial commit" for AIP-48.  We're gonna 
have to make changes and iterate and tweak things here and there, but it's a 
starting point that implements the basic functionality of the AIP.
---
 airflow/api_connexion/openapi/v1.yaml              |   1 +
 airflow/example_dags/example_datasets.py           | 133 +++++++++++++++++++++
 .../versions/0114_2_4_0_add_dataset_model.py       |  73 ++++++++++-
 airflow/models/dag.py                              |  95 +++++++++++++--
 airflow/models/dagrun.py                           |  77 +++++++++++-
 airflow/models/dataset.py                          | 132 +++++++++++++++++++-
 airflow/models/taskinstance.py                     |  18 ++-
 airflow/serialization/enums.py                     |   1 +
 airflow/serialization/schema.json                  |  37 ++++++
 airflow/serialization/serialized_objects.py        |  10 ++
 airflow/utils/helpers.py                           |  19 +++
 airflow/utils/types.py                             |   1 +
 docs/apache-airflow/concepts/datasets.rst          |  46 +++++++
 docs/apache-airflow/concepts/index.rst             |   1 +
 tests/models/test_dag.py                           |  40 ++++++-
 tests/models/test_dagrun.py                        |  65 +++++++++-
 tests/models/test_taskinstance.py                  |  31 +++++
 tests/test_utils/db.py                             |   6 +
 tests/utils/test_db_cleanup.py                     |   3 +
 tests/utils/test_helpers.py                        |  30 +++++
 20 files changed, 797 insertions(+), 22 deletions(-)

diff --git a/airflow/api_connexion/openapi/v1.yaml 
b/airflow/api_connexion/openapi/v1.yaml
index 6ce6c51608..d9c0cc6ed8 100644
--- a/airflow/api_connexion/openapi/v1.yaml
+++ b/airflow/api_connexion/openapi/v1.yaml
@@ -2487,6 +2487,7 @@ components:
             - backfill
             - manual
             - scheduled
+            - dataset_triggered
         state:
           $ref: '#/components/schemas/DagState'
           readOnly: true
diff --git a/airflow/example_dags/example_datasets.py 
b/airflow/example_dags/example_datasets.py
new file mode 100644
index 0000000000..e04479b0f5
--- /dev/null
+++ b/airflow/example_dags/example_datasets.py
@@ -0,0 +1,133 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Example DAG for demonstrating behavior of Datasets feature.
+
+Notes on usage:
+
+Turn on all the dags.
+
+DAG dag1 should run because it's on a schedule.
+
+After dag1 runs, dag3 should be triggered immediately because its only
+dataset dependency is managed by dag1.
+
+No other dags should be triggered.  Note that even though dag4 depends on
+the dataset in dag1, it will not be triggered until dag2 runs (and dag2 is
+left with no schedule so that we can trigger it manually).
+
+Next, trigger dag2.  After dag2 finishes, dag4 should run.
+
+Dags 5 and 6 should not run because they depend on datasets that never get 
updated.
+
+"""
+from datetime import datetime
+
+from airflow.models import DAG, Dataset
+from airflow.operators.bash import BashOperator
+
+# [START dataset_def]
+dag1_dataset = Dataset('s3://dag1/output_1.txt', extra={'hi': 'bye'})
+# [END dataset_def]
+dag2_dataset = Dataset('s3://dag2/output_1.txt', extra={'hi': 'bye'})
+
+dag1 = DAG(
+    dag_id='dag1',
+    catchup=False,
+    start_date=datetime(2020, 1, 1),
+    schedule_interval='@daily',
+    tags=['upstream'],
+)
+
+# [START task_outlet]
+BashOperator(outlets=[dag1_dataset], task_id='upstream_task_1', 
bash_command="sleep 5", dag=dag1)
+# [END task_outlet]
+
+with DAG(
+    dag_id='dag2',
+    catchup=False,
+    start_date=datetime(2020, 1, 1),
+    schedule_interval=None,
+    tags=['upstream'],
+) as dag2:
+    BashOperator(
+        outlets=[dag2_dataset],
+        task_id='upstream_task_2',
+        bash_command="sleep 5",
+    )
+
+# [START dag_dep]
+dag3 = DAG(
+    dag_id='dag3',
+    catchup=False,
+    start_date=datetime(2020, 1, 1),
+    schedule_on=[dag1_dataset],
+    tags=['downstream'],
+)
+# [END dag_dep]
+
+BashOperator(
+    outlets=[Dataset('s3://downstream_1_task/dataset_other.txt')],
+    task_id='downstream_1',
+    bash_command="sleep 5",
+    dag=dag3,
+)
+
+with DAG(
+    dag_id='dag4',
+    catchup=False,
+    start_date=datetime(2020, 1, 1),
+    schedule_on=[dag1_dataset, dag2_dataset],
+    tags=['downstream'],
+) as dag4:
+    BashOperator(
+        outlets=[Dataset('s3://downstream_2_task/dataset_other_unknown.txt')],
+        task_id='downstream_2',
+        bash_command="sleep 5",
+    )
+
+with DAG(
+    dag_id='dag5',
+    catchup=False,
+    start_date=datetime(2020, 1, 1),
+    schedule_on=[
+        dag1_dataset,
+        Dataset('s3://this-dataset-doesnt-get-triggered'),
+    ],
+    tags=['downstream'],
+) as dag5:
+    BashOperator(
+        outlets=[Dataset('s3://downstream_2_task/dataset_other_unknown.txt')],
+        task_id='downstream_3',
+        bash_command="sleep 5",
+    )
+
+with DAG(
+    dag_id='dag6',
+    catchup=False,
+    start_date=datetime(2020, 1, 1),
+    schedule_on=[
+        Dataset('s3://unrelated/dataset3.txt'),
+        Dataset('s3://unrelated/dataset_other_unknown.txt'),
+    ],
+    tags=['unrelated'],
+) as dag6:
+    BashOperator(
+        task_id='unrelated_task',
+        outlets=[Dataset('s3://unrelated_task/dataset_other_unknown.txt')],
+        bash_command="sleep 5",
+    )
diff --git a/airflow/migrations/versions/0114_2_4_0_add_dataset_model.py 
b/airflow/migrations/versions/0114_2_4_0_add_dataset_model.py
index 838f8780b0..9cfca3766c 100644
--- a/airflow/migrations/versions/0114_2_4_0_add_dataset_model.py
+++ b/airflow/migrations/versions/0114_2_4_0_add_dataset_model.py
@@ -26,9 +26,9 @@ Create Date: 2022-06-22 14:37:20.880672
 
 import sqlalchemy as sa
 from alembic import op
-from sqlalchemy import Integer, String
+from sqlalchemy import Integer, String, func
 
-from airflow.migrations.db_types import TIMESTAMP
+from airflow.migrations.db_types import TIMESTAMP, StringID
 from airflow.utils.sqlalchemy import ExtendedJSON
 
 revision = '0038cd0c28b4'
@@ -38,8 +38,7 @@ depends_on = None
 airflow_version = '2.4.0'
 
 
-def upgrade():
-    """Apply Add Dataset model"""
+def _create_dataset_table():
     op.create_table(
         'dataset',
         sa.Column('id', Integer, primary_key=True, autoincrement=True),
@@ -64,6 +63,72 @@ def upgrade():
     op.create_index('idx_uri_unique', 'dataset', ['uri'], unique=True)
 
 
+def _create_dataset_dag_ref_table():
+    op.create_table(
+        'dataset_dag_ref',
+        sa.Column('dataset_id', Integer, primary_key=True, nullable=False),
+        sa.Column('dag_id', String(250), primary_key=True, nullable=False),
+        sa.Column('created_at', TIMESTAMP, default=func.now, nullable=False),
+        sa.Column('updated_at', TIMESTAMP, default=func.now, nullable=False),
+        sa.ForeignKeyConstraint(
+            ('dataset_id',),
+            ['dataset.id'],
+            name="datasetdagref_dataset_fkey",
+            ondelete="CASCADE",
+        ),
+        sqlite_autoincrement=True,  # ensures PK values not reused
+    )
+
+
+def _create_dataset_task_ref_table():
+    op.create_table(
+        'dataset_task_ref',
+        sa.Column('dataset_id', Integer, primary_key=True, nullable=False),
+        sa.Column('dag_id', String(250), primary_key=True, nullable=False),
+        sa.Column('task_id', String(250), primary_key=True, nullable=False),
+        sa.Column('created_at', TIMESTAMP, default=func.now, nullable=False),
+        sa.Column('updated_at', TIMESTAMP, default=func.now, nullable=False),
+        sa.ForeignKeyConstraint(
+            ('dataset_id',),
+            ['dataset.id'],
+            name="datasettaskref_dataset_fkey",
+            ondelete="CASCADE",
+        ),
+    )
+
+
+def _create_dataset_dag_run_queue_table():
+    op.create_table(
+        'dataset_dag_run_queue',
+        sa.Column('dataset_id', Integer, primary_key=True, nullable=False),
+        sa.Column('target_dag_id', StringID(), primary_key=True, 
nullable=False),
+        sa.Column('created_at', TIMESTAMP, default=func.now, nullable=False),
+        sa.ForeignKeyConstraint(
+            ('dataset_id',),
+            ['dataset.id'],
+            name="ddrq_dataset_fkey",
+            ondelete="CASCADE",
+        ),
+        sa.ForeignKeyConstraint(
+            ('target_dag_id',),
+            ['dag.dag_id'],
+            name="ddrq_dag_fkey",
+            ondelete="CASCADE",
+        ),
+    )
+
+
+def upgrade():
+    """Apply Add Dataset model"""
+    _create_dataset_table()
+    _create_dataset_dag_ref_table()
+    _create_dataset_task_ref_table()
+    _create_dataset_dag_run_queue_table()
+
+
 def downgrade():
     """Unapply Add Dataset model"""
+    op.drop_table('dataset_dag_ref')
+    op.drop_table('dataset_task_ref')
+    op.drop_table('dataset_dag_run_queue')
     op.drop_table('dataset')
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index b16126e569..4cef709b16 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -39,6 +39,7 @@ from typing import (
     FrozenSet,
     Iterable,
     List,
+    NamedTuple,
     Optional,
     Sequence,
     Set,
@@ -83,7 +84,7 @@ from airflow.utils import timezone
 from airflow.utils.dag_cycle_tester import check_cycle
 from airflow.utils.dates import cron_presets, date_range as utils_date_range
 from airflow.utils.file import correct_maybe_zipped
-from airflow.utils.helpers import exactly_one, validate_key
+from airflow.utils.helpers import at_most_one, exactly_one, validate_key
 from airflow.utils.log.logging_mixin import LoggingMixin
 from airflow.utils.session import NEW_SESSION, provide_session
 from airflow.utils.sqlalchemy import Interval, UtcDateTime, skip_locked, 
tuple_in_condition, with_row_locks
@@ -92,6 +93,7 @@ from airflow.utils.types import NOTSET, ArgNotSet, 
DagRunType, EdgeInfoType
 
 if TYPE_CHECKING:
     from airflow.decorators import TaskDecoratorCollection
+    from airflow.models.dataset import Dataset
     from airflow.models.slamiss import SlaMiss
     from airflow.utils.task_group import TaskGroup
 
@@ -279,6 +281,7 @@ class DAG(LoggingMixin):
         to render templates as native Python types. If False, a Jinja
         ``Environment`` is used to render templates as string values.
     :param tags: List of tags to help filtering DAGs in the UI.
+    :param schedule_on: List of upstream datasets if for use in triggering DAG 
runs.
     """
 
     _comps = {
@@ -338,6 +341,7 @@ class DAG(LoggingMixin):
         jinja_environment_kwargs: Optional[Dict] = None,
         render_template_as_native_obj: bool = False,
         tags: Optional[List[str]] = None,
+        schedule_on: Optional[Sequence["Dataset"]] = None,
     ):
         from airflow.utils.task_group import TaskGroup
 
@@ -415,17 +419,29 @@ class DAG(LoggingMixin):
         if 'end_date' in self.default_args:
             self.default_args['end_date'] = 
timezone.convert_to_utc(self.default_args['end_date'])
 
-        # Calculate the DAG's timetable.
-        if timetable is None:
-            self.timetable = create_timetable(schedule_interval, self.timezone)
-            if isinstance(schedule_interval, ArgNotSet):
-                schedule_interval = DEFAULT_SCHEDULE_INTERVAL
-            self.schedule_interval: ScheduleInterval = schedule_interval
-        elif isinstance(schedule_interval, ArgNotSet):
+        # sort out DAG's scheduling behavior
+        scheduling_args = [schedule_interval, timetable, schedule_on]
+        if not at_most_one(*scheduling_args):
+            raise ValueError(
+                "At most one allowed for args 'schedule_interval', 
'timetable', and 'schedule_on'."
+            )
+
+        self.timetable: Timetable
+        self.schedule_interval: ScheduleInterval
+        self.schedule_on: Optional[List["Dataset"]] = list(schedule_on) if 
schedule_on else None
+        if schedule_on:
+            if not isinstance(schedule_on, Sequence):
+                raise ValueError("Param `schedule_on` must be 
Sequence[Dataset]")
+            self.schedule_interval = None
+            self.timetable = NullTimetable()
+        elif timetable:
             self.timetable = timetable
             self.schedule_interval = self.timetable.summary
         else:
-            raise TypeError("cannot specify both 'schedule_interval' and 
'timetable'")
+            if isinstance(schedule_interval, ArgNotSet):
+                schedule_interval = DEFAULT_SCHEDULE_INTERVAL
+            self.schedule_interval = schedule_interval
+            self.timetable = create_timetable(schedule_interval, self.timezone)
 
         if isinstance(template_searchpath, str):
             template_searchpath = [template_searchpath]
@@ -2418,6 +2434,7 @@ class DAG(LoggingMixin):
 
         log.info("Sync %s DAGs", len(dags))
         dag_by_ids = {dag.dag_id: dag for dag in dags}
+
         dag_ids = set(dag_by_ids.keys())
         query = (
             session.query(DagModel)
@@ -2509,6 +2526,65 @@ class DAG(LoggingMixin):
 
         DagCode.bulk_sync_to_db(filelocs, session=session)
 
+        from airflow.models.dataset import Dataset, DatasetDagRef, 
DatasetTaskRef
+
+        class OutletRef(NamedTuple):
+            dag_id: str
+            task_id: str
+            uri: str
+
+        class InletRef(NamedTuple):
+            dag_id: str
+            uri: str
+
+        dag_references = set()
+        outlet_references = set()
+        outlet_datasets = set()
+        input_datasets = set()
+        for dag in dags:
+            for dataset in dag.schedule_on or []:
+                dag_references.add(InletRef(dag.dag_id, dataset.uri))
+                input_datasets.add(dataset)
+            for task in dag.tasks:
+                for obj in getattr(task, '_outlets', []):  # type: Dataset
+                    if isinstance(obj, Dataset):
+                        outlet_references.add(OutletRef(task.dag_id, 
task.task_id, obj.uri))
+                        outlet_datasets.add(obj)
+        all_datasets = outlet_datasets.union(input_datasets)
+
+        # store datasets
+        stored_datasets = {}
+        for dataset in all_datasets:
+            stored_dataset = session.query(Dataset).filter(Dataset.uri == 
dataset.uri).first()
+            if stored_dataset:
+                stored_datasets[stored_dataset.uri] = stored_dataset
+            else:
+                session.add(dataset)
+                stored_datasets[dataset.uri] = dataset
+
+        session.flush()  # this is required to ensure each dataset has its PK 
loaded
+
+        del all_datasets
+
+        # store dag-schedule-on-dataset references
+        for dag_ref in dag_references:
+            session.merge(
+                DatasetDagRef(
+                    dataset_id=stored_datasets[dag_ref.uri].id,
+                    dag_id=dag_ref.dag_id,
+                )
+            )
+
+        # store task-outlet-dataset references
+        for outlet_ref in outlet_references:
+            session.merge(
+                DatasetTaskRef(
+                    dataset_id=stored_datasets[outlet_ref.uri].id,
+                    dag_id=outlet_ref.dag_id,
+                    task_id=outlet_ref.task_id,
+                )
+            )
+
         # Issue SQL/finish "Unit of Work", but let @provide_session commit (or 
if passed a session, let caller
         # decide when to commit
         session.flush()
@@ -2731,7 +2807,6 @@ class DagModel(Base):
     schedule_interval = Column(Interval)
     # Timetable/Schedule Interval description
     timetable_description = Column(String(1000), nullable=True)
-
     # Tags for view filter
     tags = relationship('DagTag', cascade='all, delete, delete-orphan', 
backref=backref('dag'))
 
diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py
index a22b7d34c9..31888fe22e 100644
--- a/airflow/models/dagrun.py
+++ b/airflow/models/dagrun.py
@@ -55,7 +55,7 @@ from sqlalchemy.exc import IntegrityError
 from sqlalchemy.ext.declarative import declared_attr
 from sqlalchemy.orm import joinedload, relationship, synonym
 from sqlalchemy.orm.session import Session
-from sqlalchemy.sql.expression import false, select, true
+from sqlalchemy.sql.expression import case, false, select, true
 
 from airflow import settings
 from airflow.callbacks.callback_requests import DagCallbackRequest
@@ -631,8 +631,83 @@ class DagRun(Base, LoggingMixin):
         session.merge(self)
         # We do not flush here for performance reasons(It increases queries 
count by +20)
 
+        self._process_dataset_dagrun_events(session=session)
+
         return schedulable_tis, callback
 
+    def _process_dataset_dagrun_events(self, *, session=NEW_SESSION):
+        """
+        Looks at all outlet datasets that have been updated by this dag,
+        and creates DAG runs that have all dataset deps fulfilled.
+        """
+        from airflow.models.dataset import Dataset, DatasetDagRef, 
DatasetTaskRef
+
+        has_dataset_outlets = False
+        if self.dag:
+            for _, task in self.dag.task_dict.items():
+                if has_dataset_outlets is True:
+                    break
+                for obj in getattr(task, '_outlets', []):
+                    if isinstance(obj, Dataset):
+                        has_dataset_outlets = True
+                        break
+        dependent_dag_ids = []
+        if self.dag and has_dataset_outlets:
+            dependent_dag_ids = [
+                x.dag_id
+                for x in session.query(DatasetDagRef.dag_id)
+                .filter(DatasetTaskRef.dag_id == self.dag_id)
+                .all()
+            ]
+
+        from airflow.models.dataset import DatasetDagRunQueue as DDRQ
+        from airflow.models.serialized_dag import SerializedDagModel
+
+        dag_ids_to_trigger = None
+        if dependent_dag_ids:
+            dag_ids_to_trigger = [
+                x.dag_id
+                for x in session.query(
+                    DatasetDagRef.dag_id,
+                )
+                .join(
+                    DDRQ,
+                    and_(
+                        DDRQ.dataset_id == DatasetDagRef.dataset_id,
+                        DDRQ.target_dag_id == DatasetDagRef.dag_id,
+                    ),
+                    isouter=True,
+                )
+                .filter(DatasetDagRef.dag_id.in_(dependent_dag_ids))
+                .group_by(DatasetDagRef.dag_id)
+                .having(func.count() == 
func.sum(case((DDRQ.target_dag_id.is_not(None), 1), else_=0)))
+                .all()
+            ]
+
+        if dag_ids_to_trigger:
+            dags_to_purge_from_queue = set()
+            for target_dag_id in dag_ids_to_trigger:
+                row = SerializedDagModel.get(target_dag_id, session)
+                if not row:
+                    self.log.warning("Could not find serialized DAG %s", 
target_dag_id)
+                    continue
+                dag = row.dag
+                if dag.schedule_on:
+                    dag.create_dagrun(
+                        run_type=DagRunType.MANUAL,
+                        run_id=self.generate_run_id(
+                            DagRunType.DATASET_TRIGGERED, 
execution_date=timezone.utcnow()
+                        ),
+                        state=DagRunState.QUEUED,
+                        session=session,
+                    )
+                else:
+                    self.log.warning(
+                        "DAG %s no longer has a dataset scheduling dep; 
purging queue records.", dag.dag_id
+                    )
+                dags_to_purge_from_queue.add(target_dag_id)
+            
session.query(DDRQ).filter(DDRQ.target_dag_id.in_(dags_to_purge_from_queue)).delete()
+
     @provide_session
     def task_instance_scheduling_decisions(self, session: Session = 
NEW_SESSION) -> TISchedulingDecision:
         tis = self.get_task_instances(session=session, state=State.task_states)
diff --git a/airflow/models/dataset.py b/airflow/models/dataset.py
index 49f5dfd1f6..256ed2293e 100644
--- a/airflow/models/dataset.py
+++ b/airflow/models/dataset.py
@@ -17,9 +17,10 @@
 # under the License.
 from urllib.parse import urlparse
 
-from sqlalchemy import Column, Index, Integer, String
+from sqlalchemy import Column, ForeignKeyConstraint, Index, Integer, 
PrimaryKeyConstraint, String
+from sqlalchemy.orm import relationship
 
-from airflow.models.base import Base
+from airflow.models.base import ID_LEN, Base, StringID
 from airflow.utils import timezone
 from airflow.utils.sqlalchemy import ExtendedJSON, UtcDateTime
 
@@ -49,6 +50,9 @@ class Dataset(Base):
     created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False)
     updated_at = Column(UtcDateTime, default=timezone.utcnow, 
onupdate=timezone.utcnow, nullable=False)
 
+    dag_references = relationship("DatasetDagRef", back_populates="dataset")
+    task_references = relationship("DatasetTaskRef", back_populates="dataset")
+
     __tablename__ = "dataset"
     __table_args__ = (
         Index('idx_uri_unique', uri, unique=True),
@@ -66,10 +70,132 @@ class Dataset(Base):
         super().__init__(uri=uri, **kwargs)
 
     def __eq__(self, other):
-        return self.uri == other.uri
+        if isinstance(other, self.__class__):
+            return self.uri == other.uri
+        else:
+            return NotImplemented
 
     def __hash__(self):
         return hash(self.uri)
 
     def __repr__(self):
         return f"{self.__class__.__name__}(uri={self.uri!r}, 
extra={self.extra!r})"
+
+
+class DatasetDagRef(Base):
+    """References from a DAG to an upstream dataset."""
+
+    dataset_id = Column(Integer, primary_key=True, nullable=False)
+    dag_id = Column(String(ID_LEN), primary_key=True, nullable=False)
+    created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False)
+    updated_at = Column(UtcDateTime, default=timezone.utcnow, 
onupdate=timezone.utcnow, nullable=False)
+
+    dataset = relationship('Dataset')
+
+    __tablename__ = "dataset_dag_ref"
+    __table_args__ = (
+        PrimaryKeyConstraint(dataset_id, dag_id, name="datasetdagref_pkey", 
mssql_clustered=True),
+        ForeignKeyConstraint(
+            (dataset_id,),
+            ["dataset.id"],
+            name='datasetdagref_dataset_fkey',
+            ondelete="CASCADE",
+        ),
+    )
+
+    def __eq__(self, other):
+        if isinstance(other, self.__class__):
+            return self.dataset_id == other.dataset_id and self.dag_id == 
other.dag_id
+        else:
+            return NotImplemented
+
+    def __hash__(self):
+        return hash(self.__mapper__.primary_key)
+
+    def __repr__(self):
+        args = []
+        for attr in [x.name for x in self.__mapper__.primary_key]:
+            args.append(f"{attr}={getattr(self, attr)!r}")
+        return f"{self.__class__.__name__}({', '.join(args)})"
+
+
+class DatasetTaskRef(Base):
+    """References from a task to a downstream dataset."""
+
+    dataset_id = Column(Integer, primary_key=True, nullable=False)
+    dag_id = Column(String(ID_LEN), primary_key=True, nullable=False)
+    task_id = Column(String(ID_LEN), primary_key=True, nullable=False)
+    created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False)
+    updated_at = Column(UtcDateTime, default=timezone.utcnow, 
onupdate=timezone.utcnow, nullable=False)
+
+    dataset = relationship("Dataset", back_populates="task_references")
+
+    __tablename__ = "dataset_task_ref"
+    __table_args__ = (
+        ForeignKeyConstraint(
+            (dataset_id,),
+            ["dataset.id"],
+            name='datasettaskref_dataset_fkey',
+            ondelete="CASCADE",
+        ),
+        PrimaryKeyConstraint(dataset_id, dag_id, task_id, 
name="datasettaskref_pkey", mssql_clustered=True),
+    )
+
+    def __eq__(self, other):
+        if isinstance(other, self.__class__):
+            return (
+                self.dataset_id == other.dataset_id
+                and self.dag_id == other.dag_id
+                and self.task_id == other.task_id
+            )
+        else:
+            return NotImplemented
+
+    def __hash__(self):
+        return hash(self.__mapper__.primary_key)
+
+    def __repr__(self):
+        args = []
+        for attr in [x.name for x in self.__mapper__.primary_key]:
+            args.append(f"{attr}={getattr(self, attr)!r}")
+        return f"{self.__class__.__name__}({', '.join(args)})"
+
+
+class DatasetDagRunQueue(Base):
+    """Model for storing dataset events that need processing."""
+
+    dataset_id = Column(Integer, primary_key=True, nullable=False)
+    target_dag_id = Column(StringID(), primary_key=True, nullable=False)
+    created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False)
+
+    __tablename__ = "dataset_dag_run_queue"
+    __table_args__ = (
+        PrimaryKeyConstraint(dataset_id, target_dag_id, 
name="datasetdagrunqueue_pkey", mssql_clustered=True),
+        ForeignKeyConstraint(
+            (dataset_id,),
+            ["dataset.id"],
+            name='ddrq_dataset_fkey',
+            ondelete="CASCADE",
+        ),
+        ForeignKeyConstraint(
+            (target_dag_id,),
+            ["dag.dag_id"],
+            name='ddrq_dag_fkey',
+            ondelete="CASCADE",
+        ),
+    )
+
+    def __eq__(self, other):
+        if isinstance(other, self.__class__):
+            return self.dataset_id == other.dataset_id and self.target_dag_id 
== other.target_dag_id
+        else:
+            return NotImplemented
+
+    def __hash__(self):
+        return hash(self.__mapper__.primary_key)
+
+    def __repr__(self):
+        args = []
+        for attr in [x.name for x in self.__mapper__.primary_key]:
+            args.append(f"{attr}={getattr(self, attr)!r}")
+        return f"{self.__class__.__name__}({', '.join(args)})"
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 8950385248..18bae29a0c 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -95,6 +95,7 @@ from airflow.exceptions import (
     XComForMappingNotPushed,
 )
 from airflow.models.base import Base, StringID
+from airflow.models.dataset import DatasetDagRunQueue
 from airflow.models.log import Log
 from airflow.models.param import ParamsDict
 from airflow.models.taskfail import TaskFail
@@ -1512,9 +1513,24 @@ class TaskInstance(Base, LoggingMixin):
         if not test_mode:
             session.add(Log(self.state, self))
             session.merge(self)
-
+            self._create_dataset_dag_run_queue_records(session=session)
             session.commit()
 
+    def _create_dataset_dag_run_queue_records(self, *, session):
+        from airflow.models import Dataset
+
+        for obj in getattr(self.task, '_outlets', []):
+            self.log.debug("outlet obj %s", obj)
+            if isinstance(obj, Dataset):
+                dataset = session.query(Dataset).filter(Dataset.uri == 
obj.uri).one_or_none()
+                if not dataset:
+                    self.log.warning("Dataset %s not found", obj)
+                    continue
+                downstream_dag_ids = [x.dag_id for x in dataset.dag_references]
+                self.log.debug("downstream dag ids %s", downstream_dag_ids)
+                for dag_id in downstream_dag_ids:
+                    session.merge(DatasetDagRunQueue(dataset_id=dataset.id, 
target_dag_id=dag_id))
+
     def _execute_task_with_callbacks(self, context, test_mode=False):
         """Prepare Task for Execution"""
         from airflow.models.renderedtifields import RenderedTaskInstanceFields
diff --git a/airflow/serialization/enums.py b/airflow/serialization/enums.py
index f4227a6f7a..420b3e015b 100644
--- a/airflow/serialization/enums.py
+++ b/airflow/serialization/enums.py
@@ -49,3 +49,4 @@ class DagAttributeTypes(str, Enum):
     EDGE_INFO = 'edgeinfo'
     PARAM = 'param'
     XCOM_REF = 'xcomref'
+    DATASET = 'dataset'
diff --git a/airflow/serialization/schema.json 
b/airflow/serialization/schema.json
index 423950bb5e..1550387eed 100644
--- a/airflow/serialization/schema.json
+++ b/airflow/serialization/schema.json
@@ -53,6 +53,34 @@
         { "type": "integer" }
       ]
     },
+    "dataset": {
+      "type": "object",
+      "properties": {
+        "uri": { "type": "string" },
+        "extra": {
+            "anyOf": [
+                {"type": "null"},
+                { "$ref": "#/definitions/dict" }
+            ]
+        }
+      },
+      "required": [ "uri", "extra" ]
+    },
+    "typed_dataset": {
+      "type": "object",
+      "properties": {
+        "__type": {
+          "type": "string",
+            "constant": "dataset"
+        },
+        "__var": { "$ref": "#/definitions/dataset" }
+      },
+      "required": [
+        "__type",
+        "__var"
+      ],
+      "additionalProperties": false
+    },
     "dict": {
       "description": "A python dictionary containing values of any type",
       "type": "object"
@@ -90,6 +118,15 @@
             { "$ref": "#/definitions/typed_relativedelta" }
           ]
         },
+        "schedule_on":{
+            "anyOf": [
+                { "type": "null" },
+                {
+                  "type": "array",
+                  "items": { "$ref": "#/definitions/typed_dataset" }
+                }
+            ]
+        },
         "timetable": {
           "type": "object",
           "properties": {
diff --git a/airflow/serialization/serialized_objects.py 
b/airflow/serialization/serialized_objects.py
index bd0430ba26..f4a4257f57 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -33,6 +33,7 @@ from pendulum.tz.timezone import FixedTimezone, Timezone
 from airflow.compat.functools import cache
 from airflow.configuration import conf
 from airflow.exceptions import AirflowException, SerializationError
+from airflow.models import Dataset
 from airflow.models.baseoperator import BaseOperator, BaseOperatorLink
 from airflow.models.connection import Connection
 from airflow.models.dag import DAG, create_timetable
@@ -370,6 +371,8 @@ class BaseSerialization:
             return cls._encode(cls._serialize_param(var), type_=DAT.PARAM)
         elif isinstance(var, XComArg):
             return cls._encode(cls._serialize_xcomarg(var), type_=DAT.XCOM_REF)
+        elif isinstance(var, Dataset):
+            return cls._encode(dict(uri=var.uri, extra=var.extra), 
type_=DAT.DATASET)
         else:
             log.debug('Cast type %s to str in serialization.', type(var))
             return str(var)
@@ -415,6 +418,8 @@ class BaseSerialization:
             return cls._deserialize_param(var)
         elif type_ == DAT.XCOM_REF:
             return cls._deserialize_xcomref(var)
+        elif type_ == DAT.DATASET:
+            return Dataset(**var)
         else:
             raise TypeError(f'Invalid type {type_!s} in deserialization.')
 
@@ -746,6 +751,9 @@ class SerializedBaseOperator(BaseOperator, 
BaseSerialization):
                 v = {arg: cls._deserialize(value) for arg, value in v.items()}
             elif k in cls._decorated_fields or k not in 
op.get_serialized_fields():
                 v = cls._deserialize(v)
+            elif k in ("_outlets", "_inlets"):
+                v = cls._deserialize(v)
+
             # else use v as it is
 
             setattr(op, k, v)
@@ -1024,6 +1032,8 @@ class SerializedDAG(DAG, BaseSerialization):
                 v = cls._deserialize(v)
             elif k == "params":
                 v = cls._deserialize_params_dict(v)
+            elif k == "schedule_on":
+                v = cls._deserialize(v)
             # else use v as it is
 
             setattr(dag, k, v)
diff --git a/airflow/utils/helpers.py b/airflow/utils/helpers.py
index e8366805b4..769f7986fb 100644
--- a/airflow/utils/helpers.py
+++ b/airflow/utils/helpers.py
@@ -39,6 +39,7 @@ from typing import (
 from airflow.configuration import conf
 from airflow.exceptions import AirflowException
 from airflow.utils.module_loading import import_string
+from airflow.utils.types import NOTSET
 
 if TYPE_CHECKING:
     import jinja2
@@ -314,6 +315,24 @@ def exactly_one(*args) -> bool:
     return sum(map(bool, args)) == 1
 
 
+def at_most_one(*args) -> bool:
+    """
+    Returns True if at most one of *args is "truthy", and False otherwise.
+
+    NOTSET is treated the same as None.
+
+    If user supplies an iterable, we raise ValueError and force them to unpack.
+    """
+
+    def is_set(val):
+        if val is NOTSET:
+            return False
+        else:
+            return bool(val)
+
+    return sum(map(is_set, args)) in (0, 1)
+
+
 def prune_dict(val: Any, mode='strict'):
     """
     Given dict ``val``, returns new dict based on ``val`` with all
diff --git a/airflow/utils/types.py b/airflow/utils/types.py
index 04688a7f65..313780bb2f 100644
--- a/airflow/utils/types.py
+++ b/airflow/utils/types.py
@@ -45,6 +45,7 @@ class DagRunType(str, enum.Enum):
     BACKFILL_JOB = "backfill"
     SCHEDULED = "scheduled"
     MANUAL = "manual"
+    DATASET_TRIGGERED = "dataset_triggered"
 
     def __str__(self) -> str:
         return self.value
diff --git a/docs/apache-airflow/concepts/datasets.rst 
b/docs/apache-airflow/concepts/datasets.rst
new file mode 100644
index 0000000000..211349db56
--- /dev/null
+++ b/docs/apache-airflow/concepts/datasets.rst
@@ -0,0 +1,46 @@
+ .. Licensed to the Apache Software Foundation (ASF) under one
+    or more contributor license agreements.  See the NOTICE file
+    distributed with this work for additional information
+    regarding copyright ownership.  The ASF licenses this file
+    to you under the Apache License, Version 2.0 (the
+    "License"); you may not use this file except in compliance
+    with the License.  You may obtain a copy of the License at
+
+ ..   http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+    software distributed under the License is distributed on an
+    "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+    KIND, either express or implied.  See the License for the
+    specific language governing permissions and limitations
+    under the License.
+
+Datasets
+========
+
+.. versionadded:: 2.4
+
+With datasets, instead of running a DAG on a schedule, a DAG can be configured 
to run when a dataset has been updated.
+
+To use this feature, define a dataset:
+
+.. exampleinclude:: /../../airflow/example_dags/example_datasets.py
+    :language: python
+    :start-after: [START dataset_def]
+    :end-before: [END dataset_def]
+
+Then reference the dataset as a task outlet:
+
+.. exampleinclude:: /../../airflow/example_dags/example_datasets.py
+    :language: python
+    :start-after: [START task_outlet]
+    :end-before: [END task_outlet]
+
+Finally, define a DAG and reference this dataset in the DAG's ``schedule_on`` 
parameter:
+
+.. exampleinclude:: /../../airflow/example_dags/example_datasets.py
+    :language: python
+    :start-after: [START dag_dep]
+    :end-before: [END dag_dep]
+
+You can reference multiple datasets in the DAG's ``schedule_on`` param.  Once 
there has been an update to all of the upstream datasets, the DAG will be 
triggered.  This means that the DAG will run as frequently as its 
least-frequently-updated dataset.
diff --git a/docs/apache-airflow/concepts/index.rst 
b/docs/apache-airflow/concepts/index.rst
index 122dc760fe..9297c54e00 100644
--- a/docs/apache-airflow/concepts/index.rst
+++ b/docs/apache-airflow/concepts/index.rst
@@ -38,6 +38,7 @@ Here you can find detailed documentation about each one of 
Airflow's core concep
     operators
     dynamic-task-mapping
     sensors
+    datasets
     deferring
     smart-sensors
     taskflow
diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py
index 0164ce0f87..50c49f3819 100644
--- a/tests/models/test_dag.py
+++ b/tests/models/test_dag.py
@@ -43,9 +43,10 @@ from airflow import models, settings
 from airflow.configuration import conf
 from airflow.decorators import task as task_decorator
 from airflow.exceptions import AirflowException, DuplicateTaskIdFound, 
ParamValidationError
-from airflow.models import DAG, DagModel, DagRun, DagTag, TaskFail, 
TaskInstance as TI
+from airflow.models import DAG, DagModel, DagRun, DagTag, Dataset, TaskFail, 
TaskInstance as TI
 from airflow.models.baseoperator import BaseOperator
 from airflow.models.dag import dag as dag_decorator
+from airflow.models.dataset import DatasetTaskRef
 from airflow.models.param import DagParam, Param, ParamsDict
 from airflow.operators.bash import BashOperator
 from airflow.operators.empty import EmptyOperator
@@ -821,6 +822,43 @@ class TestDag(unittest.TestCase):
         assert not model.has_import_errors
         session.close()
 
+    def test_bulk_write_to_db_datasets_schedule_on(self):
+        """
+        Ensure that datasets referenced in a dag are correctly loaded into the 
database.
+        """
+        # todo: clear db
+        dag_id1 = 'test_dataset_dag1'
+        dag_id2 = 'test_dataset_dag2'
+        task_id = 'test_dataset_task'
+        uri1 = 's3://dataset1'
+        d1 = Dataset(uri1, extra={"not": "used"})
+        d2 = Dataset('s3://dataset2')
+        d3 = Dataset('s3://dataset3')
+        dag1 = DAG(dag_id=dag_id1, start_date=DEFAULT_DATE, schedule_on=[d1])
+        EmptyOperator(task_id=task_id, dag=dag1, outlets=[d2, d3])
+        dag2 = DAG(dag_id=dag_id2, start_date=DEFAULT_DATE)
+        EmptyOperator(task_id=task_id, dag=dag2, outlets=[Dataset(uri1, 
extra={"should": "be used"})])
+        session = settings.Session()
+        dag1.clear()
+        DAG.bulk_write_to_db([dag1, dag2], session)
+        session.commit()
+        stored_datasets = {x.uri: x for x in session.query(Dataset).all()}
+        d1 = stored_datasets[d1.uri]
+        d2 = stored_datasets[d2.uri]
+        d3 = stored_datasets[d3.uri]
+        assert stored_datasets[uri1].extra == {"should": "be used"}
+        assert [x.dag_id for x in d1.dag_references] == [dag_id1]
+        assert [(x.task_id, x.dag_id) for x in d1.task_references] == 
[(task_id, dag_id2)]
+        assert set(
+            session.query(DatasetTaskRef.task_id, DatasetTaskRef.dag_id, 
DatasetTaskRef.dataset_id)
+            .filter(DatasetTaskRef.dag_id.in_((dag_id1, dag_id2)))
+            .all()
+        ) == {
+            (task_id, dag_id1, d2.id),
+            (task_id, dag_id1, d3.id),
+            (task_id, dag_id2, d1.id),
+        }
+
     def test_sync_to_db(self):
         dag = DAG(
             'dag',
diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py
index 6c3cc1c91c..397f519a50 100644
--- a/tests/models/test_dagrun.py
+++ b/tests/models/test_dagrun.py
@@ -20,6 +20,7 @@ import datetime
 from typing import Mapping, Optional
 from unittest import mock
 from unittest.mock import call
+from uuid import uuid4
 
 import pendulum
 import pytest
@@ -28,10 +29,21 @@ from sqlalchemy.orm.session import Session
 from airflow import settings
 from airflow.callbacks.callback_requests import DagCallbackRequest
 from airflow.decorators import task
-from airflow.models import DAG, DagBag, DagModel, DagRun, TaskInstance as TI, 
clear_task_instances
+from airflow.models import (
+    DAG,
+    DagBag,
+    DagModel,
+    DagRun,
+    TaskInstance,
+    TaskInstance as TI,
+    clear_task_instances,
+)
 from airflow.models.baseoperator import BaseOperator
+from airflow.models.dataset import Dataset, DatasetDagRunQueue
+from airflow.models.serialized_dag import SerializedDagModel
 from airflow.models.taskmap import TaskMap
 from airflow.models.xcom_arg import XComArg
+from airflow.operators.bash import BashOperator
 from airflow.operators.empty import EmptyOperator
 from airflow.operators.python import ShortCircuitOperator
 from airflow.serialization.serialized_objects import SerializedDAG
@@ -41,7 +53,13 @@ from airflow.utils.state import DagRunState, State, 
TaskInstanceState
 from airflow.utils.trigger_rule import TriggerRule
 from airflow.utils.types import DagRunType
 from tests.models import DEFAULT_DATE as _DEFAULT_DATE
-from tests.test_utils.db import clear_db_dags, clear_db_pools, clear_db_runs, 
clear_db_variables
+from tests.test_utils.db import (
+    clear_db_dags,
+    clear_db_datasets,
+    clear_db_pools,
+    clear_db_runs,
+    clear_db_variables,
+)
 from tests.test_utils.mock_operators import MockOperator
 
 DEFAULT_DATE = pendulum.instance(_DEFAULT_DATE)
@@ -55,12 +73,14 @@ class TestDagRun:
         clear_db_pools()
         clear_db_dags()
         clear_db_variables()
+        clear_db_datasets()
 
     def teardown_method(self) -> None:
         clear_db_runs()
         clear_db_pools()
         clear_db_dags()
         clear_db_variables()
+        clear_db_datasets()
 
     def create_dag_run(
         self,
@@ -1291,6 +1311,47 @@ def test_mapped_task_upstream_failed(dag_maker, session):
     assert dr.state == DagRunState.FAILED
 
 
+def test_dataset_dagruns_triggered(session):
+    unique_id = str(uuid4())
+    session = settings.Session()
+    dag1 = DAG(dag_id=f"datasets-{unique_id}-1", start_date=timezone.utcnow())
+    dataset1 = Dataset(uri=f"s3://{unique_id}-1")
+    dataset2 = Dataset(uri=f"s3://{unique_id}-2")
+    dag2 = DAG(dag_id=f"datasets-{unique_id}-2", schedule_on=[dataset1, 
dataset2])
+    dag3 = DAG(dag_id=f"datasets-{unique_id}-3", schedule_on=[dataset1])
+    task = BashOperator(task_id="task", bash_command="echo 1", dag=dag1, 
outlets=[dataset1])
+    # BashOperator(task_id="task", bash_command="echo 1", dag=dag2)
+    # BashOperator(task_id="task", bash_command="echo 1", dag=dag3)
+    DAG.bulk_write_to_db(dags=[dag1, dag2, dag3], session=session)
+    session.commit()
+    dr = DagRun(dag1.dag_id, run_id=unique_id, run_type='anything')
+    dr.dag = dag1
+    session.add(dr)
+    session.add(TaskInstance(task=task, run_id=unique_id, state=State.SUCCESS))
+    session.commit()
+
+    session.bulk_save_objects(
+        [
+            *[SerializedDagModel(dag) for dag in [dag1, dag2, dag3]],
+            DatasetDagRunQueue(dataset_id=dataset1.id, 
target_dag_id=dag2.dag_id),
+            DatasetDagRunQueue(dataset_id=dataset1.id, 
target_dag_id=dag3.dag_id),
+        ]
+    )
+    session.commit()
+    session.expunge_all()
+    dr.update_state(session=session)
+    session.commit()
+
+    # dag3 should be triggered since it only depends on dataset1, and it's 
been queued
+    assert session.query(DagRun).filter(DagRun.dag_id == dag3.dag_id).one() is 
not None
+    # dag3 DDRQ record should still be there since the dag run was *not* 
triggered
+    assert session.query(DatasetDagRunQueue).filter(DagRun.dag_id == 
dag3.dag_id).one() is not None
+    # dag2 should not be triggered since it depends on both dataset 1  and 2
+    assert session.query(DagRun).filter(DagRun.dag_id == 
dag2.dag_id).one_or_none() is None
+    # dag2 DDRQ record should be deleted since the dag run was triggered
+    assert session.query(DatasetDagRunQueue).filter(DagRun.dag_id == 
dag2.dag_id).one_or_none() is None
+
+
 def test_mapped_task_all_finish_before_downstream(dag_maker, session):
     result = None
 
diff --git a/tests/models/test_taskinstance.py 
b/tests/models/test_taskinstance.py
index da3d138306..0bb38071e8 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -27,6 +27,7 @@ from traceback import format_exception
 from typing import List, Optional, Union, cast
 from unittest import mock
 from unittest.mock import call, mock_open, patch
+from uuid import uuid4
 
 import pendulum
 import pytest
@@ -47,6 +48,7 @@ from airflow.exceptions import (
 from airflow.models import (
     DAG,
     Connection,
+    DagBag,
     DagRun,
     Pool,
     RenderedTaskInstanceFields,
@@ -55,6 +57,7 @@ from airflow.models import (
     Variable,
     XCom,
 )
+from airflow.models.dataset import DatasetDagRunQueue, DatasetTaskRef
 from airflow.models.serialized_dag import SerializedDagModel
 from airflow.models.taskfail import TaskFail
 from airflow.models.taskinstance import TaskInstance
@@ -1473,6 +1476,33 @@ class TestTaskInstance:
         ti.refresh_from_db()
         assert ti.state == State.SUCCESS
 
+    def test_outlet_datasets(self, create_task_instance):
+        """
+        Verify that when we have an outlet dataset on a task, and the task
+        completes successfully, a DatasetDagRunQueue is logged.
+        """
+        from airflow.example_dags import example_datasets
+        from airflow.example_dags.example_datasets import dag1
+
+        session = settings.Session()
+        dagbag = DagBag(dag_folder=example_datasets.__file__)
+        dagbag.collect_dags(only_if_updated=False, safe_mode=False)
+        dagbag.sync_to_db(session=session)
+        run_id = str(uuid4())
+        dr = DagRun(dag1.dag_id, run_id=run_id, run_type='anything')
+        session.merge(dr)
+        task = dag1.get_task('upstream_task_1')
+        task.bash_command = 'echo 1'  # make it go faster
+        ti = TaskInstance(task, run_id=run_id)
+        session.merge(ti)
+        session.commit()
+        ti._run_raw_task()
+        ti.refresh_from_db()
+        assert ti.state == State.SUCCESS
+        assert session.query(DatasetDagRunQueue.target_dag_id).filter(
+            DatasetTaskRef.dag_id == dag1.dag_id, DatasetTaskRef.task_id == 
'upstream_task_1'
+        ).all() == [('dag3',), ('dag4',), ('dag5',)]
+
     @staticmethod
     def _test_previous_dates_setup(
         schedule_interval: Union[str, datetime.timedelta, None],
@@ -2317,6 +2347,7 @@ class TestRunRawTaskQueriesCount:
         db.clear_db_dags()
         db.clear_db_sla_miss()
         db.clear_db_import_errors()
+        db.clear_db_datasets()
 
     def setup_method(self) -> None:
         self._clean()
diff --git a/tests/test_utils/db.py b/tests/test_utils/db.py
index 948e434a9c..be91413b55 100644
--- a/tests/test_utils/db.py
+++ b/tests/test_utils/db.py
@@ -23,6 +23,7 @@ from airflow.models import (
     DagRun,
     DagTag,
     DagWarning,
+    Dataset,
     DbCallbackRequest,
     Log,
     Pool,
@@ -52,6 +53,11 @@ def clear_db_runs():
         session.query(TaskInstance).delete()
 
 
+def clear_db_datasets():
+    with create_session() as session:
+        session.query(Dataset).delete()
+
+
 def clear_db_dags():
     with create_session() as session:
         session.query(DagTag).delete()
diff --git a/tests/utils/test_db_cleanup.py b/tests/utils/test_db_cleanup.py
index 5c05545919..a7b48ff7e9 100644
--- a/tests/utils/test_db_cleanup.py
+++ b/tests/utils/test_db_cleanup.py
@@ -263,6 +263,9 @@ class TestDBCleanup:
             'dag_warning',  # self-maintaining
             'connection',  # leave alone
             'slot_pool',  # leave alone
+            'dataset_dag_ref',  # leave alone for now
+            'dataset_task_ref',  # leave alone for now
+            'dataset_dag_run_queue',  # self-managed
         }
 
         from airflow.utils.db_cleanup import config_dict
diff --git a/tests/utils/test_helpers.py b/tests/utils/test_helpers.py
index 64d7d23f8f..8e8f799dd5 100644
--- a/tests/utils/test_helpers.py
+++ b/tests/utils/test_helpers.py
@@ -23,6 +23,7 @@ import pytest
 from airflow import AirflowException
 from airflow.utils import helpers, timezone
 from airflow.utils.helpers import (
+    at_most_one,
     build_airflow_url_with_query,
     exactly_one,
     merge_dicts,
@@ -30,6 +31,7 @@ from airflow.utils.helpers import (
     validate_group_key,
     validate_key,
 )
+from airflow.utils.types import NOTSET
 from tests.test_utils.config import conf_vars
 from tests.test_utils.db import clear_db_dags, clear_db_runs
 
@@ -264,6 +266,34 @@ class TestHelpers:
         with pytest.raises(ValueError):
             exactly_one([True, False])
 
+    def test_at_most_one(self):
+        """
+        Checks that when we set ``true_count`` elements to "truthy", and 
others to "falsy",
+        we get the expected return.
+        We check for both True / False, and truthy / falsy values 'a' and '', 
and verify that
+        they can safely be used in any combination.
+        NOTSET values should be ignored.
+        """
+
+        def assert_at_most_one(true=0, truthy=0, false=0, falsy=0, notset=0):
+            sample = []
+            for truth_value, num in [
+                (True, true),
+                (False, false),
+                ('a', truthy),
+                ('', falsy),
+                (NOTSET, notset),
+            ]:
+                if num:
+                    sample.extend([truth_value] * num)
+            if sample:
+                expected = True if true + truthy in (0, 1) else False
+                assert at_most_one(*sample) is expected
+
+        for row in product(range(4), range(4), range(4), range(4), range(4)):
+            print(row)
+            assert_at_most_one(*row)
+
     @pytest.mark.parametrize(
         'mode, expected',
         [

Reply via email to