jedcunningham commented on code in PR #37016:
URL: https://github.com/apache/airflow/pull/37016#discussion_r1496912178
##########
airflow/models/dataset.py:
##########
@@ -336,3 +337,49 @@ def __repr__(self) -> str:
]:
args.append(f"{attr}={getattr(self, attr)!r}")
return f"{self.__class__.__name__}({', '.join(args)})"
+
+
+class DatasetBooleanCondition:
+ """
+ Base class for boolean logic for dataset triggers.
+
+ :meta private:
+ """
+
+ agg_func: Callable
+
+ def __init__(self, *objects):
+ self.objects = objects
+
+ def evaluate(self, statuses: dict[str, bool]):
Review Comment:
```suggestion
def evaluate(self, statuses: dict[str, bool]) -> bool:
```
##########
tests/cli/commands/test_dag_command.py:
##########
Review Comment:
How are these changes related?
##########
tests/datasets/test_dataset.py:
##########
@@ -40,17 +46,224 @@ def test_invalid_uris(uri):
Dataset(uri=uri)
[email protected]_test
def test_uri_with_scheme():
dataset = Dataset(uri="s3://example_dataset")
EmptyOperator(task_id="task1", outlets=[dataset])
[email protected]_test
def test_uri_without_scheme():
dataset = Dataset(uri="example_dataset")
EmptyOperator(task_id="task1", outlets=[dataset])
[email protected]_test
def test_fspath():
uri = "s3://example_dataset"
dataset = Dataset(uri=uri)
assert os.fspath(dataset) == uri
+
+
[email protected]_test
[email protected](
+ "inputs, scenario, expected",
+ [
+ # Scenarios for DatasetAny
+ ((True, True, True), "any", True),
+ ((True, True, False), "any", True),
+ ((True, False, True), "any", True),
+ ((True, False, False), "any", True),
+ ((False, False, True), "any", True),
+ ((False, True, False), "any", True),
+ ((False, True, True), "any", True),
+ ((False, False, False), "any", False),
+ # Scenarios for DatasetAll
+ ((True, True, True), "all", True),
+ ((True, True, False), "all", False),
+ ((True, False, True), "all", False),
+ ((True, False, False), "all", False),
+ ((False, False, True), "all", False),
+ ((False, True, False), "all", False),
+ ((False, True, True), "all", False),
+ ((False, False, False), "all", False),
+ ],
+)
+def test_dataset_logical_conditions_evaluation_and_serialization(inputs,
scenario, expected):
+ class_ = DatasetAny if scenario == "any" else DatasetAll
+ datasets = [Dataset(uri=f"s3://abc/{i}") for i in range(123, 126)]
+ condition = class_(*datasets)
+
+ statuses = {dataset.uri: status for dataset, status in zip(datasets,
inputs)}
+ assert (
+ condition.evaluate(statuses) == expected
+ ), f"Condition evaluation failed for inputs {inputs} and scenario
'{scenario}'"
+
+ # Serialize and deserialize the condition to test persistence
+ serialized = BaseSerialization.serialize(condition)
+ deserialized = BaseSerialization.deserialize(serialized)
+ assert deserialized.evaluate(statuses) == expected, "Serialization
round-trip failed"
+
+
[email protected]_test
[email protected](
+ "status_values, expected_evaluation",
+ [
+ ((False, True, True), False), # DatasetAll requires all conditions to
be True, but d1 is False
+ ((True, True, True), True), # All conditions are True
+ ((True, False, True), True), # d1 is True, and DatasetAny condition
(d2 or d3 being True) is met
+ ((True, False, False), False), # d1 is True, but neither d2 nor d3
meet the DatasetAny condition
+ ],
+)
+def test_nested_dataset_conditions_with_serialization(status_values,
expected_evaluation):
+ # Define datasets
+ d1 = Dataset(uri="s3://abc/123")
+ d2 = Dataset(uri="s3://abc/124")
+ d3 = Dataset(uri="s3://abc/125")
+
+ # Create a nested condition: DatasetAll with d1 and DatasetAny with d2 and
d3
+ nested_condition = DatasetAll(d1, DatasetAny(d2, d3))
+
+ statuses = {
+ d1.uri: status_values[0],
+ d2.uri: status_values[1],
+ d3.uri: status_values[2],
+ }
+
+ assert nested_condition.evaluate(statuses) == expected_evaluation,
"Initial evaluation mismatch"
+
+ serialized_condition = BaseSerialization.serialize(nested_condition)
+ deserialized_condition =
BaseSerialization.deserialize(serialized_condition)
+
+ assert (
+ deserialized_condition.evaluate(statuses) == expected_evaluation
+ ), "Post-serialization evaluation mismatch"
+
+
[email protected]
+def create_test_datasets(session):
+ """Fixture to create test datasets and corresponding models."""
+ datasets = [Dataset(uri=f"hello{i}") for i in range(1, 3)]
+ for dataset in datasets:
+ session.add(DatasetModel(uri=dataset.uri))
+ session.commit()
+ return datasets
+
+
[email protected]_test
+def test_dataset_trigger_setup_and_serialization(session, dag_maker,
create_test_datasets):
+ datasets = create_test_datasets
+
+ # Create DAG with dataset triggers
+ with dag_maker(schedule=DatasetAny(*datasets)) as dag:
+ EmptyOperator(task_id="hello")
+
+ # Verify dataset triggers are set up correctly
+ assert isinstance(
+ dag.dataset_triggers, DatasetAny
+ ), "DAG dataset triggers should be an instance of DatasetAny"
+
+ # Serialize and deserialize DAG dataset triggers
+ serialized_trigger = SerializedDAG.serialize(dag.dataset_triggers)
+ deserialized_trigger = SerializedDAG.deserialize(serialized_trigger)
+
+ # Verify serialization and deserialization integrity
+ assert isinstance(
+ deserialized_trigger, DatasetAny
+ ), "Deserialized trigger should maintain type DatasetAny"
+ assert (
+ deserialized_trigger.objects == dag.dataset_triggers.objects
+ ), "Deserialized trigger objects should match original"
+
+
[email protected]_test
+def test_dataset_dag_run_queue_processing(session, dag_maker,
create_test_datasets):
+ datasets = create_test_datasets
+ dataset_models = session.query(DatasetModel).all()
+
+ with dag_maker(schedule=DatasetAny(*datasets)) as dag:
+ EmptyOperator(task_id="hello")
+
+ # Add DatasetDagRunQueue entries to simulate dataset event processing
+ for dm in dataset_models:
+ session.add(DatasetDagRunQueue(dataset_id=dm.id,
target_dag_id=dag.dag_id))
+ session.commit()
+
+ # Fetch and evaluate dataset triggers for all DAGs affected by dataset
events
+ records = session.scalars(select(DatasetDagRunQueue)).all()
+ dag_statuses = defaultdict(lambda: defaultdict(bool))
+ for record in records:
+ dag_statuses[record.target_dag_id][record.dataset.uri] = True
+
+ serialized_dags = session.execute(
+
select(SerializedDagModel).where(SerializedDagModel.dag_id.in_(dag_statuses.keys()))
+ ).fetchall()
+
+ for (serialized_dag,) in serialized_dags:
+ dag = SerializedDAG.deserialize(serialized_dag.data)
+ for dataset_uri, status in dag_statuses[dag.dag_id].items():
+ assert dag.dataset_triggers.evaluate({dataset_uri: status}), "DAG
trigger evaluation failed"
+
+
[email protected]_test
[email protected]("create_test_datasets")
+def test_additional_dag_with_no_triggers(dag_maker):
Review Comment:
This test feels a little odd without some additional assertions?
##########
airflow/models/dataset.py:
##########
@@ -336,3 +337,49 @@ def __repr__(self) -> str:
]:
args.append(f"{attr}={getattr(self, attr)!r}")
return f"{self.__class__.__name__}({', '.join(args)})"
+
+
+class DatasetBooleanCondition:
+ """
+ Base class for boolean logic for dataset triggers.
+
+ :meta private:
+ """
+
+ agg_func: Callable
+
+ def __init__(self, *objects):
+ self.objects = objects
+
+ def evaluate(self, statuses: dict[str, bool]):
+ return self.agg_func(self.eval_one(x, statuses) for x in self.objects)
+
+ def eval_one(self, obj: Dataset | DatasetAny | DatasetAll, statuses):
Review Comment:
```suggestion
def eval_one(self, obj: Dataset | DatasetAny | DatasetAll, statuses) ->
bool:
```
##########
tests/datasets/test_dataset.py:
##########
@@ -40,17 +46,224 @@ def test_invalid_uris(uri):
Dataset(uri=uri)
[email protected]_test
def test_uri_with_scheme():
dataset = Dataset(uri="s3://example_dataset")
EmptyOperator(task_id="task1", outlets=[dataset])
[email protected]_test
def test_uri_without_scheme():
dataset = Dataset(uri="example_dataset")
EmptyOperator(task_id="task1", outlets=[dataset])
[email protected]_test
def test_fspath():
uri = "s3://example_dataset"
dataset = Dataset(uri=uri)
assert os.fspath(dataset) == uri
+
+
[email protected]_test
[email protected](
+ "inputs, scenario, expected",
+ [
+ # Scenarios for DatasetAny
+ ((True, True, True), "any", True),
+ ((True, True, False), "any", True),
+ ((True, False, True), "any", True),
+ ((True, False, False), "any", True),
+ ((False, False, True), "any", True),
+ ((False, True, False), "any", True),
+ ((False, True, True), "any", True),
+ ((False, False, False), "any", False),
+ # Scenarios for DatasetAll
+ ((True, True, True), "all", True),
+ ((True, True, False), "all", False),
+ ((True, False, True), "all", False),
+ ((True, False, False), "all", False),
+ ((False, False, True), "all", False),
+ ((False, True, False), "all", False),
+ ((False, True, True), "all", False),
+ ((False, False, False), "all", False),
+ ],
+)
+def test_dataset_logical_conditions_evaluation_and_serialization(inputs,
scenario, expected):
+ class_ = DatasetAny if scenario == "any" else DatasetAll
+ datasets = [Dataset(uri=f"s3://abc/{i}") for i in range(123, 126)]
+ condition = class_(*datasets)
+
+ statuses = {dataset.uri: status for dataset, status in zip(datasets,
inputs)}
+ assert (
+ condition.evaluate(statuses) == expected
+ ), f"Condition evaluation failed for inputs {inputs} and scenario
'{scenario}'"
+
+ # Serialize and deserialize the condition to test persistence
+ serialized = BaseSerialization.serialize(condition)
+ deserialized = BaseSerialization.deserialize(serialized)
+ assert deserialized.evaluate(statuses) == expected, "Serialization
round-trip failed"
+
+
[email protected]_test
[email protected](
+ "status_values, expected_evaluation",
+ [
+ ((False, True, True), False), # DatasetAll requires all conditions to
be True, but d1 is False
+ ((True, True, True), True), # All conditions are True
+ ((True, False, True), True), # d1 is True, and DatasetAny condition
(d2 or d3 being True) is met
+ ((True, False, False), False), # d1 is True, but neither d2 nor d3
meet the DatasetAny condition
+ ],
+)
+def test_nested_dataset_conditions_with_serialization(status_values,
expected_evaluation):
+ # Define datasets
+ d1 = Dataset(uri="s3://abc/123")
+ d2 = Dataset(uri="s3://abc/124")
+ d3 = Dataset(uri="s3://abc/125")
+
+ # Create a nested condition: DatasetAll with d1 and DatasetAny with d2 and
d3
+ nested_condition = DatasetAll(d1, DatasetAny(d2, d3))
+
+ statuses = {
+ d1.uri: status_values[0],
+ d2.uri: status_values[1],
+ d3.uri: status_values[2],
+ }
+
+ assert nested_condition.evaluate(statuses) == expected_evaluation,
"Initial evaluation mismatch"
+
+ serialized_condition = BaseSerialization.serialize(nested_condition)
+ deserialized_condition =
BaseSerialization.deserialize(serialized_condition)
+
+ assert (
+ deserialized_condition.evaluate(statuses) == expected_evaluation
+ ), "Post-serialization evaluation mismatch"
+
+
[email protected]
+def create_test_datasets(session):
+ """Fixture to create test datasets and corresponding models."""
+ datasets = [Dataset(uri=f"hello{i}") for i in range(1, 3)]
+ for dataset in datasets:
+ session.add(DatasetModel(uri=dataset.uri))
+ session.commit()
+ return datasets
+
+
[email protected]_test
+def test_dataset_trigger_setup_and_serialization(session, dag_maker,
create_test_datasets):
+ datasets = create_test_datasets
+
+ # Create DAG with dataset triggers
+ with dag_maker(schedule=DatasetAny(*datasets)) as dag:
+ EmptyOperator(task_id="hello")
+
+ # Verify dataset triggers are set up correctly
+ assert isinstance(
+ dag.dataset_triggers, DatasetAny
+ ), "DAG dataset triggers should be an instance of DatasetAny"
+
+ # Serialize and deserialize DAG dataset triggers
+ serialized_trigger = SerializedDAG.serialize(dag.dataset_triggers)
+ deserialized_trigger = SerializedDAG.deserialize(serialized_trigger)
+
+ # Verify serialization and deserialization integrity
+ assert isinstance(
+ deserialized_trigger, DatasetAny
+ ), "Deserialized trigger should maintain type DatasetAny"
+ assert (
+ deserialized_trigger.objects == dag.dataset_triggers.objects
+ ), "Deserialized trigger objects should match original"
+
+
[email protected]_test
+def test_dataset_dag_run_queue_processing(session, dag_maker,
create_test_datasets):
+ datasets = create_test_datasets
+ dataset_models = session.query(DatasetModel).all()
+
+ with dag_maker(schedule=DatasetAny(*datasets)) as dag:
+ EmptyOperator(task_id="hello")
+
+ # Add DatasetDagRunQueue entries to simulate dataset event processing
+ for dm in dataset_models:
+ session.add(DatasetDagRunQueue(dataset_id=dm.id,
target_dag_id=dag.dag_id))
+ session.commit()
+
+ # Fetch and evaluate dataset triggers for all DAGs affected by dataset
events
+ records = session.scalars(select(DatasetDagRunQueue)).all()
+ dag_statuses = defaultdict(lambda: defaultdict(bool))
+ for record in records:
+ dag_statuses[record.target_dag_id][record.dataset.uri] = True
+
+ serialized_dags = session.execute(
+
select(SerializedDagModel).where(SerializedDagModel.dag_id.in_(dag_statuses.keys()))
+ ).fetchall()
+
+ for (serialized_dag,) in serialized_dags:
+ dag = SerializedDAG.deserialize(serialized_dag.data)
+ for dataset_uri, status in dag_statuses[dag.dag_id].items():
+ assert dag.dataset_triggers.evaluate({dataset_uri: status}), "DAG
trigger evaluation failed"
+
+
[email protected]_test
[email protected]("create_test_datasets")
+def test_additional_dag_with_no_triggers(dag_maker):
+ # Create an additional DAG to ensure it's not affected by dataset triggers
+ with dag_maker(dag_id="dag2"):
+ EmptyOperator(task_id="hello2")
+
+
[email protected]
+def setup_datasets_and_models(session):
+ """Fixture to create datasets and corresponding models."""
+ # Create Dataset instances
+ d1 = Dataset(uri="hello1")
+ d2 = Dataset(uri="hello2")
+
+ # Create and add DatasetModel instances to the session
+ dm1 = DatasetModel(uri=d1.uri)
+ dm2 = DatasetModel(uri=d2.uri)
+ session.add_all([dm1, dm2])
+ session.commit()
+
+ return d1, d2
+
+
[email protected]_test
+def test_dag_with_complex_dataset_triggers(session, dag_maker,
setup_datasets_and_models):
+ d1, d2 = setup_datasets_and_models
+
+ # Setup a DAG with complex dataset triggers (DatasetAny with DatasetAll)
+ with dag_maker(schedule=DatasetAny(d1, DatasetAll(d2, d1))) as dag:
+ EmptyOperator(task_id="hello")
+
+ assert isinstance(
+ dag.dataset_triggers, DatasetAny
+ ), "DAG's dataset trigger should be an instance of DatasetAny"
+ assert any(
+ isinstance(trigger, DatasetAll) for trigger in
dag.dataset_triggers.objects
+ ), "DAG's dataset trigger should include DatasetAll"
+
+ serialized_triggers = SerializedDAG.serialize(dag.dataset_triggers)
+
+ deserialized_triggers = SerializedDAG.deserialize(serialized_triggers)
+
+ assert isinstance(
+ deserialized_triggers, DatasetAny
+ ), "Deserialized triggers should be an instance of DatasetAny"
+ assert any(
+ isinstance(trigger, DatasetAll) for trigger in
deserialized_triggers.objects
+ ), "Deserialized triggers should include DatasetAll"
+
+ serialized_dag_dict = SerializedDAG.to_dict(dag)["dag"]
+ assert "dataset_triggers" in serialized_dag_dict, "Serialized DAG should
contain 'dataset_triggers'"
+ assert isinstance(
+ serialized_dag_dict["dataset_triggers"], dict
+ ), "Serialized 'dataset_triggers' should be a dict"
+
+
[email protected](autouse=True)
+def clear_datasets():
Review Comment:
nit: move this to the top.
##########
tests/datasets/test_dataset.py:
##########
@@ -40,17 +46,224 @@ def test_invalid_uris(uri):
Dataset(uri=uri)
[email protected]_test
def test_uri_with_scheme():
dataset = Dataset(uri="s3://example_dataset")
EmptyOperator(task_id="task1", outlets=[dataset])
[email protected]_test
def test_uri_without_scheme():
dataset = Dataset(uri="example_dataset")
EmptyOperator(task_id="task1", outlets=[dataset])
[email protected]_test
def test_fspath():
uri = "s3://example_dataset"
dataset = Dataset(uri=uri)
assert os.fspath(dataset) == uri
+
+
[email protected]_test
[email protected](
+ "inputs, scenario, expected",
+ [
+ # Scenarios for DatasetAny
+ ((True, True, True), "any", True),
+ ((True, True, False), "any", True),
+ ((True, False, True), "any", True),
+ ((True, False, False), "any", True),
+ ((False, False, True), "any", True),
+ ((False, True, False), "any", True),
+ ((False, True, True), "any", True),
+ ((False, False, False), "any", False),
+ # Scenarios for DatasetAll
+ ((True, True, True), "all", True),
+ ((True, True, False), "all", False),
+ ((True, False, True), "all", False),
+ ((True, False, False), "all", False),
+ ((False, False, True), "all", False),
+ ((False, True, False), "all", False),
+ ((False, True, True), "all", False),
+ ((False, False, False), "all", False),
+ ],
+)
+def test_dataset_logical_conditions_evaluation_and_serialization(inputs,
scenario, expected):
+ class_ = DatasetAny if scenario == "any" else DatasetAll
+ datasets = [Dataset(uri=f"s3://abc/{i}") for i in range(123, 126)]
+ condition = class_(*datasets)
+
+ statuses = {dataset.uri: status for dataset, status in zip(datasets,
inputs)}
+ assert (
+ condition.evaluate(statuses) == expected
+ ), f"Condition evaluation failed for inputs {inputs} and scenario
'{scenario}'"
+
+ # Serialize and deserialize the condition to test persistence
+ serialized = BaseSerialization.serialize(condition)
+ deserialized = BaseSerialization.deserialize(serialized)
+ assert deserialized.evaluate(statuses) == expected, "Serialization
round-trip failed"
+
+
[email protected]_test
[email protected](
+ "status_values, expected_evaluation",
+ [
+ ((False, True, True), False), # DatasetAll requires all conditions to
be True, but d1 is False
+ ((True, True, True), True), # All conditions are True
+ ((True, False, True), True), # d1 is True, and DatasetAny condition
(d2 or d3 being True) is met
+ ((True, False, False), False), # d1 is True, but neither d2 nor d3
meet the DatasetAny condition
+ ],
+)
+def test_nested_dataset_conditions_with_serialization(status_values,
expected_evaluation):
+ # Define datasets
+ d1 = Dataset(uri="s3://abc/123")
+ d2 = Dataset(uri="s3://abc/124")
+ d3 = Dataset(uri="s3://abc/125")
+
+ # Create a nested condition: DatasetAll with d1 and DatasetAny with d2 and
d3
+ nested_condition = DatasetAll(d1, DatasetAny(d2, d3))
+
+ statuses = {
+ d1.uri: status_values[0],
+ d2.uri: status_values[1],
+ d3.uri: status_values[2],
+ }
+
+ assert nested_condition.evaluate(statuses) == expected_evaluation,
"Initial evaluation mismatch"
+
+ serialized_condition = BaseSerialization.serialize(nested_condition)
+ deserialized_condition =
BaseSerialization.deserialize(serialized_condition)
+
+ assert (
+ deserialized_condition.evaluate(statuses) == expected_evaluation
+ ), "Post-serialization evaluation mismatch"
+
+
[email protected]
+def create_test_datasets(session):
+ """Fixture to create test datasets and corresponding models."""
+ datasets = [Dataset(uri=f"hello{i}") for i in range(1, 3)]
+ for dataset in datasets:
+ session.add(DatasetModel(uri=dataset.uri))
+ session.commit()
+ return datasets
+
+
[email protected]_test
+def test_dataset_trigger_setup_and_serialization(session, dag_maker,
create_test_datasets):
+ datasets = create_test_datasets
+
+ # Create DAG with dataset triggers
+ with dag_maker(schedule=DatasetAny(*datasets)) as dag:
+ EmptyOperator(task_id="hello")
+
+ # Verify dataset triggers are set up correctly
+ assert isinstance(
+ dag.dataset_triggers, DatasetAny
+ ), "DAG dataset triggers should be an instance of DatasetAny"
+
+ # Serialize and deserialize DAG dataset triggers
+ serialized_trigger = SerializedDAG.serialize(dag.dataset_triggers)
+ deserialized_trigger = SerializedDAG.deserialize(serialized_trigger)
+
+ # Verify serialization and deserialization integrity
+ assert isinstance(
+ deserialized_trigger, DatasetAny
+ ), "Deserialized trigger should maintain type DatasetAny"
+ assert (
+ deserialized_trigger.objects == dag.dataset_triggers.objects
+ ), "Deserialized trigger objects should match original"
+
+
[email protected]_test
+def test_dataset_dag_run_queue_processing(session, dag_maker,
create_test_datasets):
+ datasets = create_test_datasets
+ dataset_models = session.query(DatasetModel).all()
+
+ with dag_maker(schedule=DatasetAny(*datasets)) as dag:
+ EmptyOperator(task_id="hello")
+
+ # Add DatasetDagRunQueue entries to simulate dataset event processing
+ for dm in dataset_models:
+ session.add(DatasetDagRunQueue(dataset_id=dm.id,
target_dag_id=dag.dag_id))
+ session.commit()
+
+ # Fetch and evaluate dataset triggers for all DAGs affected by dataset
events
+ records = session.scalars(select(DatasetDagRunQueue)).all()
+ dag_statuses = defaultdict(lambda: defaultdict(bool))
+ for record in records:
+ dag_statuses[record.target_dag_id][record.dataset.uri] = True
+
+ serialized_dags = session.execute(
+
select(SerializedDagModel).where(SerializedDagModel.dag_id.in_(dag_statuses.keys()))
+ ).fetchall()
+
+ for (serialized_dag,) in serialized_dags:
+ dag = SerializedDAG.deserialize(serialized_dag.data)
+ for dataset_uri, status in dag_statuses[dag.dag_id].items():
+ assert dag.dataset_triggers.evaluate({dataset_uri: status}), "DAG
trigger evaluation failed"
+
+
[email protected]_test
[email protected]("create_test_datasets")
+def test_additional_dag_with_no_triggers(dag_maker):
+ # Create an additional DAG to ensure it's not affected by dataset triggers
+ with dag_maker(dag_id="dag2"):
+ EmptyOperator(task_id="hello2")
+
+
[email protected]
+def setup_datasets_and_models(session):
Review Comment:
Not sure this needs to be a fixture, it's only used once and isn't that big.
##########
tests/datasets/test_dataset.py:
##########
@@ -18,13 +18,19 @@
from __future__ import annotations
import os
+from collections import defaultdict
import pytest
+from sqlalchemy.sql import select
from airflow.datasets import Dataset
+from airflow.models.dataset import DatasetAll, DatasetAny, DatasetDagRunQueue,
DatasetModel
+from airflow.models.serialized_dag import SerializedDagModel
from airflow.operators.empty import EmptyOperator
+from airflow.serialization.serialized_objects import BaseSerialization,
SerializedDAG
[email protected]_test
Review Comment:
Why do all of these need to db tests now?
##########
tests/datasets/test_dataset.py:
##########
@@ -40,17 +46,224 @@ def test_invalid_uris(uri):
Dataset(uri=uri)
[email protected]_test
def test_uri_with_scheme():
dataset = Dataset(uri="s3://example_dataset")
EmptyOperator(task_id="task1", outlets=[dataset])
[email protected]_test
def test_uri_without_scheme():
dataset = Dataset(uri="example_dataset")
EmptyOperator(task_id="task1", outlets=[dataset])
[email protected]_test
def test_fspath():
uri = "s3://example_dataset"
dataset = Dataset(uri=uri)
assert os.fspath(dataset) == uri
+
+
[email protected]_test
[email protected](
+ "inputs, scenario, expected",
+ [
+ # Scenarios for DatasetAny
+ ((True, True, True), "any", True),
+ ((True, True, False), "any", True),
+ ((True, False, True), "any", True),
+ ((True, False, False), "any", True),
+ ((False, False, True), "any", True),
+ ((False, True, False), "any", True),
+ ((False, True, True), "any", True),
+ ((False, False, False), "any", False),
+ # Scenarios for DatasetAll
+ ((True, True, True), "all", True),
+ ((True, True, False), "all", False),
+ ((True, False, True), "all", False),
+ ((True, False, False), "all", False),
+ ((False, False, True), "all", False),
+ ((False, True, False), "all", False),
+ ((False, True, True), "all", False),
+ ((False, False, False), "all", False),
+ ],
+)
+def test_dataset_logical_conditions_evaluation_and_serialization(inputs,
scenario, expected):
+ class_ = DatasetAny if scenario == "any" else DatasetAll
+ datasets = [Dataset(uri=f"s3://abc/{i}") for i in range(123, 126)]
+ condition = class_(*datasets)
+
+ statuses = {dataset.uri: status for dataset, status in zip(datasets,
inputs)}
+ assert (
+ condition.evaluate(statuses) == expected
+ ), f"Condition evaluation failed for inputs {inputs} and scenario
'{scenario}'"
+
+ # Serialize and deserialize the condition to test persistence
+ serialized = BaseSerialization.serialize(condition)
+ deserialized = BaseSerialization.deserialize(serialized)
+ assert deserialized.evaluate(statuses) == expected, "Serialization
round-trip failed"
+
+
[email protected]_test
[email protected](
+ "status_values, expected_evaluation",
+ [
+ ((False, True, True), False), # DatasetAll requires all conditions to
be True, but d1 is False
+ ((True, True, True), True), # All conditions are True
+ ((True, False, True), True), # d1 is True, and DatasetAny condition
(d2 or d3 being True) is met
+ ((True, False, False), False), # d1 is True, but neither d2 nor d3
meet the DatasetAny condition
+ ],
+)
+def test_nested_dataset_conditions_with_serialization(status_values,
expected_evaluation):
+ # Define datasets
+ d1 = Dataset(uri="s3://abc/123")
+ d2 = Dataset(uri="s3://abc/124")
+ d3 = Dataset(uri="s3://abc/125")
+
+ # Create a nested condition: DatasetAll with d1 and DatasetAny with d2 and
d3
+ nested_condition = DatasetAll(d1, DatasetAny(d2, d3))
+
+ statuses = {
+ d1.uri: status_values[0],
+ d2.uri: status_values[1],
+ d3.uri: status_values[2],
+ }
+
+ assert nested_condition.evaluate(statuses) == expected_evaluation,
"Initial evaluation mismatch"
+
+ serialized_condition = BaseSerialization.serialize(nested_condition)
+ deserialized_condition =
BaseSerialization.deserialize(serialized_condition)
+
+ assert (
+ deserialized_condition.evaluate(statuses) == expected_evaluation
+ ), "Post-serialization evaluation mismatch"
+
+
[email protected]
+def create_test_datasets(session):
+ """Fixture to create test datasets and corresponding models."""
+ datasets = [Dataset(uri=f"hello{i}") for i in range(1, 3)]
+ for dataset in datasets:
+ session.add(DatasetModel(uri=dataset.uri))
+ session.commit()
+ return datasets
+
+
[email protected]_test
+def test_dataset_trigger_setup_and_serialization(session, dag_maker,
create_test_datasets):
+ datasets = create_test_datasets
+
+ # Create DAG with dataset triggers
+ with dag_maker(schedule=DatasetAny(*datasets)) as dag:
+ EmptyOperator(task_id="hello")
+
+ # Verify dataset triggers are set up correctly
+ assert isinstance(
+ dag.dataset_triggers, DatasetAny
+ ), "DAG dataset triggers should be an instance of DatasetAny"
+
+ # Serialize and deserialize DAG dataset triggers
+ serialized_trigger = SerializedDAG.serialize(dag.dataset_triggers)
+ deserialized_trigger = SerializedDAG.deserialize(serialized_trigger)
+
+ # Verify serialization and deserialization integrity
+ assert isinstance(
+ deserialized_trigger, DatasetAny
+ ), "Deserialized trigger should maintain type DatasetAny"
+ assert (
+ deserialized_trigger.objects == dag.dataset_triggers.objects
+ ), "Deserialized trigger objects should match original"
+
+
[email protected]_test
+def test_dataset_dag_run_queue_processing(session, dag_maker,
create_test_datasets):
+ datasets = create_test_datasets
+ dataset_models = session.query(DatasetModel).all()
+
+ with dag_maker(schedule=DatasetAny(*datasets)) as dag:
+ EmptyOperator(task_id="hello")
+
+ # Add DatasetDagRunQueue entries to simulate dataset event processing
+ for dm in dataset_models:
+ session.add(DatasetDagRunQueue(dataset_id=dm.id,
target_dag_id=dag.dag_id))
+ session.commit()
+
+ # Fetch and evaluate dataset triggers for all DAGs affected by dataset
events
+ records = session.scalars(select(DatasetDagRunQueue)).all()
+ dag_statuses = defaultdict(lambda: defaultdict(bool))
+ for record in records:
+ dag_statuses[record.target_dag_id][record.dataset.uri] = True
+
+ serialized_dags = session.execute(
+
select(SerializedDagModel).where(SerializedDagModel.dag_id.in_(dag_statuses.keys()))
+ ).fetchall()
+
+ for (serialized_dag,) in serialized_dags:
+ dag = SerializedDAG.deserialize(serialized_dag.data)
+ for dataset_uri, status in dag_statuses[dag.dag_id].items():
+ assert dag.dataset_triggers.evaluate({dataset_uri: status}), "DAG
trigger evaluation failed"
+
+
[email protected]_test
[email protected]("create_test_datasets")
+def test_additional_dag_with_no_triggers(dag_maker):
+ # Create an additional DAG to ensure it's not affected by dataset triggers
+ with dag_maker(dag_id="dag2"):
+ EmptyOperator(task_id="hello2")
+
+
[email protected]
+def setup_datasets_and_models(session):
+ """Fixture to create datasets and corresponding models."""
+ # Create Dataset instances
+ d1 = Dataset(uri="hello1")
+ d2 = Dataset(uri="hello2")
+
+ # Create and add DatasetModel instances to the session
+ dm1 = DatasetModel(uri=d1.uri)
+ dm2 = DatasetModel(uri=d2.uri)
+ session.add_all([dm1, dm2])
+ session.commit()
+
+ return d1, d2
+
+
[email protected]_test
+def test_dag_with_complex_dataset_triggers(session, dag_maker,
setup_datasets_and_models):
+ d1, d2 = setup_datasets_and_models
+
+ # Setup a DAG with complex dataset triggers (DatasetAny with DatasetAll)
+ with dag_maker(schedule=DatasetAny(d1, DatasetAll(d2, d1))) as dag:
+ EmptyOperator(task_id="hello")
+
+ assert isinstance(
+ dag.dataset_triggers, DatasetAny
+ ), "DAG's dataset trigger should be an instance of DatasetAny"
+ assert any(
+ isinstance(trigger, DatasetAll) for trigger in
dag.dataset_triggers.objects
+ ), "DAG's dataset trigger should include DatasetAll"
+
+ serialized_triggers = SerializedDAG.serialize(dag.dataset_triggers)
+
+ deserialized_triggers = SerializedDAG.deserialize(serialized_triggers)
+
+ assert isinstance(
+ deserialized_triggers, DatasetAny
+ ), "Deserialized triggers should be an instance of DatasetAny"
+ assert any(
+ isinstance(trigger, DatasetAll) for trigger in
deserialized_triggers.objects
+ ), "Deserialized triggers should include DatasetAll"
+
+ serialized_dag_dict = SerializedDAG.to_dict(dag)["dag"]
+ assert "dataset_triggers" in serialized_dag_dict, "Serialized DAG should
contain 'dataset_triggers'"
+ assert isinstance(
+ serialized_dag_dict["dataset_triggers"], dict
+ ), "Serialized 'dataset_triggers' should be a dict"
+
+
[email protected](autouse=True)
+def clear_datasets():
+ from tests.test_utils.db import clear_db_datasets
+
+ clear_db_datasets()
Review Comment:
```suggestion
clear_db_datasets()
yield
clear_db_datasets()
```
Should we also clear when we finish up?
##########
airflow/models/dataset.py:
##########
@@ -336,3 +337,49 @@ def __repr__(self) -> str:
]:
args.append(f"{attr}={getattr(self, attr)!r}")
return f"{self.__class__.__name__}({', '.join(args)})"
+
+
+class DatasetBooleanCondition:
+ """
+ Base class for boolean logic for dataset triggers.
+
+ :meta private:
+ """
+
+ agg_func: Callable
+
+ def __init__(self, *objects):
Review Comment:
```suggestion
def __init__(self, *objects) -> None:
```
##########
airflow/timetables/datasets.py:
##########
@@ -52,24 +55,20 @@ def deserialize(cls, data: dict[str, typing.Any]) ->
Timetable:
from airflow.serialization.serialized_objects import decode_timetable
return cls(
- timetable=decode_timetable(data["timetable"]),
datasets=[Dataset(**d) for d in data["datasets"]]
+ timetable=decode_timetable(data["timetable"]),
+ datasets=[], # don't need the datasets after deserialization
Review Comment:
Why don't we need them?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]