This is an automated email from the ASF dual-hosted git repository.
dstandish pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 66a7428e75 Get dataset-driven scheduling working (#24743)
66a7428e75 is described below
commit 66a7428e754d0059b7bf735c7a83a4b6675c66ef
Author: Daniel Standish <[email protected]>
AuthorDate: Thu Jul 7 14:49:55 2022 -0700
Get dataset-driven scheduling working (#24743)
AIP-48
(https://cwiki.apache.org/confluence/display/AIRFLOW/AIP-48+Data+Dependency+Management+and+Data+Driven+Scheduling)
allows DAGs to reference upstream "datasets", and tasks to reference
downstream datasets through outlets. And when a DAG's upstream datasets have
been updated, a dagrun will be triggered.
This constitutes essentially the "initial commit" for AIP-48. We're gonna
have to make changes and iterate and tweak things here and there, but it's a
starting point that implements the basic functionality of the AIP.
---
airflow/api_connexion/openapi/v1.yaml | 1 +
airflow/example_dags/example_datasets.py | 133 +++++++++++++++++++++
.../versions/0114_2_4_0_add_dataset_model.py | 73 ++++++++++-
airflow/models/dag.py | 95 +++++++++++++--
airflow/models/dagrun.py | 77 +++++++++++-
airflow/models/dataset.py | 132 +++++++++++++++++++-
airflow/models/taskinstance.py | 18 ++-
airflow/serialization/enums.py | 1 +
airflow/serialization/schema.json | 37 ++++++
airflow/serialization/serialized_objects.py | 10 ++
airflow/utils/helpers.py | 19 +++
airflow/utils/types.py | 1 +
docs/apache-airflow/concepts/datasets.rst | 46 +++++++
docs/apache-airflow/concepts/index.rst | 1 +
tests/models/test_dag.py | 40 ++++++-
tests/models/test_dagrun.py | 65 +++++++++-
tests/models/test_taskinstance.py | 31 +++++
tests/test_utils/db.py | 6 +
tests/utils/test_db_cleanup.py | 3 +
tests/utils/test_helpers.py | 30 +++++
20 files changed, 797 insertions(+), 22 deletions(-)
diff --git a/airflow/api_connexion/openapi/v1.yaml
b/airflow/api_connexion/openapi/v1.yaml
index 6ce6c51608..d9c0cc6ed8 100644
--- a/airflow/api_connexion/openapi/v1.yaml
+++ b/airflow/api_connexion/openapi/v1.yaml
@@ -2487,6 +2487,7 @@ components:
- backfill
- manual
- scheduled
+ - dataset_triggered
state:
$ref: '#/components/schemas/DagState'
readOnly: true
diff --git a/airflow/example_dags/example_datasets.py
b/airflow/example_dags/example_datasets.py
new file mode 100644
index 0000000000..e04479b0f5
--- /dev/null
+++ b/airflow/example_dags/example_datasets.py
@@ -0,0 +1,133 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Example DAG for demonstrating behavior of Datasets feature.
+
+Notes on usage:
+
+Turn on all the dags.
+
+DAG dag1 should run because it's on a schedule.
+
+After dag1 runs, dag3 should be triggered immediately because its only
+dataset dependency is managed by dag1.
+
+No other dags should be triggered. Note that even though dag4 depends on
+the dataset in dag1, it will not be triggered until dag2 runs (and dag2 is
+left with no schedule so that we can trigger it manually).
+
+Next, trigger dag2. After dag2 finishes, dag4 should run.
+
+Dags 5 and 6 should not run because they depend on datasets that never get
updated.
+
+"""
+from datetime import datetime
+
+from airflow.models import DAG, Dataset
+from airflow.operators.bash import BashOperator
+
+# [START dataset_def]
+dag1_dataset = Dataset('s3://dag1/output_1.txt', extra={'hi': 'bye'})
+# [END dataset_def]
+dag2_dataset = Dataset('s3://dag2/output_1.txt', extra={'hi': 'bye'})
+
+dag1 = DAG(
+ dag_id='dag1',
+ catchup=False,
+ start_date=datetime(2020, 1, 1),
+ schedule_interval='@daily',
+ tags=['upstream'],
+)
+
+# [START task_outlet]
+BashOperator(outlets=[dag1_dataset], task_id='upstream_task_1',
bash_command="sleep 5", dag=dag1)
+# [END task_outlet]
+
+with DAG(
+ dag_id='dag2',
+ catchup=False,
+ start_date=datetime(2020, 1, 1),
+ schedule_interval=None,
+ tags=['upstream'],
+) as dag2:
+ BashOperator(
+ outlets=[dag2_dataset],
+ task_id='upstream_task_2',
+ bash_command="sleep 5",
+ )
+
+# [START dag_dep]
+dag3 = DAG(
+ dag_id='dag3',
+ catchup=False,
+ start_date=datetime(2020, 1, 1),
+ schedule_on=[dag1_dataset],
+ tags=['downstream'],
+)
+# [END dag_dep]
+
+BashOperator(
+ outlets=[Dataset('s3://downstream_1_task/dataset_other.txt')],
+ task_id='downstream_1',
+ bash_command="sleep 5",
+ dag=dag3,
+)
+
+with DAG(
+ dag_id='dag4',
+ catchup=False,
+ start_date=datetime(2020, 1, 1),
+ schedule_on=[dag1_dataset, dag2_dataset],
+ tags=['downstream'],
+) as dag4:
+ BashOperator(
+ outlets=[Dataset('s3://downstream_2_task/dataset_other_unknown.txt')],
+ task_id='downstream_2',
+ bash_command="sleep 5",
+ )
+
+with DAG(
+ dag_id='dag5',
+ catchup=False,
+ start_date=datetime(2020, 1, 1),
+ schedule_on=[
+ dag1_dataset,
+ Dataset('s3://this-dataset-doesnt-get-triggered'),
+ ],
+ tags=['downstream'],
+) as dag5:
+ BashOperator(
+ outlets=[Dataset('s3://downstream_2_task/dataset_other_unknown.txt')],
+ task_id='downstream_3',
+ bash_command="sleep 5",
+ )
+
+with DAG(
+ dag_id='dag6',
+ catchup=False,
+ start_date=datetime(2020, 1, 1),
+ schedule_on=[
+ Dataset('s3://unrelated/dataset3.txt'),
+ Dataset('s3://unrelated/dataset_other_unknown.txt'),
+ ],
+ tags=['unrelated'],
+) as dag6:
+ BashOperator(
+ task_id='unrelated_task',
+ outlets=[Dataset('s3://unrelated_task/dataset_other_unknown.txt')],
+ bash_command="sleep 5",
+ )
diff --git a/airflow/migrations/versions/0114_2_4_0_add_dataset_model.py
b/airflow/migrations/versions/0114_2_4_0_add_dataset_model.py
index 838f8780b0..9cfca3766c 100644
--- a/airflow/migrations/versions/0114_2_4_0_add_dataset_model.py
+++ b/airflow/migrations/versions/0114_2_4_0_add_dataset_model.py
@@ -26,9 +26,9 @@ Create Date: 2022-06-22 14:37:20.880672
import sqlalchemy as sa
from alembic import op
-from sqlalchemy import Integer, String
+from sqlalchemy import Integer, String, func
-from airflow.migrations.db_types import TIMESTAMP
+from airflow.migrations.db_types import TIMESTAMP, StringID
from airflow.utils.sqlalchemy import ExtendedJSON
revision = '0038cd0c28b4'
@@ -38,8 +38,7 @@ depends_on = None
airflow_version = '2.4.0'
-def upgrade():
- """Apply Add Dataset model"""
+def _create_dataset_table():
op.create_table(
'dataset',
sa.Column('id', Integer, primary_key=True, autoincrement=True),
@@ -64,6 +63,72 @@ def upgrade():
op.create_index('idx_uri_unique', 'dataset', ['uri'], unique=True)
+def _create_dataset_dag_ref_table():
+ op.create_table(
+ 'dataset_dag_ref',
+ sa.Column('dataset_id', Integer, primary_key=True, nullable=False),
+ sa.Column('dag_id', String(250), primary_key=True, nullable=False),
+ sa.Column('created_at', TIMESTAMP, default=func.now, nullable=False),
+ sa.Column('updated_at', TIMESTAMP, default=func.now, nullable=False),
+ sa.ForeignKeyConstraint(
+ ('dataset_id',),
+ ['dataset.id'],
+ name="datasetdagref_dataset_fkey",
+ ondelete="CASCADE",
+ ),
+ sqlite_autoincrement=True, # ensures PK values not reused
+ )
+
+
+def _create_dataset_task_ref_table():
+ op.create_table(
+ 'dataset_task_ref',
+ sa.Column('dataset_id', Integer, primary_key=True, nullable=False),
+ sa.Column('dag_id', String(250), primary_key=True, nullable=False),
+ sa.Column('task_id', String(250), primary_key=True, nullable=False),
+ sa.Column('created_at', TIMESTAMP, default=func.now, nullable=False),
+ sa.Column('updated_at', TIMESTAMP, default=func.now, nullable=False),
+ sa.ForeignKeyConstraint(
+ ('dataset_id',),
+ ['dataset.id'],
+ name="datasettaskref_dataset_fkey",
+ ondelete="CASCADE",
+ ),
+ )
+
+
+def _create_dataset_dag_run_queue_table():
+ op.create_table(
+ 'dataset_dag_run_queue',
+ sa.Column('dataset_id', Integer, primary_key=True, nullable=False),
+ sa.Column('target_dag_id', StringID(), primary_key=True,
nullable=False),
+ sa.Column('created_at', TIMESTAMP, default=func.now, nullable=False),
+ sa.ForeignKeyConstraint(
+ ('dataset_id',),
+ ['dataset.id'],
+ name="ddrq_dataset_fkey",
+ ondelete="CASCADE",
+ ),
+ sa.ForeignKeyConstraint(
+ ('target_dag_id',),
+ ['dag.dag_id'],
+ name="ddrq_dag_fkey",
+ ondelete="CASCADE",
+ ),
+ )
+
+
+def upgrade():
+ """Apply Add Dataset model"""
+ _create_dataset_table()
+ _create_dataset_dag_ref_table()
+ _create_dataset_task_ref_table()
+ _create_dataset_dag_run_queue_table()
+
+
def downgrade():
"""Unapply Add Dataset model"""
+ op.drop_table('dataset_dag_ref')
+ op.drop_table('dataset_task_ref')
+ op.drop_table('dataset_dag_run_queue')
op.drop_table('dataset')
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index b16126e569..4cef709b16 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -39,6 +39,7 @@ from typing import (
FrozenSet,
Iterable,
List,
+ NamedTuple,
Optional,
Sequence,
Set,
@@ -83,7 +84,7 @@ from airflow.utils import timezone
from airflow.utils.dag_cycle_tester import check_cycle
from airflow.utils.dates import cron_presets, date_range as utils_date_range
from airflow.utils.file import correct_maybe_zipped
-from airflow.utils.helpers import exactly_one, validate_key
+from airflow.utils.helpers import at_most_one, exactly_one, validate_key
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.sqlalchemy import Interval, UtcDateTime, skip_locked,
tuple_in_condition, with_row_locks
@@ -92,6 +93,7 @@ from airflow.utils.types import NOTSET, ArgNotSet,
DagRunType, EdgeInfoType
if TYPE_CHECKING:
from airflow.decorators import TaskDecoratorCollection
+ from airflow.models.dataset import Dataset
from airflow.models.slamiss import SlaMiss
from airflow.utils.task_group import TaskGroup
@@ -279,6 +281,7 @@ class DAG(LoggingMixin):
to render templates as native Python types. If False, a Jinja
``Environment`` is used to render templates as string values.
:param tags: List of tags to help filtering DAGs in the UI.
+ :param schedule_on: List of upstream datasets if for use in triggering DAG
runs.
"""
_comps = {
@@ -338,6 +341,7 @@ class DAG(LoggingMixin):
jinja_environment_kwargs: Optional[Dict] = None,
render_template_as_native_obj: bool = False,
tags: Optional[List[str]] = None,
+ schedule_on: Optional[Sequence["Dataset"]] = None,
):
from airflow.utils.task_group import TaskGroup
@@ -415,17 +419,29 @@ class DAG(LoggingMixin):
if 'end_date' in self.default_args:
self.default_args['end_date'] =
timezone.convert_to_utc(self.default_args['end_date'])
- # Calculate the DAG's timetable.
- if timetable is None:
- self.timetable = create_timetable(schedule_interval, self.timezone)
- if isinstance(schedule_interval, ArgNotSet):
- schedule_interval = DEFAULT_SCHEDULE_INTERVAL
- self.schedule_interval: ScheduleInterval = schedule_interval
- elif isinstance(schedule_interval, ArgNotSet):
+ # sort out DAG's scheduling behavior
+ scheduling_args = [schedule_interval, timetable, schedule_on]
+ if not at_most_one(*scheduling_args):
+ raise ValueError(
+ "At most one allowed for args 'schedule_interval',
'timetable', and 'schedule_on'."
+ )
+
+ self.timetable: Timetable
+ self.schedule_interval: ScheduleInterval
+ self.schedule_on: Optional[List["Dataset"]] = list(schedule_on) if
schedule_on else None
+ if schedule_on:
+ if not isinstance(schedule_on, Sequence):
+ raise ValueError("Param `schedule_on` must be
Sequence[Dataset]")
+ self.schedule_interval = None
+ self.timetable = NullTimetable()
+ elif timetable:
self.timetable = timetable
self.schedule_interval = self.timetable.summary
else:
- raise TypeError("cannot specify both 'schedule_interval' and
'timetable'")
+ if isinstance(schedule_interval, ArgNotSet):
+ schedule_interval = DEFAULT_SCHEDULE_INTERVAL
+ self.schedule_interval = schedule_interval
+ self.timetable = create_timetable(schedule_interval, self.timezone)
if isinstance(template_searchpath, str):
template_searchpath = [template_searchpath]
@@ -2418,6 +2434,7 @@ class DAG(LoggingMixin):
log.info("Sync %s DAGs", len(dags))
dag_by_ids = {dag.dag_id: dag for dag in dags}
+
dag_ids = set(dag_by_ids.keys())
query = (
session.query(DagModel)
@@ -2509,6 +2526,65 @@ class DAG(LoggingMixin):
DagCode.bulk_sync_to_db(filelocs, session=session)
+ from airflow.models.dataset import Dataset, DatasetDagRef,
DatasetTaskRef
+
+ class OutletRef(NamedTuple):
+ dag_id: str
+ task_id: str
+ uri: str
+
+ class InletRef(NamedTuple):
+ dag_id: str
+ uri: str
+
+ dag_references = set()
+ outlet_references = set()
+ outlet_datasets = set()
+ input_datasets = set()
+ for dag in dags:
+ for dataset in dag.schedule_on or []:
+ dag_references.add(InletRef(dag.dag_id, dataset.uri))
+ input_datasets.add(dataset)
+ for task in dag.tasks:
+ for obj in getattr(task, '_outlets', []): # type: Dataset
+ if isinstance(obj, Dataset):
+ outlet_references.add(OutletRef(task.dag_id,
task.task_id, obj.uri))
+ outlet_datasets.add(obj)
+ all_datasets = outlet_datasets.union(input_datasets)
+
+ # store datasets
+ stored_datasets = {}
+ for dataset in all_datasets:
+ stored_dataset = session.query(Dataset).filter(Dataset.uri ==
dataset.uri).first()
+ if stored_dataset:
+ stored_datasets[stored_dataset.uri] = stored_dataset
+ else:
+ session.add(dataset)
+ stored_datasets[dataset.uri] = dataset
+
+ session.flush() # this is required to ensure each dataset has its PK
loaded
+
+ del all_datasets
+
+ # store dag-schedule-on-dataset references
+ for dag_ref in dag_references:
+ session.merge(
+ DatasetDagRef(
+ dataset_id=stored_datasets[dag_ref.uri].id,
+ dag_id=dag_ref.dag_id,
+ )
+ )
+
+ # store task-outlet-dataset references
+ for outlet_ref in outlet_references:
+ session.merge(
+ DatasetTaskRef(
+ dataset_id=stored_datasets[outlet_ref.uri].id,
+ dag_id=outlet_ref.dag_id,
+ task_id=outlet_ref.task_id,
+ )
+ )
+
# Issue SQL/finish "Unit of Work", but let @provide_session commit (or
if passed a session, let caller
# decide when to commit
session.flush()
@@ -2731,7 +2807,6 @@ class DagModel(Base):
schedule_interval = Column(Interval)
# Timetable/Schedule Interval description
timetable_description = Column(String(1000), nullable=True)
-
# Tags for view filter
tags = relationship('DagTag', cascade='all, delete, delete-orphan',
backref=backref('dag'))
diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py
index a22b7d34c9..31888fe22e 100644
--- a/airflow/models/dagrun.py
+++ b/airflow/models/dagrun.py
@@ -55,7 +55,7 @@ from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.orm import joinedload, relationship, synonym
from sqlalchemy.orm.session import Session
-from sqlalchemy.sql.expression import false, select, true
+from sqlalchemy.sql.expression import case, false, select, true
from airflow import settings
from airflow.callbacks.callback_requests import DagCallbackRequest
@@ -631,8 +631,83 @@ class DagRun(Base, LoggingMixin):
session.merge(self)
# We do not flush here for performance reasons(It increases queries
count by +20)
+ self._process_dataset_dagrun_events(session=session)
+
return schedulable_tis, callback
+ def _process_dataset_dagrun_events(self, *, session=NEW_SESSION):
+ """
+ Looks at all outlet datasets that have been updated by this dag,
+ and creates DAG runs that have all dataset deps fulfilled.
+ """
+ from airflow.models.dataset import Dataset, DatasetDagRef,
DatasetTaskRef
+
+ has_dataset_outlets = False
+ if self.dag:
+ for _, task in self.dag.task_dict.items():
+ if has_dataset_outlets is True:
+ break
+ for obj in getattr(task, '_outlets', []):
+ if isinstance(obj, Dataset):
+ has_dataset_outlets = True
+ break
+ dependent_dag_ids = []
+ if self.dag and has_dataset_outlets:
+ dependent_dag_ids = [
+ x.dag_id
+ for x in session.query(DatasetDagRef.dag_id)
+ .filter(DatasetTaskRef.dag_id == self.dag_id)
+ .all()
+ ]
+
+ from airflow.models.dataset import DatasetDagRunQueue as DDRQ
+ from airflow.models.serialized_dag import SerializedDagModel
+
+ dag_ids_to_trigger = None
+ if dependent_dag_ids:
+ dag_ids_to_trigger = [
+ x.dag_id
+ for x in session.query(
+ DatasetDagRef.dag_id,
+ )
+ .join(
+ DDRQ,
+ and_(
+ DDRQ.dataset_id == DatasetDagRef.dataset_id,
+ DDRQ.target_dag_id == DatasetDagRef.dag_id,
+ ),
+ isouter=True,
+ )
+ .filter(DatasetDagRef.dag_id.in_(dependent_dag_ids))
+ .group_by(DatasetDagRef.dag_id)
+ .having(func.count() ==
func.sum(case((DDRQ.target_dag_id.is_not(None), 1), else_=0)))
+ .all()
+ ]
+
+ if dag_ids_to_trigger:
+ dags_to_purge_from_queue = set()
+ for target_dag_id in dag_ids_to_trigger:
+ row = SerializedDagModel.get(target_dag_id, session)
+ if not row:
+ self.log.warning("Could not find serialized DAG %s",
target_dag_id)
+ continue
+ dag = row.dag
+ if dag.schedule_on:
+ dag.create_dagrun(
+ run_type=DagRunType.MANUAL,
+ run_id=self.generate_run_id(
+ DagRunType.DATASET_TRIGGERED,
execution_date=timezone.utcnow()
+ ),
+ state=DagRunState.QUEUED,
+ session=session,
+ )
+ else:
+ self.log.warning(
+ "DAG %s no longer has a dataset scheduling dep;
purging queue records.", dag.dag_id
+ )
+ dags_to_purge_from_queue.add(target_dag_id)
+
session.query(DDRQ).filter(DDRQ.target_dag_id.in_(dags_to_purge_from_queue)).delete()
+
@provide_session
def task_instance_scheduling_decisions(self, session: Session =
NEW_SESSION) -> TISchedulingDecision:
tis = self.get_task_instances(session=session, state=State.task_states)
diff --git a/airflow/models/dataset.py b/airflow/models/dataset.py
index 49f5dfd1f6..256ed2293e 100644
--- a/airflow/models/dataset.py
+++ b/airflow/models/dataset.py
@@ -17,9 +17,10 @@
# under the License.
from urllib.parse import urlparse
-from sqlalchemy import Column, Index, Integer, String
+from sqlalchemy import Column, ForeignKeyConstraint, Index, Integer,
PrimaryKeyConstraint, String
+from sqlalchemy.orm import relationship
-from airflow.models.base import Base
+from airflow.models.base import ID_LEN, Base, StringID
from airflow.utils import timezone
from airflow.utils.sqlalchemy import ExtendedJSON, UtcDateTime
@@ -49,6 +50,9 @@ class Dataset(Base):
created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False)
updated_at = Column(UtcDateTime, default=timezone.utcnow,
onupdate=timezone.utcnow, nullable=False)
+ dag_references = relationship("DatasetDagRef", back_populates="dataset")
+ task_references = relationship("DatasetTaskRef", back_populates="dataset")
+
__tablename__ = "dataset"
__table_args__ = (
Index('idx_uri_unique', uri, unique=True),
@@ -66,10 +70,132 @@ class Dataset(Base):
super().__init__(uri=uri, **kwargs)
def __eq__(self, other):
- return self.uri == other.uri
+ if isinstance(other, self.__class__):
+ return self.uri == other.uri
+ else:
+ return NotImplemented
def __hash__(self):
return hash(self.uri)
def __repr__(self):
return f"{self.__class__.__name__}(uri={self.uri!r},
extra={self.extra!r})"
+
+
+class DatasetDagRef(Base):
+ """References from a DAG to an upstream dataset."""
+
+ dataset_id = Column(Integer, primary_key=True, nullable=False)
+ dag_id = Column(String(ID_LEN), primary_key=True, nullable=False)
+ created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False)
+ updated_at = Column(UtcDateTime, default=timezone.utcnow,
onupdate=timezone.utcnow, nullable=False)
+
+ dataset = relationship('Dataset')
+
+ __tablename__ = "dataset_dag_ref"
+ __table_args__ = (
+ PrimaryKeyConstraint(dataset_id, dag_id, name="datasetdagref_pkey",
mssql_clustered=True),
+ ForeignKeyConstraint(
+ (dataset_id,),
+ ["dataset.id"],
+ name='datasetdagref_dataset_fkey',
+ ondelete="CASCADE",
+ ),
+ )
+
+ def __eq__(self, other):
+ if isinstance(other, self.__class__):
+ return self.dataset_id == other.dataset_id and self.dag_id ==
other.dag_id
+ else:
+ return NotImplemented
+
+ def __hash__(self):
+ return hash(self.__mapper__.primary_key)
+
+ def __repr__(self):
+ args = []
+ for attr in [x.name for x in self.__mapper__.primary_key]:
+ args.append(f"{attr}={getattr(self, attr)!r}")
+ return f"{self.__class__.__name__}({', '.join(args)})"
+
+
+class DatasetTaskRef(Base):
+ """References from a task to a downstream dataset."""
+
+ dataset_id = Column(Integer, primary_key=True, nullable=False)
+ dag_id = Column(String(ID_LEN), primary_key=True, nullable=False)
+ task_id = Column(String(ID_LEN), primary_key=True, nullable=False)
+ created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False)
+ updated_at = Column(UtcDateTime, default=timezone.utcnow,
onupdate=timezone.utcnow, nullable=False)
+
+ dataset = relationship("Dataset", back_populates="task_references")
+
+ __tablename__ = "dataset_task_ref"
+ __table_args__ = (
+ ForeignKeyConstraint(
+ (dataset_id,),
+ ["dataset.id"],
+ name='datasettaskref_dataset_fkey',
+ ondelete="CASCADE",
+ ),
+ PrimaryKeyConstraint(dataset_id, dag_id, task_id,
name="datasettaskref_pkey", mssql_clustered=True),
+ )
+
+ def __eq__(self, other):
+ if isinstance(other, self.__class__):
+ return (
+ self.dataset_id == other.dataset_id
+ and self.dag_id == other.dag_id
+ and self.task_id == other.task_id
+ )
+ else:
+ return NotImplemented
+
+ def __hash__(self):
+ return hash(self.__mapper__.primary_key)
+
+ def __repr__(self):
+ args = []
+ for attr in [x.name for x in self.__mapper__.primary_key]:
+ args.append(f"{attr}={getattr(self, attr)!r}")
+ return f"{self.__class__.__name__}({', '.join(args)})"
+
+
+class DatasetDagRunQueue(Base):
+ """Model for storing dataset events that need processing."""
+
+ dataset_id = Column(Integer, primary_key=True, nullable=False)
+ target_dag_id = Column(StringID(), primary_key=True, nullable=False)
+ created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False)
+
+ __tablename__ = "dataset_dag_run_queue"
+ __table_args__ = (
+ PrimaryKeyConstraint(dataset_id, target_dag_id,
name="datasetdagrunqueue_pkey", mssql_clustered=True),
+ ForeignKeyConstraint(
+ (dataset_id,),
+ ["dataset.id"],
+ name='ddrq_dataset_fkey',
+ ondelete="CASCADE",
+ ),
+ ForeignKeyConstraint(
+ (target_dag_id,),
+ ["dag.dag_id"],
+ name='ddrq_dag_fkey',
+ ondelete="CASCADE",
+ ),
+ )
+
+ def __eq__(self, other):
+ if isinstance(other, self.__class__):
+ return self.dataset_id == other.dataset_id and self.target_dag_id
== other.target_dag_id
+ else:
+ return NotImplemented
+
+ def __hash__(self):
+ return hash(self.__mapper__.primary_key)
+
+ def __repr__(self):
+ args = []
+ for attr in [x.name for x in self.__mapper__.primary_key]:
+ args.append(f"{attr}={getattr(self, attr)!r}")
+ return f"{self.__class__.__name__}({', '.join(args)})"
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 8950385248..18bae29a0c 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -95,6 +95,7 @@ from airflow.exceptions import (
XComForMappingNotPushed,
)
from airflow.models.base import Base, StringID
+from airflow.models.dataset import DatasetDagRunQueue
from airflow.models.log import Log
from airflow.models.param import ParamsDict
from airflow.models.taskfail import TaskFail
@@ -1512,9 +1513,24 @@ class TaskInstance(Base, LoggingMixin):
if not test_mode:
session.add(Log(self.state, self))
session.merge(self)
-
+ self._create_dataset_dag_run_queue_records(session=session)
session.commit()
+ def _create_dataset_dag_run_queue_records(self, *, session):
+ from airflow.models import Dataset
+
+ for obj in getattr(self.task, '_outlets', []):
+ self.log.debug("outlet obj %s", obj)
+ if isinstance(obj, Dataset):
+ dataset = session.query(Dataset).filter(Dataset.uri ==
obj.uri).one_or_none()
+ if not dataset:
+ self.log.warning("Dataset %s not found", obj)
+ continue
+ downstream_dag_ids = [x.dag_id for x in dataset.dag_references]
+ self.log.debug("downstream dag ids %s", downstream_dag_ids)
+ for dag_id in downstream_dag_ids:
+ session.merge(DatasetDagRunQueue(dataset_id=dataset.id,
target_dag_id=dag_id))
+
def _execute_task_with_callbacks(self, context, test_mode=False):
"""Prepare Task for Execution"""
from airflow.models.renderedtifields import RenderedTaskInstanceFields
diff --git a/airflow/serialization/enums.py b/airflow/serialization/enums.py
index f4227a6f7a..420b3e015b 100644
--- a/airflow/serialization/enums.py
+++ b/airflow/serialization/enums.py
@@ -49,3 +49,4 @@ class DagAttributeTypes(str, Enum):
EDGE_INFO = 'edgeinfo'
PARAM = 'param'
XCOM_REF = 'xcomref'
+ DATASET = 'dataset'
diff --git a/airflow/serialization/schema.json
b/airflow/serialization/schema.json
index 423950bb5e..1550387eed 100644
--- a/airflow/serialization/schema.json
+++ b/airflow/serialization/schema.json
@@ -53,6 +53,34 @@
{ "type": "integer" }
]
},
+ "dataset": {
+ "type": "object",
+ "properties": {
+ "uri": { "type": "string" },
+ "extra": {
+ "anyOf": [
+ {"type": "null"},
+ { "$ref": "#/definitions/dict" }
+ ]
+ }
+ },
+ "required": [ "uri", "extra" ]
+ },
+ "typed_dataset": {
+ "type": "object",
+ "properties": {
+ "__type": {
+ "type": "string",
+ "constant": "dataset"
+ },
+ "__var": { "$ref": "#/definitions/dataset" }
+ },
+ "required": [
+ "__type",
+ "__var"
+ ],
+ "additionalProperties": false
+ },
"dict": {
"description": "A python dictionary containing values of any type",
"type": "object"
@@ -90,6 +118,15 @@
{ "$ref": "#/definitions/typed_relativedelta" }
]
},
+ "schedule_on":{
+ "anyOf": [
+ { "type": "null" },
+ {
+ "type": "array",
+ "items": { "$ref": "#/definitions/typed_dataset" }
+ }
+ ]
+ },
"timetable": {
"type": "object",
"properties": {
diff --git a/airflow/serialization/serialized_objects.py
b/airflow/serialization/serialized_objects.py
index bd0430ba26..f4a4257f57 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -33,6 +33,7 @@ from pendulum.tz.timezone import FixedTimezone, Timezone
from airflow.compat.functools import cache
from airflow.configuration import conf
from airflow.exceptions import AirflowException, SerializationError
+from airflow.models import Dataset
from airflow.models.baseoperator import BaseOperator, BaseOperatorLink
from airflow.models.connection import Connection
from airflow.models.dag import DAG, create_timetable
@@ -370,6 +371,8 @@ class BaseSerialization:
return cls._encode(cls._serialize_param(var), type_=DAT.PARAM)
elif isinstance(var, XComArg):
return cls._encode(cls._serialize_xcomarg(var), type_=DAT.XCOM_REF)
+ elif isinstance(var, Dataset):
+ return cls._encode(dict(uri=var.uri, extra=var.extra),
type_=DAT.DATASET)
else:
log.debug('Cast type %s to str in serialization.', type(var))
return str(var)
@@ -415,6 +418,8 @@ class BaseSerialization:
return cls._deserialize_param(var)
elif type_ == DAT.XCOM_REF:
return cls._deserialize_xcomref(var)
+ elif type_ == DAT.DATASET:
+ return Dataset(**var)
else:
raise TypeError(f'Invalid type {type_!s} in deserialization.')
@@ -746,6 +751,9 @@ class SerializedBaseOperator(BaseOperator,
BaseSerialization):
v = {arg: cls._deserialize(value) for arg, value in v.items()}
elif k in cls._decorated_fields or k not in
op.get_serialized_fields():
v = cls._deserialize(v)
+ elif k in ("_outlets", "_inlets"):
+ v = cls._deserialize(v)
+
# else use v as it is
setattr(op, k, v)
@@ -1024,6 +1032,8 @@ class SerializedDAG(DAG, BaseSerialization):
v = cls._deserialize(v)
elif k == "params":
v = cls._deserialize_params_dict(v)
+ elif k == "schedule_on":
+ v = cls._deserialize(v)
# else use v as it is
setattr(dag, k, v)
diff --git a/airflow/utils/helpers.py b/airflow/utils/helpers.py
index e8366805b4..769f7986fb 100644
--- a/airflow/utils/helpers.py
+++ b/airflow/utils/helpers.py
@@ -39,6 +39,7 @@ from typing import (
from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.utils.module_loading import import_string
+from airflow.utils.types import NOTSET
if TYPE_CHECKING:
import jinja2
@@ -314,6 +315,24 @@ def exactly_one(*args) -> bool:
return sum(map(bool, args)) == 1
+def at_most_one(*args) -> bool:
+ """
+ Returns True if at most one of *args is "truthy", and False otherwise.
+
+ NOTSET is treated the same as None.
+
+ If user supplies an iterable, we raise ValueError and force them to unpack.
+ """
+
+ def is_set(val):
+ if val is NOTSET:
+ return False
+ else:
+ return bool(val)
+
+ return sum(map(is_set, args)) in (0, 1)
+
+
def prune_dict(val: Any, mode='strict'):
"""
Given dict ``val``, returns new dict based on ``val`` with all
diff --git a/airflow/utils/types.py b/airflow/utils/types.py
index 04688a7f65..313780bb2f 100644
--- a/airflow/utils/types.py
+++ b/airflow/utils/types.py
@@ -45,6 +45,7 @@ class DagRunType(str, enum.Enum):
BACKFILL_JOB = "backfill"
SCHEDULED = "scheduled"
MANUAL = "manual"
+ DATASET_TRIGGERED = "dataset_triggered"
def __str__(self) -> str:
return self.value
diff --git a/docs/apache-airflow/concepts/datasets.rst
b/docs/apache-airflow/concepts/datasets.rst
new file mode 100644
index 0000000000..211349db56
--- /dev/null
+++ b/docs/apache-airflow/concepts/datasets.rst
@@ -0,0 +1,46 @@
+ .. Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ .. http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+
+Datasets
+========
+
+.. versionadded:: 2.4
+
+With datasets, instead of running a DAG on a schedule, a DAG can be configured
to run when a dataset has been updated.
+
+To use this feature, define a dataset:
+
+.. exampleinclude:: /../../airflow/example_dags/example_datasets.py
+ :language: python
+ :start-after: [START dataset_def]
+ :end-before: [END dataset_def]
+
+Then reference the dataset as a task outlet:
+
+.. exampleinclude:: /../../airflow/example_dags/example_datasets.py
+ :language: python
+ :start-after: [START task_outlet]
+ :end-before: [END task_outlet]
+
+Finally, define a DAG and reference this dataset in the DAG's ``schedule_on``
parameter:
+
+.. exampleinclude:: /../../airflow/example_dags/example_datasets.py
+ :language: python
+ :start-after: [START dag_dep]
+ :end-before: [END dag_dep]
+
+You can reference multiple datasets in the DAG's ``schedule_on`` param. Once
there has been an update to all of the upstream datasets, the DAG will be
triggered. This means that the DAG will run as frequently as its
least-frequently-updated dataset.
diff --git a/docs/apache-airflow/concepts/index.rst
b/docs/apache-airflow/concepts/index.rst
index 122dc760fe..9297c54e00 100644
--- a/docs/apache-airflow/concepts/index.rst
+++ b/docs/apache-airflow/concepts/index.rst
@@ -38,6 +38,7 @@ Here you can find detailed documentation about each one of
Airflow's core concep
operators
dynamic-task-mapping
sensors
+ datasets
deferring
smart-sensors
taskflow
diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py
index 0164ce0f87..50c49f3819 100644
--- a/tests/models/test_dag.py
+++ b/tests/models/test_dag.py
@@ -43,9 +43,10 @@ from airflow import models, settings
from airflow.configuration import conf
from airflow.decorators import task as task_decorator
from airflow.exceptions import AirflowException, DuplicateTaskIdFound,
ParamValidationError
-from airflow.models import DAG, DagModel, DagRun, DagTag, TaskFail,
TaskInstance as TI
+from airflow.models import DAG, DagModel, DagRun, DagTag, Dataset, TaskFail,
TaskInstance as TI
from airflow.models.baseoperator import BaseOperator
from airflow.models.dag import dag as dag_decorator
+from airflow.models.dataset import DatasetTaskRef
from airflow.models.param import DagParam, Param, ParamsDict
from airflow.operators.bash import BashOperator
from airflow.operators.empty import EmptyOperator
@@ -821,6 +822,43 @@ class TestDag(unittest.TestCase):
assert not model.has_import_errors
session.close()
+ def test_bulk_write_to_db_datasets_schedule_on(self):
+ """
+ Ensure that datasets referenced in a dag are correctly loaded into the
database.
+ """
+ # todo: clear db
+ dag_id1 = 'test_dataset_dag1'
+ dag_id2 = 'test_dataset_dag2'
+ task_id = 'test_dataset_task'
+ uri1 = 's3://dataset1'
+ d1 = Dataset(uri1, extra={"not": "used"})
+ d2 = Dataset('s3://dataset2')
+ d3 = Dataset('s3://dataset3')
+ dag1 = DAG(dag_id=dag_id1, start_date=DEFAULT_DATE, schedule_on=[d1])
+ EmptyOperator(task_id=task_id, dag=dag1, outlets=[d2, d3])
+ dag2 = DAG(dag_id=dag_id2, start_date=DEFAULT_DATE)
+ EmptyOperator(task_id=task_id, dag=dag2, outlets=[Dataset(uri1,
extra={"should": "be used"})])
+ session = settings.Session()
+ dag1.clear()
+ DAG.bulk_write_to_db([dag1, dag2], session)
+ session.commit()
+ stored_datasets = {x.uri: x for x in session.query(Dataset).all()}
+ d1 = stored_datasets[d1.uri]
+ d2 = stored_datasets[d2.uri]
+ d3 = stored_datasets[d3.uri]
+ assert stored_datasets[uri1].extra == {"should": "be used"}
+ assert [x.dag_id for x in d1.dag_references] == [dag_id1]
+ assert [(x.task_id, x.dag_id) for x in d1.task_references] ==
[(task_id, dag_id2)]
+ assert set(
+ session.query(DatasetTaskRef.task_id, DatasetTaskRef.dag_id,
DatasetTaskRef.dataset_id)
+ .filter(DatasetTaskRef.dag_id.in_((dag_id1, dag_id2)))
+ .all()
+ ) == {
+ (task_id, dag_id1, d2.id),
+ (task_id, dag_id1, d3.id),
+ (task_id, dag_id2, d1.id),
+ }
+
def test_sync_to_db(self):
dag = DAG(
'dag',
diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py
index 6c3cc1c91c..397f519a50 100644
--- a/tests/models/test_dagrun.py
+++ b/tests/models/test_dagrun.py
@@ -20,6 +20,7 @@ import datetime
from typing import Mapping, Optional
from unittest import mock
from unittest.mock import call
+from uuid import uuid4
import pendulum
import pytest
@@ -28,10 +29,21 @@ from sqlalchemy.orm.session import Session
from airflow import settings
from airflow.callbacks.callback_requests import DagCallbackRequest
from airflow.decorators import task
-from airflow.models import DAG, DagBag, DagModel, DagRun, TaskInstance as TI,
clear_task_instances
+from airflow.models import (
+ DAG,
+ DagBag,
+ DagModel,
+ DagRun,
+ TaskInstance,
+ TaskInstance as TI,
+ clear_task_instances,
+)
from airflow.models.baseoperator import BaseOperator
+from airflow.models.dataset import Dataset, DatasetDagRunQueue
+from airflow.models.serialized_dag import SerializedDagModel
from airflow.models.taskmap import TaskMap
from airflow.models.xcom_arg import XComArg
+from airflow.operators.bash import BashOperator
from airflow.operators.empty import EmptyOperator
from airflow.operators.python import ShortCircuitOperator
from airflow.serialization.serialized_objects import SerializedDAG
@@ -41,7 +53,13 @@ from airflow.utils.state import DagRunState, State,
TaskInstanceState
from airflow.utils.trigger_rule import TriggerRule
from airflow.utils.types import DagRunType
from tests.models import DEFAULT_DATE as _DEFAULT_DATE
-from tests.test_utils.db import clear_db_dags, clear_db_pools, clear_db_runs,
clear_db_variables
+from tests.test_utils.db import (
+ clear_db_dags,
+ clear_db_datasets,
+ clear_db_pools,
+ clear_db_runs,
+ clear_db_variables,
+)
from tests.test_utils.mock_operators import MockOperator
DEFAULT_DATE = pendulum.instance(_DEFAULT_DATE)
@@ -55,12 +73,14 @@ class TestDagRun:
clear_db_pools()
clear_db_dags()
clear_db_variables()
+ clear_db_datasets()
def teardown_method(self) -> None:
clear_db_runs()
clear_db_pools()
clear_db_dags()
clear_db_variables()
+ clear_db_datasets()
def create_dag_run(
self,
@@ -1291,6 +1311,47 @@ def test_mapped_task_upstream_failed(dag_maker, session):
assert dr.state == DagRunState.FAILED
+def test_dataset_dagruns_triggered(session):
+ unique_id = str(uuid4())
+ session = settings.Session()
+ dag1 = DAG(dag_id=f"datasets-{unique_id}-1", start_date=timezone.utcnow())
+ dataset1 = Dataset(uri=f"s3://{unique_id}-1")
+ dataset2 = Dataset(uri=f"s3://{unique_id}-2")
+ dag2 = DAG(dag_id=f"datasets-{unique_id}-2", schedule_on=[dataset1,
dataset2])
+ dag3 = DAG(dag_id=f"datasets-{unique_id}-3", schedule_on=[dataset1])
+ task = BashOperator(task_id="task", bash_command="echo 1", dag=dag1,
outlets=[dataset1])
+ # BashOperator(task_id="task", bash_command="echo 1", dag=dag2)
+ # BashOperator(task_id="task", bash_command="echo 1", dag=dag3)
+ DAG.bulk_write_to_db(dags=[dag1, dag2, dag3], session=session)
+ session.commit()
+ dr = DagRun(dag1.dag_id, run_id=unique_id, run_type='anything')
+ dr.dag = dag1
+ session.add(dr)
+ session.add(TaskInstance(task=task, run_id=unique_id, state=State.SUCCESS))
+ session.commit()
+
+ session.bulk_save_objects(
+ [
+ *[SerializedDagModel(dag) for dag in [dag1, dag2, dag3]],
+ DatasetDagRunQueue(dataset_id=dataset1.id,
target_dag_id=dag2.dag_id),
+ DatasetDagRunQueue(dataset_id=dataset1.id,
target_dag_id=dag3.dag_id),
+ ]
+ )
+ session.commit()
+ session.expunge_all()
+ dr.update_state(session=session)
+ session.commit()
+
+ # dag3 should be triggered since it only depends on dataset1, and it's
been queued
+ assert session.query(DagRun).filter(DagRun.dag_id == dag3.dag_id).one() is
not None
+ # dag3 DDRQ record should still be there since the dag run was *not*
triggered
+ assert session.query(DatasetDagRunQueue).filter(DagRun.dag_id ==
dag3.dag_id).one() is not None
+ # dag2 should not be triggered since it depends on both dataset 1 and 2
+ assert session.query(DagRun).filter(DagRun.dag_id ==
dag2.dag_id).one_or_none() is None
+ # dag2 DDRQ record should be deleted since the dag run was triggered
+ assert session.query(DatasetDagRunQueue).filter(DagRun.dag_id ==
dag2.dag_id).one_or_none() is None
+
+
def test_mapped_task_all_finish_before_downstream(dag_maker, session):
result = None
diff --git a/tests/models/test_taskinstance.py
b/tests/models/test_taskinstance.py
index da3d138306..0bb38071e8 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -27,6 +27,7 @@ from traceback import format_exception
from typing import List, Optional, Union, cast
from unittest import mock
from unittest.mock import call, mock_open, patch
+from uuid import uuid4
import pendulum
import pytest
@@ -47,6 +48,7 @@ from airflow.exceptions import (
from airflow.models import (
DAG,
Connection,
+ DagBag,
DagRun,
Pool,
RenderedTaskInstanceFields,
@@ -55,6 +57,7 @@ from airflow.models import (
Variable,
XCom,
)
+from airflow.models.dataset import DatasetDagRunQueue, DatasetTaskRef
from airflow.models.serialized_dag import SerializedDagModel
from airflow.models.taskfail import TaskFail
from airflow.models.taskinstance import TaskInstance
@@ -1473,6 +1476,33 @@ class TestTaskInstance:
ti.refresh_from_db()
assert ti.state == State.SUCCESS
+ def test_outlet_datasets(self, create_task_instance):
+ """
+ Verify that when we have an outlet dataset on a task, and the task
+ completes successfully, a DatasetDagRunQueue is logged.
+ """
+ from airflow.example_dags import example_datasets
+ from airflow.example_dags.example_datasets import dag1
+
+ session = settings.Session()
+ dagbag = DagBag(dag_folder=example_datasets.__file__)
+ dagbag.collect_dags(only_if_updated=False, safe_mode=False)
+ dagbag.sync_to_db(session=session)
+ run_id = str(uuid4())
+ dr = DagRun(dag1.dag_id, run_id=run_id, run_type='anything')
+ session.merge(dr)
+ task = dag1.get_task('upstream_task_1')
+ task.bash_command = 'echo 1' # make it go faster
+ ti = TaskInstance(task, run_id=run_id)
+ session.merge(ti)
+ session.commit()
+ ti._run_raw_task()
+ ti.refresh_from_db()
+ assert ti.state == State.SUCCESS
+ assert session.query(DatasetDagRunQueue.target_dag_id).filter(
+ DatasetTaskRef.dag_id == dag1.dag_id, DatasetTaskRef.task_id ==
'upstream_task_1'
+ ).all() == [('dag3',), ('dag4',), ('dag5',)]
+
@staticmethod
def _test_previous_dates_setup(
schedule_interval: Union[str, datetime.timedelta, None],
@@ -2317,6 +2347,7 @@ class TestRunRawTaskQueriesCount:
db.clear_db_dags()
db.clear_db_sla_miss()
db.clear_db_import_errors()
+ db.clear_db_datasets()
def setup_method(self) -> None:
self._clean()
diff --git a/tests/test_utils/db.py b/tests/test_utils/db.py
index 948e434a9c..be91413b55 100644
--- a/tests/test_utils/db.py
+++ b/tests/test_utils/db.py
@@ -23,6 +23,7 @@ from airflow.models import (
DagRun,
DagTag,
DagWarning,
+ Dataset,
DbCallbackRequest,
Log,
Pool,
@@ -52,6 +53,11 @@ def clear_db_runs():
session.query(TaskInstance).delete()
+def clear_db_datasets():
+ with create_session() as session:
+ session.query(Dataset).delete()
+
+
def clear_db_dags():
with create_session() as session:
session.query(DagTag).delete()
diff --git a/tests/utils/test_db_cleanup.py b/tests/utils/test_db_cleanup.py
index 5c05545919..a7b48ff7e9 100644
--- a/tests/utils/test_db_cleanup.py
+++ b/tests/utils/test_db_cleanup.py
@@ -263,6 +263,9 @@ class TestDBCleanup:
'dag_warning', # self-maintaining
'connection', # leave alone
'slot_pool', # leave alone
+ 'dataset_dag_ref', # leave alone for now
+ 'dataset_task_ref', # leave alone for now
+ 'dataset_dag_run_queue', # self-managed
}
from airflow.utils.db_cleanup import config_dict
diff --git a/tests/utils/test_helpers.py b/tests/utils/test_helpers.py
index 64d7d23f8f..8e8f799dd5 100644
--- a/tests/utils/test_helpers.py
+++ b/tests/utils/test_helpers.py
@@ -23,6 +23,7 @@ import pytest
from airflow import AirflowException
from airflow.utils import helpers, timezone
from airflow.utils.helpers import (
+ at_most_one,
build_airflow_url_with_query,
exactly_one,
merge_dicts,
@@ -30,6 +31,7 @@ from airflow.utils.helpers import (
validate_group_key,
validate_key,
)
+from airflow.utils.types import NOTSET
from tests.test_utils.config import conf_vars
from tests.test_utils.db import clear_db_dags, clear_db_runs
@@ -264,6 +266,34 @@ class TestHelpers:
with pytest.raises(ValueError):
exactly_one([True, False])
+ def test_at_most_one(self):
+ """
+ Checks that when we set ``true_count`` elements to "truthy", and
others to "falsy",
+ we get the expected return.
+ We check for both True / False, and truthy / falsy values 'a' and '',
and verify that
+ they can safely be used in any combination.
+ NOTSET values should be ignored.
+ """
+
+ def assert_at_most_one(true=0, truthy=0, false=0, falsy=0, notset=0):
+ sample = []
+ for truth_value, num in [
+ (True, true),
+ (False, false),
+ ('a', truthy),
+ ('', falsy),
+ (NOTSET, notset),
+ ]:
+ if num:
+ sample.extend([truth_value] * num)
+ if sample:
+ expected = True if true + truthy in (0, 1) else False
+ assert at_most_one(*sample) is expected
+
+ for row in product(range(4), range(4), range(4), range(4), range(4)):
+ print(row)
+ assert_at_most_one(*row)
+
@pytest.mark.parametrize(
'mode, expected',
[