This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new a47b69e6c06 reformat add proxy support commit (#60432)
a47b69e6c06 is described below

commit a47b69e6c06ba906fb481d799089b59f5f09926d
Author: mansfieldj3 <[email protected]>
AuthorDate: Thu Jan 22 01:34:01 2026 +1100

    reformat add proxy support commit (#60432)
    
    * reformat add proxy support commit
    
    * static errors fix, fixes get_conn_params() causing method is not 
subscriptable error
    
    * static
    
    * add proxy_password as default sensitive field
---
 .../airflow/providers/snowflake/hooks/snowflake.py |  33 +++++-
 .../tests/unit/snowflake/hooks/test_snowflake.py   | 131 +++++++++++++++++++++
 .../secrets_masker/secrets_masker.py               |   1 +
 3 files changed, 164 insertions(+), 1 deletion(-)

diff --git 
a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py 
b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py
index e5bf83558de..96725472493 100644
--- a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py
+++ b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py
@@ -115,7 +115,7 @@ class SnowflakeHook(DbApiHook):
             BS3TextFieldWidget,
         )
         from flask_babel import lazy_gettext
-        from wtforms import BooleanField, PasswordField, StringField
+        from wtforms import BooleanField, IntegerField, PasswordField, 
StringField
 
         return {
             "account": StringField(lazy_gettext("Account"), 
widget=BS3TextFieldWidget()),
@@ -130,6 +130,10 @@ class SnowflakeHook(DbApiHook):
             "insecure_mode": BooleanField(
                 label=lazy_gettext("Insecure mode"), description="Turns off 
OCSP certificate checks"
             ),
+            "proxy_host": StringField(lazy_gettext("Proxy Host"), 
widget=BS3TextFieldWidget()),
+            "proxy_port": IntegerField(lazy_gettext("Proxy Port")),
+            "proxy_user": StringField(lazy_gettext("Proxy User"), 
widget=BS3TextFieldWidget()),
+            "proxy_password": PasswordField(lazy_gettext("Proxy Password"), 
widget=BS3PasswordFieldWidget()),
         }
 
     @classmethod
@@ -152,6 +156,10 @@ class SnowflakeHook(DbApiHook):
                         "token_endpoint": "token endpoint",
                         "refresh_token": "refresh token",
                         "scope": "scope",
+                        "proxy_host": "proxy.example.com",
+                        "proxy_port": "8080",
+                        "proxy_user": "proxy_username",
+                        "proxy_password": "proxy_password",
                     },
                     indent=1,
                 ),
@@ -166,6 +174,10 @@ class SnowflakeHook(DbApiHook):
                 "private_key_file": "Path of snowflake private key (PEM 
Format)",
                 "private_key_content": "Content to snowflake private key (PEM 
format)",
                 "insecure_mode": "insecure mode",
+                "proxy_host": "Proxy server hostname",
+                "proxy_port": "Proxy server port",
+                "proxy_user": "Proxy username (optional)",
+                "proxy_password": "Proxy password (optional)",
             },
         }
 
@@ -431,6 +443,21 @@ class SnowflakeHook(DbApiHook):
         if ocsp_fail_open is not None:
             conn_config["ocsp_fail_open"] = _try_to_boolean(ocsp_fail_open)
 
+        # Add proxy configuration if specified
+        proxy_host = self._get_field(extra_dict, "proxy_host")
+        proxy_port = self._get_field(extra_dict, "proxy_port")
+        proxy_user = self._get_field(extra_dict, "proxy_user")
+        proxy_password = self._get_field(extra_dict, "proxy_password")
+
+        if proxy_host:
+            conn_config["proxy_host"] = proxy_host
+        if proxy_port:
+            conn_config["proxy_port"] = int(proxy_port) if 
isinstance(proxy_port, str) else proxy_port
+        if proxy_user:
+            conn_config["proxy_user"] = proxy_user
+        if proxy_password:
+            conn_config["proxy_password"] = proxy_password
+
         return conn_config
 
     def _get_valid_oauth_token(
@@ -524,6 +551,10 @@ class SnowflakeHook(DbApiHook):
                     "client_store_temporary_credential",
                     "json_result_force_utf8_decoding",
                     "ocsp_fail_open",
+                    "proxy_host",
+                    "proxy_port",
+                    "proxy_user",
+                    "proxy_password",
                 ]
             }
         )
diff --git a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py 
b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py
index c5eef78787d..c28fd895d12 100644
--- a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py
+++ b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py
@@ -1309,3 +1309,134 @@ class TestPytestSnowflakeHook:
 
         # Ensure refresh actually happened
         assert mock_requests_post.call_count == 2
+
+    def test_get_conn_params_with_proxy_host_only(self):
+        """Test proxy configuration with only host specified."""
+        connection_kwargs = deepcopy(BASE_CONNECTION_KWARGS)
+        connection_kwargs["extra"]["proxy_host"] = "proxy.example.com"
+
+        with mock.patch.dict("os.environ", 
AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()):
+            hook = SnowflakeHook(snowflake_conn_id="test_conn")
+            conn_params = hook._get_conn_params()
+
+            assert conn_params["proxy_host"] == "proxy.example.com"
+            assert "proxy_port" not in conn_params
+            assert "proxy_user" not in conn_params
+            assert "proxy_password" not in conn_params
+
+    def test_get_conn_params_with_proxy_host_and_port(self):
+        """Test proxy configuration with host and port."""
+        connection_kwargs = deepcopy(BASE_CONNECTION_KWARGS)
+        connection_kwargs["extra"]["proxy_host"] = "proxy.example.com"
+        connection_kwargs["extra"]["proxy_port"] = "8080"
+
+        with mock.patch.dict("os.environ", 
AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()):
+            hook = SnowflakeHook(snowflake_conn_id="test_conn")
+            conn_params = hook._get_conn_params()
+
+            assert conn_params["proxy_host"] == "proxy.example.com"
+            assert conn_params["proxy_port"] == 8080
+            assert "proxy_user" not in conn_params
+            assert "proxy_password" not in conn_params
+
+    def test_get_conn_params_with_proxy_port_as_int(self):
+        """Test proxy configuration with port as integer."""
+        connection_kwargs = deepcopy(BASE_CONNECTION_KWARGS)
+        connection_kwargs["extra"]["proxy_host"] = "proxy.example.com"
+        connection_kwargs["extra"]["proxy_port"] = 8080  # Integer instead of 
string
+
+        with mock.patch.dict("os.environ", 
AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()):
+            hook = SnowflakeHook(snowflake_conn_id="test_conn")
+            conn_params = hook._get_conn_params()
+
+            assert conn_params["proxy_host"] == "proxy.example.com"
+            assert conn_params["proxy_port"] == 8080
+            assert isinstance(conn_params["proxy_port"], int)
+
+    def test_get_conn_params_with_proxy_full_config(self):
+        """Test proxy configuration with all parameters."""
+        connection_kwargs = deepcopy(BASE_CONNECTION_KWARGS)
+        connection_kwargs["extra"]["proxy_host"] = "proxy.example.com"
+        connection_kwargs["extra"]["proxy_port"] = "8080"
+        connection_kwargs["extra"]["proxy_user"] = "proxy_username"
+        connection_kwargs["extra"]["proxy_password"] = "proxy_password"
+
+        with mock.patch.dict("os.environ", 
AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()):
+            hook = SnowflakeHook(snowflake_conn_id="test_conn")
+            conn_params = hook._get_conn_params()
+
+            assert conn_params["proxy_host"] == "proxy.example.com"
+            assert conn_params["proxy_port"] == 8080
+            assert conn_params["proxy_user"] == "proxy_username"
+            assert conn_params["proxy_password"] == "proxy_password"
+
+    def test_get_conn_params_with_proxy_backcompat_prefix(self):
+        """Test proxy configuration with backcompat prefix."""
+        connection_kwargs = deepcopy(BASE_CONNECTION_KWARGS)
+        connection_kwargs["extra"]["extra__snowflake__proxy_host"] = 
"proxy.example.com"
+        connection_kwargs["extra"]["extra__snowflake__proxy_port"] = "8080"
+        connection_kwargs["extra"]["extra__snowflake__proxy_user"] = 
"proxy_username"
+        connection_kwargs["extra"]["extra__snowflake__proxy_password"] = 
"proxy_password"
+
+        with mock.patch.dict("os.environ", 
AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()):
+            hook = SnowflakeHook(snowflake_conn_id="test_conn")
+            conn_params = hook._get_conn_params()
+
+            assert conn_params["proxy_host"] == "proxy.example.com"
+            assert conn_params["proxy_port"] == 8080
+            assert conn_params["proxy_user"] == "proxy_username"
+            assert conn_params["proxy_password"] == "proxy_password"
+
+    def test_get_conn_with_proxy_should_call_connect(self):
+        """Test that proxy parameters are passed to connector.connect()."""
+        connection_kwargs = deepcopy(BASE_CONNECTION_KWARGS)
+        connection_kwargs["extra"]["proxy_host"] = "proxy.example.com"
+        connection_kwargs["extra"]["proxy_port"] = "8080"
+        connection_kwargs["extra"]["proxy_user"] = "proxy_user"
+        connection_kwargs["extra"]["proxy_password"] = "proxy_pass"
+
+        with (
+            mock.patch.dict("os.environ", 
AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()),
+            
mock.patch("airflow.providers.snowflake.hooks.snowflake.connector") as 
mock_connector,
+        ):
+            hook = SnowflakeHook(snowflake_conn_id="test_conn")
+            hook.get_conn()
+
+            call_args = mock_connector.connect.call_args[1]
+            assert call_args["proxy_host"] == "proxy.example.com"
+            assert call_args["proxy_port"] == 8080
+            assert call_args["proxy_user"] == "proxy_user"
+            assert call_args["proxy_password"] == "proxy_pass"
+
+    def test_sqlalchemy_uri_excludes_proxy_params(self):
+        """Test that proxy parameters are excluded from SQLAlchemy URI."""
+        connection_kwargs = deepcopy(BASE_CONNECTION_KWARGS)
+        connection_kwargs["extra"]["proxy_host"] = "proxy.example.com"
+        connection_kwargs["extra"]["proxy_port"] = "8080"
+
+        with mock.patch.dict("os.environ", 
AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()):
+            hook = SnowflakeHook(snowflake_conn_id="test_conn")
+            uri = hook.get_uri()
+
+            # Proxy parameters should NOT appear in the URI
+            assert "proxy_host" not in uri
+            assert "proxy_port" not in uri
+            assert "proxy.example.com" not in uri
+            assert "8080" not in uri
+
+    def test_get_sqlalchemy_engine_with_proxy(self):
+        """Test get_sqlalchemy_engine does not include proxy params in URI but 
passes to connect_args if needed."""
+        connection_kwargs = deepcopy(BASE_CONNECTION_KWARGS)
+        connection_kwargs["extra"]["proxy_host"] = "proxy.example.com"
+        connection_kwargs["extra"]["proxy_port"] = "8080"
+
+        with (
+            mock.patch.dict("os.environ", 
AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()),
+            
mock.patch("airflow.providers.snowflake.hooks.snowflake.create_engine") as 
mock_create_engine,
+        ):
+            hook = SnowflakeHook(snowflake_conn_id="test_conn")
+            hook.get_sqlalchemy_engine()
+
+            # Check that the URI doesn't contain proxy params
+            called_uri = mock_create_engine.call_args[0][0]
+            assert "proxy_host" not in str(called_uri)
diff --git 
a/shared/secrets_masker/src/airflow_shared/secrets_masker/secrets_masker.py 
b/shared/secrets_masker/src/airflow_shared/secrets_masker/secrets_masker.py
index 6e8d556eb6d..c99ad568e1c 100644
--- a/shared/secrets_masker/src/airflow_shared/secrets_masker/secrets_masker.py
+++ b/shared/secrets_masker/src/airflow_shared/secrets_masker/secrets_masker.py
@@ -59,6 +59,7 @@ DEFAULT_SENSITIVE_FIELDS = frozenset(
         "password",
         "private_key",
         "proxy",
+        "proxy_password",
         "proxies",
         "secret",
         "token",

Reply via email to