Repository: incubator-airflow
Updated Branches:
  refs/heads/master 87a1774cf -> 8c9d3befb


[AIRFLOW-2200] Add snowflake operator with tests


Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/c4ba1051
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/c4ba1051
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/c4ba1051

Branch: refs/heads/master
Commit: c4ba1051a750cd0a0906eba652bc21e848aa8454
Parents: 9c0c426
Author: devinXL8 <dev...@exelate.com>
Authored: Wed Apr 4 15:10:00 2018 -0400
Committer: devinXL8 <dev...@exelate.com>
Committed: Wed Apr 4 15:10:00 2018 -0400

----------------------------------------------------------------------
 airflow/contrib/hooks/snowflake_hook.py         | 93 ++++++++++++++++++++
 airflow/contrib/operators/snowflake_operator.py | 62 +++++++++++++
 airflow/models.py                               |  1 +
 setup.py                                        |  9 +-
 tests/contrib/hooks/test_snowflake_hook.py      | 64 ++++++++++++++
 .../operators/test_snowflake_operator.py        | 62 +++++++++++++
 6 files changed, 288 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/c4ba1051/airflow/contrib/hooks/snowflake_hook.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/hooks/snowflake_hook.py 
b/airflow/contrib/hooks/snowflake_hook.py
new file mode 100644
index 0000000..9d007d5
--- /dev/null
+++ b/airflow/contrib/hooks/snowflake_hook.py
@@ -0,0 +1,93 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import snowflake.connector
+
+from airflow.hooks.dbapi_hook import DbApiHook
+
+
+class SnowflakeHook(DbApiHook):
+    """
+    Interact with Snowflake.
+
+    get_sqlalchemy_engine() depends on snowflake-sqlalchemy
+
+    """
+
+    conn_name_attr = 'snowflake_conn_id'
+    default_conn_name = 'snowflake_default'
+    supports_autocommit = True
+
+    def __init__(self, *args, **kwargs):
+        super(SnowflakeHook, self).__init__(*args, **kwargs)
+        self.account = kwargs.pop("account", None)
+        self.warehouse = kwargs.pop("warehouse", None)
+        self.database = kwargs.pop("database", None)
+
+    def _get_conn_params(self):
+        '''
+        one method to fetch connection params as a dict
+        used in get_uri() and get_connection()
+        '''
+        conn = self.get_connection(self.snowflake_conn_id)
+        account = conn.extra_dejson.get('account', None)
+        warehouse = conn.extra_dejson.get('warehouse', None)
+        database = conn.extra_dejson.get('database', None)
+
+        conn_config = {
+            "user": conn.login,
+            "password": conn.password or '',
+            "schema": conn.schema or '',
+            "database": self.database or database or '',
+            "account": self.account or account or '',
+            "warehouse": self.warehouse or warehouse or ''
+        }
+        return conn_config
+
+    def get_uri(self):
+        '''
+        override DbApiHook get_uri method for get_sqlalchemy_engine()
+        '''
+        conn_config = self._get_conn_params()
+        uri = 'snowflake://{user}:{password}@{account}/{database}/'
+        uri += '{schema}?warehouse={warehouse}'
+        return uri.format(
+            **conn_config)
+
+    def get_conn(self):
+        """
+        Returns a snowflake.connection object
+        """
+        conn_config = self._get_conn_params()
+        conn = snowflake.connector.connect(**conn_config)
+        return conn
+
+    def _get_aws_credentials(self):
+        '''
+        returns aws_access_key_id, aws_secret_access_key
+        from extra
+
+        intended to be used by external import and export statements
+        '''
+        if self.snowflake_conn_id:
+            connection_object = self.get_connection(self.snowflake_conn_id)
+            if 'aws_secret_access_key' in connection_object.extra_dejson:
+                aws_access_key_id = connection_object.extra_dejson.get(
+                    'aws_access_key_id')
+                aws_secret_access_key = connection_object.extra_dejson.get(
+                    'aws_secret_access_key')
+        return aws_access_key_id, aws_secret_access_key
+
+    def set_autocommit(self, conn, autocommit):
+        conn.autocommit(autocommit)

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/c4ba1051/airflow/contrib/operators/snowflake_operator.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/operators/snowflake_operator.py 
b/airflow/contrib/operators/snowflake_operator.py
new file mode 100644
index 0000000..4947287
--- /dev/null
+++ b/airflow/contrib/operators/snowflake_operator.py
@@ -0,0 +1,62 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from airflow.contrib.hooks.snowflake_hook import SnowflakeHook
+from airflow.models import BaseOperator
+from airflow.utils.decorators import apply_defaults
+
+
+class SnowflakeOperator(BaseOperator):
+    """
+    Executes sql code in a Snowflake database
+
+    :param snowflake_conn_id: reference to specific snowflake connection id
+    :type snowflake_conn_id: string
+    :param sql: the sql code to be executed
+    :type sql: Can receive a str representing a sql statement,
+        a list of str (sql statements), or reference to a template file.
+        Template reference are recognized by str ending in '.sql'
+    :param warehouse: name of warehouse which overwrite defined
+        one in connection
+    :type warehouse: string
+    :param database: name of database which overwrite defined one in connection
+    :type database: string
+    """
+
+    template_fields = ('sql',)
+    template_ext = ('.sql',)
+    ui_color = '#ededed'
+
+    @apply_defaults
+    def __init__(
+            self, sql, snowflake_conn_id='snowflake_default', parameters=None,
+            autocommit=True, warehouse=None, database=None, *args, **kwargs):
+        super(SnowflakeOperator, self).__init__(*args, **kwargs)
+        self.snowflake_conn_id = snowflake_conn_id
+        self.sql = sql
+        self.autocommit = autocommit
+        self.parameters = parameters
+        self.warehouse = warehouse
+        self.database = database
+
+    def get_hook(self):
+        return SnowflakeHook(snowflake_conn_id=self.snowflake_conn_id,
+                             warehouse=self.warehouse, database=self.database)
+
+    def execute(self, context):
+        self.log.info('Executing: %s', self.sql)
+        hook = self.get_hook()
+        hook.run(
+            self.sql,
+            autocommit=self.autocommit,
+            parameters=self.parameters)

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/c4ba1051/airflow/models.py
----------------------------------------------------------------------
diff --git a/airflow/models.py b/airflow/models.py
index b08a3b1..e89e776 100755
--- a/airflow/models.py
+++ b/airflow/models.py
@@ -593,6 +593,7 @@ class Connection(Base, LoggingMixin):
         ('databricks', 'Databricks',),
         ('aws', 'Amazon Web Services',),
         ('emr', 'Elastic MapReduce',),
+        ('snowflake', 'Snowflake',),
     ]
 
     def __init__(

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/c4ba1051/setup.py
----------------------------------------------------------------------
diff --git a/setup.py b/setup.py
index 7279ed3..0c0b237 100644
--- a/setup.py
+++ b/setup.py
@@ -166,7 +166,8 @@ cloudant = ['cloudant>=0.5.9,<2.0'] # major update coming 
soon, clamp to 0.x
 redis = ['redis>=2.10.5']
 kubernetes = ['kubernetes>=3.0.0',
               'cryptography>=2.0.0']
-
+snowflake = ['snowflake-connector-python>=1.5.2',
+             'snowflake-sqlalchemy>=1.1.0']
 zendesk = ['zdesk']
 
 all_dbs = postgres + mysql + hive + mssql + hdfs + vertica + cloudant + druid
@@ -191,7 +192,8 @@ devel_minreq = devel + kubernetes + mysql + doc + password 
+ s3 + cgroups
 devel_hadoop = devel_minreq + hive + hdfs + webhdfs + kerberos
 devel_all = (sendgrid + devel + all_dbs + doc + samba + s3 + slack + crypto + 
oracle +
              docker + ssh + kubernetes + celery + azure + redis + gcp_api + 
datadog +
-             zendesk + jdbc + ldap + kerberos + password + webhdfs + jenkins + 
druid)
+             zendesk + jdbc + ldap + kerberos + password + webhdfs + jenkins +
+             druid + snowflake)
 
 # Snakebite & Google Cloud Dataflow are not Python 3 compatible :'(
 if PY3:
@@ -298,7 +300,8 @@ def do_setup():
             'webhdfs': webhdfs,
             'jira': jira,
             'redis': redis,
-            'kubernetes': kubernetes
+            'kubernetes': kubernetes,
+            'snowflake': snowflake
         },
         classifiers=[
             'Development Status :: 5 - Production/Stable',

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/c4ba1051/tests/contrib/hooks/test_snowflake_hook.py
----------------------------------------------------------------------
diff --git a/tests/contrib/hooks/test_snowflake_hook.py 
b/tests/contrib/hooks/test_snowflake_hook.py
new file mode 100644
index 0000000..00cb6f5
--- /dev/null
+++ b/tests/contrib/hooks/test_snowflake_hook.py
@@ -0,0 +1,64 @@
+
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import mock
+import unittest
+
+from airflow.contrib.hooks.snowflake_hook import SnowflakeHook
+
+
+class TestSnowflakeHook(unittest.TestCase):
+
+    def setUp(self):
+        super(TestSnowflakeHook, self).setUp()
+
+        self.cur = mock.MagicMock()
+        self.conn = conn = mock.MagicMock()
+        self.conn.cursor.return_value = self.cur
+
+        self.conn.login = 'user'
+        self.conn.password = 'pw'
+        self.conn.schema = 'public'
+        self.conn.extra_dejson = {'database': 'db',
+                                  'account': 'airflow',
+                                  'warehouse': 'af_wh'}
+
+        class UnitTestSnowflakeHook(SnowflakeHook):
+            conn_name_attr = 'snowflake_conn_id'
+
+            def get_conn(self):
+                return conn
+
+            def get_connection(self, connection_id):
+                return conn
+
+        self.db_hook = UnitTestSnowflakeHook()
+
+    def test_get_uri(self):
+        uri_shouldbe = 'snowflake://user:pw@airflow/db/public?warehouse=af_wh'
+        self.assertEqual(uri_shouldbe, self.db_hook.get_uri())
+
+    def test_get_conn_params(self):
+        conn_params_shouldbe = {'user': 'user',
+                                'password': 'pw',
+                                'schema': 'public',
+                                'database': 'db',
+                                'account': 'airflow',
+                                'warehouse': 'af_wh'}
+        self.assertEqual(conn_params_shouldbe, self.db_hook._get_conn_params())
+
+    def test_get_conn(self):
+        self.assertEqual(self.db_hook.get_conn(), self.conn)

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/c4ba1051/tests/contrib/operators/test_snowflake_operator.py
----------------------------------------------------------------------
diff --git a/tests/contrib/operators/test_snowflake_operator.py 
b/tests/contrib/operators/test_snowflake_operator.py
new file mode 100644
index 0000000..4febe90
--- /dev/null
+++ b/tests/contrib/operators/test_snowflake_operator.py
@@ -0,0 +1,62 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import print_function
+
+import unittest
+
+from airflow import DAG, configuration
+from airflow.utils import timezone
+
+from airflow.contrib.operators.snowflake_operator import SnowflakeOperator
+
+try:
+    from unittest import mock
+except ImportError:
+    try:
+        import mock
+    except ImportError:
+        mock = None
+
+
+DEFAULT_DATE = timezone.datetime(2015, 1, 1)
+DEFAULT_DATE_ISO = DEFAULT_DATE.isoformat()
+DEFAULT_DATE_DS = DEFAULT_DATE_ISO[:10]
+TEST_DAG_ID = 'unit_test_dag'
+LONG_MOCK_PATH = 'airflow.contrib.operators.snowflake_operator.'
+LONG_MOCK_PATH += 'SnowflakeOperator.get_hook'
+
+
+class TestSnowflakeOperator(unittest.TestCase):
+
+    def setUp(self):
+        super(TestSnowflakeOperator, self).setUp()
+        configuration.load_test_config()
+        args = {'owner': 'airflow', 'start_date': DEFAULT_DATE}
+        dag = DAG(TEST_DAG_ID, default_args=args)
+        self.dag = dag
+
+    @mock.patch(LONG_MOCK_PATH)
+    def test_snowflake_operator(self, mock_get_hook):
+        sql = """
+        CREATE TABLE IF NOT EXISTS test_airflow (
+            dummy VARCHAR(50)
+        );
+        """
+        t = SnowflakeOperator(
+            task_id='basic_snowflake',
+            sql=sql,
+            dag=self.dag)
+        t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE,
+              ignore_ti_state=True)

Reply via email to