Repository: incubator-airflow Updated Branches: refs/heads/master 87a1774cf -> 8c9d3befb
[AIRFLOW-2200] Add snowflake operator with tests Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/c4ba1051 Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/c4ba1051 Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/c4ba1051 Branch: refs/heads/master Commit: c4ba1051a750cd0a0906eba652bc21e848aa8454 Parents: 9c0c426 Author: devinXL8 <dev...@exelate.com> Authored: Wed Apr 4 15:10:00 2018 -0400 Committer: devinXL8 <dev...@exelate.com> Committed: Wed Apr 4 15:10:00 2018 -0400 ---------------------------------------------------------------------- airflow/contrib/hooks/snowflake_hook.py | 93 ++++++++++++++++++++ airflow/contrib/operators/snowflake_operator.py | 62 +++++++++++++ airflow/models.py | 1 + setup.py | 9 +- tests/contrib/hooks/test_snowflake_hook.py | 64 ++++++++++++++ .../operators/test_snowflake_operator.py | 62 +++++++++++++ 6 files changed, 288 insertions(+), 3 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/c4ba1051/airflow/contrib/hooks/snowflake_hook.py ---------------------------------------------------------------------- diff --git a/airflow/contrib/hooks/snowflake_hook.py b/airflow/contrib/hooks/snowflake_hook.py new file mode 100644 index 0000000..9d007d5 --- /dev/null +++ b/airflow/contrib/hooks/snowflake_hook.py @@ -0,0 +1,93 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import snowflake.connector + +from airflow.hooks.dbapi_hook import DbApiHook + + +class SnowflakeHook(DbApiHook): + """ + Interact with Snowflake. + + get_sqlalchemy_engine() depends on snowflake-sqlalchemy + + """ + + conn_name_attr = 'snowflake_conn_id' + default_conn_name = 'snowflake_default' + supports_autocommit = True + + def __init__(self, *args, **kwargs): + super(SnowflakeHook, self).__init__(*args, **kwargs) + self.account = kwargs.pop("account", None) + self.warehouse = kwargs.pop("warehouse", None) + self.database = kwargs.pop("database", None) + + def _get_conn_params(self): + ''' + one method to fetch connection params as a dict + used in get_uri() and get_connection() + ''' + conn = self.get_connection(self.snowflake_conn_id) + account = conn.extra_dejson.get('account', None) + warehouse = conn.extra_dejson.get('warehouse', None) + database = conn.extra_dejson.get('database', None) + + conn_config = { + "user": conn.login, + "password": conn.password or '', + "schema": conn.schema or '', + "database": self.database or database or '', + "account": self.account or account or '', + "warehouse": self.warehouse or warehouse or '' + } + return conn_config + + def get_uri(self): + ''' + override DbApiHook get_uri method for get_sqlalchemy_engine() + ''' + conn_config = self._get_conn_params() + uri = 'snowflake://{user}:{password}@{account}/{database}/' + uri += '{schema}?warehouse={warehouse}' + return uri.format( + **conn_config) + + def get_conn(self): + """ + Returns a snowflake.connection object + """ + conn_config = self._get_conn_params() + conn = snowflake.connector.connect(**conn_config) + return conn + + def _get_aws_credentials(self): + ''' + returns aws_access_key_id, aws_secret_access_key + from extra + + intended to be used by external import and export statements + ''' + if self.snowflake_conn_id: + connection_object = self.get_connection(self.snowflake_conn_id) + if 'aws_secret_access_key' in connection_object.extra_dejson: + aws_access_key_id = connection_object.extra_dejson.get( + 'aws_access_key_id') + aws_secret_access_key = connection_object.extra_dejson.get( + 'aws_secret_access_key') + return aws_access_key_id, aws_secret_access_key + + def set_autocommit(self, conn, autocommit): + conn.autocommit(autocommit) http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/c4ba1051/airflow/contrib/operators/snowflake_operator.py ---------------------------------------------------------------------- diff --git a/airflow/contrib/operators/snowflake_operator.py b/airflow/contrib/operators/snowflake_operator.py new file mode 100644 index 0000000..4947287 --- /dev/null +++ b/airflow/contrib/operators/snowflake_operator.py @@ -0,0 +1,62 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from airflow.contrib.hooks.snowflake_hook import SnowflakeHook +from airflow.models import BaseOperator +from airflow.utils.decorators import apply_defaults + + +class SnowflakeOperator(BaseOperator): + """ + Executes sql code in a Snowflake database + + :param snowflake_conn_id: reference to specific snowflake connection id + :type snowflake_conn_id: string + :param sql: the sql code to be executed + :type sql: Can receive a str representing a sql statement, + a list of str (sql statements), or reference to a template file. + Template reference are recognized by str ending in '.sql' + :param warehouse: name of warehouse which overwrite defined + one in connection + :type warehouse: string + :param database: name of database which overwrite defined one in connection + :type database: string + """ + + template_fields = ('sql',) + template_ext = ('.sql',) + ui_color = '#ededed' + + @apply_defaults + def __init__( + self, sql, snowflake_conn_id='snowflake_default', parameters=None, + autocommit=True, warehouse=None, database=None, *args, **kwargs): + super(SnowflakeOperator, self).__init__(*args, **kwargs) + self.snowflake_conn_id = snowflake_conn_id + self.sql = sql + self.autocommit = autocommit + self.parameters = parameters + self.warehouse = warehouse + self.database = database + + def get_hook(self): + return SnowflakeHook(snowflake_conn_id=self.snowflake_conn_id, + warehouse=self.warehouse, database=self.database) + + def execute(self, context): + self.log.info('Executing: %s', self.sql) + hook = self.get_hook() + hook.run( + self.sql, + autocommit=self.autocommit, + parameters=self.parameters) http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/c4ba1051/airflow/models.py ---------------------------------------------------------------------- diff --git a/airflow/models.py b/airflow/models.py index b08a3b1..e89e776 100755 --- a/airflow/models.py +++ b/airflow/models.py @@ -593,6 +593,7 @@ class Connection(Base, LoggingMixin): ('databricks', 'Databricks',), ('aws', 'Amazon Web Services',), ('emr', 'Elastic MapReduce',), + ('snowflake', 'Snowflake',), ] def __init__( http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/c4ba1051/setup.py ---------------------------------------------------------------------- diff --git a/setup.py b/setup.py index 7279ed3..0c0b237 100644 --- a/setup.py +++ b/setup.py @@ -166,7 +166,8 @@ cloudant = ['cloudant>=0.5.9,<2.0'] # major update coming soon, clamp to 0.x redis = ['redis>=2.10.5'] kubernetes = ['kubernetes>=3.0.0', 'cryptography>=2.0.0'] - +snowflake = ['snowflake-connector-python>=1.5.2', + 'snowflake-sqlalchemy>=1.1.0'] zendesk = ['zdesk'] all_dbs = postgres + mysql + hive + mssql + hdfs + vertica + cloudant + druid @@ -191,7 +192,8 @@ devel_minreq = devel + kubernetes + mysql + doc + password + s3 + cgroups devel_hadoop = devel_minreq + hive + hdfs + webhdfs + kerberos devel_all = (sendgrid + devel + all_dbs + doc + samba + s3 + slack + crypto + oracle + docker + ssh + kubernetes + celery + azure + redis + gcp_api + datadog + - zendesk + jdbc + ldap + kerberos + password + webhdfs + jenkins + druid) + zendesk + jdbc + ldap + kerberos + password + webhdfs + jenkins + + druid + snowflake) # Snakebite & Google Cloud Dataflow are not Python 3 compatible :'( if PY3: @@ -298,7 +300,8 @@ def do_setup(): 'webhdfs': webhdfs, 'jira': jira, 'redis': redis, - 'kubernetes': kubernetes + 'kubernetes': kubernetes, + 'snowflake': snowflake }, classifiers=[ 'Development Status :: 5 - Production/Stable', http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/c4ba1051/tests/contrib/hooks/test_snowflake_hook.py ---------------------------------------------------------------------- diff --git a/tests/contrib/hooks/test_snowflake_hook.py b/tests/contrib/hooks/test_snowflake_hook.py new file mode 100644 index 0000000..00cb6f5 --- /dev/null +++ b/tests/contrib/hooks/test_snowflake_hook.py @@ -0,0 +1,64 @@ + +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import mock +import unittest + +from airflow.contrib.hooks.snowflake_hook import SnowflakeHook + + +class TestSnowflakeHook(unittest.TestCase): + + def setUp(self): + super(TestSnowflakeHook, self).setUp() + + self.cur = mock.MagicMock() + self.conn = conn = mock.MagicMock() + self.conn.cursor.return_value = self.cur + + self.conn.login = 'user' + self.conn.password = 'pw' + self.conn.schema = 'public' + self.conn.extra_dejson = {'database': 'db', + 'account': 'airflow', + 'warehouse': 'af_wh'} + + class UnitTestSnowflakeHook(SnowflakeHook): + conn_name_attr = 'snowflake_conn_id' + + def get_conn(self): + return conn + + def get_connection(self, connection_id): + return conn + + self.db_hook = UnitTestSnowflakeHook() + + def test_get_uri(self): + uri_shouldbe = 'snowflake://user:pw@airflow/db/public?warehouse=af_wh' + self.assertEqual(uri_shouldbe, self.db_hook.get_uri()) + + def test_get_conn_params(self): + conn_params_shouldbe = {'user': 'user', + 'password': 'pw', + 'schema': 'public', + 'database': 'db', + 'account': 'airflow', + 'warehouse': 'af_wh'} + self.assertEqual(conn_params_shouldbe, self.db_hook._get_conn_params()) + + def test_get_conn(self): + self.assertEqual(self.db_hook.get_conn(), self.conn) http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/c4ba1051/tests/contrib/operators/test_snowflake_operator.py ---------------------------------------------------------------------- diff --git a/tests/contrib/operators/test_snowflake_operator.py b/tests/contrib/operators/test_snowflake_operator.py new file mode 100644 index 0000000..4febe90 --- /dev/null +++ b/tests/contrib/operators/test_snowflake_operator.py @@ -0,0 +1,62 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest + +from airflow import DAG, configuration +from airflow.utils import timezone + +from airflow.contrib.operators.snowflake_operator import SnowflakeOperator + +try: + from unittest import mock +except ImportError: + try: + import mock + except ImportError: + mock = None + + +DEFAULT_DATE = timezone.datetime(2015, 1, 1) +DEFAULT_DATE_ISO = DEFAULT_DATE.isoformat() +DEFAULT_DATE_DS = DEFAULT_DATE_ISO[:10] +TEST_DAG_ID = 'unit_test_dag' +LONG_MOCK_PATH = 'airflow.contrib.operators.snowflake_operator.' +LONG_MOCK_PATH += 'SnowflakeOperator.get_hook' + + +class TestSnowflakeOperator(unittest.TestCase): + + def setUp(self): + super(TestSnowflakeOperator, self).setUp() + configuration.load_test_config() + args = {'owner': 'airflow', 'start_date': DEFAULT_DATE} + dag = DAG(TEST_DAG_ID, default_args=args) + self.dag = dag + + @mock.patch(LONG_MOCK_PATH) + def test_snowflake_operator(self, mock_get_hook): + sql = """ + CREATE TABLE IF NOT EXISTS test_airflow ( + dummy VARCHAR(50) + ); + """ + t = SnowflakeOperator( + task_id='basic_snowflake', + sql=sql, + dag=self.dag) + t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, + ignore_ti_state=True)