Repository: incubator-airflow Updated Branches: refs/heads/v1-9-test ada7b2555 -> 71400b9d8
[AIRFLOW-1560] Add AWS DynamoDB hook and operator for inserting batch items Closes #2587 from sid88in/feature/dynamodb_hook_and_operator (cherry picked from commit 2f0798fcc9b7d6c0977b3190670d8a2c03818dd5) Signed-off-by: Bolke de Bruin <[email protected]> Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/71400b9d Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/71400b9d Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/71400b9d Branch: refs/heads/v1-9-test Commit: 71400b9d89f9faa49b03fccede4df4b85ac1475d Parents: ada7b25 Author: sid.gupta <[email protected]> Authored: Sat Sep 30 08:44:33 2017 +0200 Committer: Bolke de Bruin <[email protected]> Committed: Sat Sep 30 08:45:23 2017 +0200 ---------------------------------------------------------------------- airflow/contrib/hooks/__init__.py | 3 +- airflow/contrib/hooks/aws_dynamodb_hook.py | 60 ++++++++ airflow/contrib/hooks/aws_hook.py | 23 +++ airflow/contrib/operators/__init__.py | 3 +- airflow/contrib/operators/hive_to_dynamodb.py | 101 +++++++++++++ tests/contrib/hooks/test_aws_dynamodb_hook.py | 76 ++++++++++ tests/contrib/hooks/test_aws_hook.py | 37 ++++- .../operators/test_hive_to_dynamodb_operator.py | 144 +++++++++++++++++++ 8 files changed, 444 insertions(+), 3 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/71400b9d/airflow/contrib/hooks/__init__.py ---------------------------------------------------------------------- diff --git a/airflow/contrib/hooks/__init__.py b/airflow/contrib/hooks/__init__.py index 977c2ce..2891980 100644 --- a/airflow/contrib/hooks/__init__.py +++ b/airflow/contrib/hooks/__init__.py @@ -47,7 +47,8 @@ _hooks = { 'cloudant_hook': ['CloudantHook'], 'fs_hook': ['FSHook'], 'wasb_hook': ['WasbHook'], - 'gcp_pubsub_hook': ['PubSubHook'] + 'gcp_pubsub_hook': ['PubSubHook'], + 'aws_dynamodb_hook': ['AwsDynamoDBHook'] } import os as _os http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/71400b9d/airflow/contrib/hooks/aws_dynamodb_hook.py ---------------------------------------------------------------------- diff --git a/airflow/contrib/hooks/aws_dynamodb_hook.py b/airflow/contrib/hooks/aws_dynamodb_hook.py new file mode 100644 index 0000000..bb50ada --- /dev/null +++ b/airflow/contrib/hooks/aws_dynamodb_hook.py @@ -0,0 +1,60 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from airflow.exceptions import AirflowException +from airflow.contrib.hooks.aws_hook import AwsHook + + +class AwsDynamoDBHook(AwsHook): + """ + Interact with AWS DynamoDB. + + :param table_keys: partition key and sort key + :type table_keys: list + :param table_name: target DynamoDB table + :type table_name: str + :param region_name: aws region name (example: us-east-1) + :type region_name: str + """ + + def __init__(self, table_keys=None, table_name=None, region_name=None, *args, **kwargs): + self.table_keys = table_keys + self.table_name = table_name + self.region_name = region_name + super(AwsDynamoDBHook, self).__init__(*args, **kwargs) + + def get_conn(self): + self.conn = self.get_resource_type('dynamodb', self.region_name) + return self.conn + + def write_batch_data(self, items): + """ + Write batch items to dynamodb table with provisioned throughout capacity. + """ + + dynamodb_conn = self.get_conn() + + try: + table = dynamodb_conn.Table(self.table_name) + + with table.batch_writer(overwrite_by_pkeys=self.table_keys) as batch: + for item in items: + batch.put_item(Item=item) + return True + except Exception as general_error: + raise AirflowException( + 'Failed to insert items in dynamodb, error: {error}'.format( + error=str(general_error) + ) + ) http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/71400b9d/airflow/contrib/hooks/aws_hook.py ---------------------------------------------------------------------- diff --git a/airflow/contrib/hooks/aws_hook.py b/airflow/contrib/hooks/aws_hook.py index 3eced28..61d0eb4 100644 --- a/airflow/contrib/hooks/aws_hook.py +++ b/airflow/contrib/hooks/aws_hook.py @@ -24,6 +24,7 @@ class AwsHook(BaseHook): Interact with AWS. This class is a thin wrapper around the boto3 python library. """ + def __init__(self, aws_conn_id='aws_default'): self.aws_conn_id = aws_conn_id @@ -48,3 +49,25 @@ class AwsHook(BaseHook): aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key ) + + def get_resource_type(self, resource_type, region_name=None): + try: + connection_object = self.get_connection(self.aws_conn_id) + aws_access_key_id = connection_object.login + aws_secret_access_key = connection_object.password + + if region_name is None: + region_name = connection_object.extra_dejson.get('region_name') + + except AirflowException: + # No connection found: fallback on boto3 credential strategy + # http://boto3.readthedocs.io/en/latest/guide/configuration.html + aws_access_key_id = None + aws_secret_access_key = None + + return boto3.resource( + resource_type, + region_name=region_name, + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key + ) http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/71400b9d/airflow/contrib/operators/__init__.py ---------------------------------------------------------------------- diff --git a/airflow/contrib/operators/__init__.py b/airflow/contrib/operators/__init__.py index b731373..a761f4f 100644 --- a/airflow/contrib/operators/__init__.py +++ b/airflow/contrib/operators/__init__.py @@ -38,7 +38,8 @@ _operators = { 'qubole_operator': ['QuboleOperator'], 'spark_submit_operator': ['SparkSubmitOperator'], 'file_to_wasb': ['FileToWasbOperator'], - 'fs_operator': ['FileSensor'] + 'fs_operator': ['FileSensor'], + 'hive_to_dynamodb': ['HiveToDynamoDBTransferOperator'] } import os as _os http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/71400b9d/airflow/contrib/operators/hive_to_dynamodb.py ---------------------------------------------------------------------- diff --git a/airflow/contrib/operators/hive_to_dynamodb.py b/airflow/contrib/operators/hive_to_dynamodb.py new file mode 100644 index 0000000..55eca45 --- /dev/null +++ b/airflow/contrib/operators/hive_to_dynamodb.py @@ -0,0 +1,101 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging + +from airflow.contrib.hooks.aws_dynamodb_hook import AwsDynamoDBHook +from airflow.hooks.hive_hooks import HiveServer2Hook +from airflow.models import BaseOperator +from airflow.utils.decorators import apply_defaults + + +class HiveToDynamoDBTransferOperator(BaseOperator): + """ + Moves data from Hive to DynamoDB, note that for now the data is loaded + into memory before being pushed to DynamoDB, so this operator should + be used for smallish amount of data. + + :param sql: SQL query to execute against the hive database + :type sql: str + :param table_name: target DynamoDB table + :type table_name: str + :param table_keys: partition key and sort key + :type table_keys: list + :param pre_process: implement pre-processing of source data + :type pre_process: function + :param pre_process_args: list of pre_process function arguments + :type pre_process_args: list + :param pre_process_kwargs: dict of pre_process function arguments + :type pre_process_kwargs: dict + :param region_name: aws region name (example: us-east-1) + :type region_name: str + :param schema: hive database schema + :type schema: str + :param hiveserver2_conn_id: source hive connection + :type hiveserver2_conn_id: str + :param aws_conn_id: aws connection + :type aws_conn_id: str + """ + + template_fields = ('sql',) + template_ext = ('.sql',) + ui_color = '#a0e08c' + + @apply_defaults + def __init__( + self, + sql, + table_name, + table_keys, + pre_process=None, + pre_process_args=None, + pre_process_kwargs=None, + region_name=None, + schema='default', + hiveserver2_conn_id='hiveserver2_default', + aws_conn_id='aws_default', + *args, **kwargs): + super(HiveToDynamoDBTransferOperator, self).__init__(*args, **kwargs) + self.sql = sql + self.table_name = table_name + self.table_keys = table_keys + self.pre_process = pre_process + self.pre_process_args = pre_process_args + self.pre_process_kwargs = pre_process_kwargs + self.region_name = region_name + self.schema = schema + self.hiveserver2_conn_id = hiveserver2_conn_id + self.aws_conn_id = aws_conn_id + + def execute(self, context): + hive = HiveServer2Hook(hiveserver2_conn_id=self.hiveserver2_conn_id) + + logging.info('Extracting data from Hive') + logging.info(self.sql) + + data = hive.get_pandas_df(self.sql, schema=self.schema) + dynamodb = AwsDynamoDBHook(aws_conn_id=self.aws_conn_id, + table_name=self.table_name, table_keys=self.table_keys, region_name=self.region_name) + + logging.info('Inserting rows into dynamodb') + + if self.pre_process is None: + dynamodb.write_batch_data( + json.loads(data.to_json(orient='records'))) + else: + dynamodb.write_batch_data( + self.pre_process(data=data, args=self.pre_process_args, kwargs=self.pre_process_kwargs)) + + logging.info('Done.') http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/71400b9d/tests/contrib/hooks/test_aws_dynamodb_hook.py ---------------------------------------------------------------------- diff --git a/tests/contrib/hooks/test_aws_dynamodb_hook.py b/tests/contrib/hooks/test_aws_dynamodb_hook.py new file mode 100644 index 0000000..52ab428 --- /dev/null +++ b/tests/contrib/hooks/test_aws_dynamodb_hook.py @@ -0,0 +1,76 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest +import uuid + +from airflow.contrib.hooks.aws_dynamodb_hook import AwsDynamoDBHook + +try: + from moto import mock_dynamodb2 +except ImportError: + mock_dynamodb2 = None + + +class TestDynamoDBHook(unittest.TestCase): + + @unittest.skipIf(mock_dynamodb2 is None, 'mock_dynamodb2 package not present') + @mock_dynamodb2 + def test_get_conn_returns_a_boto3_connection(self): + hook = AwsDynamoDBHook(aws_conn_id='aws_default') + self.assertIsNotNone(hook.get_conn()) + + @unittest.skipIf(mock_dynamodb2 is None, 'mock_dynamodb2 package not present') + @mock_dynamodb2 + def test_insert_batch_items_dynamodb_table(self): + + hook = AwsDynamoDBHook(aws_conn_id='aws_default', + table_name='test_airflow', table_keys=['id'], region_name='us-east-1') + + # this table needs to be created in production + table = hook.get_conn().create_table( + TableName='test_airflow', + KeySchema=[ + { + 'AttributeName': 'id', + 'KeyType': 'HASH' + }, + ], + AttributeDefinitions=[ + { + 'AttributeName': 'name', + 'AttributeType': 'S' + } + ], + ProvisionedThroughput={ + 'ReadCapacityUnits': 10, + 'WriteCapacityUnits': 10 + } + ) + + table = hook.get_conn().Table('test_airflow') + + items = [{'id': str(uuid.uuid4()), 'name': 'airflow'} + for _ in range(10)] + + hook.write_batch_data(items) + + table.meta.client.get_waiter( + 'table_exists').wait(TableName='test_airflow') + self.assertEqual(table.item_count, 10) + + +if __name__ == '__main__': + unittest.main() http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/71400b9d/tests/contrib/hooks/test_aws_hook.py ---------------------------------------------------------------------- diff --git a/tests/contrib/hooks/test_aws_hook.py b/tests/contrib/hooks/test_aws_hook.py index 6f13e58..aa246f0 100644 --- a/tests/contrib/hooks/test_aws_hook.py +++ b/tests/contrib/hooks/test_aws_hook.py @@ -21,9 +21,10 @@ from airflow.contrib.hooks.aws_hook import AwsHook try: - from moto import mock_emr + from moto import mock_emr, mock_dynamodb2 except ImportError: mock_emr = None + mock_dynamodb2 = None class TestAwsHook(unittest.TestCase): @@ -43,5 +44,39 @@ class TestAwsHook(unittest.TestCase): self.assertEqual(client_from_hook.list_clusters()['Clusters'], []) + @unittest.skipIf(mock_dynamodb2 is None, 'mock_dynamo2 package not present') + @mock_dynamodb2 + def test_get_resource_type_returns_a_boto3_resource_of_the_requested_type(self): + + hook = AwsHook(aws_conn_id='aws_default') + resource_from_hook = hook.get_resource_type('dynamodb') + + # this table needs to be created in production + table = resource_from_hook.create_table( + TableName='test_airflow', + KeySchema=[ + { + 'AttributeName': 'id', + 'KeyType': 'HASH' + }, + ], + AttributeDefinitions=[ + { + 'AttributeName': 'name', + 'AttributeType': 'S' + } + ], + ProvisionedThroughput={ + 'ReadCapacityUnits': 10, + 'WriteCapacityUnits': 10 + } + ) + + table.meta.client.get_waiter( + 'table_exists').wait(TableName='test_airflow') + + self.assertEqual(table.item_count, 0) + + if __name__ == '__main__': unittest.main() http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/71400b9d/tests/contrib/operators/test_hive_to_dynamodb_operator.py ---------------------------------------------------------------------- diff --git a/tests/contrib/operators/test_hive_to_dynamodb_operator.py b/tests/contrib/operators/test_hive_to_dynamodb_operator.py new file mode 100644 index 0000000..fe9d1ca --- /dev/null +++ b/tests/contrib/operators/test_hive_to_dynamodb_operator.py @@ -0,0 +1,144 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import json +import unittest + +import mock +import pandas as pd + +from airflow import configuration, DAG + +configuration.load_test_config() +import datetime +from airflow.contrib.hooks.aws_dynamodb_hook import AwsDynamoDBHook +import airflow.contrib.operators.hive_to_dynamodb + +DEFAULT_DATE = datetime.datetime(2015, 1, 1) +DEFAULT_DATE_ISO = DEFAULT_DATE.isoformat() +DEFAULT_DATE_DS = DEFAULT_DATE_ISO[:10] + +try: + from moto import mock_dynamodb2 +except ImportError: + mock_dynamodb2 = None + + +class HiveToDynamoDBTransferOperatorTest(unittest.TestCase): + + def setUp(self): + configuration.load_test_config() + args = {'owner': 'airflow', 'start_date': DEFAULT_DATE} + dag = DAG('test_dag_id', default_args=args) + self.dag = dag + self.sql = 'SELECT 1' + self.hook = AwsDynamoDBHook( + aws_conn_id='aws_default', region_name='us-east-1') + + def process_data(self, data, *args, **kwargs): + return json.loads(data.to_json(orient='records')) + + @unittest.skipIf(mock_dynamodb2 is None, 'mock_dynamodb2 package not present') + @mock_dynamodb2 + def test_get_conn_returns_a_boto3_connection(self): + hook = AwsDynamoDBHook(aws_conn_id='aws_default') + self.assertIsNotNone(hook.get_conn()) + + @mock.patch('airflow.hooks.hive_hooks.HiveServer2Hook.get_pandas_df', + return_value=pd.DataFrame(data=[('1', 'sid')], columns=['id', 'name'])) + @unittest.skipIf(mock_dynamodb2 is None, 'mock_dynamodb2 package not present') + @mock_dynamodb2 + def test_get_records_with_schema(self, get_results_mock): + + # this table needs to be created in production + table = self.hook.get_conn().create_table( + TableName='test_airflow', + KeySchema=[ + { + 'AttributeName': 'id', + 'KeyType': 'HASH' + }, + ], + AttributeDefinitions=[ + { + 'AttributeName': 'name', + 'AttributeType': 'S' + } + ], + ProvisionedThroughput={ + 'ReadCapacityUnits': 10, + 'WriteCapacityUnits': 10 + } + ) + + operator = airflow.contrib.operators.hive_to_dynamodb.HiveToDynamoDBTransferOperator( + sql=self.sql, + table_name="test_airflow", + task_id='hive_to_dynamodb_check', + table_keys=['id'], + dag=self.dag) + + operator.execute(None) + + table = self.hook.get_conn().Table('test_airflow') + table.meta.client.get_waiter( + 'table_exists').wait(TableName='test_airflow') + self.assertEqual(table.item_count, 1) + + @mock.patch('airflow.hooks.hive_hooks.HiveServer2Hook.get_pandas_df', + return_value=pd.DataFrame(data=[('1', 'sid'), ('1', 'gupta')], columns=['id', 'name'])) + @unittest.skipIf(mock_dynamodb2 is None, 'mock_dynamodb2 package not present') + @mock_dynamodb2 + def test_pre_process_records_with_schema(self, get_results_mock): + + # this table needs to be created in production + table = self.hook.get_conn().create_table( + TableName='test_airflow', + KeySchema=[ + { + 'AttributeName': 'id', + 'KeyType': 'HASH' + }, + ], + AttributeDefinitions=[ + { + 'AttributeName': 'name', + 'AttributeType': 'S' + } + ], + ProvisionedThroughput={ + 'ReadCapacityUnits': 10, + 'WriteCapacityUnits': 10 + } + ) + + operator = airflow.contrib.operators.hive_to_dynamodb.HiveToDynamoDBTransferOperator( + sql=self.sql, + table_name='test_airflow', + task_id='hive_to_dynamodb_check', + table_keys=['id'], + pre_process=self.process_data, + dag=self.dag) + + operator.execute(None) + + table = self.hook.get_conn().Table('test_airflow') + table.meta.client.get_waiter( + 'table_exists').wait(TableName='test_airflow') + self.assertEqual(table.item_count, 1) + + +if __name__ == '__main__': + unittest.main()
