This is an automated email from the ASF dual-hosted git repository.
dstandish pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new bb5307850d Fix short circuit operator re teardowns (#32538)
bb5307850d is described below
commit bb5307850d93df16f442bfce2f388c90d63c6f9a
Author: Daniel Standish <[email protected]>
AuthorDate: Thu Jul 13 17:52:27 2023 -0700
Fix short circuit operator re teardowns (#32538)
Short circuit operator should not skip teardown tasks. Note that if
there's a setup downstream that is skipped from this, sometimes the teardown
will end up skipped by the scheduler in accordance with trigger rules, but
that's handled by the scheduler.
Also, did a little optimization so that most of the time we don't need to
create an intermediate list of to-be-skipped task in this operator's execute
(only when debugging do we need to "materialize" the list first). Also
consolidated logic so that we only have one call to `skip` instead of two
different ones, one in each side of an if / else.
---
airflow/models/baseoperator.py | 5 +-
airflow/operators/python.py | 46 ++++++++-----
tests/operators/test_python.py | 142 ++++++++++++++++++++++++++++++++++++++++-
3 files changed, 173 insertions(+), 20 deletions(-)
diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index ccbcb12efb..40fe3b8f0b 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -75,7 +75,7 @@ from airflow.models.mappedoperator import OperatorPartial,
validate_mapping_kwar
from airflow.models.param import ParamsDict
from airflow.models.pool import Pool
from airflow.models.taskinstance import TaskInstance, clear_task_instances
-from airflow.models.taskmixin import DAGNode, DependencyMixin
+from airflow.models.taskmixin import DependencyMixin
from airflow.serialization.enums import DagAttributeTypes
from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
from airflow.ti_deps.deps.not_in_retry_period_dep import NotInRetryPeriodDep
@@ -100,6 +100,7 @@ if TYPE_CHECKING:
import jinja2 # Slow import.
from airflow.models.dag import DAG
+ from airflow.models.operator import Operator
from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.models.xcom_arg import XComArg
from airflow.utils.task_group import TaskGroup
@@ -1373,7 +1374,7 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
self.log.info("Rendering template for %s", field)
self.log.info(content)
- def get_direct_relatives(self, upstream: bool = False) ->
Iterable[DAGNode]:
+ def get_direct_relatives(self, upstream: bool = False) ->
Iterable[Operator]:
"""
Get list of the direct relatives to the current task, upstream or
downstream.
diff --git a/airflow/operators/python.py b/airflow/operators/python.py
index a0ea5be80c..2db5ca69c3 100644
--- a/airflow/operators/python.py
+++ b/airflow/operators/python.py
@@ -18,6 +18,7 @@
from __future__ import annotations
import inspect
+import logging
import os
import pickle
import shutil
@@ -30,7 +31,7 @@ from collections.abc import Container
from pathlib import Path
from tempfile import TemporaryDirectory
from textwrap import dedent
-from typing import Any, Callable, Collection, Iterable, Mapping, Sequence
+from typing import TYPE_CHECKING, Any, Callable, Collection, Iterable,
Mapping, Sequence, cast
import dill
@@ -49,6 +50,9 @@ from airflow.utils.operator_helpers import KeywordParameters
from airflow.utils.process_utils import execute_in_subprocess
from airflow.utils.python_virtualenv import prepare_virtualenv,
write_python_script
+if TYPE_CHECKING:
+ from pendulum.datetime import DateTime
+
def task(python_callable: Callable | None = None, multiple_outputs: bool |
None = None, **kwargs):
"""Deprecated. Use :func:`airflow.decorators.task` instead.
@@ -251,27 +255,35 @@ class ShortCircuitOperator(PythonOperator, SkipMixin):
self.log.info("Proceeding with downstream tasks...")
return condition
- downstream_tasks = context["task"].get_flat_relatives(upstream=False)
- self.log.debug("Downstream task IDs %s", downstream_tasks)
+ if not self.downstream_task_ids:
+ self.log.info("No downstream tasks; nothing to do.")
+ return
- if downstream_tasks:
- dag_run = context["dag_run"]
- execution_date = dag_run.execution_date
+ dag_run = context["dag_run"]
+ def get_tasks_to_skip():
if self.ignore_downstream_trigger_rules is True:
- self.log.info("Skipping all downstream tasks...")
- self.skip(dag_run, execution_date, downstream_tasks,
map_index=context["ti"].map_index)
+ tasks = context["task"].get_flat_relatives(upstream=False)
else:
- self.log.info("Skipping downstream tasks while respecting
trigger rules...")
- # Explicitly setting the state of the direct, downstream
task(s) to "skipped" and letting the
- # Scheduler handle the remaining downstream task(s)
appropriately.
- self.skip(
- dag_run,
- execution_date,
- context["task"].get_direct_relatives(upstream=False),
- map_index=context["ti"].map_index,
- )
+ tasks = context["task"].get_direct_relatives(upstream=False)
+ for t in tasks:
+ if not t.is_teardown:
+ yield t
+
+ to_skip = get_tasks_to_skip()
+ # this let's us avoid an intermediate list unless debug logging
+ if self.log.getEffectiveLevel() <= logging.DEBUG:
+ self.log.debug("Downstream task IDs %s", to_skip :=
list(get_tasks_to_skip()))
+
+ self.log.info("Skipping downstream tasks")
+
+ self.skip(
+ dag_run=dag_run,
+ execution_date=cast("DateTime", dag_run.execution_date),
+ tasks=to_skip,
+ map_index=context["ti"].map_index,
+ )
self.log.info("Done.")
diff --git a/tests/operators/test_python.py b/tests/operators/test_python.py
index 498be2eb52..1df74fef4f 100644
--- a/tests/operators/test_python.py
+++ b/tests/operators/test_python.py
@@ -26,7 +26,9 @@ import warnings
from collections import namedtuple
from datetime import date, datetime, timedelta
from subprocess import CalledProcessError
+from typing import Generator
from unittest import mock
+from unittest.mock import MagicMock
import pytest
from slugify import slugify
@@ -35,7 +37,7 @@ from airflow.decorators import task_group
from airflow.exceptions import AirflowException, DeserializingResultError,
RemovedInAirflow3Warning
from airflow.models import DAG, DagRun, TaskInstance as TI
from airflow.models.baseoperator import BaseOperator
-from airflow.models.taskinstance import clear_task_instances,
set_current_context
+from airflow.models.taskinstance import TaskInstance, clear_task_instances,
set_current_context
from airflow.operators.empty import EmptyOperator
from airflow.operators.python import (
BranchPythonOperator,
@@ -1145,3 +1147,141 @@ class TestCurrentContextRuntime:
with DAG(dag_id="edge_case_context_dag", default_args=DEFAULT_ARGS,
schedule="@once"):
op = PythonOperator(python_callable=get_all_the_context,
task_id="get_all_the_context")
op.run(ignore_first_depends_on_past=True, ignore_ti_state=True)
+
+
+class TestShortCircuitWithTeardown:
+ @pytest.mark.parametrize(
+ "ignore_downstream_trigger_rules, with_teardown, should_skip,
expected",
+ [
+ (False, True, True, ["op2"]),
+ (False, True, False, []),
+ (False, False, True, ["op2"]),
+ (False, False, False, []),
+ (True, True, True, ["op2", "op3"]),
+ (True, True, False, []),
+ (True, False, True, ["op2", "op3", "op4"]),
+ (True, False, False, []),
+ ],
+ )
+ def test_short_circuit_with_teardowns(
+ self, dag_maker, ignore_downstream_trigger_rules, should_skip,
with_teardown, expected
+ ):
+ with dag_maker() as dag:
+ op1 = ShortCircuitOperator(
+ task_id="op1",
+ python_callable=lambda: not should_skip,
+
ignore_downstream_trigger_rules=ignore_downstream_trigger_rules,
+ )
+ op2 = PythonOperator(task_id="op2", python_callable=print)
+ op3 = PythonOperator(task_id="op3", python_callable=print)
+ op4 = PythonOperator(task_id="op4", python_callable=print)
+ if with_teardown:
+ op4.as_teardown()
+ op1 >> op2 >> op3 >> op4
+ op1.skip = MagicMock()
+ dagrun = dag_maker.create_dagrun()
+ tis = dagrun.get_task_instances()
+ ti: TaskInstance = [x for x in tis if x.task_id == "op1"][0]
+ ti._run_raw_task()
+ expected_tasks = {dag.task_dict[x] for x in expected}
+ if should_skip:
+ # we can't use assert_called_with because it's a set and therefore
not ordered
+ actual_skipped = set(op1.skip.call_args.kwargs["tasks"])
+ assert actual_skipped == expected_tasks
+ else:
+ op1.skip.assert_not_called()
+
+ @pytest.mark.parametrize("config", ["sequence", "parallel"])
+ def test_short_circuit_with_teardowns_complicated(self, dag_maker, config):
+ with dag_maker():
+ s1 = PythonOperator(task_id="s1", python_callable=print).as_setup()
+ s2 = PythonOperator(task_id="s2", python_callable=print).as_setup()
+ op1 = ShortCircuitOperator(
+ task_id="op1",
+ python_callable=lambda: False,
+ )
+ op2 = PythonOperator(task_id="op2", python_callable=print)
+ t1 = PythonOperator(task_id="t1",
python_callable=print).as_teardown(setups=s1)
+ t2 = PythonOperator(task_id="t2",
python_callable=print).as_teardown(setups=s2)
+ if config == "sequence":
+ s1 >> op1 >> s2 >> op2 >> [t1, t2]
+ elif config == "parallel":
+ s1 >> op1 >> s2 >> op2 >> t2 >> t1
+ else:
+ raise ValueError("unexpected")
+ op1.skip = MagicMock()
+ dagrun = dag_maker.create_dagrun()
+ tis = dagrun.get_task_instances()
+ ti: TaskInstance = [x for x in tis if x.task_id == "op1"][0]
+ ti._run_raw_task()
+ # we can't use assert_called_with because it's a set and therefore
not ordered
+ actual_skipped = set(op1.skip.call_args.kwargs["tasks"])
+ assert actual_skipped == {s2, op2}
+
+ def test_short_circuit_with_teardowns_complicated_2(self, dag_maker):
+ with dag_maker():
+ s1 = PythonOperator(task_id="s1", python_callable=print).as_setup()
+ s2 = PythonOperator(task_id="s2", python_callable=print).as_setup()
+ op1 = ShortCircuitOperator(
+ task_id="op1",
+ python_callable=lambda: False,
+ )
+ op2 = PythonOperator(task_id="op2", python_callable=print)
+ op3 = PythonOperator(task_id="op3", python_callable=print)
+ t1 = PythonOperator(task_id="t1",
python_callable=print).as_teardown(setups=s1)
+ t2 = PythonOperator(task_id="t2",
python_callable=print).as_teardown(setups=s2)
+ s1 >> op1 >> op3 >> t1
+ s2 >> op2 >> t2
+
+ # this is the weird, maybe nonsensical part
+ # in this case we don't want to skip t2 since it should run
+ op1 >> t2
+ op1.skip = MagicMock()
+ dagrun = dag_maker.create_dagrun()
+ tis = dagrun.get_task_instances()
+ ti: TaskInstance = [x for x in tis if x.task_id == "op1"][0]
+ ti._run_raw_task()
+ # we can't use assert_called_with because it's a set and therefore
not ordered
+ actual_kwargs = op1.skip.call_args.kwargs
+ actual_skipped = set(actual_kwargs["tasks"])
+ assert actual_kwargs["execution_date"] == dagrun.logical_date
+ assert actual_skipped == {op3}
+
+ @pytest.mark.parametrize("level", [logging.DEBUG, logging.INFO])
+ def test_short_circuit_with_teardowns_debug_level(self, dag_maker, level,
clear_db):
+ """
+ When logging is debug we convert to a list to log the tasks skipped
+ before passing them to the skip method.
+ """
+ with dag_maker():
+ s1 = PythonOperator(task_id="s1", python_callable=print).as_setup()
+ s2 = PythonOperator(task_id="s2", python_callable=print).as_setup()
+ op1 = ShortCircuitOperator(
+ task_id="op1",
+ python_callable=lambda: False,
+ )
+ op1.log.setLevel(level)
+ op2 = PythonOperator(task_id="op2", python_callable=print)
+ op3 = PythonOperator(task_id="op3", python_callable=print)
+ t1 = PythonOperator(task_id="t1",
python_callable=print).as_teardown(setups=s1)
+ t2 = PythonOperator(task_id="t2",
python_callable=print).as_teardown(setups=s2)
+ s1 >> op1 >> op3 >> t1
+ s2 >> op2 >> t2
+
+ # this is the weird, maybe nonsensical part
+ # in this case we don't want to skip t2 since it should run
+ op1 >> t2
+ op1.skip = MagicMock()
+ dagrun = dag_maker.create_dagrun()
+ tis = dagrun.get_task_instances()
+ ti: TaskInstance = [x for x in tis if x.task_id == "op1"][0]
+ ti._run_raw_task()
+ # we can't use assert_called_with because it's a set and therefore
not ordered
+ actual_kwargs = op1.skip.call_args.kwargs
+ actual_skipped = actual_kwargs["tasks"]
+ if level <= logging.DEBUG:
+ assert isinstance(actual_skipped, list)
+ else:
+ assert isinstance(actual_skipped, Generator)
+ assert set(actual_skipped) == {op3}
+ assert actual_kwargs["execution_date"] == dagrun.logical_date