Repository: incubator-airflow Updated Branches: refs/heads/master e05d3b4df -> 0dd00291d
[AIRFLOW-1345] Dont expire TIs on each scheduler loop TIs get expired on commit, which causes any access to their properties to cause a new query to the DB to be issued, causing an n+1 query issue, even when the TI is not scheduled. This change makes all queries batches, which will make scheduling substantially faster. Closes #2397 from saguziel/aguziel-commit-last Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/0dd00291 Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/0dd00291 Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/0dd00291 Branch: refs/heads/master Commit: 0dd00291d74e10a30ed328c8542416b78e24bc06 Parents: e05d3b4 Author: Alex Guziel <[email protected]> Authored: Fri Jul 14 16:38:25 2017 -0700 Committer: Alex Guziel <[email protected]> Committed: Fri Jul 14 16:38:25 2017 -0700 ---------------------------------------------------------------------- airflow/jobs.py | 243 ++++++++++++++++++++++++++++++++------------- tests/jobs.py | 271 +++++++++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 441 insertions(+), 73 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/0dd00291/airflow/jobs.py ---------------------------------------------------------------------- diff --git a/airflow/jobs.py b/airflow/jobs.py index e8431b7..6b63df0 100644 --- a/airflow/jobs.py +++ b/airflow/jobs.py @@ -1016,13 +1016,10 @@ class SchedulerJob(BaseJob): tis_changed, new_state)) @provide_session - def _execute_task_instances(self, - simple_dag_bag, - states, - session=None): + def _find_executable_task_instances(self, simple_dag_bag, states, session=None): """ - Fetches task instances from ORM in the specified states, figures - out pool limits, and sends them to the executor for execution. + Finds TIs that are ready for execution with respect to pool limits, + dag concurrency, executor state, and priority. :param simple_dag_bag: TaskInstances associated with DAGs in the simple_dag_bag will be fetched from the DB and executed @@ -1031,18 +1028,20 @@ class SchedulerJob(BaseJob): :type executor: BaseExecutor :param states: Execute TaskInstances in these states :type states: Tuple[State] - :return: None + :return: List[TaskInstance] """ + executable_tis = [] + # Get all the queued task instances from associated with scheduled - # DagRuns. + # DagRuns which are not backfilled, in the given states, + # and the dag is not pasued TI = models.TaskInstance DR = models.DagRun DM = models.DagModel - task_instances_to_examine = ( + ti_query = ( session .query(TI) .filter(TI.dag_id.in_(simple_dag_bag.dag_ids)) - .filter(TI.state.in_(states)) .outerjoin(DR, and_(DR.dag_id == TI.dag_id, DR.execution_date == TI.execution_date)) @@ -1051,14 +1050,19 @@ class SchedulerJob(BaseJob): .outerjoin(DM, DM.dag_id==TI.dag_id) .filter(or_(DM.dag_id == None, not_(DM.is_paused))) - .all() ) + if None in states: + ti_query = ti_query.filter(or_(TI.state == None, TI.state.in_(states))) + else: + ti_query = ti_query.filter(TI.state.in_(states)) + + task_instances_to_examine = ti_query.all() - # Put one task instance on each line if len(task_instances_to_examine) == 0: - self.logger.info("No tasks to send to the executor") - return + self.logger.info("No tasks to consider for execution.") + return executable_tis + # Put one task instance on each line task_instance_str = "\n\t".join( ["{}".format(x) for x in task_instances_to_examine]) self.logger.info("Tasks up for execution:\n\t{}".format(task_instance_str)) @@ -1130,63 +1134,170 @@ class SchedulerJob(BaseJob): if self.executor.has_task(task_instance): - self.logger.debug("Not handling task {} as the executor reports it is running" + self.logger.debug(("Not handling task {} as the executor " + + "reports it is running") .format(task_instance.key)) continue + executable_tis.append(task_instance) + open_slots -= 1 + dag_id_to_possibly_running_task_count[dag_id] += 1 - command = " ".join(TI.generate_command( - task_instance.dag_id, - task_instance.task_id, - task_instance.execution_date, - local=True, - mark_success=False, - ignore_all_deps=False, - ignore_depends_on_past=False, - ignore_task_deps=False, - ignore_ti_state=False, - pool=task_instance.pool, - file_path=simple_dag_bag.get_dag(task_instance.dag_id).full_filepath, - pickle_id=simple_dag_bag.get_dag(task_instance.dag_id).pickle_id)) - - priority = task_instance.priority_weight - queue = task_instance.queue - self.logger.info("Sending to executor {} with priority {} and queue {}" - .format(task_instance.key, priority, queue)) - - # Set the state to queued - task_instance.refresh_from_db(lock_for_update=True, session=session) - if task_instance.state not in states: - self.logger.info("Task {} was set to {} outside this scheduler." - .format(task_instance.key, task_instance.state)) - session.commit() - continue + task_instance_str = "\n\t".join( + ["{}".format(x) for x in executable_tis]) + self.logger.info("Setting the follow tasks to queued state:\n\t{}" + .format(task_instance_str)) + return executable_tis - self.logger.info("Setting state of {} to {}".format( - task_instance.key, State.QUEUED)) - task_instance.state = State.QUEUED - task_instance.queued_dttm = (datetime.now() - if not task_instance.queued_dttm - else task_instance.queued_dttm) - session.merge(task_instance) - session.commit() - - # These attributes will be lost after the object expires, so save them. - task_id_ = task_instance.task_id - dag_id_ = task_instance.dag_id - execution_date_ = task_instance.execution_date - make_transient(task_instance) - task_instance.task_id = task_id_ - task_instance.dag_id = dag_id_ - task_instance.execution_date = execution_date_ - - self.executor.queue_command( - task_instance, - command, - priority=priority, - queue=queue) + @provide_session + def _change_state_for_executable_task_instances(self, task_instances, + acceptable_states, session=None): + """ + Changes the state of task instances in the list with one of the given states + to QUEUED atomically, and returns the TIs changed. + + :param task_instances: TaskInstances to change the state of + :type task_instances: List[TaskInstance] + :param acceptable_states: Filters the TaskInstances updated to be in these states + :type acceptable_states: Iterable[State] + :return: List[TaskInstance] + """ + if len(task_instances) == 0: + session.commit() + return [] - open_slots -= 1 - dag_id_to_possibly_running_task_count[dag_id] += 1 + TI = models.TaskInstance + filter_for_ti_state_change = ( + [and_( + TI.dag_id == ti.dag_id, + TI.task_id == ti.task_id, + TI.execution_date == ti.execution_date) + for ti in task_instances]) + ti_query = ( + session + .query(TI) + .filter(or_(*filter_for_ti_state_change))) + + if None in acceptable_states: + ti_query = ti_query.filter(or_(TI.state == None, TI.state.in_(acceptable_states))) + else: + ti_query = ti_query.filter(TI.state.in_(acceptable_states)) + + tis_to_set_to_queued = ( + ti_query + .with_for_update() + .all()) + if len(tis_to_set_to_queued) == 0: + self.logger.info("No tasks were able to have their state changed to queued.") + session.commit() + return [] + + # set TIs to queued state + for task_instance in tis_to_set_to_queued: + task_instance.state = State.QUEUED + task_instance.queued_dttm = (datetime.now() + if not task_instance.queued_dttm + else task_instance.queued_dttm) + session.merge(task_instance) + + # save which TIs we set before session expires them + filter_for_ti_enqueue = ([and_(TI.dag_id == ti.dag_id, + TI.task_id == ti.task_id, + TI.execution_date == ti.execution_date) + for ti in tis_to_set_to_queued]) + session.commit() + + # requery in batch since above was expired by commit + tis_to_be_queued = ( + session + .query(TI) + .filter(or_(*filter_for_ti_enqueue)) + .all()) + + task_instance_str = "\n\t".join( + ["{}".format(x) for x in tis_to_be_queued]) + self.logger.info("Setting the follow tasks to queued state:\n\t{}" + .format(task_instance_str)) + return tis_to_be_queued + + def _enqueue_task_instances_with_queued_state(self, simple_dag_bag, task_instances): + """ + Takes task_instances, which should have been set to queued, and enqueues them + with the executor. + + :param task_instances: TaskInstances to enqueue + :type task_instances: List[TaskInstance] + :param simple_dag_bag: Should contains all of the task_instances' dags + :type simple_dag_bag: SimpleDagBag + """ + TI = models.TaskInstance + # actually enqueue them + for task_instance in task_instances: + command = " ".join(TI.generate_command( + task_instance.dag_id, + task_instance.task_id, + task_instance.execution_date, + local=True, + mark_success=False, + ignore_all_deps=False, + ignore_depends_on_past=False, + ignore_task_deps=False, + ignore_ti_state=False, + pool=task_instance.pool, + file_path=simple_dag_bag.get_dag(task_instance.dag_id).full_filepath, + pickle_id=simple_dag_bag.get_dag(task_instance.dag_id).pickle_id)) + + priority = task_instance.priority_weight + queue = task_instance.queue + self.logger.info("Sending {} to executor with priority {} and queue {}" + .format(task_instance.key, priority, queue)) + + # save attributes so sqlalchemy doesnt expire them + copy_dag_id = task_instance.dag_id + copy_task_id = task_instance.task_id + copy_execution_date = task_instance.execution_date + make_transient(task_instance) + task_instance.dag_id = copy_dag_id + task_instance.task_id = copy_task_id + task_instance.execution_date = copy_execution_date + + self.executor.queue_command( + task_instance, + command, + priority=priority, + queue=queue) + + @provide_session + def _execute_task_instances(self, + simple_dag_bag, + states, + session=None): + """ + Attempts to execute TaskInstances that should be executed by the scheduler. + + There are three steps: + 1. Pick TIs by priority with the constraint that they are in the expected states + and that we do exceed max_active_runs or pool limits. + 2. Change the state for the TIs above atomically. + 3. Enqueue the TIs in the executor. + + :param simple_dag_bag: TaskInstances associated with DAGs in the + simple_dag_bag will be fetched from the DB and executed + :type simple_dag_bag: SimpleDagBag + :param states: Execute TaskInstances in these states + :type states: Tuple[State] + :return: None + """ + executable_tis = self._find_executable_task_instances(simple_dag_bag, states, + session=session) + tis_with_state_changed = self._change_state_for_executable_task_instances( + executable_tis, + states, + session=session) + self._enqueue_task_instances_with_queued_state( + simple_dag_bag, + tis_with_state_changed) + session.commit() + return len(tis_with_state_changed) def _process_dags(self, dagbag, dags, tis_out): """ http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/0dd00291/tests/jobs.py ---------------------------------------------------------------------- diff --git a/tests/jobs.py b/tests/jobs.py index 13bd9f5..e987e0c 100644 --- a/tests/jobs.py +++ b/tests/jobs.py @@ -28,7 +28,7 @@ from tempfile import mkdtemp from airflow import AirflowException, settings, models from airflow.bin import cli -from airflow.executors import SequentialExecutor +from airflow.executors import BaseExecutor, SequentialExecutor from airflow.jobs import BackfillJob, SchedulerJob, LocalTaskJob from airflow.models import DAG, DagModel, DagBag, DagRun, Pool, TaskInstance as TI from airflow.operators.dummy_operator import DummyOperator @@ -38,7 +38,7 @@ from airflow.utils.state import State from airflow.utils.timeout import timeout from airflow.utils.dag_processing import SimpleDagBag, list_py_file_paths -from mock import patch +from mock import Mock, patch from sqlalchemy.orm.session import make_transient from tests.executors.test_executor import TestExecutor @@ -718,8 +718,266 @@ class SchedulerJobTest(unittest.TestCase): ti1.refresh_from_db() self.assertEquals(State.SCHEDULED, ti1.state) - def test_concurrency(self): - dag_id = 'SchedulerJobTest.test_concurrency' + def test_find_executable_task_instances_backfill_nodagrun(self): + dag_id = 'SchedulerJobTest.test_find_executable_task_instances_backfill_nodagrun' + task_id_1 = 'dummy' + dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, concurrency=16) + task1 = DummyOperator(dag=dag, task_id=task_id_1) + dagbag = SimpleDagBag([dag]) + + scheduler = SchedulerJob(**self.default_scheduler_args) + session = settings.Session() + + dr1 = scheduler.create_dag_run(dag) + dr2 = scheduler.create_dag_run(dag) + dr2.run_id = BackfillJob.ID_PREFIX + 'asdf' + + ti_no_dagrun = TI(task1, DEFAULT_DATE - datetime.timedelta(days=1)) + ti_backfill = TI(task1, dr2.execution_date) + ti_with_dagrun = TI(task1, dr1.execution_date) + # ti_with_paused + ti_no_dagrun.state = State.SCHEDULED + ti_backfill.state = State.SCHEDULED + ti_with_dagrun.state = State.SCHEDULED + + session.merge(dr2) + session.merge(ti_no_dagrun) + session.merge(ti_backfill) + session.merge(ti_with_dagrun) + session.commit() + + res = scheduler._find_executable_task_instances( + dagbag, + states=[State.SCHEDULED], + session=session) + + self.assertEqual(2, len(res)) + res_keys = map(lambda x: x.key, res) + self.assertIn(ti_no_dagrun.key, res_keys) + self.assertIn(ti_with_dagrun.key, res_keys) + + def test_find_executable_task_instances_pool(self): + dag_id = 'SchedulerJobTest.test_find_executable_task_instances_pool' + task_id_1 = 'dummy' + task_id_2 = 'dummydummy' + dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, concurrency=16) + task1 = DummyOperator(dag=dag, task_id=task_id_1, pool='a') + task2 = DummyOperator(dag=dag, task_id=task_id_2, pool='b') + dagbag = SimpleDagBag([dag]) + + scheduler = SchedulerJob(**self.default_scheduler_args) + session = settings.Session() + + dr1 = scheduler.create_dag_run(dag) + dr2 = scheduler.create_dag_run(dag) + + tis = ([ + TI(task1, dr1.execution_date), + TI(task2, dr1.execution_date), + TI(task1, dr2.execution_date), + TI(task2, dr2.execution_date) + ]) + for ti in tis: + ti.state = State.SCHEDULED + session.merge(ti) + pool = models.Pool(pool='a', slots=1, description='haha') + pool2 = models.Pool(pool='b', slots=100, description='haha') + session.add(pool) + session.add(pool2) + session.commit() + + res = scheduler._find_executable_task_instances( + dagbag, + states=[State.SCHEDULED], + session=session) + session.commit() + self.assertEqual(3, len(res)) + res_keys = [] + for ti in res: + res_keys.append(ti.key) + self.assertIn(tis[0].key, res_keys) + self.assertIn(tis[1].key, res_keys) + self.assertIn(tis[3].key, res_keys) + + def test_find_executable_task_instances_none(self): + dag_id = 'SchedulerJobTest.test_find_executable_task_instances_none' + task_id_1 = 'dummy' + dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, concurrency=16) + task1 = DummyOperator(dag=dag, task_id=task_id_1) + dagbag = SimpleDagBag([dag]) + + scheduler = SchedulerJob(**self.default_scheduler_args) + session = settings.Session() + + dr1 = scheduler.create_dag_run(dag) + session.commit() + + self.assertEqual(0, len(scheduler._find_executable_task_instances( + dagbag, + states=[State.SCHEDULED], + session=session))) + + def test_find_executable_task_instances_concurrency(self): + dag_id = 'SchedulerJobTest.test_find_executable_task_instances_concurrency' + task_id_1 = 'dummy' + dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, concurrency=2) + task1 = DummyOperator(dag=dag, task_id=task_id_1) + dagbag = SimpleDagBag([dag]) + + scheduler = SchedulerJob(**self.default_scheduler_args) + session = settings.Session() + + dr1 = scheduler.create_dag_run(dag) + dr2 = scheduler.create_dag_run(dag) + dr3 = scheduler.create_dag_run(dag) + + ti1 = TI(task1, dr1.execution_date) + ti2 = TI(task1, dr2.execution_date) + ti3 = TI(task1, dr3.execution_date) + ti1.state = State.RUNNING + ti2.state = State.SCHEDULED + ti3.state = State.SCHEDULED + session.merge(ti1) + session.merge(ti2) + session.merge(ti3) + + session.commit() + + res = scheduler._find_executable_task_instances( + dagbag, + states=[State.SCHEDULED], + session=session) + + self.assertEqual(1, len(res)) + res_keys = map(lambda x: x.key, res) + self.assertIn(ti2.key, res_keys) + + ti2.state = State.RUNNING + session.merge(ti2) + session.commit() + + res = scheduler._find_executable_task_instances( + dagbag, + states=[State.SCHEDULED], + session=session) + + self.assertEqual(0, len(res)) + + def test_change_state_for_executable_task_instances_no_tis(self): + scheduler = SchedulerJob(**self.default_scheduler_args) + session = settings.Session() + res = scheduler._change_state_for_executable_task_instances([], [State.NONE], session) + self.assertEqual(0, len(res)) + + def test_change_state_for_executable_task_instances_no_tis_with_state(self): + dag_id = 'SchedulerJobTest.test_change_state_for__no_tis_with_state' + task_id_1 = 'dummy' + dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, concurrency=2) + task1 = DummyOperator(dag=dag, task_id=task_id_1) + dagbag = SimpleDagBag([dag]) + + scheduler = SchedulerJob(**self.default_scheduler_args) + session = settings.Session() + + dr1 = scheduler.create_dag_run(dag) + dr2 = scheduler.create_dag_run(dag) + dr3 = scheduler.create_dag_run(dag) + + ti1 = TI(task1, dr1.execution_date) + ti2 = TI(task1, dr2.execution_date) + ti3 = TI(task1, dr3.execution_date) + ti1.state = State.SCHEDULED + ti2.state = State.SCHEDULED + ti3.state = State.SCHEDULED + session.merge(ti1) + session.merge(ti2) + session.merge(ti3) + + session.commit() + + res = scheduler._change_state_for_executable_task_instances( + [ti1, ti2, ti3], + [State.RUNNING], + session) + self.assertEqual(0, len(res)) + + def test_change_state_for_executable_task_instances_none_state(self): + dag_id = 'SchedulerJobTest.test_change_state_for__none_state' + task_id_1 = 'dummy' + dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, concurrency=2) + task1 = DummyOperator(dag=dag, task_id=task_id_1) + dagbag = SimpleDagBag([dag]) + + scheduler = SchedulerJob(**self.default_scheduler_args) + session = settings.Session() + + dr1 = scheduler.create_dag_run(dag) + dr2 = scheduler.create_dag_run(dag) + dr3 = scheduler.create_dag_run(dag) + + ti1 = TI(task1, dr1.execution_date) + ti2 = TI(task1, dr2.execution_date) + ti3 = TI(task1, dr3.execution_date) + ti1.state = State.SCHEDULED + ti2.state = State.QUEUED + ti3.state = State.NONE + session.merge(ti1) + session.merge(ti2) + session.merge(ti3) + + session.commit() + + res = scheduler._change_state_for_executable_task_instances( + [ti1, ti2, ti3], + [State.NONE, State.SCHEDULED], + session) + self.assertEqual(2, len(res)) + ti1.refresh_from_db() + ti3.refresh_from_db() + self.assertEqual(State.QUEUED, ti1.state) + self.assertEqual(State.QUEUED, ti3.state) + + def test_enqueue_task_instances_with_queued_state(self): + dag_id = 'SchedulerJobTest.test_enqueue_task_instances_with_queued_state' + task_id_1 = 'dummy' + dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE) + task1 = DummyOperator(dag=dag, task_id=task_id_1) + dagbag = SimpleDagBag([dag]) + + scheduler = SchedulerJob(**self.default_scheduler_args) + session = settings.Session() + + dr1 = scheduler.create_dag_run(dag) + + ti1 = TI(task1, dr1.execution_date) + session.merge(ti1) + session.commit() + + with patch.object(BaseExecutor, 'queue_command') as mock_queue_command: + scheduler._enqueue_task_instances_with_queued_state(dagbag, [ti1]) + + mock_queue_command.assert_called() + + def test_execute_task_instances_nothing(self): + dag_id = 'SchedulerJobTest.test_execute_task_instances_nothing' + task_id_1 = 'dummy' + dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, concurrency=2) + task1 = DummyOperator(dag=dag, task_id=task_id_1) + dagbag = SimpleDagBag([]) + + scheduler = SchedulerJob(**self.default_scheduler_args) + session = settings.Session() + + dr1 = scheduler.create_dag_run(dag) + ti1 = TI(task1, dr1.execution_date) + ti1.state = State.SCHEDULED + session.merge(ti1) + session.commit() + + self.assertEqual(0, scheduler._execute_task_instances(dagbag, states=[State.SCHEDULED])) + + def test_execute_task_instances(self): + dag_id = 'SchedulerJobTest.test_execute_task_instances' task_id_1 = 'dummy_task' task_id_2 = 'dummy_task_nonexistent_queue' # important that len(tasks) is less than concurrency @@ -765,7 +1023,7 @@ class SchedulerJobTest(unittest.TestCase): self.assertEqual(State.RUNNING, dr2.state) - scheduler._execute_task_instances(dagbag, [State.SCHEDULED]) + res = scheduler._execute_task_instances(dagbag, [State.SCHEDULED]) # check that concurrency is respected ti1.refresh_from_db() @@ -777,8 +1035,7 @@ class SchedulerJobTest(unittest.TestCase): self.assertEqual(State.RUNNING, ti1.state) self.assertEqual(State.RUNNING, ti2.state) six.assertCountEqual(self, [State.QUEUED, State.SCHEDULED], [ti3.state, ti4.state]) - - session.close() + self.assertEqual(1, res) def test_change_state_for_tis_without_dagrun(self): dag = DAG(
