This is an automated email from the ASF dual-hosted git repository.
phanikumv pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 0900055585 Refactor dataset class inheritance (#37590)
0900055585 is described below
commit 090005558558c19c21626f715042ef840a68c05a
Author: Ankit Chaurasia <[email protected]>
AuthorDate: Thu Feb 22 14:29:52 2024 +0530
Refactor dataset class inheritance (#37590)
* Refactor DatasetAll and DatasetAny inheritance
They are moved from airflow.models.datasets to airflow.datasets since
the intention is to use them with Dataset, not DatasetModel. It is more
natural for users to import from the latter module instead.
A new (abstract) base class is added for the two classes, plus the OG
Dataset class, to inherit from. This allows us to replace a few
isinstance checks with simple molymorphism and make the logic a bit
simpler.
---------
Co-authored-by: Tzu-ping Chung <[email protected]>
Co-authored-by: Wei Lee <[email protected]>
---
airflow/datasets/__init__.py | 64 +++++++++++++++++++++++++++--
airflow/models/dag.py | 26 ++++--------
airflow/models/dataset.py | 47 ---------------------
airflow/serialization/serialized_objects.py | 5 +--
airflow/timetables/datasets.py | 13 ++++--
tests/datasets/test_dataset.py | 4 +-
6 files changed, 82 insertions(+), 77 deletions(-)
diff --git a/airflow/datasets/__init__.py b/airflow/datasets/__init__.py
index eaa25d0a30..1d08d7d6d3 100644
--- a/airflow/datasets/__init__.py
+++ b/airflow/datasets/__init__.py
@@ -14,18 +14,35 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+
from __future__ import annotations
import os
-from typing import Any, ClassVar
+from typing import Any, Callable, ClassVar, Iterable, Iterator, Protocol,
runtime_checkable
from urllib.parse import urlsplit
import attr
+__all__ = ["Dataset", "DatasetAll", "DatasetAny"]
+
+
+@runtime_checkable
+class BaseDatasetEventInput(Protocol):
+ """Protocol for all dataset triggers to use in ``DAG(schedule=...)``.
+
+ :meta private:
+ """
+
+ def evaluate(self, statuses: dict[str, bool]) -> bool:
+ raise NotImplementedError
+
+ def iter_datasets(self) -> Iterator[tuple[str, Dataset]]:
+ raise NotImplementedError
+
@attr.define()
-class Dataset(os.PathLike):
- """A Dataset is used for marking data dependencies between workflows."""
+class Dataset(os.PathLike, BaseDatasetEventInput):
+ """A representation of data dependencies between workflows."""
uri: str = attr.field(validator=[attr.validators.min_len(1),
attr.validators.max_len(3000)])
extra: dict[str, Any] | None = None
@@ -44,7 +61,7 @@ class Dataset(os.PathLike):
if parsed.scheme and parsed.scheme.lower() == "airflow":
raise ValueError(f"{attr.name!r} scheme `airflow` is reserved")
- def __fspath__(self):
+ def __fspath__(self) -> str:
return self.uri
def __eq__(self, other):
@@ -55,3 +72,42 @@ class Dataset(os.PathLike):
def __hash__(self):
return hash(self.uri)
+
+ def iter_datasets(self) -> Iterator[tuple[str, Dataset]]:
+ yield self.uri, self
+
+ def evaluate(self, statuses: dict[str, bool]) -> bool:
+ return statuses.get(self.uri, False)
+
+
+class _DatasetBooleanCondition(BaseDatasetEventInput):
+ """Base class for dataset boolean logic."""
+
+ agg_func: Callable[[Iterable], bool]
+
+ def __init__(self, *objects: BaseDatasetEventInput) -> None:
+ self.objects = objects
+
+ def evaluate(self, statuses: dict[str, bool]) -> bool:
+ return self.agg_func(x.evaluate(statuses=statuses) for x in
self.objects)
+
+ def iter_datasets(self) -> Iterator[tuple[str, Dataset]]:
+ seen = set() # We want to keep the first instance.
+ for o in self.objects:
+ for k, v in o.iter_datasets():
+ if k in seen:
+ continue
+ yield k, v
+ seen.add(k)
+
+
+class DatasetAny(_DatasetBooleanCondition):
+ """Use to combine datasets schedule references in an "and" relationship."""
+
+ agg_func = any
+
+
+class DatasetAll(_DatasetBooleanCondition):
+ """Use to combine datasets schedule references in an "or" relationship."""
+
+ agg_func = all
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 237759010a..19bf428543 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -80,6 +80,7 @@ import airflow.templates
from airflow import settings, utils
from airflow.api_internal.internal_api_call import internal_api_call
from airflow.configuration import conf as airflow_conf, secrets_backend_list
+from airflow.datasets import BaseDatasetEventInput, Dataset, DatasetAll
from airflow.datasets.manager import dataset_manager
from airflow.exceptions import (
AirflowDagInconsistent,
@@ -98,13 +99,7 @@ from airflow.models.baseoperator import BaseOperator
from airflow.models.dagcode import DagCode
from airflow.models.dagpickle import DagPickle
from airflow.models.dagrun import RUN_ID_REGEX, DagRun
-from airflow.models.dataset import (
- DatasetAll,
- DatasetAny,
- DatasetBooleanCondition,
- DatasetDagRunQueue,
- DatasetModel,
-)
+from airflow.models.dataset import DatasetDagRunQueue, DatasetModel
from airflow.models.param import DagParam, ParamsDict
from airflow.models.taskinstance import (
Context,
@@ -150,7 +145,6 @@ if TYPE_CHECKING:
from sqlalchemy.orm.query import Query
from sqlalchemy.orm.session import Session
- from airflow.datasets import Dataset
from airflow.decorators import TaskDecoratorCollection
from airflow.models.dagbag import DagBag
from airflow.models.operator import Operator
@@ -174,7 +168,7 @@ ScheduleInterval = Union[None, str, timedelta,
relativedelta]
# but Mypy cannot handle that right now. Track progress of PEP 661 for
progress.
# See also: https://discuss.python.org/t/9126/7
ScheduleIntervalArg = Union[ArgNotSet, ScheduleInterval]
-ScheduleArg = Union[ArgNotSet, ScheduleInterval, Timetable,
Collection["Dataset"]]
+ScheduleArg = Union[ArgNotSet, ScheduleInterval, Timetable,
BaseDatasetEventInput, Collection["Dataset"]]
SLAMissCallback = Callable[["DAG", str, str, List["SlaMiss"],
List[TaskInstance]], None]
@@ -586,12 +580,10 @@ class DAG(LoggingMixin):
self.timetable: Timetable
self.schedule_interval: ScheduleInterval
- self.dataset_triggers: DatasetBooleanCondition | None = None
- if isinstance(schedule, (DatasetAll, DatasetAny)):
+ self.dataset_triggers: BaseDatasetEventInput | None = None
+ if isinstance(schedule, BaseDatasetEventInput):
self.dataset_triggers = schedule
- if isinstance(schedule, Collection) and not isinstance(schedule, str):
- from airflow.datasets import Dataset
-
+ elif isinstance(schedule, Collection) and not isinstance(schedule,
str):
if not all(isinstance(x, Dataset) for x in schedule):
raise ValueError("All elements in 'schedule' should be
datasets")
self.dataset_triggers = DatasetAll(*schedule)
@@ -3181,7 +3173,7 @@ class DAG(LoggingMixin):
if curr_orm_dag and curr_orm_dag.schedule_dataset_references:
curr_orm_dag.schedule_dataset_references = []
else:
- for dataset in dag.dataset_triggers.all_datasets().values():
+ for _, dataset in dag.dataset_triggers.iter_datasets():
dag_references[dag.dag_id].add(dataset.uri)
input_datasets[DatasetModel.from_public(dataset)] = None
curr_outlet_references = curr_orm_dag and
curr_orm_dag.task_outlet_dataset_references
@@ -3793,14 +3785,14 @@ class DagModel(Base):
"""
from airflow.models.serialized_dag import SerializedDagModel
- def dag_ready(dag_id: str, cond: DatasetBooleanCondition, statuses:
dict) -> bool | None:
+ def dag_ready(dag_id: str, cond: BaseDatasetEventInput, statuses:
dict) -> bool | None:
# if dag was serialized before 2.9 and we *just* upgraded,
# we may be dealing with old version. In that case,
# just wait for the dag to be reserialized.
try:
return cond.evaluate(statuses)
except AttributeError:
- log.warning("dag '%s' has old serialization; skipping dag run
creation.", dag_id)
+ log.warning("dag '%s' has old serialization; skipping DAG run
creation.", dag_id)
return None
# this loads all the DDRQ records.... may need to limit num dags
diff --git a/airflow/models/dataset.py b/airflow/models/dataset.py
index bf28777358..aa10eb3809 100644
--- a/airflow/models/dataset.py
+++ b/airflow/models/dataset.py
@@ -17,7 +17,6 @@
# under the License.
from __future__ import annotations
-from typing import Callable, Iterable
from urllib.parse import urlsplit
import sqlalchemy_jsonfield
@@ -337,49 +336,3 @@ class DatasetEvent(Base):
]:
args.append(f"{attr}={getattr(self, attr)!r}")
return f"{self.__class__.__name__}({', '.join(args)})"
-
-
-class DatasetBooleanCondition:
- """
- Base class for boolean logic for dataset triggers.
-
- :meta private:
- """
-
- agg_func: Callable[[Iterable], bool]
-
- def __init__(self, *objects) -> None:
- self.objects = objects
-
- def evaluate(self, statuses: dict[str, bool]) -> bool:
- return self.agg_func(self.eval_one(x, statuses) for x in self.objects)
-
- def eval_one(self, obj: Dataset | DatasetAny | DatasetAll, statuses) ->
bool:
- if isinstance(obj, Dataset):
- return statuses.get(obj.uri, False)
- return obj.evaluate(statuses=statuses)
-
- def all_datasets(self) -> dict[str, Dataset]:
- uris = {}
- for x in self.objects:
- if isinstance(x, Dataset):
- if x.uri not in uris:
- uris[x.uri] = x
- else:
- # keep the first instance
- for k, v in x.all_datasets().items():
- if k not in uris:
- uris[k] = v
- return uris
-
-
-class DatasetAny(DatasetBooleanCondition):
- """Use to combine datasets schedule references in an "and" relationship."""
-
- agg_func = any
-
-
-class DatasetAll(DatasetBooleanCondition):
- """Use to combine datasets schedule references in an "or" relationship."""
-
- agg_func = all
diff --git a/airflow/serialization/serialized_objects.py
b/airflow/serialization/serialized_objects.py
index 5e6073233e..552244d73b 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -35,14 +35,13 @@ from pendulum.tz.timezone import FixedTimezone, Timezone
from airflow.compat.functools import cache
from airflow.configuration import conf
-from airflow.datasets import Dataset
+from airflow.datasets import Dataset, DatasetAll, DatasetAny
from airflow.exceptions import AirflowException, RemovedInAirflow3Warning,
SerializationError
from airflow.jobs.job import Job
from airflow.models.baseoperator import BaseOperator
from airflow.models.connection import Connection
from airflow.models.dag import DAG, DagModel, create_timetable
from airflow.models.dagrun import DagRun
-from airflow.models.dataset import DatasetAll, DatasetAny
from airflow.models.expandinput import EXPAND_INPUT_EMPTY,
create_expand_input, get_map_type_key
from airflow.models.mappedoperator import MappedOperator
from airflow.models.param import Param, ParamsDict
@@ -788,7 +787,7 @@ class DependencyDetector:
return
if not dag.dataset_triggers:
return
- for uri in dag.dataset_triggers.all_datasets().keys():
+ for uri, _ in dag.dataset_triggers.iter_datasets():
yield DagDependency(
source="dataset",
target=dag.dag_id,
diff --git a/airflow/timetables/datasets.py b/airflow/timetables/datasets.py
index c755df964e..dcc0652929 100644
--- a/airflow/timetables/datasets.py
+++ b/airflow/timetables/datasets.py
@@ -19,8 +19,8 @@ from __future__ import annotations
import typing
+from airflow.datasets import BaseDatasetEventInput, DatasetAll
from airflow.exceptions import AirflowTimetableInvalid
-from airflow.models.dataset import DatasetAll, DatasetBooleanCondition
from airflow.timetables.simple import DatasetTriggeredTimetable as
DatasetTriggeredSchedule
from airflow.utils.types import DagRunType
@@ -36,9 +36,14 @@ if typing.TYPE_CHECKING:
class DatasetOrTimeSchedule(DatasetTriggeredSchedule):
"""Combine time-based scheduling with event-based scheduling."""
- def __init__(self, timetable: Timetable, datasets: Collection[Dataset] |
DatasetBooleanCondition) -> None:
+ def __init__(
+ self,
+ *,
+ timetable: Timetable,
+ datasets: Collection[Dataset] | BaseDatasetEventInput,
+ ) -> None:
self.timetable = timetable
- if isinstance(datasets, DatasetBooleanCondition):
+ if isinstance(datasets, BaseDatasetEventInput):
self.datasets = datasets
else:
self.datasets = DatasetAll(*datasets)
@@ -70,7 +75,7 @@ class DatasetOrTimeSchedule(DatasetTriggeredSchedule):
def validate(self) -> None:
if isinstance(self.timetable, DatasetTriggeredSchedule):
raise AirflowTimetableInvalid("cannot nest dataset timetables")
- if not isinstance(self.datasets, DatasetBooleanCondition):
+ if not isinstance(self.datasets, BaseDatasetEventInput):
raise AirflowTimetableInvalid("all elements in 'datasets' must be
datasets")
@property
diff --git a/tests/datasets/test_dataset.py b/tests/datasets/test_dataset.py
index e10264b0e2..258e542cec 100644
--- a/tests/datasets/test_dataset.py
+++ b/tests/datasets/test_dataset.py
@@ -23,8 +23,8 @@ from collections import defaultdict
import pytest
from sqlalchemy.sql import select
-from airflow.datasets import Dataset
-from airflow.models.dataset import DatasetAll, DatasetAny, DatasetDagRunQueue,
DatasetModel
+from airflow.datasets import Dataset, DatasetAll, DatasetAny
+from airflow.models.dataset import DatasetDagRunQueue, DatasetModel
from airflow.models.serialized_dag import SerializedDagModel
from airflow.operators.empty import EmptyOperator
from airflow.serialization.serialized_objects import BaseSerialization,
SerializedDAG