Repository: incubator-airflow Updated Branches: refs/heads/master fc26cade8 -> 6c93460b9
[AIRFLOW-1852] Allow hostname to be overridable. This allows hostnames to be overridable to facilitate service discovery requirements in common production deployments. Closes #3036 from thekashifmalik/hostnames Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/6c93460b Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/6c93460b Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/6c93460b Branch: refs/heads/master Commit: 6c93460b98f5a046662ca02956eac418907af765 Parents: fc26cad Author: Trevor Joynson (trevorj) <[email protected]> Authored: Tue Feb 20 16:56:25 2018 -0800 Committer: Joy Gao <[email protected]> Committed: Tue Feb 20 16:56:44 2018 -0800 ---------------------------------------------------------------------- airflow/bin/cli.py | 3 +- airflow/config_templates/default_airflow.cfg | 4 ++ airflow/jobs.py | 5 +- airflow/models.py | 6 +-- airflow/security/utils.py | 4 +- airflow/utils/net.py | 40 ++++++++++++++ airflow/www/app.py | 4 +- airflow/www/views.py | 5 +- tests/jobs.py | 5 +- tests/utils/test_net.py | 55 ++++++++++++++++++++ .../api/experimental/test_kerberos_endpoints.py | 5 +- 11 files changed, 121 insertions(+), 15 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/6c93460b/airflow/bin/cli.py ---------------------------------------------------------------------- diff --git a/airflow/bin/cli.py b/airflow/bin/cli.py index bf21a54..98b4321 100755 --- a/airflow/bin/cli.py +++ b/airflow/bin/cli.py @@ -54,6 +54,7 @@ from airflow.models import (DagModel, DagBag, TaskInstance, from airflow.ti_deps.dep_context import (DepContext, SCHEDULER_DEPS) from airflow.utils import db as db_utils +from airflow.utils.net import get_hostname from airflow.utils.log.logging_mixin import (LoggingMixin, redirect_stderr, redirect_stdout, set_context) from airflow.www.app import cached_app @@ -437,7 +438,7 @@ def run(args, dag=None): ti.init_run_context(raw=args.raw) - hostname = socket.getfqdn() + hostname = get_hostname() log.info("Running %s on host %s", ti, hostname) if args.interactive: http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/6c93460b/airflow/config_templates/default_airflow.cfg ---------------------------------------------------------------------- diff --git a/airflow/config_templates/default_airflow.cfg b/airflow/config_templates/default_airflow.cfg index 0f6d56f..5356af7 100644 --- a/airflow/config_templates/default_airflow.cfg +++ b/airflow/config_templates/default_airflow.cfg @@ -56,6 +56,10 @@ logging_config_class = log_format = [%%(asctime)s] {{%%(filename)s:%%(lineno)d}} %%(levelname)s - %%(message)s simple_log_format = %%(asctime)s %%(levelname)s - %%(message)s +# Hostname override by providing a path to a callable. +# hostname_callable = socket:getfqdn + + # Default timezone in case supplied date times are naive # can be utc (default), system, or any IANA timezone string (e.g. Europe/Amsterdam) default_timezone = utc http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/6c93460b/airflow/jobs.py ---------------------------------------------------------------------- diff --git a/airflow/jobs.py b/airflow/jobs.py index 00d6b22..5a2ab5b 100644 --- a/airflow/jobs.py +++ b/airflow/jobs.py @@ -59,6 +59,7 @@ from airflow.utils.email import send_email from airflow.utils.log.logging_mixin import LoggingMixin, set_context, StreamLogWriter from airflow.utils.state import State from airflow.utils.configuration import tmp_configuration_copy +from airflow.utils.net import get_hostname Base = models.Base ID_LEN = models.ID_LEN @@ -99,7 +100,7 @@ class BaseJob(Base, LoggingMixin): executor=executors.GetDefaultExecutor(), heartrate=conf.getfloat('scheduler', 'JOB_HEARTBEAT_SEC'), *args, **kwargs): - self.hostname = socket.getfqdn() + self.hostname = get_hostname() self.executor = executor self.executor_class = executor.__class__.__name__ self.start_date = timezone.utcnow() @@ -2569,7 +2570,7 @@ class LocalTaskJob(BaseJob): self.task_instance.refresh_from_db() ti = self.task_instance - fqdn = socket.getfqdn() + fqdn = get_hostname() same_hostname = fqdn == ti.hostname same_process = ti.pid == os.getpid() http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/6c93460b/airflow/models.py ---------------------------------------------------------------------- diff --git a/airflow/models.py b/airflow/models.py index 5e4474f..e436974 100755 --- a/airflow/models.py +++ b/airflow/models.py @@ -39,7 +39,6 @@ import os import pickle import re import signal -import socket import sys import textwrap import traceback @@ -84,6 +83,7 @@ from airflow.utils.state import State from airflow.utils.timeout import timeout from airflow.utils.trigger_rule import TriggerRule from airflow.utils.weight_rule import WeightRule +from airflow.utils.net import get_hostname from airflow.utils.log.logging_mixin import LoggingMixin install_aliases() @@ -1363,7 +1363,7 @@ class TaskInstance(Base, LoggingMixin): self.test_mode = test_mode self.refresh_from_db(session=session, lock_for_update=True) self.job_id = job_id - self.hostname = socket.getfqdn() + self.hostname = get_hostname() self.operator = task.__class__.__name__ if not ignore_all_deps and not ignore_ti_state and self.state == State.SUCCESS: @@ -1480,7 +1480,7 @@ class TaskInstance(Base, LoggingMixin): self.test_mode = test_mode self.refresh_from_db(session=session) self.job_id = job_id - self.hostname = socket.getfqdn() + self.hostname = get_hostname() self.operator = task.__class__.__name__ context = {} http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/6c93460b/airflow/security/utils.py ---------------------------------------------------------------------- diff --git a/airflow/security/utils.py b/airflow/security/utils.py index b2de3e3..24d823b 100644 --- a/airflow/security/utils.py +++ b/airflow/security/utils.py @@ -20,6 +20,8 @@ import re import socket import airflow.configuration as conf +from airflow.utils.net import get_hostname + # Pattern to replace with hostname HOSTNAME_PATTERN = '_HOST' @@ -53,7 +55,7 @@ def replace_hostname_pattern(components, host=None): def get_localhost_name(): - return socket.getfqdn() + return get_hostname() def get_fqdn(hostname_or_ip=None): http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/6c93460b/airflow/utils/net.py ---------------------------------------------------------------------- diff --git a/airflow/utils/net.py b/airflow/utils/net.py new file mode 100644 index 0000000..e00629f --- /dev/null +++ b/airflow/utils/net.py @@ -0,0 +1,40 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import importlib +import socket +from airflow.configuration import (conf, AirflowConfigException) + + +def get_hostname(): + """ + Fetch the hostname using the callable from the config or using + `socket.getfqdn` as a fallback. + """ + # First we attempt to fetch the callable path from the config. + try: + callable_path = conf.get('core', 'hostname_callable') + except AirflowConfigException: + callable_path = None + + # Then we handle the case when the config is missing or empty. This is the + # default behavior. + if not callable_path: + return socket.getfqdn() + + # Since we have a callable path, we try to import and run it next. + module_path, attr_name = callable_path.split(':') + module = importlib.import_module(module_path) + callable = getattr(module, attr_name) + return callable() http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/6c93460b/airflow/www/app.py ---------------------------------------------------------------------- diff --git a/airflow/www/app.py b/airflow/www/app.py index e7b4ca6..15399ea 100644 --- a/airflow/www/app.py +++ b/airflow/www/app.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import socket import six from flask import Flask @@ -32,6 +31,7 @@ from airflow.logging_config import configure_logging from airflow import jobs from airflow import settings from airflow import configuration +from airflow.utils.net import get_hostname csrf = CSRFProtect() @@ -149,7 +149,7 @@ def create_app(config=None, testing=False): @app.context_processor def jinja_globals(): return { - 'hostname': socket.getfqdn(), + 'hostname': get_hostname(), } @app.teardown_appcontext http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/6c93460b/airflow/www/views.py ---------------------------------------------------------------------- diff --git a/airflow/www/views.py b/airflow/www/views.py index a41f5fe..8694cbf 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -83,6 +83,7 @@ from airflow.utils.db import create_session, provide_session from airflow.utils.helpers import alchemy_to_dict from airflow.utils.dates import infer_time_unit, scale_time_units, parse_execution_date from airflow.utils.timezone import datetime +from airflow.utils.net import get_hostname from airflow.www import utils as wwwutils from airflow.www.forms import DateTimeForm, DateTimeWithNumRunsForm from airflow.www.validators import GreaterEqualThan @@ -647,14 +648,14 @@ class Airflow(BaseView): @current_app.errorhandler(404) def circles(self): return render_template( - 'airflow/circles.html', hostname=socket.getfqdn()), 404 + 'airflow/circles.html', hostname=get_hostname()), 404 @current_app.errorhandler(500) def show_traceback(self): from airflow.utils import asciiart as ascii_ return render_template( 'airflow/traceback.html', - hostname=socket.getfqdn(), + hostname=get_hostname(), nukular=ascii_.nukular, info=traceback.format_exc()), 500 http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/6c93460b/tests/jobs.py ---------------------------------------------------------------------- diff --git a/tests/jobs.py b/tests/jobs.py index 1c87b8f..ace593a 100644 --- a/tests/jobs.py +++ b/tests/jobs.py @@ -44,6 +44,7 @@ from airflow.utils.db import provide_session from airflow.utils.state import State from airflow.utils.timeout import timeout from airflow.utils.dag_processing import SimpleDag, SimpleDagBag, list_py_file_paths +from airflow.utils.net import get_hostname from mock import Mock, patch from sqlalchemy.orm.session import make_transient @@ -841,7 +842,7 @@ class LocalTaskJobTest(unittest.TestCase): mock_pid.return_value = 1 ti.state = State.RUNNING - ti.hostname = socket.getfqdn() + ti.hostname = get_hostname() ti.pid = 1 session.merge(ti) session.commit() @@ -911,7 +912,7 @@ class LocalTaskJobTest(unittest.TestCase): session=session) ti = dr.get_task_instance(task_id=task.task_id, session=session) ti.state = State.RUNNING - ti.hostname = socket.getfqdn() + ti.hostname = get_hostname() ti.pid = 1 session.commit() http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/6c93460b/tests/utils/test_net.py ---------------------------------------------------------------------- diff --git a/tests/utils/test_net.py b/tests/utils/test_net.py new file mode 100644 index 0000000..c71612d --- /dev/null +++ b/tests/utils/test_net.py @@ -0,0 +1,55 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import mock + +from airflow.utils import net + + +def get_hostname(): + return 'awesomehostname' + + +class GetHostname(unittest.TestCase): + + @mock.patch('airflow.utils.net.socket') + @mock.patch('airflow.utils.net.conf') + def test_get_hostname_unset(self, patched_conf, patched_socket): + patched_conf.get = mock.Mock(return_value=None) + patched_socket.getfqdn = mock.Mock(return_value='first') + self.assertTrue(net.get_hostname() == 'first') + + @mock.patch('airflow.utils.net.conf') + def test_get_hostname_set(self, patched_conf): + patched_conf.get = mock.Mock( + return_value='tests.utils.test_net:get_hostname' + ) + self.assertTrue(net.get_hostname() == 'awesomehostname') + + @mock.patch('airflow.utils.net.conf') + def test_get_hostname_set_incorrect(self, patched_conf): + patched_conf.get = mock.Mock( + return_value='tests.utils.test_net' + ) + with self.assertRaises(ValueError): + net.get_hostname() + + @mock.patch('airflow.utils.net.conf') + def test_get_hostname_set_missing(self, patched_conf): + patched_conf.get = mock.Mock( + return_value='tests.utils.test_net:missing_func' + ) + with self.assertRaises(AttributeError): + net.get_hostname() http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/6c93460b/tests/www/api/experimental/test_kerberos_endpoints.py ---------------------------------------------------------------------- diff --git a/tests/www/api/experimental/test_kerberos_endpoints.py b/tests/www/api/experimental/test_kerberos_endpoints.py index 2fce019..a23c10f 100644 --- a/tests/www/api/experimental/test_kerberos_endpoints.py +++ b/tests/www/api/experimental/test_kerberos_endpoints.py @@ -22,6 +22,7 @@ from datetime import datetime from airflow import configuration from airflow.api.auth.backend.kerberos_auth import client_auth +from airflow.utils.net import get_hostname from airflow.www import app as application @@ -57,7 +58,7 @@ class ApiKerberosTests(unittest.TestCase): ) self.assertEqual(401, response.status_code) - response.url = 'http://{}'.format(socket.getfqdn()) + response.url = 'http://{}'.format(get_hostname()) class Request(): headers = {} @@ -72,7 +73,7 @@ class ApiKerberosTests(unittest.TestCase): client_auth.mutual_authentication = 3 # case can influence the results - client_auth.hostname_override = socket.getfqdn() + client_auth.hostname_override = get_hostname() client_auth.handle_response(response) self.assertIn('Authorization', response.request.headers)
