JonnyIncognito commented on a change in pull request #6210: [AIRFLOW-5567] 
BaseReschedulePokeOperator
URL: https://github.com/apache/airflow/pull/6210#discussion_r339979856
 
 

 ##########
 File path: tests/models/test_base_async_operator.py
 ##########
 @@ -0,0 +1,288 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+""" Tests for BaseAsyncOperator"""
+
+import random
+import unittest
+import uuid
+from datetime import timedelta
+from unittest.mock import Mock  # pylint: disable=ungrouped-imports
+
+from freezegun import freeze_time
+from parameterized import parameterized
+
+from airflow import DAG, settings
+from airflow.exceptions import AirflowSensorTimeout
+from airflow.models import DagRun, TaskInstance, TaskReschedule
+from airflow.models.base_async_operator import BaseAsyncOperator
+from airflow.models.xcom import XCOM_EXTERNAL_RESOURCE_ID_KEY
+from airflow.operators.dummy_operator import DummyOperator
+from airflow.utils import timezone
+from airflow.utils.state import State
+from airflow.utils.timezone import datetime
+
+DEFAULT_DATE = datetime(2015, 1, 1)
+TEST_DAG_ID = 'unit_test_dag'
+DUMMY_OP = 'dummy_op'
+ASYNC_OP = 'async_op'
+
+
+def _job_id():
+    """yield a random job id."""
+    return 'job_id-{}'.format(uuid.uuid4())
+
+
+ALL_ID_TYPES = [
+    (_job_id(),),
+    (random.randint(0, 10**10),),
+    ([_job_id(), _job_id()],),
+    ({'job1': _job_id()},),
+    (None,)
+]
+
+
+class DummyAsyncOperator(BaseAsyncOperator):
+    """
+    Test subclass of BaseAsyncOperator
+    """
+    def __init__(self, return_value=False,
+                 **kwargs):
+        super().__init__(**kwargs)
+        self.return_value = return_value
+
+    def poke(self, context):
+        """successful on first poke"""
+        return self.return_value
+
+    def submit_request(self, context):
+        """pretend to submit a job w/ random id"""
+        return _job_id()
+
+    def process_result(self, context):
+        """attempt to get the external resource_id"""
+        return self.get_external_resource_id(context)
+
+
+class TestBaseAsyncOperator(unittest.TestCase):
+    """Test cases for BaseAsyncOperator."""
+    def setUp(self):
+        args = {
+            'owner': 'airflow',
+            'start_date': DEFAULT_DATE
+        }
+        self.dag = DAG(TEST_DAG_ID, default_args=args)
+
+        session = settings.Session()
+        session.query(TaskReschedule).delete()
+        session.query(DagRun).delete()
+        session.query(TaskInstance).delete()
+        session.commit()
+
+    def _make_dag_run(self):
+        return self.dag.create_dagrun(
+            run_id='manual__',
+            start_date=timezone.utcnow(),
+            execution_date=DEFAULT_DATE,
+            state=State.RUNNING
+        )
+
+    def _make_async_op(self, return_value, resource_id=None, **kwargs):
+        poke_interval = 'poke_interval'
+        timeout = 'timeout'
+        if poke_interval not in kwargs:
+            kwargs[poke_interval] = 0
+        if timeout not in kwargs:
+            kwargs[timeout] = 0
+
+        async_op = DummyAsyncOperator(
+            task_id=ASYNC_OP,
+            return_value=return_value,
+            resource_id=resource_id,
+            dag=self.dag,
+            **kwargs
+        )
+
+        dummy_op = DummyOperator(
+            task_id=DUMMY_OP,
+            dag=self.dag
+        )
+        dummy_op.set_upstream(async_op)
+        return async_op
+
+    @classmethod
+    def _run(cls, task):
+        task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, 
ignore_ti_state=True)
+
+    def test_ok(self):
+        """ Test normal behavior"""
+        async_op = self._make_async_op(True)
+        dr = self._make_dag_run()
+
+        self._run(async_op)
+        tis = dr.get_task_instances()
+        self.assertEqual(len(tis), 2)
+        for ti in tis:
+            if ti.task_id == ASYNC_OP:
+                self.assertEqual(ti.state, State.SUCCESS)
+            if ti.task_id == DUMMY_OP:
+                self.assertEqual(ti.state, State.NONE)
+
+    def test_poke_fail(self):
+        """ Test failure in poke"""
+        async_op = self._make_async_op(False)
+        dr = self._make_dag_run()
+
+        with self.assertRaises(AirflowSensorTimeout):
+            self._run(async_op)
+        tis = dr.get_task_instances()
+        self.assertEqual(len(tis), 2)
+        for ti in tis:
+            if ti.task_id == ASYNC_OP:
+                self.assertEqual(ti.state, State.FAILED)
+            if ti.task_id == DUMMY_OP:
+                self.assertEqual(ti.state, State.NONE)
+
+    @parameterized.expand(ALL_ID_TYPES)
+    def test_set_get_external_resource_id(self, resource_id):
+        """ test resource id mechanism """
+        async_op = self._make_async_op(
+            return_value=None,
+            poke_interval=10,
+            timeout=25)
+
+        context = TaskInstance(task=async_op,
+                               
execution_date=DEFAULT_DATE).get_template_context()
+        async_op.set_external_resource_id(context, resource_id)
+        self.assertEqual(resource_id, 
async_op.get_external_resource_id(context))
+
+    def test_xcom(self):
+        """test xcom is set w/ job id. """
+        async_op = self._make_async_op(
+            return_value=None,
+            poke_interval=10,
+            timeout=25)
+        async_op.process_result = Mock()
+        async_op.poke = Mock(side_effect=[True])
+
+        dr = self._make_dag_run()
+
+        date1 = timezone.utcnow()
+        with freeze_time(date1):
+            self._run(async_op)
+        tis = dr.get_task_instances()
+
+        # Check that XCom was set to job_id.
+        for ti in tis:
+            if ti.task_id == ASYNC_OP:
+                resource_id = ti.xcom_pull(task_ids=ASYNC_OP,
+                                           key=XCOM_EXTERNAL_RESOURCE_ID_KEY)
+                self.assertIsNotNone(resource_id)
+                self.assertTrue(resource_id.startswith('job_id'))
+
+    def test_ok_with_reschedule(self):
 
 Review comment:
   @Fokko the ability to retain state is super important for a rescheduled 
task, I agree with @dstandish on this. These are mostly used where external 
long-lived resources are created. For sensors we've dodged the issue because 
they tend to use XCom to fetch the state from a prior task. But if we're now 
going further and having multi-phased / composed tasks, there needs to be a 
general Airflow mechanism for that state rather than hacking around it with 
e.g. cloud-provider specific things like labels.
   
   > Keeping state makes the function impure by definition.
   
   Depends on the scope of the idempotency. I can see why this argument could 
be used for existing operators that are supposed to do the same thing for every 
call of `execute()`. However, for ones with multiple phases they should be 
idempotent by the time the task succeeds, across multiple calls to `execute()`, 
i.e. the scope is a bit wider.
   
   How about we only clear the state when needed? Options:
   
   1. Delegate clearing to the task (default implementation always clears). If 
you look at my proposal in other thread 
([link](https://github.com/apache/airflow/pull/6210#discussion_r339016148)) the 
task could clear all state on the initial phase but not for subsequent phases / 
reschedules.
   
   1. Clear state only if there was no prior transition to 
`State.UP_FOR_RESCHEDULE`. I've seen code that detects this but it'd add a cost 
for every user 
([link](https://github.com/apache/airflow/blob/master/airflow/sensors/base_sensor_operator.py#L104)).
   
   1. Add some concept of a reschedule counter to the task instance and only 
clear all state on the first execution. If it were more along the lines of a 
phase counter, it could later be used for tasks that have more than two phases 
to figure out which code path to follow on each `execute()`.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to