This is an automated email from the ASF dual-hosted git repository.

turbaszek pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/master by this push:
     new 9c68e7c  Add Snowflake support to SQL operator and sensor (#9843)
9c68e7c is described below

commit 9c68e7cc6fc1bf7c5a9a0156a2f0cf166cf2dfbe
Author: Andy <[email protected]>
AuthorDate: Fri Jul 17 02:04:14 2020 -0500

    Add Snowflake support to SQL operator and sensor (#9843)
    
    * Add Snowflake support to SQL operator and sensor
    * Add test for conn_type to valid hook mapping
    * Improve code quality for conn type mapping test
---
 airflow/models/connection.py    | 3 ++-
 airflow/operators/sql.py        | 1 +
 airflow/sensors/sql_sensor.py   | 2 +-
 tests/models/test_connection.py | 5 +++++
 4 files changed, 9 insertions(+), 2 deletions(-)

diff --git a/airflow/models/connection.py b/airflow/models/connection.py
index 7144bcb..75174c4 100644
--- a/airflow/models/connection.py
+++ b/airflow/models/connection.py
@@ -72,7 +72,7 @@ CONN_TYPE_TO_HOOK = {
     "jira": ("airflow.providers.jira.hooks.jira.JiraHook", "jira_conn_id"),
     "kubernetes": 
("airflow.providers.cncf.kubernetes.hooks.kubernetes.KubernetesHook", 
"kubernetes_conn_id"),
     "mongo": ("airflow.providers.mongo.hooks.mongo.MongoHook", "conn_id"),
-    "mssql": ("airflow.providers.microsoft.mssql.hooks.mssql.MsSqlHook", 
"mssql_conn_id"),
+    "mssql": ("airflow.providers.odbc.hooks.odbc.OdbcHook", "odbc_conn_id"),
     "mysql": ("airflow.providers.mysql.hooks.mysql.MySqlHook", 
"mysql_conn_id"),
     "odbc": ("airflow.providers.odbc.hooks.odbc.OdbcHook", "odbc_conn_id"),
     "oracle": ("airflow.providers.oracle.hooks.oracle.OracleHook", 
"oracle_conn_id"),
@@ -80,6 +80,7 @@ CONN_TYPE_TO_HOOK = {
     "postgres": ("airflow.providers.postgres.hooks.postgres.PostgresHook", 
"postgres_conn_id"),
     "presto": ("airflow.providers.presto.hooks.presto.PrestoHook", 
"presto_conn_id"),
     "redis": ("airflow.providers.redis.hooks.redis.RedisHook", 
"redis_conn_id"),
+    "snowflake": ("airflow.providers.snowflake.hooks.snowflake.SnowflakeHook", 
"snowflake_conn_id"),
     "sqlite": ("airflow.providers.sqlite.hooks.sqlite.SqliteHook", 
"sqlite_conn_id"),
     "tableau": ("airflow.providers.salesforce.hooks.tableau.TableauHook", 
"tableau_conn_id"),
     "vertica": ("airflow.providers.vertica.hooks.vertica.VerticaHook", 
"vertica_conn_id"),
diff --git a/airflow/operators/sql.py b/airflow/operators/sql.py
index a0427f3..c99933f 100644
--- a/airflow/operators/sql.py
+++ b/airflow/operators/sql.py
@@ -32,6 +32,7 @@ ALLOWED_CONN_TYPE = {
     "oracle",
     "postgres",
     "presto",
+    "snowflake",
     "sqlite",
     "vertica",
 }
diff --git a/airflow/sensors/sql_sensor.py b/airflow/sensors/sql_sensor.py
index e4f8e3a..4996ee5 100644
--- a/airflow/sensors/sql_sensor.py
+++ b/airflow/sensors/sql_sensor.py
@@ -71,7 +71,7 @@ class SqlSensor(BaseSensorOperator):
 
         allowed_conn_type = {'google_cloud_platform', 'jdbc', 'mssql',
                              'mysql', 'odbc', 'oracle', 'postgres',
-                             'presto', 'sqlite', 'vertica'}
+                             'presto', 'snowflake', 'sqlite', 'vertica'}
         if conn.conn_type not in allowed_conn_type:
             raise AirflowException("The connection type is not supported by 
SqlSensor. " +
                                    "Supported connection types: 
{}".format(list(allowed_conn_type)))
diff --git a/tests/models/test_connection.py b/tests/models/test_connection.py
index 3b6abb2..df3ace8 100644
--- a/tests/models/test_connection.py
+++ b/tests/models/test_connection.py
@@ -30,6 +30,7 @@ from airflow.hooks.base_hook import BaseHook
 from airflow.models import Connection, crypto
 from airflow.models.connection import CONN_TYPE_TO_HOOK
 from airflow.providers.sqlite.hooks.sqlite import SqliteHook
+from airflow.utils.module_loading import import_string
 from tests.test_utils.config import conf_vars
 
 ConnectionParts = namedtuple("ConnectionParts", ["conn_type", "login", 
"password", "host", "port", "schema"])
@@ -533,3 +534,7 @@ class TestConnTypeToHook(unittest.TestCase):
         expected_keys = sorted(current_keys)
 
         self.assertEqual(expected_keys, current_keys)
+
+    def test_hooks_importable(self):
+        for hook_path, _ in CONN_TYPE_TO_HOOK.values():
+            self.assertTrue(issubclass(import_string(hook_path), BaseHook))

Reply via email to