sunank200 commented on code in PR #37016:
URL: https://github.com/apache/airflow/pull/37016#discussion_r1498020329
##########
tests/datasets/test_dataset.py:
##########
@@ -40,17 +46,224 @@ def test_invalid_uris(uri):
Dataset(uri=uri)
[email protected]_test
def test_uri_with_scheme():
dataset = Dataset(uri="s3://example_dataset")
EmptyOperator(task_id="task1", outlets=[dataset])
[email protected]_test
def test_uri_without_scheme():
dataset = Dataset(uri="example_dataset")
EmptyOperator(task_id="task1", outlets=[dataset])
[email protected]_test
def test_fspath():
uri = "s3://example_dataset"
dataset = Dataset(uri=uri)
assert os.fspath(dataset) == uri
+
+
[email protected]_test
[email protected](
+ "inputs, scenario, expected",
+ [
+ # Scenarios for DatasetAny
+ ((True, True, True), "any", True),
+ ((True, True, False), "any", True),
+ ((True, False, True), "any", True),
+ ((True, False, False), "any", True),
+ ((False, False, True), "any", True),
+ ((False, True, False), "any", True),
+ ((False, True, True), "any", True),
+ ((False, False, False), "any", False),
+ # Scenarios for DatasetAll
+ ((True, True, True), "all", True),
+ ((True, True, False), "all", False),
+ ((True, False, True), "all", False),
+ ((True, False, False), "all", False),
+ ((False, False, True), "all", False),
+ ((False, True, False), "all", False),
+ ((False, True, True), "all", False),
+ ((False, False, False), "all", False),
+ ],
+)
+def test_dataset_logical_conditions_evaluation_and_serialization(inputs,
scenario, expected):
+ class_ = DatasetAny if scenario == "any" else DatasetAll
+ datasets = [Dataset(uri=f"s3://abc/{i}") for i in range(123, 126)]
+ condition = class_(*datasets)
+
+ statuses = {dataset.uri: status for dataset, status in zip(datasets,
inputs)}
+ assert (
+ condition.evaluate(statuses) == expected
+ ), f"Condition evaluation failed for inputs {inputs} and scenario
'{scenario}'"
+
+ # Serialize and deserialize the condition to test persistence
+ serialized = BaseSerialization.serialize(condition)
+ deserialized = BaseSerialization.deserialize(serialized)
+ assert deserialized.evaluate(statuses) == expected, "Serialization
round-trip failed"
+
+
[email protected]_test
[email protected](
+ "status_values, expected_evaluation",
+ [
+ ((False, True, True), False), # DatasetAll requires all conditions to
be True, but d1 is False
+ ((True, True, True), True), # All conditions are True
+ ((True, False, True), True), # d1 is True, and DatasetAny condition
(d2 or d3 being True) is met
+ ((True, False, False), False), # d1 is True, but neither d2 nor d3
meet the DatasetAny condition
+ ],
+)
+def test_nested_dataset_conditions_with_serialization(status_values,
expected_evaluation):
+ # Define datasets
+ d1 = Dataset(uri="s3://abc/123")
+ d2 = Dataset(uri="s3://abc/124")
+ d3 = Dataset(uri="s3://abc/125")
+
+ # Create a nested condition: DatasetAll with d1 and DatasetAny with d2 and
d3
+ nested_condition = DatasetAll(d1, DatasetAny(d2, d3))
+
+ statuses = {
+ d1.uri: status_values[0],
+ d2.uri: status_values[1],
+ d3.uri: status_values[2],
+ }
+
+ assert nested_condition.evaluate(statuses) == expected_evaluation,
"Initial evaluation mismatch"
+
+ serialized_condition = BaseSerialization.serialize(nested_condition)
+ deserialized_condition =
BaseSerialization.deserialize(serialized_condition)
+
+ assert (
+ deserialized_condition.evaluate(statuses) == expected_evaluation
+ ), "Post-serialization evaluation mismatch"
+
+
[email protected]
+def create_test_datasets(session):
+ """Fixture to create test datasets and corresponding models."""
+ datasets = [Dataset(uri=f"hello{i}") for i in range(1, 3)]
+ for dataset in datasets:
+ session.add(DatasetModel(uri=dataset.uri))
+ session.commit()
+ return datasets
+
+
[email protected]_test
+def test_dataset_trigger_setup_and_serialization(session, dag_maker,
create_test_datasets):
+ datasets = create_test_datasets
+
+ # Create DAG with dataset triggers
+ with dag_maker(schedule=DatasetAny(*datasets)) as dag:
+ EmptyOperator(task_id="hello")
+
+ # Verify dataset triggers are set up correctly
+ assert isinstance(
+ dag.dataset_triggers, DatasetAny
+ ), "DAG dataset triggers should be an instance of DatasetAny"
+
+ # Serialize and deserialize DAG dataset triggers
+ serialized_trigger = SerializedDAG.serialize(dag.dataset_triggers)
+ deserialized_trigger = SerializedDAG.deserialize(serialized_trigger)
+
+ # Verify serialization and deserialization integrity
+ assert isinstance(
+ deserialized_trigger, DatasetAny
+ ), "Deserialized trigger should maintain type DatasetAny"
+ assert (
+ deserialized_trigger.objects == dag.dataset_triggers.objects
+ ), "Deserialized trigger objects should match original"
+
+
[email protected]_test
+def test_dataset_dag_run_queue_processing(session, dag_maker,
create_test_datasets):
+ datasets = create_test_datasets
+ dataset_models = session.query(DatasetModel).all()
+
+ with dag_maker(schedule=DatasetAny(*datasets)) as dag:
+ EmptyOperator(task_id="hello")
+
+ # Add DatasetDagRunQueue entries to simulate dataset event processing
+ for dm in dataset_models:
+ session.add(DatasetDagRunQueue(dataset_id=dm.id,
target_dag_id=dag.dag_id))
+ session.commit()
+
+ # Fetch and evaluate dataset triggers for all DAGs affected by dataset
events
+ records = session.scalars(select(DatasetDagRunQueue)).all()
+ dag_statuses = defaultdict(lambda: defaultdict(bool))
+ for record in records:
+ dag_statuses[record.target_dag_id][record.dataset.uri] = True
+
+ serialized_dags = session.execute(
+
select(SerializedDagModel).where(SerializedDagModel.dag_id.in_(dag_statuses.keys()))
+ ).fetchall()
+
+ for (serialized_dag,) in serialized_dags:
+ dag = SerializedDAG.deserialize(serialized_dag.data)
+ for dataset_uri, status in dag_statuses[dag.dag_id].items():
+ assert dag.dataset_triggers.evaluate({dataset_uri: status}), "DAG
trigger evaluation failed"
+
+
[email protected]_test
[email protected]("create_test_datasets")
+def test_additional_dag_with_no_triggers(dag_maker):
Review Comment:
I think the test_additional_dag_with_no_triggers function was to validate
that a DAG runs correctly, without any dataset triggers and it is correctly
configured and functions independently of dataset-related events.
##########
tests/datasets/test_dataset.py:
##########
@@ -18,13 +18,19 @@
from __future__ import annotations
import os
+from collections import defaultdict
import pytest
+from sqlalchemy.sql import select
from airflow.datasets import Dataset
+from airflow.models.dataset import DatasetAll, DatasetAny, DatasetDagRunQueue,
DatasetModel
+from airflow.models.serialized_dag import SerializedDagModel
from airflow.operators.empty import EmptyOperator
+from airflow.serialization.serialized_objects import BaseSerialization,
SerializedDAG
[email protected]_test
Review Comment:
That's true its for `clear_datasets()`
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]