This is an automated email from the ASF dual-hosted git repository.
uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 6405d8f Better multiple_outputs inferral for @task.python (#20800)
6405d8f is described below
commit 6405d8f804e7cbd1748aa7eed65f2bbf0fcf022e
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Thu Jan 27 14:52:26 2022 +0800
Better multiple_outputs inferral for @task.python (#20800)
---
airflow/decorators/base.py | 25 +++---
tests/decorators/test_python.py | 119 ++++++++++-------------------
tests/decorators/test_python_virtualenv.py | 48 ++++++------
tests/operators/test_python.py | 25 +++---
tests/sensors/test_python.py | 98 ++++++++++++------------
5 files changed, 142 insertions(+), 173 deletions(-)
diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py
index 2e157a3..dec09df 100644
--- a/airflow/decorators/base.py
+++ b/airflow/decorators/base.py
@@ -38,6 +38,7 @@ from typing import (
)
import attr
+import typing_extensions
from airflow.compat.functools import cached_property
from airflow.exceptions import AirflowException
@@ -233,19 +234,21 @@ class _TaskDecorator(Generic[Function, OperatorSubclass]):
@multiple_outputs.default
def _infer_multiple_outputs(self):
- return_type = self.function_signature.return_annotation
-
- # If the return type annotation is already the builtins ``dict`` type,
use it for the inference.
- if return_type == dict:
- ttype = return_type
- # Checking if Python 3.6, ``__origin__`` attribute does not exist
until 3.7; need to use ``__extra__``
- # TODO: Remove check when support for Python 3.6 is dropped in Airflow
2.3.
- elif sys.version_info < (3, 7):
- ttype = getattr(return_type, "__extra__", None)
+ try:
+ return_type =
typing_extensions.get_type_hints(self.function).get("return", Any)
+ except Exception: # Can't evaluate retrurn type.
+ return False
+
+ # Get the non-subscripted type. The ``__origin__`` attribute is not
+ # stable until 3.7, but we need to use ``__extra__`` instead.
+ # TODO: Remove the ``__extra__`` branch when support for Python 3.6 is
+ # dropped in Airflow 2.3.
+ if sys.version_info < (3, 7):
+ ttype = getattr(return_type, "__extra__", return_type)
else:
- ttype = getattr(return_type, "__origin__", None)
+ ttype = getattr(return_type, "__origin__", return_type)
- return return_type is not inspect.Signature.empty and ttype in (dict,
Dict)
+ return ttype == dict or ttype == Dict
def __attrs_post_init__(self):
self.kwargs.setdefault('task_id', self.function.__name__)
diff --git a/tests/decorators/test_python.py b/tests/decorators/test_python.py
index ea1a23a..0c93b49 100644
--- a/tests/decorators/test_python.py
+++ b/tests/decorators/test_python.py
@@ -16,24 +16,24 @@
# specific language governing permissions and limitations
# under the License.
import sys
-import unittest.mock
from collections import namedtuple
from datetime import date, timedelta
+from typing import Dict # noqa: F401 # This is used by annotation tests.
from typing import Tuple
import pytest
-from parameterized import parameterized
from airflow.decorators import task as task_decorator
from airflow.exceptions import AirflowException
-from airflow.models import DAG, DagRun, TaskInstance as TI
+from airflow.models import DAG
from airflow.models.baseoperator import MappedOperator
from airflow.models.xcom_arg import XComArg
from airflow.utils import timezone
-from airflow.utils.session import create_session
from airflow.utils.state import State
from airflow.utils.task_group import TaskGroup
from airflow.utils.types import DagRunType
+from tests.operators.test_python import Call, assert_calls_equal,
build_recording_function
+from tests.test_utils.db import clear_db_runs
DEFAULT_DATE = timezone.datetime(2016, 1, 1)
END_DATE = timezone.datetime(2016, 1, 2)
@@ -48,66 +48,19 @@ TI_CONTEXT_ENV_VARS = [
]
-class Call:
- def __init__(self, *args, **kwargs):
- self.args = args
- self.kwargs = kwargs
+class TestAirflowTaskDecorator:
+ def setup_class(self):
+ clear_db_runs()
-
-def build_recording_function(calls_collection):
- """
- We can not use a Mock instance as a PythonOperator callable function or
some tests fail with a
- TypeError: Object of type Mock is not JSON serializable
- Then using this custom function recording custom Call objects for further
testing
- (replacing Mock.assert_called_with assertion method)
- """
-
- def recording_function(*args, **kwargs):
- calls_collection.append(Call(*args, **kwargs))
-
- return recording_function
-
-
-class TestPythonBase(unittest.TestCase):
- """Base test class for TestPythonOperator and TestPythonSensor classes"""
-
- @classmethod
- def setUpClass(cls):
- super().setUpClass()
-
- with create_session() as session:
- session.query(DagRun).delete()
- session.query(TI).delete()
-
- def setUp(self):
- super().setUp()
- self.dag = DAG('test_dag', default_args={'owner': 'airflow',
'start_date': DEFAULT_DATE})
- self.addCleanup(self.dag.clear)
- self.clear_run()
- self.addCleanup(self.clear_run)
-
- def tearDown(self):
- super().tearDown()
-
- with create_session() as session:
- session.query(DagRun).delete()
- session.query(TI).delete()
-
- def clear_run(self):
+ def setup_method(self):
+ self.dag = DAG("test_dag", default_args={"owner": "airflow",
"start_date": DEFAULT_DATE})
self.run = False
- def _assert_calls_equal(self, first, second):
- assert isinstance(first, Call)
- assert isinstance(second, Call)
- assert first.args == second.args
- # eliminate context (conf, dag_run, task_instance, etc.)
- test_args = ["an_int", "a_date", "a_templated_string"]
- first.kwargs = {key: value for (key, value) in first.kwargs.items() if
key in test_args}
- second.kwargs = {key: value for (key, value) in second.kwargs.items()
if key in test_args}
- assert first.kwargs == second.kwargs
-
+ def teardown_method(self):
+ self.dag.clear()
+ self.run = False
+ clear_db_runs()
-class TestAirflowTaskDecorator(TestPythonBase):
def test_python_operator_python_callable_is_callable(self):
"""Tests that @task will only instantiate if
the python_callable argument is callable."""
@@ -115,22 +68,34 @@ class TestAirflowTaskDecorator(TestPythonBase):
with pytest.raises(TypeError):
task_decorator(not_callable, dag=self.dag)
- @parameterized.expand([["dict"], ["dict[str, int]"], ["Dict"], ["Dict[str,
int]"]])
- def test_infer_multiple_outputs_using_dict_typing(self,
test_return_annotation):
- if sys.version_info < (3, 9) and test_return_annotation == "dict[str,
int]":
- raise pytest.skip("dict[...] not a supported typing prior to
Python 3.9")
-
- @task_decorator
- def identity_dict(x: int, y: int) -> eval(test_return_annotation):
- return {"x": x, "y": y}
-
- assert identity_dict(5, 5).operator.multiple_outputs is True
-
- @task_decorator
- def identity_dict_stringified(x: int, y: int) ->
test_return_annotation:
- return {"x": x, "y": y}
+ @pytest.mark.parametrize(
+ "resolve",
+ [
+ pytest.param(eval, id="eval"),
+ pytest.param(lambda t: t, id="stringify"),
+ ],
+ )
+ @pytest.mark.parametrize(
+ "annotation",
+ [
+ "dict",
+ pytest.param(
+ "dict[str, int]",
+ marks=pytest.mark.skipif(
+ sys.version_info < (3, 9),
+ reason="PEP 585 is implemented in Python 3.9",
+ ),
+ ),
+ "Dict",
+ "Dict[str, int]",
+ ],
+ )
+ def test_infer_multiple_outputs_using_dict_typing(self, resolve,
annotation):
+ @task_decorator
+ def identity_dict(x: int, y: int) -> resolve(annotation):
+ return {"x": x, "y": y}
- assert identity_dict_stringified(5, 5).operator.multiple_outputs
is True
+ assert identity_dict(5, 5).operator.multiple_outputs is True
def test_infer_multiple_outputs_using_other_typing(self):
@task_decorator
@@ -288,7 +253,7 @@ class TestAirflowTaskDecorator(TestPythonBase):
ds_templated = DEFAULT_DATE.date().isoformat()
assert len(recorded_calls) == 1
- self._assert_calls_equal(
+ assert_calls_equal(
recorded_calls[0],
Call(
4,
@@ -319,7 +284,7 @@ class TestAirflowTaskDecorator(TestPythonBase):
ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
assert len(recorded_calls) == 1
- self._assert_calls_equal(
+ assert_calls_equal(
recorded_calls[0],
Call(
an_int=4,
diff --git a/tests/decorators/test_python_virtualenv.py
b/tests/decorators/test_python_virtualenv.py
index bce9c53..01e646c 100644
--- a/tests/decorators/test_python_virtualenv.py
+++ b/tests/decorators/test_python_virtualenv.py
@@ -25,8 +25,6 @@ import pytest
from airflow.decorators import task
from airflow.utils import timezone
-from .test_python import TestPythonBase
-
DEFAULT_DATE = timezone.datetime(2016, 1, 1)
END_DATE = timezone.datetime(2016, 1, 2)
INTERVAL = timedelta(hours=12)
@@ -43,31 +41,31 @@ TI_CONTEXT_ENV_VARS = [
PYTHON_VERSION = sys.version_info[0]
-class TestPythonVirtualenvDecorator(TestPythonBase):
- def test_add_dill(self):
+class TestPythonVirtualenvDecorator:
+ def test_add_dill(self, dag_maker):
@task.virtualenv(use_dill=True, system_site_packages=False)
def f():
"""Ensure dill is correctly installed."""
import dill # noqa: F401
- with self.dag:
+ with dag_maker():
ret = f()
ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
- def test_no_requirements(self):
+ def test_no_requirements(self, dag_maker):
"""Tests that the python callable is invoked on task run."""
@task.virtualenv()
def f():
pass
- with self.dag:
+ with dag_maker():
ret = f()
ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
- def test_no_system_site_packages(self):
+ def test_no_system_site_packages(self, dag_maker):
@task.virtualenv(system_site_packages=False,
python_version=PYTHON_VERSION, use_dill=True)
def f():
try:
@@ -76,12 +74,12 @@ class TestPythonVirtualenvDecorator(TestPythonBase):
return True
raise Exception
- with self.dag:
+ with dag_maker():
ret = f()
ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
- def test_system_site_packages(self):
+ def test_system_site_packages(self, dag_maker):
@task.virtualenv(
system_site_packages=False,
requirements=['funcsigs'],
@@ -91,12 +89,12 @@ class TestPythonVirtualenvDecorator(TestPythonBase):
def f():
import funcsigs # noqa: F401
- with self.dag:
+ with dag_maker():
ret = f()
ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
- def test_with_requirements_pinned(self):
+ def test_with_requirements_pinned(self, dag_maker):
@task.virtualenv(
system_site_packages=False,
requirements=['funcsigs==0.4'],
@@ -109,12 +107,12 @@ class TestPythonVirtualenvDecorator(TestPythonBase):
if funcsigs.__version__ != '0.4':
raise Exception
- with self.dag:
+ with dag_maker():
ret = f()
ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
- def test_unpinned_requirements(self):
+ def test_unpinned_requirements(self, dag_maker):
@task.virtualenv(
system_site_packages=False,
requirements=['funcsigs', 'dill'],
@@ -124,23 +122,23 @@ class TestPythonVirtualenvDecorator(TestPythonBase):
def f():
import funcsigs # noqa: F401
- with self.dag:
+ with dag_maker():
ret = f()
ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
- def test_fail(self):
+ def test_fail(self, dag_maker):
@task.virtualenv()
def f():
raise Exception
- with self.dag:
+ with dag_maker():
ret = f()
with pytest.raises(CalledProcessError):
ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
- def test_python_3(self):
+ def test_python_3(self, dag_maker):
@task.virtualenv(python_version=3, use_dill=False,
requirements=['dill'])
def f():
import sys
@@ -152,12 +150,12 @@ class TestPythonVirtualenvDecorator(TestPythonBase):
return
raise Exception
- with self.dag:
+ with dag_maker():
ret = f()
ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
- def test_with_args(self):
+ def test_with_args(self, dag_maker):
@task.virtualenv
def f(a, b, c=False, d=False):
if a == 0 and b == 1 and c and not d:
@@ -165,27 +163,27 @@ class TestPythonVirtualenvDecorator(TestPythonBase):
else:
raise Exception
- with self.dag:
+ with dag_maker():
ret = f(0, 1, c=True)
ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
- def test_return_none(self):
+ def test_return_none(self, dag_maker):
@task.virtualenv
def f():
return None
- with self.dag:
+ with dag_maker():
ret = f()
ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
- def test_nonimported_as_arg(self):
+ def test_nonimported_as_arg(self, dag_maker):
@task.virtualenv
def f(_):
return None
- with self.dag:
+ with dag_maker():
ret = f(datetime.datetime.utcnow())
ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
diff --git a/tests/operators/test_python.py b/tests/operators/test_python.py
index dda716e..24331ef 100644
--- a/tests/operators/test_python.py
+++ b/tests/operators/test_python.py
@@ -84,6 +84,17 @@ def build_recording_function(calls_collection):
return recording_function
+def assert_calls_equal(first: Call, second: Call) -> None:
+ assert isinstance(first, Call)
+ assert isinstance(second, Call)
+ assert first.args == second.args
+ # eliminate context (conf, dag_run, task_instance, etc.)
+ test_args = ["an_int", "a_date", "a_templated_string"]
+ first.kwargs = {key: value for (key, value) in first.kwargs.items() if key
in test_args}
+ second.kwargs = {key: value for (key, value) in second.kwargs.items() if
key in test_args}
+ assert first.kwargs == second.kwargs
+
+
class TestPythonBase(unittest.TestCase):
"""Base test class for TestPythonOperator and TestPythonSensor classes"""
@@ -112,16 +123,6 @@ class TestPythonBase(unittest.TestCase):
def clear_run(self):
self.run = False
- def _assert_calls_equal(self, first, second):
- assert isinstance(first, Call)
- assert isinstance(second, Call)
- assert first.args == second.args
- # eliminate context (conf, dag_run, task_instance, etc.)
- test_args = ["an_int", "a_date", "a_templated_string"]
- first.kwargs = {key: value for (key, value) in first.kwargs.items() if
key in test_args}
- second.kwargs = {key: value for (key, value) in second.kwargs.items()
if key in test_args}
- assert first.kwargs == second.kwargs
-
class TestPythonOperator(TestPythonBase):
def do_run(self):
@@ -176,7 +177,7 @@ class TestPythonOperator(TestPythonBase):
ds_templated = DEFAULT_DATE.date().isoformat()
assert 1 == len(recorded_calls)
- self._assert_calls_equal(
+ assert_calls_equal(
recorded_calls[0],
Call(
4,
@@ -213,7 +214,7 @@ class TestPythonOperator(TestPythonBase):
task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
assert 1 == len(recorded_calls)
- self._assert_calls_equal(
+ assert_calls_equal(
recorded_calls[0],
Call(
an_int=4,
diff --git a/tests/sensors/test_python.py b/tests/sensors/test_python.py
index 63ea6ce..3c21ef5 100644
--- a/tests/sensors/test_python.py
+++ b/tests/sensors/test_python.py
@@ -27,33 +27,35 @@ from airflow.sensors.python import PythonSensor
from airflow.utils.state import State
from airflow.utils.timezone import datetime
from airflow.utils.types import DagRunType
-from tests.operators.test_python import Call, TestPythonBase,
build_recording_function
+from tests.operators.test_python import Call, assert_calls_equal,
build_recording_function
DEFAULT_DATE = datetime(2015, 1, 1)
-class TestPythonSensor(TestPythonBase):
- def test_python_sensor_true(self):
- op = PythonSensor(task_id='python_sensor_check_true',
python_callable=lambda: True, dag=self.dag)
+class TestPythonSensor:
+ def test_python_sensor_true(self, dag_maker):
+ with dag_maker():
+ op = PythonSensor(task_id='python_sensor_check_true',
python_callable=lambda: True)
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE,
ignore_ti_state=True)
- def test_python_sensor_false(self):
- op = PythonSensor(
- task_id='python_sensor_check_false',
- timeout=0.01,
- poke_interval=0.01,
- python_callable=lambda: False,
- dag=self.dag,
- )
+ def test_python_sensor_false(self, dag_maker):
+ with dag_maker():
+ op = PythonSensor(
+ task_id='python_sensor_check_false',
+ timeout=0.01,
+ poke_interval=0.01,
+ python_callable=lambda: False,
+ )
with pytest.raises(AirflowSensorTimeout):
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE,
ignore_ti_state=True)
- def test_python_sensor_raise(self):
- op = PythonSensor(task_id='python_sensor_check_raise',
python_callable=lambda: 1 / 0, dag=self.dag)
+ def test_python_sensor_raise(self, dag_maker):
+ with dag_maker():
+ op = PythonSensor(task_id='python_sensor_check_raise',
python_callable=lambda: 1 / 0)
with pytest.raises(ZeroDivisionError):
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE,
ignore_ti_state=True)
- def test_python_callable_arguments_are_templatized(self):
+ def test_python_callable_arguments_are_templatized(self, dag_maker):
"""Test PythonSensor op_args are templatized"""
recorded_calls = []
@@ -62,18 +64,18 @@ class TestPythonSensor(TestPythonBase):
Named = namedtuple('Named', ['var1', 'var2'])
named_tuple = Named('{{ ds }}', 'unchanged')
- task = PythonSensor(
- task_id='python_sensor',
- timeout=0.01,
- poke_interval=0.3,
- # a Mock instance cannot be used as a callable function or test
fails with a
- # TypeError: Object of type Mock is not JSON serializable
- python_callable=build_recording_function(recorded_calls),
- op_args=[4, date(2019, 1, 1), "dag {{dag.dag_id}} ran on {{ds}}.",
named_tuple],
- dag=self.dag,
- )
-
- self.dag.create_dagrun(
+ with dag_maker() as dag:
+ task = PythonSensor(
+ task_id='python_sensor',
+ timeout=0.01,
+ poke_interval=0.3,
+ # a Mock instance cannot be used as a callable function or
test fails with a
+ # TypeError: Object of type Mock is not JSON serializable
+ python_callable=build_recording_function(recorded_calls),
+ op_args=[4, date(2019, 1, 1), "dag {{dag.dag_id}} ran on
{{ds}}.", named_tuple],
+ )
+
+ dag.create_dagrun(
run_type=DagRunType.MANUAL,
execution_date=DEFAULT_DATE,
data_interval=(DEFAULT_DATE, DEFAULT_DATE),
@@ -84,36 +86,36 @@ class TestPythonSensor(TestPythonBase):
task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
ds_templated = DEFAULT_DATE.date().isoformat()
- self._assert_calls_equal(
+ assert_calls_equal(
recorded_calls[0],
Call(
4,
date(2019, 1, 1),
- f"dag {self.dag.dag_id} ran on {ds_templated}.",
+ f"dag {dag.dag_id} ran on {ds_templated}.",
Named(ds_templated, 'unchanged'),
),
)
- def test_python_callable_keyword_arguments_are_templatized(self):
+ def test_python_callable_keyword_arguments_are_templatized(self,
dag_maker):
"""Test PythonSensor op_kwargs are templatized"""
recorded_calls = []
- task = PythonSensor(
- task_id='python_sensor',
- timeout=0.01,
- poke_interval=0.01,
- # a Mock instance cannot be used as a callable function or test
fails with a
- # TypeError: Object of type Mock is not JSON serializable
- python_callable=build_recording_function(recorded_calls),
- op_kwargs={
- 'an_int': 4,
- 'a_date': date(2019, 1, 1),
- 'a_templated_string': "dag {{dag.dag_id}} ran on {{ds}}.",
- },
- dag=self.dag,
- )
-
- self.dag.create_dagrun(
+ with dag_maker() as dag:
+ task = PythonSensor(
+ task_id='python_sensor',
+ timeout=0.01,
+ poke_interval=0.01,
+ # a Mock instance cannot be used as a callable function or
test fails with a
+ # TypeError: Object of type Mock is not JSON serializable
+ python_callable=build_recording_function(recorded_calls),
+ op_kwargs={
+ 'an_int': 4,
+ 'a_date': date(2019, 1, 1),
+ 'a_templated_string': "dag {{dag.dag_id}} ran on {{ds}}.",
+ },
+ )
+
+ dag.create_dagrun(
run_type=DagRunType.MANUAL,
execution_date=DEFAULT_DATE,
data_interval=(DEFAULT_DATE, DEFAULT_DATE),
@@ -123,11 +125,11 @@ class TestPythonSensor(TestPythonBase):
with pytest.raises(AirflowSensorTimeout):
task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
- self._assert_calls_equal(
+ assert_calls_equal(
recorded_calls[0],
Call(
an_int=4,
a_date=date(2019, 1, 1),
- a_templated_string=f"dag {self.dag.dag_id} ran on
{DEFAULT_DATE.date().isoformat()}.",
+ a_templated_string=f"dag {dag.dag_id} ran on
{DEFAULT_DATE.date().isoformat()}.",
),
)