SameerMesiah97 commented on code in PR #62378:
URL: https://github.com/apache/airflow/pull/62378#discussion_r2843260344


##########
providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py:
##########
@@ -411,7 +411,17 @@ def _get_static_conn_params(self) -> dict[str, str | None]:
                 raise ValueError("The private_key_file size is too big. Please 
keep it less than 4 KB.")
             private_key_pem = Path(private_key_file_path).read_bytes()
         elif private_key_content:
-            private_key_pem = base64.b64decode(private_key_content)
+            if any(
+                private_key_content.startswith(header)
+                for header in [
+                    "-----BEGIN ENCRYPTED PRIVATE KEY-----\n",
+                    "-----BEGIN RSA PRIVATE KEY-----\n",
+                    "-----BEGIN PRIVATE KEY-----\n",

Review Comment:
   Matching these 3 headers is likely good enough but seems a bit too brittle. 
And what about cases where the header does not terminate with a new line or has 
leading/trailing whitespaces. Now, you could argue that it is expected that 
private keys headers will conform to the set you are validating against but 
reading RFC 7468, it seems like there is no strict requirements regarding the 
inclusion of new lines or exclusion of whitespace. Have you considered 
stripping `private_key_content` of whitespaces and just checking if it starts 
with `'-----BEGIN'`. Unless there's a need to restrict it to the three PEM 
formats, I don't think you need to validate to the labels either. 



##########
providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py:
##########
@@ -411,7 +411,17 @@ def _get_static_conn_params(self) -> dict[str, str | None]:
                 raise ValueError("The private_key_file size is too big. Please 
keep it less than 4 KB.")
             private_key_pem = Path(private_key_file_path).read_bytes()
         elif private_key_content:
-            private_key_pem = base64.b64decode(private_key_content)
+            if any(
+                private_key_content.startswith(header)
+                for header in [
+                    "-----BEGIN ENCRYPTED PRIVATE KEY-----\n",
+                    "-----BEGIN RSA PRIVATE KEY-----\n",
+                    "-----BEGIN PRIVATE KEY-----\n",
+                ]
+            ):
+                private_key_pem = private_key_content.encode()
+            else:
+                private_key_pem = base64.b64decode(private_key_content)

Review Comment:
   Assuming that the `private_key_content` is in valid base64 format just 
because it does not start with the headers you enumerated is somewhat of a 
heuristic. That might likely be true but this part can be made more robust by 
wrapping line 424 in a try/except and then catching `binascii.Error` and 
raising an error. 



##########
providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py:
##########
@@ -372,6 +372,47 @@ def 
test_hook_should_support_prepare_basic_conn_params_and_uri(
             assert SnowflakeHook(snowflake_conn_id="test_conn").get_uri() == 
expected_uri
             assert 
SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params() == 
expected_conn_params
 
+    def test_get_private_key_is_not_base64_encoded_with_headers(
+        self, unencrypted_temporary_private_key: Path, 
encrypted_temporary_private_key: Path
+    ):
+        """Test get_private_key function skips base64 encoding if private key 
headers are present."""
+        for pwd, key_file in [
+            (None, unencrypted_temporary_private_key),
+            (_PASSWORD, encrypted_temporary_private_key),
+        ]:
+            private_key_content = key_file.read_text()
+
+            p_key = serialization.load_pem_private_key(
+                private_key_content.encode(),
+                password=(pwd.encode() if pwd is not None else None),
+                backend=default_backend(),
+            )
+
+            pkb = p_key.private_bytes(
+                encoding=serialization.Encoding.DER,
+                format=serialization.PrivateFormat.PKCS8,
+                encryption_algorithm=serialization.NoEncryption(),
+            )
+
+            connection_kwargs: Any = {
+                **BASE_CONNECTION_KWARGS,
+                "password": pwd,
+                "extra": {
+                    "database": "db",
+                    "account": "airflow",
+                    "warehouse": "af_wh",
+                    "region": "af_region",
+                    "role": "af_role",
+                    "private_key_content": private_key_content,
+                },
+            }
+            with mock.patch.dict(
+                "os.environ", 
AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()
+            ):
+                conn_params = 
SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params()
+                assert "private_key" in conn_params
+                assert pkb == conn_params["private_key"]

Review Comment:
   I would assert that `base64.decode()` is not called. Do not patch it 
globally. Instead patch it where it is imported in `SnowflakeHook`.



##########
providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py:
##########
@@ -372,6 +372,47 @@ def 
test_hook_should_support_prepare_basic_conn_params_and_uri(
             assert SnowflakeHook(snowflake_conn_id="test_conn").get_uri() == 
expected_uri
             assert 
SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params() == 
expected_conn_params
 

Review Comment:
   I think you should add another test for the second branch of your 
implementation i.e. private_key_content does not match the headers in the list. 
It would be best to assert that `base64.decode()` is not called in that test. 
Optionally, you could add another test or add a case to the above test for 
invalid base64 in the same codepath. 



##########
providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py:
##########
@@ -372,6 +372,47 @@ def 
test_hook_should_support_prepare_basic_conn_params_and_uri(
             assert SnowflakeHook(snowflake_conn_id="test_conn").get_uri() == 
expected_uri
             assert 
SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params() == 
expected_conn_params
 
+    def test_get_private_key_is_not_base64_encoded_with_headers(
+        self, unencrypted_temporary_private_key: Path, 
encrypted_temporary_private_key: Path
+    ):
+        """Test get_private_key function skips base64 encoding if private key 
headers are present."""
+        for pwd, key_file in [
+            (None, unencrypted_temporary_private_key),
+            (_PASSWORD, encrypted_temporary_private_key),
+        ]:

Review Comment:
   I think it would be better to parametrize this instead of using a loop. This 
is functional but it makes it difficult to see which branch of the test failed 
during test runs. 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to