This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new a47b69e6c06 reformat add proxy support commit (#60432)
a47b69e6c06 is described below
commit a47b69e6c06ba906fb481d799089b59f5f09926d
Author: mansfieldj3 <[email protected]>
AuthorDate: Thu Jan 22 01:34:01 2026 +1100
reformat add proxy support commit (#60432)
* reformat add proxy support commit
* static errors fix, fixes get_conn_params() causing method is not
subscriptable error
* static
* add proxy_password as default sensitive field
---
.../airflow/providers/snowflake/hooks/snowflake.py | 33 +++++-
.../tests/unit/snowflake/hooks/test_snowflake.py | 131 +++++++++++++++++++++
.../secrets_masker/secrets_masker.py | 1 +
3 files changed, 164 insertions(+), 1 deletion(-)
diff --git
a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py
b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py
index e5bf83558de..96725472493 100644
--- a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py
+++ b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py
@@ -115,7 +115,7 @@ class SnowflakeHook(DbApiHook):
BS3TextFieldWidget,
)
from flask_babel import lazy_gettext
- from wtforms import BooleanField, PasswordField, StringField
+ from wtforms import BooleanField, IntegerField, PasswordField,
StringField
return {
"account": StringField(lazy_gettext("Account"),
widget=BS3TextFieldWidget()),
@@ -130,6 +130,10 @@ class SnowflakeHook(DbApiHook):
"insecure_mode": BooleanField(
label=lazy_gettext("Insecure mode"), description="Turns off
OCSP certificate checks"
),
+ "proxy_host": StringField(lazy_gettext("Proxy Host"),
widget=BS3TextFieldWidget()),
+ "proxy_port": IntegerField(lazy_gettext("Proxy Port")),
+ "proxy_user": StringField(lazy_gettext("Proxy User"),
widget=BS3TextFieldWidget()),
+ "proxy_password": PasswordField(lazy_gettext("Proxy Password"),
widget=BS3PasswordFieldWidget()),
}
@classmethod
@@ -152,6 +156,10 @@ class SnowflakeHook(DbApiHook):
"token_endpoint": "token endpoint",
"refresh_token": "refresh token",
"scope": "scope",
+ "proxy_host": "proxy.example.com",
+ "proxy_port": "8080",
+ "proxy_user": "proxy_username",
+ "proxy_password": "proxy_password",
},
indent=1,
),
@@ -166,6 +174,10 @@ class SnowflakeHook(DbApiHook):
"private_key_file": "Path of snowflake private key (PEM
Format)",
"private_key_content": "Content to snowflake private key (PEM
format)",
"insecure_mode": "insecure mode",
+ "proxy_host": "Proxy server hostname",
+ "proxy_port": "Proxy server port",
+ "proxy_user": "Proxy username (optional)",
+ "proxy_password": "Proxy password (optional)",
},
}
@@ -431,6 +443,21 @@ class SnowflakeHook(DbApiHook):
if ocsp_fail_open is not None:
conn_config["ocsp_fail_open"] = _try_to_boolean(ocsp_fail_open)
+ # Add proxy configuration if specified
+ proxy_host = self._get_field(extra_dict, "proxy_host")
+ proxy_port = self._get_field(extra_dict, "proxy_port")
+ proxy_user = self._get_field(extra_dict, "proxy_user")
+ proxy_password = self._get_field(extra_dict, "proxy_password")
+
+ if proxy_host:
+ conn_config["proxy_host"] = proxy_host
+ if proxy_port:
+ conn_config["proxy_port"] = int(proxy_port) if
isinstance(proxy_port, str) else proxy_port
+ if proxy_user:
+ conn_config["proxy_user"] = proxy_user
+ if proxy_password:
+ conn_config["proxy_password"] = proxy_password
+
return conn_config
def _get_valid_oauth_token(
@@ -524,6 +551,10 @@ class SnowflakeHook(DbApiHook):
"client_store_temporary_credential",
"json_result_force_utf8_decoding",
"ocsp_fail_open",
+ "proxy_host",
+ "proxy_port",
+ "proxy_user",
+ "proxy_password",
]
}
)
diff --git a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py
b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py
index c5eef78787d..c28fd895d12 100644
--- a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py
+++ b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py
@@ -1309,3 +1309,134 @@ class TestPytestSnowflakeHook:
# Ensure refresh actually happened
assert mock_requests_post.call_count == 2
+
+ def test_get_conn_params_with_proxy_host_only(self):
+ """Test proxy configuration with only host specified."""
+ connection_kwargs = deepcopy(BASE_CONNECTION_KWARGS)
+ connection_kwargs["extra"]["proxy_host"] = "proxy.example.com"
+
+ with mock.patch.dict("os.environ",
AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()):
+ hook = SnowflakeHook(snowflake_conn_id="test_conn")
+ conn_params = hook._get_conn_params()
+
+ assert conn_params["proxy_host"] == "proxy.example.com"
+ assert "proxy_port" not in conn_params
+ assert "proxy_user" not in conn_params
+ assert "proxy_password" not in conn_params
+
+ def test_get_conn_params_with_proxy_host_and_port(self):
+ """Test proxy configuration with host and port."""
+ connection_kwargs = deepcopy(BASE_CONNECTION_KWARGS)
+ connection_kwargs["extra"]["proxy_host"] = "proxy.example.com"
+ connection_kwargs["extra"]["proxy_port"] = "8080"
+
+ with mock.patch.dict("os.environ",
AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()):
+ hook = SnowflakeHook(snowflake_conn_id="test_conn")
+ conn_params = hook._get_conn_params()
+
+ assert conn_params["proxy_host"] == "proxy.example.com"
+ assert conn_params["proxy_port"] == 8080
+ assert "proxy_user" not in conn_params
+ assert "proxy_password" not in conn_params
+
+ def test_get_conn_params_with_proxy_port_as_int(self):
+ """Test proxy configuration with port as integer."""
+ connection_kwargs = deepcopy(BASE_CONNECTION_KWARGS)
+ connection_kwargs["extra"]["proxy_host"] = "proxy.example.com"
+ connection_kwargs["extra"]["proxy_port"] = 8080 # Integer instead of
string
+
+ with mock.patch.dict("os.environ",
AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()):
+ hook = SnowflakeHook(snowflake_conn_id="test_conn")
+ conn_params = hook._get_conn_params()
+
+ assert conn_params["proxy_host"] == "proxy.example.com"
+ assert conn_params["proxy_port"] == 8080
+ assert isinstance(conn_params["proxy_port"], int)
+
+ def test_get_conn_params_with_proxy_full_config(self):
+ """Test proxy configuration with all parameters."""
+ connection_kwargs = deepcopy(BASE_CONNECTION_KWARGS)
+ connection_kwargs["extra"]["proxy_host"] = "proxy.example.com"
+ connection_kwargs["extra"]["proxy_port"] = "8080"
+ connection_kwargs["extra"]["proxy_user"] = "proxy_username"
+ connection_kwargs["extra"]["proxy_password"] = "proxy_password"
+
+ with mock.patch.dict("os.environ",
AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()):
+ hook = SnowflakeHook(snowflake_conn_id="test_conn")
+ conn_params = hook._get_conn_params()
+
+ assert conn_params["proxy_host"] == "proxy.example.com"
+ assert conn_params["proxy_port"] == 8080
+ assert conn_params["proxy_user"] == "proxy_username"
+ assert conn_params["proxy_password"] == "proxy_password"
+
+ def test_get_conn_params_with_proxy_backcompat_prefix(self):
+ """Test proxy configuration with backcompat prefix."""
+ connection_kwargs = deepcopy(BASE_CONNECTION_KWARGS)
+ connection_kwargs["extra"]["extra__snowflake__proxy_host"] =
"proxy.example.com"
+ connection_kwargs["extra"]["extra__snowflake__proxy_port"] = "8080"
+ connection_kwargs["extra"]["extra__snowflake__proxy_user"] =
"proxy_username"
+ connection_kwargs["extra"]["extra__snowflake__proxy_password"] =
"proxy_password"
+
+ with mock.patch.dict("os.environ",
AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()):
+ hook = SnowflakeHook(snowflake_conn_id="test_conn")
+ conn_params = hook._get_conn_params()
+
+ assert conn_params["proxy_host"] == "proxy.example.com"
+ assert conn_params["proxy_port"] == 8080
+ assert conn_params["proxy_user"] == "proxy_username"
+ assert conn_params["proxy_password"] == "proxy_password"
+
+ def test_get_conn_with_proxy_should_call_connect(self):
+ """Test that proxy parameters are passed to connector.connect()."""
+ connection_kwargs = deepcopy(BASE_CONNECTION_KWARGS)
+ connection_kwargs["extra"]["proxy_host"] = "proxy.example.com"
+ connection_kwargs["extra"]["proxy_port"] = "8080"
+ connection_kwargs["extra"]["proxy_user"] = "proxy_user"
+ connection_kwargs["extra"]["proxy_password"] = "proxy_pass"
+
+ with (
+ mock.patch.dict("os.environ",
AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()),
+
mock.patch("airflow.providers.snowflake.hooks.snowflake.connector") as
mock_connector,
+ ):
+ hook = SnowflakeHook(snowflake_conn_id="test_conn")
+ hook.get_conn()
+
+ call_args = mock_connector.connect.call_args[1]
+ assert call_args["proxy_host"] == "proxy.example.com"
+ assert call_args["proxy_port"] == 8080
+ assert call_args["proxy_user"] == "proxy_user"
+ assert call_args["proxy_password"] == "proxy_pass"
+
+ def test_sqlalchemy_uri_excludes_proxy_params(self):
+ """Test that proxy parameters are excluded from SQLAlchemy URI."""
+ connection_kwargs = deepcopy(BASE_CONNECTION_KWARGS)
+ connection_kwargs["extra"]["proxy_host"] = "proxy.example.com"
+ connection_kwargs["extra"]["proxy_port"] = "8080"
+
+ with mock.patch.dict("os.environ",
AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()):
+ hook = SnowflakeHook(snowflake_conn_id="test_conn")
+ uri = hook.get_uri()
+
+ # Proxy parameters should NOT appear in the URI
+ assert "proxy_host" not in uri
+ assert "proxy_port" not in uri
+ assert "proxy.example.com" not in uri
+ assert "8080" not in uri
+
+ def test_get_sqlalchemy_engine_with_proxy(self):
+ """Test get_sqlalchemy_engine does not include proxy params in URI but
passes to connect_args if needed."""
+ connection_kwargs = deepcopy(BASE_CONNECTION_KWARGS)
+ connection_kwargs["extra"]["proxy_host"] = "proxy.example.com"
+ connection_kwargs["extra"]["proxy_port"] = "8080"
+
+ with (
+ mock.patch.dict("os.environ",
AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()),
+
mock.patch("airflow.providers.snowflake.hooks.snowflake.create_engine") as
mock_create_engine,
+ ):
+ hook = SnowflakeHook(snowflake_conn_id="test_conn")
+ hook.get_sqlalchemy_engine()
+
+ # Check that the URI doesn't contain proxy params
+ called_uri = mock_create_engine.call_args[0][0]
+ assert "proxy_host" not in str(called_uri)
diff --git
a/shared/secrets_masker/src/airflow_shared/secrets_masker/secrets_masker.py
b/shared/secrets_masker/src/airflow_shared/secrets_masker/secrets_masker.py
index 6e8d556eb6d..c99ad568e1c 100644
--- a/shared/secrets_masker/src/airflow_shared/secrets_masker/secrets_masker.py
+++ b/shared/secrets_masker/src/airflow_shared/secrets_masker/secrets_masker.py
@@ -59,6 +59,7 @@ DEFAULT_SENSITIVE_FIELDS = frozenset(
"password",
"private_key",
"proxy",
+ "proxy_password",
"proxies",
"secret",
"token",