This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new c2fe5544de7 [Snowflake] [Feat] Allow SnowflakeHook +
SnowflakeSqlApiHook `private_key_content` to use raw key in addition to base64
encoding (#62378)
c2fe5544de7 is described below
commit c2fe5544de70dc70e4ecbd9e84f0284de28af953
Author: Daniel Reeves <[email protected]>
AuthorDate: Mon Mar 9 16:14:40 2026 -0400
[Snowflake] [Feat] Allow SnowflakeHook + SnowflakeSqlApiHook
`private_key_content` to use raw key in addition to base64 encoding (#62378)
* allow SnowflakeHook private_key_content to use raw key instead of only
base64 encoding.
* Make code more DRY by moving get_private_key() to SnowflakeHook.
* Fix get_connection call in SnowflakeHook
* assert private_key is b64decoded when passed in as b64encoded value
* Update Snowflake provider docs to clarify new private_key_content behavior
---
providers/snowflake/docs/connections/snowflake.rst | 2 +-
.../airflow/providers/snowflake/hooks/snowflake.py | 85 ++++++++++++--------
.../providers/snowflake/hooks/snowflake_sql_api.py | 43 +----------
.../tests/unit/snowflake/hooks/test_snowflake.py | 90 +++++++++++++++++++++-
.../unit/snowflake/hooks/test_snowflake_sql_api.py | 16 ++--
5 files changed, 150 insertions(+), 86 deletions(-)
diff --git a/providers/snowflake/docs/connections/snowflake.rst
b/providers/snowflake/docs/connections/snowflake.rst
index 900ed04a1ed..584bbcfa7c4 100644
--- a/providers/snowflake/docs/connections/snowflake.rst
+++ b/providers/snowflake/docs/connections/snowflake.rst
@@ -64,7 +64,7 @@ Extra (optional)
* ``refresh_token``: Specify refresh_token for OAuth connection.
* ``azure_conn_id``: Azure Connection ID to be used for retrieving the
OAuth token using Azure Entra authentication. Login and Password fields aren't
required when using this method. Scope for the Azure OAuth token can be set in
the config option ``azure_oauth_scope`` under the section ``[snowflake]``.
Requires `apache-airflow-providers-microsoft-azure>=12.8.0`.
* ``private_key_file``: Specify the path to the private key file.
- * ``private_key_content``: Specify the content of the private key file in
base64 encoded format. You can use the following Python code to encode the
private key:
+ * ``private_key_content``: Specify the content of the private key file,
either in plain text or base64 encoded. When using the Airflow UI to manage the
Snowflake connection, you should base64 encode the ``private_key_content``. You
can use the following Python code to encode the private key:
.. code-block:: python
diff --git
a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py
b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py
index 6199bd6ff54..3a9205fd8f5 100644
--- a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py
+++ b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py
@@ -56,6 +56,7 @@ T = TypeVar("T")
if TYPE_CHECKING:
+ from cryptography.hazmat.primitives.asymmetric.types import PrivateKeyTypes
from snowflake.connector import SnowflakeConnection
from airflow.providers.openlineage.extractors import OperatorLineage
@@ -400,40 +401,9 @@ class SnowflakeHook(DbApiHook):
if client_store_temporary_credential:
conn_config["client_store_temporary_credential"] =
client_store_temporary_credential
- # If private_key_file is specified in the extra json, load the
contents of the file as a private key.
- # If private_key_content is specified in the extra json, use it as a
private key.
- # As a next step, specify this private key in the connection
configuration.
- # The connection password then becomes the passphrase for the private
key.
- # If your private key is not encrypted (not recommended), then leave
the password empty.
-
- private_key_file = self._get_field(extra_dict, "private_key_file")
- private_key_content = self._get_field(extra_dict,
"private_key_content")
-
- private_key_pem = None
- if private_key_content and private_key_file:
- raise AirflowException(
- "The private_key_file and private_key_content extra fields are
mutually exclusive. "
- "Please remove one."
- )
- if private_key_file:
- private_key_file_path = Path(private_key_file)
- if not private_key_file_path.is_file() or
private_key_file_path.stat().st_size == 0:
- raise ValueError("The private_key_file path points to an empty
or invalid file.")
- if private_key_file_path.stat().st_size > 4096:
- raise ValueError("The private_key_file size is too big. Please
keep it less than 4 KB.")
- private_key_pem = Path(private_key_file_path).read_bytes()
- elif private_key_content:
- private_key_pem = base64.b64decode(private_key_content)
-
- if private_key_pem:
- passphrase = None
- if conn.password:
- passphrase = conn.password.strip().encode()
-
- p_key = serialization.load_pem_private_key(
- private_key_pem, password=passphrase, backend=default_backend()
- )
+ p_key = self.get_private_key()
+ if p_key:
pkb = p_key.private_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PrivateFormat.PKCS8,
@@ -587,6 +557,55 @@ class SnowflakeHook(DbApiHook):
response.raise_for_status()
return response
+ def get_private_key(self) -> PrivateKeyTypes | None:
+ """Get the private key from snowflake connection."""
+ conn = self.get_connection(self.get_conn_id())
+ extra_dict = conn.extra_dejson
+
+ # If private_key_file is specified in the extra json, load the
contents of the file as a private key.
+ # If private_key_content is specified in the extra json, use it as a
private key.
+ # As a next step, specify this private key in the connection
configuration.
+ # The connection password then becomes the passphrase for the private
key.
+ # If your private key is not encrypted (not recommended), then leave
the password empty.
+
+ private_key_file = self._get_field(extra_dict, "private_key_file")
+ private_key_content = self._get_field(extra_dict,
"private_key_content")
+
+ passphrase = None
+ if conn.password:
+ passphrase = conn.password.strip().encode()
+
+ private_key_pem = None
+ p_key = None
+
+ if private_key_content and private_key_file:
+ raise AirflowException(
+ "The private_key_file and private_key_content extra fields are
mutually exclusive. "
+ "Please remove one."
+ )
+ if private_key_file:
+ private_key_file_path = Path(private_key_file)
+ if not private_key_file_path.is_file() or
private_key_file_path.stat().st_size == 0:
+ raise ValueError("The private_key_file path points to an empty
or invalid file.")
+ if private_key_file_path.stat().st_size > 4096:
+ raise ValueError("The private_key_file size is too big. Please
keep it less than 4 KB.")
+ private_key_pem = Path(private_key_file_path).read_bytes()
+ elif private_key_content:
+ try:
+ p_key = serialization.load_pem_private_key(
+ private_key_content.encode(), password=passphrase,
backend=default_backend()
+ )
+ except (TypeError, ValueError):
+ # Assume base64 encoding if string is not valid private key
+ private_key_pem = base64.b64decode(private_key_content)
+
+ if private_key_pem:
+ p_key = serialization.load_pem_private_key(
+ private_key_pem, password=passphrase, backend=default_backend()
+ )
+
+ return p_key
+
def get_uri(self) -> str:
"""Override DbApiHook get_uri method for get_sqlalchemy_engine()."""
conn_params = self._get_conn_params()
diff --git
a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py
b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py
index efe45ed9d13..2f52be48b88 100644
---
a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py
+++
b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py
@@ -17,19 +17,15 @@
from __future__ import annotations
import asyncio
-import base64
import time
import uuid
import warnings
from datetime import timedelta
-from pathlib import Path
from typing import Any
import aiohttp
import requests
from aiohttp import ClientConnectionError, ClientResponseError
-from cryptography.hazmat.backends import default_backend
-from cryptography.hazmat.primitives import serialization
from requests.exceptions import ConnectionError, HTTPError, Timeout
from tenacity import (
AsyncRetrying,
@@ -134,43 +130,6 @@ class SnowflakeSqlApiHook(SnowflakeHook):
self.aiohttp_session_kwargs = aiohttp_session_kwargs or {}
self.aiohttp_request_kwargs = aiohttp_request_kwargs or {}
- def get_private_key(self) -> None:
- """Get the private key from snowflake connection."""
- conn = self.get_connection(self.snowflake_conn_id)
-
- # If private_key_file is specified in the extra json, load the
contents of the file as a private key.
- # If private_key_content is specified in the extra json, use it as a
private key.
- # As a next step, specify this private key in the connection
configuration.
- # The connection password then becomes the passphrase for the private
key.
- # If your private key is not encrypted (not recommended), then leave
the password empty.
-
- private_key_file = conn.extra_dejson.get(
- "extra__snowflake__private_key_file"
- ) or conn.extra_dejson.get("private_key_file")
- private_key_content = conn.extra_dejson.get(
- "extra__snowflake__private_key_content"
- ) or conn.extra_dejson.get("private_key_content")
-
- private_key_pem = None
- if private_key_content and private_key_file:
- raise AirflowException(
- "The private_key_file and private_key_content extra fields are
mutually exclusive. "
- "Please remove one."
- )
- if private_key_file:
- private_key_pem = Path(private_key_file).read_bytes()
- elif private_key_content:
- private_key_pem = base64.b64decode(private_key_content)
-
- if private_key_pem:
- passphrase = None
- if conn.password:
- passphrase = conn.password.strip().encode()
-
- self.private_key = serialization.load_pem_private_key(
- private_key_pem, password=passphrase, backend=default_backend()
- )
-
def execute_query(
self, sql: str, statement_count: int, query_tag: str = "", bindings:
dict[str, Any] | None = None
) -> list[str]:
@@ -272,7 +231,7 @@ class SnowflakeSqlApiHook(SnowflakeHook):
# Alternatively, get the JWT token from the connection details and the
private key
if not self.private_key:
- self.get_private_key()
+ self.private_key = self.get_private_key()
token = JWTGenerator(
conn_config["account"], # type: ignore[arg-type]
diff --git a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py
b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py
index fa3f5a65a03..901ac925c07 100644
--- a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py
+++ b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py
@@ -372,9 +372,93 @@ class TestPytestSnowflakeHook:
assert SnowflakeHook(snowflake_conn_id="test_conn").get_uri() ==
expected_uri
assert
SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params() ==
expected_conn_params
+ def test_plain_text_unencrypted_private_key_is_not_base64_encoded(
+ self, unencrypted_temporary_private_key: Path
+ ):
+ """Test get_private_key function skips base64 encoding if private key
is plain text."""
+ private_key_content = unencrypted_temporary_private_key.read_text()
+
+ p_key = serialization.load_pem_private_key(
+ private_key_content.encode(),
+ password=None,
+ backend=default_backend(),
+ )
+
+ pkb = p_key.private_bytes(
+ encoding=serialization.Encoding.DER,
+ format=serialization.PrivateFormat.PKCS8,
+ encryption_algorithm=serialization.NoEncryption(),
+ )
+
+ connection_kwargs: Any = {
+ **BASE_CONNECTION_KWARGS,
+ "password": None,
+ "extra": {
+ "database": "db",
+ "account": "airflow",
+ "warehouse": "af_wh",
+ "region": "af_region",
+ "role": "af_role",
+ "private_key_content": private_key_content,
+ },
+ }
+ with mock.patch.dict("os.environ",
AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()):
+ conn_params =
SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params()
+ assert "private_key" in conn_params
+ assert pkb == conn_params["private_key"]
+
+ def test_plain_text_encrypted_private_key_is_not_base64_encoded(
+ self, encrypted_temporary_private_key: Path
+ ):
+ """Test get_private_key function skips base64 encoding if private key
is plain text."""
+ private_key_content = encrypted_temporary_private_key.read_text()
+
+ p_key = serialization.load_pem_private_key(
+ private_key_content.encode(),
+ password=_PASSWORD.encode(),
+ backend=default_backend(),
+ )
+
+ pkb = p_key.private_bytes(
+ encoding=serialization.Encoding.DER,
+ format=serialization.PrivateFormat.PKCS8,
+ encryption_algorithm=serialization.NoEncryption(),
+ )
+
+ connection_kwargs: Any = {
+ **BASE_CONNECTION_KWARGS,
+ "password": _PASSWORD,
+ "extra": {
+ "database": "db",
+ "account": "airflow",
+ "warehouse": "af_wh",
+ "region": "af_region",
+ "role": "af_role",
+ "private_key_content": private_key_content,
+ },
+ }
+ with mock.patch.dict("os.environ",
AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()):
+ conn_params =
SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params()
+ assert "private_key" in conn_params
+ assert pkb == conn_params["private_key"]
+
def test_get_conn_params_should_support_private_auth_in_connection(
- self, base64_encoded_encrypted_private_key: Path
+ self, base64_encoded_encrypted_private_key: str,
encrypted_temporary_private_key: Path
):
+ private_key_content = encrypted_temporary_private_key.read_text()
+
+ p_key = serialization.load_pem_private_key(
+ private_key_content.encode(),
+ password=_PASSWORD.encode(),
+ backend=default_backend(),
+ )
+
+ pkb = p_key.private_bytes(
+ encoding=serialization.Encoding.DER,
+ format=serialization.PrivateFormat.PKCS8,
+ encryption_algorithm=serialization.NoEncryption(),
+ )
+
connection_kwargs: Any = {
**BASE_CONNECTION_KWARGS,
"password": _PASSWORD,
@@ -388,7 +472,9 @@ class TestPytestSnowflakeHook:
},
}
with mock.patch.dict("os.environ",
AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()):
- assert "private_key" in
SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params()
+ conn_params =
SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params()
+ assert "private_key" in conn_params
+ assert conn_params["private_key"] == pkb
@pytest.mark.parametrize("include_params", [True, False])
def test_hook_param_beats_extra(self, include_params):
diff --git
a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake_sql_api.py
b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake_sql_api.py
index a5e332df41c..32e475fa15b 100644
--- a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake_sql_api.py
+++ b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake_sql_api.py
@@ -561,8 +561,8 @@ class TestSnowflakeSqlApiHook:
"os.environ",
AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()
):
hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn")
- hook.get_private_key()
- assert hook.private_key is not None
+ private_key = hook.get_private_key()
+ assert private_key is not None
def test_get_private_key_raise_exception(
self, encrypted_temporary_private_key: Path,
base64_encoded_encrypted_private_key: str
@@ -617,8 +617,8 @@ class TestSnowflakeSqlApiHook:
"os.environ",
AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()
):
hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn")
- hook.get_private_key()
- assert hook.private_key is not None
+ private_key = hook.get_private_key()
+ assert private_key is not None
def test_get_private_key_should_support_private_auth_with_unencrypted_key(
self,
@@ -640,15 +640,15 @@ class TestSnowflakeSqlApiHook:
"os.environ",
AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()
):
hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn")
- hook.get_private_key()
- assert hook.private_key is not None
+ private_key = hook.get_private_key()
+ assert private_key is not None
connection_kwargs["password"] = ""
with unittest.mock.patch.dict(
"os.environ",
AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()
):
hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn")
- hook.get_private_key()
- assert hook.private_key is not None
+ private_key = hook.get_private_key()
+ assert private_key is not None
connection_kwargs["password"] = _PASSWORD
with (
unittest.mock.patch.dict(