This is an automated email from the ASF dual-hosted git repository.

uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 6405d8f  Better multiple_outputs inferral for @task.python (#20800)
6405d8f is described below

commit 6405d8f804e7cbd1748aa7eed65f2bbf0fcf022e
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Thu Jan 27 14:52:26 2022 +0800

    Better multiple_outputs inferral for @task.python (#20800)
---
 airflow/decorators/base.py                 |  25 +++---
 tests/decorators/test_python.py            | 119 ++++++++++-------------------
 tests/decorators/test_python_virtualenv.py |  48 ++++++------
 tests/operators/test_python.py             |  25 +++---
 tests/sensors/test_python.py               |  98 ++++++++++++------------
 5 files changed, 142 insertions(+), 173 deletions(-)

diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py
index 2e157a3..dec09df 100644
--- a/airflow/decorators/base.py
+++ b/airflow/decorators/base.py
@@ -38,6 +38,7 @@ from typing import (
 )
 
 import attr
+import typing_extensions
 
 from airflow.compat.functools import cached_property
 from airflow.exceptions import AirflowException
@@ -233,19 +234,21 @@ class _TaskDecorator(Generic[Function, OperatorSubclass]):
 
     @multiple_outputs.default
     def _infer_multiple_outputs(self):
-        return_type = self.function_signature.return_annotation
-
-        # If the return type annotation is already the builtins ``dict`` type, 
use it for the inference.
-        if return_type == dict:
-            ttype = return_type
-        # Checking if Python 3.6, ``__origin__`` attribute does not exist 
until 3.7; need to use ``__extra__``
-        # TODO: Remove check when support for Python 3.6 is dropped in Airflow 
2.3.
-        elif sys.version_info < (3, 7):
-            ttype = getattr(return_type, "__extra__", None)
+        try:
+            return_type = 
typing_extensions.get_type_hints(self.function).get("return", Any)
+        except Exception:  # Can't evaluate retrurn type.
+            return False
+
+        # Get the non-subscripted type. The ``__origin__`` attribute is not
+        # stable until 3.7, but we need to use ``__extra__`` instead.
+        # TODO: Remove the ``__extra__`` branch when support for Python 3.6 is
+        # dropped in Airflow 2.3.
+        if sys.version_info < (3, 7):
+            ttype = getattr(return_type, "__extra__", return_type)
         else:
-            ttype = getattr(return_type, "__origin__", None)
+            ttype = getattr(return_type, "__origin__", return_type)
 
-        return return_type is not inspect.Signature.empty and ttype in (dict, 
Dict)
+        return ttype == dict or ttype == Dict
 
     def __attrs_post_init__(self):
         self.kwargs.setdefault('task_id', self.function.__name__)
diff --git a/tests/decorators/test_python.py b/tests/decorators/test_python.py
index ea1a23a..0c93b49 100644
--- a/tests/decorators/test_python.py
+++ b/tests/decorators/test_python.py
@@ -16,24 +16,24 @@
 # specific language governing permissions and limitations
 # under the License.
 import sys
-import unittest.mock
 from collections import namedtuple
 from datetime import date, timedelta
+from typing import Dict  # noqa: F401  # This is used by annotation tests.
 from typing import Tuple
 
 import pytest
-from parameterized import parameterized
 
 from airflow.decorators import task as task_decorator
 from airflow.exceptions import AirflowException
-from airflow.models import DAG, DagRun, TaskInstance as TI
+from airflow.models import DAG
 from airflow.models.baseoperator import MappedOperator
 from airflow.models.xcom_arg import XComArg
 from airflow.utils import timezone
-from airflow.utils.session import create_session
 from airflow.utils.state import State
 from airflow.utils.task_group import TaskGroup
 from airflow.utils.types import DagRunType
+from tests.operators.test_python import Call, assert_calls_equal, 
build_recording_function
+from tests.test_utils.db import clear_db_runs
 
 DEFAULT_DATE = timezone.datetime(2016, 1, 1)
 END_DATE = timezone.datetime(2016, 1, 2)
@@ -48,66 +48,19 @@ TI_CONTEXT_ENV_VARS = [
 ]
 
 
-class Call:
-    def __init__(self, *args, **kwargs):
-        self.args = args
-        self.kwargs = kwargs
+class TestAirflowTaskDecorator:
+    def setup_class(self):
+        clear_db_runs()
 
-
-def build_recording_function(calls_collection):
-    """
-    We can not use a Mock instance as a PythonOperator callable function or 
some tests fail with a
-    TypeError: Object of type Mock is not JSON serializable
-    Then using this custom function recording custom Call objects for further 
testing
-    (replacing Mock.assert_called_with assertion method)
-    """
-
-    def recording_function(*args, **kwargs):
-        calls_collection.append(Call(*args, **kwargs))
-
-    return recording_function
-
-
-class TestPythonBase(unittest.TestCase):
-    """Base test class for TestPythonOperator and TestPythonSensor classes"""
-
-    @classmethod
-    def setUpClass(cls):
-        super().setUpClass()
-
-        with create_session() as session:
-            session.query(DagRun).delete()
-            session.query(TI).delete()
-
-    def setUp(self):
-        super().setUp()
-        self.dag = DAG('test_dag', default_args={'owner': 'airflow', 
'start_date': DEFAULT_DATE})
-        self.addCleanup(self.dag.clear)
-        self.clear_run()
-        self.addCleanup(self.clear_run)
-
-    def tearDown(self):
-        super().tearDown()
-
-        with create_session() as session:
-            session.query(DagRun).delete()
-            session.query(TI).delete()
-
-    def clear_run(self):
+    def setup_method(self):
+        self.dag = DAG("test_dag", default_args={"owner": "airflow", 
"start_date": DEFAULT_DATE})
         self.run = False
 
-    def _assert_calls_equal(self, first, second):
-        assert isinstance(first, Call)
-        assert isinstance(second, Call)
-        assert first.args == second.args
-        # eliminate context (conf, dag_run, task_instance, etc.)
-        test_args = ["an_int", "a_date", "a_templated_string"]
-        first.kwargs = {key: value for (key, value) in first.kwargs.items() if 
key in test_args}
-        second.kwargs = {key: value for (key, value) in second.kwargs.items() 
if key in test_args}
-        assert first.kwargs == second.kwargs
-
+    def teardown_method(self):
+        self.dag.clear()
+        self.run = False
+        clear_db_runs()
 
-class TestAirflowTaskDecorator(TestPythonBase):
     def test_python_operator_python_callable_is_callable(self):
         """Tests that @task will only instantiate if
         the python_callable argument is callable."""
@@ -115,22 +68,34 @@ class TestAirflowTaskDecorator(TestPythonBase):
         with pytest.raises(TypeError):
             task_decorator(not_callable, dag=self.dag)
 
-    @parameterized.expand([["dict"], ["dict[str, int]"], ["Dict"], ["Dict[str, 
int]"]])
-    def test_infer_multiple_outputs_using_dict_typing(self, 
test_return_annotation):
-        if sys.version_info < (3, 9) and test_return_annotation == "dict[str, 
int]":
-            raise pytest.skip("dict[...] not a supported typing prior to 
Python 3.9")
-
-            @task_decorator
-            def identity_dict(x: int, y: int) -> eval(test_return_annotation):
-                return {"x": x, "y": y}
-
-            assert identity_dict(5, 5).operator.multiple_outputs is True
-
-            @task_decorator
-            def identity_dict_stringified(x: int, y: int) -> 
test_return_annotation:
-                return {"x": x, "y": y}
+    @pytest.mark.parametrize(
+        "resolve",
+        [
+            pytest.param(eval, id="eval"),
+            pytest.param(lambda t: t, id="stringify"),
+        ],
+    )
+    @pytest.mark.parametrize(
+        "annotation",
+        [
+            "dict",
+            pytest.param(
+                "dict[str, int]",
+                marks=pytest.mark.skipif(
+                    sys.version_info < (3, 9),
+                    reason="PEP 585 is implemented in Python 3.9",
+                ),
+            ),
+            "Dict",
+            "Dict[str, int]",
+        ],
+    )
+    def test_infer_multiple_outputs_using_dict_typing(self, resolve, 
annotation):
+        @task_decorator
+        def identity_dict(x: int, y: int) -> resolve(annotation):
+            return {"x": x, "y": y}
 
-            assert identity_dict_stringified(5, 5).operator.multiple_outputs 
is True
+        assert identity_dict(5, 5).operator.multiple_outputs is True
 
     def test_infer_multiple_outputs_using_other_typing(self):
         @task_decorator
@@ -288,7 +253,7 @@ class TestAirflowTaskDecorator(TestPythonBase):
 
         ds_templated = DEFAULT_DATE.date().isoformat()
         assert len(recorded_calls) == 1
-        self._assert_calls_equal(
+        assert_calls_equal(
             recorded_calls[0],
             Call(
                 4,
@@ -319,7 +284,7 @@ class TestAirflowTaskDecorator(TestPythonBase):
         ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
 
         assert len(recorded_calls) == 1
-        self._assert_calls_equal(
+        assert_calls_equal(
             recorded_calls[0],
             Call(
                 an_int=4,
diff --git a/tests/decorators/test_python_virtualenv.py 
b/tests/decorators/test_python_virtualenv.py
index bce9c53..01e646c 100644
--- a/tests/decorators/test_python_virtualenv.py
+++ b/tests/decorators/test_python_virtualenv.py
@@ -25,8 +25,6 @@ import pytest
 from airflow.decorators import task
 from airflow.utils import timezone
 
-from .test_python import TestPythonBase
-
 DEFAULT_DATE = timezone.datetime(2016, 1, 1)
 END_DATE = timezone.datetime(2016, 1, 2)
 INTERVAL = timedelta(hours=12)
@@ -43,31 +41,31 @@ TI_CONTEXT_ENV_VARS = [
 PYTHON_VERSION = sys.version_info[0]
 
 
-class TestPythonVirtualenvDecorator(TestPythonBase):
-    def test_add_dill(self):
+class TestPythonVirtualenvDecorator:
+    def test_add_dill(self, dag_maker):
         @task.virtualenv(use_dill=True, system_site_packages=False)
         def f():
             """Ensure dill is correctly installed."""
             import dill  # noqa: F401
 
-        with self.dag:
+        with dag_maker():
             ret = f()
 
         ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
 
-    def test_no_requirements(self):
+    def test_no_requirements(self, dag_maker):
         """Tests that the python callable is invoked on task run."""
 
         @task.virtualenv()
         def f():
             pass
 
-        with self.dag:
+        with dag_maker():
             ret = f()
 
         ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
 
-    def test_no_system_site_packages(self):
+    def test_no_system_site_packages(self, dag_maker):
         @task.virtualenv(system_site_packages=False, 
python_version=PYTHON_VERSION, use_dill=True)
         def f():
             try:
@@ -76,12 +74,12 @@ class TestPythonVirtualenvDecorator(TestPythonBase):
                 return True
             raise Exception
 
-        with self.dag:
+        with dag_maker():
             ret = f()
 
         ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
 
-    def test_system_site_packages(self):
+    def test_system_site_packages(self, dag_maker):
         @task.virtualenv(
             system_site_packages=False,
             requirements=['funcsigs'],
@@ -91,12 +89,12 @@ class TestPythonVirtualenvDecorator(TestPythonBase):
         def f():
             import funcsigs  # noqa: F401
 
-        with self.dag:
+        with dag_maker():
             ret = f()
 
         ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
 
-    def test_with_requirements_pinned(self):
+    def test_with_requirements_pinned(self, dag_maker):
         @task.virtualenv(
             system_site_packages=False,
             requirements=['funcsigs==0.4'],
@@ -109,12 +107,12 @@ class TestPythonVirtualenvDecorator(TestPythonBase):
             if funcsigs.__version__ != '0.4':
                 raise Exception
 
-        with self.dag:
+        with dag_maker():
             ret = f()
 
         ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
 
-    def test_unpinned_requirements(self):
+    def test_unpinned_requirements(self, dag_maker):
         @task.virtualenv(
             system_site_packages=False,
             requirements=['funcsigs', 'dill'],
@@ -124,23 +122,23 @@ class TestPythonVirtualenvDecorator(TestPythonBase):
         def f():
             import funcsigs  # noqa: F401
 
-        with self.dag:
+        with dag_maker():
             ret = f()
 
         ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
 
-    def test_fail(self):
+    def test_fail(self, dag_maker):
         @task.virtualenv()
         def f():
             raise Exception
 
-        with self.dag:
+        with dag_maker():
             ret = f()
 
         with pytest.raises(CalledProcessError):
             ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
 
-    def test_python_3(self):
+    def test_python_3(self, dag_maker):
         @task.virtualenv(python_version=3, use_dill=False, 
requirements=['dill'])
         def f():
             import sys
@@ -152,12 +150,12 @@ class TestPythonVirtualenvDecorator(TestPythonBase):
                 return
             raise Exception
 
-        with self.dag:
+        with dag_maker():
             ret = f()
 
         ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
 
-    def test_with_args(self):
+    def test_with_args(self, dag_maker):
         @task.virtualenv
         def f(a, b, c=False, d=False):
             if a == 0 and b == 1 and c and not d:
@@ -165,27 +163,27 @@ class TestPythonVirtualenvDecorator(TestPythonBase):
             else:
                 raise Exception
 
-        with self.dag:
+        with dag_maker():
             ret = f(0, 1, c=True)
 
         ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
 
-    def test_return_none(self):
+    def test_return_none(self, dag_maker):
         @task.virtualenv
         def f():
             return None
 
-        with self.dag:
+        with dag_maker():
             ret = f()
 
         ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
 
-    def test_nonimported_as_arg(self):
+    def test_nonimported_as_arg(self, dag_maker):
         @task.virtualenv
         def f(_):
             return None
 
-        with self.dag:
+        with dag_maker():
             ret = f(datetime.datetime.utcnow())
 
         ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
diff --git a/tests/operators/test_python.py b/tests/operators/test_python.py
index dda716e..24331ef 100644
--- a/tests/operators/test_python.py
+++ b/tests/operators/test_python.py
@@ -84,6 +84,17 @@ def build_recording_function(calls_collection):
     return recording_function
 
 
+def assert_calls_equal(first: Call, second: Call) -> None:
+    assert isinstance(first, Call)
+    assert isinstance(second, Call)
+    assert first.args == second.args
+    # eliminate context (conf, dag_run, task_instance, etc.)
+    test_args = ["an_int", "a_date", "a_templated_string"]
+    first.kwargs = {key: value for (key, value) in first.kwargs.items() if key 
in test_args}
+    second.kwargs = {key: value for (key, value) in second.kwargs.items() if 
key in test_args}
+    assert first.kwargs == second.kwargs
+
+
 class TestPythonBase(unittest.TestCase):
     """Base test class for TestPythonOperator and TestPythonSensor classes"""
 
@@ -112,16 +123,6 @@ class TestPythonBase(unittest.TestCase):
     def clear_run(self):
         self.run = False
 
-    def _assert_calls_equal(self, first, second):
-        assert isinstance(first, Call)
-        assert isinstance(second, Call)
-        assert first.args == second.args
-        # eliminate context (conf, dag_run, task_instance, etc.)
-        test_args = ["an_int", "a_date", "a_templated_string"]
-        first.kwargs = {key: value for (key, value) in first.kwargs.items() if 
key in test_args}
-        second.kwargs = {key: value for (key, value) in second.kwargs.items() 
if key in test_args}
-        assert first.kwargs == second.kwargs
-
 
 class TestPythonOperator(TestPythonBase):
     def do_run(self):
@@ -176,7 +177,7 @@ class TestPythonOperator(TestPythonBase):
 
         ds_templated = DEFAULT_DATE.date().isoformat()
         assert 1 == len(recorded_calls)
-        self._assert_calls_equal(
+        assert_calls_equal(
             recorded_calls[0],
             Call(
                 4,
@@ -213,7 +214,7 @@ class TestPythonOperator(TestPythonBase):
         task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
 
         assert 1 == len(recorded_calls)
-        self._assert_calls_equal(
+        assert_calls_equal(
             recorded_calls[0],
             Call(
                 an_int=4,
diff --git a/tests/sensors/test_python.py b/tests/sensors/test_python.py
index 63ea6ce..3c21ef5 100644
--- a/tests/sensors/test_python.py
+++ b/tests/sensors/test_python.py
@@ -27,33 +27,35 @@ from airflow.sensors.python import PythonSensor
 from airflow.utils.state import State
 from airflow.utils.timezone import datetime
 from airflow.utils.types import DagRunType
-from tests.operators.test_python import Call, TestPythonBase, 
build_recording_function
+from tests.operators.test_python import Call, assert_calls_equal, 
build_recording_function
 
 DEFAULT_DATE = datetime(2015, 1, 1)
 
 
-class TestPythonSensor(TestPythonBase):
-    def test_python_sensor_true(self):
-        op = PythonSensor(task_id='python_sensor_check_true', 
python_callable=lambda: True, dag=self.dag)
+class TestPythonSensor:
+    def test_python_sensor_true(self, dag_maker):
+        with dag_maker():
+            op = PythonSensor(task_id='python_sensor_check_true', 
python_callable=lambda: True)
         op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, 
ignore_ti_state=True)
 
-    def test_python_sensor_false(self):
-        op = PythonSensor(
-            task_id='python_sensor_check_false',
-            timeout=0.01,
-            poke_interval=0.01,
-            python_callable=lambda: False,
-            dag=self.dag,
-        )
+    def test_python_sensor_false(self, dag_maker):
+        with dag_maker():
+            op = PythonSensor(
+                task_id='python_sensor_check_false',
+                timeout=0.01,
+                poke_interval=0.01,
+                python_callable=lambda: False,
+            )
         with pytest.raises(AirflowSensorTimeout):
             op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, 
ignore_ti_state=True)
 
-    def test_python_sensor_raise(self):
-        op = PythonSensor(task_id='python_sensor_check_raise', 
python_callable=lambda: 1 / 0, dag=self.dag)
+    def test_python_sensor_raise(self, dag_maker):
+        with dag_maker():
+            op = PythonSensor(task_id='python_sensor_check_raise', 
python_callable=lambda: 1 / 0)
         with pytest.raises(ZeroDivisionError):
             op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, 
ignore_ti_state=True)
 
-    def test_python_callable_arguments_are_templatized(self):
+    def test_python_callable_arguments_are_templatized(self, dag_maker):
         """Test PythonSensor op_args are templatized"""
         recorded_calls = []
 
@@ -62,18 +64,18 @@ class TestPythonSensor(TestPythonBase):
         Named = namedtuple('Named', ['var1', 'var2'])
         named_tuple = Named('{{ ds }}', 'unchanged')
 
-        task = PythonSensor(
-            task_id='python_sensor',
-            timeout=0.01,
-            poke_interval=0.3,
-            # a Mock instance cannot be used as a callable function or test 
fails with a
-            # TypeError: Object of type Mock is not JSON serializable
-            python_callable=build_recording_function(recorded_calls),
-            op_args=[4, date(2019, 1, 1), "dag {{dag.dag_id}} ran on {{ds}}.", 
named_tuple],
-            dag=self.dag,
-        )
-
-        self.dag.create_dagrun(
+        with dag_maker() as dag:
+            task = PythonSensor(
+                task_id='python_sensor',
+                timeout=0.01,
+                poke_interval=0.3,
+                # a Mock instance cannot be used as a callable function or 
test fails with a
+                # TypeError: Object of type Mock is not JSON serializable
+                python_callable=build_recording_function(recorded_calls),
+                op_args=[4, date(2019, 1, 1), "dag {{dag.dag_id}} ran on 
{{ds}}.", named_tuple],
+            )
+
+        dag.create_dagrun(
             run_type=DagRunType.MANUAL,
             execution_date=DEFAULT_DATE,
             data_interval=(DEFAULT_DATE, DEFAULT_DATE),
@@ -84,36 +86,36 @@ class TestPythonSensor(TestPythonBase):
             task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
 
         ds_templated = DEFAULT_DATE.date().isoformat()
-        self._assert_calls_equal(
+        assert_calls_equal(
             recorded_calls[0],
             Call(
                 4,
                 date(2019, 1, 1),
-                f"dag {self.dag.dag_id} ran on {ds_templated}.",
+                f"dag {dag.dag_id} ran on {ds_templated}.",
                 Named(ds_templated, 'unchanged'),
             ),
         )
 
-    def test_python_callable_keyword_arguments_are_templatized(self):
+    def test_python_callable_keyword_arguments_are_templatized(self, 
dag_maker):
         """Test PythonSensor op_kwargs are templatized"""
         recorded_calls = []
 
-        task = PythonSensor(
-            task_id='python_sensor',
-            timeout=0.01,
-            poke_interval=0.01,
-            # a Mock instance cannot be used as a callable function or test 
fails with a
-            # TypeError: Object of type Mock is not JSON serializable
-            python_callable=build_recording_function(recorded_calls),
-            op_kwargs={
-                'an_int': 4,
-                'a_date': date(2019, 1, 1),
-                'a_templated_string': "dag {{dag.dag_id}} ran on {{ds}}.",
-            },
-            dag=self.dag,
-        )
-
-        self.dag.create_dagrun(
+        with dag_maker() as dag:
+            task = PythonSensor(
+                task_id='python_sensor',
+                timeout=0.01,
+                poke_interval=0.01,
+                # a Mock instance cannot be used as a callable function or 
test fails with a
+                # TypeError: Object of type Mock is not JSON serializable
+                python_callable=build_recording_function(recorded_calls),
+                op_kwargs={
+                    'an_int': 4,
+                    'a_date': date(2019, 1, 1),
+                    'a_templated_string': "dag {{dag.dag_id}} ran on {{ds}}.",
+                },
+            )
+
+        dag.create_dagrun(
             run_type=DagRunType.MANUAL,
             execution_date=DEFAULT_DATE,
             data_interval=(DEFAULT_DATE, DEFAULT_DATE),
@@ -123,11 +125,11 @@ class TestPythonSensor(TestPythonBase):
         with pytest.raises(AirflowSensorTimeout):
             task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
 
-        self._assert_calls_equal(
+        assert_calls_equal(
             recorded_calls[0],
             Call(
                 an_int=4,
                 a_date=date(2019, 1, 1),
-                a_templated_string=f"dag {self.dag.dag_id} ran on 
{DEFAULT_DATE.date().isoformat()}.",
+                a_templated_string=f"dag {dag.dag_id} ran on 
{DEFAULT_DATE.date().isoformat()}.",
             ),
         )

Reply via email to