This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch v1-10-test
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/v1-10-test by this push:
     new 8d23325  [AIRFLOW-5906] Add authenticator parameter to snowflake_hook 
(#8642)
8d23325 is described below

commit 8d233254894c8916a8e165d2fdb925ba74f89300
Author: Peter Kosztolanyi <[email protected]>
AuthorDate: Sun May 10 02:02:53 2020 +0100

    [AIRFLOW-5906] Add authenticator parameter to snowflake_hook (#8642)
    
    (cherry picked from commit cd635dd7d57cab2f41efac2d3d94e8f80a6c96d6)
---
 airflow/contrib/hooks/snowflake_hook.py         |  8 +++++---
 airflow/contrib/operators/snowflake_operator.py | 13 +++++++++++--
 tests/contrib/hooks/test_snowflake_hook.py      |  6 ++++--
 3 files changed, 20 insertions(+), 7 deletions(-)

diff --git a/airflow/contrib/hooks/snowflake_hook.py 
b/airflow/contrib/hooks/snowflake_hook.py
index cd6c1c9..8574336 100644
--- a/airflow/contrib/hooks/snowflake_hook.py
+++ b/airflow/contrib/hooks/snowflake_hook.py
@@ -44,6 +44,7 @@ class SnowflakeHook(DbApiHook):
         self.region = kwargs.pop("region", None)
         self.role = kwargs.pop("role", None)
         self.schema = kwargs.pop("schema", None)
+        self.authenticator = kwargs.pop("authenticator", None)
 
     def _get_conn_params(self):
         """
@@ -56,6 +57,7 @@ class SnowflakeHook(DbApiHook):
         database = conn.extra_dejson.get('database', None)
         region = conn.extra_dejson.get("region", None)
         role = conn.extra_dejson.get('role', None)
+        authenticator = conn.extra_dejson.get('authenticator', 'snowflake')
 
         conn_config = {
             "user": conn.login,
@@ -65,8 +67,8 @@ class SnowflakeHook(DbApiHook):
             "account": self.account or account or '',
             "warehouse": self.warehouse or warehouse or '',
             "region": self.region or region or '',
-            "role": self.role or role or ''
-
+            "role": self.role or role,
+            "authenticator": self.authenticator or authenticator
         }
 
         """
@@ -103,7 +105,7 @@ class SnowflakeHook(DbApiHook):
         """
         conn_config = self._get_conn_params()
         uri = 'snowflake://{user}:{password}@{account}/{database}/'
-        uri += '{schema}?warehouse={warehouse}&role={role}'
+        uri += 
'{schema}?warehouse={warehouse}&role={role}&authenticator={authenticator}'
         return uri.format(**conn_config)
 
     def get_conn(self):
diff --git a/airflow/contrib/operators/snowflake_operator.py 
b/airflow/contrib/operators/snowflake_operator.py
index f115fc3..caea8190 100644
--- a/airflow/contrib/operators/snowflake_operator.py
+++ b/airflow/contrib/operators/snowflake_operator.py
@@ -42,6 +42,14 @@ class SnowflakeOperator(BaseOperator):
     :type schema: str
     :param role: name of role (will overwrite any role defined in
         connection's extra JSON)
+    :param authenticator: authenticator for Snowflake.
+        'snowflake' (default) to use the internal Snowflake authenticator
+        'externalbrowser' to authenticate using your web browser and
+        Okta, ADFS or any other SAML 2.0-compliant identify provider
+        (IdP) that has been defined for your account
+        'https://<your_okta_account_name>.okta.com' to authenticate
+        through native Okta.
+    :type authenticator: str
     """
 
     template_fields = ('sql',)
@@ -52,7 +60,7 @@ class SnowflakeOperator(BaseOperator):
     def __init__(
             self, sql, snowflake_conn_id='snowflake_default', parameters=None,
             autocommit=True, warehouse=None, database=None, role=None,
-            schema=None, *args, **kwargs):
+            schema=None, authenticator=None, *args, **kwargs):
         super(SnowflakeOperator, self).__init__(*args, **kwargs)
         self.snowflake_conn_id = snowflake_conn_id
         self.sql = sql
@@ -62,11 +70,12 @@ class SnowflakeOperator(BaseOperator):
         self.database = database
         self.role = role
         self.schema = schema
+        self.authenticator = authenticator
 
     def get_hook(self):
         return SnowflakeHook(snowflake_conn_id=self.snowflake_conn_id,
                              warehouse=self.warehouse, database=self.database,
-                             role=self.role, schema=self.schema)
+                             role=self.role, schema=self.schema, 
authenticator=self.authenticator)
 
     def execute(self, context):
         self.log.info('Executing: %s', self.sql)
diff --git a/tests/contrib/hooks/test_snowflake_hook.py 
b/tests/contrib/hooks/test_snowflake_hook.py
index 19ae138..158f598 100644
--- a/tests/contrib/hooks/test_snowflake_hook.py
+++ b/tests/contrib/hooks/test_snowflake_hook.py
@@ -92,7 +92,8 @@ class TestSnowflakeHook(unittest.TestCase):
         os.remove(self.nonEncryptedPrivateKey)
 
     def test_get_uri(self):
-        uri_shouldbe = 
'snowflake://user:pw@airflow/db/public?warehouse=af_wh&role=af_role'
+        uri_shouldbe = 
'snowflake://user:pw@airflow/db/public?warehouse=af_wh&role=af_role' \
+                       '&authenticator=snowflake'
         self.assertEqual(uri_shouldbe, self.db_hook.get_uri())
 
     def test_get_conn_params(self):
@@ -103,7 +104,8 @@ class TestSnowflakeHook(unittest.TestCase):
                                 'account': 'airflow',
                                 'warehouse': 'af_wh',
                                 'region': 'af_region',
-                                'role': 'af_role'}
+                                'role': 'af_role',
+                                'authenticator': 'snowflake'}
         self.assertEqual(conn_params_shouldbe, self.db_hook._get_conn_params())
 
     def test_get_conn(self):

Reply via email to