This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 7e9b9edd76a snowflake: pass through the ocsp_fail_open setting (#46476)
7e9b9edd76a is described below
commit 7e9b9edd76a9535fd6137ab33142f2c735eb4d4b
Author: Joe Crobak <[email protected]>
AuthorDate: Wed Feb 5 17:02:01 2025 -0500
snowflake: pass through the ocsp_fail_open setting (#46476)
For sqlalchemy, this needs to be passed in as part of connect_args.
---
providers/snowflake/docs/connections/snowflake.rst | 1 +
.../airflow/providers/snowflake/hooks/snowflake.py | 10 +++
.../snowflake/hooks/test_snowflake.py | 71 ++++++++++++++++++++++
3 files changed, 82 insertions(+)
diff --git a/providers/snowflake/docs/connections/snowflake.rst
b/providers/snowflake/docs/connections/snowflake.rst
index 741d73a62e3..2d7076d120f 100644
--- a/providers/snowflake/docs/connections/snowflake.rst
+++ b/providers/snowflake/docs/connections/snowflake.rst
@@ -64,6 +64,7 @@ Extra (optional)
* ``insecure_mode``: Turn off OCSP certificate checks. For details, see:
`How To: Turn Off OCSP Checking in Snowflake Client Drivers - Snowflake
Community
<https://community.snowflake.com/s/article/How-to-turn-off-OCSP-checking-in-Snowflake-client-drivers>`_.
* ``host``: Target Snowflake hostname to connect to (e.g., for local
testing with LocalStack).
* ``port``: Target Snowflake port to connect to (e.g., for local testing
with LocalStack).
+ * ``ocsp_fail_open``: Specify `ocsp_fail_open
<https://docs.snowflake.com/en/developer-guide/python-connector/python-connector-connect#label-python-ocsp-choosing-fail-open-or-fail-close-mode>`_.
URI format example
^^^^^^^^^^^^^^^^^^
diff --git
a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py
b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py
index 45e12666b88..5777968b8d8 100644
--- a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py
+++ b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py
@@ -299,6 +299,12 @@ class SnowflakeHook(DbApiHook):
if snowflake_port:
conn_config["port"] = snowflake_port
+ # if a value for ocsp_fail_open is set, pass it along.
+ # Note the check is for `is not None` so that we can pass along
`False` as a value.
+ ocsp_fail_open = extra_dict.get("ocsp_fail_open")
+ if ocsp_fail_open is not None:
+ conn_config["ocsp_fail_open"] = _try_to_boolean(ocsp_fail_open)
+
return conn_config
def get_uri(self) -> str:
@@ -320,6 +326,7 @@ class SnowflakeHook(DbApiHook):
"client_request_mfa_token",
"client_store_temporary_credential",
"json_result_force_utf8_decoding",
+ "ocsp_fail_open",
]
}
)
@@ -345,6 +352,9 @@ class SnowflakeHook(DbApiHook):
if "json_result_force_utf8_decoding" in conn_params:
engine_kwargs.setdefault("connect_args", {})
engine_kwargs["connect_args"]["json_result_force_utf8_decoding"] =
True
+ if "ocsp_fail_open" in conn_params:
+ engine_kwargs.setdefault("connect_args", {})
+ engine_kwargs["connect_args"]["ocsp_fail_open"] =
conn_params["ocsp_fail_open"]
for key in ["session_parameters", "private_key"]:
if conn_params.get(key):
engine_kwargs.setdefault("connect_args", {})
diff --git
a/providers/snowflake/tests/provider_tests/snowflake/hooks/test_snowflake.py
b/providers/snowflake/tests/provider_tests/snowflake/hooks/test_snowflake.py
index 775e9382729..b1a65b4293b 100644
--- a/providers/snowflake/tests/provider_tests/snowflake/hooks/test_snowflake.py
+++ b/providers/snowflake/tests/provider_tests/snowflake/hooks/test_snowflake.py
@@ -277,6 +277,60 @@ class TestPytestSnowflakeHook:
"json_result_force_utf8_decoding": True,
},
),
+ (
+ {
+ **BASE_CONNECTION_KWARGS,
+ "extra": {
+ **BASE_CONNECTION_KWARGS["extra"],
+ "ocsp_fail_open": True,
+ },
+ },
+ (
+ "snowflake://user:[email protected]_region/db/public?"
+
"application=AIRFLOW&authenticator=snowflake&role=af_role&warehouse=af_wh"
+ ),
+ {
+ "account": "airflow",
+ "application": "AIRFLOW",
+ "authenticator": "snowflake",
+ "database": "db",
+ "password": "pw",
+ "region": "af_region",
+ "role": "af_role",
+ "schema": "public",
+ "session_parameters": None,
+ "user": "user",
+ "warehouse": "af_wh",
+ "ocsp_fail_open": True,
+ },
+ ),
+ (
+ {
+ **BASE_CONNECTION_KWARGS,
+ "extra": {
+ **BASE_CONNECTION_KWARGS["extra"],
+ "ocsp_fail_open": False,
+ },
+ },
+ (
+ "snowflake://user:[email protected]_region/db/public?"
+
"application=AIRFLOW&authenticator=snowflake&role=af_role&warehouse=af_wh"
+ ),
+ {
+ "account": "airflow",
+ "application": "AIRFLOW",
+ "authenticator": "snowflake",
+ "database": "db",
+ "password": "pw",
+ "region": "af_region",
+ "role": "af_role",
+ "schema": "public",
+ "session_parameters": None,
+ "user": "user",
+ "warehouse": "af_wh",
+ "ocsp_fail_open": False,
+ },
+ ),
],
)
def test_hook_should_support_prepare_basic_conn_params_and_uri(
@@ -530,6 +584,23 @@ class TestPytestSnowflakeHook:
assert "private_key" in
mock_create_engine.call_args.kwargs["connect_args"]
assert mock_create_engine.return_value == conn
+ def test_get_sqlalchemy_engine_should_support_ocsp_fail_open(self):
+ connection_kwargs = deepcopy(BASE_CONNECTION_KWARGS)
+ connection_kwargs["extra"]["ocsp_fail_open"] = "False"
+
+ with (
+ mock.patch.dict("os.environ",
AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()),
+
mock.patch("airflow.providers.snowflake.hooks.snowflake.create_engine") as
mock_create_engine,
+ ):
+ hook = SnowflakeHook(snowflake_conn_id="test_conn")
+ conn = hook.get_sqlalchemy_engine()
+ mock_create_engine.assert_called_once_with(
+ "snowflake://user:[email protected]_region/db/public"
+
"?application=AIRFLOW&authenticator=snowflake&role=af_role&warehouse=af_wh",
+ connect_args={"ocsp_fail_open": False},
+ )
+ assert mock_create_engine.return_value == conn
+
def test_hook_parameters_should_take_precedence(self):
with mock.patch.dict(
"os.environ",
AIRFLOW_CONN_TEST_CONN=Connection(**BASE_CONNECTION_KWARGS).get_uri()