[AIRFLOW-793] Enable compressed loading in S3ToHiveTransfer Testing Done: - Added new unit tests for the S3ToHiveTransfer module
Closes #2012 from krishnabhupatiraju/S3ToHiveTrans fer_compress_loading (cherry picked from commit ad15f5efd6c663bd5f0c8cd3f556d08182cc778c) Signed-off-by: Bolke de Bruin <[email protected]> Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/1c231333 Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/1c231333 Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/1c231333 Branch: refs/heads/v1-8-stable Commit: 1c2313338a586aae4a7752c3fb3b9de4e3564415 Parents: 3658bf3 Author: Krishna Bhupatiraju <[email protected]> Authored: Mon Feb 6 16:52:11 2017 -0800 Committer: Bolke de Bruin <[email protected]> Committed: Sat Feb 18 15:56:37 2017 +0100 ---------------------------------------------------------------------- airflow/operators/s3_to_hive_operator.py | 151 ++++++++++++---- airflow/utils/compression.py | 38 ++++ tests/operators/__init__.py | 1 + tests/operators/s3_to_hive_operator.py | 247 ++++++++++++++++++++++++++ tests/utils/compression.py | 97 ++++++++++ 5 files changed, 497 insertions(+), 37 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/1c231333/airflow/operators/s3_to_hive_operator.py ---------------------------------------------------------------------- diff --git a/airflow/operators/s3_to_hive_operator.py b/airflow/operators/s3_to_hive_operator.py index 3e01c29..92340f8 100644 --- a/airflow/operators/s3_to_hive_operator.py +++ b/airflow/operators/s3_to_hive_operator.py @@ -16,13 +16,18 @@ from builtins import next from builtins import zip import logging from tempfile import NamedTemporaryFile +from airflow.utils.file import TemporaryDirectory +import gzip +import bz2 +import tempfile +import os from airflow.exceptions import AirflowException from airflow.hooks.S3_hook import S3Hook from airflow.hooks.hive_hooks import HiveCliHook from airflow.models import BaseOperator from airflow.utils.decorators import apply_defaults - +from airflow.utils.compression import uncompress_file class S3ToHiveTransfer(BaseOperator): """ @@ -68,8 +73,11 @@ class S3ToHiveTransfer(BaseOperator): :type delimiter: str :param s3_conn_id: source s3 connection :type s3_conn_id: str - :param hive_conn_id: destination hive connection - :type hive_conn_id: str + :param hive_cli_conn_id: destination hive connection + :type hive_cli_conn_id: str + :param input_compressed: Boolean to determine if file decompression is + required to process headers + :type input_compressed: bool """ template_fields = ('s3_key', 'partition', 'hive_table') @@ -91,6 +99,7 @@ class S3ToHiveTransfer(BaseOperator): wildcard_match=False, s3_conn_id='s3_default', hive_cli_conn_id='hive_cli_default', + input_compressed=False, *args, **kwargs): super(S3ToHiveTransfer, self).__init__(*args, **kwargs) self.s3_key = s3_key @@ -105,28 +114,41 @@ class S3ToHiveTransfer(BaseOperator): self.wildcard_match = wildcard_match self.hive_cli_conn_id = hive_cli_conn_id self.s3_conn_id = s3_conn_id + self.input_compressed = input_compressed + + if (self.check_headers and + not (self.field_dict is not None and self.headers)): + raise AirflowException("To check_headers provide " + + "field_dict and headers") def execute(self, context): - self.hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id) + # Downloading file from S3 self.s3 = S3Hook(s3_conn_id=self.s3_conn_id) + self.hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id) logging.info("Downloading S3 file") + if self.wildcard_match: if not self.s3.check_for_wildcard_key(self.s3_key): - raise AirflowException("No key matches {0}".format(self.s3_key)) + raise AirflowException("No key matches {0}" + .format(self.s3_key)) s3_key_object = self.s3.get_wildcard_key(self.s3_key) else: if not self.s3.check_for_key(self.s3_key): raise AirflowException( "The key {0} does not exists".format(self.s3_key)) s3_key_object = self.s3.get_key(self.s3_key) - with NamedTemporaryFile("w") as f: + root, file_ext = os.path.splitext(s3_key_object.key) + with TemporaryDirectory(prefix='tmps32hive_') as tmp_dir,\ + NamedTemporaryFile(mode="w", + dir=tmp_dir, + suffix=file_ext) as f: logging.info("Dumping S3 key {0} contents to local" " file {1}".format(s3_key_object.key, f.name)) s3_key_object.get_contents_to_file(f) f.flush() self.s3.connection.close() if not self.headers: - logging.info("Loading file into Hive") + logging.info("Loading file {0} into Hive".format(f.name)) self.hive.load_file( f.name, self.hive_table, @@ -136,33 +158,88 @@ class S3ToHiveTransfer(BaseOperator): delimiter=self.delimiter, recreate=self.recreate) else: - with open(f.name, 'r') as tmpf: - if self.check_headers: - header_l = tmpf.readline() - header_line = header_l.rstrip() - header_list = header_line.split(self.delimiter) - field_names = list(self.field_dict.keys()) - test_field_match = [h1.lower() == h2.lower() for h1, h2 - in zip(header_list, field_names)] - if not all(test_field_match): - logging.warning("Headers do not match field names" - "File headers:\n {header_list}\n" - "Field names: \n {field_names}\n" - "".format(**locals())) - raise AirflowException("Headers do not match the " - "field_dict keys") - with NamedTemporaryFile("w") as f_no_headers: - tmpf.seek(0) - next(tmpf) - for line in tmpf: - f_no_headers.write(line) - f_no_headers.flush() - logging.info("Loading file without headers into Hive") - self.hive.load_file( - f_no_headers.name, - self.hive_table, - field_dict=self.field_dict, - create=self.create, - partition=self.partition, - delimiter=self.delimiter, - recreate=self.recreate) + # Decompressing file + if self.input_compressed: + logging.info("Uncompressing file {0}".format(f.name)) + fn_uncompressed = uncompress_file(f.name, + file_ext, + tmp_dir) + logging.info("Uncompressed to {0}".format(fn_uncompressed)) + # uncompressed file available now so deleting + # compressed file to save disk space + f.close() + else: + fn_uncompressed = f.name + + # Testing if header matches field_dict + if self.check_headers: + logging.info("Matching file header against field_dict") + header_list = self._get_top_row_as_list(fn_uncompressed) + if not self._match_headers(header_list): + raise AirflowException("Header check failed") + + # Deleting top header row + logging.info("Removing header from file {0}". + format(fn_uncompressed)) + headless_file = ( + self._delete_top_row_and_compress(fn_uncompressed, + file_ext, + tmp_dir)) + logging.info("Headless file {0}".format(headless_file)) + logging.info("Loading file {0} into Hive".format(headless_file)) + self.hive.load_file(headless_file, + self.hive_table, + field_dict=self.field_dict, + create=self.create, + partition=self.partition, + delimiter=self.delimiter, + recreate=self.recreate) + + def _get_top_row_as_list(self, file_name): + with open(file_name, 'rt') as f: + header_line = f.readline().strip() + header_list = header_line.split(self.delimiter) + return header_list + + def _match_headers(self, header_list): + if not header_list: + raise AirflowException("Unable to retrieve header row from file") + field_names = self.field_dict.keys() + if len(field_names) != len(header_list): + logging.warning("Headers count mismatch" + "File headers:\n {header_list}\n" + "Field names: \n {field_names}\n" + "".format(**locals())) + return False + test_field_match = [h1.lower() == h2.lower() + for h1, h2 in zip(header_list, field_names)] + if not all(test_field_match): + logging.warning("Headers do not match field names" + "File headers:\n {header_list}\n" + "Field names: \n {field_names}\n" + "".format(**locals())) + return False + else: + return True + + def _delete_top_row_and_compress( + self, + input_file_name, + output_file_ext, + dest_dir): + # When output_file_ext is not defined, file is not compressed + open_fn = open + if output_file_ext.lower() == '.gz': + open_fn = gzip.GzipFile + elif output_file_ext.lower() == '.bz2': + open_fn = bz2.BZ2File + + os_fh_output, fn_output = \ + tempfile.mkstemp(suffix=output_file_ext, dir=dest_dir) + with open(input_file_name, 'rb') as f_in,\ + open_fn(fn_output, 'wb') as f_out: + f_in.seek(0) + next(f_in) + for line in f_in: + f_out.write(line) + return fn_output http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/1c231333/airflow/utils/compression.py ---------------------------------------------------------------------- diff --git a/airflow/utils/compression.py b/airflow/utils/compression.py new file mode 100644 index 0000000..9d0785f --- /dev/null +++ b/airflow/utils/compression.py @@ -0,0 +1,38 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from tempfile import NamedTemporaryFile +import shutil +import gzip +import bz2 + + +def uncompress_file(input_file_name, file_extension, dest_dir): + """ + Uncompress gz and bz2 files + """ + if file_extension.lower() not in ('.gz', '.bz2'): + raise NotImplementedError("Received {} format. Only gz and bz2 " + "files can currently be uncompressed." + .format(file_extension)) + if file_extension.lower() == '.gz': + fmodule = gzip.GzipFile + elif file_extension.lower() == '.bz2': + fmodule = bz2.BZ2File + with fmodule(input_file_name, mode='rb') as f_compressed,\ + NamedTemporaryFile(dir=dest_dir, + mode='wb', + delete=False) as f_uncompressed: + shutil.copyfileobj(f_compressed, f_uncompressed) + return f_uncompressed.name http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/1c231333/tests/operators/__init__.py ---------------------------------------------------------------------- diff --git a/tests/operators/__init__.py b/tests/operators/__init__.py index 63ff2a0..1fb0e5e 100644 --- a/tests/operators/__init__.py +++ b/tests/operators/__init__.py @@ -17,3 +17,4 @@ from .subdag_operator import * from .operators import * from .sensors import * from .hive_operator import * +from .s3_to_hive_operator import * http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/1c231333/tests/operators/s3_to_hive_operator.py ---------------------------------------------------------------------- diff --git a/tests/operators/s3_to_hive_operator.py b/tests/operators/s3_to_hive_operator.py new file mode 100644 index 0000000..faab11e --- /dev/null +++ b/tests/operators/s3_to_hive_operator.py @@ -0,0 +1,247 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +try: + from unittest import mock +except ImportError: + try: + import mock + except ImportError: + mock = None +import logging +from itertools import product +from airflow.operators.s3_to_hive_operator import S3ToHiveTransfer +from collections import OrderedDict +from airflow.exceptions import AirflowException +from tempfile import NamedTemporaryFile, mkdtemp +import gzip +import bz2 +import shutil +import filecmp +import errno + + +class S3ToHiveTransferTest(unittest.TestCase): + + def setUp(self): + self.fn = {} + self.task_id = 'S3ToHiveTransferTest' + self.s3_key = 'S32hive_test_file' + self.field_dict = OrderedDict([('Sno', 'BIGINT'), ('Some,Text', 'STRING')]) + self.hive_table = 'S32hive_test_table' + self.delimiter = '\t' + self.create = True + self.recreate = True + self.partition = {'ds': 'STRING'} + self.headers = True + self.check_headers = True + self.wildcard_match = False + self.input_compressed = False + self.kwargs = {'task_id': self.task_id, + 's3_key': self.s3_key, + 'field_dict': self.field_dict, + 'hive_table': self.hive_table, + 'delimiter': self.delimiter, + 'create': self.create, + 'recreate': self.recreate, + 'partition': self.partition, + 'headers': self.headers, + 'check_headers': self.check_headers, + 'wildcard_match': self.wildcard_match, + 'input_compressed': self.input_compressed + } + try: + header = "Sno\tSome,Text \n".encode() + line1 = "1\tAirflow Test\n".encode() + line2 = "2\tS32HiveTransfer\n".encode() + self.tmp_dir = mkdtemp(prefix='test_tmps32hive_') + # create sample txt, gz and bz2 with and without headers + with NamedTemporaryFile(mode='wb+', + dir=self.tmp_dir, + delete=False) as f_txt_h: + self._set_fn(f_txt_h.name, '.txt', True) + f_txt_h.writelines([header, line1, line2]) + fn_gz = self._get_fn('.txt', True) + ".gz" + with gzip.GzipFile(filename=fn_gz, + mode="wb") as f_gz_h: + self._set_fn(fn_gz, '.gz', True) + f_gz_h.writelines([header, line1, line2]) + fn_bz2 = self._get_fn('.txt', True) + '.bz2' + with bz2.BZ2File(filename=fn_bz2, + mode="wb") as f_bz2_h: + self._set_fn(fn_bz2, '.bz2', True) + f_bz2_h.writelines([header, line1, line2]) + # create sample txt, bz and bz2 without header + with NamedTemporaryFile(mode='wb+', + dir=self.tmp_dir, + delete=False) as f_txt_nh: + self._set_fn(f_txt_nh.name, '.txt', False) + f_txt_nh.writelines([line1, line2]) + fn_gz = self._get_fn('.txt', False) + ".gz" + with gzip.GzipFile(filename=fn_gz, + mode="wb") as f_gz_nh: + self._set_fn(fn_gz, '.gz', False) + f_gz_nh.writelines([line1, line2]) + fn_bz2 = self._get_fn('.txt', False) + '.bz2' + with bz2.BZ2File(filename=fn_bz2, + mode="wb") as f_bz2_nh: + self._set_fn(fn_bz2, '.bz2', False) + f_bz2_nh.writelines([line1, line2]) + # Base Exception so it catches Keyboard Interrupt + except BaseException as e: + logging.error(e) + self.tearDown() + + def tearDown(self): + try: + shutil.rmtree(self.tmp_dir) + except OSError as e: + # ENOENT - no such file or directory + if e.errno != errno.ENOENT: + raise e + + # Helper method to create a dictionary of file names and + # file types (file extension and header) + def _set_fn(self, fn, ext, header): + key = self._get_key(ext, header) + self.fn[key] = fn + + # Helper method to fetch a file of a + # certain format (file extension and header) + def _get_fn(self, ext, header): + key = self._get_key(ext, header) + return self.fn[key] + + def _get_key(self, ext, header): + key = ext + "_" + ('h' if header else 'nh') + return key + + def _cp_file_contents(self, fn_src, fn_dest): + with open(fn_src, 'rb') as f_src, open(fn_dest, 'wb') as f_dest: + shutil.copyfileobj(f_src, f_dest) + + def _check_file_equality(self, fn_1, fn_2, ext): + # gz files contain mtime and filename in the header that + # causes filecmp to return False even if contents are identical + # Hence decompress to test for equality + if(ext == '.gz'): + with gzip.GzipFile(fn_1, 'rb') as f_1,\ + NamedTemporaryFile(mode='wb') as f_txt_1,\ + gzip.GzipFile(fn_2, 'rb') as f_2,\ + NamedTemporaryFile(mode='wb') as f_txt_2: + shutil.copyfileobj(f_1, f_txt_1) + shutil.copyfileobj(f_2, f_txt_2) + f_txt_1.flush() + f_txt_2.flush() + return filecmp.cmp(f_txt_1.name, f_txt_2.name, shallow=False) + else: + return filecmp.cmp(fn_1, fn_2, shallow=False) + + def test_bad_parameters(self): + self.kwargs['check_headers'] = True + self.kwargs['headers'] = False + self.assertRaisesRegexp(AirflowException, + "To check_headers.*", + S3ToHiveTransfer, + **self.kwargs) + + def test__get_top_row_as_list(self): + self.kwargs['delimiter'] = '\t' + fn_txt = self._get_fn('.txt', True) + header_list = S3ToHiveTransfer(**self.kwargs).\ + _get_top_row_as_list(fn_txt) + self.assertEqual(header_list, ['Sno', 'Some,Text'], + msg="Top row from file doesnt matched expected value") + + self.kwargs['delimiter'] = ',' + header_list = S3ToHiveTransfer(**self.kwargs).\ + _get_top_row_as_list(fn_txt) + self.assertEqual(header_list, ['Sno\tSome', 'Text'], + msg="Top row from file doesnt matched expected value") + + def test__match_headers(self): + self.kwargs['field_dict'] = OrderedDict([('Sno', 'BIGINT'), + ('Some,Text', 'STRING')]) + self.assertTrue(S3ToHiveTransfer(**self.kwargs). + _match_headers(['Sno', 'Some,Text']), + msg="Header row doesnt match expected value") + # Testing with different column order + self.assertFalse(S3ToHiveTransfer(**self.kwargs). + _match_headers(['Some,Text', 'Sno']), + msg="Header row doesnt match expected value") + # Testing with extra column in header + self.assertFalse(S3ToHiveTransfer(**self.kwargs). + _match_headers(['Sno', 'Some,Text', 'ExtraColumn']), + msg="Header row doesnt match expected value") + + def test__delete_top_row_and_compress(self): + s32hive = S3ToHiveTransfer(**self.kwargs) + # Testing gz file type + fn_txt = self._get_fn('.txt', True) + gz_txt_nh = s32hive._delete_top_row_and_compress(fn_txt, + '.gz', + self.tmp_dir) + fn_gz = self._get_fn('.gz', False) + self.assertTrue(self._check_file_equality(gz_txt_nh, fn_gz, '.gz'), + msg="gz Compressed file not as expected") + # Testing bz2 file type + bz2_txt_nh = s32hive._delete_top_row_and_compress(fn_txt, + '.bz2', + self.tmp_dir) + fn_bz2 = self._get_fn('.bz2', False) + self.assertTrue(self._check_file_equality(bz2_txt_nh, fn_bz2, '.bz2'), + msg="bz2 Compressed file not as expected") + + @unittest.skipIf(mock is None, 'mock package not present') + @mock.patch('airflow.operators.s3_to_hive_operator.HiveCliHook') + @mock.patch('airflow.operators.s3_to_hive_operator.S3Hook') + def test_execute(self, mock_s3hook, mock_hiveclihook): + # Testing txt, zip, bz2 files with and without header row + for test in product(['.txt', '.gz', '.bz2'], [True, False]): + ext = test[0] + has_header = test[1] + self.kwargs['headers'] = has_header + self.kwargs['check_headers'] = has_header + logging.info("Testing {0} format {1} header". + format(ext, + ('with' if has_header else 'without')) + ) + self.kwargs['input_compressed'] = (False if ext == '.txt' else True) + self.kwargs['s3_key'] = self.s3_key + ext + ip_fn = self._get_fn(ext, self.kwargs['headers']) + op_fn = self._get_fn(ext, False) + # Mock s3 object returned by S3Hook + mock_s3_object = mock.Mock(key=self.kwargs['s3_key']) + mock_s3_object.get_contents_to_file.side_effect = \ + lambda dest_file: \ + self._cp_file_contents(ip_fn, dest_file.name) + mock_s3hook().get_key.return_value = mock_s3_object + # file paramter to HiveCliHook.load_file is compared + # against expected file oputput + mock_hiveclihook().load_file.side_effect = \ + lambda *args, **kwargs: \ + self.assertTrue( + self._check_file_equality(args[0], + op_fn, + ext + ), + msg='{0} output file not as expected'.format(ext)) + # Execute S3ToHiveTransfer + s32hive = S3ToHiveTransfer(**self.kwargs) + s32hive.execute(None) + + +if __name__ == '__main__': + unittest.main() http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/1c231333/tests/utils/compression.py ---------------------------------------------------------------------- diff --git a/tests/utils/compression.py b/tests/utils/compression.py new file mode 100644 index 0000000..f8e0ebb --- /dev/null +++ b/tests/utils/compression.py @@ -0,0 +1,97 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from airflow.utils import compression +import unittest +from tempfile import NamedTemporaryFile, mkdtemp +import bz2 +import gzip +import shutil +import logging +import errno +import filecmp + + +class Compression(unittest.TestCase): + + def setUp(self): + self.fn = {} + try: + header = "Sno\tSome,Text \n".encode() + line1 = "1\tAirflow Test\n".encode() + line2 = "2\tCompressionUtil\n".encode() + self.tmp_dir = mkdtemp(prefix='test_utils_compression_') + # create sample txt, gz and bz2 files + with NamedTemporaryFile(mode='wb+', + dir=self.tmp_dir, + delete=False) as f_txt: + self._set_fn(f_txt.name, '.txt') + f_txt.writelines([header, line1, line2]) + fn_gz = self._get_fn('.txt') + ".gz" + with gzip.GzipFile(filename=fn_gz, + mode="wb") as f_gz: + self._set_fn(fn_gz, '.gz') + f_gz.writelines([header, line1, line2]) + fn_bz2 = self._get_fn('.txt') + '.bz2' + with bz2.BZ2File(filename=fn_bz2, + mode="wb") as f_bz2: + self._set_fn(fn_bz2, '.bz2') + f_bz2.writelines([header, line1, line2]) + # Base Exception so it catches Keyboard Interrupt + except BaseException as e: + logging.error(e) + self.tearDown() + + def tearDown(self): + try: + shutil.rmtree(self.tmp_dir) + except OSError as e: + # ENOENT - no such file or directory + if e.errno != errno.ENOENT: + raise e + + # Helper method to create a dictionary of file names and + # file extension + def _set_fn(self, fn, ext): + self.fn[ext] = fn + + # Helper method to fetch a file of a + # certain extension + def _get_fn(self, ext): + return self.fn[ext] + + def test_uncompress_file(self): + # Testing txt file type + self.assertRaisesRegexp(NotImplementedError, + "^Received .txt format. Only gz and bz2.*", + compression.uncompress_file, + **{'input_file_name': None, + 'file_extension': '.txt', + 'dest_dir': None + }) + # Testing gz file type + fn_txt = self._get_fn('.txt') + fn_gz = self._get_fn('.gz') + txt_gz = compression.uncompress_file(fn_gz, '.gz', self.tmp_dir) + self.assertTrue(filecmp.cmp(txt_gz, fn_txt, shallow=False), + msg="Uncompressed file doest match original") + # Testing bz2 file type + fn_bz2 = self._get_fn('.bz2') + txt_bz2 = compression.uncompress_file(fn_bz2, '.bz2', self.tmp_dir) + self.assertTrue(filecmp.cmp(txt_bz2, fn_txt, shallow=False), + msg="Uncompressed file doest match original") + + +if __name__ == '__main__': + unittest.main()
