This is an automated email from the ASF dual-hosted git repository.
dstandish pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new f971232ab4 Add conditional logic for dataset triggering (#37016)
f971232ab4 is described below
commit f971232ab4a636a1f54a349041a7e22476b8b2dc
Author: Daniel Standish <[email protected]>
AuthorDate: Wed Feb 21 11:24:21 2024 -0800
Add conditional logic for dataset triggering (#37016)
Add conditional logic for dataset-triggered dags so that we can schedule
based on dataset1 OR dataset1.
This PR only implements the underlying classes, DatasetAny and DatasetAll.
In a followup PR we will add more convenient syntax for this, specifically the
| and & symbols, e.g. (dataset1 | dataset2) & dataset3.
---------
Co-authored-by: Ankit Chaurasia <[email protected]>
Co-authored-by: Jed Cunningham
<[email protected]>
Co-authored-by: Wei Lee <[email protected]>
---
airflow/models/dag.py | 102 +++++++++-----
airflow/models/dataset.py | 49 ++++++-
airflow/serialization/enums.py | 2 +
airflow/serialization/schema.json | 36 ++++-
airflow/serialization/serialized_objects.py | 29 +++-
airflow/timetables/datasets.py | 32 +++--
tests/cli/commands/test_dag_command.py | 12 +-
tests/datasets/test_dataset.py | 196 ++++++++++++++++++++++++++
tests/serialization/test_dag_serialization.py | 14 +-
tests/timetables/test_datasets_timetable.py | 1 -
10 files changed, 409 insertions(+), 64 deletions(-)
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index dd43568657..237759010a 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -18,7 +18,6 @@
from __future__ import annotations
import asyncio
-import collections
import copy
import functools
import itertools
@@ -31,7 +30,7 @@ import time
import traceback
import warnings
import weakref
-from collections import deque
+from collections import abc, defaultdict, deque
from contextlib import ExitStack
from datetime import datetime, timedelta
from inspect import signature
@@ -99,6 +98,13 @@ from airflow.models.baseoperator import BaseOperator
from airflow.models.dagcode import DagCode
from airflow.models.dagpickle import DagPickle
from airflow.models.dagrun import RUN_ID_REGEX, DagRun
+from airflow.models.dataset import (
+ DatasetAll,
+ DatasetAny,
+ DatasetBooleanCondition,
+ DatasetDagRunQueue,
+ DatasetModel,
+)
from airflow.models.param import DagParam, ParamsDict
from airflow.models.taskinstance import (
Context,
@@ -462,7 +468,7 @@ class DAG(LoggingMixin):
on_success_callback: None | DagStateChangeCallback |
list[DagStateChangeCallback] = None,
on_failure_callback: None | DagStateChangeCallback |
list[DagStateChangeCallback] = None,
doc_md: str | None = None,
- params: collections.abc.MutableMapping | None = None,
+ params: abc.MutableMapping | None = None,
access_control: dict | None = None,
is_paused_upon_creation: bool | None = None,
jinja_environment_kwargs: dict | None = None,
@@ -580,25 +586,28 @@ class DAG(LoggingMixin):
self.timetable: Timetable
self.schedule_interval: ScheduleInterval
- self.dataset_triggers: Collection[Dataset] = []
-
+ self.dataset_triggers: DatasetBooleanCondition | None = None
+ if isinstance(schedule, (DatasetAll, DatasetAny)):
+ self.dataset_triggers = schedule
if isinstance(schedule, Collection) and not isinstance(schedule, str):
from airflow.datasets import Dataset
if not all(isinstance(x, Dataset) for x in schedule):
raise ValueError("All elements in 'schedule' should be
datasets")
- self.dataset_triggers = list(schedule)
+ self.dataset_triggers = DatasetAll(*schedule)
elif isinstance(schedule, Timetable):
timetable = schedule
elif schedule is not NOTSET:
schedule_interval = schedule
- if self.dataset_triggers:
+ if isinstance(schedule, DatasetOrTimeSchedule):
+ self.timetable = schedule
+ self.dataset_triggers = self.timetable.datasets
+ self.schedule_interval = self.timetable.summary
+ elif self.dataset_triggers:
self.timetable = DatasetTriggeredTimetable()
self.schedule_interval = self.timetable.summary
elif timetable:
- if isinstance(timetable, DatasetOrTimeSchedule):
- self.dataset_triggers = timetable.datasets
self.timetable = timetable
self.schedule_interval = self.timetable.summary
else:
@@ -3156,8 +3165,8 @@ class DAG(LoggingMixin):
TaskOutletDatasetReference,
)
- dag_references = collections.defaultdict(set)
- outlet_references = collections.defaultdict(set)
+ dag_references = defaultdict(set)
+ outlet_references = defaultdict(set)
# We can't use a set here as we want to preserve order
outlet_datasets: dict[DatasetModel, None] = {}
input_datasets: dict[DatasetModel, None] = {}
@@ -3168,12 +3177,13 @@ class DAG(LoggingMixin):
# later we'll persist them to the database.
for dag in dags:
curr_orm_dag = existing_dags.get(dag.dag_id)
- if not dag.dataset_triggers:
+ if dag.dataset_triggers is None:
if curr_orm_dag and curr_orm_dag.schedule_dataset_references:
curr_orm_dag.schedule_dataset_references = []
- for dataset in dag.dataset_triggers:
- dag_references[dag.dag_id].add(dataset.uri)
- input_datasets[DatasetModel.from_public(dataset)] = None
+ else:
+ for dataset in dag.dataset_triggers.all_datasets().values():
+ dag_references[dag.dag_id].add(dataset.uri)
+ input_datasets[DatasetModel.from_public(dataset)] = None
curr_outlet_references = curr_orm_dag and
curr_orm_dag.task_outlet_dataset_references
for task in dag.tasks:
dataset_outlets = [x for x in task.outlets or [] if
isinstance(x, Dataset)]
@@ -3229,7 +3239,7 @@ class DAG(LoggingMixin):
for obj in dag_refs_stored - dag_refs_needed:
session.delete(obj)
- existing_task_outlet_refs_dict = collections.defaultdict(set)
+ existing_task_outlet_refs_dict = defaultdict(set)
for dag_id, orm_dag in existing_dags.items():
for todr in orm_dag.task_outlet_dataset_references:
existing_task_outlet_refs_dict[(dag_id,
todr.task_id)].add(todr)
@@ -3512,7 +3522,7 @@ class DagOwnerAttributes(Base):
@classmethod
def get_all(cls, session) -> dict[str, dict[str, str]]:
- dag_links: dict = collections.defaultdict(dict)
+ dag_links: dict = defaultdict(dict)
for obj in session.scalars(select(cls)):
dag_links[obj.dag_id].update({obj.owner: obj.link})
return dag_links
@@ -3781,23 +3791,43 @@ class DagModel(Base):
you should ensure that any scheduling decisions are made in a single
transaction -- as soon as the
transaction is committed it will be unlocked.
"""
- from airflow.models.dataset import DagScheduleDatasetReference,
DatasetDagRunQueue as DDRQ
-
- # these dag ids are triggered by datasets, and they are ready to go.
- dataset_triggered_dag_info = {
- x.dag_id: (x.first_queued_time, x.last_queued_time)
- for x in session.execute(
- select(
- DagScheduleDatasetReference.dag_id,
- func.max(DDRQ.created_at).label("last_queued_time"),
- func.min(DDRQ.created_at).label("first_queued_time"),
- )
- .join(DagScheduleDatasetReference.queue_records, isouter=True)
- .group_by(DagScheduleDatasetReference.dag_id)
- .having(func.count() ==
func.sum(case((DDRQ.target_dag_id.is_not(None), 1), else_=0)))
- )
- }
- dataset_triggered_dag_ids = set(dataset_triggered_dag_info)
+ from airflow.models.serialized_dag import SerializedDagModel
+
+ def dag_ready(dag_id: str, cond: DatasetBooleanCondition, statuses:
dict) -> bool | None:
+ # if dag was serialized before 2.9 and we *just* upgraded,
+ # we may be dealing with old version. In that case,
+ # just wait for the dag to be reserialized.
+ try:
+ return cond.evaluate(statuses)
+ except AttributeError:
+ log.warning("dag '%s' has old serialization; skipping dag run
creation.", dag_id)
+ return None
+
+ # this loads all the DDRQ records.... may need to limit num dags
+ all_records = session.scalars(select(DatasetDagRunQueue)).all()
+ by_dag = defaultdict(list)
+ for r in all_records:
+ by_dag[r.target_dag_id].append(r)
+ del all_records
+ dag_statuses = {}
+ for dag_id, records in by_dag.items():
+ dag_statuses[dag_id] = {x.dataset.uri: True for x in records}
+ ser_dags = session.scalars(
+
select(SerializedDagModel).where(SerializedDagModel.dag_id.in_(dag_statuses.keys()))
+ ).all()
+ for ser_dag in ser_dags:
+ dag_id = ser_dag.dag_id
+ statuses = dag_statuses[dag_id]
+ if not dag_ready(dag_id, cond=ser_dag.dag.dataset_triggers,
statuses=statuses):
+ del by_dag[dag_id]
+ del dag_statuses[dag_id]
+ del dag_statuses
+ dataset_triggered_dag_info = {}
+ for dag_id, records in by_dag.items():
+ times = sorted(x.created_at for x in records)
+ dataset_triggered_dag_info[dag_id] = (times[0], times[-1])
+ del by_dag
+ dataset_triggered_dag_ids = set(dataset_triggered_dag_info.keys())
if dataset_triggered_dag_ids:
exclusion_list = set(
session.scalars(
@@ -3908,7 +3938,7 @@ def dag(
on_success_callback: None | DagStateChangeCallback |
list[DagStateChangeCallback] = None,
on_failure_callback: None | DagStateChangeCallback |
list[DagStateChangeCallback] = None,
doc_md: str | None = None,
- params: collections.abc.MutableMapping | None = None,
+ params: abc.MutableMapping | None = None,
access_control: dict | None = None,
is_paused_upon_creation: bool | None = None,
jinja_environment_kwargs: dict | None = None,
@@ -4030,7 +4060,7 @@ class DagContext:
"""
- _context_managed_dags: collections.deque[DAG] = deque()
+ _context_managed_dags: deque[DAG] = deque()
autoregistered_dags: set[tuple[DAG, ModuleType]] = set()
current_autoregister_module_name: str | None = None
diff --git a/airflow/models/dataset.py b/airflow/models/dataset.py
index d9dd8e4bb5..bf28777358 100644
--- a/airflow/models/dataset.py
+++ b/airflow/models/dataset.py
@@ -17,6 +17,7 @@
# under the License.
from __future__ import annotations
+from typing import Callable, Iterable
from urllib.parse import urlsplit
import sqlalchemy_jsonfield
@@ -208,7 +209,7 @@ class DatasetDagRunQueue(Base):
dataset_id = Column(Integer, primary_key=True, nullable=False)
target_dag_id = Column(StringID(), primary_key=True, nullable=False)
created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False)
-
+ dataset = relationship("DatasetModel", viewonly=True)
__tablename__ = "dataset_dag_run_queue"
__table_args__ = (
PrimaryKeyConstraint(dataset_id, target_dag_id,
name="datasetdagrunqueue_pkey"),
@@ -336,3 +337,49 @@ class DatasetEvent(Base):
]:
args.append(f"{attr}={getattr(self, attr)!r}")
return f"{self.__class__.__name__}({', '.join(args)})"
+
+
+class DatasetBooleanCondition:
+ """
+ Base class for boolean logic for dataset triggers.
+
+ :meta private:
+ """
+
+ agg_func: Callable[[Iterable], bool]
+
+ def __init__(self, *objects) -> None:
+ self.objects = objects
+
+ def evaluate(self, statuses: dict[str, bool]) -> bool:
+ return self.agg_func(self.eval_one(x, statuses) for x in self.objects)
+
+ def eval_one(self, obj: Dataset | DatasetAny | DatasetAll, statuses) ->
bool:
+ if isinstance(obj, Dataset):
+ return statuses.get(obj.uri, False)
+ return obj.evaluate(statuses=statuses)
+
+ def all_datasets(self) -> dict[str, Dataset]:
+ uris = {}
+ for x in self.objects:
+ if isinstance(x, Dataset):
+ if x.uri not in uris:
+ uris[x.uri] = x
+ else:
+ # keep the first instance
+ for k, v in x.all_datasets().items():
+ if k not in uris:
+ uris[k] = v
+ return uris
+
+
+class DatasetAny(DatasetBooleanCondition):
+ """Use to combine datasets schedule references in an "and" relationship."""
+
+ agg_func = any
+
+
+class DatasetAll(DatasetBooleanCondition):
+ """Use to combine datasets schedule references in an "or" relationship."""
+
+ agg_func = all
diff --git a/airflow/serialization/enums.py b/airflow/serialization/enums.py
index 4f95c849c8..963dec580e 100644
--- a/airflow/serialization/enums.py
+++ b/airflow/serialization/enums.py
@@ -50,6 +50,8 @@ class DagAttributeTypes(str, Enum):
PARAM = "param"
XCOM_REF = "xcomref"
DATASET = "dataset"
+ DATASET_ANY = "dataset_any"
+ DATASET_ALL = "dataset_all"
SIMPLE_TASK_INSTANCE = "simple_task_instance"
BASE_JOB = "Job"
TASK_INSTANCE = "task_instance"
diff --git a/airflow/serialization/schema.json
b/airflow/serialization/schema.json
index ae7121fd14..71ee0c8006 100644
--- a/airflow/serialization/schema.json
+++ b/airflow/serialization/schema.json
@@ -81,6 +81,36 @@
],
"additionalProperties": false
},
+ "typed_dataset_cond": {
+ "type": "object",
+ "properties": {
+ "__type": {
+ "anyOf": [{
+ "type": "string",
+ "constant": "dataset_or"
+ },
+ {
+ "type": "string",
+ "constant": "dataset_and"
+ }
+ ]
+ },
+ "__var": {
+ "type": "array",
+ "items": {
+ "anyOf": [
+ {"$ref": "#/definitions/typed_dataset"},
+ { "$ref": "#/definitions/typed_dataset_cond"}
+ ]
+ }
+ }
+ },
+ "required": [
+ "__type",
+ "__var"
+ ],
+ "additionalProperties": false
+ },
"dict": {
"description": "A python dictionary containing values of any type",
"type": "object"
@@ -119,9 +149,9 @@
]
},
"dataset_triggers": {
- "type": "array",
- "items": { "$ref": "#/definitions/typed_dataset" }
- },
+ "$ref": "#/definitions/typed_dataset_cond"
+
+},
"owner_links": { "type": "object" },
"timetable": {
"type": "object",
diff --git a/airflow/serialization/serialized_objects.py
b/airflow/serialization/serialized_objects.py
index 7adddbab10..5e6073233e 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -42,6 +42,7 @@ from airflow.models.baseoperator import BaseOperator
from airflow.models.connection import Connection
from airflow.models.dag import DAG, DagModel, create_timetable
from airflow.models.dagrun import DagRun
+from airflow.models.dataset import DatasetAll, DatasetAny
from airflow.models.expandinput import EXPAND_INPUT_EMPTY,
create_expand_input, get_map_type_key
from airflow.models.mappedoperator import MappedOperator
from airflow.models.param import Param, ParamsDict
@@ -404,6 +405,8 @@ class BaseSerialization:
serialized_object[key] = cls.serialize(value)
elif key == "timetable" and value is not None:
serialized_object[key] = encode_timetable(value)
+ elif key == "dataset_triggers":
+ serialized_object[key] = cls.serialize(value)
else:
value = cls.serialize(value)
if isinstance(value, dict) and Encoding.TYPE in value:
@@ -497,6 +500,22 @@ class BaseSerialization:
return cls._encode(serialize_xcom_arg(var), type_=DAT.XCOM_REF)
elif isinstance(var, Dataset):
return cls._encode({"uri": var.uri, "extra": var.extra},
type_=DAT.DATASET)
+ elif isinstance(var, DatasetAll):
+ return cls._encode(
+ [
+ cls.serialize(x, strict=strict,
use_pydantic_models=use_pydantic_models)
+ for x in var.objects
+ ],
+ type_=DAT.DATASET_ALL,
+ )
+ elif isinstance(var, DatasetAny):
+ return cls._encode(
+ [
+ cls.serialize(x, strict=strict,
use_pydantic_models=use_pydantic_models)
+ for x in var.objects
+ ],
+ type_=DAT.DATASET_ANY,
+ )
elif isinstance(var, SimpleTaskInstance):
return cls._encode(
cls.serialize(var.__dict__, strict=strict,
use_pydantic_models=use_pydantic_models),
@@ -587,6 +606,10 @@ class BaseSerialization:
return _XComRef(var) # Delay deserializing XComArg objects until
we have the entire DAG.
elif type_ == DAT.DATASET:
return Dataset(**var)
+ elif type_ == DAT.DATASET_ANY:
+ return DatasetAny(*(cls.deserialize(x) for x in var))
+ elif type_ == DAT.DATASET_ALL:
+ return DatasetAll(*(cls.deserialize(x) for x in var))
elif type_ == DAT.SIMPLE_TASK_INSTANCE:
return SimpleTaskInstance(**cls.deserialize(var))
elif type_ == DAT.CONNECTION:
@@ -763,12 +786,14 @@ class DependencyDetector:
"""Detect dependencies set directly on the DAG object."""
if not dag:
return
- for x in dag.dataset_triggers:
+ if not dag.dataset_triggers:
+ return
+ for uri in dag.dataset_triggers.all_datasets().keys():
yield DagDependency(
source="dataset",
target=dag.dag_id,
dependency_type="dataset",
- dependency_id=x.uri,
+ dependency_id=uri,
)
diff --git a/airflow/timetables/datasets.py b/airflow/timetables/datasets.py
index 4904c64e9c..c755df964e 100644
--- a/airflow/timetables/datasets.py
+++ b/airflow/timetables/datasets.py
@@ -17,28 +17,31 @@
from __future__ import annotations
-import collections.abc
import typing
-import attrs
-
-from airflow.datasets import Dataset
from airflow.exceptions import AirflowTimetableInvalid
+from airflow.models.dataset import DatasetAll, DatasetBooleanCondition
from airflow.timetables.simple import DatasetTriggeredTimetable as
DatasetTriggeredSchedule
from airflow.utils.types import DagRunType
if typing.TYPE_CHECKING:
+ from collections.abc import Collection
+
import pendulum
+ from airflow.datasets import Dataset
from airflow.timetables.base import DagRunInfo, DataInterval,
TimeRestriction, Timetable
class DatasetOrTimeSchedule(DatasetTriggeredSchedule):
"""Combine time-based scheduling with event-based scheduling."""
- def __init__(self, timetable: Timetable, datasets:
collections.abc.Collection[Dataset]) -> None:
+ def __init__(self, timetable: Timetable, datasets: Collection[Dataset] |
DatasetBooleanCondition) -> None:
self.timetable = timetable
- self.datasets = datasets
+ if isinstance(datasets, DatasetBooleanCondition):
+ self.datasets = datasets
+ else:
+ self.datasets = DatasetAll(*datasets)
self.description = f"Triggered by datasets or {timetable.description}"
self.periodic = timetable.periodic
@@ -52,24 +55,23 @@ class DatasetOrTimeSchedule(DatasetTriggeredSchedule):
from airflow.serialization.serialized_objects import decode_timetable
return cls(
- timetable=decode_timetable(data["timetable"]),
datasets=[Dataset(**d) for d in data["datasets"]]
+ timetable=decode_timetable(data["timetable"]),
+ # don't need the datasets after deserialization
+ # they are already stored on dataset_triggers attr on DAG
+ # and this is what scheduler looks at
+ datasets=[],
)
def serialize(self) -> dict[str, typing.Any]:
from airflow.serialization.serialized_objects import encode_timetable
- return {
- "timetable": encode_timetable(self.timetable),
- "datasets": [attrs.asdict(e) for e in self.datasets],
- }
+ return {"timetable": encode_timetable(self.timetable)}
def validate(self) -> None:
if isinstance(self.timetable, DatasetTriggeredSchedule):
raise AirflowTimetableInvalid("cannot nest dataset timetables")
- if not isinstance(self.datasets, collections.abc.Collection) or not
all(
- isinstance(d, Dataset) for d in self.datasets
- ):
- raise AirflowTimetableInvalid("all elements in 'event' must be
datasets")
+ if not isinstance(self.datasets, DatasetBooleanCondition):
+ raise AirflowTimetableInvalid("all elements in 'datasets' must be
datasets")
@property
def summary(self) -> str:
diff --git a/tests/cli/commands/test_dag_command.py
b/tests/cli/commands/test_dag_command.py
index 0df2c36f7d..ca47309721 100644
--- a/tests/cli/commands/test_dag_command.py
+++ b/tests/cli/commands/test_dag_command.py
@@ -392,16 +392,24 @@ class TestCliDags:
disable_retry=False,
)
- @mock.patch("workday.AfterWorkdayTimetable")
+ @mock.patch("workday.AfterWorkdayTimetable.get_next_workday")
@mock.patch("airflow.models.taskinstance.TaskInstance.dry_run")
@mock.patch("airflow.cli.commands.dag_command.DagRun")
- def test_backfill_with_custom_timetable(self, mock_dagrun, mock_dry_run,
mock_AfterWorkdayTimetable):
+ def test_backfill_with_custom_timetable(self, mock_dagrun, mock_dry_run,
mock_get_next_workday):
"""
when calling `dags backfill` on dag with custom timetable, the DagRun
object should be created with
data_intervals.
"""
+
start_date = DEFAULT_DATE + timedelta(days=1)
end_date = start_date + timedelta(days=1)
+ workdays = [
+ start_date,
+ start_date + timedelta(days=1),
+ start_date + timedelta(days=2),
+ ]
+ mock_get_next_workday.side_effect = workdays
+
cli_args = self.parser.parse_args(
[
"dags",
diff --git a/tests/datasets/test_dataset.py b/tests/datasets/test_dataset.py
index 9e9ca99513..dfc8b82ba1 100644
--- a/tests/datasets/test_dataset.py
+++ b/tests/datasets/test_dataset.py
@@ -18,11 +18,25 @@
from __future__ import annotations
import os
+from collections import defaultdict
import pytest
+from sqlalchemy.sql import select
from airflow.datasets import Dataset
+from airflow.models.dataset import DatasetAll, DatasetAny, DatasetDagRunQueue,
DatasetModel
+from airflow.models.serialized_dag import SerializedDagModel
from airflow.operators.empty import EmptyOperator
+from airflow.serialization.serialized_objects import BaseSerialization,
SerializedDAG
+
+
[email protected]
+def clear_datasets():
+ from tests.test_utils.db import clear_db_datasets
+
+ clear_db_datasets()
+ yield
+ clear_db_datasets()
@pytest.mark.parametrize(
@@ -54,3 +68,185 @@ def test_fspath():
uri = "s3://example_dataset"
dataset = Dataset(uri=uri)
assert os.fspath(dataset) == uri
+
+
[email protected](
+ "inputs, scenario, expected",
+ [
+ # Scenarios for DatasetAny
+ ((True, True, True), "any", True),
+ ((True, True, False), "any", True),
+ ((True, False, True), "any", True),
+ ((True, False, False), "any", True),
+ ((False, False, True), "any", True),
+ ((False, True, False), "any", True),
+ ((False, True, True), "any", True),
+ ((False, False, False), "any", False),
+ # Scenarios for DatasetAll
+ ((True, True, True), "all", True),
+ ((True, True, False), "all", False),
+ ((True, False, True), "all", False),
+ ((True, False, False), "all", False),
+ ((False, False, True), "all", False),
+ ((False, True, False), "all", False),
+ ((False, True, True), "all", False),
+ ((False, False, False), "all", False),
+ ],
+)
+def test_dataset_logical_conditions_evaluation_and_serialization(inputs,
scenario, expected):
+ class_ = DatasetAny if scenario == "any" else DatasetAll
+ datasets = [Dataset(uri=f"s3://abc/{i}") for i in range(123, 126)]
+ condition = class_(*datasets)
+
+ statuses = {dataset.uri: status for dataset, status in zip(datasets,
inputs)}
+ assert (
+ condition.evaluate(statuses) == expected
+ ), f"Condition evaluation failed for inputs {inputs} and scenario
'{scenario}'"
+
+ # Serialize and deserialize the condition to test persistence
+ serialized = BaseSerialization.serialize(condition)
+ deserialized = BaseSerialization.deserialize(serialized)
+ assert deserialized.evaluate(statuses) == expected, "Serialization
round-trip failed"
+
+
[email protected](
+ "status_values, expected_evaluation",
+ [
+ ((False, True, True), False), # DatasetAll requires all conditions to
be True, but d1 is False
+ ((True, True, True), True), # All conditions are True
+ ((True, False, True), True), # d1 is True, and DatasetAny condition
(d2 or d3 being True) is met
+ ((True, False, False), False), # d1 is True, but neither d2 nor d3
meet the DatasetAny condition
+ ],
+)
+def test_nested_dataset_conditions_with_serialization(status_values,
expected_evaluation):
+ # Define datasets
+ d1 = Dataset(uri="s3://abc/123")
+ d2 = Dataset(uri="s3://abc/124")
+ d3 = Dataset(uri="s3://abc/125")
+
+ # Create a nested condition: DatasetAll with d1 and DatasetAny with d2 and
d3
+ nested_condition = DatasetAll(d1, DatasetAny(d2, d3))
+
+ statuses = {
+ d1.uri: status_values[0],
+ d2.uri: status_values[1],
+ d3.uri: status_values[2],
+ }
+
+ assert nested_condition.evaluate(statuses) == expected_evaluation,
"Initial evaluation mismatch"
+
+ serialized_condition = BaseSerialization.serialize(nested_condition)
+ deserialized_condition =
BaseSerialization.deserialize(serialized_condition)
+
+ assert (
+ deserialized_condition.evaluate(statuses) == expected_evaluation
+ ), "Post-serialization evaluation mismatch"
+
+
[email protected]
+def create_test_datasets(session):
+ """Fixture to create test datasets and corresponding models."""
+ datasets = [Dataset(uri=f"hello{i}") for i in range(1, 3)]
+ for dataset in datasets:
+ session.add(DatasetModel(uri=dataset.uri))
+ session.commit()
+ return datasets
+
+
[email protected]_test
[email protected]("clear_datasets")
+def test_dataset_trigger_setup_and_serialization(session, dag_maker,
create_test_datasets):
+ datasets = create_test_datasets
+
+ # Create DAG with dataset triggers
+ with dag_maker(schedule=DatasetAny(*datasets)) as dag:
+ EmptyOperator(task_id="hello")
+
+ # Verify dataset triggers are set up correctly
+ assert isinstance(
+ dag.dataset_triggers, DatasetAny
+ ), "DAG dataset triggers should be an instance of DatasetAny"
+
+ # Serialize and deserialize DAG dataset triggers
+ serialized_trigger = SerializedDAG.serialize(dag.dataset_triggers)
+ deserialized_trigger = SerializedDAG.deserialize(serialized_trigger)
+
+ # Verify serialization and deserialization integrity
+ assert isinstance(
+ deserialized_trigger, DatasetAny
+ ), "Deserialized trigger should maintain type DatasetAny"
+ assert (
+ deserialized_trigger.objects == dag.dataset_triggers.objects
+ ), "Deserialized trigger objects should match original"
+
+
[email protected]_test
[email protected]("clear_datasets")
+def test_dataset_dag_run_queue_processing(session, clear_datasets, dag_maker,
create_test_datasets):
+ datasets = create_test_datasets
+ dataset_models = session.query(DatasetModel).all()
+
+ with dag_maker(schedule=DatasetAny(*datasets)) as dag:
+ EmptyOperator(task_id="hello")
+
+ # Add DatasetDagRunQueue entries to simulate dataset event processing
+ for dm in dataset_models:
+ session.add(DatasetDagRunQueue(dataset_id=dm.id,
target_dag_id=dag.dag_id))
+ session.commit()
+
+ # Fetch and evaluate dataset triggers for all DAGs affected by dataset
events
+ records = session.scalars(select(DatasetDagRunQueue)).all()
+ dag_statuses = defaultdict(lambda: defaultdict(bool))
+ for record in records:
+ dag_statuses[record.target_dag_id][record.dataset.uri] = True
+
+ serialized_dags = session.execute(
+
select(SerializedDagModel).where(SerializedDagModel.dag_id.in_(dag_statuses.keys()))
+ ).fetchall()
+
+ for (serialized_dag,) in serialized_dags:
+ dag = SerializedDAG.deserialize(serialized_dag.data)
+ for dataset_uri, status in dag_statuses[dag.dag_id].items():
+ assert dag.dataset_triggers.evaluate({dataset_uri: status}), "DAG
trigger evaluation failed"
+
+
[email protected]_test
[email protected]("clear_datasets")
+def test_dag_with_complex_dataset_triggers(session, dag_maker):
+ # Create Dataset instances
+ d1 = Dataset(uri="hello1")
+ d2 = Dataset(uri="hello2")
+
+ # Create and add DatasetModel instances to the session
+ dm1 = DatasetModel(uri=d1.uri)
+ dm2 = DatasetModel(uri=d2.uri)
+ session.add_all([dm1, dm2])
+ session.commit()
+
+ # Setup a DAG with complex dataset triggers (DatasetAny with DatasetAll)
+ with dag_maker(schedule=DatasetAny(d1, DatasetAll(d2, d1))) as dag:
+ EmptyOperator(task_id="hello")
+
+ assert isinstance(
+ dag.dataset_triggers, DatasetAny
+ ), "DAG's dataset trigger should be an instance of DatasetAny"
+ assert any(
+ isinstance(trigger, DatasetAll) for trigger in
dag.dataset_triggers.objects
+ ), "DAG's dataset trigger should include DatasetAll"
+
+ serialized_triggers = SerializedDAG.serialize(dag.dataset_triggers)
+
+ deserialized_triggers = SerializedDAG.deserialize(serialized_triggers)
+
+ assert isinstance(
+ deserialized_triggers, DatasetAny
+ ), "Deserialized triggers should be an instance of DatasetAny"
+ assert any(
+ isinstance(trigger, DatasetAll) for trigger in
deserialized_triggers.objects
+ ), "Deserialized triggers should include DatasetAll"
+
+ serialized_dag_dict = SerializedDAG.to_dict(dag)["dag"]
+ assert "dataset_triggers" in serialized_dag_dict, "Serialized DAG should
contain 'dataset_triggers'"
+ assert isinstance(
+ serialized_dag_dict["dataset_triggers"], dict
+ ), "Serialized 'dataset_triggers' should be a dict"
diff --git a/tests/serialization/test_dag_serialization.py
b/tests/serialization/test_dag_serialization.py
index 8a122592fd..2adc956b6f 100644
--- a/tests/serialization/test_dag_serialization.py
+++ b/tests/serialization/test_dag_serialization.py
@@ -60,6 +60,7 @@ from airflow.sensors.bash import BashSensor
from airflow.serialization.enums import Encoding
from airflow.serialization.json_schema import load_dag_schema_dict
from airflow.serialization.serialized_objects import (
+ BaseSerialization,
DagDependency,
DependencyDetector,
SerializedBaseOperator,
@@ -212,7 +213,6 @@ serialized_simple_dag_ground_truth = {
},
],
"schedule_interval": {"__type": "timedelta", "__var": 86400.0},
- "dataset_triggers": [],
"timezone": "UTC",
"_access_control": {
"__type": "dict",
@@ -551,11 +551,17 @@ class TestStringifiedDAGs:
"params",
"_processor_dags_folder",
}
+ compare_serialization_list = {
+ "dataset_triggers",
+ }
fields_to_check = dag.get_serialized_fields() - exclusion_list
for field in fields_to_check:
- assert getattr(serialized_dag, field) == getattr(
- dag, field
- ), f"{dag.dag_id}.{field} does not match"
+ actual = getattr(serialized_dag, field)
+ expected = getattr(dag, field)
+ if field in compare_serialization_list:
+ actual = BaseSerialization.serialize(actual)
+ expected = BaseSerialization.serialize(expected)
+ assert actual == expected, f"{dag.dag_id}.{field} does not match"
# _processor_dags_folder is only populated at serialization time
# it's only used when relying on serialized dag to determine a dag's
relative path
assert dag._processor_dags_folder is None
diff --git a/tests/timetables/test_datasets_timetable.py
b/tests/timetables/test_datasets_timetable.py
index 8e293888ca..ce58c42a6b 100644
--- a/tests/timetables/test_datasets_timetable.py
+++ b/tests/timetables/test_datasets_timetable.py
@@ -127,7 +127,6 @@ def test_serialization(dataset_timetable:
DatasetOrTimeSchedule, monkeypatch: An
serialized = dataset_timetable.serialize()
assert serialized == {
"timetable": "mock_serialized_timetable",
- "datasets": [{"uri": "test_dataset", "extra": None}],
}