This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch v1-10-test
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/v1-10-test by this push:
new 8d23325 [AIRFLOW-5906] Add authenticator parameter to snowflake_hook
(#8642)
8d23325 is described below
commit 8d233254894c8916a8e165d2fdb925ba74f89300
Author: Peter Kosztolanyi <[email protected]>
AuthorDate: Sun May 10 02:02:53 2020 +0100
[AIRFLOW-5906] Add authenticator parameter to snowflake_hook (#8642)
(cherry picked from commit cd635dd7d57cab2f41efac2d3d94e8f80a6c96d6)
---
airflow/contrib/hooks/snowflake_hook.py | 8 +++++---
airflow/contrib/operators/snowflake_operator.py | 13 +++++++++++--
tests/contrib/hooks/test_snowflake_hook.py | 6 ++++--
3 files changed, 20 insertions(+), 7 deletions(-)
diff --git a/airflow/contrib/hooks/snowflake_hook.py
b/airflow/contrib/hooks/snowflake_hook.py
index cd6c1c9..8574336 100644
--- a/airflow/contrib/hooks/snowflake_hook.py
+++ b/airflow/contrib/hooks/snowflake_hook.py
@@ -44,6 +44,7 @@ class SnowflakeHook(DbApiHook):
self.region = kwargs.pop("region", None)
self.role = kwargs.pop("role", None)
self.schema = kwargs.pop("schema", None)
+ self.authenticator = kwargs.pop("authenticator", None)
def _get_conn_params(self):
"""
@@ -56,6 +57,7 @@ class SnowflakeHook(DbApiHook):
database = conn.extra_dejson.get('database', None)
region = conn.extra_dejson.get("region", None)
role = conn.extra_dejson.get('role', None)
+ authenticator = conn.extra_dejson.get('authenticator', 'snowflake')
conn_config = {
"user": conn.login,
@@ -65,8 +67,8 @@ class SnowflakeHook(DbApiHook):
"account": self.account or account or '',
"warehouse": self.warehouse or warehouse or '',
"region": self.region or region or '',
- "role": self.role or role or ''
-
+ "role": self.role or role,
+ "authenticator": self.authenticator or authenticator
}
"""
@@ -103,7 +105,7 @@ class SnowflakeHook(DbApiHook):
"""
conn_config = self._get_conn_params()
uri = 'snowflake://{user}:{password}@{account}/{database}/'
- uri += '{schema}?warehouse={warehouse}&role={role}'
+ uri +=
'{schema}?warehouse={warehouse}&role={role}&authenticator={authenticator}'
return uri.format(**conn_config)
def get_conn(self):
diff --git a/airflow/contrib/operators/snowflake_operator.py
b/airflow/contrib/operators/snowflake_operator.py
index f115fc3..caea8190 100644
--- a/airflow/contrib/operators/snowflake_operator.py
+++ b/airflow/contrib/operators/snowflake_operator.py
@@ -42,6 +42,14 @@ class SnowflakeOperator(BaseOperator):
:type schema: str
:param role: name of role (will overwrite any role defined in
connection's extra JSON)
+ :param authenticator: authenticator for Snowflake.
+ 'snowflake' (default) to use the internal Snowflake authenticator
+ 'externalbrowser' to authenticate using your web browser and
+ Okta, ADFS or any other SAML 2.0-compliant identify provider
+ (IdP) that has been defined for your account
+ 'https://<your_okta_account_name>.okta.com' to authenticate
+ through native Okta.
+ :type authenticator: str
"""
template_fields = ('sql',)
@@ -52,7 +60,7 @@ class SnowflakeOperator(BaseOperator):
def __init__(
self, sql, snowflake_conn_id='snowflake_default', parameters=None,
autocommit=True, warehouse=None, database=None, role=None,
- schema=None, *args, **kwargs):
+ schema=None, authenticator=None, *args, **kwargs):
super(SnowflakeOperator, self).__init__(*args, **kwargs)
self.snowflake_conn_id = snowflake_conn_id
self.sql = sql
@@ -62,11 +70,12 @@ class SnowflakeOperator(BaseOperator):
self.database = database
self.role = role
self.schema = schema
+ self.authenticator = authenticator
def get_hook(self):
return SnowflakeHook(snowflake_conn_id=self.snowflake_conn_id,
warehouse=self.warehouse, database=self.database,
- role=self.role, schema=self.schema)
+ role=self.role, schema=self.schema,
authenticator=self.authenticator)
def execute(self, context):
self.log.info('Executing: %s', self.sql)
diff --git a/tests/contrib/hooks/test_snowflake_hook.py
b/tests/contrib/hooks/test_snowflake_hook.py
index 19ae138..158f598 100644
--- a/tests/contrib/hooks/test_snowflake_hook.py
+++ b/tests/contrib/hooks/test_snowflake_hook.py
@@ -92,7 +92,8 @@ class TestSnowflakeHook(unittest.TestCase):
os.remove(self.nonEncryptedPrivateKey)
def test_get_uri(self):
- uri_shouldbe =
'snowflake://user:pw@airflow/db/public?warehouse=af_wh&role=af_role'
+ uri_shouldbe =
'snowflake://user:pw@airflow/db/public?warehouse=af_wh&role=af_role' \
+ '&authenticator=snowflake'
self.assertEqual(uri_shouldbe, self.db_hook.get_uri())
def test_get_conn_params(self):
@@ -103,7 +104,8 @@ class TestSnowflakeHook(unittest.TestCase):
'account': 'airflow',
'warehouse': 'af_wh',
'region': 'af_region',
- 'role': 'af_role'}
+ 'role': 'af_role',
+ 'authenticator': 'snowflake'}
self.assertEqual(conn_params_shouldbe, self.db_hook._get_conn_params())
def test_get_conn(self):