Repository: incubator-airflow
Updated Branches:
  refs/heads/master 8b86ee6a7 -> 92064398c


[AIRFLOW-264] Adding workload management for Hive

Dear Airflow Maintainers,

Please accept this PR that addresses the following issues:
- https://issues.apache.org/jira/browse/AIRFLOW-264

CC: Original PR by Jparks2532
https://github.com/apache/incubator-airflow/pull/1384

Add workload management to the hive hook and operator.
Edited operator_helper to avoid KeyError on retrieving conf values.
Refactored hive_cli command preparation in a separate private
method.
Added a small helper to flatten one level of an iterator to a list.

Closes #1614 from artwr/artwr_fixing_hive_queue_PR


Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/92064398
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/92064398
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/92064398

Branch: refs/heads/master
Commit: 92064398c4c982a310925da376745a1713bf96e2
Parents: 8b86ee6
Author: Arthur Wiedmer <arthur.wied...@gmail.com>
Authored: Wed Jul 6 12:43:12 2016 -0700
Committer: Arthur Wiedmer <arthur.wied...@gmail.com>
Committed: Wed Jul 6 12:43:44 2016 -0700

----------------------------------------------------------------------
 airflow/hooks/hive_hooks.py        | 183 +++++++++++++++++++++++---------
 airflow/operators/hive_operator.py |  22 +++-
 airflow/utils/helpers.py           |  10 ++
 airflow/utils/operator_helpers.py  |  10 +-
 tests/operators/hive_operator.py   |  12 ++-
 tests/operators/operators.py       |   2 +-
 6 files changed, 179 insertions(+), 60 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/92064398/airflow/hooks/hive_hooks.py
----------------------------------------------------------------------
diff --git a/airflow/hooks/hive_hooks.py b/airflow/hooks/hive_hooks.py
index eaad390..a9fac48 100644
--- a/airflow/hooks/hive_hooks.py
+++ b/airflow/hooks/hive_hooks.py
@@ -17,19 +17,25 @@ from __future__ import print_function
 from builtins import zip
 from past.builtins import basestring
 
+import collections
 import unicodecsv as csv
+import itertools
 import logging
 import re
 import subprocess
+import time
 from tempfile import NamedTemporaryFile
 import hive_metastore
 
 from airflow.exceptions import AirflowException
 from airflow.hooks.base_hook import BaseHook
+from airflow.utils.helpers import as_flattened_list
 from airflow.utils.file import TemporaryDirectory
 from airflow import configuration
 import airflow.security.utils as utils
 
+HIVE_QUEUE_PRIORITIES = ['VERY_HIGH', 'HIGH', 'NORMAL', 'LOW', 'VERY_LOW']
+
 
 class HiveCliHook(BaseHook):
 
@@ -47,12 +53,24 @@ class HiveCliHook(BaseHook):
 
     The extra connection parameter ``auth`` gets passed as in the ``jdbc``
     connection string as is.
+
+    :param mapred_queue: queue used by the Hadoop Scheduler (Capacity or Fair)
+    :type  mapred_queue: string
+    :param mapred_queue_priority: priority within the job queue.
+        Possible settings include: VERY_HIGH, HIGH, NORMAL, LOW, VERY_LOW
+    :type  mapred_queue_priority: string
+    :param mapred_job_name: This name will appear in the jobtracker.
+        This can make monitoring easier.
+    :type  mapred_job_name: string
     """
 
     def __init__(
             self,
             hive_cli_conn_id="hive_cli_default",
-            run_as=None):
+            run_as=None,
+            mapred_queue=None,
+            mapred_queue_priority=None,
+            mapred_job_name=None):
         conn = self.get_connection(hive_cli_conn_id)
         self.hive_cli_params = conn.extra_dejson.get('hive_cli_params', '')
         self.use_beeline = conn.extra_dejson.get('use_beeline', False)
@@ -60,16 +78,92 @@ class HiveCliHook(BaseHook):
         self.conn = conn
         self.run_as = run_as
 
+        if mapred_queue_priority:
+            mapred_queue_priority = mapred_queue_priority.upper()
+            if mapred_queue_priority not in HIVE_QUEUE_PRIORITIES:
+                raise AirflowException(
+                    "Invalid Mapred Queue Priority.  Valid values are: "
+                    "{}".format(', '.join(HIVE_QUEUE_PRIORITIES)))
+
+        self.mapred_queue = mapred_queue
+        self.mapred_queue_priority = mapred_queue_priority
+        self.mapred_job_name = mapred_job_name
+
+    def _prepare_cli_cmd(self):
+        """
+        This function creates the command list from available information
+        """
+        conn = self.conn
+        hive_bin = 'hive'
+        cmd_extra = []
+
+        if self.use_beeline:
+            hive_bin = 'beeline'
+            jdbc_url = "jdbc:hive2://{conn.host}:{conn.port}/{conn.schema}"
+            if configuration.get('core', 'security') == 'kerberos':
+                template = conn.extra_dejson.get(
+                    'principal', "hive/_h...@example.com")
+                if "_HOST" in template:
+                    template = utils.replace_hostname_pattern(
+                        utils.get_components(template))
+
+                proxy_user = ""  # noqa
+                if conn.extra_dejson.get('proxy_user') == "login" and 
conn.login:
+                    proxy_user = 
"hive.server2.proxy.user={0}".format(conn.login)
+                elif conn.extra_dejson.get('proxy_user') == "owner" and 
self.run_as:
+                    proxy_user = 
"hive.server2.proxy.user={0}".format(self.run_as)
+
+                jdbc_url += ";principal={template};{proxy_user}"
+            elif self.auth:
+                jdbc_url += ";auth=" + self.auth
+
+            jdbc_url = jdbc_url.format(**locals())
+
+            cmd_extra += ['-u', jdbc_url]
+            if conn.login:
+                cmd_extra += ['-n', conn.login]
+            if conn.password:
+                cmd_extra += ['-p', conn.password]
+
+        hive_params_list = self.hive_cli_params.split()
+
+        return [hive_bin] + cmd_extra + hive_params_list
+
+    def _prepare_hiveconf(self, d):
+        """
+        This function prepares a list of hiveconf params
+        from a dictionary of key value pairs.
+
+        :param d:
+        :type d: dict
+
+        >>> hh = HiveCliHook()
+        >>> hive_conf = {"hive.exec.dynamic.partition": "true",
+        ... "hive.exec.dynamic.partition.mode": "nonstrict"}
+        >>> hh._prepare_hiveconf(hive_conf)
+        ["-hiveconf", "hive.exec.dynamic.partition=true",\
+ "-hiveconf", "hive.exec.dynamic.partition.mode=nonstrict"]
+        """
+        if not d:
+            return []
+        return as_flattened_list(
+            itertools.izip(
+                ["-hiveconf"] * len(d),
+                ["{}={}".format(k, v) for k, v in d.items()]
+                )
+            )
+
     def run_cli(self, hql, schema=None, verbose=True, hive_conf=None):
         """
-        Run an hql statement using the hive cli. If hive_conf is specified it 
should be a
-        dict and the entries will be set as key/value pairs in HiveConf
+        Run an hql statement using the hive cli. If hive_conf is specified
+        it should be a dict and the entries will be set as key/value pairs
+        in HiveConf
 
 
-        :param hive_conf: if specified these key value pairs will be passed to 
hive as
-            ``-hiveconf "key"="value"``. Note that they will be passed after 
the
-            ``hive_cli_params`` and thus will override whatever values are 
specified in
-            the database.
+        :param hive_conf: if specified these key value pairs will be passed
+            to hive as ``-hiveconf "key"="value"``. Note that they will be
+            passed after the ``hive_cli_params`` and thus will override
+            whatever values are specified in the database.
         :type hive_conf: dict
 
         >>> hh = HiveCliHook()
@@ -86,47 +180,29 @@ class HiveCliHook(BaseHook):
             with NamedTemporaryFile(dir=tmp_dir) as f:
                 f.write(hql.encode('UTF-8'))
                 f.flush()
-                fname = f.name
-                hive_bin = 'hive'
-                cmd_extra = []
-
-                if self.use_beeline:
-                    hive_bin = 'beeline'
-                    jdbc_url = 
"jdbc:hive2://{conn.host}:{conn.port}/{conn.schema}"
-                    if configuration.get('core', 'security') == 'kerberos':
-                        template = conn.extra_dejson.get(
-                            'principal', "hive/_h...@example.com")
-                        if "_HOST" in template:
-                            template = utils.replace_hostname_pattern(
-                                utils.get_components(template))
-
-                        proxy_user = ""  # noqa
-                        if conn.extra_dejson.get('proxy_user') == "login" and 
conn.login:
-                            proxy_user = 
"hive.server2.proxy.user={0}".format(conn.login)
-                        elif conn.extra_dejson.get('proxy_user') == "owner" 
and self.run_as:
-                            proxy_user = 
"hive.server2.proxy.user={0}".format(self.run_as)
-
-                        jdbc_url += ";principal={template};{proxy_user}"
-                    elif self.auth:
-                        jdbc_url += ";auth=" + self.auth
-
-                    jdbc_url = jdbc_url.format(**locals())
-
-                    cmd_extra += ['-u', jdbc_url]
-                    if conn.login:
-                        cmd_extra += ['-n', conn.login]
-                    if conn.password:
-                        cmd_extra += ['-p', conn.password]
-
-                hive_conf = hive_conf or {}
-                for key, value in hive_conf.items():
-                    cmd_extra += ['-hiveconf', '{0}={1}'.format(key, value)]
-
-                hive_cmd = [hive_bin, '-f', fname] + cmd_extra
-
-                if self.hive_cli_params:
-                    hive_params_list = self.hive_cli_params.split()
-                    hive_cmd.extend(hive_params_list)
+                hive_cmd = self._prepare_cli_cmd()
+                hive_conf_params = self._prepare_hiveconf(hive_conf)
+                if self.mapred_queue:
+                    hive_conf_params.extend(
+                        ['-hiveconf',
+                         'mapreduce.job.queuename={}'
+                         .format(self.mapred_queue)])
+
+                if self.mapred_queue_priority:
+                    hive_conf_params.extend(
+                        ['-hiveconf',
+                         'mapreduce.job.priority={}'
+                         .format(self.mapred_queue_priority)])
+
+                if self.mapred_job_name:
+                    hive_conf_params.extend(
+                        ['-hiveconf',
+                         'mapred.job.name={}'
+                         .format(self.mapred_job_name)])
+
+                hive_cmd.extend(hive_conf_params)
+                hive_cmd.extend(['-f', f.name])
+
                 if verbose:
                     logging.info(" ".join(hive_cmd))
                 sp = subprocess.Popen(
@@ -260,6 +336,8 @@ class HiveCliHook(BaseHook):
         if hasattr(self, 'sp'):
             if self.sp.poll() is None:
                 print("Killing the Hive job")
+                self.sp.terminate()
+                time.sleep(60)
                 self.sp.kill()
 
 
@@ -561,11 +639,12 @@ class HiveServer2Hook(BaseHook):
                 cur.execute(hql)
                 schema = cur.description
                 with open(csv_filepath, 'wb') as f:
-                    writer = csv.writer(f, delimiter=delimiter,
-                        lineterminator=lineterminator, encoding='utf-8')
+                    writer = csv.writer(f,
+                                        delimiter=delimiter,
+                                        lineterminator=lineterminator,
+                                        encoding='utf-8')
                     if output_header:
-                        writer.writerow([c[0]
-                            for c in cur.description])
+                        writer.writerow([c[0] for c in cur.description])
                     i = 0
                     while True:
                         rows = [row for row in cur.fetchmany(fetch_size) if 
row]

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/92064398/airflow/operators/hive_operator.py
----------------------------------------------------------------------
diff --git a/airflow/operators/hive_operator.py 
b/airflow/operators/hive_operator.py
index 3763d6b..06a83e3 100644
--- a/airflow/operators/hive_operator.py
+++ b/airflow/operators/hive_operator.py
@@ -38,6 +38,14 @@ class HiveOperator(BaseOperator):
     :param script_begin_tag: If defined, the operator will get rid of the
         part of the script before the first occurrence of `script_begin_tag`
     :type script_begin_tag: str
+    :param mapred_queue: queue used by the Hadoop CapacityScheduler
+    :type  mapred_queue: string
+    :param mapred_queue_priority: priority within CapacityScheduler queue.
+        Possible settings include: VERY_HIGH, HIGH, NORMAL, LOW, VERY_LOW
+    :type  mapred_queue_priority: string
+    :param mapred_job_name: This name will appear in the jobtracker.
+        This can make monitoring easier.
+    :type  mapred_job_name: string
     """
 
     template_fields = ('hql', 'schema')
@@ -52,6 +60,9 @@ class HiveOperator(BaseOperator):
             hiveconf_jinja_translate=False,
             script_begin_tag=None,
             run_as_owner=False,
+            mapred_queue=None,
+            mapred_queue_priority=None,
+            mapred_job_name=None,
             *args, **kwargs):
 
         super(HiveOperator, self).__init__(*args, **kwargs)
@@ -64,8 +75,17 @@ class HiveOperator(BaseOperator):
         if run_as_owner:
             self.run_as = self.dag.owner
 
+        self.mapred_queue = mapred_queue
+        self.mapred_queue_priority = mapred_queue_priority
+        self.mapred_job_name = mapred_job_name
+
     def get_hook(self):
-        return HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id, 
run_as=self.run_as)
+        return HiveCliHook(
+                        hive_cli_conn_id=self.hive_cli_conn_id,
+                        run_as=self.run_as,
+                        mapred_queue=self.mapred_queue,
+                        mapred_queue_priority=self.mapred_queue_priority,
+                        mapred_job_name=self.mapred_job_name)
 
     def prepare_template(self):
         if self.hiveconf_jinja_translate:

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/92064398/airflow/utils/helpers.py
----------------------------------------------------------------------
diff --git a/airflow/utils/helpers.py b/airflow/utils/helpers.py
index 7e3426e..5fc8fc1 100644
--- a/airflow/utils/helpers.py
+++ b/airflow/utils/helpers.py
@@ -104,6 +104,16 @@ def as_tuple(obj):
         return tuple([obj])
 
 
+def as_flattened_list(iterable):
+    """
+    Return an iterable with one level flattened
+
+    >>> as_flattened_list((('blue', 'red'), ('green', 'yellow', 'pink')))
+    ['blue', 'red', 'green', 'yellow', 'pink']
+    """
+    return [e for i in iterable for e in i]
+
+
 def chain(*tasks):
     """
     Given a number of tasks, builds a dependency chain.

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/92064398/airflow/utils/operator_helpers.py
----------------------------------------------------------------------
diff --git a/airflow/utils/operator_helpers.py 
b/airflow/utils/operator_helpers.py
index 617976e..7381fb3 100644
--- a/airflow/utils/operator_helpers.py
+++ b/airflow/utils/operator_helpers.py
@@ -20,19 +20,19 @@ def context_to_airflow_vars(context):
     externally reconstruct relations between dags, dag_runs, tasks and 
task_instances.
 
     :param context: The context for the task_instance of interest
-    :type successes: dict
+    :type context: dict
     """
     params = dict()
-    dag = context['dag']
+    dag = context.get('dag')
     if dag and dag.dag_id:
         params['airflow.ctx.dag.dag_id'] = dag.dag_id
-    dag_run = context['dag_run']
+    dag_run = context.get('dag_run')
     if dag_run and dag_run.execution_date:
         params['airflow.ctx.dag_run.execution_date'] = 
dag_run.execution_date.isoformat()
-    task = context['task']
+    task = context.get('task')
     if task and task.task_id:
         params['airflow.ctx.task.task_id'] = task.task_id
-    task_instance = context['task_instance']
+    task_instance = context.get('task_instance')
     if task_instance and task_instance.execution_date:
         params['airflow.ctx.task_instance.execution_date'] = \
             task_instance.execution_date.isoformat()

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/92064398/tests/operators/hive_operator.py
----------------------------------------------------------------------
diff --git a/tests/operators/hive_operator.py b/tests/operators/hive_operator.py
index f59bbf1..c5a0f0e 100644
--- a/tests/operators/hive_operator.py
+++ b/tests/operators/hive_operator.py
@@ -96,10 +96,20 @@ if 'AIRFLOW_RUNALL_TESTS' in os.environ:
                 task_id='basic_hql', hql=self.hql, dag=self.dag)
             t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, force=True)
 
+        def test_hive_queues(self):
+            import airflow.operators.hive_operator
+            t = operators.hive_operator.HiveOperator(
+                task_id='test_hive_queues', hql=self.hql,
+                mapred_queue='default', mapred_queue_priority='HIGH',
+                mapred_job_name='airflow.test_hive_queues',
+                dag=self.dag)
+            t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, force=True)
+
+
         def test_hive_dryrun(self):
             import airflow.operators.hive_operator
             t = operators.hive_operator.HiveOperator(
-                task_id='basic_hql', hql=self.hql, dag=self.dag)
+                task_id='dry_run_basic_hql', hql=self.hql, dag=self.dag)
             t.dry_run()
 
         def test_beeline(self):

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/92064398/tests/operators/operators.py
----------------------------------------------------------------------
diff --git a/tests/operators/operators.py b/tests/operators/operators.py
index 2365ba0..8ca47bd 100644
--- a/tests/operators/operators.py
+++ b/tests/operators/operators.py
@@ -167,7 +167,7 @@ class TransferTests(unittest.TestCase):
 
     def setUp(self):
         configuration.test_mode()
-        args = {'owner': 'airflow', 'start_date': DEFAULT_DATE_ISO}
+        args = {'owner': 'airflow', 'start_date': DEFAULT_DATE}
         dag = DAG(TEST_DAG_ID, default_args=args)
         self.dag = dag
 

Reply via email to