This is an automated email from the ASF dual-hosted git repository.

dstandish pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new bb5307850d Fix short circuit operator re teardowns (#32538)
bb5307850d is described below

commit bb5307850d93df16f442bfce2f388c90d63c6f9a
Author: Daniel Standish <[email protected]>
AuthorDate: Thu Jul 13 17:52:27 2023 -0700

    Fix short circuit operator re teardowns (#32538)
    
    Short circuit operator should not skip teardown tasks.  Note that if 
there's a setup downstream that is skipped from this, sometimes the teardown 
will end up skipped by the scheduler in accordance with trigger rules, but 
that's handled by the scheduler.
    
    Also, did a little optimization so that most of the time we don't need to 
create an intermediate list of to-be-skipped task in this operator's execute 
(only when debugging do we need to "materialize" the list first).  Also 
consolidated logic so that we only have one call to `skip` instead of two 
different ones, one in each side of an if / else.
---
 airflow/models/baseoperator.py |   5 +-
 airflow/operators/python.py    |  46 ++++++++-----
 tests/operators/test_python.py | 142 ++++++++++++++++++++++++++++++++++++++++-
 3 files changed, 173 insertions(+), 20 deletions(-)

diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index ccbcb12efb..40fe3b8f0b 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -75,7 +75,7 @@ from airflow.models.mappedoperator import OperatorPartial, 
validate_mapping_kwar
 from airflow.models.param import ParamsDict
 from airflow.models.pool import Pool
 from airflow.models.taskinstance import TaskInstance, clear_task_instances
-from airflow.models.taskmixin import DAGNode, DependencyMixin
+from airflow.models.taskmixin import DependencyMixin
 from airflow.serialization.enums import DagAttributeTypes
 from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
 from airflow.ti_deps.deps.not_in_retry_period_dep import NotInRetryPeriodDep
@@ -100,6 +100,7 @@ if TYPE_CHECKING:
     import jinja2  # Slow import.
 
     from airflow.models.dag import DAG
+    from airflow.models.operator import Operator
     from airflow.models.taskinstancekey import TaskInstanceKey
     from airflow.models.xcom_arg import XComArg
     from airflow.utils.task_group import TaskGroup
@@ -1373,7 +1374,7 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
                 self.log.info("Rendering template for %s", field)
                 self.log.info(content)
 
-    def get_direct_relatives(self, upstream: bool = False) -> 
Iterable[DAGNode]:
+    def get_direct_relatives(self, upstream: bool = False) -> 
Iterable[Operator]:
         """
         Get list of the direct relatives to the current task, upstream or
         downstream.
diff --git a/airflow/operators/python.py b/airflow/operators/python.py
index a0ea5be80c..2db5ca69c3 100644
--- a/airflow/operators/python.py
+++ b/airflow/operators/python.py
@@ -18,6 +18,7 @@
 from __future__ import annotations
 
 import inspect
+import logging
 import os
 import pickle
 import shutil
@@ -30,7 +31,7 @@ from collections.abc import Container
 from pathlib import Path
 from tempfile import TemporaryDirectory
 from textwrap import dedent
-from typing import Any, Callable, Collection, Iterable, Mapping, Sequence
+from typing import TYPE_CHECKING, Any, Callable, Collection, Iterable, 
Mapping, Sequence, cast
 
 import dill
 
@@ -49,6 +50,9 @@ from airflow.utils.operator_helpers import KeywordParameters
 from airflow.utils.process_utils import execute_in_subprocess
 from airflow.utils.python_virtualenv import prepare_virtualenv, 
write_python_script
 
+if TYPE_CHECKING:
+    from pendulum.datetime import DateTime
+
 
 def task(python_callable: Callable | None = None, multiple_outputs: bool | 
None = None, **kwargs):
     """Deprecated. Use :func:`airflow.decorators.task` instead.
@@ -251,27 +255,35 @@ class ShortCircuitOperator(PythonOperator, SkipMixin):
             self.log.info("Proceeding with downstream tasks...")
             return condition
 
-        downstream_tasks = context["task"].get_flat_relatives(upstream=False)
-        self.log.debug("Downstream task IDs %s", downstream_tasks)
+        if not self.downstream_task_ids:
+            self.log.info("No downstream tasks; nothing to do.")
+            return
 
-        if downstream_tasks:
-            dag_run = context["dag_run"]
-            execution_date = dag_run.execution_date
+        dag_run = context["dag_run"]
 
+        def get_tasks_to_skip():
             if self.ignore_downstream_trigger_rules is True:
-                self.log.info("Skipping all downstream tasks...")
-                self.skip(dag_run, execution_date, downstream_tasks, 
map_index=context["ti"].map_index)
+                tasks = context["task"].get_flat_relatives(upstream=False)
             else:
-                self.log.info("Skipping downstream tasks while respecting 
trigger rules...")
-                # Explicitly setting the state of the direct, downstream 
task(s) to "skipped" and letting the
-                # Scheduler handle the remaining downstream task(s) 
appropriately.
-                self.skip(
-                    dag_run,
-                    execution_date,
-                    context["task"].get_direct_relatives(upstream=False),
-                    map_index=context["ti"].map_index,
-                )
+                tasks = context["task"].get_direct_relatives(upstream=False)
+            for t in tasks:
+                if not t.is_teardown:
+                    yield t
+
+        to_skip = get_tasks_to_skip()
 
+        # this let's us avoid an intermediate list unless debug logging
+        if self.log.getEffectiveLevel() <= logging.DEBUG:
+            self.log.debug("Downstream task IDs %s", to_skip := 
list(get_tasks_to_skip()))
+
+        self.log.info("Skipping downstream tasks")
+
+        self.skip(
+            dag_run=dag_run,
+            execution_date=cast("DateTime", dag_run.execution_date),
+            tasks=to_skip,
+            map_index=context["ti"].map_index,
+        )
         self.log.info("Done.")
 
 
diff --git a/tests/operators/test_python.py b/tests/operators/test_python.py
index 498be2eb52..1df74fef4f 100644
--- a/tests/operators/test_python.py
+++ b/tests/operators/test_python.py
@@ -26,7 +26,9 @@ import warnings
 from collections import namedtuple
 from datetime import date, datetime, timedelta
 from subprocess import CalledProcessError
+from typing import Generator
 from unittest import mock
+from unittest.mock import MagicMock
 
 import pytest
 from slugify import slugify
@@ -35,7 +37,7 @@ from airflow.decorators import task_group
 from airflow.exceptions import AirflowException, DeserializingResultError, 
RemovedInAirflow3Warning
 from airflow.models import DAG, DagRun, TaskInstance as TI
 from airflow.models.baseoperator import BaseOperator
-from airflow.models.taskinstance import clear_task_instances, 
set_current_context
+from airflow.models.taskinstance import TaskInstance, clear_task_instances, 
set_current_context
 from airflow.operators.empty import EmptyOperator
 from airflow.operators.python import (
     BranchPythonOperator,
@@ -1145,3 +1147,141 @@ class TestCurrentContextRuntime:
         with DAG(dag_id="edge_case_context_dag", default_args=DEFAULT_ARGS, 
schedule="@once"):
             op = PythonOperator(python_callable=get_all_the_context, 
task_id="get_all_the_context")
             op.run(ignore_first_depends_on_past=True, ignore_ti_state=True)
+
+
+class TestShortCircuitWithTeardown:
+    @pytest.mark.parametrize(
+        "ignore_downstream_trigger_rules, with_teardown, should_skip, 
expected",
+        [
+            (False, True, True, ["op2"]),
+            (False, True, False, []),
+            (False, False, True, ["op2"]),
+            (False, False, False, []),
+            (True, True, True, ["op2", "op3"]),
+            (True, True, False, []),
+            (True, False, True, ["op2", "op3", "op4"]),
+            (True, False, False, []),
+        ],
+    )
+    def test_short_circuit_with_teardowns(
+        self, dag_maker, ignore_downstream_trigger_rules, should_skip, 
with_teardown, expected
+    ):
+        with dag_maker() as dag:
+            op1 = ShortCircuitOperator(
+                task_id="op1",
+                python_callable=lambda: not should_skip,
+                
ignore_downstream_trigger_rules=ignore_downstream_trigger_rules,
+            )
+            op2 = PythonOperator(task_id="op2", python_callable=print)
+            op3 = PythonOperator(task_id="op3", python_callable=print)
+            op4 = PythonOperator(task_id="op4", python_callable=print)
+            if with_teardown:
+                op4.as_teardown()
+            op1 >> op2 >> op3 >> op4
+            op1.skip = MagicMock()
+            dagrun = dag_maker.create_dagrun()
+            tis = dagrun.get_task_instances()
+            ti: TaskInstance = [x for x in tis if x.task_id == "op1"][0]
+            ti._run_raw_task()
+            expected_tasks = {dag.task_dict[x] for x in expected}
+        if should_skip:
+            # we can't use assert_called_with because it's a set and therefore 
not ordered
+            actual_skipped = set(op1.skip.call_args.kwargs["tasks"])
+            assert actual_skipped == expected_tasks
+        else:
+            op1.skip.assert_not_called()
+
+    @pytest.mark.parametrize("config", ["sequence", "parallel"])
+    def test_short_circuit_with_teardowns_complicated(self, dag_maker, config):
+        with dag_maker():
+            s1 = PythonOperator(task_id="s1", python_callable=print).as_setup()
+            s2 = PythonOperator(task_id="s2", python_callable=print).as_setup()
+            op1 = ShortCircuitOperator(
+                task_id="op1",
+                python_callable=lambda: False,
+            )
+            op2 = PythonOperator(task_id="op2", python_callable=print)
+            t1 = PythonOperator(task_id="t1", 
python_callable=print).as_teardown(setups=s1)
+            t2 = PythonOperator(task_id="t2", 
python_callable=print).as_teardown(setups=s2)
+            if config == "sequence":
+                s1 >> op1 >> s2 >> op2 >> [t1, t2]
+            elif config == "parallel":
+                s1 >> op1 >> s2 >> op2 >> t2 >> t1
+            else:
+                raise ValueError("unexpected")
+            op1.skip = MagicMock()
+            dagrun = dag_maker.create_dagrun()
+            tis = dagrun.get_task_instances()
+            ti: TaskInstance = [x for x in tis if x.task_id == "op1"][0]
+            ti._run_raw_task()
+            # we can't use assert_called_with because it's a set and therefore 
not ordered
+            actual_skipped = set(op1.skip.call_args.kwargs["tasks"])
+            assert actual_skipped == {s2, op2}
+
+    def test_short_circuit_with_teardowns_complicated_2(self, dag_maker):
+        with dag_maker():
+            s1 = PythonOperator(task_id="s1", python_callable=print).as_setup()
+            s2 = PythonOperator(task_id="s2", python_callable=print).as_setup()
+            op1 = ShortCircuitOperator(
+                task_id="op1",
+                python_callable=lambda: False,
+            )
+            op2 = PythonOperator(task_id="op2", python_callable=print)
+            op3 = PythonOperator(task_id="op3", python_callable=print)
+            t1 = PythonOperator(task_id="t1", 
python_callable=print).as_teardown(setups=s1)
+            t2 = PythonOperator(task_id="t2", 
python_callable=print).as_teardown(setups=s2)
+            s1 >> op1 >> op3 >> t1
+            s2 >> op2 >> t2
+
+            # this is the weird, maybe nonsensical part
+            # in this case we don't want to skip t2 since it should run
+            op1 >> t2
+            op1.skip = MagicMock()
+            dagrun = dag_maker.create_dagrun()
+            tis = dagrun.get_task_instances()
+            ti: TaskInstance = [x for x in tis if x.task_id == "op1"][0]
+            ti._run_raw_task()
+            # we can't use assert_called_with because it's a set and therefore 
not ordered
+            actual_kwargs = op1.skip.call_args.kwargs
+            actual_skipped = set(actual_kwargs["tasks"])
+            assert actual_kwargs["execution_date"] == dagrun.logical_date
+            assert actual_skipped == {op3}
+
+    @pytest.mark.parametrize("level", [logging.DEBUG, logging.INFO])
+    def test_short_circuit_with_teardowns_debug_level(self, dag_maker, level, 
clear_db):
+        """
+        When logging is debug we convert to a list to log the tasks skipped
+        before passing them to the skip method.
+        """
+        with dag_maker():
+            s1 = PythonOperator(task_id="s1", python_callable=print).as_setup()
+            s2 = PythonOperator(task_id="s2", python_callable=print).as_setup()
+            op1 = ShortCircuitOperator(
+                task_id="op1",
+                python_callable=lambda: False,
+            )
+            op1.log.setLevel(level)
+            op2 = PythonOperator(task_id="op2", python_callable=print)
+            op3 = PythonOperator(task_id="op3", python_callable=print)
+            t1 = PythonOperator(task_id="t1", 
python_callable=print).as_teardown(setups=s1)
+            t2 = PythonOperator(task_id="t2", 
python_callable=print).as_teardown(setups=s2)
+            s1 >> op1 >> op3 >> t1
+            s2 >> op2 >> t2
+
+            # this is the weird, maybe nonsensical part
+            # in this case we don't want to skip t2 since it should run
+            op1 >> t2
+            op1.skip = MagicMock()
+            dagrun = dag_maker.create_dagrun()
+            tis = dagrun.get_task_instances()
+            ti: TaskInstance = [x for x in tis if x.task_id == "op1"][0]
+            ti._run_raw_task()
+            # we can't use assert_called_with because it's a set and therefore 
not ordered
+            actual_kwargs = op1.skip.call_args.kwargs
+            actual_skipped = actual_kwargs["tasks"]
+            if level <= logging.DEBUG:
+                assert isinstance(actual_skipped, list)
+            else:
+                assert isinstance(actual_skipped, Generator)
+            assert set(actual_skipped) == {op3}
+            assert actual_kwargs["execution_date"] == dagrun.logical_date

Reply via email to