This is an automated email from the ASF dual-hosted git repository. kaxilnaik pushed a commit to branch v1-10-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit b95126ba14c861d7bcbb1dbf9248aec73a8b343b Author: Kaxil Naik <[email protected]> AuthorDate: Mon Jul 20 12:45:18 2020 +0100 Update Serialized DAGs in Webserver when DAGs are Updated (#9851) Before this change, if DAG Serialization was enabled the Webserver would not update the DAGs once they are fetched from DB. The default worker_refresh_interval was `30` so whenever the gunicorn workers were restarted, they used to pull the updated DAGs when needed. This change will allow us to have a larged worker_refresh_interval (e.g 30 mins or even 1 day) (cherry picked from commit 84b85d8acc181edfe1fdd21b82c1773c19c47044) --- airflow/config_templates/config.yml | 8 +++ airflow/config_templates/default_airflow.cfg | 4 ++ airflow/models/dagbag.py | 40 +++++++++++---- airflow/models/serialized_dag.py | 14 ++++++ airflow/settings.py | 5 ++ docs/dag-serialization.rst | 11 ++++- tests/models/test_dagbag.py | 45 +++++++++++++++++ tests/test_utils/asserts.py | 73 ++++++++++++++++++++++++++++ 8 files changed, 188 insertions(+), 12 deletions(-) diff --git a/airflow/config_templates/config.yml b/airflow/config_templates/config.yml index 75c47cb..9535d5b 100644 --- a/airflow/config_templates/config.yml +++ b/airflow/config_templates/config.yml @@ -447,6 +447,14 @@ type: string example: ~ default: "30" + - name: min_serialized_dag_fetch_interval + description: | + Fetching serialized DAG can not be faster than a minimum interval to reduce database + read rate. This config controls when your DAGs are updated in the Webserver + version_added: 1.10.12 + type: string + example: ~ + default: "10" - name: store_dag_code description: | Whether to persist DAG files code in DB. diff --git a/airflow/config_templates/default_airflow.cfg b/airflow/config_templates/default_airflow.cfg index 3a9bba2..9729403 100644 --- a/airflow/config_templates/default_airflow.cfg +++ b/airflow/config_templates/default_airflow.cfg @@ -234,6 +234,10 @@ store_serialized_dags = False # Updating serialized DAG can not be faster than a minimum interval to reduce database write rate. min_serialized_dag_update_interval = 30 +# Fetching serialized DAG can not be faster than a minimum interval to reduce database +# read rate. This config controls when your DAGs are updated in the Webserver +min_serialized_dag_fetch_interval = 10 + # Whether to persist DAG files code in DB. # If set to True, Webserver reads file contents from DB instead of # trying to access files in a DAG folder. Defaults to same as the diff --git a/airflow/models/dagbag.py b/airflow/models/dagbag.py index 48cbd3e..1b8be89 100644 --- a/airflow/models/dagbag.py +++ b/airflow/models/dagbag.py @@ -28,7 +28,7 @@ import sys import textwrap import zipfile from collections import namedtuple -from datetime import datetime +from datetime import datetime, timedelta from croniter import CroniterBadCronError, CroniterBadDateError, CroniterNotAlphaError, croniter import six @@ -102,6 +102,7 @@ class DagBag(BaseDagBag, LoggingMixin): self.import_errors = {} self.has_logged = False self.store_serialized_dags = store_serialized_dags + self.dags_last_fetched = {} self.collect_dags( dag_folder=dag_folder, @@ -127,20 +128,26 @@ class DagBag(BaseDagBag, LoggingMixin): """ from airflow.models.dag import DagModel # Avoid circular import - # Only read DAGs from DB if this dagbag is store_serialized_dags. if self.store_serialized_dags: # Import here so that serialized dag is only imported when serialization is enabled from airflow.models.serialized_dag import SerializedDagModel if dag_id not in self.dags: # Load from DB if not (yet) in the bag - row = SerializedDagModel.get(dag_id) - if not row: - return None - - dag = row.dag - for subdag in dag.subdags: - self.dags[subdag.dag_id] = subdag - self.dags[dag.dag_id] = dag + self._add_dag_from_db(dag_id=dag_id) + return self.dags.get(dag_id) + + # If DAG is in the DagBag, check the following + # 1. if time has come to check if DAG is updated (controlled by min_serialized_dag_fetch_secs) + # 2. check the last_updated column in SerializedDag table to see if Serialized DAG is updated + # 3. if (2) is yes, fetch the Serialized DAG. + min_serialized_dag_fetch_secs = timedelta(seconds=settings.MIN_SERIALIZED_DAG_FETCH_INTERVAL) + if ( + dag_id in self.dags_last_fetched and + timezone.utcnow() > self.dags_last_fetched[dag_id] + min_serialized_dag_fetch_secs + ): + sd_last_updated_datetime = SerializedDagModel.get_last_updated_datetime(dag_id=dag_id) + if sd_last_updated_datetime > self.dags_last_fetched[dag_id]: + self._add_dag_from_db(dag_id=dag_id) return self.dags.get(dag_id) @@ -178,6 +185,19 @@ class DagBag(BaseDagBag, LoggingMixin): del self.dags[dag_id] return self.dags.get(dag_id) + def _add_dag_from_db(self, dag_id): + """Add DAG to DagBag from DB""" + from airflow.models.serialized_dag import SerializedDagModel + row = SerializedDagModel.get(dag_id) + if not row: + raise ValueError("DAG '{}' not found in serialized_dag table".format(dag_id)) + + dag = row.dag + for subdag in dag.subdags: + self.dags[subdag.dag_id] = subdag + self.dags[dag.dag_id] = dag + self.dags_last_fetched[dag.dag_id] = timezone.utcnow() + def process_file(self, filepath, only_if_updated=True, safe_mode=True): """ Given a path to a python module or zip file, this method imports diff --git a/airflow/models/serialized_dag.py b/airflow/models/serialized_dag.py index 1313cac..d29e43c 100644 --- a/airflow/models/serialized_dag.py +++ b/airflow/models/serialized_dag.py @@ -219,3 +219,17 @@ class SerializedDagModel(Base): DagModel.root_dag_id).filter(DagModel.dag_id == dag_id).scalar() return session.query(cls).filter(cls.dag_id == root_dag_id).one_or_none() + + @classmethod + @db.provide_session + def get_last_updated_datetime(cls, dag_id, session): + """ + Get the date when the Serialized DAG associated to DAG was last updated + in serialized_dag table + + :param dag_id: DAG ID + :type dag_id: str + :param session: ORM Session + :type session: Session + """ + return session.query(cls.last_updated).filter(cls.dag_id == dag_id).scalar() diff --git a/airflow/settings.py b/airflow/settings.py index 0158ec8..e39c960 100644 --- a/airflow/settings.py +++ b/airflow/settings.py @@ -428,6 +428,11 @@ STORE_SERIALIZED_DAGS = conf.getboolean('core', 'store_serialized_dags', fallbac MIN_SERIALIZED_DAG_UPDATE_INTERVAL = conf.getint( 'core', 'min_serialized_dag_update_interval', fallback=30) +# Fetching serialized DAG can not be faster than a minimum interval to reduce database +# read rate. This config controls when your DAGs are updated in the Webserver +MIN_SERIALIZED_DAG_FETCH_INTERVAL = conf.getint( + 'core', 'min_serialized_dag_fetch_interval', fallback=10) + # Whether to persist DAG files code in DB. If set to True, Webserver reads file contents # from DB instead of trying to access files in a DAG folder. # Defaults to same as the store_serialized_dags setting. diff --git a/docs/dag-serialization.rst b/docs/dag-serialization.rst index 0edd644..e2fcf14 100644 --- a/docs/dag-serialization.rst +++ b/docs/dag-serialization.rst @@ -57,14 +57,21 @@ Add the following settings in ``airflow.cfg``: [core] store_serialized_dags = True + store_dag_code = True + + # You can also update the following default configurations based on your needs min_serialized_dag_update_interval = 30 + min_serialized_dag_fetch_interval = 10 * ``store_serialized_dags``: This flag decides whether to serialise DAGs and persist them in DB. If set to True, Webserver reads from DB instead of parsing DAG files -* ``min_serialized_dag_update_interval``: This flag sets the minimum interval (in seconds) after which - the serialized DAG in DB should be updated. This helps in reducing database write rate. * ``store_dag_code``: This flag decides whether to persist DAG files code in DB. If set to True, Webserver reads file contents from DB instead of trying to access files in a DAG folder. +* ``min_serialized_dag_update_interval``: This flag sets the minimum interval (in seconds) after which + the serialized DAG in DB should be updated. This helps in reducing database write rate. +* ``min_serialized_dag_fetch_interval``: This flag controls how often a SerializedDAG will be re-fetched + from the DB when it's already loaded in the DagBag in the Webserver. Setting this higher will reduce + load on the DB, but at the expense of displaying a possibly stale cached version of the DAG. If you are updating Airflow from <1.10.7, please do not forget to run ``airflow upgradedb``. diff --git a/tests/models/test_dagbag.py b/tests/models/test_dagbag.py index 04c2372..b9d18ac 100644 --- a/tests/models/test_dagbag.py +++ b/tests/models/test_dagbag.py @@ -19,6 +19,7 @@ import inspect import os +import six import shutil import textwrap import unittest @@ -26,15 +27,19 @@ from datetime import datetime from tempfile import NamedTemporaryFile, mkdtemp from mock import patch, ANY +from freezegun import freeze_time from airflow import models from airflow.configuration import conf from airflow.utils.dag_processing import SimpleTaskInstance from airflow.models import DagModel, DagBag, TaskInstance as TI +from airflow.models.serialized_dag import SerializedDagModel +from airflow.utils.dates import timezone as tz from airflow.utils.db import create_session from airflow.utils.state import State from airflow.utils.timezone import utc from tests.models import TEST_DAGS_FOLDER, DEFAULT_DATE +from tests.test_utils.asserts import assert_queries_count from tests.test_utils.config import conf_vars import airflow.example_dags @@ -650,3 +655,43 @@ class DagBagTest(unittest.TestCase): # clean up with create_session() as session: session.query(DagModel).filter(DagModel.dag_id == 'test_deactivate_unknown_dags').delete() + + @patch("airflow.models.dagbag.settings.STORE_SERIALIZED_DAGS", True) + @patch("airflow.models.dagbag.settings.MIN_SERIALIZED_DAG_UPDATE_INTERVAL", 5) + @patch("airflow.models.dagbag.settings.MIN_SERIALIZED_DAG_FETCH_INTERVAL", 5) + def test_get_dag_with_dag_serialization(self): + """ + Test that Serialized DAG is updated in DagBag when it is updated in + Serialized DAG table after 'min_serialized_dag_fetch_interval' seconds are passed. + """ + + with freeze_time(tz.datetime(2020, 1, 5, 0, 0, 0)): + example_bash_op_dag = DagBag(include_examples=True).dags.get("example_bash_operator") + SerializedDagModel.write_dag(dag=example_bash_op_dag) + + dag_bag = DagBag(store_serialized_dags=True) + ser_dag_1 = dag_bag.get_dag("example_bash_operator") + ser_dag_1_update_time = dag_bag.dags_last_fetched["example_bash_operator"] + self.assertEqual(example_bash_op_dag.tags, ser_dag_1.tags) + self.assertEqual(ser_dag_1_update_time, tz.datetime(2020, 1, 5, 0, 0, 0)) + + # Check that if min_serialized_dag_fetch_interval has not passed we do not fetch the DAG + # from DB + with freeze_time(tz.datetime(2020, 1, 5, 0, 0, 4)): + with assert_queries_count(0): + self.assertEqual(dag_bag.get_dag("example_bash_operator").tags, ["example"]) + + # Make a change in the DAG and write Serialized DAG to the DB + with freeze_time(tz.datetime(2020, 1, 5, 0, 0, 6)): + example_bash_op_dag.tags += ["new_tag"] + SerializedDagModel.write_dag(dag=example_bash_op_dag) + + # Since min_serialized_dag_fetch_interval is passed verify that calling 'dag_bag.get_dag' + # fetches the Serialized DAG from DB + with freeze_time(tz.datetime(2020, 1, 5, 0, 0, 8)): + with assert_queries_count(2): + updated_ser_dag_1 = dag_bag.get_dag("example_bash_operator") + updated_ser_dag_1_update_time = dag_bag.dags_last_fetched["example_bash_operator"] + + six.assertCountEqual(self, updated_ser_dag_1.tags, ["example", "new_tag"]) + self.assertGreater(updated_ser_dag_1_update_time, ser_dag_1_update_time) diff --git a/tests/test_utils/asserts.py b/tests/test_utils/asserts.py new file mode 100644 index 0000000..ca3cf2f --- /dev/null +++ b/tests/test_utils/asserts.py @@ -0,0 +1,73 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import logging +import re +from contextlib import contextmanager + +from sqlalchemy import event + +# Long import to not create a copy of the reference, but to refer to one place. +import airflow.settings + +log = logging.getLogger(__name__) + + +def assert_equal_ignore_multiple_spaces(case, first, second, msg=None): + def _trim(s): + return re.sub(r"\s+", " ", s.strip()) + return case.assertEqual(_trim(first), _trim(second), msg) + + +class CountQueriesResult: + def __init__(self): + self.count = 0 + + +class CountQueries: + """ + Counts the number of queries sent to Airflow Database in a given context. + + Does not support multiple processes. When a new process is started in context, its queries will + not be included. + """ + def __init__(self): + self.result = CountQueriesResult() + + def __enter__(self): + event.listen(airflow.settings.engine, "after_cursor_execute", self.after_cursor_execute) + return self.result + + def __exit__(self, type_, value, traceback): + event.remove(airflow.settings.engine, "after_cursor_execute", self.after_cursor_execute) + log.debug("Queries count: %d", self.result.count) + + def after_cursor_execute(self, *args, **kwargs): + self.result.count += 1 + + +count_queries = CountQueries # pylint: disable=invalid-name + + +@contextmanager +def assert_queries_count(expected_count, message_fmt=None): + with count_queries() as result: + yield None + message_fmt = message_fmt or "The expected number of db queries is {expected_count}. " \ + "The current number is {current_count}." + message = message_fmt.format(current_count=result.count, expected_count=expected_count) + assert expected_count == result.count, message
