This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 518d394119a Allow `json_result_force_utf8_encoding` specification in 
`providers.snowflake.hooks.SnowflakeHook` extra dict (#44264)
518d394119a is described below

commit 518d394119af0afe302a2b5b4f406af330e5078f
Author: Tim Zhou <[email protected]>
AuthorDate: Wed Nov 27 22:50:18 2024 -0500

    Allow `json_result_force_utf8_encoding` specification in 
`providers.snowflake.hooks.SnowflakeHook` extra dict (#44264)
    
    * Allow json_result_force_utf8_encoding specification in SnowflakeHook 
extra dict
    
    * Use a set for the not in
---
 .../airflow/providers/snowflake/hooks/snowflake.py | 17 ++++++++++++++++-
 providers/tests/snowflake/hooks/test_snowflake.py  | 22 ++++++++++++++++++++++
 2 files changed, 38 insertions(+), 1 deletion(-)

diff --git a/providers/src/airflow/providers/snowflake/hooks/snowflake.py 
b/providers/src/airflow/providers/snowflake/hooks/snowflake.py
index e957c0623cb..fdf75939bda 100644
--- a/providers/src/airflow/providers/snowflake/hooks/snowflake.py
+++ b/providers/src/airflow/providers/snowflake/hooks/snowflake.py
@@ -201,6 +201,9 @@ class SnowflakeHook(DbApiHook):
         region = self._get_field(extra_dict, "region") or ""
         role = self._get_field(extra_dict, "role") or ""
         insecure_mode = _try_to_boolean(self._get_field(extra_dict, 
"insecure_mode"))
+        json_result_force_utf8_decoding = _try_to_boolean(
+            self._get_field(extra_dict, "json_result_force_utf8_decoding")
+        )
         schema = conn.schema or ""
         client_request_mfa_token = _try_to_boolean(self._get_field(extra_dict, 
"client_request_mfa_token"))
 
@@ -225,6 +228,9 @@ class SnowflakeHook(DbApiHook):
         if insecure_mode:
             conn_config["insecure_mode"] = insecure_mode
 
+        if json_result_force_utf8_decoding:
+            conn_config["json_result_force_utf8_decoding"] = 
json_result_force_utf8_decoding
+
         if client_request_mfa_token:
             conn_config["client_request_mfa_token"] = client_request_mfa_token
 
@@ -302,7 +308,13 @@ class SnowflakeHook(DbApiHook):
                 for k, v in conn_params.items()
                 if v
                 and k
-                not in ["session_parameters", "insecure_mode", "private_key", 
"client_request_mfa_token"]
+                not in {
+                    "session_parameters",
+                    "insecure_mode",
+                    "private_key",
+                    "client_request_mfa_token",
+                    "json_result_force_utf8_decoding",
+                }
             }
         )
 
@@ -324,6 +336,9 @@ class SnowflakeHook(DbApiHook):
         if "insecure_mode" in conn_params:
             engine_kwargs.setdefault("connect_args", {})
             engine_kwargs["connect_args"]["insecure_mode"] = True
+        if "json_result_force_utf8_decoding" in conn_params:
+            engine_kwargs.setdefault("connect_args", {})
+            engine_kwargs["connect_args"]["json_result_force_utf8_decoding"] = 
True
         for key in ["session_parameters", "private_key"]:
             if conn_params.get(key):
                 engine_kwargs.setdefault("connect_args", {})
diff --git a/providers/tests/snowflake/hooks/test_snowflake.py 
b/providers/tests/snowflake/hooks/test_snowflake.py
index b7c9382654b..d75f1a4baf1 100644
--- a/providers/tests/snowflake/hooks/test_snowflake.py
+++ b/providers/tests/snowflake/hooks/test_snowflake.py
@@ -138,6 +138,7 @@ class TestPytestSnowflakeHook:
                         "extra__snowflake__region": "af_region",
                         "extra__snowflake__role": "af_role",
                         "extra__snowflake__insecure_mode": "True",
+                        "extra__snowflake__json_result_force_utf8_decoding": 
"True",
                         "extra__snowflake__client_request_mfa_token": "True",
                     },
                 },
@@ -158,6 +159,7 @@ class TestPytestSnowflakeHook:
                     "user": "user",
                     "warehouse": "af_wh",
                     "insecure_mode": True,
+                    "json_result_force_utf8_decoding": True,
                     "client_request_mfa_token": True,
                 },
             ),
@@ -171,6 +173,7 @@ class TestPytestSnowflakeHook:
                         "extra__snowflake__region": "af_region",
                         "extra__snowflake__role": "af_role",
                         "extra__snowflake__insecure_mode": "False",
+                        "extra__snowflake__json_result_force_utf8_decoding": 
"False",
                         "extra__snowflake__client_request_mfa_token": "False",
                     },
                 },
@@ -247,6 +250,7 @@ class TestPytestSnowflakeHook:
                     "extra": {
                         **BASE_CONNECTION_KWARGS["extra"],
                         "extra__snowflake__insecure_mode": False,
+                        "extra__snowflake__json_result_force_utf8_decoding": 
True,
                         "extra__snowflake__client_request_mfa_token": False,
                     },
                 },
@@ -266,6 +270,7 @@ class TestPytestSnowflakeHook:
                     "session_parameters": None,
                     "user": "user",
                     "warehouse": "af_wh",
+                    "json_result_force_utf8_decoding": True,
                 },
             ),
         ],
@@ -473,6 +478,23 @@ class TestPytestSnowflakeHook:
             )
             assert mock_create_engine.return_value == conn
 
+    def 
test_get_sqlalchemy_engine_should_support_json_result_force_utf8_decoding(self):
+        connection_kwargs = deepcopy(BASE_CONNECTION_KWARGS)
+        
connection_kwargs["extra"]["extra__snowflake__json_result_force_utf8_decoding"] 
= "True"
+
+        with (
+            mock.patch.dict("os.environ", 
AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()),
+            
mock.patch("airflow.providers.snowflake.hooks.snowflake.create_engine") as 
mock_create_engine,
+        ):
+            hook = SnowflakeHook(snowflake_conn_id="test_conn")
+            conn = hook.get_sqlalchemy_engine()
+            mock_create_engine.assert_called_once_with(
+                "snowflake://user:[email protected]_region/db/public"
+                
"?application=AIRFLOW&authenticator=snowflake&role=af_role&warehouse=af_wh",
+                connect_args={"json_result_force_utf8_decoding": True},
+            )
+            assert mock_create_engine.return_value == conn
+
     def test_get_sqlalchemy_engine_should_support_session_parameters(self):
         connection_kwargs = deepcopy(BASE_CONNECTION_KWARGS)
         connection_kwargs["extra"]["session_parameters"] = {"TEST_PARAM": 
"AA", "TEST_PARAM_B": 123}

Reply via email to