jedcunningham commented on code in PR #46021:
URL: https://github.com/apache/airflow/pull/46021#discussion_r1929067488


##########
tests/models/test_dagrun.py:
##########
@@ -1412,382 +1386,68 @@ def task_2(arg2): ...
     assert len(decision.schedulable_tis) == 2
 
 
-def 
test_mapped_literal_length_with_no_change_at_runtime_doesnt_call_verify_integrity(dag_maker,
 session):
-    """
-    Test that when there's no change to mapped task indexes at runtime, the 
dagrun.verify_integrity
-    is not called
-    """
-    from airflow.models import Variable
-
-    Variable.set(key="arg1", value=[1, 2, 3])
-
-    @task
-    def task_1():
-        return Variable.get("arg1", deserialize_json=True)
-
-    with dag_maker(session=session) as dag:
-
-        @task
-        def task_2(arg2): ...
-
-        task_2.expand(arg2=task_1())
-
-    dr = dag_maker.create_dagrun()
-    ti = dr.get_task_instance(task_id="task_1")
-    ti.run()
-    dr.task_instance_scheduling_decisions()
-    tis = dr.get_task_instances()
-    indices = [(ti.map_index, ti.state) for ti in tis if ti.map_index >= 0]
-    assert sorted(indices) == [
-        (0, State.NONE),
-        (1, State.NONE),
-        (2, State.NONE),
-    ]
-
-    # Now "clear" and no change to length
-    dag.clear()
-    Variable.set(key="arg1", value=[1, 2, 3])
-
-    with dag:
-        task_2.expand(arg2=task_1()).operator
-
-    # At this point, we need to test that the change works on the serialized
-    # DAG (which is what the scheduler operates on)
-    serialized_dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
-
-    dr.dag = serialized_dag
-
-    # Run the first task again to get the new lengths
-    ti = dr.get_task_instance(task_id="task_1")
-    task1 = dag.get_task("task_1")
-    ti.refresh_from_task(task1)
-    ti.run()
-
-    # this would be called by the localtask job
-    # Verify that DagRun.verify_integrity is not called
-    with mock.patch("airflow.models.dagrun.DagRun.verify_integrity") as 
mock_verify_integrity:
-        dr.task_instance_scheduling_decisions()
-        mock_verify_integrity.assert_not_called()
-
-
-def 
test_calls_to_verify_integrity_with_mapped_task_increase_at_runtime(dag_maker, 
session):
-    """
-    Test increase in mapped task at runtime with calls to 
dagrun.verify_integrity
-    """
-    from airflow.models import Variable
-
-    Variable.set(key="arg1", value=[1, 2, 3])
-
-    @task
-    def task_1():
-        return Variable.get("arg1", deserialize_json=True)
-
-    with dag_maker(session=session) as dag:
-
-        @task
-        def task_2(arg2): ...
-
-        task_2.expand(arg2=task_1())
-
-    dr = dag_maker.create_dagrun()
-    ti = dr.get_task_instance(task_id="task_1")
-    ti.run()
-    dr.task_instance_scheduling_decisions()
-    tis = dr.get_task_instances()
-    indices = [(ti.map_index, ti.state) for ti in tis if ti.map_index >= 0]
-    assert sorted(indices) == [
-        (0, State.NONE),
-        (1, State.NONE),
-        (2, State.NONE),
-    ]
-    # Now "clear" and "increase" the length of literal
-    dag.clear()
-    Variable.set(key="arg1", value=[1, 2, 3, 4, 5])
-
-    with dag:
-        task_2.expand(arg2=task_1()).operator
-
-    # At this point, we need to test that the change works on the serialized
-    # DAG (which is what the scheduler operates on)
-    serialized_dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
-
-    dr.dag = serialized_dag
-
-    # Run the first task again to get the new lengths
-    ti = dr.get_task_instance(task_id="task_1")
-    task1 = dag.get_task("task_1")
-    ti.refresh_from_task(task1)
-    ti.run()
-    task2 = dag.get_task("task_2")
-    for ti in dr.get_task_instances():
-        if ti.map_index < 0:
-            ti.task = task1
-        else:
-            ti.task = task2
-        session.merge(ti)
-    session.flush()
-    # create the additional task
-    dr.task_instance_scheduling_decisions()
-    # Run verify_integrity as a whole and assert new tasks were added
-    dr.verify_integrity()
-    tis = dr.get_task_instances()
-    indices = [(ti.map_index, ti.state) for ti in tis if ti.map_index >= 0]
-    assert sorted(indices) == [
-        (0, State.NONE),
-        (1, State.NONE),
-        (2, State.NONE),
-        (3, State.NONE),
-        (4, State.NONE),
-    ]
-    ti3 = dr.get_task_instance(task_id="task_2", map_index=3)
-    ti3.task = task2
-    ti3.state = TaskInstanceState.FAILED
-    session.merge(ti3)
-    session.flush()
-    # assert repeated calls did not change the instances
-    dr.verify_integrity()
-    tis = dr.get_task_instances()
-    indices = [(ti.map_index, ti.state) for ti in tis if ti.map_index >= 0]
-    assert sorted(indices) == [
-        (0, State.NONE),
-        (1, State.NONE),
-        (2, State.NONE),
-        (3, TaskInstanceState.FAILED),
-        (4, State.NONE),
-    ]
-
-
-def 
test_calls_to_verify_integrity_with_mapped_task_reduction_at_runtime(dag_maker, 
session):
[email protected]_serialized_dag
+def 
test_calls_to_verify_integrity_with_mapped_task_zero_length_at_runtime(dag_maker,
 session, caplog):
     """
-    Test reduction in mapped task at runtime with calls to 
dagrun.verify_integrity
+    Test zero length reduction in mapped task at runtime with calls to 
dagrun.verify_integrity
     """
-    from airflow.models import Variable
-
-    Variable.set(key="arg1", value=[1, 2, 3])
-
-    @task
-    def task_1():
-        return Variable.get("arg1", deserialize_json=True)
+    import logging
 
     with dag_maker(session=session) as dag:
 
         @task
-        def task_2(arg2): ...
-
-        task_2.expand(arg2=task_1())
-
-    dr = dag_maker.create_dagrun()
-    ti = dr.get_task_instance(task_id="task_1")
-    ti.run()
-    dr.task_instance_scheduling_decisions()
-    tis = dr.get_task_instances()
-    indices = [(ti.map_index, ti.state) for ti in tis if ti.map_index >= 0]
-    assert sorted(indices) == [
-        (0, State.NONE),
-        (1, State.NONE),
-        (2, State.NONE),
-    ]
-    # Now "clear" and "reduce" the length of literal
-    dag.clear()
-    Variable.set(key="arg1", value=[1])
-
-    with dag:
-        task_2.expand(arg2=task_1()).operator
-
-    # At this point, we need to test that the change works on the serialized
-    # DAG (which is what the scheduler operates on)
-    serialized_dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
-
-    dr.dag = serialized_dag
-
-    # Run the first task again to get the new lengths
-    ti = dr.get_task_instance(task_id="task_1")
-    task1 = dag.get_task("task_1")
-    ti.refresh_from_task(task1)
-    ti.run()
-    task2 = dag.get_task("task_2")
-    for ti in dr.get_task_instances():
-        if ti.map_index < 0:
-            ti.task = task1
-        else:
-            ti.task = task2
-            ti.state = TaskInstanceState.SUCCESS
-        session.merge(ti)
-    session.flush()
-
-    # Run verify_integrity as a whole and assert some tasks were removed
-    dr.verify_integrity()
-    tis = dr.get_task_instances()
-    indices = [(ti.map_index, ti.state) for ti in tis if ti.map_index >= 0]
-    assert sorted(indices) == [
-        (0, TaskInstanceState.SUCCESS),
-        (1, TaskInstanceState.REMOVED),
-        (2, TaskInstanceState.REMOVED),
-    ]
-
-    # assert repeated calls did not change the instances
-    dr.verify_integrity()
-    tis = dr.get_task_instances()
-    indices = [(ti.map_index, ti.state) for ti in tis if ti.map_index >= 0]
-    assert sorted(indices) == [
-        (0, TaskInstanceState.SUCCESS),
-        (1, TaskInstanceState.REMOVED),
-        (2, TaskInstanceState.REMOVED),
-    ]
-
-
-def 
test_calls_to_verify_integrity_with_mapped_task_with_no_changes_at_runtime(dag_maker,
 session):
-    """
-    Test no change in mapped task at runtime with calls to 
dagrun.verify_integrity
-    """
-    from airflow.models import Variable
-
-    Variable.set(key="arg1", value=[1, 2, 3])
-
-    @task
-    def task_1():
-        return Variable.get("arg1", deserialize_json=True)
-
-    with dag_maker(session=session) as dag:
+        def task_1():
+            # return Variable.get("arg1", deserialize_json=True)
+            ...
 
         @task
         def task_2(arg2): ...
 
         task_2.expand(arg2=task_1())
 
-    dr = dag_maker.create_dagrun()
-    ti = dr.get_task_instance(task_id="task_1")
-    ti.run()
-    dr.task_instance_scheduling_decisions()
-    tis = dr.get_task_instances()
-    indices = [(ti.map_index, ti.state) for ti in tis if ti.map_index >= 0]
-    assert sorted(indices) == [
-        (0, State.NONE),
-        (1, State.NONE),
-        (2, State.NONE),
-    ]
-    # Now "clear" and return the same length
-    dag.clear()
-    Variable.set(key="arg1", value=[1, 2, 3])
-
-    with dag:
-        task_2.expand(arg2=task_1()).operator
-
-    # At this point, we need to test that the change works on the serialized
-    # DAG (which is what the scheduler operates on)
-    serialized_dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
-
-    dr.dag = serialized_dag
-
-    # Run the first task again to get the new lengths
-    ti = dr.get_task_instance(task_id="task_1")
-    task1 = dag.get_task("task_1")
-    ti.refresh_from_task(task1)
-    ti.run()
-    task2 = dag.get_task("task_2")
-    for ti in dr.get_task_instances():
-        if ti.map_index < 0:
-            ti.task = task1
-        else:
-            ti.task = task2
-            ti.state = TaskInstanceState.SUCCESS
-        session.merge(ti)
+    dr: DagRun = dag_maker.create_dagrun()
+    ti = dr.get_task_instance(task_id="task_1", session=session)
+    assert ti
+    # "Run" task_1
+    ti.state = TaskInstanceState.SUCCESS
+    # Operate as if we are doing this
+    # Variable.set(key="arg1", value=[1, 2, 3])

Review Comment:
   ```suggestion
       # Operate as if TI ran after: Variable.set(key="arg1", value=[1, 2, 3])
   ```



##########
tests/models/test_dagrun.py:
##########
@@ -784,38 +781,46 @@ def test_depends_on_past(self, dagbag, session, 
prev_ti_state, is_ti_success):
             (None, False),
         ],
     )
-    def test_wait_for_downstream(self, dagbag, session, prev_ti_state, 
is_ti_success):
+    @pytest.mark.need_serialized_dag
+    def test_wait_for_downstream(self, dag_maker, session, prev_ti_state, 
is_ti_success):
         dag_id = "test_wait_for_downstream"
-        dag = dagbag.get_dag(dag_id)
+
+        with dag_maker(dag_id=dag_id, session=session) as dag:
+            dag_wfd_upstream = EmptyOperator(
+                task_id="upstream_task",
+                wait_for_downstream=True,
+            )
+            dag_wfd_downstream = EmptyOperator(
+                task_id="downstream_task",
+            )

Review Comment:
   ```suggestion
               dag_wfd_downstream = EmptyOperator(task_id="downstream_task")
   ```
   
   -2 :)



##########
tests/models/test_dagrun.py:
##########
@@ -1337,42 +1317,36 @@ def task_2(arg2): ...
 
         task_2.expand(arg2=task_1())
 
-    dr = dag_maker.create_dagrun()
-    ti = dr.get_task_instance(task_id="task_1")
-    ti.run()
-    dr.task_instance_scheduling_decisions()
-    tis = dr.get_task_instances()
-    indices = [(ti.map_index, ti.state) for ti in tis if ti.map_index >= 0]
-    assert sorted(indices) == [
-        (0, State.NONE),
-        (1, State.NONE),
-        (2, State.NONE),
-    ]
+    dr: DagRun = dag_maker.create_dagrun()
+    ti = dr.get_task_instance(task_id="task_1", session=session)
+    assert ti
+    ti.state = TaskInstanceState.SUCCESS
+    # Behave as if: Variable.set(key="arg1", value=[1, 2, 3])

Review Comment:
   ```suggestion
       # Behave as if TI ran after: Variable.set(key="arg1", value=[1, 2, 3])
   ```



##########
tests/models/test_dagrun.py:
##########
@@ -745,34 +738,38 @@ def mutate_task_instance(task_instance):
             (None, False),
         ],
     )
-    def test_depends_on_past(self, dagbag, session, prev_ti_state, 
is_ti_success):
-        dag_id = "test_depends_on_past"
+    @pytest.mark.need_serialized_dag
+    def test_depends_on_past(self, dag_maker, session, prev_ti_state, 
is_ti_success):

Review Comment:
   nit: `is_ti_success` -> `is_ti_scheduled` or similar.



##########
tests/models/test_dagrun.py:
##########
@@ -1273,62 +1264,51 @@ def task_2(arg2): ...
 
         task_2.expand(arg2=task_1())
 
-    dr = dag_maker.create_dagrun()
-    ti = dr.get_task_instance(task_id="task_1")
-    ti.run()
-    dr.task_instance_scheduling_decisions()
-    tis = dr.get_task_instances()
-    indices = [(ti.map_index, ti.state) for ti in tis if ti.map_index >= 0]
-    assert sorted(indices) == [
-        (0, State.NONE),
-        (1, State.NONE),
-        (2, State.NONE),
-    ]
+    dr: DagRun = dag_maker.create_dagrun()
+    ti = dr.get_task_instance(task_id="task_1", session=session)
+    assert ti
+    ti.state = TaskInstanceState.SUCCESS
+    # Behave as if: Variable.set(key="arg1", value=[1, 2, 3])

Review Comment:
   ```suggestion
       # Behave as if TI ran after: Variable.set(key="arg1", value=[1, 2, 3])
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to