SameerMesiah97 commented on code in PR #62378:
URL: https://github.com/apache/airflow/pull/62378#discussion_r2847725086


##########
providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py:
##########
@@ -372,6 +372,47 @@ def 
test_hook_should_support_prepare_basic_conn_params_and_uri(
             assert SnowflakeHook(snowflake_conn_id="test_conn").get_uri() == 
expected_uri
             assert 
SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params() == 
expected_conn_params
 
+    def test_get_private_key_is_not_base64_encoded_with_headers(
+        self, unencrypted_temporary_private_key: Path, 
encrypted_temporary_private_key: Path
+    ):
+        """Test get_private_key function skips base64 encoding if private key 
headers are present."""
+        for pwd, key_file in [
+            (None, unencrypted_temporary_private_key),
+            (_PASSWORD, encrypted_temporary_private_key),
+        ]:
+            private_key_content = key_file.read_text()
+
+            p_key = serialization.load_pem_private_key(
+                private_key_content.encode(),
+                password=(pwd.encode() if pwd is not None else None),
+                backend=default_backend(),
+            )
+
+            pkb = p_key.private_bytes(
+                encoding=serialization.Encoding.DER,
+                format=serialization.PrivateFormat.PKCS8,
+                encryption_algorithm=serialization.NoEncryption(),
+            )
+
+            connection_kwargs: Any = {
+                **BASE_CONNECTION_KWARGS,
+                "password": pwd,
+                "extra": {
+                    "database": "db",
+                    "account": "airflow",
+                    "warehouse": "af_wh",
+                    "region": "af_region",
+                    "role": "af_role",
+                    "private_key_content": private_key_content,
+                },
+            }
+            with mock.patch.dict(
+                "os.environ", 
AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()
+            ):
+                conn_params = 
SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params()
+                assert "private_key" in conn_params
+                assert pkb == conn_params["private_key"]

Review Comment:
   > > Are you worried about another base64.b64decode() call being added in 
another part of _get_conn_params?
   > 
   > Yes.
   > 
   > I can assert it was either not called or not called with the first arg. If 
we `try: load_pem_private_key except: b64decode`, then that works.
   > 
   > But-- and sorry for pushing back a bit-- I guess I'm left wondering why it 
is so important to use `mock` to listen for a call, when it is more explicit 
just assert the return value is what we expect. Mocking a call and 
`assert_called` makes a lot of sense for function calls with side-effects or 
for proper unit test isolation. Here the function can be a black box for all we 
care; we just want to ensure the private key is read correctly into the params, 
and I am testing this directly.
   
   Okay. I see your point. I agree it is not critical to the test but it would 
make the implementation more robust. The PR should not be blocked because of 
this. It's up to you accept the suggestion. As long as the return value is 
being checked, we should be fine.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to