This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 518d394119a Allow `json_result_force_utf8_encoding` specification in
`providers.snowflake.hooks.SnowflakeHook` extra dict (#44264)
518d394119a is described below
commit 518d394119af0afe302a2b5b4f406af330e5078f
Author: Tim Zhou <[email protected]>
AuthorDate: Wed Nov 27 22:50:18 2024 -0500
Allow `json_result_force_utf8_encoding` specification in
`providers.snowflake.hooks.SnowflakeHook` extra dict (#44264)
* Allow json_result_force_utf8_encoding specification in SnowflakeHook
extra dict
* Use a set for the not in
---
.../airflow/providers/snowflake/hooks/snowflake.py | 17 ++++++++++++++++-
providers/tests/snowflake/hooks/test_snowflake.py | 22 ++++++++++++++++++++++
2 files changed, 38 insertions(+), 1 deletion(-)
diff --git a/providers/src/airflow/providers/snowflake/hooks/snowflake.py
b/providers/src/airflow/providers/snowflake/hooks/snowflake.py
index e957c0623cb..fdf75939bda 100644
--- a/providers/src/airflow/providers/snowflake/hooks/snowflake.py
+++ b/providers/src/airflow/providers/snowflake/hooks/snowflake.py
@@ -201,6 +201,9 @@ class SnowflakeHook(DbApiHook):
region = self._get_field(extra_dict, "region") or ""
role = self._get_field(extra_dict, "role") or ""
insecure_mode = _try_to_boolean(self._get_field(extra_dict,
"insecure_mode"))
+ json_result_force_utf8_decoding = _try_to_boolean(
+ self._get_field(extra_dict, "json_result_force_utf8_decoding")
+ )
schema = conn.schema or ""
client_request_mfa_token = _try_to_boolean(self._get_field(extra_dict,
"client_request_mfa_token"))
@@ -225,6 +228,9 @@ class SnowflakeHook(DbApiHook):
if insecure_mode:
conn_config["insecure_mode"] = insecure_mode
+ if json_result_force_utf8_decoding:
+ conn_config["json_result_force_utf8_decoding"] =
json_result_force_utf8_decoding
+
if client_request_mfa_token:
conn_config["client_request_mfa_token"] = client_request_mfa_token
@@ -302,7 +308,13 @@ class SnowflakeHook(DbApiHook):
for k, v in conn_params.items()
if v
and k
- not in ["session_parameters", "insecure_mode", "private_key",
"client_request_mfa_token"]
+ not in {
+ "session_parameters",
+ "insecure_mode",
+ "private_key",
+ "client_request_mfa_token",
+ "json_result_force_utf8_decoding",
+ }
}
)
@@ -324,6 +336,9 @@ class SnowflakeHook(DbApiHook):
if "insecure_mode" in conn_params:
engine_kwargs.setdefault("connect_args", {})
engine_kwargs["connect_args"]["insecure_mode"] = True
+ if "json_result_force_utf8_decoding" in conn_params:
+ engine_kwargs.setdefault("connect_args", {})
+ engine_kwargs["connect_args"]["json_result_force_utf8_decoding"] =
True
for key in ["session_parameters", "private_key"]:
if conn_params.get(key):
engine_kwargs.setdefault("connect_args", {})
diff --git a/providers/tests/snowflake/hooks/test_snowflake.py
b/providers/tests/snowflake/hooks/test_snowflake.py
index b7c9382654b..d75f1a4baf1 100644
--- a/providers/tests/snowflake/hooks/test_snowflake.py
+++ b/providers/tests/snowflake/hooks/test_snowflake.py
@@ -138,6 +138,7 @@ class TestPytestSnowflakeHook:
"extra__snowflake__region": "af_region",
"extra__snowflake__role": "af_role",
"extra__snowflake__insecure_mode": "True",
+ "extra__snowflake__json_result_force_utf8_decoding":
"True",
"extra__snowflake__client_request_mfa_token": "True",
},
},
@@ -158,6 +159,7 @@ class TestPytestSnowflakeHook:
"user": "user",
"warehouse": "af_wh",
"insecure_mode": True,
+ "json_result_force_utf8_decoding": True,
"client_request_mfa_token": True,
},
),
@@ -171,6 +173,7 @@ class TestPytestSnowflakeHook:
"extra__snowflake__region": "af_region",
"extra__snowflake__role": "af_role",
"extra__snowflake__insecure_mode": "False",
+ "extra__snowflake__json_result_force_utf8_decoding":
"False",
"extra__snowflake__client_request_mfa_token": "False",
},
},
@@ -247,6 +250,7 @@ class TestPytestSnowflakeHook:
"extra": {
**BASE_CONNECTION_KWARGS["extra"],
"extra__snowflake__insecure_mode": False,
+ "extra__snowflake__json_result_force_utf8_decoding":
True,
"extra__snowflake__client_request_mfa_token": False,
},
},
@@ -266,6 +270,7 @@ class TestPytestSnowflakeHook:
"session_parameters": None,
"user": "user",
"warehouse": "af_wh",
+ "json_result_force_utf8_decoding": True,
},
),
],
@@ -473,6 +478,23 @@ class TestPytestSnowflakeHook:
)
assert mock_create_engine.return_value == conn
+ def
test_get_sqlalchemy_engine_should_support_json_result_force_utf8_decoding(self):
+ connection_kwargs = deepcopy(BASE_CONNECTION_KWARGS)
+
connection_kwargs["extra"]["extra__snowflake__json_result_force_utf8_decoding"]
= "True"
+
+ with (
+ mock.patch.dict("os.environ",
AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()),
+
mock.patch("airflow.providers.snowflake.hooks.snowflake.create_engine") as
mock_create_engine,
+ ):
+ hook = SnowflakeHook(snowflake_conn_id="test_conn")
+ conn = hook.get_sqlalchemy_engine()
+ mock_create_engine.assert_called_once_with(
+ "snowflake://user:[email protected]_region/db/public"
+
"?application=AIRFLOW&authenticator=snowflake&role=af_role&warehouse=af_wh",
+ connect_args={"json_result_force_utf8_decoding": True},
+ )
+ assert mock_create_engine.return_value == conn
+
def test_get_sqlalchemy_engine_should_support_session_parameters(self):
connection_kwargs = deepcopy(BASE_CONNECTION_KWARGS)
connection_kwargs["extra"]["session_parameters"] = {"TEST_PARAM":
"AA", "TEST_PARAM_B": 123}