Repository: incubator-airflow
Updated Branches:
  refs/heads/master 08a18395e -> 3245d1745


[AIRFLOW-2534] Fix bug in HiveServer2Hook

This commit also adds numerous tests for
HiveServer2 and switches
Impyla for PyHive (0.6.0), making HiveServer2
Python 2 compatible.

Closes #3432 from gglanzani/AIRFLOW-2534


Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/3245d174
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/3245d174
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/3245d174

Branch: refs/heads/master
Commit: 3245d1745d08d4b3a7a7aad0b54aae734a2c1c53
Parents: 08a1839
Author: Giovanni Lanzani <[email protected]>
Authored: Fri Jun 15 14:19:25 2018 +0200
Committer: Fokko Driesprong <[email protected]>
Committed: Fri Jun 15 14:19:25 2018 +0200

----------------------------------------------------------------------
 airflow/hooks/hive_hooks.py   | 141 ++++++++++++++++++++++---------------
 setup.py                      |   5 +-
 tests/hooks/test_hive_hook.py |  86 ++++++++++++++++++++++
 tests/operators/operators.py  |   6 +-
 4 files changed, 174 insertions(+), 64 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/3245d174/airflow/hooks/hive_hooks.py
----------------------------------------------------------------------
diff --git a/airflow/hooks/hive_hooks.py b/airflow/hooks/hive_hooks.py
index 7f0f068..93e1f45 100644
--- a/airflow/hooks/hive_hooks.py
+++ b/airflow/hooks/hive_hooks.py
@@ -18,6 +18,10 @@
 # under the License.
 
 from __future__ import print_function, unicode_literals
+
+import contextlib
+import os
+
 from six.moves import zip
 from past.builtins import basestring, unicode
 
@@ -731,7 +735,7 @@ class HiveMetastoreHook(BaseHook):
 
 class HiveServer2Hook(BaseHook):
     """
-    Wrapper around the impyla library
+    Wrapper around the pyhive library
 
     Note that the default authMechanism is PLAIN, to override it you
     can specify it in the ``extra`` of your connection in the UI as in
@@ -741,56 +745,74 @@ class HiveServer2Hook(BaseHook):
 
     def get_conn(self, schema=None):
         db = self.get_connection(self.hiveserver2_conn_id)
-        auth_mechanism = db.extra_dejson.get('authMechanism', 'PLAIN')
+        auth_mechanism = db.extra_dejson.get('authMechanism', 'NONE')
+        if auth_mechanism == 'NONE' and db.login is None:
+            # we need to give a username
+            username = 'airflow'
         kerberos_service_name = None
         if configuration.conf.get('core', 'security') == 'kerberos':
-            auth_mechanism = db.extra_dejson.get('authMechanism', 'GSSAPI')
+            auth_mechanism = db.extra_dejson.get('authMechanism', 'KERBEROS')
             kerberos_service_name = 
db.extra_dejson.get('kerberos_service_name', 'hive')
 
-        # impyla uses GSSAPI instead of KERBEROS as a auth_mechanism identifier
-        if auth_mechanism == 'KERBEROS':
+        # pyhive uses GSSAPI instead of KERBEROS as a auth_mechanism identifier
+        if auth_mechanism == 'GSSAPI':
             self.log.warning(
-                "Detected deprecated 'KERBEROS' for "
-                "authMechanism for %s. Please use 'GSSAPI' instead",
+                "Detected deprecated 'GSSAPI' for authMechanism "
+                "for %s. Please use 'KERBEROS' instead",
                 self.hiveserver2_conn_id
             )
-            auth_mechanism = 'GSSAPI'
+            auth_mechanism = 'KERBEROS'
 
-        from impala.dbapi import connect
+        from pyhive.hive import connect
         return connect(
             host=db.host,
             port=db.port,
-            auth_mechanism=auth_mechanism,
+            auth=auth_mechanism,
             kerberos_service_name=kerberos_service_name,
-            user=db.login,
+            username=db.login or username,
             database=schema or db.schema or 'default')
 
-    def get_results(self, hql, schema='default', arraysize=1000):
-        from impala.error import ProgrammingError
-        with self.get_conn(schema) as conn:
-            if isinstance(hql, basestring):
-                hql = [hql]
-            results = {
-                'data': [],
-                'header': [],
-            }
-            cur = conn.cursor()
+    def _get_results(self, hql, schema='default', fetch_size=None):
+        from pyhive.exc import ProgrammingError
+        if isinstance(hql, basestring):
+            hql = [hql]
+        previous_description = None
+        with contextlib.closing(self.get_conn(schema)) as conn, \
+                contextlib.closing(conn.cursor()) as cur:
+            cur.arraysize = fetch_size or 1000
             for statement in hql:
                 cur.execute(statement)
-                records = []
-                try:
-                    # impala Lib raises when no results are returned
-                    # we're silencing here as some statements in the list
-                    # may be `SET` or DDL
-                    records = cur.fetchall()
-                except ProgrammingError:
-                    self.log.debug("get_results returned no records")
-                if records:
-                    results = {
-                        'data': records,
-                        'header': cur.description,
-                    }
-            return results
+                # we only get results of statements that returns
+                lowered_statement = statement.lower().strip()
+                if (lowered_statement.startswith('select') or
+                   lowered_statement.startswith('with')):
+                    description = [c for c in cur.description]
+                    if previous_description and previous_description != 
description:
+                        message = '''The statements are producing different 
descriptions:
+                                     Current: {}
+                                     Previous: {}'''.format(repr(description),
+                                                            
repr(previous_description))
+                        raise ValueError(message)
+                    elif not previous_description:
+                        previous_description = description
+                        yield description
+                    try:
+                        # DB API 2 raises when no results are returned
+                        # we're silencing here as some statements in the list
+                        # may be `SET` or DDL
+                        for row in cur:
+                            yield row
+                    except ProgrammingError:
+                        self.log.debug("get_results returned no records")
+
+    def get_results(self, hql, schema='default', fetch_size=None):
+        results_iter = self._get_results(hql, schema, fetch_size=fetch_size)
+        header = next(results_iter)
+        results = {
+            'data': list(results_iter),
+            'header': header
+        }
+        return results
 
     def to_csv(
             self,
@@ -801,29 +823,34 @@ class HiveServer2Hook(BaseHook):
             lineterminator='\r\n',
             output_header=True,
             fetch_size=1000):
-        schema = schema or 'default'
-        with self.get_conn(schema) as conn:
-            with conn.cursor() as cur:
-                self.log.info("Running query: %s", hql)
-                cur.execute(hql)
-                schema = cur.description
-                with open(csv_filepath, 'wb') as f:
-                    writer = csv.writer(f,
-                                        delimiter=delimiter,
-                                        lineterminator=lineterminator,
-                                        encoding='utf-8')
-                    if output_header:
-                        writer.writerow([c[0] for c in cur.description])
-                    i = 0
-                    while True:
-                        rows = [row for row in cur.fetchmany(fetch_size) if 
row]
-                        if not rows:
-                            break
-
-                        writer.writerows(rows)
-                        i += len(rows)
+
+        results_iter = self._get_results(hql, schema, fetch_size=fetch_size)
+        header = next(results_iter)
+        message = None
+
+        with open(csv_filepath, 'wb') as f:
+            writer = csv.writer(f,
+                                delimiter=delimiter,
+                                lineterminator=lineterminator,
+                                encoding='utf-8')
+            try:
+                if output_header:
+                    self.log.debug('Cursor description is %s', header)
+                    writer.writerow([c[0] for c in header])
+
+                for i, row in enumerate(results_iter):
+                    writer.writerow(row)
+                    if i % fetch_size == 0:
                         self.log.info("Written %s rows so far.", i)
-                    self.log.info("Done. Loaded a total of %s rows.", i)
+            except ValueError as exception:
+                message = str(exception)
+
+        if message:
+            # need to clean up the file first
+            os.remove(csv_filepath)
+            raise ValueError(message)
+
+        self.log.info("Done. Loaded a total of %s rows.", i)
 
     def get_records(self, hql, schema='default'):
         """

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/3245d174/setup.py
----------------------------------------------------------------------
diff --git a/setup.py b/setup.py
index fcae019..368652a 100644
--- a/setup.py
+++ b/setup.py
@@ -155,9 +155,8 @@ github_enterprise = ['Flask-OAuthlib>=0.9.1']
 hdfs = ['snakebite>=2.7.8']
 hive = [
     'hmsclient>=0.1.0',
-    'pyhive>=0.1.3',
-    'impyla>=0.13.3',
-    'thrift_sasl==0.2.1',
+    'pyhive>=0.6.0',
+    'impyla>=0.13.3'
 ]
 jdbc = ['jaydebeapi>=1.1.1']
 jenkins = ['python-jenkins>=0.4.15']

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/3245d174/tests/hooks/test_hive_hook.py
----------------------------------------------------------------------
diff --git a/tests/hooks/test_hive_hook.py b/tests/hooks/test_hive_hook.py
index b0029ea..7d5d103 100644
--- a/tests/hooks/test_hive_hook.py
+++ b/tests/hooks/test_hive_hook.py
@@ -20,6 +20,8 @@
 
 import datetime
 import itertools
+import os
+
 import pandas as pd
 import random
 
@@ -320,6 +322,90 @@ class TestHiveMetastoreHook(HiveEnvironmentTest):
 
 
 class TestHiveServer2Hook(unittest.TestCase):
+
+    def _upload_dataframe(self):
+        df = pd.DataFrame({'a': [1, 2], 'b': [1, 2]})
+        self.local_path = '/tmp/TestHiveServer2Hook.csv'
+        df.to_csv(self.local_path, header=False, index=False)
+
+    def setUp(self):
+        configuration.load_test_config()
+        self._upload_dataframe()
+        args = {'owner': 'airflow', 'start_date': DEFAULT_DATE}
+        self.dag = DAG('test_dag_id', default_args=args)
+        self.database = 'airflow'
+        self.table = 'hive_server_hook'
+        self.hql = """
+        CREATE DATABASE IF NOT EXISTS {{ params.database }};
+        USE {{ params.database }};
+        DROP TABLE IF EXISTS {{ params.table }};
+        CREATE TABLE IF NOT EXISTS {{ params.table }} (
+            a int,
+            b int)
+        ROW FORMAT DELIMITED
+        FIELDS TERMINATED BY ',';
+        LOAD DATA LOCAL INPATH '{{ params.csv_path }}'
+        OVERWRITE INTO TABLE {{ params.table }};
+        """
+        self.columns = ['{}.a'.format(self.table),
+                        '{}.b'.format(self.table)]
+        self.hook = HiveMetastoreHook()
+        t = HiveOperator(
+            task_id='HiveHook_' + str(random.randint(1, 10000)),
+            params={
+                'database': self.database,
+                'table': self.table,
+                'csv_path': self.local_path
+            },
+            hive_cli_conn_id='beeline_default',
+            hql=self.hql, dag=self.dag)
+        t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE,
+              ignore_ti_state=True)
+
+    def tearDown(self):
+        hook = HiveMetastoreHook()
+        with hook.get_conn() as metastore:
+            metastore.drop_table(self.database, self.table, deleteData=True)
+        os.remove(self.local_path)
+
     def test_get_conn(self):
         hook = HiveServer2Hook()
         hook.get_conn()
+
+    def test_get_records(self):
+        hook = HiveServer2Hook()
+        query = "SELECT * FROM {}".format(self.table)
+        results = hook.get_pandas_df(query, schema=self.database)
+        self.assertEqual(len(results), 2)
+
+    def test_get_pandas_df(self):
+        hook = HiveServer2Hook()
+        query = "SELECT * FROM {}".format(self.table)
+        df = hook.get_pandas_df(query, schema=self.database)
+        self.assertEqual(len(df), 2)
+        self.assertListEqual(df.columns.tolist(), self.columns)
+        self.assertListEqual(df[self.columns[0]].values.tolist(), [1, 2])
+
+    def test_get_results_header(self):
+        hook = HiveServer2Hook()
+        query = "SELECT * FROM {}".format(self.table)
+        results = hook.get_results(query, schema=self.database)
+        self.assertListEqual([col[0] for col in results['header']],
+                             self.columns)
+
+    def test_get_results_data(self):
+        hook = HiveServer2Hook()
+        query = "SELECT * FROM {}".format(self.table)
+        results = hook.get_results(query, schema=self.database)
+        self.assertListEqual(results['data'], [(1, 1), (2, 2)])
+
+    def test_to_csv(self):
+        hook = HiveServer2Hook()
+        query = "SELECT * FROM {}".format(self.table)
+        csv_filepath = 'query_results.csv'
+        hook.to_csv(query, csv_filepath, schema=self.database,
+                    delimiter=',', lineterminator='\n', output_header=True)
+        df = pd.read_csv(csv_filepath, sep=',')
+        self.assertListEqual(df.columns.tolist(), self.columns)
+        self.assertListEqual(df[self.columns[0]].values.tolist(), [1, 2])
+        self.assertEqual(len(df), 2)

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/3245d174/tests/operators/operators.py
----------------------------------------------------------------------
diff --git a/tests/operators/operators.py b/tests/operators/operators.py
index c343d5c..ae95325 100644
--- a/tests/operators/operators.py
+++ b/tests/operators/operators.py
@@ -7,9 +7,9 @@
 # to you under the Apache License, Version 2.0 (the
 # "License"); you may not use this file except in compliance
 # with the License.  You may obtain a copy of the License at
-# 
+#
 #   http://www.apache.org/licenses/LICENSE-2.0
-# 
+#
 # Unless required by applicable law or agreed to in writing,
 # software distributed under the License is distributed on an
 # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -362,8 +362,6 @@ class TransferTests(unittest.TestCase):
             with m.get_conn() as c:
                 c.execute("DROP TABLE IF EXISTS {}".format(mysql_table))
 
-    @unittest.skipIf(six.PY2, "Skip since HiveServer2Hook doesn't work "
-                              "on Python2 for now. See AIRFLOW-2514.")
     def test_mysql_to_hive_verify_loaded_values(self):
         mysql_conn_id = 'airflow_ci'
         mysql_table = 'test_mysql_to_hive'

Reply via email to