SameerMesiah97 commented on code in PR #62378:
URL: https://github.com/apache/airflow/pull/62378#discussion_r2843260344
##########
providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py:
##########
@@ -411,7 +411,17 @@ def _get_static_conn_params(self) -> dict[str, str | None]:
raise ValueError("The private_key_file size is too big. Please
keep it less than 4 KB.")
private_key_pem = Path(private_key_file_path).read_bytes()
elif private_key_content:
- private_key_pem = base64.b64decode(private_key_content)
+ if any(
+ private_key_content.startswith(header)
+ for header in [
+ "-----BEGIN ENCRYPTED PRIVATE KEY-----\n",
+ "-----BEGIN RSA PRIVATE KEY-----\n",
+ "-----BEGIN PRIVATE KEY-----\n",
Review Comment:
Matching these 3 headers is likely good enough but seems a bit too brittle.
And what about cases where the header does not terminate with a new line or has
leading/trailing whitespaces. Now, you could argue that it is expected that
private keys headers will conform to the set you are validating against but
reading RFC 7468, it seems like there is no strict requirements regarding the
inclusion of new lines or exclusion of whitespace. Have you considered
stripping `private_key_content` of whitespaces and just checking if it starts
with `'-----BEGIN'`. Unless there's a need to restrict it to the three PEM
formats, I don't think you need to validate to the labels either.
##########
providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py:
##########
@@ -411,7 +411,17 @@ def _get_static_conn_params(self) -> dict[str, str | None]:
raise ValueError("The private_key_file size is too big. Please
keep it less than 4 KB.")
private_key_pem = Path(private_key_file_path).read_bytes()
elif private_key_content:
- private_key_pem = base64.b64decode(private_key_content)
+ if any(
+ private_key_content.startswith(header)
+ for header in [
+ "-----BEGIN ENCRYPTED PRIVATE KEY-----\n",
+ "-----BEGIN RSA PRIVATE KEY-----\n",
+ "-----BEGIN PRIVATE KEY-----\n",
+ ]
+ ):
+ private_key_pem = private_key_content.encode()
+ else:
+ private_key_pem = base64.b64decode(private_key_content)
Review Comment:
Assuming that the `private_key_content` is in valid base64 format just
because it does not start with the headers you enumerated is somewhat of a
heuristic. That might likely be true but this part can be made more robust by
wrapping line 424 in a try/except and then catching `binascii.Error` and
raising an error.
##########
providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py:
##########
@@ -372,6 +372,47 @@ def
test_hook_should_support_prepare_basic_conn_params_and_uri(
assert SnowflakeHook(snowflake_conn_id="test_conn").get_uri() ==
expected_uri
assert
SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params() ==
expected_conn_params
+ def test_get_private_key_is_not_base64_encoded_with_headers(
+ self, unencrypted_temporary_private_key: Path,
encrypted_temporary_private_key: Path
+ ):
+ """Test get_private_key function skips base64 encoding if private key
headers are present."""
+ for pwd, key_file in [
+ (None, unencrypted_temporary_private_key),
+ (_PASSWORD, encrypted_temporary_private_key),
+ ]:
+ private_key_content = key_file.read_text()
+
+ p_key = serialization.load_pem_private_key(
+ private_key_content.encode(),
+ password=(pwd.encode() if pwd is not None else None),
+ backend=default_backend(),
+ )
+
+ pkb = p_key.private_bytes(
+ encoding=serialization.Encoding.DER,
+ format=serialization.PrivateFormat.PKCS8,
+ encryption_algorithm=serialization.NoEncryption(),
+ )
+
+ connection_kwargs: Any = {
+ **BASE_CONNECTION_KWARGS,
+ "password": pwd,
+ "extra": {
+ "database": "db",
+ "account": "airflow",
+ "warehouse": "af_wh",
+ "region": "af_region",
+ "role": "af_role",
+ "private_key_content": private_key_content,
+ },
+ }
+ with mock.patch.dict(
+ "os.environ",
AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()
+ ):
+ conn_params =
SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params()
+ assert "private_key" in conn_params
+ assert pkb == conn_params["private_key"]
Review Comment:
I would assert that `base64.decode()` is not called. Do not patch it
globally. Instead patch it where it is imported in `SnowflakeHook`.
##########
providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py:
##########
@@ -372,6 +372,47 @@ def
test_hook_should_support_prepare_basic_conn_params_and_uri(
assert SnowflakeHook(snowflake_conn_id="test_conn").get_uri() ==
expected_uri
assert
SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params() ==
expected_conn_params
Review Comment:
I think you should add another test for the second branch of your
implementation i.e. private_key_content does not match the headers in the list.
It would be best to assert that `base64.decode()` is not called in that test.
Optionally, you could add another test or add a case to the above test for
invalid base64 in the same codepath.
##########
providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py:
##########
@@ -372,6 +372,47 @@ def
test_hook_should_support_prepare_basic_conn_params_and_uri(
assert SnowflakeHook(snowflake_conn_id="test_conn").get_uri() ==
expected_uri
assert
SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params() ==
expected_conn_params
+ def test_get_private_key_is_not_base64_encoded_with_headers(
+ self, unencrypted_temporary_private_key: Path,
encrypted_temporary_private_key: Path
+ ):
+ """Test get_private_key function skips base64 encoding if private key
headers are present."""
+ for pwd, key_file in [
+ (None, unencrypted_temporary_private_key),
+ (_PASSWORD, encrypted_temporary_private_key),
+ ]:
Review Comment:
I think it would be better to parametrize this instead of using a loop. This
is functional but it makes it difficult to see which branch of the test failed
during test runs.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]