Repository: incubator-airflow Updated Branches: refs/heads/master a92330e4f -> 042c3f2ae
[AIRFLOW-2430] Extend query batching to additional slow queries Closes #3324 from gsilk/batch-inserts Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/042c3f2a Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/042c3f2a Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/042c3f2a Branch: refs/heads/master Commit: 042c3f2aeec7e4da335ae900c5b7499250304175 Parents: a92330e Author: Gabriel Silk <[email protected]> Authored: Sun May 13 20:54:00 2018 +0200 Committer: Fokko Driesprong <[email protected]> Committed: Sun May 13 20:54:00 2018 +0200 ---------------------------------------------------------------------- airflow/config_templates/default_airflow.cfg | 13 +++- airflow/config_templates/default_test.cfg | 2 +- airflow/jobs.py | 93 ++++++++++++----------- airflow/utils/helpers.py | 24 +++++- tests/utils/test_helpers.py | 30 ++++++++ 5 files changed, 113 insertions(+), 49 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/042c3f2a/airflow/config_templates/default_airflow.cfg ---------------------------------------------------------------------- diff --git a/airflow/config_templates/default_airflow.cfg b/airflow/config_templates/default_airflow.cfg index b07f1f6..ee28cc5 100644 --- a/airflow/config_templates/default_airflow.cfg +++ b/airflow/config_templates/default_airflow.cfg @@ -421,9 +421,16 @@ scheduler_zombie_task_threshold = 300 catchup_by_default = True # This changes the batch size of queries in the scheduling main loop. -# This depends on query length limits and how long you are willing to hold locks. -# 0 for no limit -max_tis_per_query = 0 +# If this is too high, SQL query performance may be impacted by one +# or more of the following: +# - reversion to full table scan +# - complexity of query predicate +# - excessive locking +# +# Additionally, you may hit the maximum allowable query length for your db. +# +# Set this to 0 for no limit (not advised) +max_tis_per_query = 512 # Statsd (https://github.com/etsy/statsd) integration settings statsd_on = False http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/042c3f2a/airflow/config_templates/default_test.cfg ---------------------------------------------------------------------- diff --git a/airflow/config_templates/default_test.cfg b/airflow/config_templates/default_test.cfg index 7c569cd..7360619 100644 --- a/airflow/config_templates/default_test.cfg +++ b/airflow/config_templates/default_test.cfg @@ -108,7 +108,7 @@ max_threads = 2 catchup_by_default = True scheduler_zombie_task_threshold = 300 dag_dir_list_interval = 0 -max_tis_per_query = 0 +max_tis_per_query = 512 [admin] hide_sensitive_variable_fields = True http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/042c3f2a/airflow/jobs.py ---------------------------------------------------------------------- diff --git a/airflow/jobs.py b/airflow/jobs.py index 045b4b7..7f4f470 100644 --- a/airflow/jobs.py +++ b/airflow/jobs.py @@ -51,7 +51,7 @@ from airflow.models import DAG, DagRun from airflow.settings import Stats from airflow.task.task_runner import get_task_runner from airflow.ti_deps.dep_context import DepContext, QUEUE_DEPS, RUN_DEPS -from airflow.utils import asciiart, timezone +from airflow.utils import asciiart, helpers, timezone from airflow.utils.dag_processing import (AbstractDagFileProcessor, DagFileProcessorManager, SimpleDag, @@ -110,6 +110,7 @@ class BaseJob(Base, LoggingMixin): self.latest_heartbeat = timezone.utcnow() self.heartrate = heartrate self.unixname = getpass.getuser() + self.max_tis_per_query = conf.getint('scheduler', 'max_tis_per_query') super(BaseJob, self).__init__(*args, **kwargs) def is_alive(self): @@ -254,21 +255,30 @@ class BaseJob(Base, LoggingMixin): if ti.key not in queued_tis and ti.key not in running_tis: tis_to_reset.append(ti) - filter_for_tis = ([and_(TI.dag_id == ti.dag_id, - TI.task_id == ti.task_id, - TI.execution_date == ti.execution_date) - for ti in tis_to_reset]) if len(tis_to_reset) == 0: return [] - reset_tis = ( - session - .query(TI) - .filter(or_(*filter_for_tis), TI.state.in_(resettable_states)) - .with_for_update() - .all()) - for ti in reset_tis: - ti.state = State.NONE - session.merge(ti) + + def query(result, items): + filter_for_tis = ([and_(TI.dag_id == ti.dag_id, + TI.task_id == ti.task_id, + TI.execution_date == ti.execution_date) + for ti in items]) + reset_tis = ( + session + .query(TI) + .filter(or_(*filter_for_tis), TI.state.in_(resettable_states)) + .with_for_update() + .all()) + for ti in reset_tis: + ti.state = State.NONE + session.merge(ti) + return result + reset_tis + + reset_tis = helpers.reduce_in_chunks(query, + tis_to_reset, + [], + self.max_tis_per_query) + task_instance_str = '\n\t'.join( ["{}".format(x) for x in reset_tis]) session.commit() @@ -579,7 +589,6 @@ class SchedulerJob(BaseJob): # files have finished parsing. self.min_file_parsing_loop_time = min_file_parsing_loop_time - self.max_tis_per_query = conf.getint('scheduler', 'max_tis_per_query') if run_duration is None: self.run_duration = conf.getint('scheduler', 'run_duration') @@ -1261,19 +1270,28 @@ class SchedulerJob(BaseJob): filter_for_ti_enqueue = ([and_(TI.dag_id == ti.dag_id, TI.task_id == ti.task_id, TI.execution_date == ti.execution_date) - for ti in tis_to_set_to_queued]) + for ti in tis_to_set_to_queued]) session.commit() - # requery in batch since above was expired by commit - tis_to_be_queued = ( - session - .query(TI) - .filter(or_(*filter_for_ti_enqueue)) - .all()) + # requery in batches since above was expired by commit + + def query(result, items): + tis_to_be_queued = ( + session + .query(TI) + .filter(or_(*items)) + .all()) + task_instance_str = "\n\t".join( + ["{}".format(x) for x in tis_to_be_queued]) + self.log.info("Setting the follow tasks to queued state:\n\t%s", + task_instance_str) + return result + tis_to_be_queued + + tis_to_be_queued = helpers.reduce_in_chunks(query, + filter_for_ti_enqueue, + [], + self.max_tis_per_query) - task_instance_str = "\n\t".join( - ["{}".format(x) for x in tis_to_be_queued]) - self.log.info("Setting the follow tasks to queued state:\n\t%s", task_instance_str) return tis_to_be_queued def _enqueue_task_instances_with_queued_state(self, simple_dag_bag, task_instances): @@ -1349,32 +1367,19 @@ class SchedulerJob(BaseJob): """ executable_tis = self._find_executable_task_instances(simple_dag_bag, states, session=session) - if self.max_tis_per_query == 0: + + def query(result, items): tis_with_state_changed = self._change_state_for_executable_task_instances( - executable_tis, + items, states, session=session) self._enqueue_task_instances_with_queued_state( simple_dag_bag, tis_with_state_changed) session.commit() - return len(tis_with_state_changed) - else: - # makes chunks of max_tis_per_query size - chunks = ([executable_tis[i:i + self.max_tis_per_query] - for i in range(0, len(executable_tis), self.max_tis_per_query)]) - total_tis_queued = 0 - for chunk in chunks: - tis_with_state_changed = self._change_state_for_executable_task_instances( - chunk, - states, - session=session) - self._enqueue_task_instances_with_queued_state( - simple_dag_bag, - tis_with_state_changed) - session.commit() - total_tis_queued += len(tis_with_state_changed) - return total_tis_queued + return result + len(tis_with_state_changed) + + return helpers.reduce_in_chunks(query, executable_tis, 0, self.max_tis_per_query) def _process_dags(self, dagbag, dags, tis_out): """ http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/042c3f2a/airflow/utils/helpers.py ---------------------------------------------------------------------- diff --git a/airflow/utils/helpers.py b/airflow/utils/helpers.py index d2affe5..3389788 100644 --- a/airflow/utils/helpers.py +++ b/airflow/utils/helpers.py @@ -27,7 +27,7 @@ import psutil from builtins import input from past.builtins import basestring from datetime import datetime -import getpass +from functools import reduce import imp import os import re @@ -123,6 +123,28 @@ def as_tuple(obj): return tuple([obj]) +def chunks(items, chunk_size): + """ + Yield successive chunks of a given size from a list of items + """ + if (chunk_size <= 0): + raise ValueError('Chunk size must be a positive integer') + for i in range(0, len(items), chunk_size): + yield items[i:i + chunk_size] + + +def reduce_in_chunks(fn, iterable, initializer, chunk_size=0): + """ + Reduce the given list of items by splitting it into chunks + of the given size and passing each chunk through the reducer + """ + if len(iterable) == 0: + return initializer + if chunk_size == 0: + chunk_size = len(iterable) + return reduce(fn, chunks(iterable, chunk_size), initializer) + + def as_flattened_list(iterable): """ Return an iterable with one level flattened http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/042c3f2a/tests/utils/test_helpers.py ---------------------------------------------------------------------- diff --git a/tests/utils/test_helpers.py b/tests/utils/test_helpers.py index 0a07536..1005671 100644 --- a/tests/utils/test_helpers.py +++ b/tests/utils/test_helpers.py @@ -86,6 +86,36 @@ class TestHelpers(unittest.TestCase): except OSError: pass + def test_chunks(self): + with self.assertRaises(ValueError): + [i for i in helpers.chunks([1, 2, 3], 0)] + + with self.assertRaises(ValueError): + [i for i in helpers.chunks([1, 2, 3], -3)] + + self.assertEqual([i for i in helpers.chunks([], 5)], []) + self.assertEqual([i for i in helpers.chunks([1], 1)], [[1]]) + self.assertEqual([i for i in helpers.chunks([1, 2, 3], 2)], + [[1, 2], [3]]) + + def test_reduce_in_chunks(self): + self.assertEqual(helpers.reduce_in_chunks(lambda x, y: x + [y], + [1, 2, 3, 4, 5], + []), + [[1, 2, 3, 4, 5]]) + + self.assertEqual(helpers.reduce_in_chunks(lambda x, y: x + [y], + [1, 2, 3, 4, 5], + [], + 2), + [[1, 2], [3, 4], [5]]) + + self.assertEqual(helpers.reduce_in_chunks(lambda x, y: x + y[0] * y[1], + [1, 2, 3, 4], + 0, + 2), + 14) + if __name__ == '__main__': unittest.main()
