potiuk commented on a change in pull request #5162: [AIRFLOW-4358] Speed up 
test_jobs by not running tasks
URL: https://github.com/apache/airflow/pull/5162#discussion_r279494774
 
 

 ##########
 File path: tests/executors/test_executor.py
 ##########
 @@ -16,47 +16,73 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+
+from collections import defaultdict
+
 from airflow.executors.base_executor import BaseExecutor
 from airflow.utils.state import State
-
-from airflow import settings
+from airflow.utils.db import create_session
 
 
 class TestExecutor(BaseExecutor):
     """
     TestExecutor is used for unit testing purposes.
     """
 
-    def __init__(self, do_update=False, *args, **kwargs):
+    def __init__(self, do_update=True, *args, **kwargs):
         self.do_update = do_update
         self._running = []
+
+        # A list of "batches" of tasks
         self.history = []
+        # All the tasks, in a stable sort order
+        self.sorted_tasks = []
+        self.mock_task_results = defaultdict(lambda: State.SUCCESS)
 
         super().__init__(*args, **kwargs)
 
-    def execute_async(self, key, command, queue=None):
-        self.log.debug("{} running task instances".format(len(self.running)))
-        self.log.debug("{} in queue".format(len(self.queued_tasks)))
-
     def heartbeat(self):
-        session = settings.Session()
-        if self.do_update:
+        if not self.do_update:
+            return
+
+        with create_session() as session:
             self.history.append(list(self.queued_tasks.values()))
-            while len(self._running) > 0:
-                ti = self._running.pop()
-                ti.set_state(State.SUCCESS, session)
-            for key, val in list(self.queued_tasks.items()):
-                (command, priority, queue, simple_ti) = val
-                ti = simple_ti.construct_task_instance()
-                ti.set_state(State.RUNNING, session)
-                self._running.append(ti)
-                self.queued_tasks.pop(key)
 
-        session.commit()
-        session.close()
+            # Create a stable/predictable sort order for events in self.history
+            # for tests!
+            def sort_by(item):
+                key, val = item
+                (dag_id, task_id, date, try_number) = key
+                (cmd, prio, queue, sti) = val
+                # Sort by priority (DESC), then date,task, try
+                return -prio, date, dag_id, task_id, try_number
+            sorted_queue = sorted(self.queued_tasks.items(), key=sort_by)
+
+            for (key, (_, _, _, simple_ti)) in sorted_queue:
+                self.queued_tasks.pop(key)
+                state = self.mock_task_results[key]
+                ti = simple_ti.construct_task_instance(session=session, 
lock_for_update=True)
+                ti.set_state(state, session=session)
+                self.change_state(key, state)
 
     def terminate(self):
         pass
 
     def end(self):
         self.sync()
+
+    def change_state(self, key, state):
+        super().change_state(key, state)
+        # The normal event buffer is cleared after reading, we want to keep
+        # a list of all events for testing
+        self.sorted_tasks.append((key, state))
+
+    def mock_task_fail(self, dag_id, task_id, date, try_number=1):
+        """
+        Set the mock outcome of running this particular task instances to
+        FAILED.
+
+        If the task identified by the tuple ``(dag_id, task_id, date,
+        try_number)`` is run by this executor it's state will be FALIED.
 
 Review comment:
   FALIED->FAILED

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to