This is an automated email from the ASF dual-hosted git repository.

amoghdesai pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 2f051a82f44 Remove Connection dependency from shared secrets backend 
(#61523)
2f051a82f44 is described below

commit 2f051a82f440e14b3da9878988ffcb0731cfbcfb
Author: Amogh Desai <[email protected]>
AuthorDate: Mon Mar 2 16:15:46 2026 +0530

    Remove Connection dependency from shared secrets backend (#61523)
---
 airflow-core/src/airflow/configuration.py          |  9 +++-
 airflow-core/src/airflow/secrets/base_secrets.py   | 17 ++++++-
 airflow-core/tests/unit/always/test_connection.py  |  5 +-
 .../tests/unit/always/test_secrets_backends.py     |  2 +
 .../unit/amazon/aws/transfers/test_s3_to_sql.py    | 30 +++++++++---
 .../livy/tests/unit/apache/livy/hooks/test_livy.py |  3 +-
 .../src/airflow_shared/secrets_backend/base.py     | 54 +++++++++++-----------
 .../tests/secrets_backend/test_base.py             | 40 ++++++++++++++++
 task-sdk/src/airflow/sdk/bases/secrets_backend.py  | 16 ++++++-
 task-sdk/src/airflow/sdk/configuration.py          |  9 +++-
 10 files changed, 141 insertions(+), 44 deletions(-)

diff --git a/airflow-core/src/airflow/configuration.py 
b/airflow-core/src/airflow/configuration.py
index c190c3a656d..1f4fc07e4c0 100644
--- a/airflow-core/src/airflow/configuration.py
+++ b/airflow-core/src/airflow/configuration.py
@@ -870,11 +870,18 @@ def initialize_secrets_backends(
     custom_secret_backend = get_custom_secret_backend(worker_mode)
 
     if custom_secret_backend is not None:
+        from airflow.models import Connection
+
+        custom_secret_backend._set_connection_class(Connection)
         backend_list.append(custom_secret_backend)
 
     for class_name in default_backends:
+        from airflow.models import Connection
+
         secrets_backend_cls = import_string(class_name)
-        backend_list.append(secrets_backend_cls())
+        backend = secrets_backend_cls()
+        backend._set_connection_class(Connection)
+        backend_list.append(backend)
 
     return backend_list
 
diff --git a/airflow-core/src/airflow/secrets/base_secrets.py 
b/airflow-core/src/airflow/secrets/base_secrets.py
index b144bc9194f..939d993cfcc 100644
--- a/airflow-core/src/airflow/secrets/base_secrets.py
+++ b/airflow-core/src/airflow/secrets/base_secrets.py
@@ -16,8 +16,21 @@
 # under the License.
 from __future__ import annotations
 
-# Re export for compat
-from airflow._shared.secrets_backend.base import BaseSecretsBackend as 
BaseSecretsBackend
+from airflow._shared.secrets_backend.base import BaseSecretsBackend as 
_BaseSecretsBackend
+
+
+class BaseSecretsBackend(_BaseSecretsBackend):
+    """Base class for secrets backend with Core Connection as default."""
+
+    def _get_connection_class(self) -> type:
+        conn_class = getattr(self, "_connection_class", None)
+        if conn_class is None:
+            from airflow.models import Connection
+
+            self._connection_class = Connection
+            return Connection
+        return conn_class
+
 
 # Server side default secrets backend search path used by server components 
(scheduler, API server)
 DEFAULT_SECRETS_SEARCH_PATH = [
diff --git a/airflow-core/tests/unit/always/test_connection.py 
b/airflow-core/tests/unit/always/test_connection.py
index be05a820d7d..74ab193a1f6 100644
--- a/airflow-core/tests/unit/always/test_connection.py
+++ b/airflow-core/tests/unit/always/test_connection.py
@@ -595,7 +595,8 @@ class TestConnection:
             "AIRFLOW_CONN_TEST_URI": 
"postgresql://username:[email protected]:5432/the_database",
         },
     )
-    def test_using_env_var(self):
+    @mock.patch("airflow.sdk.execution_time.context._mask_connection_secrets")
+    def test_using_env_var(self, mock_mask_conn):
         from airflow.providers.sqlite.hooks.sqlite import SqliteHook
 
         conn = SqliteHook.get_connection(conn_id="test_uri")
@@ -605,7 +606,7 @@ class TestConnection:
         assert conn.password == "password!"
         assert conn.port == 5432
 
-        self.mask_secret.assert_has_calls([mock.call("password!"), 
mock.call(quote("password!"))])
+        mock_mask_conn.assert_called_once()
 
     @mock.patch.dict(
         "os.environ",
diff --git a/airflow-core/tests/unit/always/test_secrets_backends.py 
b/airflow-core/tests/unit/always/test_secrets_backends.py
index 532705dbbb4..8d2db408354 100644
--- a/airflow-core/tests/unit/always/test_secrets_backends.py
+++ b/airflow-core/tests/unit/always/test_secrets_backends.py
@@ -67,6 +67,7 @@ class TestBaseSecretsBackend:
     def test_connection_env_secrets_backend(self):
         sample_conn_1 = SampleConn("sample_1", "A")
         env_secrets_backend = EnvironmentVariablesBackend()
+        env_secrets_backend._set_connection_class(Connection)
         os.environ[sample_conn_1.var_name] = sample_conn_1.conn_uri
         conn = env_secrets_backend.get_connection(sample_conn_1.conn_id)
 
@@ -79,6 +80,7 @@ class TestBaseSecretsBackend:
             session.add(sample_conn_2.conn)
             session.commit()
         metastore_backend = MetastoreBackend()
+        metastore_backend._set_connection_class(Connection)
         conn = metastore_backend.get_connection("sample_2")
         assert sample_conn_2.host.lower() == conn.host
 
diff --git a/providers/amazon/tests/unit/amazon/aws/transfers/test_s3_to_sql.py 
b/providers/amazon/tests/unit/amazon/aws/transfers/test_s3_to_sql.py
index f050a7a23da..2ccd03cf1f0 100644
--- a/providers/amazon/tests/unit/amazon/aws/transfers/test_s3_to_sql.py
+++ b/providers/amazon/tests/unit/amazon/aws/transfers/test_s3_to_sql.py
@@ -75,9 +75,14 @@ class TestS3ToSqlTransfer:
         return bad_hook
 
     
@patch("airflow.providers.amazon.aws.transfers.s3_to_sql.NamedTemporaryFile")
-    @patch("airflow.models.connection.Connection.get_hook")
+    
@patch("airflow.providers.amazon.aws.transfers.s3_to_sql.BaseHook.get_connection")
     @patch("airflow.providers.amazon.aws.transfers.s3_to_sql.S3Hook.get_key")
-    def test_execute(self, mock_get_key, mock_hook, mock_tempfile, 
mock_parser):
+    def test_execute(self, mock_get_key, mock_get_connection, mock_tempfile, 
mock_parser):
+        mock_conn = MagicMock()
+        mock_hook = MagicMock()
+        mock_conn.get_hook.return_value = mock_hook
+        mock_get_connection.return_value = mock_conn
+
         S3ToSqlOperator(parser=mock_parser, 
**self.s3_to_sql_transfer_kwargs).execute({})
 
         mock_get_key.assert_called_once_with(
@@ -91,7 +96,7 @@ class TestS3ToSqlTransfer:
 
         
mock_parser.assert_called_once_with(mock_tempfile.return_value.__enter__.return_value.name)
 
-        mock_hook.return_value.insert_rows.assert_called_once_with(
+        mock_hook.insert_rows.assert_called_once_with(
             table=self.s3_to_sql_transfer_kwargs["table"],
             schema=self.s3_to_sql_transfer_kwargs["schema"],
             target_fields=self.s3_to_sql_transfer_kwargs["column_list"],
@@ -100,13 +105,26 @@ class TestS3ToSqlTransfer:
         )
 
     
@patch("airflow.providers.amazon.aws.transfers.s3_to_sql.NamedTemporaryFile")
-    @patch("airflow.models.connection.Connection.get_hook", 
return_value=mock_bad_hook)
+    
@patch("airflow.providers.amazon.aws.transfers.s3_to_sql.BaseHook.get_connection")
     @patch("airflow.providers.amazon.aws.transfers.s3_to_sql.S3Hook.get_key")
-    def test_execute_with_bad_hook(self, mock_get_key, mock_bad_hook, 
mock_tempfile, mock_parser):
+    def test_execute_with_bad_hook(
+        self, mock_get_key, mock_get_connection, mock_tempfile, mock_parser, 
mock_bad_hook
+    ):
+        mock_conn = MagicMock()
+        mock_conn.get_hook.return_value = mock_bad_hook
+        mock_get_connection.return_value = mock_conn
+
         with pytest.raises(AirflowException):
             S3ToSqlOperator(parser=mock_parser, 
**self.s3_to_sql_transfer_kwargs).execute({})
 
-    def test_hook_params(self, mock_parser):
+    
@patch("airflow.providers.amazon.aws.transfers.s3_to_sql.BaseHook.get_connection")
+    def test_hook_params(self, mock_get_connection, mock_parser):
+        mock_conn = MagicMock()
+        mock_hook = MagicMock()
+        mock_hook.log_sql = False
+        mock_conn.get_hook.return_value = mock_hook
+        mock_get_connection.return_value = mock_conn
+
         op = S3ToSqlOperator(
             parser=mock_parser,
             sql_hook_params={
diff --git a/providers/apache/livy/tests/unit/apache/livy/hooks/test_livy.py 
b/providers/apache/livy/tests/unit/apache/livy/hooks/test_livy.py
index 3353d796915..3a819c38064 100644
--- a/providers/apache/livy/tests/unit/apache/livy/hooks/test_livy.py
+++ b/providers/apache/livy/tests/unit/apache/livy/hooks/test_livy.py
@@ -776,8 +776,7 @@ class TestLivyAsyncHook:
 
         for conn_id, expected in connection_url_mapping.items():
             hook = LivyAsyncHook(livy_conn_id=conn_id)
-            response_conn: Connection = hook.get_connection(conn_id=conn_id)
-            assert isinstance(response_conn, Connection)
+            response_conn = hook.get_connection(conn_id=conn_id)
             assert hook._generate_base_url(response_conn) == expected
 
     def test_build_body(self):
diff --git a/shared/secrets_backend/src/airflow_shared/secrets_backend/base.py 
b/shared/secrets_backend/src/airflow_shared/secrets_backend/base.py
index c5aa62c803f..566ed379c85 100644
--- a/shared/secrets_backend/src/airflow_shared/secrets_backend/base.py
+++ b/shared/secrets_backend/src/airflow_shared/secrets_backend/base.py
@@ -64,32 +64,38 @@ class BaseSecretsBackend(ABC):
         """
         return None
 
-    @staticmethod
-    def _get_connection_class():
-        """
-        Detect which Connection class to use based on execution context.
-
-        Returns SDK Connection in worker context, core Connection in server 
context.
-        """
-        import os
-
-        process_context = os.environ.get("_AIRFLOW_PROCESS_CONTEXT", 
"").lower()
-        if process_context == "client":
-            # Client context (worker, dag processor, triggerer)
-            from airflow.sdk.definitions.connection import Connection
-
-            return Connection
+    def _set_connection_class(self, conn_class: type) -> None:
+        if not isinstance(conn_class, type):
+            raise TypeError(f"Connection class must be a type/class, got 
{type(conn_class).__name__}")
+        self._connection_class = conn_class
+
+    def _get_connection_class(self) -> type:
+        """Get the Connection class to use for deserialization."""
+        conn_class = getattr(self, "_connection_class", None)
+        if conn_class is None:
+            raise RuntimeError(
+                "Connection class not set on backend instance. "
+                "Backends must be instantiated via 
initialize_secrets_backends() "
+                "or have _connection_class set manually."
+            )
+        return conn_class
 
-        # Server context (scheduler, API server, etc.)
-        from airflow.models.connection import Connection
+    @staticmethod
+    def _deserialize_connection_value(conn_class: type, conn_id: str, value: 
str):
+        value = value.strip()
+        if value[0] == "{":
+            return conn_class.from_json(value=value, conn_id=conn_id)
 
-        return Connection
+        # TODO: Only sdk has from_uri defined on it. Is it worthwhile 
developing the core path or not?
+        if hasattr(conn_class, "from_uri"):
+            return conn_class.from_uri(conn_id=conn_id, uri=value)
+        return conn_class(conn_id=conn_id, uri=value)
 
     def deserialize_connection(self, conn_id: str, value: str):
         """
         Given a serialized representation of the airflow Connection, return an 
instance.
 
-        Auto-detects which Connection class to use based on execution context.
+        Uses the Connection class set on this class (which should be set to 
the appropriate Connection class for the execution context).
         Uses Connection.from_json() for JSON format, Connection(uri=...) for 
URI format.
 
         :param conn_id: connection id
@@ -97,15 +103,7 @@ class BaseSecretsBackend(ABC):
         :return: the deserialized Connection
         """
         conn_class = self._get_connection_class()
-
-        value = value.strip()
-        if value[0] == "{":
-            return conn_class.from_json(value=value, conn_id=conn_id)
-
-        # TODO: Only sdk has from_uri defined on it. Is it worthwhile 
developing the core path or not?
-        if hasattr(conn_class, "from_uri"):
-            return conn_class.from_uri(conn_id=conn_id, uri=value)
-        return conn_class(conn_id=conn_id, uri=value)
+        return self._deserialize_connection_value(conn_class, conn_id, value)
 
     def get_connection(self, conn_id: str, team_name: str | None = None):
         """
diff --git a/shared/secrets_backend/tests/secrets_backend/test_base.py 
b/shared/secrets_backend/tests/secrets_backend/test_base.py
index d57f272d043..cd2b7a0934a 100644
--- a/shared/secrets_backend/tests/secrets_backend/test_base.py
+++ b/shared/secrets_backend/tests/secrets_backend/test_base.py
@@ -22,6 +22,26 @@ import pytest
 from airflow_shared.secrets_backend.base import BaseSecretsBackend
 
 
+class MockConnection:
+    """Mock Connection class for testing deserialize_connection."""
+
+    def __init__(self, conn_id: str, uri: str | None = None, **kwargs):
+        self.conn_id = conn_id
+        self.uri = uri
+        self._kwargs = kwargs
+
+    @classmethod
+    def from_json(cls, value: str, conn_id: str):
+        import json
+
+        data = json.loads(value)
+        return cls(conn_id=conn_id, **data)
+
+    @classmethod
+    def from_uri(cls, conn_id: str, uri: str):
+        return cls(conn_id=conn_id, uri=uri)
+
+
 class _TestBackend(BaseSecretsBackend):
     def __init__(self, conn_values: dict[str, str] | None = None, variables: 
dict[str, str] | None = None):
         self.conn_values = conn_values or {}
@@ -93,3 +113,23 @@ class TestBaseSecretsBackend:
         backend = _TestBackend(conn_values={conn_id: f"uri_{expected}"})
         conn_value = backend.get_conn_value(conn_id)
         assert conn_value == f"uri_{expected}"
+
+    def test_deserialize_connection_json(self, sample_conn_json):
+        """Test deserialize_connection with JSON format through 
_TestBackend."""
+        backend = _TestBackend()
+        backend._set_connection_class(MockConnection)
+
+        conn = backend.deserialize_connection("test_conn", sample_conn_json)
+        assert isinstance(conn, MockConnection)
+        assert conn.conn_id == "test_conn"
+        assert conn._kwargs["conn_type"] == "mysql"
+
+    def test_deserialize_connection_uri(self, sample_conn_uri):
+        """Test deserialize_connection with URI format through _TestBackend."""
+        backend = _TestBackend()
+        backend._set_connection_class(MockConnection)
+
+        conn = backend.deserialize_connection("test_conn", sample_conn_uri)
+        assert isinstance(conn, MockConnection)
+        assert conn.conn_id == "test_conn"
+        assert conn.uri == sample_conn_uri
diff --git a/task-sdk/src/airflow/sdk/bases/secrets_backend.py 
b/task-sdk/src/airflow/sdk/bases/secrets_backend.py
index 408eca251fb..df176ff8af1 100644
--- a/task-sdk/src/airflow/sdk/bases/secrets_backend.py
+++ b/task-sdk/src/airflow/sdk/bases/secrets_backend.py
@@ -16,5 +16,17 @@
 # under the License.
 from __future__ import annotations
 
-# Re export for compat
-from airflow.sdk._shared.secrets_backend.base import BaseSecretsBackend as 
BaseSecretsBackend
+from airflow.sdk._shared.secrets_backend.base import BaseSecretsBackend as 
_BaseSecretsBackend
+
+
+class BaseSecretsBackend(_BaseSecretsBackend):
+    """Base class for secrets backend with SDK Connection as default."""
+
+    def _get_connection_class(self) -> type:
+        conn_class = getattr(self, "_connection_class", None)
+        if conn_class is None:
+            from airflow.sdk.definitions.connection import Connection
+
+            self._connection_class = Connection
+            return Connection
+        return conn_class
diff --git a/task-sdk/src/airflow/sdk/configuration.py 
b/task-sdk/src/airflow/sdk/configuration.py
index 9d0da594b4f..30e073d9ee3 100644
--- a/task-sdk/src/airflow/sdk/configuration.py
+++ b/task-sdk/src/airflow/sdk/configuration.py
@@ -227,11 +227,18 @@ def initialize_secrets_backends(
     custom_secret_backend = get_custom_secret_backend(worker_mode)
 
     if custom_secret_backend is not None:
+        from airflow.sdk.definitions.connection import Connection
+
+        custom_secret_backend._set_connection_class(Connection)
         backend_list.append(custom_secret_backend)
 
     for class_name in default_backends:
+        from airflow.sdk.definitions.connection import Connection
+
         secrets_backend_cls = import_string(class_name)
-        backend_list.append(secrets_backend_cls())
+        backend = secrets_backend_cls()
+        backend._set_connection_class(Connection)
+        backend_list.append(backend)
 
     return backend_list
 

Reply via email to