Repository: incubator-airflow Updated Branches: refs/heads/master 5de632e07 -> a7a518902
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/a7a51890/airflow/plugins_manager.py ---------------------------------------------------------------------- diff --git a/airflow/plugins_manager.py b/airflow/plugins_manager.py index 83aae23..7c1d246 100644 --- a/airflow/plugins_manager.py +++ b/airflow/plugins_manager.py @@ -20,13 +20,14 @@ from __future__ import unicode_literals from builtins import object import imp import inspect -import logging import os import re import sys from airflow import configuration +from airflow.utils.log.LoggingMixin import LoggingMixin +log = LoggingMixin().logger class AirflowPluginException(Exception): pass @@ -72,7 +73,7 @@ for root, dirs, files in os.walk(plugins_folder, followlinks=True): if file_ext != '.py': continue - logging.debug('Importing plugin module ' + filepath) + log.debug('Importing plugin module %s', filepath) # normalize root path as namespace namespace = '_'.join([re.sub(norm_pattern, '__', root), mod_name]) @@ -87,12 +88,12 @@ for root, dirs, files in os.walk(plugins_folder, followlinks=True): plugins.append(obj) except Exception as e: - logging.exception(e) - logging.error('Failed to import plugin ' + filepath) + log.exception(e) + log.error('Failed to import plugin %s', filepath) def make_module(name, objects): - logging.debug('Creating module ' + name) + log.debug('Creating module %s', name) name = name.lower() module = imp.new_module(name) module._name = name.split('.')[-1] http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/a7a51890/airflow/security/kerberos.py ---------------------------------------------------------------------- diff --git a/airflow/security/kerberos.py b/airflow/security/kerberos.py index bac5c46..a9687b3 100644 --- a/airflow/security/kerberos.py +++ b/airflow/security/kerberos.py @@ -15,18 +15,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging +import socket import subprocess import sys import time -import socket -from airflow import configuration - -LOG = logging.getLogger(__name__) +from airflow import configuration, LoggingMixin NEED_KRB181_WORKAROUND = None +log = LoggingMixin().logger + def renew_from_kt(): # The config is specified in seconds. But we ask for that same amount in @@ -37,10 +36,10 @@ def renew_from_kt(): cmdv = [configuration.get('kerberos', 'kinit_path'), "-r", renewal_lifetime, "-k", # host ticket - "-t", configuration.get('kerberos', 'keytab'), # specify keytab - "-c", configuration.get('kerberos', 'ccache'), # specify credentials cache + "-t", configuration.get('kerberos', 'keytab'), # specify keytab + "-c", configuration.get('kerberos', 'ccache'), # specify credentials cache principal] - LOG.info("Reinitting kerberos from keytab: " + " ".join(cmdv)) + log.info("Reinitting kerberos from keytab: " + " ".join(cmdv)) subp = subprocess.Popen(cmdv, stdout=subprocess.PIPE, @@ -50,7 +49,7 @@ def renew_from_kt(): universal_newlines=True) subp.wait() if subp.returncode != 0: - LOG.error("Couldn't reinit from keytab! `kinit' exited with %s.\n%s\n%s" % ( + log.error("Couldn't reinit from keytab! `kinit' exited with %s.\n%s\n%s" % ( subp.returncode, "\n".join(subp.stdout.readlines()), "\n".join(subp.stderr.readlines()))) @@ -71,7 +70,7 @@ def perform_krb181_workaround(): "-c", configuration.get('kerberos', 'ccache'), "-R"] # Renew ticket_cache - LOG.info("Renewing kerberos ticket to work around kerberos 1.8.1: " + + log.info("Renewing kerberos ticket to work around kerberos 1.8.1: " + " ".join(cmdv)) ret = subprocess.call(cmdv) @@ -80,7 +79,7 @@ def perform_krb181_workaround(): principal = "%s/%s" % (configuration.get('kerberos', 'principal'), socket.getfqdn()) fmt_dict = dict(princ=principal, ccache=configuration.get('kerberos', 'principal')) - LOG.error("Couldn't renew kerberos ticket in order to work around " + log.error("Couldn't renew kerberos ticket in order to work around " "Kerberos 1.8.1 issue. Please check that the ticket for " "'%(princ)s' is still renewable:\n" " $ kinit -f -c %(ccache)s\n" @@ -105,8 +104,8 @@ def detect_conf_var(): def run(): - if configuration.get('kerberos','keytab') is None: - LOG.debug("Keytab renewer not starting, no keytab configured") + if configuration.get('kerberos', 'keytab') is None: + log.debug("Keytab renewer not starting, no keytab configured") sys.exit(0) while True: http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/a7a51890/airflow/settings.py ---------------------------------------------------------------------- diff --git a/airflow/settings.py b/airflow/settings.py index 9567020..cf1eca4 100644 --- a/airflow/settings.py +++ b/airflow/settings.py @@ -27,7 +27,9 @@ from sqlalchemy.orm import scoped_session, sessionmaker from sqlalchemy.pool import NullPool from airflow import configuration as conf +from airflow.utils.log.LoggingMixin import LoggingMixin +log = LoggingMixin().logger class DummyStatsLogger(object): @@ -130,8 +132,9 @@ def configure_logging(log_format=LOG_FORMAT): try: _configure_logging(logging_level) except ValueError: - logging.warning("Logging level {} is not defined. " - "Use default.".format(logging_level)) + logging.warning( + "Logging level %s is not defined. Use default.", logging_level + ) _configure_logging(logging.INFO) @@ -162,7 +165,7 @@ def configure_orm(disable_connection_pool=False): try: from airflow_local_settings import * - logging.info("Loaded airflow_local_settings.") + log.info("Loaded airflow_local_settings.") except: pass @@ -174,11 +177,13 @@ configure_orm() logging_config_path = conf.get('core', 'logging_config_path') try: from logging_config_path import LOGGING_CONFIG - logging.debug("Successfully imported user-defined logging config.") + log.debug("Successfully imported user-defined logging config.") except Exception as e: # Import default logging configurations. - logging.debug("Unable to load custom logging config file: {}." - " Using default airflow logging config instead".format(str(e))) + log.debug( + "Unable to load custom logging config file: %s. Using default airflow logging config instead", + e + ) from airflow.config_templates.default_airflow_logging import \ DEFAULT_LOGGING_CONFIG as LOGGING_CONFIG logging.config.dictConfig(LOGGING_CONFIG) http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/a7a51890/airflow/task_runner/base_task_runner.py ---------------------------------------------------------------------- diff --git a/airflow/task_runner/base_task_runner.py b/airflow/task_runner/base_task_runner.py index 8ca8f1a..7794f4a 100644 --- a/airflow/task_runner/base_task_runner.py +++ b/airflow/task_runner/base_task_runner.py @@ -19,8 +19,9 @@ import json import subprocess import threading +from airflow.utils.log.LoggingMixin import LoggingMixin + from airflow import configuration as conf -from airflow.utils.logging import LoggingMixin from tempfile import mkstemp @@ -53,7 +54,7 @@ class BaseTaskRunner(LoggingMixin): # Add sudo commands to change user if we need to. Needed to handle SubDagOperator # case using a SequentialExecutor. if self.run_as_user and (self.run_as_user != getpass.getuser()): - self.logger.debug("Planning to run as the {} user".format(self.run_as_user)) + self.logger.debug("Planning to run as the %s user", self.run_as_user) cfg_dict = conf.as_dict(display_sensitive=True) cfg_subset = { 'core': cfg_dict.get('core', {}), @@ -94,7 +95,7 @@ class BaseTaskRunner(LoggingMixin): line = line.decode('utf-8') if len(line) == 0: break - self.logger.info(u'Subtask: {}'.format(line.rstrip('\n'))) + self.logger.info('Subtask: %s', line.rstrip('\n')) def run_command(self, run_with, join_args=False): """ @@ -111,7 +112,7 @@ class BaseTaskRunner(LoggingMixin): """ cmd = [" ".join(self._command)] if join_args else self._command full_cmd = run_with + cmd - self.logger.info('Running: {}'.format(full_cmd)) + self.logger.info('Running: %s', full_cmd) proc = subprocess.Popen( full_cmd, stdout=subprocess.PIPE, http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/a7a51890/airflow/utils/dag_processing.py ---------------------------------------------------------------------- diff --git a/airflow/utils/dag_processing.py b/airflow/utils/dag_processing.py index 2e975c1..6497fcc 100644 --- a/airflow/utils/dag_processing.py +++ b/airflow/utils/dag_processing.py @@ -17,19 +17,17 @@ from __future__ import division from __future__ import print_function from __future__ import unicode_literals -import logging import os import re import time import zipfile - from abc import ABCMeta, abstractmethod from collections import defaultdict from datetime import datetime -from airflow.exceptions import AirflowException from airflow.dag.base_dag import BaseDag, BaseDagBag -from airflow.utils.logging import LoggingMixin +from airflow.exceptions import AirflowException +from airflow.utils.log.LoggingMixin import LoggingMixin class SimpleDag(BaseDag): @@ -207,7 +205,8 @@ def list_py_file_paths(directory, safe_mode=True): file_paths.append(file_path) except Exception: - logging.exception("Error while examining %s", f) + log = LoggingMixin().logger + log.exception("Error while examining %s", f) return file_paths @@ -444,7 +443,7 @@ class DagFileProcessorManager(LoggingMixin): if file_path in new_file_paths: filtered_processors[file_path] = processor else: - self.logger.warning("Stopping processor for {}".format(file_path)) + self.logger.warning("Stopping processor for %s", file_path) processor.stop() self._processors = filtered_processors @@ -512,17 +511,18 @@ class DagFileProcessorManager(LoggingMixin): log_directory = self._get_log_directory() latest_log_directory_path = os.path.join( self._child_process_log_directory, "latest") - if (os.path.isdir(log_directory)): + if os.path.isdir(log_directory): # if symlink exists but is stale, update it - if (os.path.islink(latest_log_directory_path)): - if(os.readlink(latest_log_directory_path) != log_directory): + if os.path.islink(latest_log_directory_path): + if os.readlink(latest_log_directory_path) != log_directory: os.unlink(latest_log_directory_path) os.symlink(log_directory, latest_log_directory_path) elif (os.path.isdir(latest_log_directory_path) or os.path.isfile(latest_log_directory_path)): - self.logger.warning("{} already exists as a dir/file. " - "Skip creating symlink." - .format(latest_log_directory_path)) + self.logger.warning( + "%s already exists as a dir/file. Skip creating symlink.", + latest_log_directory_path + ) else: os.symlink(log_directory, latest_log_directory_path) @@ -558,7 +558,7 @@ class DagFileProcessorManager(LoggingMixin): for file_path, processor in self._processors.items(): if processor.done: - self.logger.info("Processor for {} finished".format(file_path)) + self.logger.info("Processor for %s finished", file_path) now = datetime.now() finished_processors[file_path] = processor self._last_runtime[file_path] = (now - @@ -573,11 +573,10 @@ class DagFileProcessorManager(LoggingMixin): simple_dags = [] for file_path, processor in finished_processors.items(): if processor.result is None: - self.logger.warning("Processor for {} exited with return code " - "{}. See {} for details." - .format(processor.file_path, - processor.exit_code, - processor.log_file)) + self.logger.warning( + "Processor for %s exited with return code %s. See %s for details.", + processor.file_path, processor.exit_code, processor.log_file + ) else: for simple_dag in processor.result: simple_dags.append(simple_dag) @@ -607,12 +606,15 @@ class DagFileProcessorManager(LoggingMixin): set(files_paths_at_run_limit)) for file_path, processor in self._processors.items(): - self.logger.debug("File path {} is still being processed (started: {})" - .format(processor.file_path, - processor.start_time.isoformat())) + self.logger.debug( + "File path %s is still being processed (started: %s)", + processor.file_path, processor.start_time.isoformat() + ) - self.logger.debug("Queuing the following files for processing:\n\t{}" - .format("\n\t".join(files_paths_to_queue))) + self.logger.debug( + "Queuing the following files for processing:\n\t%s", + "\n\t".join(files_paths_to_queue) + ) self._file_path_queue.extend(files_paths_to_queue) @@ -624,9 +626,10 @@ class DagFileProcessorManager(LoggingMixin): processor = self._processor_factory(file_path, log_file_path) processor.start() - self.logger.info("Started a process (PID: {}) to generate " - "tasks for {} - logging into {}" - .format(processor.pid, file_path, log_file_path)) + self.logger.info( + "Started a process (PID: %s) to generate tasks for %s - logging into %s", + processor.pid, file_path, log_file_path + ) self._processors[file_path] = processor http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/a7a51890/airflow/utils/db.py ---------------------------------------------------------------------- diff --git a/airflow/utils/db.py b/airflow/utils/db.py index 35c187c..b3c8a4d 100644 --- a/airflow/utils/db.py +++ b/airflow/utils/db.py @@ -19,13 +19,16 @@ from __future__ import unicode_literals from datetime import datetime from functools import wraps -import logging + import os from sqlalchemy import event, exc from sqlalchemy.pool import Pool from airflow import settings +from airflow.utils.log.LoggingMixin import LoggingMixin + +log = LoggingMixin().logger def provide_session(func): """ @@ -308,7 +311,8 @@ def upgradedb(): from alembic import command from alembic.config import Config - logging.info("Creating tables") + log.info("Creating tables") + current_dir = os.path.dirname(os.path.abspath(__file__)) package_dir = os.path.normpath(os.path.join(current_dir, '..')) directory = os.path.join(package_dir, 'migrations') @@ -326,7 +330,8 @@ def resetdb(): # alembic adds significant import time, so we import it lazily from alembic.migration import MigrationContext - logging.info("Dropping tables that exist") + log.info("Dropping tables that exist") + models.Base.metadata.drop_all(settings.engine) mc = MigrationContext.configure(settings.engine) if mc._version.exists(settings.engine): http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/a7a51890/airflow/utils/email.py ---------------------------------------------------------------------- diff --git a/airflow/utils/email.py b/airflow/utils/email.py index 57219c3..f252d55 100644 --- a/airflow/utils/email.py +++ b/airflow/utils/email.py @@ -21,7 +21,6 @@ from builtins import str from past.builtins import basestring import importlib -import logging import os import smtplib @@ -32,6 +31,7 @@ from email.utils import formatdate from airflow import configuration from airflow.exceptions import AirflowConfigException +from airflow.utils.log.LoggingMixin import LoggingMixin def send_email(to, subject, html_content, files=None, dryrun=False, cc=None, bcc=None, mime_subtype='mixed'): @@ -88,6 +88,8 @@ def send_email_smtp(to, subject, html_content, files=None, dryrun=False, cc=None def send_MIME_email(e_from, e_to, mime_msg, dryrun=False): + log = LoggingMixin().logger + SMTP_HOST = configuration.get('smtp', 'SMTP_HOST') SMTP_PORT = configuration.getint('smtp', 'SMTP_PORT') SMTP_STARTTLS = configuration.getboolean('smtp', 'SMTP_STARTTLS') @@ -99,7 +101,7 @@ def send_MIME_email(e_from, e_to, mime_msg, dryrun=False): SMTP_USER = configuration.get('smtp', 'SMTP_USER') SMTP_PASSWORD = configuration.get('smtp', 'SMTP_PASSWORD') except AirflowConfigException: - logging.debug("No user/password found for SMTP, so logging in with no authentication.") + log.debug("No user/password found for SMTP, so logging in with no authentication.") if not dryrun: s = smtplib.SMTP_SSL(SMTP_HOST, SMTP_PORT) if SMTP_SSL else smtplib.SMTP(SMTP_HOST, SMTP_PORT) @@ -107,7 +109,7 @@ def send_MIME_email(e_from, e_to, mime_msg, dryrun=False): s.starttls() if SMTP_USER and SMTP_PASSWORD: s.login(SMTP_USER, SMTP_PASSWORD) - logging.info("Sent an alert email to " + str(e_to)) + log.info("Sent an alert email to %s", e_to) s.sendmail(e_from, e_to, mime_msg.as_string()) s.quit() http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/a7a51890/airflow/utils/log/LoggingMixin.py ---------------------------------------------------------------------- diff --git a/airflow/utils/log/LoggingMixin.py b/airflow/utils/log/LoggingMixin.py new file mode 100644 index 0000000..4572d63 --- /dev/null +++ b/airflow/utils/log/LoggingMixin.py @@ -0,0 +1,45 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import logging +from builtins import object + + +class LoggingMixin(object): + """ + Convenience super-class to have a logger configured with the class name + """ + + @property + def logger(self): + try: + return self._logger + except AttributeError: + self._logger = logging.root.getChild(self.__class__.__module__ + '.' + self.__class__.__name__) + return self._logger + + def set_logger_contexts(self, task_instance): + """ + Set the context for all handlers of current logger. + """ + for handler in self.logger.handlers: + try: + handler.set_context(task_instance) + except AttributeError: + pass http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/a7a51890/airflow/utils/log/file_task_handler.py ---------------------------------------------------------------------- diff --git a/airflow/utils/log/file_task_handler.py b/airflow/utils/log/file_task_handler.py index 7392aae..b31c968 100644 --- a/airflow/utils/log/file_task_handler.py +++ b/airflow/utils/log/file_task_handler.py @@ -14,6 +14,7 @@ import logging import os +import requests from jinja2 import Template @@ -65,16 +66,16 @@ class FileTaskHandler(logging.Handler): def close(self): if self.handler is not None: self.handler.close() - + def _render_filename(self, ti, try_number): if self.filename_jinja_template: jinja_context = ti.get_template_context() jinja_context['try_number'] = try_number - return self.filename_jinja_template.render(**jinja_context) - - return self.filename_template.format(dag_id=ti.dag_id, + return self.filename_jinja_template.render(**jinja_context) + + return self.filename_template.format(dag_id=ti.dag_id, task_id=ti.task_id, - execution_date=ti.execution_date.isoformat(), + execution_date=ti.execution_date.isoformat(), try_number=try_number) def _read(self, ti, try_number): @@ -89,32 +90,37 @@ class FileTaskHandler(logging.Handler): # initializing the handler. Thus explicitly getting log location # is needed to get correct log path. log_relative_path = self._render_filename(ti, try_number + 1) - loc = os.path.join(self.local_base, log_relative_path) + location = os.path.join(self.local_base, log_relative_path) + log = "" - if os.path.exists(loc): + if os.path.exists(location): try: - with open(loc) as f: + with open(location) as f: log += "*** Reading local log.\n" + "".join(f.readlines()) except Exception as e: - log = "*** Failed to load local log file: {}. {}\n".format(loc, str(e)) + log = "*** Failed to load local log file: {}. {}\n".format(location, str(e)) else: - url = os.path.join("http://{ti.hostname}:{worker_log_server_port}/log", - log_relative_path).format( + url = os.path.join( + "http://{ti.hostname}:{worker_log_server_port}/log", log_relative_path + ).format( ti=ti, - worker_log_server_port=conf.get('celery', 'WORKER_LOG_SERVER_PORT')) + worker_log_server_port=conf.get('celery', 'WORKER_LOG_SERVER_PORT') + ) log += "*** Log file isn't local.\n" log += "*** Fetching here: {url}\n".format(**locals()) try: - import requests timeout = None # No timeout try: timeout = conf.getint('webserver', 'log_fetch_timeout_sec') except (AirflowConfigException, ValueError): pass - response = requests.get(url, timeout=timeout) + response = requests.get(url, timeout=self.timeout) + + # Check if the resource was properly fetched response.raise_for_status() + log += '\n' + response.text except Exception as e: log += "*** Failed to fetch log file from worker. {}\n".format(str(e)) http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/a7a51890/airflow/utils/log/gcs_task_handler.py ---------------------------------------------------------------------- diff --git a/airflow/utils/log/gcs_task_handler.py b/airflow/utils/log/gcs_task_handler.py index c340f10..0bc0b5e 100644 --- a/airflow/utils/log/gcs_task_handler.py +++ b/airflow/utils/log/gcs_task_handler.py @@ -11,28 +11,46 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import os -import warnings -from airflow import configuration as conf -from airflow.utils import logging as logging_utils +from airflow import configuration +from airflow.exceptions import AirflowException +from airflow.utils.log.LoggingMixin import LoggingMixin from airflow.utils.log.file_task_handler import FileTaskHandler -class GCSTaskHandler(FileTaskHandler): +class GCSTaskHandler(FileTaskHandler, LoggingMixin): """ GCSTaskHandler is a python log handler that handles and reads task instance logs. It extends airflow FileTaskHandler and uploads to and reads from GCS remote storage. Upon log reading failure, it reads from host machine's local disk. """ - def __init__(self, base_log_folder, gcs_log_folder, filename_template): super(GCSTaskHandler, self).__init__(base_log_folder, filename_template) self.remote_base = gcs_log_folder self.log_relative_path = '' - self.closed = False + self._hook = None + + def _build_hook(self): + remote_conn_id = configuration.get('core', 'REMOTE_LOG_CONN_ID') + try: + from airflow.contrib.hooks.gcs_hook import GoogleCloudStorageHook + return GoogleCloudStorageHook( + google_cloud_storage_conn_id=remote_conn_id + ) + except: + self.logger.error( + 'Could not create a GoogleCloudStorageHook with connection id ' + '"%s". Please make sure that airflow[gcp_api] is installed ' + 'and the GCS connection exists.', remote_conn_id + ) + + @property + def hook(self): + if self._hook is None: + self._hook = self._build_hook() + return self._hook def set_context(self, ti): super(GCSTaskHandler, self).set_context(ti) @@ -49,7 +67,7 @@ class GCSTaskHandler(FileTaskHandler): # calling close method. Here we check if logger is already # closed to prevent uploading the log to remote storage multiple # times when `logging.shutdown` is called. - if self.closed: + if self._hook is None: return super(GCSTaskHandler, self).close() @@ -60,9 +78,10 @@ class GCSTaskHandler(FileTaskHandler): # read log and remove old logs to get just the latest additions with open(local_loc, 'r') as logfile: log = logfile.read() - logging_utils.GCSLog().write(log, remote_loc) + self.gcs_write(log, remote_loc) - self.closed = True + # Unset variable + self._hook = None def _read(self, ti, try_number): """ @@ -77,15 +96,95 @@ class GCSTaskHandler(FileTaskHandler): log_relative_path = self._render_filename(ti, try_number + 1) remote_loc = os.path.join(self.remote_base, log_relative_path) - gcs_log = logging_utils.GCSLog() - if gcs_log.log_exists(remote_loc): + if self.gcs_log_exists(remote_loc): # If GCS remote file exists, we do not fetch logs from task instance # local machine even if there are errors reading remote logs, as # remote_log will contain error message. - remote_log = gcs_log.read(remote_loc, return_error=True) + remote_log = self.gcs_read(remote_loc, return_error=True) log = '*** Reading remote log from {}.\n{}\n'.format( remote_loc, remote_log) else: log = super(GCSTaskHandler, self)._read(ti, try_number) return log + + def gcs_log_exists(self, remote_log_location): + """ + Check if remote_log_location exists in remote storage + :param remote_log_location: log's location in remote storage + :return: True if location exists else False + """ + try: + bkt, blob = self.parse_gcs_url(remote_log_location) + return self.hook.exists(bkt, blob) + except Exception: + pass + return False + + def gcs_read(self, remote_log_location, return_error=False): + """ + Returns the log found at the remote_log_location. + :param remote_log_location: the log's location in remote storage + :type remote_log_location: string (path) + :param return_error: if True, returns a string error message if an + error occurs. Otherwise returns '' when an error occurs. + :type return_error: bool + """ + try: + bkt, blob = self.parse_gcs_url(remote_log_location) + return self.hook.download(bkt, blob).decode() + except: + # return error if needed + if return_error: + msg = 'Could not read logs from {}'.format(remote_log_location) + self.logger.error(msg) + return msg + + def gcs_write(self, log, remote_log_location, append=True): + """ + Writes the log to the remote_log_location. Fails silently if no hook + was created. + :param log: the log to write to the remote_log_location + :type log: string + :param remote_log_location: the log's location in remote storage + :type remote_log_location: string (path) + :param append: if False, any existing log file is overwritten. If True, + the new log is appended to any existing logs. + :type append: bool + """ + if append: + old_log = self.read(remote_log_location) + log = '\n'.join([old_log, log]) + + try: + bkt, blob = self.parse_gcs_url(remote_log_location) + from tempfile import NamedTemporaryFile + with NamedTemporaryFile(mode='w+') as tmpfile: + tmpfile.write(log) + # Force the file to be flushed, since we're doing the + # upload from within the file context (it hasn't been + # closed). + tmpfile.flush() + self.hook.upload(bkt, blob, tmpfile.name) + except: + self.logger.error('Could not write logs to %s', remote_log_location) + + def parse_gcs_url(self, gsurl): + """ + Given a Google Cloud Storage URL (gs://<bucket>/<blob>), returns a + tuple containing the corresponding bucket and blob. + """ + # Python 3 + try: + from urllib.parse import urlparse + # Python 2 + except ImportError: + from urlparse import urlparse + + parsed_url = urlparse(gsurl) + if not parsed_url.netloc: + raise AirflowException('Please provide a bucket name') + else: + bucket = parsed_url.netloc + blob = parsed_url.path.strip('/') + return bucket, blob http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/a7a51890/airflow/utils/log/s3_task_handler.py ---------------------------------------------------------------------- diff --git a/airflow/utils/log/s3_task_handler.py b/airflow/utils/log/s3_task_handler.py index 51baaac..71fc149 100644 --- a/airflow/utils/log/s3_task_handler.py +++ b/airflow/utils/log/s3_task_handler.py @@ -11,25 +11,42 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import os -from airflow.utils import logging as logging_utils +from airflow import configuration +from airflow.utils.log.LoggingMixin import LoggingMixin from airflow.utils.log.file_task_handler import FileTaskHandler -class S3TaskHandler(FileTaskHandler): +class S3TaskHandler(FileTaskHandler, LoggingMixin): """ S3TaskHandler is a python log handler that handles and reads task instance logs. It extends airflow FileTaskHandler and uploads to and reads from S3 remote storage. """ - def __init__(self, base_log_folder, s3_log_folder, filename_template): super(S3TaskHandler, self).__init__(base_log_folder, filename_template) self.remote_base = s3_log_folder self.log_relative_path = '' - self.closed = False + self._hook = None + + def _build_hook(self): + remote_conn_id = configuration.get('core', 'REMOTE_LOG_CONN_ID') + try: + from airflow.hooks.S3_hook import S3Hook + return S3Hook(remote_conn_id) + except: + self.logger.error( + 'Could not create an S3Hook with connection id "%s". ' + 'Please make sure that airflow[s3] is installed and ' + 'the S3 connection exists.', remote_conn_id + ) + + @property + def hook(self): + if self._hook is None: + self._hook = self._build_hook() + return self._hook def set_context(self, ti): super(S3TaskHandler, self).set_context(ti) @@ -45,7 +62,7 @@ class S3TaskHandler(FileTaskHandler): # calling close method. Here we check if logger is already # closed to prevent uploading the log to remote storage multiple # times when `logging.shutdown` is called. - if self.closed: + if self._hook is None: return super(S3TaskHandler, self).close() @@ -56,9 +73,9 @@ class S3TaskHandler(FileTaskHandler): # read log and remove old logs to get just the latest additions with open(local_loc, 'r') as logfile: log = logfile.read() - logging_utils.S3Log().write(log, remote_loc) + self.s3_write(log, remote_loc) - self.closed = True + self._hook = None def _read(self, ti, try_number): """ @@ -73,15 +90,73 @@ class S3TaskHandler(FileTaskHandler): log_relative_path = self._render_filename(ti, try_number + 1) remote_loc = os.path.join(self.remote_base, log_relative_path) - s3_log = logging_utils.S3Log() - if s3_log.log_exists(remote_loc): + if self.s3_log_exists(remote_loc): # If S3 remote file exists, we do not fetch logs from task instance # local machine even if there are errors reading remote logs, as # returned remote_log will contain error messages. - remote_log = s3_log.read(remote_loc, return_error=True) + remote_log = self.s3_log_read(remote_loc, return_error=True) log = '*** Reading remote log from {}.\n{}\n'.format( remote_loc, remote_log) else: log = super(S3TaskHandler, self)._read(ti, try_number) return log + + def s3_log_exists(self, remote_log_location): + """ + Check if remote_log_location exists in remote storage + :param remote_log_location: log's location in remote storage + :return: True if location exists else False + """ + try: + return self.hook.get_key(remote_log_location) is not None + except Exception: + pass + return False + + def s3_log_read(self, remote_log_location, return_error=False): + """ + Returns the log found at the remote_log_location. Returns '' if no + logs are found or there is an error. + :param remote_log_location: the log's location in remote storage + :type remote_log_location: string (path) + :param return_error: if True, returns a string error message if an + error occurs. Otherwise returns '' when an error occurs. + :type return_error: bool + """ + try: + s3_key = self.hook.get_key(remote_log_location) + if s3_key: + return s3_key.get_contents_as_string().decode() + except: + # return error if needed + if return_error: + msg = 'Could not read logs from {}'.format(remote_log_location) + self.logger.error(msg) + return msg + + def s3_write(self, log, remote_log_location, append=True): + """ + Writes the log to the remote_log_location. Fails silently if no hook + was created. + :param log: the log to write to the remote_log_location + :type log: string + :param remote_log_location: the log's location in remote storage + :type remote_log_location: string (path) + :param append: if False, any existing log file is overwritten. If True, + the new log is appended to any existing logs. + :type append: bool + """ + if append: + old_log = self.read(remote_log_location) + log = '\n'.join([old_log, log]) + + try: + self.hook.load_string( + log, + key=remote_log_location, + replace=True, + encrypt=configuration.getboolean('core', 'ENCRYPT_S3_LOGS'), + ) + except: + self.logger.error('Could not write logs to %s', remote_log_location) http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/a7a51890/airflow/utils/logging.py ---------------------------------------------------------------------- diff --git a/airflow/utils/logging.py b/airflow/utils/logging.py deleted file mode 100644 index c550c88..0000000 --- a/airflow/utils/logging.py +++ /dev/null @@ -1,252 +0,0 @@ -# -*- coding: utf-8 -*- -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -from __future__ import unicode_literals - -from builtins import object - -import dateutil.parser -import logging -import six - -from airflow import configuration -from airflow.exceptions import AirflowException - - -class LoggingMixin(object): - """ - Convenience super-class to have a logger configured with the class name - """ - - @property - def logger(self): - try: - return self._logger - except AttributeError: - self._logger = logging.root.getChild(self.__class__.__module__ + '.' + self.__class__.__name__) - return self._logger - - def set_logger_contexts(self, task_instance): - """ - Set the context for all handlers of current logger. - """ - for handler in self.logger.handlers: - try: - handler.set_context(task_instance) - except AttributeError: - pass - - -class S3Log(object): - """ - Utility class for reading and writing logs in S3. - Requires airflow[s3] and setting the REMOTE_BASE_LOG_FOLDER and - REMOTE_LOG_CONN_ID configuration options in airflow.cfg. - """ - def __init__(self): - remote_conn_id = configuration.get('core', 'REMOTE_LOG_CONN_ID') - try: - from airflow.hooks.S3_hook import S3Hook - self.hook = S3Hook(remote_conn_id) - except: - self.hook = None - logging.error( - 'Could not create an S3Hook with connection id "{}". ' - 'Please make sure that airflow[s3] is installed and ' - 'the S3 connection exists.'.format(remote_conn_id)) - - def log_exists(self, remote_log_location): - """ - Check if remote_log_location exists in remote storage - :param remote_log_location: log's location in remote storage - :return: True if location exists else False - """ - if self.hook: - try: - return self.hook.get_key(remote_log_location) is not None - except Exception: - pass - return False - - def read(self, remote_log_location, return_error=False): - """ - Returns the log found at the remote_log_location. Returns '' if no - logs are found or there is an error. - - :param remote_log_location: the log's location in remote storage - :type remote_log_location: string (path) - :param return_error: if True, returns a string error message if an - error occurs. Otherwise returns '' when an error occurs. - :type return_error: bool - """ - if self.hook: - try: - s3_key = self.hook.get_key(remote_log_location) - if s3_key: - return s3_key.get_contents_as_string().decode() - except: - pass - - # return error if needed - if return_error: - msg = 'Could not read logs from {}'.format(remote_log_location) - logging.error(msg) - return msg - - return '' - - def write(self, log, remote_log_location, append=True): - """ - Writes the log to the remote_log_location. Fails silently if no hook - was created. - - :param log: the log to write to the remote_log_location - :type log: string - :param remote_log_location: the log's location in remote storage - :type remote_log_location: string (path) - :param append: if False, any existing log file is overwritten. If True, - the new log is appended to any existing logs. - :type append: bool - """ - if self.hook: - if append: - old_log = self.read(remote_log_location) - log = '\n'.join([old_log, log]) - - try: - self.hook.load_string( - log, - key=remote_log_location, - replace=True, - encrypt=configuration.getboolean('core', 'ENCRYPT_S3_LOGS'), - ) - except: - logging.error('Could not write logs to {}'.format(remote_log_location)) - - -class GCSLog(object): - """ - Utility class for reading and writing logs in GCS. Requires - airflow[gcp_api] and setting the REMOTE_BASE_LOG_FOLDER and - REMOTE_LOG_CONN_ID configuration options in airflow.cfg. - """ - def __init__(self): - """ - Attempt to create hook with airflow[gcp_api]. - """ - remote_conn_id = configuration.get('core', 'REMOTE_LOG_CONN_ID') - self.hook = None - - try: - from airflow.contrib.hooks.gcs_hook import GoogleCloudStorageHook - self.hook = GoogleCloudStorageHook( - google_cloud_storage_conn_id=remote_conn_id) - except: - logging.error( - 'Could not create a GoogleCloudStorageHook with connection id ' - '"{}". Please make sure that airflow[gcp_api] is installed ' - 'and the GCS connection exists.'.format(remote_conn_id)) - - def log_exists(self, remote_log_location): - """ - Check if remote_log_location exists in remote storage - :param remote_log_location: log's location in remote storage - :return: True if location exists else False - """ - if self.hook: - try: - bkt, blob = self.parse_gcs_url(remote_log_location) - return self.hook.exists(bkt, blob) - except Exception: - pass - return False - - def read(self, remote_log_location, return_error=False): - """ - Returns the log found at the remote_log_location. - - :param remote_log_location: the log's location in remote storage - :type remote_log_location: string (path) - :param return_error: if True, returns a string error message if an - error occurs. Otherwise returns '' when an error occurs. - :type return_error: bool - """ - if self.hook: - try: - bkt, blob = self.parse_gcs_url(remote_log_location) - return self.hook.download(bkt, blob).decode() - except: - pass - - # return error if needed - if return_error: - msg = 'Could not read logs from {}'.format(remote_log_location) - logging.error(msg) - return msg - - return '' - - def write(self, log, remote_log_location, append=True): - """ - Writes the log to the remote_log_location. Fails silently if no hook - was created. - - :param log: the log to write to the remote_log_location - :type log: string - :param remote_log_location: the log's location in remote storage - :type remote_log_location: string (path) - :param append: if False, any existing log file is overwritten. If True, - the new log is appended to any existing logs. - :type append: bool - """ - if self.hook: - if append: - old_log = self.read(remote_log_location) - log = '\n'.join([old_log, log]) - - try: - bkt, blob = self.parse_gcs_url(remote_log_location) - from tempfile import NamedTemporaryFile - with NamedTemporaryFile(mode='w+') as tmpfile: - tmpfile.write(log) - # Force the file to be flushed, since we're doing the - # upload from within the file context (it hasn't been - # closed). - tmpfile.flush() - self.hook.upload(bkt, blob, tmpfile.name) - except: - logging.error('Could not write logs to {}'.format(remote_log_location)) - - def parse_gcs_url(self, gsurl): - """ - Given a Google Cloud Storage URL (gs://<bucket>/<blob>), returns a - tuple containing the corresponding bucket and blob. - """ - # Python 3 - try: - from urllib.parse import urlparse - # Python 2 - except ImportError: - from urlparse import urlparse - - parsed_url = urlparse(gsurl) - if not parsed_url.netloc: - raise AirflowException('Please provide a bucket name') - else: - bucket = parsed_url.netloc - blob = parsed_url.path.strip('/') - return (bucket, blob) http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/a7a51890/airflow/utils/timeout.py ---------------------------------------------------------------------- diff --git a/airflow/utils/timeout.py b/airflow/utils/timeout.py index 62af9db..53f2149 100644 --- a/airflow/utils/timeout.py +++ b/airflow/utils/timeout.py @@ -17,24 +17,23 @@ from __future__ import division from __future__ import print_function from __future__ import unicode_literals -import logging import signal -from builtins import object - from airflow.exceptions import AirflowTaskTimeout +from airflow.utils.log.LoggingMixin import LoggingMixin -class timeout(object): +class timeout(LoggingMixin): """ To be used in a ``with`` block and timeout its content. """ + def __init__(self, seconds=1, error_message='Timeout'): self.seconds = seconds self.error_message = error_message def handle_timeout(self, signum, frame): - logging.error("Process timed out") + self.logger.error("Process timed out") raise AirflowTaskTimeout(self.error_message) def __enter__(self): @@ -42,12 +41,12 @@ class timeout(object): signal.signal(signal.SIGALRM, self.handle_timeout) signal.alarm(self.seconds) except ValueError as e: - logging.warning("timeout can't be used in the current context") - logging.exception(e) + self.logger.warning("timeout can't be used in the current context") + self.logger.exception(e) def __exit__(self, type, value, traceback): try: signal.alarm(0) except ValueError as e: - logging.warning("timeout can't be used in the current context") - logging.exception(e) + self.logger.warning("timeout can't be used in the current context") + self.logger.exception(e) http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/a7a51890/airflow/www/api/experimental/endpoints.py ---------------------------------------------------------------------- diff --git a/airflow/www/api/experimental/endpoints.py b/airflow/www/api/experimental/endpoints.py index 3e0ff46..4e5892d 100644 --- a/airflow/www/api/experimental/endpoints.py +++ b/airflow/www/api/experimental/endpoints.py @@ -11,9 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -import logging - import airflow.api from airflow.api.common.experimental import pool as pool_api @@ -21,6 +18,7 @@ from airflow.api.common.experimental import trigger_dag as trigger from airflow.api.common.experimental.get_task import get_task from airflow.api.common.experimental.get_task_instance import get_task_instance from airflow.exceptions import AirflowException +from airflow.utils.log.LoggingMixin import LoggingMixin from airflow.www.app import csrf from flask import ( @@ -29,7 +27,7 @@ from flask import ( ) from datetime import datetime -_log = logging.getLogger(__name__) +_log = LoggingMixin().logger requires_authentication = airflow.api.api_auth.requires_authentication http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/a7a51890/airflow/www/app.py ---------------------------------------------------------------------- diff --git a/airflow/www/app.py b/airflow/www/app.py index 1ae2731..f280713 100644 --- a/airflow/www/app.py +++ b/airflow/www/app.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import logging import socket import six @@ -23,7 +22,7 @@ from flask_wtf.csrf import CSRFProtect csrf = CSRFProtect() import airflow -from airflow import models +from airflow import models, LoggingMixin from airflow.settings import Session from airflow.www.blueprints import routes @@ -114,16 +113,17 @@ def create_app(config=None, testing=False): def integrate_plugins(): """Integrate plugins to the context""" + log = LoggingMixin().logger from airflow.plugins_manager import ( admin_views, flask_blueprints, menu_links) for v in admin_views: - logging.debug('Adding view ' + v.name) + log.debug('Adding view %s', v.name) admin.add_view(v) for bp in flask_blueprints: - logging.debug('Adding blueprint ' + bp.name) + log.debug('Adding blueprint %s', bp.name) app.register_blueprint(bp) for ml in sorted(menu_links, key=lambda x: x.name): - logging.debug('Adding menu link ' + ml.name) + log.debug('Adding menu link %s', ml.name) admin.add_link(ml) integrate_plugins() http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/a7a51890/airflow/www/views.py ---------------------------------------------------------------------- diff --git a/airflow/www/views.py b/airflow/www/views.py index 655d95a..447c19f 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -68,17 +68,14 @@ from airflow.ti_deps.dep_context import DepContext, QUEUE_DEPS, SCHEDULER_DEPS from airflow.models import BaseOperator from airflow.operators.subdag_operator import SubDagOperator -from airflow.utils.logging import LoggingMixin from airflow.utils.json import json_ser from airflow.utils.state import State from airflow.utils.db import provide_session from airflow.utils.helpers import alchemy_to_dict -from airflow.utils import logging as log_utils from airflow.utils.dates import infer_time_unit, scale_time_units from airflow.www import utils as wwwutils from airflow.www.forms import DateTimeForm, DateTimeWithNumRunsForm from airflow.www.validators import GreaterEqualThan -from airflow.configuration import AirflowConfigException QUERY_LIMIT = 100000 CHART_LIMIT = 200000 @@ -2604,7 +2601,7 @@ class UserModelView(wwwutils.SuperUserMixin, AirflowModelView): column_default_sort = 'username' -class VersionView(wwwutils.SuperUserMixin, LoggingMixin, BaseView): +class VersionView(wwwutils.SuperUserMixin, BaseView): @expose('/') def version(self): # Look at the version from setup.py @@ -2612,7 +2609,7 @@ class VersionView(wwwutils.SuperUserMixin, LoggingMixin, BaseView): airflow_version = pkg_resources.require("apache-airflow")[0].version except Exception as e: airflow_version = None - self.logger.error(e) + logging.error(e) # Get the Git repo and git hash git_version = None @@ -2620,7 +2617,7 @@ class VersionView(wwwutils.SuperUserMixin, LoggingMixin, BaseView): with open(os.path.join(*[settings.AIRFLOW_HOME, 'airflow', 'git_version'])) as f: git_version = f.readline() except Exception as e: - self.logger.error(e) + logging.error(e) # Render information title = "Version Info" http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/a7a51890/setup.py ---------------------------------------------------------------------- diff --git a/setup.py b/setup.py index b22e4e8..0ddc0f0 100644 --- a/setup.py +++ b/setup.py @@ -99,16 +99,6 @@ def write_version(filename=os.path.join(*['airflow', with open(filename, 'w') as a: a.write(text) - -def check_previous(): - installed_packages = ([package.project_name for package - in pip.get_installed_distributions()]) - if 'airflow' in installed_packages: - print("An earlier non-apache version of Airflow was installed, " - "please uninstall it first. Then reinstall.") - sys.exit(1) - - async = [ 'greenlet>=0.4.9', 'eventlet>= 0.9.7', @@ -206,7 +196,6 @@ devel_all = devel + all_dbs + doc + samba + s3 + slack + crypto + oracle + docke def do_setup(): - check_previous() write_version() setup( name='apache-airflow', http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/a7a51890/tests/contrib/hooks/test_databricks_hook.py ---------------------------------------------------------------------- diff --git a/tests/contrib/hooks/test_databricks_hook.py b/tests/contrib/hooks/test_databricks_hook.py index 56288a1..e091067 100644 --- a/tests/contrib/hooks/test_databricks_hook.py +++ b/tests/contrib/hooks/test_databricks_hook.py @@ -108,18 +108,17 @@ class DatabricksHookTest(unittest.TestCase): with self.assertRaises(AssertionError): DatabricksHook(retry_limit = 0) - @mock.patch('airflow.contrib.hooks.databricks_hook.logging') @mock.patch('airflow.contrib.hooks.databricks_hook.requests') - def test_do_api_call_with_error_retry(self, mock_requests, mock_logging): + def test_do_api_call_with_error_retry(self, mock_requests): for exception in [requests_exceptions.ConnectionError, requests_exceptions.Timeout]: - mock_requests.reset_mock() - mock_logging.reset_mock() - mock_requests.post.side_effect = exception() + with mock.patch.object(self.hook.logger, 'error') as mock_errors: + mock_requests.reset_mock() + mock_requests.post.side_effect = exception() - with self.assertRaises(AirflowException): - self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {}) + with self.assertRaises(AirflowException): + self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {}) - self.assertEquals(len(mock_logging.error.mock_calls), self.hook.retry_limit) + self.assertEquals(len(mock_errors.mock_calls), self.hook.retry_limit) @mock.patch('airflow.contrib.hooks.databricks_hook.requests') def test_do_api_call_with_bad_status_code(self, mock_requests): http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/a7a51890/tests/contrib/operators/test_dataproc_operator.py ---------------------------------------------------------------------- diff --git a/tests/contrib/operators/test_dataproc_operator.py b/tests/contrib/operators/test_dataproc_operator.py index 71edf58..89ad258 100644 --- a/tests/contrib/operators/test_dataproc_operator.py +++ b/tests/contrib/operators/test_dataproc_operator.py @@ -27,6 +27,7 @@ from copy import deepcopy from mock import Mock from mock import patch + TASK_ID = 'test-dataproc-operator' CLUSTER_NAME = 'test-cluster-name' PROJECT_ID = 'test-project-id' @@ -53,6 +54,7 @@ class DataprocClusterCreateOperatorTest(unittest.TestCase): # instantiate two different test cases with different labels. self.labels = [LABEL1, LABEL2] self.dataproc_operators = [] + self.mock_conn = Mock() for labels in self.labels: self.dataproc_operators.append( DataprocClusterCreateOperator( @@ -120,8 +122,8 @@ class DataprocClusterCreateOperatorTest(unittest.TestCase): self.assertEqual(cluster_data['labels'], merged_labels) def test_cluster_name_log_no_sub(self): - with patch('airflow.contrib.operators.dataproc_operator.DataProcHook') \ - as mock_hook, patch('logging.info') as l: + with patch('airflow.contrib.operators.dataproc_operator.DataProcHook') as mock_hook: + mock_hook.return_value.get_conn = self.mock_conn dataproc_task = DataprocClusterCreateOperator( task_id=TASK_ID, cluster_name=CLUSTER_NAME, @@ -130,14 +132,14 @@ class DataprocClusterCreateOperatorTest(unittest.TestCase): zone=ZONE, dag=self.dag ) - - with self.assertRaises(TypeError) as _: - dataproc_task.execute(None) - l.assert_called_with(('Creating cluster: ' + CLUSTER_NAME)) + with patch.object(dataproc_task.logger, 'info') as mock_info: + with self.assertRaises(TypeError) as _: + dataproc_task.execute(None) + mock_info.assert_called_with('Creating cluster: %s', CLUSTER_NAME) def test_cluster_name_log_sub(self): - with patch('airflow.contrib.operators.dataproc_operator.DataProcHook') \ - as mock_hook, patch('logging.info') as l: + with patch('airflow.contrib.operators.dataproc_operator.DataProcHook') as mock_hook: + mock_hook.return_value.get_conn = self.mock_conn dataproc_task = DataprocClusterCreateOperator( task_id=TASK_ID, cluster_name='smoke-cluster-{{ ts_nodash }}', @@ -146,14 +148,14 @@ class DataprocClusterCreateOperatorTest(unittest.TestCase): zone=ZONE, dag=self.dag ) + with patch.object(dataproc_task.logger, 'info') as mock_info: + context = { 'ts_nodash' : 'testnodash'} - context = { 'ts_nodash' : 'testnodash'} - - rendered = dataproc_task.render_template('cluster_name', getattr(dataproc_task,'cluster_name'), context) - setattr(dataproc_task, 'cluster_name', rendered) - with self.assertRaises(TypeError) as _: - dataproc_task.execute(None) - l.assert_called_with(('Creating cluster: smoke-cluster-testnodash')) + rendered = dataproc_task.render_template('cluster_name', getattr(dataproc_task,'cluster_name'), context) + setattr(dataproc_task, 'cluster_name', rendered) + with self.assertRaises(TypeError) as _: + dataproc_task.execute(None) + mock_info.assert_called_with('Creating cluster: %s', u'smoke-cluster-testnodash') class DataprocClusterDeleteOperatorTest(unittest.TestCase): # Unitest for the DataprocClusterDeleteOperator @@ -180,8 +182,7 @@ class DataprocClusterDeleteOperatorTest(unittest.TestCase): schedule_interval='@daily') def test_cluster_name_log_no_sub(self): - with patch('airflow.contrib.hooks.gcp_dataproc_hook.DataProcHook') \ - as mock_hook, patch('logging.info') as l: + with patch('airflow.contrib.hooks.gcp_dataproc_hook.DataProcHook') as mock_hook: mock_hook.return_value.get_conn = self.mock_conn dataproc_task = DataprocClusterDeleteOperator( task_id=TASK_ID, @@ -189,14 +190,13 @@ class DataprocClusterDeleteOperatorTest(unittest.TestCase): project_id=PROJECT_ID, dag=self.dag ) - - with self.assertRaises(TypeError) as _: - dataproc_task.execute(None) - l.assert_called_with(('Deleting cluster: ' + CLUSTER_NAME)) + with patch.object(dataproc_task.logger, 'info') as mock_info: + with self.assertRaises(TypeError) as _: + dataproc_task.execute(None) + mock_info.assert_called_with('Deleting cluster: %s', CLUSTER_NAME) def test_cluster_name_log_sub(self): - with patch('airflow.contrib.operators.dataproc_operator.DataProcHook') \ - as mock_hook, patch('logging.info') as l: + with patch('airflow.contrib.operators.dataproc_operator.DataProcHook') as mock_hook: mock_hook.return_value.get_conn = self.mock_conn dataproc_task = DataprocClusterDeleteOperator( task_id=TASK_ID, @@ -205,10 +205,11 @@ class DataprocClusterDeleteOperatorTest(unittest.TestCase): dag=self.dag ) - context = { 'ts_nodash' : 'testnodash'} + with patch.object(dataproc_task.logger, 'info') as mock_info: + context = { 'ts_nodash' : 'testnodash'} - rendered = dataproc_task.render_template('cluster_name', getattr(dataproc_task,'cluster_name'), context) - setattr(dataproc_task, 'cluster_name', rendered) - with self.assertRaises(TypeError) as _: - dataproc_task.execute(None) - l.assert_called_with(('Deleting cluster: smoke-cluster-testnodash')) + rendered = dataproc_task.render_template('cluster_name', getattr(dataproc_task,'cluster_name'), context) + setattr(dataproc_task, 'cluster_name', rendered) + with self.assertRaises(TypeError) as _: + dataproc_task.execute(None) + mock_info.assert_called_with('Deleting cluster: %s', u'smoke-cluster-testnodash') http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/a7a51890/tests/core.py ---------------------------------------------------------------------- diff --git a/tests/core.py b/tests/core.py index 8eeec82..acc543b 100644 --- a/tests/core.py +++ b/tests/core.py @@ -29,7 +29,6 @@ from datetime import datetime, time, timedelta from email.mime.multipart import MIMEMultipart from email.mime.application import MIMEApplication import signal -from time import time as timetime from time import sleep import warnings @@ -37,7 +36,7 @@ from dateutil.relativedelta import relativedelta import sqlalchemy from airflow import configuration -from airflow.executors import SequentialExecutor, LocalExecutor +from airflow.executors import SequentialExecutor from airflow.models import Variable from tests.test_utils.fake_datetime import FakeDatetime @@ -53,13 +52,11 @@ from airflow.operators.http_operator import SimpleHttpOperator from airflow.operators import sensors from airflow.hooks.base_hook import BaseHook from airflow.hooks.sqlite_hook import SqliteHook -from airflow.hooks.postgres_hook import PostgresHook from airflow.bin import cli from airflow.www import app as application from airflow.settings import Session from airflow.utils.state import State from airflow.utils.dates import infer_time_unit, round_time, scale_time_units -from airflow.utils.logging import LoggingMixin from lxml import html from airflow.exceptions import AirflowException from airflow.configuration import AirflowConfigException, run_command @@ -805,17 +802,6 @@ class CoreTest(unittest.TestCase): # restore the envvar back to the original state del os.environ[key] - def test_class_with_logger_should_have_logger_with_correct_name(self): - - # each class should automatically receive a logger with a correct name - - class Blah(LoggingMixin): - pass - - self.assertEqual("tests.core.Blah", Blah().logger.name) - self.assertEqual("airflow.executors.sequential_executor.SequentialExecutor", SequentialExecutor().logger.name) - self.assertEqual("airflow.executors.local_executor.LocalExecutor", LocalExecutor().logger.name) - def test_round_time(self): rt1 = round_time(datetime(2015, 1, 1, 6), timedelta(days=1)) http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/a7a51890/tests/operators/sensors.py ---------------------------------------------------------------------- diff --git a/tests/operators/sensors.py b/tests/operators/sensors.py index 9a40a05..9b256e6 100644 --- a/tests/operators/sensors.py +++ b/tests/operators/sensors.py @@ -11,29 +11,32 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - import logging -import os import sys import time import unittest - -from mock import patch from datetime import datetime, timedelta +from mock import patch -from airflow import DAG, configuration, jobs, settings -from airflow.jobs import BackfillJob, SchedulerJob -from airflow.models import TaskInstance, DagModel, DagBag -from airflow.operators.sensors import HttpSensor, BaseSensorOperator, HdfsSensor, ExternalTaskSensor -from airflow.operators.bash_operator import BashOperator -from airflow.operators.dummy_operator import DummyOperator -from airflow.utils.decorators import apply_defaults +from airflow import DAG, configuration, settings from airflow.exceptions import (AirflowException, AirflowSensorTimeout, AirflowSkipException) +from airflow.models import TaskInstance +from airflow.operators.bash_operator import BashOperator +from airflow.operators.dummy_operator import DummyOperator +from airflow.operators.sensors import HttpSensor, BaseSensorOperator, HdfsSensor, ExternalTaskSensor +from airflow.utils.decorators import apply_defaults from airflow.utils.state import State -from tests.core import TEST_DAG_FOLDER + +try: + from unittest import mock +except ImportError: + try: + import mock + except ImportError: + mock = None + configuration.load_test_config() DEFAULT_DATE = datetime(2015, 1, 1) @@ -72,7 +75,7 @@ class TimeoutTestSensor(BaseSensorOperator): else: raise AirflowSensorTimeout('Snap. Time is OUT.') time.sleep(self.poke_interval) - logging.info("Success criteria met. Exiting.") + self.logger.info("Success criteria met. Exiting.") class SensorTimeoutTest(unittest.TestCase): @@ -158,10 +161,8 @@ class HttpSensorTests(unittest.TestCase): self.assertTrue(prep_request.method, received_request.method) @patch("airflow.hooks.http_hook.requests.Session.send") - @patch("airflow.hooks.http_hook.logging.error") def test_logging_head_error_request( self, - mock_error_logging, mock_session_send ): @@ -183,13 +184,15 @@ class HttpSensorTests(unittest.TestCase): method='HEAD', response_check=resp_check, timeout=5, - poke_interval=1) + poke_interval=1 + ) - with self.assertRaises(AirflowSensorTimeout): - task.execute(None) + with mock.patch.object(task.hook.logger, 'error') as mock_errors: + with self.assertRaises(AirflowSensorTimeout): + task.execute(None) - self.assertTrue(mock_error_logging.called) - mock_error_logging.assert_called_with('HTTP error: Not Found') + self.assertTrue(mock_errors.called) + mock_errors.assert_called_with('HTTP error: %s', 'Not Found') class HdfsSensorTests(unittest.TestCase): @@ -199,8 +202,6 @@ class HdfsSensorTests(unittest.TestCase): raise unittest.SkipTest('HdfsSensor won\'t work with python3. No need to test anything here') from tests.core import FakeHDFSHook self.hook = FakeHDFSHook - self.logger = logging.getLogger() - self.logger.setLevel(logging.DEBUG) def test_legacy_file_exist(self): """ @@ -208,7 +209,7 @@ class HdfsSensorTests(unittest.TestCase): :return: """ # Given - self.logger.info("Test for existing file with the legacy behaviour") + logging.info("Test for existing file with the legacy behaviour") # When task = HdfsSensor(task_id='Should_be_file_legacy', filepath='/datadirectory/datafile', @@ -227,7 +228,7 @@ class HdfsSensorTests(unittest.TestCase): :return: """ # Given - self.logger.info("Test for existing file with the legacy behaviour") + logging.info("Test for existing file with the legacy behaviour") # When task = HdfsSensor(task_id='Should_be_file_legacy', filepath='/datadirectory/datafile', @@ -248,7 +249,7 @@ class HdfsSensorTests(unittest.TestCase): :return: """ # Given - self.logger.info("Test for non existing file with the legacy behaviour") + logging.info("Test for non existing file with the legacy behaviour") task = HdfsSensor(task_id='Should_not_be_file_legacy', filepath='/datadirectory/not_existing_file_or_directory', timeout=1, http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/a7a51890/tests/utils/log/test_logging.py ---------------------------------------------------------------------- diff --git a/tests/utils/log/test_logging.py b/tests/utils/log/test_logging.py new file mode 100644 index 0000000..7e05c7d --- /dev/null +++ b/tests/utils/log/test_logging.py @@ -0,0 +1,108 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import mock +import unittest + +from airflow.utils.log.s3_task_handler import S3TaskHandler + + +class TestS3TaskHandler(unittest.TestCase): + + def setUp(self): + super(S3TaskHandler, self).setUp() + self.remote_log_location = 'remote/log/location' + self.hook_patcher = mock.patch("airflow.hooks.S3_hook.S3Hook") + self.hook_mock = self.hook_patcher.start() + self.hook_inst_mock = self.hook_mock.return_value + self.hook_key_mock = self.hook_inst_mock.get_key.return_value + self.hook_key_mock.get_contents_as_string.return_value.decode.\ + return_value = 'content' + + def tearDown(self): + self.hook_patcher.stop() + super(S3TaskHandler, self).tearDown() + + def test_init(self): + S3TaskHandler() + self.hook_mock.assert_called_once_with('') + + def test_init_raises(self): + self.hook_mock.side_effect = Exception('Failed to connect') + handler = S3TaskHandler() + with mock.patch.object(handler.logger, 'error') as mock_error: + # Initialize the hook + handler.hook() + mock_error.assert_called_once_with( + 'Could not create an S3Hook with connection id "". Please make ' + 'sure that airflow[s3] is installed and the S3 connection exists.' + ) + + def test_log_exists(self): + self.assertTrue(S3TaskHandler().log_exists(self.remote_log_location)) + + def test_log_exists_none(self): + self.hook_inst_mock.get_key.return_value = None + self.assertFalse(S3TaskHandler().log_exists(self.remote_log_location)) + + def test_log_exists_raises(self): + self.hook_inst_mock.get_key.side_effect = Exception('error') + self.assertFalse(S3TaskHandler().log_exists(self.remote_log_location)) + + def test_log_exists_no_hook(self): + self.hook_mock.side_effect = Exception('Failed to connect') + self.assertFalse(S3TaskHandler().log_exists(self.remote_log_location)) + + def test_read(self): + self.assertEqual( + S3TaskHandler().read(self.remote_log_location), + 'content' + ) + + def test_read_key_empty(self): + self.hook_inst_mock.get_key.return_value = None + self.assertEqual(S3TaskHandler().read(self.remote_log_location), '') + + def test_read_raises(self): + self.hook_inst_mock.get_key.side_effect = Exception('error') + self.assertEqual(S3TaskHandler().read(self.remote_log_location), '') + + def test_read_raises_return_error(self): + self.hook_inst_mock.get_key.side_effect = Exception('error') + handler = S3TaskHandler() + with mock.patch.object(handler.logger, 'error') as mock_error: + result = handler.s3_log_read( + self.remote_log_location, + return_error=True + ) + msg = 'Could not read logs from %s' % self.remote_log_location + self.assertEqual(result, msg) + mock_error.assert_called_once_with(msg) + + def test_write(self): + S3TaskHandler().write('text', self.remote_log_location) + self.hook_inst_mock.load_string.assert_called_once_with( + 'content\ntext', + key=self.remote_log_location, + replace=True, + encrypt=False, + ) + + def test_write_raises(self): + self.hook_inst_mock.load_string.side_effect = Exception('error') + handler = S3TaskHandler() + with mock.patch.object(handler.logger, 'error') as mock_error: + handler.write('text', self.remote_log_location) + msg = 'Could not write logs to %s' % self.remote_log_location + mock_error.assert_called_once_with(msg) http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/a7a51890/tests/utils/test_logging.py ---------------------------------------------------------------------- diff --git a/tests/utils/test_logging.py b/tests/utils/test_logging.py deleted file mode 100644 index 72c5d49..0000000 --- a/tests/utils/test_logging.py +++ /dev/null @@ -1,103 +0,0 @@ -# -*- coding: utf-8 -*- -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import mock -import unittest - -from airflow.utils import logging -from datetime import datetime - -DEFAULT_DATE = datetime(2016, 1, 1) - - -class TestS3Log(unittest.TestCase): - - def setUp(self): - super(TestS3Log, self).setUp() - self.remote_log_location = 'remote/log/location' - self.hook_patcher = mock.patch("airflow.hooks.S3_hook.S3Hook") - self.hook_mock = self.hook_patcher.start() - self.hook_inst_mock = self.hook_mock.return_value - self.hook_key_mock = self.hook_inst_mock.get_key.return_value - self.hook_key_mock.get_contents_as_string.return_value.decode.\ - return_value = 'content' - self.logging_patcher = mock.patch("airflow.utils.logging.logging") - self.logging_mock = self.logging_patcher.start() - - def tearDown(self): - self.logging_patcher.stop() - self.hook_patcher.stop() - super(TestS3Log, self).tearDown() - - def test_init(self): - logging.S3Log() - self.hook_mock.assert_called_once_with('') - - def test_init_raises(self): - self.hook_mock.side_effect = Exception('Failed to connect') - logging.S3Log() - self.logging_mock.error.assert_called_once_with( - 'Could not create an S3Hook with connection id "". Please make ' - 'sure that airflow[s3] is installed and the S3 connection exists.' - ) - - def test_log_exists(self): - self.assertTrue(logging.S3Log().log_exists(self.remote_log_location)) - - def test_log_exists_none(self): - self.hook_inst_mock.get_key.return_value = None - self.assertFalse(logging.S3Log().log_exists(self.remote_log_location)) - - def test_log_exists_raises(self): - self.hook_inst_mock.get_key.side_effect = Exception('error') - self.assertFalse(logging.S3Log().log_exists(self.remote_log_location)) - - def test_log_exists_no_hook(self): - self.hook_mock.side_effect = Exception('Failed to connect') - self.assertFalse(logging.S3Log().log_exists(self.remote_log_location)) - - def test_read(self): - self.assertEqual(logging.S3Log().read(self.remote_log_location), - 'content') - - def test_read_key_empty(self): - self.hook_inst_mock.get_key.return_value = None - self.assertEqual(logging.S3Log().read(self.remote_log_location), '') - - def test_read_raises(self): - self.hook_inst_mock.get_key.side_effect = Exception('error') - self.assertEqual(logging.S3Log().read(self.remote_log_location), '') - - def test_read_raises_return_error(self): - self.hook_inst_mock.get_key.side_effect = Exception('error') - result = logging.S3Log().read(self.remote_log_location, - return_error=True) - msg = 'Could not read logs from %s' % self.remote_log_location - self.assertEqual(result, msg) - self.logging_mock.error.assert_called_once_with(msg) - - def test_write(self): - logging.S3Log().write('text', self.remote_log_location) - self.hook_inst_mock.load_string.assert_called_once_with( - 'content\ntext', - key=self.remote_log_location, - replace=True, - encrypt=False, - ) - - def test_write_raises(self): - self.hook_inst_mock.load_string.side_effect = Exception('error') - logging.S3Log().write('text', self.remote_log_location) - msg = 'Could not write logs to %s' % self.remote_log_location - self.logging_mock.error.assert_called_once_with(msg)
