This is an automated email from the ASF dual-hosted git repository. pierrejeambrun pushed a commit to branch v2-5-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 061338fad1a9ec4bf12b1aad482b3a72f7d3551c Author: Andrey Anshin <[email protected]> AuthorDate: Thu Dec 22 12:32:06 2022 +0400 Refactor python operators/sensor tests (#28493) (cherry picked from commit 884fca8d114ce8e0c982747937a1014f3b5e7491) --- tests/conftest.py | 8 +- tests/decorators/test_python.py | 143 +--- tests/decorators/test_python_virtualenv.py | 13 - tests/operators/test_python.py | 1003 ++++++++++------------------ tests/sensors/test_python.py | 124 +--- 5 files changed, 426 insertions(+), 865 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 0d4d1170f0..d71d8eb0f0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -22,6 +22,7 @@ import subprocess import sys from contextlib import ExitStack, suppress from datetime import datetime, timedelta +from typing import TYPE_CHECKING import freezegun import pytest @@ -46,6 +47,9 @@ from tests.test_utils.perf.perf_kit.sqlalchemy import ( # noqa isort:skip trace_queries, ) +if TYPE_CHECKING: + from airflow.models.taskinstance import TaskInstance + @pytest.fixture() def reset_environment(): @@ -741,7 +745,7 @@ def create_task_instance(dag_maker, create_dummy_dag): run_type=None, data_interval=None, **kwargs, - ): + ) -> TaskInstance: if execution_date is None: from airflow.utils import timezone @@ -775,7 +779,7 @@ def create_task_instance_of_operator(dag_maker): execution_date=None, session=None, **operator_kwargs, - ): + ) -> TaskInstance: with dag_maker(dag_id=dag_id, session=session): operator_class(**operator_kwargs) if execution_date is None: diff --git a/tests/decorators/test_python.py b/tests/decorators/test_python.py index 47a908db77..1bbad51a0b 100644 --- a/tests/decorators/test_python.py +++ b/tests/decorators/test_python.py @@ -37,41 +37,20 @@ from airflow.utils import timezone from airflow.utils.state import State from airflow.utils.task_group import TaskGroup from airflow.utils.types import DagRunType -from tests.operators.test_python import Call, assert_calls_equal, build_recording_function -from tests.test_utils.db import clear_db_runs +from tests.operators.test_python import BasePythonTest DEFAULT_DATE = timezone.datetime(2016, 1, 1) -END_DATE = timezone.datetime(2016, 1, 2) -INTERVAL = timedelta(hours=12) -FROZEN_NOW = timezone.datetime(2016, 1, 2, 12, 1, 1) -TI_CONTEXT_ENV_VARS = [ - "AIRFLOW_CTX_DAG_ID", - "AIRFLOW_CTX_TASK_ID", - "AIRFLOW_CTX_EXECUTION_DATE", - "AIRFLOW_CTX_DAG_RUN_ID", -] - -class TestAirflowTaskDecorator: - def setup_class(self): - clear_db_runs() - - def setup_method(self): - self.dag = DAG("test_dag", default_args={"owner": "airflow", "start_date": DEFAULT_DATE}) - self.run = False - - def teardown_method(self): - self.dag.clear() - self.run = False - clear_db_runs() +class TestAirflowTaskDecorator(BasePythonTest): + default_date = DEFAULT_DATE def test_python_operator_python_callable_is_callable(self): """Tests that @task will only instantiate if the python_callable argument is callable.""" not_callable = {} with pytest.raises(TypeError): - task_decorator(not_callable, dag=self.dag) + task_decorator(not_callable) @pytest.mark.parametrize( "resolve", @@ -155,13 +134,7 @@ class TestAirflowTaskDecorator: with self.dag: res = identity2(8, 4) - dr = self.dag.create_dagrun( - run_id=DagRunType.MANUAL.value, - start_date=timezone.utcnow(), - execution_date=DEFAULT_DATE, - state=State.RUNNING, - ) - + dr = self.create_dag_run() res.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) ti = dr.get_task_instances()[0] @@ -179,13 +152,7 @@ class TestAirflowTaskDecorator: with self.dag: ident = identity_tuple(35, 36) - dr = self.dag.create_dagrun( - run_id=DagRunType.MANUAL.value, - start_date=timezone.utcnow(), - execution_date=DEFAULT_DATE, - state=State.RUNNING, - ) - + dr = self.create_dag_run() ident.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) ti = dr.get_task_instances()[0] @@ -227,15 +194,9 @@ class TestAirflowTaskDecorator: with self.dag: ret = add_number(2) - self.dag.create_dagrun( - run_id=DagRunType.MANUAL, - execution_date=DEFAULT_DATE, - start_date=DEFAULT_DATE, - state=State.RUNNING, - ) + self.create_dag_run() with pytest.raises(AirflowException): - ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) def test_fail_multiple_outputs_no_dict(self): @@ -245,84 +206,53 @@ class TestAirflowTaskDecorator: with self.dag: ret = add_number(2) - self.dag.create_dagrun( - run_id=DagRunType.MANUAL, - execution_date=DEFAULT_DATE, - start_date=DEFAULT_DATE, - state=State.RUNNING, - ) + self.create_dag_run() with pytest.raises(AirflowException): - ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) def test_python_callable_arguments_are_templatized(self): """Test @task op_args are templatized""" - recorded_calls = [] + + @task_decorator + def arg_task(*args): + raise RuntimeError("Should not executed") # Create a named tuple and ensure it is still preserved # after the rendering is done Named = namedtuple("Named", ["var1", "var2"]) named_tuple = Named("{{ ds }}", "unchanged") - task = task_decorator( - # a Mock instance cannot be used as a callable function or test fails with a - # TypeError: Object of type Mock is not JSON serializable - build_recording_function(recorded_calls), - dag=self.dag, - ) - ret = task(4, date(2019, 1, 1), "dag {{dag.dag_id}} ran on {{ds}}.", named_tuple) - - self.dag.create_dagrun( - run_id=DagRunType.MANUAL, - execution_date=DEFAULT_DATE, - data_interval=(DEFAULT_DATE, DEFAULT_DATE), - start_date=DEFAULT_DATE, - state=State.RUNNING, - ) - ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + with self.dag: + ret = arg_task(4, date(2019, 1, 1), "dag {{dag.dag_id}} ran on {{ds}}.", named_tuple) - ds_templated = DEFAULT_DATE.date().isoformat() - assert len(recorded_calls) == 1 - assert_calls_equal( - recorded_calls[0], - Call( - 4, - date(2019, 1, 1), - f"dag {self.dag.dag_id} ran on {ds_templated}.", - Named(ds_templated, "unchanged"), - ), - ) + dr = self.create_dag_run() + ti = TaskInstance(task=ret.operator, run_id=dr.run_id) + rendered_op_args = ti.render_templates().op_args + assert len(rendered_op_args) == 4 + assert rendered_op_args[0] == 4 + assert rendered_op_args[1] == date(2019, 1, 1) + assert rendered_op_args[2] == f"dag {self.dag_id} ran on {self.ds_templated}." + assert rendered_op_args[3] == Named(self.ds_templated, "unchanged") def test_python_callable_keyword_arguments_are_templatized(self): """Test PythonOperator op_kwargs are templatized""" - recorded_calls = [] - task = task_decorator( - # a Mock instance cannot be used as a callable function or test fails with a - # TypeError: Object of type Mock is not JSON serializable - build_recording_function(recorded_calls), - dag=self.dag, - ) - ret = task(an_int=4, a_date=date(2019, 1, 1), a_templated_string="dag {{dag.dag_id}} ran on {{ds}}.") - self.dag.create_dagrun( - run_id=DagRunType.MANUAL, - execution_date=DEFAULT_DATE, - data_interval=(DEFAULT_DATE, DEFAULT_DATE), - start_date=DEFAULT_DATE, - state=State.RUNNING, - ) - ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + @task_decorator + def kwargs_task(an_int, a_date, a_templated_string): + raise RuntimeError("Should not executed") - assert len(recorded_calls) == 1 - assert_calls_equal( - recorded_calls[0], - Call( - an_int=4, - a_date=date(2019, 1, 1), - a_templated_string=f"dag {self.dag.dag_id} ran on {DEFAULT_DATE.date().isoformat()}.", - ), - ) + with self.dag: + ret = kwargs_task( + an_int=4, a_date=date(2019, 1, 1), a_templated_string="dag {{dag.dag_id}} ran on {{ds}}." + ) + + dr = self.create_dag_run() + ti = TaskInstance(task=ret.operator, run_id=dr.run_id) + rendered_op_kwargs = ti.render_templates().op_kwargs + assert rendered_op_kwargs["an_int"] == 4 + assert rendered_op_kwargs["a_date"] == date(2019, 1, 1) + assert rendered_op_kwargs["a_templated_string"] == f"dag {self.dag_id} ran on {self.ds_templated}." def test_manual_task_id(self): """Test manually setting task_id""" @@ -415,6 +345,7 @@ class TestAirflowTaskDecorator: def do_run(): return 4 + self.dag.default_args["owner"] = "airflow" with self.dag: ret = do_run() assert ret.operator.owner == "airflow" diff --git a/tests/decorators/test_python_virtualenv.py b/tests/decorators/test_python_virtualenv.py index 032ec34aa5..88121c5db3 100644 --- a/tests/decorators/test_python_virtualenv.py +++ b/tests/decorators/test_python_virtualenv.py @@ -19,7 +19,6 @@ from __future__ import annotations import datetime import sys -from datetime import timedelta from subprocess import CalledProcessError import pytest @@ -28,18 +27,6 @@ from airflow.decorators import task from airflow.utils import timezone DEFAULT_DATE = timezone.datetime(2016, 1, 1) -END_DATE = timezone.datetime(2016, 1, 2) -INTERVAL = timedelta(hours=12) -FROZEN_NOW = timezone.datetime(2016, 1, 2, 12, 1, 1) - -TI_CONTEXT_ENV_VARS = [ - "AIRFLOW_CTX_DAG_ID", - "AIRFLOW_CTX_TASK_ID", - "AIRFLOW_CTX_EXECUTION_DATE", - "AIRFLOW_CTX_DAG_RUN_ID", -] - - PYTHON_VERSION = sys.version_info[0] diff --git a/tests/operators/test_python.py b/tests/operators/test_python.py index c011c5cb35..8f6c089f08 100644 --- a/tests/operators/test_python.py +++ b/tests/operators/test_python.py @@ -20,16 +20,18 @@ from __future__ import annotations import copy import logging import os +import re import sys -import unittest.mock import warnings from collections import namedtuple from datetime import date, datetime, timedelta from subprocess import CalledProcessError +from unittest import mock import pytest +from slugify import slugify -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, RemovedInAirflow3Warning from airflow.models import DAG, DagRun, TaskInstance as TI from airflow.models.baseoperator import BaseOperator from airflow.models.taskinstance import clear_task_instances, set_current_context @@ -45,88 +47,108 @@ from airflow.utils import timezone from airflow.utils.context import AirflowContextDeprecationWarning, Context from airflow.utils.python_virtualenv import prepare_virtualenv from airflow.utils.session import create_session -from airflow.utils.state import State +from airflow.utils.state import DagRunState, State from airflow.utils.trigger_rule import TriggerRule -from airflow.utils.types import DagRunType +from airflow.utils.types import NOTSET, DagRunType from tests.test_utils import AIRFLOW_MAIN_FOLDER from tests.test_utils.db import clear_db_runs DEFAULT_DATE = timezone.datetime(2016, 1, 1) -END_DATE = timezone.datetime(2016, 1, 2) -INTERVAL = timedelta(hours=12) -FROZEN_NOW = timezone.datetime(2016, 1, 2, 12, 1, 1) - -TI_CONTEXT_ENV_VARS = [ - "AIRFLOW_CTX_DAG_ID", - "AIRFLOW_CTX_TASK_ID", - "AIRFLOW_CTX_EXECUTION_DATE", - "AIRFLOW_CTX_DAG_RUN_ID", -] - TEMPLATE_SEARCHPATH = os.path.join(AIRFLOW_MAIN_FOLDER, "tests", "config_templates") +LOGGER_NAME = "airflow.task.operators" -class Call: - def __init__(self, *args, **kwargs): - self.args = args - self.kwargs = kwargs - - -def build_recording_function(calls_collection): - """ - We can not use a Mock instance as a PythonOperator callable function or some tests fail with a - TypeError: Object of type Mock is not JSON serializable - Then using this custom function recording custom Call objects for further testing - (replacing Mock.assert_called_with assertion method) - """ - - def recording_function(*args, **kwargs): - calls_collection.append(Call(*args, **kwargs)) - - return recording_function - +class BasePythonTest: + """Base test class for TestPythonOperator and TestPythonSensor classes""" -def assert_calls_equal(first: Call, second: Call) -> None: - assert isinstance(first, Call) - assert isinstance(second, Call) - assert first.args == second.args - # eliminate context (conf, dag_run, task_instance, etc.) - test_args = ["an_int", "a_date", "a_templated_string"] - first.kwargs = {key: value for (key, value) in first.kwargs.items() if key in test_args} - second.kwargs = {key: value for (key, value) in second.kwargs.items() if key in test_args} - assert first.kwargs == second.kwargs + opcls: type[BaseOperator] + dag_id: str + task_id: str + run_id: str + dag: DAG + ds_templated: str + default_date: datetime = DEFAULT_DATE + + @pytest.fixture(autouse=True) + def base_tests_setup(self, request, create_task_instance_of_operator, dag_maker): + self.dag_id = f"dag_{slugify(request.cls.__name__)}" + self.task_id = f"task_{slugify(request.node.name, max_length=40)}" + self.run_id = f"run_{slugify(request.node.name, max_length=40)}" + self.ds_templated = self.default_date.date().isoformat() + self.ti_maker = create_task_instance_of_operator + self.dag_maker = dag_maker + self.dag = self.dag_maker(self.dag_id, template_searchpath=TEMPLATE_SEARCHPATH).dag + clear_db_runs() + yield + clear_db_runs() + + @staticmethod + def assert_expected_task_states(dag_run: DagRun, expected_states: dict): + """Helper function that asserts `TaskInstances` of a given `task_id` are in a given state.""" + asserts = [] + for ti in dag_run.get_task_instances(): + try: + expected = expected_states[ti.task_id] + except KeyError: + asserts.append(f"Unexpected task id {ti.task_id!r} found, expected {expected_states.keys()}") + continue + + if ti.state != expected: + asserts.append(f"Task {ti.task_id!r} has state {ti.state!r} instead of expected {expected!r}") + if asserts: + pytest.fail("\n".join(asserts)) + + @staticmethod + def default_kwargs(**kwargs): + """Default arguments for specific Operator.""" + return kwargs + + def create_dag_run(self) -> DagRun: + return self.dag.create_dagrun( + state=DagRunState.RUNNING, + start_date=self.dag_maker.start_date, + session=self.dag_maker.session, + execution_date=self.default_date, + run_type=DagRunType.MANUAL, + ) + def create_ti(self, fn, **kwargs) -> TI: + """Create TaskInstance for class defined Operator.""" + return self.ti_maker( + self.opcls, + python_callable=fn, + **self.default_kwargs(**kwargs), + dag_id=self.dag_id, + task_id=self.task_id, + execution_date=self.default_date, + ) -class TestPythonBase(unittest.TestCase): - """Base test class for TestPythonOperator and TestPythonSensor classes""" + def run_as_operator(self, fn, **kwargs): + """Run task by direct call ``run`` method.""" + with self.dag: + task = self.opcls(task_id=self.task_id, python_callable=fn, **self.default_kwargs(**kwargs)) - @classmethod - def setUpClass(cls): - super().setUpClass() + task.run(start_date=self.default_date, end_date=self.default_date) + return task - with create_session() as session: - session.query(DagRun).delete() - session.query(TI).delete() + def run_as_task(self, fn, **kwargs): + """Create TaskInstance and run it.""" + ti = self.create_ti(fn, **kwargs) + ti.run() + return ti.task - def setUp(self): - super().setUp() - self.dag = DAG("test_dag", default_args={"owner": "airflow", "start_date": DEFAULT_DATE}) - self.addCleanup(self.dag.clear) - self.clear_run() - self.addCleanup(self.clear_run) + def render_templates(self, fn, **kwargs): + """Create TaskInstance and render templates without actual run.""" + return self.create_ti(fn, **kwargs).render_templates() - def tearDown(self): - super().tearDown() - with create_session() as session: - session.query(DagRun).delete() - session.query(TI).delete() +class TestPythonOperator(BasePythonTest): + opcls = PythonOperator - def clear_run(self): + @pytest.fixture(autouse=True) + def setup_tests(self): self.run = False - -class TestPythonOperator(TestPythonBase): def do_run(self): self.run = True @@ -135,105 +157,58 @@ class TestPythonOperator(TestPythonBase): def test_python_operator_run(self): """Tests that the python callable is invoked on task run.""" - task = PythonOperator(python_callable=self.do_run, task_id="python_operator", dag=self.dag) + ti = self.create_ti(self.do_run) assert not self.is_run() - task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + ti.run() assert self.is_run() - def test_python_operator_python_callable_is_callable(self): - """Tests that PythonOperator will only instantiate if - the python_callable argument is callable.""" - not_callable = {} - with pytest.raises(AirflowException): - PythonOperator(python_callable=not_callable, task_id="python_operator", dag=self.dag) - not_callable = None - with pytest.raises(AirflowException): - PythonOperator(python_callable=not_callable, task_id="python_operator", dag=self.dag) + @pytest.mark.parametrize("not_callable", [{}, None]) + def test_python_operator_python_callable_is_callable(self, not_callable): + """Tests that PythonOperator will only instantiate if the python_callable argument is callable.""" + with pytest.raises(AirflowException, match="`python_callable` param must be callable"): + PythonOperator(python_callable=not_callable, task_id="python_operator") def test_python_callable_arguments_are_templatized(self): """Test PythonOperator op_args are templatized""" - recorded_calls = [] - # Create a named tuple and ensure it is still preserved # after the rendering is done Named = namedtuple("Named", ["var1", "var2"]) named_tuple = Named("{{ ds }}", "unchanged") - task = PythonOperator( - task_id="python_operator", - # a Mock instance cannot be used as a callable function or test fails with a - # TypeError: Object of type Mock is not JSON serializable - python_callable=build_recording_function(recorded_calls), + task = self.render_templates( + lambda: 0, op_args=[4, date(2019, 1, 1), "dag {{dag.dag_id}} ran on {{ds}}.", named_tuple], - dag=self.dag, - ) - - self.dag.create_dagrun( - run_type=DagRunType.MANUAL, - execution_date=DEFAULT_DATE, - data_interval=(DEFAULT_DATE, DEFAULT_DATE), - start_date=DEFAULT_DATE, - state=State.RUNNING, - ) - task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) - - ds_templated = DEFAULT_DATE.date().isoformat() - assert 1 == len(recorded_calls) - assert_calls_equal( - recorded_calls[0], - Call( - 4, - date(2019, 1, 1), - f"dag {self.dag.dag_id} ran on {ds_templated}.", - Named(ds_templated, "unchanged"), - ), ) + rendered_op_args = task.op_args + assert len(rendered_op_args) == 4 + assert rendered_op_args[0] == 4 + assert rendered_op_args[1] == date(2019, 1, 1) + assert rendered_op_args[2] == f"dag {self.dag_id} ran on {self.ds_templated}." + assert rendered_op_args[3] == Named(self.ds_templated, "unchanged") def test_python_callable_keyword_arguments_are_templatized(self): """Test PythonOperator op_kwargs are templatized""" - recorded_calls = [] - - task = PythonOperator( - task_id="python_operator", - # a Mock instance cannot be used as a callable function or test fails with a - # TypeError: Object of type Mock is not JSON serializable - python_callable=build_recording_function(recorded_calls), + task = self.render_templates( + lambda: 0, op_kwargs={ "an_int": 4, "a_date": date(2019, 1, 1), "a_templated_string": "dag {{dag.dag_id}} ran on {{ds}}.", }, - dag=self.dag, - ) - - self.dag.create_dagrun( - run_type=DagRunType.MANUAL, - execution_date=DEFAULT_DATE, - data_interval=(DEFAULT_DATE, DEFAULT_DATE), - start_date=DEFAULT_DATE, - state=State.RUNNING, - ) - task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) - - assert 1 == len(recorded_calls) - assert_calls_equal( - recorded_calls[0], - Call( - an_int=4, - a_date=date(2019, 1, 1), - a_templated_string=f"dag {self.dag.dag_id} ran on {DEFAULT_DATE.date().isoformat()}.", - ), ) + rendered_op_kwargs = task.op_kwargs + assert rendered_op_kwargs["an_int"] == 4 + assert rendered_op_kwargs["a_date"] == date(2019, 1, 1) + assert rendered_op_kwargs["a_templated_string"] == f"dag {self.dag_id} ran on {self.ds_templated}." def test_python_operator_shallow_copy_attr(self): def not_callable(x): - return x + assert False, "Should not be triggered" original_task = PythonOperator( python_callable=not_callable, - task_id="python_operator", op_kwargs={"certain_attrs": ""}, - dag=self.dag, + task_id=self.task_id, ) new_task = copy.deepcopy(original_task) # shallow copy op_kwargs @@ -242,383 +217,213 @@ class TestPythonOperator(TestPythonBase): assert id(original_task.python_callable) == id(new_task.python_callable) def test_conflicting_kwargs(self): - self.dag.create_dagrun( - run_type=DagRunType.MANUAL, - execution_date=DEFAULT_DATE, - start_date=DEFAULT_DATE, - state=State.RUNNING, - external_trigger=False, - ) - # dag is not allowed since it is a reserved keyword def func(dag): - # An ValueError should be triggered since we're using dag as a - # reserved keyword + # An ValueError should be triggered since we're using dag as a reserved keyword raise RuntimeError(f"Should not be triggered, dag: {dag}") - python_operator = PythonOperator( - task_id="python_operator", op_args=[1], python_callable=func, dag=self.dag - ) - - with pytest.raises(ValueError) as ctx: - python_operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) - assert "dag" in str(ctx.value), "'dag' not found in the exception" + ti = self.create_ti(func, op_args=[1]) + error_message = re.escape("The key 'dag' in args is a part of kwargs and therefore reserved.") + with pytest.raises(ValueError, match=error_message): + ti.run() def test_provide_context_does_not_fail(self): - """ - ensures that provide_context doesn't break dags in 2.0 - """ - self.dag.create_dagrun( - run_type=DagRunType.MANUAL, - execution_date=DEFAULT_DATE, - start_date=DEFAULT_DATE, - state=State.RUNNING, - external_trigger=False, - ) + """Ensures that provide_context doesn't break dags in 2.0.""" def func(custom, dag): assert 1 == custom, "custom should be 1" assert dag is not None, "dag should be set" - python_operator = PythonOperator( - task_id="python_operator", - op_kwargs={"custom": 1}, - python_callable=func, - provide_context=True, - dag=self.dag, - ) - python_operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + with pytest.warns(RemovedInAirflow3Warning): + self.run_as_task(func, op_kwargs={"custom": 1}, provide_context=True) def test_context_with_conflicting_op_args(self): - self.dag.create_dagrun( - run_type=DagRunType.MANUAL, - execution_date=DEFAULT_DATE, - start_date=DEFAULT_DATE, - state=State.RUNNING, - external_trigger=False, - ) - def func(custom, dag): assert 1 == custom, "custom should be 1" assert dag is not None, "dag should be set" - python_operator = PythonOperator( - task_id="python_operator", op_kwargs={"custom": 1}, python_callable=func, dag=self.dag - ) - python_operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + self.run_as_task(func, op_kwargs={"custom": 1}) def test_context_with_kwargs(self): - self.dag.create_dagrun( - run_type=DagRunType.MANUAL, - execution_date=DEFAULT_DATE, - start_date=DEFAULT_DATE, - state=State.RUNNING, - external_trigger=False, - ) - def func(**context): # check if context is being set assert len(context) > 0, "Context has not been injected" - python_operator = PythonOperator( - task_id="python_operator", op_kwargs={"custom": 1}, python_callable=func, dag=self.dag - ) - python_operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) - - def test_return_value_log_with_show_return_value_in_logs_default(self): - self.dag.create_dagrun( - run_type=DagRunType.MANUAL, - execution_date=DEFAULT_DATE, - start_date=DEFAULT_DATE, - state=State.RUNNING, - external_trigger=False, - ) - - def func(): - return "test_return_value" - - python_operator = PythonOperator(task_id="python_operator", python_callable=func, dag=self.dag) - - with self.assertLogs("airflow.task.operators", level=logging.INFO) as cm: - python_operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + self.run_as_task(func, op_kwargs={"custom": 1}) - assert ( - "INFO:airflow.task.operators:Done. Returned value was: test_return_value" in cm.output - ), "Return value should be shown" - - def test_return_value_log_with_show_return_value_in_logs_false(self): - self.dag.create_dagrun( - run_type=DagRunType.MANUAL, - execution_date=DEFAULT_DATE, - start_date=DEFAULT_DATE, - state=State.RUNNING, - external_trigger=False, - ) + @pytest.mark.parametrize( + "show_return_value_in_logs, should_shown", + [ + pytest.param(NOTSET, True, id="default"), + pytest.param(True, True, id="show"), + pytest.param(False, False, id="hide"), + ], + ) + def test_return_value_log(self, show_return_value_in_logs, should_shown, caplog): + caplog.set_level(logging.INFO, logger=LOGGER_NAME) def func(): return "test_return_value" - python_operator = PythonOperator( - task_id="python_operator", - python_callable=func, - dag=self.dag, - show_return_value_in_logs=False, - ) - - with self.assertLogs("airflow.task.operators", level=logging.INFO) as cm: - python_operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) - - assert ( - "INFO:airflow.task.operators:Done. Returned value was: test_return_value" not in cm.output - ), "Return value should not be shown" - assert ( - "INFO:airflow.task.operators:Done. Returned value not shown" in cm.output - ), "Log message that the option is turned off should be shown" + if show_return_value_in_logs is NOTSET: + self.run_as_task(func) + else: + self.run_as_task(func, show_return_value_in_logs=show_return_value_in_logs) + if should_shown: + assert "Done. Returned value was: test_return_value" in caplog.messages + assert "Done. Returned value not shown" not in caplog.messages + else: + assert "Done. Returned value was: test_return_value" not in caplog.messages + assert "Done. Returned value not shown" in caplog.messages -class TestBranchOperator(unittest.TestCase): - @classmethod - def setUpClass(cls): - super().setUpClass() - - with create_session() as session: - session.query(DagRun).delete() - session.query(TI).delete() - - def setUp(self): - self.dag = DAG( - "branch_operator_test", - default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, - schedule=INTERVAL, - ) - self.branch_1 = EmptyOperator(task_id="branch_1", dag=self.dag) - self.branch_2 = EmptyOperator(task_id="branch_2", dag=self.dag) - self.branch_3 = None +class TestBranchOperator(BasePythonTest): + opcls = BranchPythonOperator - def tearDown(self): - super().tearDown() - - with create_session() as session: - session.query(DagRun).delete() - session.query(TI).delete() + @pytest.fixture(autouse=True) + def setup_tests(self): + self.branch_1 = EmptyOperator(task_id="branch_1") + self.branch_2 = EmptyOperator(task_id="branch_2") def test_with_dag_run(self): - branch_op = BranchPythonOperator( - task_id="make_choice", dag=self.dag, python_callable=lambda: "branch_1" - ) - - self.branch_1.set_upstream(branch_op) - self.branch_2.set_upstream(branch_op) - self.dag.clear() + with self.dag: + branch_op = BranchPythonOperator(task_id=self.task_id, python_callable=lambda: "branch_1") + branch_op >> [self.branch_1, self.branch_2] - dr = self.dag.create_dagrun( - run_type=DagRunType.MANUAL, - start_date=timezone.utcnow(), - execution_date=DEFAULT_DATE, - state=State.RUNNING, + dr = self.create_dag_run() + branch_op.run(start_date=self.default_date, end_date=self.default_date) + self.assert_expected_task_states( + dr, {self.task_id: State.SUCCESS, "branch_1": State.NONE, "branch_2": State.SKIPPED} ) - branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) - - tis = dr.get_task_instances() - for ti in tis: - if ti.task_id == "make_choice": - assert ti.state == State.SUCCESS - elif ti.task_id == "branch_1": - assert ti.state == State.NONE - elif ti.task_id == "branch_2": - assert ti.state == State.SKIPPED - else: - raise ValueError(f"Invalid task id {ti.task_id} found!") - def test_with_skip_in_branch_downstream_dependencies(self): - branch_op = BranchPythonOperator( - task_id="make_choice", dag=self.dag, python_callable=lambda: "branch_1" - ) - - branch_op >> self.branch_1 >> self.branch_2 - branch_op >> self.branch_2 - self.dag.clear() + with self.dag: + branch_op = BranchPythonOperator(task_id=self.task_id, python_callable=lambda: "branch_1") + branch_op >> self.branch_1 >> self.branch_2 + branch_op >> self.branch_2 - dr = self.dag.create_dagrun( - run_type=DagRunType.MANUAL, - start_date=timezone.utcnow(), - execution_date=DEFAULT_DATE, - state=State.RUNNING, + dr = self.create_dag_run() + branch_op.run(start_date=self.default_date, end_date=self.default_date) + self.assert_expected_task_states( + dr, {self.task_id: State.SUCCESS, "branch_1": State.NONE, "branch_2": State.NONE} ) - branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) - - tis = dr.get_task_instances() - for ti in tis: - if ti.task_id == "make_choice": - assert ti.state == State.SUCCESS - elif ti.task_id == "branch_1": - assert ti.state == State.NONE - elif ti.task_id == "branch_2": - assert ti.state == State.NONE - else: - raise ValueError(f"Invalid task id {ti.task_id} found!") - def test_with_skip_in_branch_downstream_dependencies2(self): - branch_op = BranchPythonOperator( - task_id="make_choice", dag=self.dag, python_callable=lambda: "branch_2" - ) - - branch_op >> self.branch_1 >> self.branch_2 - branch_op >> self.branch_2 - self.dag.clear() + with self.dag: + branch_op = BranchPythonOperator(task_id=self.task_id, python_callable=lambda: "branch_2") + branch_op >> self.branch_1 >> self.branch_2 + branch_op >> self.branch_2 - dr = self.dag.create_dagrun( - run_type=DagRunType.MANUAL, - start_date=timezone.utcnow(), - execution_date=DEFAULT_DATE, - state=State.RUNNING, + dr = self.create_dag_run() + branch_op.run(start_date=self.default_date, end_date=self.default_date) + self.assert_expected_task_states( + dr, {self.task_id: State.SUCCESS, "branch_1": State.SKIPPED, "branch_2": State.NONE} ) - branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) - - tis = dr.get_task_instances() - for ti in tis: - if ti.task_id == "make_choice": - assert ti.state == State.SUCCESS - elif ti.task_id == "branch_1": - assert ti.state == State.SKIPPED - elif ti.task_id == "branch_2": - assert ti.state == State.NONE - else: - raise ValueError(f"Invalid task id {ti.task_id} found!") - def test_xcom_push(self): - branch_op = BranchPythonOperator( - task_id="make_choice", dag=self.dag, python_callable=lambda: "branch_1" - ) - - self.branch_1.set_upstream(branch_op) - self.branch_2.set_upstream(branch_op) - self.dag.clear() - - dr = self.dag.create_dagrun( - run_type=DagRunType.MANUAL, - start_date=timezone.utcnow(), - execution_date=DEFAULT_DATE, - state=State.RUNNING, - ) - - branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + with self.dag: + branch_op = BranchPythonOperator(task_id=self.task_id, python_callable=lambda: "branch_1") + branch_op >> [self.branch_1, self.branch_2] - tis = dr.get_task_instances() - for ti in tis: - if ti.task_id == "make_choice": - assert ti.xcom_pull(task_ids="make_choice") == "branch_1" + dr = self.create_dag_run() + branch_op.run(start_date=self.default_date, end_date=self.default_date) + for ti in dr.get_task_instances(): + if ti.task_id == self.task_id: + assert ti.xcom_pull(task_ids=self.task_id) == "branch_1" + break + else: + pytest.fail(f"{self.task_id!r} not found.") def test_clear_skipped_downstream_task(self): """ After a downstream task is skipped by BranchPythonOperator, clearing the skipped task should not cause it to be executed. """ - branch_op = BranchPythonOperator( - task_id="make_choice", dag=self.dag, python_callable=lambda: "branch_1" - ) - branches = [self.branch_1, self.branch_2] - branch_op >> branches - self.dag.clear() - - dr = self.dag.create_dagrun( - run_type=DagRunType.MANUAL, - start_date=timezone.utcnow(), - execution_date=DEFAULT_DATE, - state=State.RUNNING, - ) - - branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + with self.dag: + branch_op = BranchPythonOperator(task_id=self.task_id, python_callable=lambda: "branch_1") + branches = [self.branch_1, self.branch_2] + branch_op >> branches + dr = self.create_dag_run() + branch_op.run(start_date=self.default_date, end_date=self.default_date) for task in branches: - task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + task.run(start_date=self.default_date, end_date=self.default_date) - tis = dr.get_task_instances() - for ti in tis: - if ti.task_id == "make_choice": - assert ti.state == State.SUCCESS - elif ti.task_id == "branch_1": - assert ti.state == State.SUCCESS - elif ti.task_id == "branch_2": - assert ti.state == State.SKIPPED - else: - raise ValueError(f"Invalid task id {ti.task_id} found!") + expected_states = { + self.task_id: State.SUCCESS, + "branch_1": State.SUCCESS, + "branch_2": State.SKIPPED, + } - children_tis = [ti for ti in tis if ti.task_id in branch_op.get_direct_relative_ids()] + self.assert_expected_task_states(dr, expected_states) # Clear the children tasks. + tis = dr.get_task_instances() + children_tis = [ti for ti in tis if ti.task_id in branch_op.get_direct_relative_ids()] with create_session() as session: - clear_task_instances(children_tis, session=session, dag=self.dag) + clear_task_instances(children_tis, session=session, dag=branch_op.dag) # Run the cleared tasks again. for task in branches: - task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + task.run(start_date=self.default_date, end_date=self.default_date) # Check if the states are correct after children tasks are cleared. - for ti in dr.get_task_instances(): - if ti.task_id == "make_choice": - assert ti.state == State.SUCCESS - elif ti.task_id == "branch_1": - assert ti.state == State.SUCCESS - elif ti.task_id == "branch_2": - assert ti.state == State.SKIPPED - else: - raise ValueError(f"Invalid task id {ti.task_id} found!") + self.assert_expected_task_states(dr, expected_states) def test_raise_exception_on_no_accepted_type_return(self): - branch_op = BranchPythonOperator(task_id="make_choice", dag=self.dag, python_callable=lambda: 5) - self.dag.clear() - with pytest.raises(AirflowException) as ctx: - branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) - assert "must be either None, a task ID, or an Iterable of IDs" in str(ctx.value) + ti = self.create_ti(lambda: 5) + with pytest.raises(AirflowException, match="must be either None, a task ID, or an Iterable of IDs"): + ti.run() def test_raise_exception_on_invalid_task_id(self): - branch_op = BranchPythonOperator( - task_id="make_choice", dag=self.dag, python_callable=lambda: "some_task_id" - ) - self.dag.clear() - with pytest.raises(AirflowException) as ctx: - branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) - assert "Invalid tasks found: {'some_task_id'}" in str(ctx.value) + ti = self.create_ti(lambda: "some_task_id") + with pytest.raises(AirflowException, match="Invalid tasks found: {'some_task_id'}"): + ti.run() + @pytest.mark.parametrize( + "choice,expected_states", + [ + ("task1", [State.SUCCESS, State.SUCCESS, State.SUCCESS]), + ("join", [State.SUCCESS, State.SKIPPED, State.SUCCESS]), + ], + ) + def test_empty_branch(self, choice, expected_states): + """ + Tests that BranchPythonOperator handles empty branches properly. + """ + with self.dag: + branch = BranchPythonOperator(task_id=self.task_id, python_callable=lambda: choice) + task1 = EmptyOperator(task_id="task1") + join = EmptyOperator(task_id="join", trigger_rule="none_failed_min_one_success") -class TestShortCircuitOperator: - def setup(self): - with create_session() as session: - session.query(DagRun).delete() - session.query(TI).delete() + branch >> [task1, join] + task1 >> join - self.dag = DAG( - "short_circuit_op_test", - start_date=DEFAULT_DATE, - schedule=INTERVAL, - ) + dr = self.create_dag_run() + task_ids = [self.task_id, "task1", "join"] + tis = {ti.task_id: ti for ti in dr.task_instances} - with self.dag: - self.op1 = EmptyOperator(task_id="op1") - self.op2 = EmptyOperator(task_id="op2") - self.op1.set_downstream(self.op2) + for task_id in task_ids: # Mimic the specific order the scheduling would run the tests. + task_instance = tis[task_id] + task_instance.refresh_from_task(self.dag.get_task(task_id)) + task_instance.run() - def teardown(self): - with create_session() as session: - session.query(DagRun).delete() - session.query(TI).delete() + def get_state(ti): + ti.refresh_from_db() + return ti.state - def _assert_expected_task_states(self, dagrun, expected_states): - """Helper function that asserts `TaskInstances` of a given `task_id` are in a given state.""" + assert [get_state(tis[task_id]) for task_id in task_ids] == expected_states - tis = dagrun.get_task_instances() - for ti in tis: - try: - expected_state = expected_states[ti.task_id] - except KeyError: - raise ValueError(f"Invalid task id {ti.task_id} found!") - else: - assert ti.state == expected_state + +class TestShortCircuitOperator(BasePythonTest): + opcls = ShortCircuitOperator + + @pytest.fixture(autouse=True) + def setup_tests(self): + self.task_id = "short_circuit" + self.op1 = EmptyOperator(task_id="op1") + self.op2 = EmptyOperator(task_id="op2") all_downstream_skipped_states = { "short_circuit": State.SUCCESS, @@ -725,62 +530,41 @@ class TestShortCircuitOperator: Checking the behavior of the ShortCircuitOperator in several scenarios enabling/disabling the skipping of downstream tasks, both short-circuiting modes, and various trigger rules of downstream tasks. """ - - self.short_circuit = ShortCircuitOperator( - task_id="short_circuit", - python_callable=lambda: callable_return, - ignore_downstream_trigger_rules=test_ignore_downstream_trigger_rules, - dag=self.dag, - ) - self.short_circuit.set_downstream(self.op1) - self.op2.trigger_rule = test_trigger_rule - self.dag.clear() - - dagrun = self.dag.create_dagrun( - run_type=DagRunType.MANUAL, - start_date=timezone.utcnow(), - execution_date=DEFAULT_DATE, - state=State.RUNNING, - ) - - self.short_circuit.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) - self.op1.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) - self.op2.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) - - assert self.short_circuit.ignore_downstream_trigger_rules == test_ignore_downstream_trigger_rules - assert self.short_circuit.trigger_rule == TriggerRule.ALL_SUCCESS + with self.dag: + short_circuit = ShortCircuitOperator( + task_id="short_circuit", + python_callable=lambda: callable_return, + ignore_downstream_trigger_rules=test_ignore_downstream_trigger_rules, + ) + short_circuit >> self.op1 >> self.op2 + self.op2.trigger_rule = test_trigger_rule + + dr = self.create_dag_run() + short_circuit.run(start_date=self.default_date, end_date=self.default_date) + self.op1.run(start_date=self.default_date, end_date=self.default_date) + self.op2.run(start_date=self.default_date, end_date=self.default_date) + + assert short_circuit.ignore_downstream_trigger_rules == test_ignore_downstream_trigger_rules + assert short_circuit.trigger_rule == TriggerRule.ALL_SUCCESS assert self.op1.trigger_rule == TriggerRule.ALL_SUCCESS assert self.op2.trigger_rule == test_trigger_rule - - self._assert_expected_task_states(dagrun, expected_task_states) + self.assert_expected_task_states(dr, expected_task_states) def test_clear_skipped_downstream_task(self): """ After a downstream task is skipped by ShortCircuitOperator, clearing the skipped task should not cause it to be executed. """ - - self.short_circuit = ShortCircuitOperator( - task_id="short_circuit", - python_callable=lambda: False, - dag=self.dag, - ) - self.short_circuit.set_downstream(self.op1) - self.dag.clear() - - dagrun = self.dag.create_dagrun( - run_type=DagRunType.MANUAL, - start_date=timezone.utcnow(), - execution_date=DEFAULT_DATE, - state=State.RUNNING, - ) - - self.short_circuit.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) - self.op1.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) - self.op2.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) - - assert self.short_circuit.ignore_downstream_trigger_rules - assert self.short_circuit.trigger_rule == TriggerRule.ALL_SUCCESS + with self.dag: + short_circuit = ShortCircuitOperator(task_id="short_circuit", python_callable=lambda: False) + short_circuit >> self.op1 >> self.op2 + dr = self.create_dag_run() + + short_circuit.run(start_date=self.default_date, end_date=self.default_date) + self.op1.run(start_date=self.default_date, end_date=self.default_date) + self.op2.run(start_date=self.default_date, end_date=self.default_date) + assert short_circuit.ignore_downstream_trigger_rules + assert short_circuit.trigger_rule == TriggerRule.ALL_SUCCESS assert self.op1.trigger_rule == TriggerRule.ALL_SUCCESS assert self.op2.trigger_rule == TriggerRule.ALL_SUCCESS @@ -789,82 +573,45 @@ class TestShortCircuitOperator: "op1": State.SKIPPED, "op2": State.SKIPPED, } - self._assert_expected_task_states(dagrun, expected_states) + self.assert_expected_task_states(dr, expected_states) # Clear downstream task "op1" that was previously executed. - tis = dagrun.get_task_instances() - + tis = dr.get_task_instances() with create_session() as session: - clear_task_instances([ti for ti in tis if ti.task_id == "op1"], session=session, dag=self.dag) - - self.op1.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) - self._assert_expected_task_states(dagrun, expected_states) + clear_task_instances( + [ti for ti in tis if ti.task_id == "op1"], session=session, dag=short_circuit.dag + ) + self.op1.run(start_date=self.default_date, end_date=self.default_date) + self.assert_expected_task_states(dr, expected_states) def test_xcom_push(self): - short_op_push_xcom = ShortCircuitOperator( - task_id="push_xcom_from_shortcircuit", dag=self.dag, python_callable=lambda: "signature" - ) - - short_op_no_push_xcom = ShortCircuitOperator( - task_id="do_not_push_xcom_from_shortcircuit", dag=self.dag, python_callable=lambda: False - ) - - self.dag.clear() - dr = self.dag.create_dagrun( - run_type=DagRunType.MANUAL, - start_date=timezone.utcnow(), - execution_date=DEFAULT_DATE, - state=State.RUNNING, - ) + with self.dag: + short_op_push_xcom = ShortCircuitOperator( + task_id="push_xcom_from_shortcircuit", python_callable=lambda: "signature" + ) + short_op_no_push_xcom = ShortCircuitOperator( + task_id="do_not_push_xcom_from_shortcircuit", python_callable=lambda: False + ) - short_op_push_xcom.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) - short_op_no_push_xcom.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + dr = self.create_dag_run() + short_op_push_xcom.run(start_date=self.default_date, end_date=self.default_date) + short_op_no_push_xcom.run(start_date=self.default_date, end_date=self.default_date) tis = dr.get_task_instances() - xcom_value_short_op_push_xcom = tis[0].xcom_pull( - task_ids="push_xcom_from_shortcircuit", key="return_value" - ) - assert xcom_value_short_op_push_xcom == "signature" - - xcom_value_short_op_no_push_xcom = tis[0].xcom_pull( - task_ids="do_not_push_xcom_from_shortcircuit", key="return_value" - ) - assert xcom_value_short_op_no_push_xcom is None + assert tis[0].xcom_pull(task_ids=short_op_push_xcom.task_id, key="return_value") == "signature" + assert tis[0].xcom_pull(task_ids=short_op_no_push_xcom.task_id, key="return_value") is None virtualenv_string_args: list[str] = [] -class TestPythonVirtualenvOperator(unittest.TestCase): - def setUp(self): - super().setUp() - self.dag = DAG( - "test_dag", - default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, - template_searchpath=TEMPLATE_SEARCHPATH, - schedule=INTERVAL, - ) - self.dag.create_dagrun( - run_type=DagRunType.MANUAL, - start_date=timezone.utcnow(), - execution_date=DEFAULT_DATE, - state=State.RUNNING, - ) - self.addCleanup(self.dag.clear) - - def tearDown(self): - super().tearDown() - with create_session() as session: - session.query(DagRun).delete() - session.query(TI).delete() - - def _run_as_operator(self, fn, python_version=sys.version_info[0], **kwargs): +class TestPythonVirtualenvOperator(BasePythonTest): + opcls = PythonVirtualenvOperator - task = PythonVirtualenvOperator( - python_callable=fn, python_version=python_version, task_id="task", dag=self.dag, **kwargs - ) - task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) - return task + @staticmethod + def default_kwargs(*, python_version=sys.version_info[0], **kwargs): + kwargs["python_version"] = python_version + return kwargs def test_template_fields(self): assert set(PythonOperator.template_fields).issubset(PythonVirtualenvOperator.template_fields) @@ -874,7 +621,7 @@ class TestPythonVirtualenvOperator(unittest.TestCase): """Ensure dill is correctly installed.""" import dill # noqa: F401 - self._run_as_operator(f, use_dill=True, system_site_packages=False) + self.run_as_task(f, use_dill=True, system_site_packages=False) def test_no_requirements(self): """Tests that the python callable is invoked on task run.""" @@ -882,7 +629,7 @@ class TestPythonVirtualenvOperator(unittest.TestCase): def f(): pass - self._run_as_operator(f) + self.run_as_task(f) def test_no_system_site_packages(self): def f(): @@ -892,13 +639,13 @@ class TestPythonVirtualenvOperator(unittest.TestCase): return True raise Exception - self._run_as_operator(f, system_site_packages=False, requirements=["dill"]) + self.run_as_task(f, system_site_packages=False, requirements=["dill"]) def test_system_site_packages(self): def f(): import funcsigs # noqa: F401 - self._run_as_operator(f, requirements=["funcsigs"], system_site_packages=True) + self.run_as_task(f, requirements=["funcsigs"], system_site_packages=True) def test_with_requirements_pinned(self): def f(): @@ -907,44 +654,44 @@ class TestPythonVirtualenvOperator(unittest.TestCase): if funcsigs.__version__ != "0.4": raise Exception - self._run_as_operator(f, requirements=["funcsigs==0.4"]) + self.run_as_task(f, requirements=["funcsigs==0.4"]) def test_unpinned_requirements(self): def f(): import funcsigs # noqa: F401 - self._run_as_operator(f, requirements=["funcsigs", "dill"], system_site_packages=False) + self.run_as_task(f, requirements=["funcsigs", "dill"], system_site_packages=False) def test_range_requirements(self): def f(): import funcsigs # noqa: F401 - self._run_as_operator(f, requirements=["funcsigs>1.0", "dill"], system_site_packages=False) + self.run_as_task(f, requirements=["funcsigs>1.0", "dill"], system_site_packages=False) def test_requirements_file(self): def f(): import funcsigs # noqa: F401 - self._run_as_operator(f, requirements="requirements.txt", system_site_packages=False) + self.run_as_operator(f, requirements="requirements.txt", system_site_packages=False) - @unittest.mock.patch("airflow.operators.python.prepare_virtualenv") + @mock.patch("airflow.operators.python.prepare_virtualenv") def test_pip_install_options(self, mocked_prepare_virtualenv): def f(): import funcsigs # noqa: F401 mocked_prepare_virtualenv.side_effect = prepare_virtualenv - self._run_as_operator( + self.run_as_task( f, requirements=["funcsigs==0.4"], system_site_packages=False, pip_install_options=["--no-deps"], ) mocked_prepare_virtualenv.assert_called_with( - venv_directory=unittest.mock.ANY, - python_bin=unittest.mock.ANY, + venv_directory=mock.ANY, + python_bin=mock.ANY, system_site_packages=False, - requirements_file_path=unittest.mock.ANY, + requirements_file_path=mock.ANY, pip_install_options=["--no-deps"], ) @@ -954,7 +701,7 @@ class TestPythonVirtualenvOperator(unittest.TestCase): assert funcsigs.__version__ == "1.0.2" - self._run_as_operator( + self.run_as_operator( f, requirements="requirements.txt", use_dill=True, @@ -967,7 +714,7 @@ class TestPythonVirtualenvOperator(unittest.TestCase): raise Exception with pytest.raises(CalledProcessError): - self._run_as_operator(f) + self.run_as_task(f) def test_python_3(self): def f(): @@ -980,13 +727,13 @@ class TestPythonVirtualenvOperator(unittest.TestCase): return raise Exception - self._run_as_operator(f, python_version=3, use_dill=False, requirements=["dill"]) + self.run_as_task(f, python_version=3, use_dill=False, requirements=["dill"]) def test_without_dill(self): def f(a): return a - self._run_as_operator(f, system_site_packages=False, use_dill=False, op_args=[4]) + self.run_as_task(f, system_site_packages=False, use_dill=False, op_args=[4]) def test_string_args(self): def f(): @@ -995,7 +742,7 @@ class TestPythonVirtualenvOperator(unittest.TestCase): if virtualenv_string_args[0] != virtualenv_string_args[2]: raise Exception - self._run_as_operator(f, string_args=[1, 2, 1]) + self.run_as_task(f, string_args=[1, 2, 1]) def test_with_args(self): def f(a, b, c=False, d=False): @@ -1004,37 +751,38 @@ class TestPythonVirtualenvOperator(unittest.TestCase): else: raise Exception - self._run_as_operator(f, op_args=[0, 1], op_kwargs={"c": True}) + self.run_as_task(f, op_args=[0, 1], op_kwargs={"c": True}) def test_return_none(self): def f(): return None - task = self._run_as_operator(f) + task = self.run_as_task(f) assert task.execute_callable() is None def test_return_false(self): def f(): return False - task = self._run_as_operator(f) + task = self.run_as_task(f) assert task.execute_callable() is False def test_lambda(self): with pytest.raises(AirflowException): - PythonVirtualenvOperator(python_callable=lambda x: 4, task_id="task", dag=self.dag) + PythonVirtualenvOperator(python_callable=lambda x: 4, task_id=self.task_id) def test_nonimported_as_arg(self): def f(_): return None - self._run_as_operator(f, op_args=[datetime.utcnow()]) + self.run_as_task(f, op_args=[datetime.utcnow()]) def test_context(self): def f(templates_dict): return templates_dict["ds"] - self._run_as_operator(f, templates_dict={"ds": "{{ ds }}"}) + task = self.run_as_task(f, templates_dict={"ds": "{{ ds }}"}) + assert task.templates_dict == {"ds": self.ds_templated} # This tests might take longer than default 60 seconds as it is serializing a lot of # context using dill (which is slow apparently). @@ -1078,7 +826,7 @@ class TestPythonVirtualenvOperator(unittest.TestCase): ): pass - self._run_as_operator(f, use_dill=True, system_site_packages=True, requirements=None) + self.run_as_operator(f, use_dill=True, system_site_packages=True, requirements=None) @pytest.mark.filterwarnings("ignore::airflow.utils.context.AirflowContextDeprecationWarning") def test_pendulum_context(self): @@ -1112,7 +860,7 @@ class TestPythonVirtualenvOperator(unittest.TestCase): ): pass - self._run_as_operator(f, use_dill=True, system_site_packages=False, requirements=["pendulum"]) + self.run_as_task(f, use_dill=True, system_site_packages=False, requirements=["pendulum"]) @pytest.mark.filterwarnings("ignore::airflow.utils.context.AirflowContextDeprecationWarning") def test_base_context(self): @@ -1140,7 +888,7 @@ class TestPythonVirtualenvOperator(unittest.TestCase): ): pass - self._run_as_operator(f, use_dill=True, system_site_packages=False, requirements=None) + self.run_as_task(f, use_dill=True, system_site_packages=False, requirements=None) def test_deepcopy(self): """Test that PythonVirtualenvOperator are deep-copyable.""" @@ -1148,13 +896,37 @@ class TestPythonVirtualenvOperator(unittest.TestCase): def f(): return 1 - task = PythonVirtualenvOperator( - python_callable=f, - task_id="task", - dag=self.dag, - ) + task = PythonVirtualenvOperator(python_callable=f, task_id="task") copy.deepcopy(task) + def test_virtualenv_serializable_context_fields(self, create_task_instance): + """Ensure all template context fields are listed in the operator. + + This exists mainly so when a field is added to the context, we remember to + also add it to PythonVirtualenvOperator. + """ + # These are intentionally NOT serialized into the virtual environment: + # * Variables pointing to the task instance itself. + # * Variables that are accessor instances. + intentionally_excluded_context_keys = [ + "task_instance", + "ti", + "var", # Accessor for Variable; var->json and var->value. + "conn", # Accessor for Connection. + ] + + ti = create_task_instance(dag_id=self.dag_id, task_id=self.task_id, schedule=None) + context = ti.get_template_context() + + declared_keys = { + *PythonVirtualenvOperator.BASE_SERIALIZABLE_CONTEXT_KEYS, + *PythonVirtualenvOperator.PENDULUM_SERIALIZABLE_CONTEXT_KEYS, + *PythonVirtualenvOperator.AIRFLOW_SERIALIZABLE_CONTEXT_KEYS, + *intentionally_excluded_context_keys, + } + + assert set(context) == declared_keys + DEFAULT_ARGS = { "owner": "test", @@ -1221,7 +993,7 @@ def get_all_the_context(**context): assert context == current_context._context [email protected]() [email protected] def clear_db(): clear_db_runs() yield @@ -1239,76 +1011,3 @@ class TestCurrentContextRuntime: with DAG(dag_id="edge_case_context_dag", default_args=DEFAULT_ARGS): op = PythonOperator(python_callable=get_all_the_context, task_id="get_all_the_context") op.run(ignore_first_depends_on_past=True, ignore_ti_state=True) - - [email protected]( - "choice,expected_states", - [ - ("task1", [State.SUCCESS, State.SUCCESS, State.SUCCESS]), - ("join", [State.SUCCESS, State.SKIPPED, State.SUCCESS]), - ], -) -def test_empty_branch(dag_maker, choice, expected_states): - """ - Tests that BranchPythonOperator handles empty branches properly. - """ - with dag_maker( - "test_empty_branch", - start_date=DEFAULT_DATE, - ) as dag: - branch = BranchPythonOperator(task_id="branch", python_callable=lambda: choice) - task1 = EmptyOperator(task_id="task1") - join = EmptyOperator(task_id="join", trigger_rule="none_failed_min_one_success") - - branch >> [task1, join] - task1 >> join - - dag.clear(start_date=DEFAULT_DATE) - dag_run = dag_maker.create_dagrun() - - task_ids = ["branch", "task1", "join"] - tis = {ti.task_id: ti for ti in dag_run.task_instances} - - for task_id in task_ids: # Mimic the specific order the scheduling would run the tests. - task_instance = tis[task_id] - task_instance.refresh_from_task(dag.get_task(task_id)) - task_instance.run() - - def get_state(ti): - ti.refresh_from_db() - return ti.state - - assert [get_state(tis[task_id]) for task_id in task_ids] == expected_states - - -def test_virtualenv_serializable_context_fields(create_task_instance): - """Ensure all template context fields are listed in the operator. - - This exists mainly so when a field is added to the context, we remember to - also add it to PythonVirtualenvOperator. - """ - # These are intentionally NOT serialized into the virtual environment: - # * Variables pointing to the task instance itself. - # * Variables that are accessor instances. - intentionally_excluded_context_keys = [ - "task_instance", - "ti", - "var", # Accessor for Variable; var->json and var->value. - "conn", # Accessor for Connection. - ] - - ti = create_task_instance( - dag_id="test_virtualenv_serializable_context_fields", - task_id="test_virtualenv_serializable_context_fields_task", - schedule=None, - ) - context = ti.get_template_context() - - declared_keys = { - *PythonVirtualenvOperator.BASE_SERIALIZABLE_CONTEXT_KEYS, - *PythonVirtualenvOperator.PENDULUM_SERIALIZABLE_CONTEXT_KEYS, - *PythonVirtualenvOperator.AIRFLOW_SERIALIZABLE_CONTEXT_KEYS, - *intentionally_excluded_context_keys, - } - - assert set(context) == declared_keys diff --git a/tests/sensors/test_python.py b/tests/sensors/test_python.py index f3c258185d..73a0ddffe4 100644 --- a/tests/sensors/test_python.py +++ b/tests/sensors/test_python.py @@ -24,112 +24,52 @@ import pytest from airflow.exceptions import AirflowSensorTimeout from airflow.sensors.python import PythonSensor -from airflow.utils.state import State -from airflow.utils.timezone import datetime -from airflow.utils.types import DagRunType -from tests.operators.test_python import Call, assert_calls_equal, build_recording_function +from tests.operators.test_python import BasePythonTest -DEFAULT_DATE = datetime(2015, 1, 1) +class TestPythonSensor(BasePythonTest): + opcls = PythonSensor -class TestPythonSensor: - def test_python_sensor_true(self, dag_maker): - with dag_maker(): - op = PythonSensor(task_id="python_sensor_check_true", python_callable=lambda: True) - op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) + def test_python_sensor_true(self): + self.run_as_task(fn=lambda: True) - def test_python_sensor_false(self, dag_maker): - with dag_maker(): - op = PythonSensor( - task_id="python_sensor_check_false", - timeout=0.01, - poke_interval=0.01, - python_callable=lambda: False, - ) + def test_python_sensor_false(self): with pytest.raises(AirflowSensorTimeout): - op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) + self.run_as_task(lambda: False, timeout=0.01, poke_interval=0.01) - def test_python_sensor_raise(self, dag_maker): - with dag_maker(): - op = PythonSensor(task_id="python_sensor_check_raise", python_callable=lambda: 1 / 0) + def test_python_sensor_raise(self): with pytest.raises(ZeroDivisionError): - op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) + self.run_as_task(lambda: 1 / 0) - def test_python_callable_arguments_are_templatized(self, dag_maker): + def test_python_callable_arguments_are_templatized(self): """Test PythonSensor op_args are templatized""" - recorded_calls = [] - # Create a named tuple and ensure it is still preserved # after the rendering is done Named = namedtuple("Named", ["var1", "var2"]) named_tuple = Named("{{ ds }}", "unchanged") - with dag_maker() as dag: - task = PythonSensor( - task_id="python_sensor", - timeout=0.01, - poke_interval=0.3, - # a Mock instance cannot be used as a callable function or test fails with a - # TypeError: Object of type Mock is not JSON serializable - python_callable=build_recording_function(recorded_calls), - op_args=[4, date(2019, 1, 1), "dag {{dag.dag_id}} ran on {{ds}}.", named_tuple], - ) - - dag.create_dagrun( - run_type=DagRunType.MANUAL, - execution_date=DEFAULT_DATE, - data_interval=(DEFAULT_DATE, DEFAULT_DATE), - start_date=DEFAULT_DATE, - state=State.RUNNING, - ) - with pytest.raises(AirflowSensorTimeout): - task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) - - ds_templated = DEFAULT_DATE.date().isoformat() - assert_calls_equal( - recorded_calls[0], - Call( - 4, - date(2019, 1, 1), - f"dag {dag.dag_id} ran on {ds_templated}.", - Named(ds_templated, "unchanged"), - ), + task = self.render_templates( + lambda: 0, + op_args=[4, date(2019, 1, 1), "dag {{dag.dag_id}} ran on {{ds}}.", named_tuple], ) - - def test_python_callable_keyword_arguments_are_templatized(self, dag_maker): + rendered_op_args = task.op_args + assert len(rendered_op_args) == 4 + assert rendered_op_args[0] == 4 + assert rendered_op_args[1] == date(2019, 1, 1) + assert rendered_op_args[2] == f"dag {self.dag_id} ran on {self.ds_templated}." + assert rendered_op_args[3] == Named(self.ds_templated, "unchanged") + + def test_python_callable_keyword_arguments_are_templatized(self): """Test PythonSensor op_kwargs are templatized""" - recorded_calls = [] - - with dag_maker() as dag: - task = PythonSensor( - task_id="python_sensor", - timeout=0.01, - poke_interval=0.01, - # a Mock instance cannot be used as a callable function or test fails with a - # TypeError: Object of type Mock is not JSON serializable - python_callable=build_recording_function(recorded_calls), - op_kwargs={ - "an_int": 4, - "a_date": date(2019, 1, 1), - "a_templated_string": "dag {{dag.dag_id}} ran on {{ds}}.", - }, - ) - - dag.create_dagrun( - run_type=DagRunType.MANUAL, - execution_date=DEFAULT_DATE, - data_interval=(DEFAULT_DATE, DEFAULT_DATE), - start_date=DEFAULT_DATE, - state=State.RUNNING, - ) - with pytest.raises(AirflowSensorTimeout): - task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) - - assert_calls_equal( - recorded_calls[0], - Call( - an_int=4, - a_date=date(2019, 1, 1), - a_templated_string=f"dag {dag.dag_id} ran on {DEFAULT_DATE.date().isoformat()}.", - ), + task = self.render_templates( + lambda: 0, + op_kwargs={ + "an_int": 4, + "a_date": date(2019, 1, 1), + "a_templated_string": "dag {{dag.dag_id}} ran on {{ds}}.", + }, ) + rendered_op_kwargs = task.op_kwargs + assert rendered_op_kwargs["an_int"] == 4 + assert rendered_op_kwargs["a_date"] == date(2019, 1, 1) + assert rendered_op_kwargs["a_templated_string"] == f"dag {self.dag_id} ran on {self.ds_templated}."
