dwreeves commented on code in PR #62378:
URL: https://github.com/apache/airflow/pull/62378#discussion_r2847409866


##########
providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py:
##########
@@ -372,6 +372,47 @@ def 
test_hook_should_support_prepare_basic_conn_params_and_uri(
             assert SnowflakeHook(snowflake_conn_id="test_conn").get_uri() == 
expected_uri
             assert 
SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params() == 
expected_conn_params
 
+    def test_get_private_key_is_not_base64_encoded_with_headers(
+        self, unencrypted_temporary_private_key: Path, 
encrypted_temporary_private_key: Path
+    ):
+        """Test get_private_key function skips base64 encoding if private key 
headers are present."""
+        for pwd, key_file in [
+            (None, unencrypted_temporary_private_key),
+            (_PASSWORD, encrypted_temporary_private_key),
+        ]:
+            private_key_content = key_file.read_text()
+
+            p_key = serialization.load_pem_private_key(
+                private_key_content.encode(),
+                password=(pwd.encode() if pwd is not None else None),
+                backend=default_backend(),
+            )
+
+            pkb = p_key.private_bytes(
+                encoding=serialization.Encoding.DER,
+                format=serialization.PrivateFormat.PKCS8,
+                encryption_algorithm=serialization.NoEncryption(),
+            )
+
+            connection_kwargs: Any = {
+                **BASE_CONNECTION_KWARGS,
+                "password": pwd,
+                "extra": {
+                    "database": "db",
+                    "account": "airflow",
+                    "warehouse": "af_wh",
+                    "region": "af_region",
+                    "role": "af_role",
+                    "private_key_content": private_key_content,
+                },
+            }
+            with mock.patch.dict(
+                "os.environ", 
AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()
+            ):
+                conn_params = 
SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params()
+                assert "private_key" in conn_params
+                assert pkb == conn_params["private_key"]

Review Comment:
   Maybe this is a little overkill, but I made a deliberate decision to avoid 
mocking `base64.decode()` because I am unclear if, in the future, `base64` may 
be called inside the function, but outside of the private key flow. So I felt 
there was a very small but nonzero chance that asserting `base64.decode()` 
could, sometime in the future, incorrectly fail the test, and that this 
approach was more explicit.
   
   WDYT?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to