Repository: incubator-airflow Updated Branches: refs/heads/master 08a18395e -> 3245d1745
[AIRFLOW-2534] Fix bug in HiveServer2Hook This commit also adds numerous tests for HiveServer2 and switches Impyla for PyHive (0.6.0), making HiveServer2 Python 2 compatible. Closes #3432 from gglanzani/AIRFLOW-2534 Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/3245d174 Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/3245d174 Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/3245d174 Branch: refs/heads/master Commit: 3245d1745d08d4b3a7a7aad0b54aae734a2c1c53 Parents: 08a1839 Author: Giovanni Lanzani <[email protected]> Authored: Fri Jun 15 14:19:25 2018 +0200 Committer: Fokko Driesprong <[email protected]> Committed: Fri Jun 15 14:19:25 2018 +0200 ---------------------------------------------------------------------- airflow/hooks/hive_hooks.py | 141 ++++++++++++++++++++++--------------- setup.py | 5 +- tests/hooks/test_hive_hook.py | 86 ++++++++++++++++++++++ tests/operators/operators.py | 6 +- 4 files changed, 174 insertions(+), 64 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/3245d174/airflow/hooks/hive_hooks.py ---------------------------------------------------------------------- diff --git a/airflow/hooks/hive_hooks.py b/airflow/hooks/hive_hooks.py index 7f0f068..93e1f45 100644 --- a/airflow/hooks/hive_hooks.py +++ b/airflow/hooks/hive_hooks.py @@ -18,6 +18,10 @@ # under the License. from __future__ import print_function, unicode_literals + +import contextlib +import os + from six.moves import zip from past.builtins import basestring, unicode @@ -731,7 +735,7 @@ class HiveMetastoreHook(BaseHook): class HiveServer2Hook(BaseHook): """ - Wrapper around the impyla library + Wrapper around the pyhive library Note that the default authMechanism is PLAIN, to override it you can specify it in the ``extra`` of your connection in the UI as in @@ -741,56 +745,74 @@ class HiveServer2Hook(BaseHook): def get_conn(self, schema=None): db = self.get_connection(self.hiveserver2_conn_id) - auth_mechanism = db.extra_dejson.get('authMechanism', 'PLAIN') + auth_mechanism = db.extra_dejson.get('authMechanism', 'NONE') + if auth_mechanism == 'NONE' and db.login is None: + # we need to give a username + username = 'airflow' kerberos_service_name = None if configuration.conf.get('core', 'security') == 'kerberos': - auth_mechanism = db.extra_dejson.get('authMechanism', 'GSSAPI') + auth_mechanism = db.extra_dejson.get('authMechanism', 'KERBEROS') kerberos_service_name = db.extra_dejson.get('kerberos_service_name', 'hive') - # impyla uses GSSAPI instead of KERBEROS as a auth_mechanism identifier - if auth_mechanism == 'KERBEROS': + # pyhive uses GSSAPI instead of KERBEROS as a auth_mechanism identifier + if auth_mechanism == 'GSSAPI': self.log.warning( - "Detected deprecated 'KERBEROS' for " - "authMechanism for %s. Please use 'GSSAPI' instead", + "Detected deprecated 'GSSAPI' for authMechanism " + "for %s. Please use 'KERBEROS' instead", self.hiveserver2_conn_id ) - auth_mechanism = 'GSSAPI' + auth_mechanism = 'KERBEROS' - from impala.dbapi import connect + from pyhive.hive import connect return connect( host=db.host, port=db.port, - auth_mechanism=auth_mechanism, + auth=auth_mechanism, kerberos_service_name=kerberos_service_name, - user=db.login, + username=db.login or username, database=schema or db.schema or 'default') - def get_results(self, hql, schema='default', arraysize=1000): - from impala.error import ProgrammingError - with self.get_conn(schema) as conn: - if isinstance(hql, basestring): - hql = [hql] - results = { - 'data': [], - 'header': [], - } - cur = conn.cursor() + def _get_results(self, hql, schema='default', fetch_size=None): + from pyhive.exc import ProgrammingError + if isinstance(hql, basestring): + hql = [hql] + previous_description = None + with contextlib.closing(self.get_conn(schema)) as conn, \ + contextlib.closing(conn.cursor()) as cur: + cur.arraysize = fetch_size or 1000 for statement in hql: cur.execute(statement) - records = [] - try: - # impala Lib raises when no results are returned - # we're silencing here as some statements in the list - # may be `SET` or DDL - records = cur.fetchall() - except ProgrammingError: - self.log.debug("get_results returned no records") - if records: - results = { - 'data': records, - 'header': cur.description, - } - return results + # we only get results of statements that returns + lowered_statement = statement.lower().strip() + if (lowered_statement.startswith('select') or + lowered_statement.startswith('with')): + description = [c for c in cur.description] + if previous_description and previous_description != description: + message = '''The statements are producing different descriptions: + Current: {} + Previous: {}'''.format(repr(description), + repr(previous_description)) + raise ValueError(message) + elif not previous_description: + previous_description = description + yield description + try: + # DB API 2 raises when no results are returned + # we're silencing here as some statements in the list + # may be `SET` or DDL + for row in cur: + yield row + except ProgrammingError: + self.log.debug("get_results returned no records") + + def get_results(self, hql, schema='default', fetch_size=None): + results_iter = self._get_results(hql, schema, fetch_size=fetch_size) + header = next(results_iter) + results = { + 'data': list(results_iter), + 'header': header + } + return results def to_csv( self, @@ -801,29 +823,34 @@ class HiveServer2Hook(BaseHook): lineterminator='\r\n', output_header=True, fetch_size=1000): - schema = schema or 'default' - with self.get_conn(schema) as conn: - with conn.cursor() as cur: - self.log.info("Running query: %s", hql) - cur.execute(hql) - schema = cur.description - with open(csv_filepath, 'wb') as f: - writer = csv.writer(f, - delimiter=delimiter, - lineterminator=lineterminator, - encoding='utf-8') - if output_header: - writer.writerow([c[0] for c in cur.description]) - i = 0 - while True: - rows = [row for row in cur.fetchmany(fetch_size) if row] - if not rows: - break - - writer.writerows(rows) - i += len(rows) + + results_iter = self._get_results(hql, schema, fetch_size=fetch_size) + header = next(results_iter) + message = None + + with open(csv_filepath, 'wb') as f: + writer = csv.writer(f, + delimiter=delimiter, + lineterminator=lineterminator, + encoding='utf-8') + try: + if output_header: + self.log.debug('Cursor description is %s', header) + writer.writerow([c[0] for c in header]) + + for i, row in enumerate(results_iter): + writer.writerow(row) + if i % fetch_size == 0: self.log.info("Written %s rows so far.", i) - self.log.info("Done. Loaded a total of %s rows.", i) + except ValueError as exception: + message = str(exception) + + if message: + # need to clean up the file first + os.remove(csv_filepath) + raise ValueError(message) + + self.log.info("Done. Loaded a total of %s rows.", i) def get_records(self, hql, schema='default'): """ http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/3245d174/setup.py ---------------------------------------------------------------------- diff --git a/setup.py b/setup.py index fcae019..368652a 100644 --- a/setup.py +++ b/setup.py @@ -155,9 +155,8 @@ github_enterprise = ['Flask-OAuthlib>=0.9.1'] hdfs = ['snakebite>=2.7.8'] hive = [ 'hmsclient>=0.1.0', - 'pyhive>=0.1.3', - 'impyla>=0.13.3', - 'thrift_sasl==0.2.1', + 'pyhive>=0.6.0', + 'impyla>=0.13.3' ] jdbc = ['jaydebeapi>=1.1.1'] jenkins = ['python-jenkins>=0.4.15'] http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/3245d174/tests/hooks/test_hive_hook.py ---------------------------------------------------------------------- diff --git a/tests/hooks/test_hive_hook.py b/tests/hooks/test_hive_hook.py index b0029ea..7d5d103 100644 --- a/tests/hooks/test_hive_hook.py +++ b/tests/hooks/test_hive_hook.py @@ -20,6 +20,8 @@ import datetime import itertools +import os + import pandas as pd import random @@ -320,6 +322,90 @@ class TestHiveMetastoreHook(HiveEnvironmentTest): class TestHiveServer2Hook(unittest.TestCase): + + def _upload_dataframe(self): + df = pd.DataFrame({'a': [1, 2], 'b': [1, 2]}) + self.local_path = '/tmp/TestHiveServer2Hook.csv' + df.to_csv(self.local_path, header=False, index=False) + + def setUp(self): + configuration.load_test_config() + self._upload_dataframe() + args = {'owner': 'airflow', 'start_date': DEFAULT_DATE} + self.dag = DAG('test_dag_id', default_args=args) + self.database = 'airflow' + self.table = 'hive_server_hook' + self.hql = """ + CREATE DATABASE IF NOT EXISTS {{ params.database }}; + USE {{ params.database }}; + DROP TABLE IF EXISTS {{ params.table }}; + CREATE TABLE IF NOT EXISTS {{ params.table }} ( + a int, + b int) + ROW FORMAT DELIMITED + FIELDS TERMINATED BY ','; + LOAD DATA LOCAL INPATH '{{ params.csv_path }}' + OVERWRITE INTO TABLE {{ params.table }}; + """ + self.columns = ['{}.a'.format(self.table), + '{}.b'.format(self.table)] + self.hook = HiveMetastoreHook() + t = HiveOperator( + task_id='HiveHook_' + str(random.randint(1, 10000)), + params={ + 'database': self.database, + 'table': self.table, + 'csv_path': self.local_path + }, + hive_cli_conn_id='beeline_default', + hql=self.hql, dag=self.dag) + t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, + ignore_ti_state=True) + + def tearDown(self): + hook = HiveMetastoreHook() + with hook.get_conn() as metastore: + metastore.drop_table(self.database, self.table, deleteData=True) + os.remove(self.local_path) + def test_get_conn(self): hook = HiveServer2Hook() hook.get_conn() + + def test_get_records(self): + hook = HiveServer2Hook() + query = "SELECT * FROM {}".format(self.table) + results = hook.get_pandas_df(query, schema=self.database) + self.assertEqual(len(results), 2) + + def test_get_pandas_df(self): + hook = HiveServer2Hook() + query = "SELECT * FROM {}".format(self.table) + df = hook.get_pandas_df(query, schema=self.database) + self.assertEqual(len(df), 2) + self.assertListEqual(df.columns.tolist(), self.columns) + self.assertListEqual(df[self.columns[0]].values.tolist(), [1, 2]) + + def test_get_results_header(self): + hook = HiveServer2Hook() + query = "SELECT * FROM {}".format(self.table) + results = hook.get_results(query, schema=self.database) + self.assertListEqual([col[0] for col in results['header']], + self.columns) + + def test_get_results_data(self): + hook = HiveServer2Hook() + query = "SELECT * FROM {}".format(self.table) + results = hook.get_results(query, schema=self.database) + self.assertListEqual(results['data'], [(1, 1), (2, 2)]) + + def test_to_csv(self): + hook = HiveServer2Hook() + query = "SELECT * FROM {}".format(self.table) + csv_filepath = 'query_results.csv' + hook.to_csv(query, csv_filepath, schema=self.database, + delimiter=',', lineterminator='\n', output_header=True) + df = pd.read_csv(csv_filepath, sep=',') + self.assertListEqual(df.columns.tolist(), self.columns) + self.assertListEqual(df[self.columns[0]].values.tolist(), [1, 2]) + self.assertEqual(len(df), 2) http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/3245d174/tests/operators/operators.py ---------------------------------------------------------------------- diff --git a/tests/operators/operators.py b/tests/operators/operators.py index c343d5c..ae95325 100644 --- a/tests/operators/operators.py +++ b/tests/operators/operators.py @@ -7,9 +7,9 @@ # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -362,8 +362,6 @@ class TransferTests(unittest.TestCase): with m.get_conn() as c: c.execute("DROP TABLE IF EXISTS {}".format(mysql_table)) - @unittest.skipIf(six.PY2, "Skip since HiveServer2Hook doesn't work " - "on Python2 for now. See AIRFLOW-2514.") def test_mysql_to_hive_verify_loaded_values(self): mysql_conn_id = 'airflow_ci' mysql_table = 'test_mysql_to_hive'
