Repository: incubator-airflow
Updated Branches:
  refs/heads/master 8f9a466de -> b56cb5cc9


[AIRFLOW-219][AIRFLOW-398] Cgroups + impersonation

Submitting on behalf of plypaul

Please accept this PR that addresses the following
issues:
-
https://issues.apache.org/jira/browse/AIRFLOW-219
-
https://issues.apache.org/jira/browse/AIRFLOW-398

Testing Done:
- Running on Airbnb prod (though on a different
mergebase) for many months

Credits:
Impersonation Work: georgeke did most of the work
but plypaul did quite a bit of work too.
Cgroups: plypaul did most of the work, I just did
some touch up/bug fixes (see commit history,
cgroups + impersonation commit is actually plypaul
's not mine)

Closes #1934 from aoen/ddavydov/cgroups_and_impers
onation_after_rebase


Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/b56cb5cc
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/b56cb5cc
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/b56cb5cc

Branch: refs/heads/master
Commit: b56cb5cc97de074bb0e520f66b79e7eb2d913fb1
Parents: 8f9a466
Author: Dan Davydov <dan.davy...@airbnb.com>
Authored: Wed Jan 18 18:11:01 2017 -0800
Committer: Dan Davydov <dan.davy...@airbnb.com>
Committed: Wed Jan 18 18:11:06 2017 -0800

----------------------------------------------------------------------
 .travis.yml                                     |   2 +-
 airflow/bin/cli.py                              |  96 +++++++--
 airflow/configuration.py                        |   7 +
 airflow/contrib/task_runner/__init__.py         |  13 ++
 .../contrib/task_runner/cgroup_task_runner.py   | 202 +++++++++++++++++++
 airflow/jobs.py                                 |  67 +++---
 .../1a5a9e6bf2b5_add_state_index_for_dagruns.py |  37 ++++
 airflow/models.py                               |  92 ++++++---
 airflow/settings.py                             |  23 ++-
 airflow/task_runner/__init__.py                 |  38 ++++
 airflow/task_runner/base_task_runner.py         | 153 ++++++++++++++
 airflow/task_runner/bash_task_runner.py         |  39 ++++
 airflow/utils/file.py                           |  23 +++
 airflow/utils/helpers.py                        |  79 +++++++-
 docs/security.rst                               |  22 ++
 run_unit_tests.sh                               |  14 ++
 scripts/ci/airflow_travis.cfg                   |   1 +
 scripts/ci/requirements.txt                     |   1 +
 setup.py                                        |   4 +
 tests/__init__.py                               |   1 +
 tests/dags/test_default_impersonation.py        |  44 ++++
 tests/dags/test_impersonation.py                |  45 +++++
 tests/dags/test_no_impersonation.py             |  43 ++++
 tests/impersonation.py                          | 111 ++++++++++
 24 files changed, 1061 insertions(+), 96 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/b56cb5cc/.travis.yml
----------------------------------------------------------------------
diff --git a/.travis.yml b/.travis.yml
index 407e7f9..90f33e3 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -89,7 +89,7 @@ cache:
     - $HOME/.wheelhouse/
     - $HOME/.travis_cache/
 before_install:
-  - ssh-keygen -t rsa -C your_em...@youremail.com -P '' -f ~/.ssh/id_rsa
+  - yes | ssh-keygen -t rsa -C your_em...@youremail.com -P '' -f ~/.ssh/id_rsa
   - cat ~/.ssh/id_rsa.pub >> ~/.ssh/authorized_keys
   - ln -s ~/.ssh/authorized_keys ~/.ssh/authorized_keys2
   - chmod 600 ~/.ssh/*

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/b56cb5cc/airflow/bin/cli.py
----------------------------------------------------------------------
diff --git a/airflow/bin/cli.py b/airflow/bin/cli.py
index d55fdfc..736df0a 100755
--- a/airflow/bin/cli.py
+++ b/airflow/bin/cli.py
@@ -22,7 +22,6 @@ import os
 import subprocess
 import textwrap
 import warnings
-from datetime import datetime
 from importlib import import_module
 
 import argparse
@@ -53,7 +52,7 @@ from airflow.models import (DagModel, DagBag, TaskInstance,
 from airflow.ti_deps.dep_context import (DepContext, SCHEDULER_DEPS)
 from airflow.utils import db as db_utils
 from airflow.utils import logging as logging_utils
-from airflow.utils.state import State
+from airflow.utils.file import mkdirs
 from airflow.www.app import cached_app
 
 from sqlalchemy import func
@@ -300,6 +299,7 @@ def export_helper(filepath):
         varfile.write(json.dumps(var_dict, sort_keys=True, indent=4))
     print("{} variables successfully exported to {}".format(len(var_dict), 
filepath))
 
+
 def pause(args, dag=None):
     set_is_paused(True, args, dag)
 
@@ -329,19 +329,65 @@ def run(args, dag=None):
     if dag:
         args.dag_id = dag.dag_id
 
-    # Setting up logging
-    log_base = os.path.expanduser(conf.get('core', 'BASE_LOG_FOLDER'))
-    directory = log_base + "/{args.dag_id}/{args.task_id}".format(args=args)
-    if not os.path.exists(directory):
-        os.makedirs(directory)
-    iso = args.execution_date.isoformat()
-    filename = "{directory}/{iso}".format(**locals())
+    # Load custom airflow config
+    if args.cfg_path:
+        with open(args.cfg_path, 'r') as conf_file:
+           conf_dict = json.load(conf_file)
+
+        if os.path.exists(args.cfg_path):
+            os.remove(args.cfg_path)
+
+        for section, config in conf_dict.items():
+            for option, value in config.items():
+                conf.set(section, option, value)
+        settings.configure_vars()
+        settings.configure_orm()
 
     logging.root.handlers = []
-    logging.basicConfig(
-        filename=filename,
-        level=settings.LOGGING_LEVEL,
-        format=settings.LOG_FORMAT)
+    if args.raw:
+        # Output to STDOUT for the parent process to read and log
+        logging.basicConfig(
+            stream=sys.stdout,
+            level=settings.LOGGING_LEVEL,
+            format=settings.LOG_FORMAT)
+    else:
+        # Setting up logging to a file.
+
+        # To handle log writing when tasks are impersonated, the log files 
need to
+        # be writable by the user that runs the Airflow command and the user
+        # that is impersonated. This is mainly to handle corner cases with the
+        # SubDagOperator. When the SubDagOperator is run, all of the operators
+        # run under the impersonated user and create appropriate log files
+        # as the impersonated user. However, if the user manually runs tasks
+        # of the SubDagOperator through the UI, then the log files are created
+        # by the user that runs the Airflow command. For example, the Airflow
+        # run command may be run by the `airflow_sudoable` user, but the 
Airflow
+        # tasks may be run by the `airflow` user. If the log files are not
+        # writable by both users, then it's possible that re-running a task
+        # via the UI (or vice versa) results in a permission error as the task
+        # tries to write to a log file created by the other user.
+        log_base = os.path.expanduser(conf.get('core', 'BASE_LOG_FOLDER'))
+        directory = log_base + 
"/{args.dag_id}/{args.task_id}".format(args=args)
+        # Create the log file and give it group writable permissions
+        # TODO(aoen): Make log dirs and logs globally readable for now since 
the SubDag
+        # operator is not compatible with impersonation (e.g. if a Celery 
executor is used
+        # for a SubDag operator and the SubDag operator has a different owner 
than the
+        # parent DAG)
+        if not os.path.exists(directory):
+            # Create the directory as globally writable using custom mkdirs
+            # as os.makedirs doesn't set mode properly.
+            mkdirs(directory, 0o775)
+        iso = args.execution_date.isoformat()
+        filename = "{directory}/{iso}".format(**locals())
+
+        if not os.path.exists(filename):
+            open(filename, "a").close()
+            os.chmod(filename, 0o666)
+
+        logging.basicConfig(
+            filename=filename,
+            level=settings.LOGGING_LEVEL,
+            format=settings.LOG_FORMAT)
 
     if not args.pickle and not dag:
         dag = get_dag(args)
@@ -413,6 +459,10 @@ def run(args, dag=None):
         executor.heartbeat()
         executor.end()
 
+    # Child processes should not flush or upload to remote
+    if args.raw:
+        return
+
     # Force the log to flush, and set the handler to go back to normal so we
     # don't continue logging to the task's log file. The flush is important
     # because we subsequently read from the log to insert into S3 or Google
@@ -626,7 +676,7 @@ def restart_workers(gunicorn_master_proc, 
num_workers_expected):
     def start_refresh(gunicorn_master_proc):
         batch_size = conf.getint('webserver', 'worker_refresh_batch_size')
         logging.debug('%s doing a refresh of %s workers',
-            state, batch_size)
+                      state, batch_size)
         sys.stdout.flush()
         sys.stderr.flush()
 
@@ -635,11 +685,10 @@ def restart_workers(gunicorn_master_proc, 
num_workers_expected):
             gunicorn_master_proc.send_signal(signal.SIGTTIN)
             excess += 1
             wait_until_true(lambda: num_workers_expected + excess ==
-                get_num_workers_running(gunicorn_master_proc))
-
+                            get_num_workers_running(gunicorn_master_proc))
 
     wait_until_true(lambda: num_workers_expected ==
-        get_num_workers_running(gunicorn_master_proc))
+                    get_num_workers_running(gunicorn_master_proc))
 
     while True:
         num_workers_running = get_num_workers_running(gunicorn_master_proc)
@@ -662,7 +711,7 @@ def restart_workers(gunicorn_master_proc, 
num_workers_expected):
                 gunicorn_master_proc.send_signal(signal.SIGTTOU)
                 excess -= 1
                 wait_until_true(lambda: num_workers_expected + excess ==
-                    get_num_workers_running(gunicorn_master_proc))
+                                get_num_workers_running(gunicorn_master_proc))
 
         # Start a new worker by asking gunicorn to increase number of workers
         elif num_workers_running == num_workers_expected:
@@ -761,7 +810,8 @@ def webserver(args):
         if conf.getint('webserver', 'worker_refresh_interval') > 0:
             restart_workers(gunicorn_master_proc, num_workers)
         else:
-            while True: time.sleep(1)
+            while True:
+                time.sleep(1)
 
 
 def scheduler(args):
@@ -920,7 +970,7 @@ def connections(args):
                               Connection.is_encrypted,
                               Connection.is_extra_encrypted,
                               Connection.extra).all()
-        conns = [map(reprlib.repr, conn) for conn in conns] 
+        conns = [map(reprlib.repr, conn) for conn in conns]
         print(tabulate(conns, ['Conn Id', 'Conn Type', 'Host', 'Port',
                                'Is Encrypted', 'Is Extra Encrypted', 'Extra'],
                        tablefmt="fancy_grid"))
@@ -1255,6 +1305,8 @@ class CLIFactory(object):
             ("-p", "--pickle"),
             "Serialized pickle object of the entire dag (used internally)"),
         'job_id': Arg(("-j", "--job_id"), argparse.SUPPRESS),
+        'cfg_path': Arg(
+            ("--cfg_path", ), "Path to config file to use instead of 
airflow.cfg"),
         # webserver
         'port': Arg(
             ("-p", "--port"),
@@ -1433,7 +1485,7 @@ class CLIFactory(object):
             'help': "Run a single task instance",
             'args': (
                 'dag_id', 'task_id', 'execution_date', 'subdir',
-                'mark_success', 'force', 'pool',
+                'mark_success', 'force', 'pool', 'cfg_path',
                 'local', 'raw', 'ignore_all_dependencies', 
'ignore_dependencies',
                 'ignore_depends_on_past', 'ship_dag', 'pickle', 'job_id'),
         }, {
@@ -1486,7 +1538,7 @@ class CLIFactory(object):
             'func': upgradedb,
             'help': "Upgrade the metadata database to latest version",
             'args': tuple(),
-        },{
+        }, {
             'func': scheduler,
             'help': "Start a scheduler instance",
             'args': ('dag_id_opt', 'subdir', 'run_duration', 'num_runs',

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/b56cb5cc/airflow/configuration.py
----------------------------------------------------------------------
diff --git a/airflow/configuration.py b/airflow/configuration.py
index 9b27328..979b071 100644
--- a/airflow/configuration.py
+++ b/airflow/configuration.py
@@ -166,6 +166,13 @@ donot_pickle = False
 # How long before timing out a python file import while filling the DagBag
 dagbag_import_timeout = 30
 
+# The class to use for running task instances in a subprocess
+task_runner = BashTaskRunner
+
+# If set, tasks without a `run_as_user` argument will be run with this user
+# Can be used to de-elevate a sudo user running Airflow when executing tasks
+default_impersonation =
+
 # What security module to use (for example kerberos):
 security =
 

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/b56cb5cc/airflow/contrib/task_runner/__init__.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/task_runner/__init__.py 
b/airflow/contrib/task_runner/__init__.py
new file mode 100644
index 0000000..d4cd6f7
--- /dev/null
+++ b/airflow/contrib/task_runner/__init__.py
@@ -0,0 +1,13 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/b56cb5cc/airflow/contrib/task_runner/cgroup_task_runner.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/task_runner/cgroup_task_runner.py 
b/airflow/contrib/task_runner/cgroup_task_runner.py
new file mode 100644
index 0000000..79aafc8
--- /dev/null
+++ b/airflow/contrib/task_runner/cgroup_task_runner.py
@@ -0,0 +1,202 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import datetime
+import getpass
+import subprocess
+import os
+import uuid
+
+from cgroupspy import trees
+import psutil
+
+from airflow.task_runner.base_task_runner import BaseTaskRunner
+from airflow.utils.helpers import kill_process_tree
+
+
+class CgroupTaskRunner(BaseTaskRunner):
+    """
+    Runs the raw Airflow task in a cgroup that has containment for memory and
+    cpu. It uses the resource requirements defined in the task to construct
+    the settings for the cgroup.
+
+    Note that this task runner will only work if the Airflow user has root 
privileges,
+    e.g. if the airflow user is called `airflow` then the following entries 
(or an even
+    less restrictive ones) are needed in the sudoers file (replacing
+    /CGROUPS_FOLDER with your system's cgroups folder, e.g. '/sys/fs/cgroup/'):
+    airflow ALL= (root) NOEXEC: /bin/chown /CGROUPS_FOLDER/memory/airflow/*
+    airflow ALL= (root) NOEXEC: !/bin/chown /CGROUPS_FOLDER/memory/airflow/*..*
+    airflow ALL= (root) NOEXEC: !/bin/chown /CGROUPS_FOLDER/memory/airflow/* *
+    airflow ALL= (root) NOEXEC: /bin/chown /CGROUPS_FOLDER/cpu/airflow/*
+    airflow ALL= (root) NOEXEC: !/bin/chown /CGROUPS_FOLDER/cpu/airflow/*..*
+    airflow ALL= (root) NOEXEC: !/bin/chown /CGROUPS_FOLDER/cpu/airflow/* *
+    airflow ALL= (root) NOEXEC: /bin/chmod /CGROUPS_FOLDER/memory/airflow/*
+    airflow ALL= (root) NOEXEC: !/bin/chmod /CGROUPS_FOLDER/memory/airflow/*..*
+    airflow ALL= (root) NOEXEC: !/bin/chmod /CGROUPS_FOLDER/memory/airflow/* *
+    airflow ALL= (root) NOEXEC: /bin/chmod /CGROUPS_FOLDER/cpu/airflow/*
+    airflow ALL= (root) NOEXEC: !/bin/chmod /CGROUPS_FOLDER/cpu/airflow/*..*
+    airflow ALL= (root) NOEXEC: !/bin/chmod /CGROUPS_FOLDER/cpu/airflow/* *
+    """
+
+    def __init__(self, local_task_job):
+        super(CgroupTaskRunner, self).__init__(local_task_job)
+        self.process = None
+        self._finished_running = False
+        self._cpu_shares = None
+        self._mem_mb_limit = None
+        self._created_cpu_cgroup = False
+        self._created_mem_cgroup = False
+        self._cur_user = getpass.getuser()
+
+    def _create_cgroup(self, path):
+        """
+        Create the specified cgroup.
+
+        :param path: The path of the cgroup to create.
+        E.g. cpu/mygroup/mysubgroup
+        :return: the Node associated with the created cgroup.
+        :rtype: cgroupspy.nodes.Node
+        """
+        node = trees.Tree().root
+        path_split = path.split(os.sep)
+        for path_element in path_split:
+            name_to_node = {x.name: x for x in node.children}
+            if path_element not in name_to_node:
+                self.logger.debug("Creating cgroup {} in {}"
+                                  .format(path_element, node.path))
+                subprocess.check_output("sudo mkdir -p 
{}".format(path_element))
+                subprocess.check_output("sudo chown -R {} {}".format(
+                    self._cur_user, path_element))
+            else:
+                self.logger.debug("Not creating cgroup {} in {} "
+                                  "since it already exists"
+                                  .format(path_element, node.path))
+            node = name_to_node[path_element]
+        return node
+
+    def _delete_cgroup(self, path):
+        """
+        Delete the specified cgroup.
+
+        :param path: The path of the cgroup to delete.
+        E.g. cpu/mygroup/mysubgroup
+        """
+        node = trees.Tree().root
+        path_split = path.split("/")
+        for path_element in path_split:
+            name_to_node = {x.name: x for x in node.children}
+            if path_element not in name_to_node:
+                self.logger.warn("Cgroup does not exist: {}"
+                                 .format(path))
+                return
+            else:
+                node = name_to_node[path_element]
+        # node is now the leaf node
+        parent = node.parent
+        self.logger.debug("Deleting cgroup {}/{}".format(parent, node.name))
+        parent.delete_cgroup(node.name)
+
+    def start(self):
+        # Use bash if it's already in a cgroup
+        cgroups = self._get_cgroup_names()
+        if cgroups["cpu"] != "/" or cgroups["memory"] != "/":
+            self.logger.debug("Already running in a cgroup (cpu: {} memory: {} 
so "
+                              "not creating another one"
+                              .format(cgroups.get("cpu"),
+                                      cgroups.get("memory")))
+            self.process = self.run_command(['bash', '-c'], join_args=True)
+            return
+
+        # Create a unique cgroup name
+        cgroup_name = "airflow/{}/{}".format(datetime.datetime.now().
+                                             strftime("%Y-%m-%d"),
+                                             str(uuid.uuid1()))
+
+        self.mem_cgroup_name = "memory/{}".format(cgroup_name)
+        self.cpu_cgroup_name = "cpu/{}".format(cgroup_name)
+
+        # Get the resource requirements from the task
+        task = self._task_instance.task
+        resources = task.resources
+        cpus = resources.cpus.qty
+        self._cpu_shares = cpus * 1024
+        self._mem_mb_limit = resources.ram.qty
+
+        # Create the memory cgroup
+        mem_cgroup_node = self._create_cgroup(self.mem_cgroup_name)
+        self._created_mem_cgroup = True
+        if self._mem_mb_limit > 0:
+            self.logger.debug("Setting {} with {} MB of memory"
+                              .format(self.mem_cgroup_name, 
self._mem_mb_limit))
+            mem_cgroup_node.controller.limit_in_bytes = self._mem_mb_limit * 
1024 * 1024
+
+        # Create the CPU cgroup
+        cpu_cgroup_node = self._create_cgroup(self.cpu_cgroup_name)
+        self._created_cpu_cgroup = True
+        if self._cpu_shares > 0:
+            self.logger.debug("Setting {} with {} CPU shares"
+                              .format(self.cpu_cgroup_name, self._cpu_shares))
+            cpu_cgroup_node.controller.shares = self._cpu_shares
+
+        # Start the process w/ cgroups
+        self.logger.debug("Starting task process with cgroups cpu,memory:{}"
+                          .format(cgroup_name))
+        self.process = self.run_command(
+            ['cgexec', '-g', 'cpu,memory:{}'.format(cgroup_name)]
+        )
+
+    def return_code(self):
+        return_code = self.process.poll()
+        # TODO(plypaul) Monitoring the the control file in the cgroup fs is 
better than
+        # checking the return code here. The PR to use this is here:
+        # 
https://github.com/plypaul/airflow/blob/e144e4d41996300ffa93947f136eab7785b114ed/airflow/contrib/task_runner/cgroup_task_runner.py#L43
+        # but there were some issues installing the python butter package and
+        # libseccomp-dev on some hosts for some reason.
+        # I wasn't able to track down the root cause of the package install 
failures, but
+        # we might want to revisit that approach at some other point.
+        if return_code == 137:
+            self.logger.warn("Task failed with return code of 137. This may 
indicate "
+                             "that it was killed due to excessive memory 
usage. "
+                             "Please consider optimizing your task or using 
the "
+                             "resources argument to reserve more memory for 
your "
+                             "task")
+        return return_code
+
+    def terminate(self):
+        if self.process and psutil.pid_exists(self.process.pid):
+            kill_process_tree(self.logger, self.process.pid)
+
+    def on_finish(self):
+        # Let the OOM watcher thread know we're done to avoid false OOM alarms
+        self._finished_running = True
+        # Clean up the cgroups
+        if self._created_mem_cgroup:
+            self._delete_cgroup(self.mem_cgroup_name)
+        if self._created_cpu_cgroup:
+            self._delete_cgroup(self.cpu_cgroup_name)
+
+    def _get_cgroup_names(self):
+        """
+        :return: a mapping between the subsystem name to the cgroup name
+        :rtype: dict[str, str]
+        """
+        with open("/proc/self/cgroup") as f:
+            lines = f.readlines()
+            d = {}
+            for line in lines:
+                line_split = line.rstrip().split(":")
+                subsystem = line_split[1]
+                group_name = line_split[2]
+                d[subsystem] = group_name
+            return d

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/b56cb5cc/airflow/jobs.py
----------------------------------------------------------------------
diff --git a/airflow/jobs.py b/airflow/jobs.py
index 2a6af39..f1de333 100644
--- a/airflow/jobs.py
+++ b/airflow/jobs.py
@@ -25,7 +25,6 @@ from datetime import datetime
 import getpass
 import logging
 import socket
-import subprocess
 import multiprocessing
 import os
 import signal
@@ -35,7 +34,7 @@ import time
 from time import sleep
 
 import psutil
-from sqlalchemy import Column, Integer, String, DateTime, func, Index, or_
+from sqlalchemy import Column, Integer, String, DateTime, func, Index, or_, 
and_
 from sqlalchemy.exc import OperationalError
 from sqlalchemy.orm.session import make_transient
 from tabulate import tabulate
@@ -45,6 +44,7 @@ from airflow import configuration as conf
 from airflow.exceptions import AirflowException
 from airflow.models import DagRun
 from airflow.settings import Stats
+from airflow.task_runner import get_task_runner
 from airflow.ti_deps.dep_context import DepContext, QUEUE_DEPS, RUN_DEPS
 from airflow.utils.state import State
 from airflow.utils.db import provide_session, pessimistic_connection_handling
@@ -54,15 +54,12 @@ from airflow.utils.dag_processing import 
(AbstractDagFileProcessor,
                                           SimpleDagBag,
                                           list_py_file_paths)
 from airflow.utils.email import send_email
-from airflow.utils.helpers import kill_descendant_processes
 from airflow.utils.logging import LoggingMixin
 from airflow.utils import asciiart
 
 
 Base = models.Base
-DagRun = models.DagRun
 ID_LEN = models.ID_LEN
-Stats = settings.Stats
 
 
 class BaseJob(Base, LoggingMixin):
@@ -956,13 +953,18 @@ class SchedulerJob(BaseJob):
         :type states: Tuple[State]
         :return: None
         """
-        # Get all the relevant task instances
+        # Get all the queued task instances from associated with scheduled
+        # DagRuns.
         TI = models.TaskInstance
         task_instances_to_examine = (
             session
             .query(TI)
             .filter(TI.dag_id.in_(simple_dag_bag.dag_ids))
             .filter(TI.state.in_(states))
+            .join(DagRun, and_(TI.dag_id == DagRun.dag_id,
+                               TI.execution_date == DagRun.execution_date,
+                               DagRun.state == State.RUNNING,
+                               DagRun.run_id.like(DagRun.ID_PREFIX + '%')))
             .all()
         )
 
@@ -1017,7 +1019,7 @@ class SchedulerJob(BaseJob):
                     self.logger.debug("Not handling task {} as the executor 
reports it is running"
                                       .format(task_instance.key))
                     continue
- 
+
                 if simple_dag_bag.get_dag(task_instance.dag_id).is_paused:
                     self.logger.info("Not executing queued {} since {} is 
paused"
                                      .format(task_instance, 
task_instance.dag_id))
@@ -1054,7 +1056,7 @@ class SchedulerJob(BaseJob):
                                              task_concurrency_limit))
                     continue
 
-                command = TI.generate_command(
+                command = " ".join(TI.generate_command(
                     task_instance.dag_id,
                     task_instance.task_id,
                     task_instance.execution_date,
@@ -1066,7 +1068,7 @@ class SchedulerJob(BaseJob):
                     ignore_ti_state=False,
                     pool=task_instance.pool,
                     
file_path=simple_dag_bag.get_dag(task_instance.dag_id).full_filepath,
-                    
pickle_id=simple_dag_bag.get_dag(task_instance.dag_id).pickle_id)
+                    
pickle_id=simple_dag_bag.get_dag(task_instance.dag_id).pickle_id))
 
                 priority = task_instance.priority_weight
                 queue = task_instance.queue
@@ -1659,7 +1661,7 @@ class BackfillJob(BaseJob):
 
         # consider max_active_runs but ignore when running subdags
         # "parent.child" as a dag_id is by convention a subdag
-        if self.dag.schedule_interval and not "." in self.dag.dag_id:
+        if self.dag.schedule_interval and "." not in self.dag.dag_id:
             active_runs = DagRun.find(
                 dag_id=self.dag.dag_id,
                 state=State.RUNNING,
@@ -1915,7 +1917,6 @@ class BackfillJob(BaseJob):
                                 self.logger.error(msg)
                                 ti.handle_failure(msg)
                                 tasks_to_run.pop(key)
-
                 msg = ' | '.join([
                     "[backfill progress]",
                     "dag run {6} of {7}",
@@ -2026,23 +2027,14 @@ class LocalTaskJob(BaseJob):
         super(LocalTaskJob, self).__init__(*args, **kwargs)
 
     def _execute(self):
+        self.task_runner = get_task_runner(self)
         try:
-            command = self.task_instance.command(
-                raw=True,
-                ignore_all_deps = self.ignore_all_deps,
-                ignore_depends_on_past = self.ignore_depends_on_past,
-                ignore_task_deps = self.ignore_task_deps,
-                ignore_ti_state = self.ignore_ti_state,
-                pickle_id = self.pickle_id,
-                mark_success = self.mark_success,
-                job_id = self.id,
-                pool = self.pool
-            )
-            self.process = subprocess.Popen(['bash', '-c', command])
-            self.logger.info("Subprocess PID is {}".format(self.process.pid))
+            self.task_runner.start()
+
             ti = self.task_instance
             session = settings.Session()
-            ti.pid = self.process.pid
+            if self.task_runner.process:
+                ti.pid = self.task_runner.process.pid
             ti.hostname = socket.getfqdn()
             session.merge(ti)
             session.commit()
@@ -2053,8 +2045,10 @@ class LocalTaskJob(BaseJob):
                                                
'scheduler_zombie_task_threshold')
             while True:
                 # Monitor the task to see if it's done
-                return_code = self.process.poll()
+                return_code = self.task_runner.return_code()
                 if return_code is not None:
+                    self.logger.info("Task exited with return code {}"
+                                     .format(return_code))
                     return
 
                 # Periodically heartbeat so that the scheduler doesn't think 
this
@@ -2079,11 +2073,11 @@ class LocalTaskJob(BaseJob):
                                            .format(time_since_last_heartbeat,
                                                    heartbeat_time_limit))
         finally:
-            # Kill processes that were left running
-            kill_descendant_processes(self.logger)
+            self.on_kill()
 
     def on_kill(self):
-        self.process.terminate()
+        self.task_runner.terminate()
+        self.task_runner.on_finish()
 
     @provide_session
     def heartbeat_callback(self, session=None):
@@ -2097,23 +2091,24 @@ class LocalTaskJob(BaseJob):
         TI = models.TaskInstance
         ti = self.task_instance
         new_ti = session.query(TI).filter(
-            TI.dag_id==ti.dag_id, TI.task_id==ti.task_id,
-            TI.execution_date==ti.execution_date).scalar()
+            TI.dag_id == ti.dag_id, TI.task_id == ti.task_id,
+            TI.execution_date == ti.execution_date).scalar()
         if new_ti.state == State.RUNNING:
             self.was_running = True
             fqdn = socket.getfqdn()
-            if not (fqdn == new_ti.hostname and self.process.pid == 
new_ti.pid):
+            if not (fqdn == new_ti.hostname and
+                    self.task_runner.process.pid == new_ti.pid):
                 logging.warning("Recorded hostname and pid of 
{new_ti.hostname} "
                                 "and {new_ti.pid} do not match this instance's 
"
                                 "which are {fqdn} and "
-                                "{self.process.pid}. Taking the poison pill. 
So "
-                                "long."
+                                "{self.task_runner.process.pid}. Taking the 
poison pill. "
+                                "So long."
                                 .format(**locals()))
                 raise AirflowException("Another worker/process is running this 
job")
-        elif self.was_running and hasattr(self, 'process'):
+        elif self.was_running and hasattr(self.task_runner, 'process'):
             logging.warning(
                 "State of this instance has been externally set to "
                 "{self.task_instance.state}. "
                 "Taking the poison pill. So long.".format(**locals()))
-            self.process.terminate()
+            self.task_runner.terminate()
             self.terminating = True

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/b56cb5cc/airflow/migrations/versions/1a5a9e6bf2b5_add_state_index_for_dagruns.py
----------------------------------------------------------------------
diff --git 
a/airflow/migrations/versions/1a5a9e6bf2b5_add_state_index_for_dagruns.py 
b/airflow/migrations/versions/1a5a9e6bf2b5_add_state_index_for_dagruns.py
new file mode 100644
index 0000000..29ffaf1
--- /dev/null
+++ b/airflow/migrations/versions/1a5a9e6bf2b5_add_state_index_for_dagruns.py
@@ -0,0 +1,37 @@
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Add state index for dagruns to allow the quick lookup of active dagruns
+
+Revision ID: 1a5a9e6bf2b5
+Revises: 5e7d17757c7a
+Create Date: 2017-01-17 10:22:53.193711
+
+"""
+
+# revision identifiers, used by Alembic.
+revision = '1a5a9e6bf2b5'
+down_revision = '5e7d17757c7a'
+branch_labels = None
+depends_on = None
+
+from alembic import op
+import sqlalchemy as sa
+
+
+def upgrade():
+    op.create_index('dr_state', 'dag_run', ['state'], unique=False)
+
+
+def downgrade():
+    op.drop_index('state', table_name='dag_run')

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/b56cb5cc/airflow/models.py
----------------------------------------------------------------------
diff --git a/airflow/models.py b/airflow/models.py
index 8682f35..a16603d 100755
--- a/airflow/models.py
+++ b/airflow/models.py
@@ -251,9 +251,9 @@ class DagBag(BaseDagBag, LoggingMixin):
 
             self.logger.debug("Importing {}".format(filepath))
             org_mod_name, _ = os.path.splitext(os.path.split(filepath)[-1])
-            mod_name = ('unusual_prefix_'
-                        + hashlib.sha1(filepath.encode('utf-8')).hexdigest()
-                        + '_' + org_mod_name)
+            mod_name = ('unusual_prefix_' +
+                        hashlib.sha1(filepath.encode('utf-8')).hexdigest() +
+                        '_' + org_mod_name)
 
             if mod_name in sys.modules:
                 del sys.modules[mod_name]
@@ -756,6 +756,7 @@ class TaskInstance(Base):
         self.priority_weight = task.priority_weight_total
         self.try_number = 0
         self.unixname = getpass.getuser()
+        self.run_as_user = task.run_as_user
         if state:
             self.state = state
         self.hostname = ''
@@ -777,7 +778,39 @@ class TaskInstance(Base):
             pickle_id=None,
             raw=False,
             job_id=None,
-            pool=None):
+            pool=None,
+            cfg_path=None):
+        """
+        Returns a command that can be executed anywhere where airflow is
+        installed. This command is part of the message sent to executors by
+        the orchestrator.
+        """
+        return " ".join(self.command_as_list(
+            mark_success=mark_success,
+            ignore_all_deps=ignore_all_deps,
+            ignore_depends_on_past=ignore_depends_on_past,
+            ignore_task_deps=ignore_task_deps,
+            ignore_ti_state=ignore_ti_state,
+            local=local,
+            pickle_id=pickle_id,
+            raw=raw,
+            job_id=job_id,
+            pool=pool,
+            cfg_path=cfg_path))
+
+    def command_as_list(
+            self,
+            mark_success=False,
+            ignore_all_deps=False,
+            ignore_task_deps=False,
+            ignore_depends_on_past=False,
+            ignore_ti_state=False,
+            local=False,
+            pickle_id=None,
+            raw=False,
+            job_id=None,
+            pool=None,
+            cfg_path=None):
         """
         Returns a command that can be executed anywhere where airflow is
         installed. This command is part of the message sent to executors by
@@ -799,15 +832,16 @@ class TaskInstance(Base):
             self.execution_date,
             mark_success=mark_success,
             ignore_all_deps=ignore_all_deps,
-            ignore_depends_on_past=ignore_depends_on_past,
             ignore_task_deps=ignore_task_deps,
+            ignore_depends_on_past=ignore_depends_on_past,
             ignore_ti_state=ignore_ti_state,
             local=local,
             pickle_id=pickle_id,
             file_path=path,
             raw=raw,
             job_id=job_id,
-            pool=pool)
+            pool=pool,
+            cfg_path=cfg_path)
 
     @staticmethod
     def generate_command(dag_id,
@@ -823,7 +857,8 @@ class TaskInstance(Base):
                          file_path=None,
                          raw=False,
                          job_id=None,
-                         pool=None
+                         pool=None,
+                         cfg_path=None
                          ):
         """
         Generates the shell command required to execute this task instance.
@@ -860,19 +895,20 @@ class TaskInstance(Base):
         :return: shell command that can be used to run the task instance
         """
         iso = execution_date.isoformat()
-        cmd = "airflow run {dag_id} {task_id} {iso} "
-        cmd += "--mark_success " if mark_success else ""
-        cmd += "--pickle {pickle_id} " if pickle_id else ""
-        cmd += "--job_id {job_id} " if job_id else ""
-        cmd += "-A " if ignore_all_deps else ""
-        cmd += "-i " if ignore_task_deps else ""
-        cmd += "-I " if ignore_depends_on_past else ""
-        cmd += "--force " if ignore_ti_state else ""
-        cmd += "--local " if local else ""
-        cmd += "--pool {pool} " if pool else ""
-        cmd += "--raw " if raw else ""
-        cmd += "-sd {file_path}" if file_path else ""
-        return cmd.format(**locals())
+        cmd = ["airflow", "run", str(dag_id), str(task_id), str(iso)]
+        cmd.extend(["--mark_success"]) if mark_success else None
+        cmd.extend(["--pickle", str(pickle_id)]) if pickle_id else None
+        cmd.extend(["--job_id", str(job_id)]) if job_id else None
+        cmd.extend(["-A "]) if ignore_all_deps else None
+        cmd.extend(["-i"]) if ignore_task_deps else None
+        cmd.extend(["-I"]) if ignore_depends_on_past else None
+        cmd.extend(["--force"]) if ignore_ti_state else None
+        cmd.extend(["--local"]) if local else None
+        cmd.extend(["--pool", pool]) if pool else None
+        cmd.extend(["--raw"]) if raw else None
+        cmd.extend(["-sd", file_path]) if file_path else None
+        cmd.extend(["--cfg_path", cfg_path]) if cfg_path else None
+        return cmd
 
     @property
     def log_filepath(self):
@@ -1825,6 +1861,8 @@ class BaseOperator(object):
     :param resources: A map of resource parameter names (the argument names of 
the
         Resources constructor) to their values.
     :type resources: dict
+    :param run_as_user: unix username to impersonate while running the task
+    :type run_as_user: str
     """
 
     # For derived classes to define which fields will get jinjaified
@@ -1866,6 +1904,7 @@ class BaseOperator(object):
             on_retry_callback=None,
             trigger_rule=TriggerRule.ALL_SUCCESS,
             resources=None,
+            run_as_user=None,
             *args,
             **kwargs):
 
@@ -1929,6 +1968,7 @@ class BaseOperator(object):
         self.adhoc = adhoc
         self.priority_weight = priority_weight
         self.resources = Resources(**(resources or {}))
+        self.run_as_user = run_as_user
 
         # Private attributes
         self._upstream_task_ids = []
@@ -2854,13 +2894,7 @@ class DAG(BaseDag, LoggingMixin):
         :param session:
         :return: List of execution dates
         """
-        runs = (
-           session.query(DagRun)
-           .filter(
-           DagRun.dag_id == self.dag_id,
-           DagRun.state == State.RUNNING)
-           .order_by(DagRun.execution_date)
-           .all())
+        runs = DagRun.find(dag_id=self.dag_id, state=State.RUNNING)
 
         active_dates = []
         for run in runs:
@@ -2959,7 +2993,7 @@ class DAG(BaseDag, LoggingMixin):
             self, session, start_date=None, end_date=None, state=None):
         TI = TaskInstance
         if not start_date:
-            start_date = (datetime.today()-timedelta(30)).date()
+            start_date = (datetime.today() - timedelta(30)).date()
             start_date = datetime.combine(start_date, datetime.min.time())
         end_date = end_date or datetime.now()
         tis = session.query(TI).filter(
@@ -3488,7 +3522,6 @@ class Variable(Base):
             else:
                 return obj.val
 
-
     @classmethod
     @provide_session
     def get(cls, key, default_var=None, deserialize_json=False, session=None):
@@ -3695,7 +3728,6 @@ class DagStat(Base):
         :type full_query: bool
         """
         dag_ids = set(dag_ids)
-        ds_ids = set(session.query(DagStat.dag_id).all())
 
         qry = (
             session.query(DagStat)

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/b56cb5cc/airflow/settings.py
----------------------------------------------------------------------
diff --git a/airflow/settings.py b/airflow/settings.py
index ce2ca92..4882875 100644
--- a/airflow/settings.py
+++ b/airflow/settings.py
@@ -68,10 +68,7 @@ ___  ___ |  / _  /   _  __/ _  / / /_/ /_ |/ |/ /
  """
 
 BASE_LOG_URL = '/admin/airflow/log'
-AIRFLOW_HOME = os.path.expanduser(conf.get('core', 'AIRFLOW_HOME'))
-SQL_ALCHEMY_CONN = conf.get('core', 'SQL_ALCHEMY_CONN')
 LOGGING_LEVEL = logging.INFO
-DAGS_FOLDER = os.path.expanduser(conf.get('core', 'DAGS_FOLDER'))
 
 # the prefix to append to gunicorn worker processes after init
 GUNICORN_WORKER_READY_PREFIX = "[ready] "
@@ -85,6 +82,13 @@ LOG_FORMAT_WITH_THREAD_NAME = (
     '[%(asctime)s] {%(filename)s:%(lineno)d} %(threadName)s %(levelname)s - 
%(message)s')
 SIMPLE_LOG_FORMAT = '%(asctime)s %(levelname)s - %(message)s'
 
+AIRFLOW_HOME = None
+SQL_ALCHEMY_CONN = None
+DAGS_FOLDER = None
+
+engine = None
+Session = None
+
 
 def policy(task_instance):
     """
@@ -118,8 +122,14 @@ def configure_logging(log_format=LOG_FORMAT):
     logging.basicConfig(
         format=log_format, stream=sys.stdout, level=LOGGING_LEVEL)
 
-engine = None
-Session = None
+
+def configure_vars():
+    global AIRFLOW_HOME
+    global SQL_ALCHEMY_CONN
+    global DAGS_FOLDER
+    AIRFLOW_HOME = os.path.expanduser(conf.get('core', 'AIRFLOW_HOME'))
+    SQL_ALCHEMY_CONN = conf.get('core', 'SQL_ALCHEMY_CONN')
+    DAGS_FOLDER = os.path.expanduser(conf.get('core', 'DAGS_FOLDER'))
 
 
 def configure_orm(disable_connection_pool=False):
@@ -133,7 +143,7 @@ def configure_orm(disable_connection_pool=False):
         engine_args['pool_size'] = conf.getint('core', 'SQL_ALCHEMY_POOL_SIZE')
         engine_args['pool_recycle'] = conf.getint('core',
                                                   'SQL_ALCHEMY_POOL_RECYCLE')
-        #engine_args['echo'] = True
+        # engine_args['echo'] = True
 
     engine = create_engine(SQL_ALCHEMY_CONN, **engine_args)
     Session = scoped_session(
@@ -146,6 +156,7 @@ except:
     pass
 
 configure_logging()
+configure_vars()
 configure_orm()
 
 # Const stuff

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/b56cb5cc/airflow/task_runner/__init__.py
----------------------------------------------------------------------
diff --git a/airflow/task_runner/__init__.py b/airflow/task_runner/__init__.py
new file mode 100644
index 0000000..f134e8e
--- /dev/null
+++ b/airflow/task_runner/__init__.py
@@ -0,0 +1,38 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from airflow import configuration
+from airflow.contrib.task_runner.cgroup_task_runner import CgroupTaskRunner
+from airflow.task_runner.bash_task_runner import BashTaskRunner
+from airflow.exceptions import AirflowException
+
+_TASK_RUNNER = configuration.get('core', 'TASK_RUNNER')
+
+
+def get_task_runner(local_task_job):
+    """
+    Get the task runner that can be used to run the given job.
+
+    :param local_task_job: The LocalTaskJob associated with the TaskInstance
+    that needs to be executed.
+    :type local_task_job: airflow.jobs.LocalTaskJob
+    :return: The task runner to use to run the task.
+    :rtype: airflow.task_runner.base_task_runner.BaseTaskRunner
+    """
+    if _TASK_RUNNER == "BashTaskRunner":
+        return BashTaskRunner(local_task_job)
+    elif _TASK_RUNNER == "CgroupTaskRunner":
+        return CgroupTaskRunner(local_task_job)
+    else:
+        raise AirflowException("Unknown task runner type 
{}".format(_TASK_RUNNER))

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/b56cb5cc/airflow/task_runner/base_task_runner.py
----------------------------------------------------------------------
diff --git a/airflow/task_runner/base_task_runner.py 
b/airflow/task_runner/base_task_runner.py
new file mode 100644
index 0000000..69802a8
--- /dev/null
+++ b/airflow/task_runner/base_task_runner.py
@@ -0,0 +1,153 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import getpass
+import os
+import json
+import subprocess
+import threading
+
+from airflow import configuration as conf
+from airflow.utils.logging import LoggingMixin
+from tempfile import mkstemp
+
+
+class BaseTaskRunner(LoggingMixin):
+    """
+    Runs Airflow task instances by invoking the `airflow run` command with raw
+    mode enabled in a subprocess.
+    """
+
+    def __init__(self, local_task_job):
+        """
+        :param local_task_job: The local task job associated with running the
+        associated task instance.
+        :type local_task_job: airflow.jobs.LocalTaskJob
+        """
+        self._task_instance = local_task_job.task_instance
+
+        popen_prepend = []
+        cfg_path = None
+        if self._task_instance.run_as_user:
+            self.run_as_user = self._task_instance.run_as_user
+        else:
+            try:
+                self.run_as_user = conf.get('core', 'default_impersonation')
+            except conf.AirflowConfigException:
+                self.run_as_user = None
+
+        # Add sudo commands to change user if we need to. Needed to handle 
SubDagOperator
+        # case using a SequentialExecutor.
+        if self.run_as_user and (self.run_as_user != getpass.getuser()):
+            self.logger.debug("Planning to run as the {} 
user".format(self.run_as_user))
+            cfg_dict = conf.as_dict(display_sensitive=True)
+            cfg_subset = {
+                'core': cfg_dict.get('core', {}),
+                'smtp': cfg_dict.get('smtp', {}),
+                'scheduler': cfg_dict.get('scheduler', {}),
+                'webserver': cfg_dict.get('webserver', {}),
+            }
+            temp_fd, cfg_path = mkstemp()
+
+            # Give ownership of file to user; only they can read and write
+            subprocess.call(
+                ['sudo', 'chown', self.run_as_user, cfg_path]
+            )
+            subprocess.call(
+                ['sudo', 'chmod', '600', cfg_path]
+            )
+
+            with os.fdopen(temp_fd, 'w') as temp_file:
+                json.dump(cfg_subset, temp_file)
+
+            popen_prepend = ['sudo', '-H', '-u', self.run_as_user]
+
+        self._cfg_path = cfg_path
+        self._command = popen_prepend + self._task_instance.command_as_list(
+            raw=True,
+            ignore_all_deps=local_task_job.ignore_all_deps,
+            ignore_depends_on_past=local_task_job.ignore_depends_on_past,
+            ignore_ti_state=local_task_job.ignore_ti_state,
+            pickle_id=local_task_job.pickle_id,
+            mark_success=local_task_job.mark_success,
+            job_id=local_task_job.id,
+            pool=local_task_job.pool,
+            cfg_path=cfg_path,
+        )
+        self.process = None
+
+    def _read_task_logs(self, stream):
+        while True:
+            line = stream.readline()
+            if len(line) == 0:
+                break
+            self.logger.info('Subtask: {}'.format(line.rstrip('\n')))
+
+    def run_command(self, run_with, join_args=False):
+        """
+        Run the task command
+
+        :param run_with: list of tokens to run the task command with
+        E.g. ['bash', '-c']
+        :type run_with: list
+        :param join_args: whether to concatenate the list of command tokens
+        E.g. ['airflow', 'run'] vs ['airflow run']
+        :param join_args: bool
+        :return: the process that was run
+        :rtype: subprocess.Popen
+        """
+        cmd = [" ".join(self._command)] if join_args else self._command
+        full_cmd = run_with + cmd
+        self.logger.info('Running: {}'.format(full_cmd))
+        proc = subprocess.Popen(
+            full_cmd,
+            stdout=subprocess.PIPE,
+            stderr=subprocess.STDOUT
+        )
+
+        # Start daemon thread to read subprocess logging output
+        log_reader = threading.Thread(
+            target=self._read_task_logs,
+            args=(proc.stdout,),
+        )
+        log_reader.daemon = True
+        log_reader.start()
+        return proc
+
+    def start(self):
+        """
+        Start running the task instance in a subprocess.
+        """
+        raise NotImplementedError()
+
+    def return_code(self):
+        """
+        :return: The return code associated with running the task instance or
+        None if the task is not yet done.
+        :rtype int:
+        """
+        raise NotImplementedError()
+
+    def terminate(self):
+        """
+        Kill the running task instance.
+        """
+        raise NotImplementedError()
+
+    def on_finish(self):
+        """
+        A callback that should be called when this is done running.
+        """
+        if self._cfg_path and os.path.isfile(self._cfg_path):
+            subprocess.call(['sudo', 'rm', self._cfg_path])

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/b56cb5cc/airflow/task_runner/bash_task_runner.py
----------------------------------------------------------------------
diff --git a/airflow/task_runner/bash_task_runner.py 
b/airflow/task_runner/bash_task_runner.py
new file mode 100644
index 0000000..b73e258
--- /dev/null
+++ b/airflow/task_runner/bash_task_runner.py
@@ -0,0 +1,39 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import psutil
+
+from airflow.task_runner.base_task_runner import BaseTaskRunner
+from airflow.utils.helpers import kill_process_tree
+
+
+class BashTaskRunner(BaseTaskRunner):
+    """
+    Runs the raw Airflow task by invoking through the Bash shell.
+    """
+    def __init__(self, local_task_job):
+        super(BashTaskRunner, self).__init__(local_task_job)
+
+    def start(self):
+        self.process = self.run_command(['bash', '-c'], join_args=True)
+
+    def return_code(self):
+        return self.process.poll()
+
+    def terminate(self):
+        if self.process and psutil.pid_exists(self.process.pid):
+            kill_process_tree(self.logger, self.process.pid)
+
+    def on_finish(self):
+        super(BashTaskRunner, self).on_finish()

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/b56cb5cc/airflow/utils/file.py
----------------------------------------------------------------------
diff --git a/airflow/utils/file.py b/airflow/utils/file.py
index d4526e9..78ddeaa 100644
--- a/airflow/utils/file.py
+++ b/airflow/utils/file.py
@@ -16,6 +16,7 @@ from __future__ import absolute_import
 from __future__ import unicode_literals
 
 import errno
+import os
 import shutil
 from tempfile import mkdtemp
 
@@ -34,3 +35,25 @@ def TemporaryDirectory(suffix='', prefix=None, dir=None):
             # ENOENT - no such file or directory
             if e.errno != errno.ENOENT:
                 raise e
+
+
+def mkdirs(path, mode):
+    """
+    Creates the directory specified by path, creating intermediate directories
+    as necessary. If directory already exists, this is a no-op.
+
+    :param path: The directory to create
+    :type path: str
+    :param mode: The mode to give to the directory e.g. 0o755
+    :type mode: int
+    :return: A list of directories that were created
+    :rtype: list[str]
+    """
+    if not path or os.path.exists(path):
+        return []
+    (head, _) = os.path.split(path)
+    res = mkdirs(head, mode)
+    os.mkdir(path)
+    os.chmod(path, mode)
+    res += [path]
+    return res

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/b56cb5cc/airflow/utils/helpers.py
----------------------------------------------------------------------
diff --git a/airflow/utils/helpers.py b/airflow/utils/helpers.py
index 6bd7a64..e66745c 100644
--- a/airflow/utils/helpers.py
+++ b/airflow/utils/helpers.py
@@ -22,10 +22,12 @@ import psutil
 from builtins import input
 from past.builtins import basestring
 from datetime import datetime
+import getpass
 import imp
-import logging
 import os
 import re
+import signal
+import subprocess
 import sys
 import warnings
 
@@ -35,6 +37,7 @@ from airflow.exceptions import AirflowException
 # SIGKILL.
 TIME_TO_WAIT_AFTER_SIGTERM = 5
 
+
 def validate_key(k, max_length=250):
     if not isinstance(k, basestring):
         raise TypeError("The key has to be a string")
@@ -179,6 +182,80 @@ def pprinttable(rows):
     return s
 
 
+def kill_using_shell(pid, signal=signal.SIGTERM):
+    process = psutil.Process(pid)
+    # Use sudo only when necessary - consider SubDagOperator and 
SequentialExecutor case.
+    if process.username() != getpass.getuser():
+        args = ["sudo", "kill", "-{}".format(int(signal)), str(pid)]
+    else:
+        args = ["kill", "-{}".format(int(signal)), str(pid)]
+    # PID may not exist and return a non-zero error code
+    subprocess.call(args)
+
+
+def kill_process_tree(logger, pid):
+    """
+    Kills the process and all of the descendants. Kills using the `kill`
+    shell command so that it can change users. Note: killing via PIDs
+    has the potential to the wrong process if the process dies and the
+    PID gets recycled in a narrow time window.
+
+    :param logger: logger
+    :type logger: logging.Logger
+    """
+    try:
+        root_process = psutil.Process(pid)
+    except psutil.NoSuchProcess:
+        logger.warn("PID: {} does not exist".format(pid))
+        return
+
+    # Check child processes to reduce cases where a child process died but
+    # the PID got reused.
+    descendant_processes = [x for x in root_process.children(recursive=True)
+                            if x.is_running()]
+
+    if len(descendant_processes) != 0:
+        logger.warn("Terminating descendant processes of {} PID: {}"
+                    .format(root_process.cmdline(),
+                            root_process.pid))
+        temp_processes = descendant_processes[:]
+        for descendant in temp_processes:
+            logger.warn("Terminating descendant process {} PID: {}"
+                        .format(descendant.cmdline(), descendant.pid))
+            try:
+                kill_using_shell(descendant.pid, signal.SIGTERM)
+            except psutil.NoSuchProcess:
+                descendant_processes.remove(descendant)
+
+        logger.warn("Waiting up to {}s for processes to exit..."
+                    .format(TIME_TO_WAIT_AFTER_SIGTERM))
+        try:
+            psutil.wait_procs(descendant_processes, TIME_TO_WAIT_AFTER_SIGTERM)
+            logger.warn("Done waiting")
+        except psutil.TimeoutExpired:
+            logger.warn("Ran out of time while waiting for "
+                        "processes to exit")
+        # Then SIGKILL
+        descendant_processes = [x for x in 
root_process.children(recursive=True)
+                                if x.is_running()]
+
+        if len(descendant_processes) > 0:
+            temp_processes = descendant_processes[:]
+            for descendant in temp_processes:
+                logger.warn("Killing descendant process {} PID: {}"
+                            .format(descendant.cmdline(), descendant.pid))
+                try:
+                    kill_using_shell(descendant.pid, signal.SIGTERM)
+                    descendant.wait()
+                except psutil.NoSuchProcess:
+                    descendant_processes.remove(descendant)
+            logger.warn("Killed all descendant processes of {} PID: {}"
+                        .format(root_process.cmdline(),
+                                root_process.pid))
+    else:
+        logger.debug("There are no descendant processes to kill")
+
+
 def kill_descendant_processes(logger, pids_to_kill=None):
     """
     Kills all descendant processes of this process.

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/b56cb5cc/docs/security.rst
----------------------------------------------------------------------
diff --git a/docs/security.rst b/docs/security.rst
index 29f228d..70db606 100644
--- a/docs/security.rst
+++ b/docs/security.rst
@@ -310,3 +310,25 @@ standard port 443, you'll need to configure that too. Be 
aware that super user p
     # Optionally, set the server to listen on the standard SSL port.
     web_server_port = 443
     base_url = http://<hostname or IP>:443
+
+Impersonation
+'''''''''''''
+
+Airflow has the ability to impersonate a unix user while running task
+instances based on the task's ``run_as_user`` parameter, which takes a user's 
name.
+
+*NOTE* For impersonations to work, Airflow must be run with `sudo` as subtasks 
are run
+with `sudo -u` and permissions of files are changed. Furthermore, the unix 
user needs to
+exist on the worker. Here is what a simple sudoers file entry could look like 
to achieve
+this, assuming as airflow is running as the `airflow` user. Note that this 
means that
+the airflow user must be trusted and treated the same way as the root user.
+
+.. code-block:: none
+    airflow ALL=(ALL) NOPASSWD: ALL
+
+Subtasks with impersonation will still log to the same folder, except that the 
files they
+log to will have permissions changed such that only the unix user can write to 
it.
+
+*Default impersonation* To prevent tasks that don't use impersonation to be 
run with
+`sudo` privileges, you can set the `default_impersonation` config in `core` 
which sets a
+default user impersonate if `run_as_user` is not set.

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/b56cb5cc/run_unit_tests.sh
----------------------------------------------------------------------
diff --git a/run_unit_tests.sh b/run_unit_tests.sh
index c291292..c922a55 100755
--- a/run_unit_tests.sh
+++ b/run_unit_tests.sh
@@ -13,6 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+set -x
 
 # environment
 export AIRFLOW_HOME=${AIRFLOW_HOME:=~/airflow}
@@ -48,6 +49,19 @@ echo "Initializing the DB"
 yes | airflow resetdb
 airflow initdb
 
+if [ "${TRAVIS}" ]; then
+  # For impersonation tests running on SQLite on Travis, make the database 
world readable so other 
+  # users can update it
+  AIRFLOW_DB="/home/travis/airflow/airflow.db"
+  if [ -f "${AIRFLOW_DB}" ]; then
+    sudo chmod a+rw "${AIRFLOW_DB}"
+  fi
+
+  # For impersonation tests on Travis, make airflow accessible to other users 
via the global PATH
+  # (which contains /usr/local/bin)
+  sudo ln -s "${VIRTUAL_ENV}/bin/airflow" /usr/local/bin/
+fi
+
 echo "Starting the unit tests with the following nose arguments: "$nose_args
 nosetests $nose_args
 

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/b56cb5cc/scripts/ci/airflow_travis.cfg
----------------------------------------------------------------------
diff --git a/scripts/ci/airflow_travis.cfg b/scripts/ci/airflow_travis.cfg
index 505bc0e..2834ad4 100644
--- a/scripts/ci/airflow_travis.cfg
+++ b/scripts/ci/airflow_travis.cfg
@@ -22,6 +22,7 @@ load_examples = True
 donot_pickle = False
 dag_concurrency = 16
 dags_are_paused_at_creation = False
+default_impersonation =
 fernet_key = af7CN0q6ag5U3g08IsPsw3K45U7Xa0axgVFhoh-3zB8=
 
 [webserver]

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/b56cb5cc/scripts/ci/requirements.txt
----------------------------------------------------------------------
diff --git a/scripts/ci/requirements.txt b/scripts/ci/requirements.txt
index 9e503f9..a5786f6 100644
--- a/scripts/ci/requirements.txt
+++ b/scripts/ci/requirements.txt
@@ -2,6 +2,7 @@ alembic
 bcrypt
 boto
 celery
+cgroupspy
 chartkick
 cloudant
 coverage

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/b56cb5cc/setup.py
----------------------------------------------------------------------
diff --git a/setup.py b/setup.py
index aad9984..b8fe677 100644
--- a/setup.py
+++ b/setup.py
@@ -108,6 +108,9 @@ celery = [
     'celery>=3.1.17',
     'flower>=0.7.3'
 ]
+cgroups = [
+    'cgroupspy>=0.1.4',
+]
 crypto = ['cryptography>=0.9.3']
 datadog = ['datadog>=0.14.0']
 doc = [
@@ -227,6 +230,7 @@ def do_setup():
             'all_dbs': all_dbs,
             'async': async,
             'celery': celery,
+            'cgroups': cgroups,
             'cloudant': cloudant,
             'crypto': crypto,
             'datadog': datadog,

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/b56cb5cc/tests/__init__.py
----------------------------------------------------------------------
diff --git a/tests/__init__.py b/tests/__init__.py
index 69abb33..e1e8551 100644
--- a/tests/__init__.py
+++ b/tests/__init__.py
@@ -18,6 +18,7 @@ from .configuration import *
 from .contrib import *
 from .core import *
 from .jobs import *
+from .impersonation import *
 from .models import *
 from .operators import *
 from .utils import *

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/b56cb5cc/tests/dags/test_default_impersonation.py
----------------------------------------------------------------------
diff --git a/tests/dags/test_default_impersonation.py 
b/tests/dags/test_default_impersonation.py
new file mode 100644
index 0000000..41cca00
--- /dev/null
+++ b/tests/dags/test_default_impersonation.py
@@ -0,0 +1,44 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from airflow.models import DAG
+from airflow.operators.bash_operator import BashOperator
+from datetime import datetime
+from textwrap import dedent
+
+
+DEFAULT_DATE = datetime(2016, 1, 1)
+
+args = {
+    'owner': 'airflow',
+    'start_date': DEFAULT_DATE,
+}
+
+dag = DAG(dag_id='test_default_impersonation', default_args=args)
+
+deelevated_user = 'airflow_test_user'
+
+test_command = dedent(
+    """\
+    if [ '{user}' != "$(whoami)" ]; then
+        echo current user $(whoami) is not {user}!
+        exit 1
+    fi
+    """.format(user=deelevated_user))
+
+task = BashOperator(
+    task_id='test_deelevated_user',
+    bash_command=test_command,
+    dag=dag,
+)

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/b56cb5cc/tests/dags/test_impersonation.py
----------------------------------------------------------------------
diff --git a/tests/dags/test_impersonation.py b/tests/dags/test_impersonation.py
new file mode 100644
index 0000000..3727903
--- /dev/null
+++ b/tests/dags/test_impersonation.py
@@ -0,0 +1,45 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from airflow.models import DAG
+from airflow.operators.bash_operator import BashOperator
+from datetime import datetime
+from textwrap import dedent
+
+
+DEFAULT_DATE = datetime(2016, 1, 1)
+
+args = {
+    'owner': 'airflow',
+    'start_date': DEFAULT_DATE,
+}
+
+dag = DAG(dag_id='test_impersonation', default_args=args)
+
+run_as_user = 'airflow_test_user'
+
+test_command = dedent(
+    """\
+    if [ '{user}' != "$(whoami)" ]; then
+        echo current user is not {user}!
+        exit 1
+    fi
+    """.format(user=run_as_user))
+
+task = BashOperator(
+    task_id='test_impersonated_user',
+    bash_command=test_command,
+    dag=dag,
+    run_as_user=run_as_user,
+)

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/b56cb5cc/tests/dags/test_no_impersonation.py
----------------------------------------------------------------------
diff --git a/tests/dags/test_no_impersonation.py 
b/tests/dags/test_no_impersonation.py
new file mode 100644
index 0000000..0fc63da
--- /dev/null
+++ b/tests/dags/test_no_impersonation.py
@@ -0,0 +1,43 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from airflow.models import DAG
+from airflow.operators.bash_operator import BashOperator
+from datetime import datetime
+from textwrap import dedent
+
+
+DEFAULT_DATE = datetime(2016, 1, 1)
+
+args = {
+    'owner': 'airflow',
+    'start_date': DEFAULT_DATE,
+}
+
+dag = DAG(dag_id='test_no_impersonation', default_args=args)
+
+test_command = dedent(
+    """\
+    sudo ls
+    if [ $? -ne 0 ]; then
+        echo 'current uid does not have root privileges!'
+        exit 1
+    fi
+    """)
+
+task = BashOperator(
+    task_id='test_superuser',
+    bash_command=test_command,
+    dag=dag,
+)

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/b56cb5cc/tests/impersonation.py
----------------------------------------------------------------------
diff --git a/tests/impersonation.py b/tests/impersonation.py
new file mode 100644
index 0000000..0777def
--- /dev/null
+++ b/tests/impersonation.py
@@ -0,0 +1,111 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# from __future__ import print_function
+import errno
+import os
+import subprocess
+import unittest
+
+from airflow import jobs, models
+from airflow.utils.state import State
+from datetime import datetime
+
+DEV_NULL = '/dev/null'
+TEST_DAG_FOLDER = os.path.join(
+    os.path.dirname(os.path.realpath(__file__)), 'dags')
+DEFAULT_DATE = datetime(2015, 1, 1)
+TEST_USER = 'airflow_test_user'
+
+
+# TODO(aoen): Adding/remove a user as part of a test is very bad (especially 
if the user
+# already existed to begin with on the OS), this logic should be moved into a 
test
+# that is wrapped in a container like docker so that the user can be safely 
added/removed.
+# When this is done we can also modify the sudoers file to ensure that useradd 
will work
+# without any manual modification of the sudoers file by the agent that is 
running these
+# tests.
+
+class ImpersonationTest(unittest.TestCase):
+    def setUp(self):
+        self.dagbag = models.DagBag(
+            dag_folder=TEST_DAG_FOLDER,
+            include_examples=False,
+        )
+        try:
+            subprocess.check_output(['sudo', 'useradd', '-m', TEST_USER, '-g',
+                                     str(os.getegid())])
+        except OSError as e:
+            if e.errno == errno.ENOENT:
+                raise unittest.SkipTest(
+                    "The 'useradd' command did not exist so unable to test "
+                    "impersonation; Skipping Test. These tests can only be run 
on a "
+                    "linux host that supports 'useradd'."
+                )
+            else:
+                raise unittest.SkipTest(
+                    "The 'useradd' command exited non-zero; Skipping tests. 
Does the "
+                    "current user have permission to run 'useradd' without a 
password "
+                    "prompt (check sudoers file)?"
+                )
+
+    def tearDown(self):
+        subprocess.check_output(['sudo', 'userdel', '-r', TEST_USER])
+
+    def run_backfill(self, dag_id, task_id):
+        dag = self.dagbag.get_dag(dag_id)
+        dag.clear()
+
+        jobs.BackfillJob(
+            dag=dag,
+            start_date=DEFAULT_DATE,
+            end_date=DEFAULT_DATE).run()
+
+        ti = models.TaskInstance(
+            task=dag.get_task(task_id),
+            execution_date=DEFAULT_DATE)
+        ti.refresh_from_db()
+        self.assertEqual(ti.state, State.SUCCESS)
+
+    def test_impersonation(self):
+        """
+        Tests that impersonating a unix user works
+        """
+        self.run_backfill(
+            'test_impersonation',
+            'test_impersonated_user'
+        )
+
+    def test_no_impersonation(self):
+        """
+        If default_impersonation=None, tests that the job is run
+        as the current user (which will be a sudoer)
+        """
+        self.run_backfill(
+            'test_no_impersonation',
+            'test_superuser',
+        )
+
+    def test_default_impersonation(self):
+        """
+        If default_impersonation=TEST_USER, tests that the job defaults
+        to running as TEST_USER for a test without run_as_user set
+        """
+        os.environ['AIRFLOW__CORE__DEFAULT_IMPERSONATION'] = TEST_USER
+
+        try:
+            self.run_backfill(
+                'test_default_impersonation',
+                'test_deelevated_user'
+            )
+        finally:
+            del os.environ['AIRFLOW__CORE__DEFAULT_IMPERSONATION']

Reply via email to