ashb commented on a change in pull request #8962:
URL: https://github.com/apache/airflow/pull/8962#discussion_r429922157



##########
File path: tests/operators/test_python.py
##########
@@ -311,6 +315,350 @@ def func(**context):
         python_operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
 
 
+class TestAirflowTask(unittest.TestCase):
+
+    @classmethod
+    def setUpClass(cls):
+        super().setUpClass()
+
+        with create_session() as session:
+            session.query(DagRun).delete()
+            session.query(TI).delete()
+
+    def setUp(self):
+        super().setUp()
+        self.dag = DAG(
+            'test_dag',
+            default_args={
+                'owner': 'airflow',
+                'start_date': DEFAULT_DATE})
+        self.addCleanup(self.dag.clear)
+
+    def tearDown(self):
+        super().tearDown()
+
+        with create_session() as session:
+            session.query(DagRun).delete()
+            session.query(TI).delete()
+
+    def _assert_calls_equal(self, first, second):
+        assert isinstance(first, Call)
+        assert isinstance(second, Call)
+        assert first.args == second.args
+        # eliminate context (conf, dag_run, task_instance, etc.)
+        test_args = ["an_int", "a_date", "a_templated_string"]
+        first.kwargs = {
+            key: value
+            for (key, value) in first.kwargs.items()
+            if key in test_args
+        }
+        second.kwargs = {
+            key: value
+            for (key, value) in second.kwargs.items()
+            if key in test_args
+        }
+        assert first.kwargs == second.kwargs
+
+    def test_python_operator_python_callable_is_callable(self):
+        """Tests that @task will only instantiate if
+        the python_callable argument is callable."""
+        not_callable = {}
+        with pytest.raises(AirflowException):
+            task_decorator(not_callable, dag=self.dag)
+
+    def test_python_callable_arguments_are_templatized(self):
+        """Test @task op_args are templatized"""
+        recorded_calls = []
+
+        # Create a named tuple and ensure it is still preserved
+        # after the rendering is done
+        Named = namedtuple('Named', ['var1', 'var2'])
+        named_tuple = Named('{{ ds }}', 'unchanged')
+
+        task = task_decorator(
+            # a Mock instance cannot be used as a callable function or test 
fails with a
+            # TypeError: Object of type Mock is not JSON serializable
+            build_recording_function(recorded_calls),
+            dag=self.dag)
+        task(4, date(2019, 1, 1), "dag {{dag.dag_id}} ran on {{ds}}.", 
named_tuple)
+
+        self.dag.create_dagrun(
+            run_id=DagRunType.MANUAL.value,
+            execution_date=DEFAULT_DATE,
+            start_date=DEFAULT_DATE,
+            state=State.RUNNING
+        )
+        task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+        ds_templated = DEFAULT_DATE.date().isoformat()
+        assert len(recorded_calls) == 1
+        self._assert_calls_equal(
+            recorded_calls[0],
+            Call(4,
+                 date(2019, 1, 1),
+                 "dag {} ran on {}.".format(self.dag.dag_id, ds_templated),
+                 Named(ds_templated, 'unchanged'))
+        )
+
+    def test_python_callable_keyword_arguments_are_templatized(self):
+        """Test PythonOperator op_kwargs are templatized"""
+        recorded_calls = []
+
+        task = task_decorator(
+            # a Mock instance cannot be used as a callable function or test 
fails with a
+            # TypeError: Object of type Mock is not JSON serializable
+            build_recording_function(recorded_calls),
+            dag=self.dag
+        )
+        task(an_int=4, a_date=date(2019, 1, 1), a_templated_string="dag 
{{dag.dag_id}} ran on {{ds}}.")
+        self.dag.create_dagrun(
+            run_id=DagRunType.MANUAL.value,
+            execution_date=DEFAULT_DATE,
+            start_date=DEFAULT_DATE,
+            state=State.RUNNING
+        )
+        task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+        assert len(recorded_calls) == 1
+        self._assert_calls_equal(
+            recorded_calls[0],
+            Call(an_int=4,
+                 a_date=date(2019, 1, 1),
+                 a_templated_string="dag {} ran on {}.".format(
+                     self.dag.dag_id, DEFAULT_DATE.date().isoformat()))
+        )
+
+    def test_copy_in_dag(self):
+        """Test copy method to reuse tasks in a DAG"""
+
+        @task_decorator
+        def do_run():
+            return 4
+        with self.dag:
+            do_run()
+            assert ['do_run'] == self.dag.task_ids
+            do_run_1 = do_run.copy()
+            do_run_2 = do_run.copy()
+        assert do_run_1.task_id == 'do_run__1'
+        assert do_run_2.task_id == 'do_run__2'
+
+    def test_copy(self):
+        """Test copy method outside of a DAG"""
+        @task_decorator
+        def do_run():
+            return 4
+
+        @task_decorator
+        def do__run():
+            return 4
+        do_run_1 = do_run.copy()
+        do_run_2 = do_run.copy()
+        do__run_1 = do__run.copy()
+        do__run_2 = do__run.copy()
+        with self.dag:
+            do_run()
+            assert ['do_run'] == self.dag.task_ids
+            do_run_1()
+            do_run_2()
+            do__run()
+            do__run_1()
+            do__run_2()
+
+        assert do_run_1.task_id == 'do_run__1'
+        assert do_run_2.task_id == 'do_run__2'
+        assert do__run_1.task_id == 'do__run__1'
+        assert do__run_2.task_id == 'do__run__2'
+
+    def test_copy_10(self):
+        """Test copy method outside of a DAG"""
+        @task_decorator
+        def __do_run():
+            return 4
+
+        with self.dag:
+            __do_run()
+            do_runs = [__do_run.copy() for _ in range(20)]
+
+        assert do_runs[-1].task_id == '__do_run__20'
+
+    def test_dict_outputs(self):
+        """Tests pushing multiple outputs as a dictionary"""
+
+        @task_decorator(multiple_outputs=True)
+        def return_dict(number: int):
+            return {
+                'number': number + 1,
+                43: 43

Review comment:
       Yes, I mean we should define, and test, what we want the behaviour to be 
in this case.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to