dstandish commented on code in PR #37016:
URL: https://github.com/apache/airflow/pull/37016#discussion_r1498049067


##########
tests/datasets/test_dataset.py:
##########
@@ -40,17 +46,224 @@ def test_invalid_uris(uri):
         Dataset(uri=uri)
 
 
[email protected]_test
 def test_uri_with_scheme():
     dataset = Dataset(uri="s3://example_dataset")
     EmptyOperator(task_id="task1", outlets=[dataset])
 
 
[email protected]_test
 def test_uri_without_scheme():
     dataset = Dataset(uri="example_dataset")
     EmptyOperator(task_id="task1", outlets=[dataset])
 
 
[email protected]_test
 def test_fspath():
     uri = "s3://example_dataset"
     dataset = Dataset(uri=uri)
     assert os.fspath(dataset) == uri
+
+
[email protected]_test
[email protected](
+    "inputs, scenario, expected",
+    [
+        # Scenarios for DatasetAny
+        ((True, True, True), "any", True),
+        ((True, True, False), "any", True),
+        ((True, False, True), "any", True),
+        ((True, False, False), "any", True),
+        ((False, False, True), "any", True),
+        ((False, True, False), "any", True),
+        ((False, True, True), "any", True),
+        ((False, False, False), "any", False),
+        # Scenarios for DatasetAll
+        ((True, True, True), "all", True),
+        ((True, True, False), "all", False),
+        ((True, False, True), "all", False),
+        ((True, False, False), "all", False),
+        ((False, False, True), "all", False),
+        ((False, True, False), "all", False),
+        ((False, True, True), "all", False),
+        ((False, False, False), "all", False),
+    ],
+)
+def test_dataset_logical_conditions_evaluation_and_serialization(inputs, 
scenario, expected):
+    class_ = DatasetAny if scenario == "any" else DatasetAll
+    datasets = [Dataset(uri=f"s3://abc/{i}") for i in range(123, 126)]
+    condition = class_(*datasets)
+
+    statuses = {dataset.uri: status for dataset, status in zip(datasets, 
inputs)}
+    assert (
+        condition.evaluate(statuses) == expected
+    ), f"Condition evaluation failed for inputs {inputs} and scenario 
'{scenario}'"
+
+    # Serialize and deserialize the condition to test persistence
+    serialized = BaseSerialization.serialize(condition)
+    deserialized = BaseSerialization.deserialize(serialized)
+    assert deserialized.evaluate(statuses) == expected, "Serialization 
round-trip failed"
+
+
[email protected]_test
[email protected](
+    "status_values, expected_evaluation",
+    [
+        ((False, True, True), False),  # DatasetAll requires all conditions to 
be True, but d1 is False
+        ((True, True, True), True),  # All conditions are True
+        ((True, False, True), True),  # d1 is True, and DatasetAny condition 
(d2 or d3 being True) is met
+        ((True, False, False), False),  # d1 is True, but neither d2 nor d3 
meet the DatasetAny condition
+    ],
+)
+def test_nested_dataset_conditions_with_serialization(status_values, 
expected_evaluation):
+    # Define datasets
+    d1 = Dataset(uri="s3://abc/123")
+    d2 = Dataset(uri="s3://abc/124")
+    d3 = Dataset(uri="s3://abc/125")
+
+    # Create a nested condition: DatasetAll with d1 and DatasetAny with d2 and 
d3
+    nested_condition = DatasetAll(d1, DatasetAny(d2, d3))
+
+    statuses = {
+        d1.uri: status_values[0],
+        d2.uri: status_values[1],
+        d3.uri: status_values[2],
+    }
+
+    assert nested_condition.evaluate(statuses) == expected_evaluation, 
"Initial evaluation mismatch"
+
+    serialized_condition = BaseSerialization.serialize(nested_condition)
+    deserialized_condition = 
BaseSerialization.deserialize(serialized_condition)
+
+    assert (
+        deserialized_condition.evaluate(statuses) == expected_evaluation
+    ), "Post-serialization evaluation mismatch"
+
+
[email protected]
+def create_test_datasets(session):
+    """Fixture to create test datasets and corresponding models."""
+    datasets = [Dataset(uri=f"hello{i}") for i in range(1, 3)]
+    for dataset in datasets:
+        session.add(DatasetModel(uri=dataset.uri))
+    session.commit()
+    return datasets
+
+
[email protected]_test
+def test_dataset_trigger_setup_and_serialization(session, dag_maker, 
create_test_datasets):
+    datasets = create_test_datasets
+
+    # Create DAG with dataset triggers
+    with dag_maker(schedule=DatasetAny(*datasets)) as dag:
+        EmptyOperator(task_id="hello")
+
+    # Verify dataset triggers are set up correctly
+    assert isinstance(
+        dag.dataset_triggers, DatasetAny
+    ), "DAG dataset triggers should be an instance of DatasetAny"
+
+    # Serialize and deserialize DAG dataset triggers
+    serialized_trigger = SerializedDAG.serialize(dag.dataset_triggers)
+    deserialized_trigger = SerializedDAG.deserialize(serialized_trigger)
+
+    # Verify serialization and deserialization integrity
+    assert isinstance(
+        deserialized_trigger, DatasetAny
+    ), "Deserialized trigger should maintain type DatasetAny"
+    assert (
+        deserialized_trigger.objects == dag.dataset_triggers.objects
+    ), "Deserialized trigger objects should match original"
+
+
[email protected]_test
+def test_dataset_dag_run_queue_processing(session, dag_maker, 
create_test_datasets):
+    datasets = create_test_datasets
+    dataset_models = session.query(DatasetModel).all()
+
+    with dag_maker(schedule=DatasetAny(*datasets)) as dag:
+        EmptyOperator(task_id="hello")
+
+    # Add DatasetDagRunQueue entries to simulate dataset event processing
+    for dm in dataset_models:
+        session.add(DatasetDagRunQueue(dataset_id=dm.id, 
target_dag_id=dag.dag_id))
+    session.commit()
+
+    # Fetch and evaluate dataset triggers for all DAGs affected by dataset 
events
+    records = session.scalars(select(DatasetDagRunQueue)).all()
+    dag_statuses = defaultdict(lambda: defaultdict(bool))
+    for record in records:
+        dag_statuses[record.target_dag_id][record.dataset.uri] = True
+
+    serialized_dags = session.execute(
+        
select(SerializedDagModel).where(SerializedDagModel.dag_id.in_(dag_statuses.keys()))
+    ).fetchall()
+
+    for (serialized_dag,) in serialized_dags:
+        dag = SerializedDAG.deserialize(serialized_dag.data)
+        for dataset_uri, status in dag_statuses[dag.dag_id].items():
+            assert dag.dataset_triggers.evaluate({dataset_uri: status}), "DAG 
trigger evaluation failed"
+
+
[email protected]_test
[email protected]("create_test_datasets")
+def test_additional_dag_with_no_triggers(dag_maker):
+    # Create an additional DAG to ensure it's not affected by dataset triggers
+    with dag_maker(dag_id="dag2"):
+        EmptyOperator(task_id="hello2")
+
+
[email protected]
+def setup_datasets_and_models(session):

Review Comment:
   yup, removing



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to