This is an automated email from the ASF dual-hosted git repository.
amoghdesai pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 2f051a82f44 Remove Connection dependency from shared secrets backend
(#61523)
2f051a82f44 is described below
commit 2f051a82f440e14b3da9878988ffcb0731cfbcfb
Author: Amogh Desai <[email protected]>
AuthorDate: Mon Mar 2 16:15:46 2026 +0530
Remove Connection dependency from shared secrets backend (#61523)
---
airflow-core/src/airflow/configuration.py | 9 +++-
airflow-core/src/airflow/secrets/base_secrets.py | 17 ++++++-
airflow-core/tests/unit/always/test_connection.py | 5 +-
.../tests/unit/always/test_secrets_backends.py | 2 +
.../unit/amazon/aws/transfers/test_s3_to_sql.py | 30 +++++++++---
.../livy/tests/unit/apache/livy/hooks/test_livy.py | 3 +-
.../src/airflow_shared/secrets_backend/base.py | 54 +++++++++++-----------
.../tests/secrets_backend/test_base.py | 40 ++++++++++++++++
task-sdk/src/airflow/sdk/bases/secrets_backend.py | 16 ++++++-
task-sdk/src/airflow/sdk/configuration.py | 9 +++-
10 files changed, 141 insertions(+), 44 deletions(-)
diff --git a/airflow-core/src/airflow/configuration.py
b/airflow-core/src/airflow/configuration.py
index c190c3a656d..1f4fc07e4c0 100644
--- a/airflow-core/src/airflow/configuration.py
+++ b/airflow-core/src/airflow/configuration.py
@@ -870,11 +870,18 @@ def initialize_secrets_backends(
custom_secret_backend = get_custom_secret_backend(worker_mode)
if custom_secret_backend is not None:
+ from airflow.models import Connection
+
+ custom_secret_backend._set_connection_class(Connection)
backend_list.append(custom_secret_backend)
for class_name in default_backends:
+ from airflow.models import Connection
+
secrets_backend_cls = import_string(class_name)
- backend_list.append(secrets_backend_cls())
+ backend = secrets_backend_cls()
+ backend._set_connection_class(Connection)
+ backend_list.append(backend)
return backend_list
diff --git a/airflow-core/src/airflow/secrets/base_secrets.py
b/airflow-core/src/airflow/secrets/base_secrets.py
index b144bc9194f..939d993cfcc 100644
--- a/airflow-core/src/airflow/secrets/base_secrets.py
+++ b/airflow-core/src/airflow/secrets/base_secrets.py
@@ -16,8 +16,21 @@
# under the License.
from __future__ import annotations
-# Re export for compat
-from airflow._shared.secrets_backend.base import BaseSecretsBackend as
BaseSecretsBackend
+from airflow._shared.secrets_backend.base import BaseSecretsBackend as
_BaseSecretsBackend
+
+
+class BaseSecretsBackend(_BaseSecretsBackend):
+ """Base class for secrets backend with Core Connection as default."""
+
+ def _get_connection_class(self) -> type:
+ conn_class = getattr(self, "_connection_class", None)
+ if conn_class is None:
+ from airflow.models import Connection
+
+ self._connection_class = Connection
+ return Connection
+ return conn_class
+
# Server side default secrets backend search path used by server components
(scheduler, API server)
DEFAULT_SECRETS_SEARCH_PATH = [
diff --git a/airflow-core/tests/unit/always/test_connection.py
b/airflow-core/tests/unit/always/test_connection.py
index be05a820d7d..74ab193a1f6 100644
--- a/airflow-core/tests/unit/always/test_connection.py
+++ b/airflow-core/tests/unit/always/test_connection.py
@@ -595,7 +595,8 @@ class TestConnection:
"AIRFLOW_CONN_TEST_URI":
"postgresql://username:[email protected]:5432/the_database",
},
)
- def test_using_env_var(self):
+ @mock.patch("airflow.sdk.execution_time.context._mask_connection_secrets")
+ def test_using_env_var(self, mock_mask_conn):
from airflow.providers.sqlite.hooks.sqlite import SqliteHook
conn = SqliteHook.get_connection(conn_id="test_uri")
@@ -605,7 +606,7 @@ class TestConnection:
assert conn.password == "password!"
assert conn.port == 5432
- self.mask_secret.assert_has_calls([mock.call("password!"),
mock.call(quote("password!"))])
+ mock_mask_conn.assert_called_once()
@mock.patch.dict(
"os.environ",
diff --git a/airflow-core/tests/unit/always/test_secrets_backends.py
b/airflow-core/tests/unit/always/test_secrets_backends.py
index 532705dbbb4..8d2db408354 100644
--- a/airflow-core/tests/unit/always/test_secrets_backends.py
+++ b/airflow-core/tests/unit/always/test_secrets_backends.py
@@ -67,6 +67,7 @@ class TestBaseSecretsBackend:
def test_connection_env_secrets_backend(self):
sample_conn_1 = SampleConn("sample_1", "A")
env_secrets_backend = EnvironmentVariablesBackend()
+ env_secrets_backend._set_connection_class(Connection)
os.environ[sample_conn_1.var_name] = sample_conn_1.conn_uri
conn = env_secrets_backend.get_connection(sample_conn_1.conn_id)
@@ -79,6 +80,7 @@ class TestBaseSecretsBackend:
session.add(sample_conn_2.conn)
session.commit()
metastore_backend = MetastoreBackend()
+ metastore_backend._set_connection_class(Connection)
conn = metastore_backend.get_connection("sample_2")
assert sample_conn_2.host.lower() == conn.host
diff --git a/providers/amazon/tests/unit/amazon/aws/transfers/test_s3_to_sql.py
b/providers/amazon/tests/unit/amazon/aws/transfers/test_s3_to_sql.py
index f050a7a23da..2ccd03cf1f0 100644
--- a/providers/amazon/tests/unit/amazon/aws/transfers/test_s3_to_sql.py
+++ b/providers/amazon/tests/unit/amazon/aws/transfers/test_s3_to_sql.py
@@ -75,9 +75,14 @@ class TestS3ToSqlTransfer:
return bad_hook
@patch("airflow.providers.amazon.aws.transfers.s3_to_sql.NamedTemporaryFile")
- @patch("airflow.models.connection.Connection.get_hook")
+
@patch("airflow.providers.amazon.aws.transfers.s3_to_sql.BaseHook.get_connection")
@patch("airflow.providers.amazon.aws.transfers.s3_to_sql.S3Hook.get_key")
- def test_execute(self, mock_get_key, mock_hook, mock_tempfile,
mock_parser):
+ def test_execute(self, mock_get_key, mock_get_connection, mock_tempfile,
mock_parser):
+ mock_conn = MagicMock()
+ mock_hook = MagicMock()
+ mock_conn.get_hook.return_value = mock_hook
+ mock_get_connection.return_value = mock_conn
+
S3ToSqlOperator(parser=mock_parser,
**self.s3_to_sql_transfer_kwargs).execute({})
mock_get_key.assert_called_once_with(
@@ -91,7 +96,7 @@ class TestS3ToSqlTransfer:
mock_parser.assert_called_once_with(mock_tempfile.return_value.__enter__.return_value.name)
- mock_hook.return_value.insert_rows.assert_called_once_with(
+ mock_hook.insert_rows.assert_called_once_with(
table=self.s3_to_sql_transfer_kwargs["table"],
schema=self.s3_to_sql_transfer_kwargs["schema"],
target_fields=self.s3_to_sql_transfer_kwargs["column_list"],
@@ -100,13 +105,26 @@ class TestS3ToSqlTransfer:
)
@patch("airflow.providers.amazon.aws.transfers.s3_to_sql.NamedTemporaryFile")
- @patch("airflow.models.connection.Connection.get_hook",
return_value=mock_bad_hook)
+
@patch("airflow.providers.amazon.aws.transfers.s3_to_sql.BaseHook.get_connection")
@patch("airflow.providers.amazon.aws.transfers.s3_to_sql.S3Hook.get_key")
- def test_execute_with_bad_hook(self, mock_get_key, mock_bad_hook,
mock_tempfile, mock_parser):
+ def test_execute_with_bad_hook(
+ self, mock_get_key, mock_get_connection, mock_tempfile, mock_parser,
mock_bad_hook
+ ):
+ mock_conn = MagicMock()
+ mock_conn.get_hook.return_value = mock_bad_hook
+ mock_get_connection.return_value = mock_conn
+
with pytest.raises(AirflowException):
S3ToSqlOperator(parser=mock_parser,
**self.s3_to_sql_transfer_kwargs).execute({})
- def test_hook_params(self, mock_parser):
+
@patch("airflow.providers.amazon.aws.transfers.s3_to_sql.BaseHook.get_connection")
+ def test_hook_params(self, mock_get_connection, mock_parser):
+ mock_conn = MagicMock()
+ mock_hook = MagicMock()
+ mock_hook.log_sql = False
+ mock_conn.get_hook.return_value = mock_hook
+ mock_get_connection.return_value = mock_conn
+
op = S3ToSqlOperator(
parser=mock_parser,
sql_hook_params={
diff --git a/providers/apache/livy/tests/unit/apache/livy/hooks/test_livy.py
b/providers/apache/livy/tests/unit/apache/livy/hooks/test_livy.py
index 3353d796915..3a819c38064 100644
--- a/providers/apache/livy/tests/unit/apache/livy/hooks/test_livy.py
+++ b/providers/apache/livy/tests/unit/apache/livy/hooks/test_livy.py
@@ -776,8 +776,7 @@ class TestLivyAsyncHook:
for conn_id, expected in connection_url_mapping.items():
hook = LivyAsyncHook(livy_conn_id=conn_id)
- response_conn: Connection = hook.get_connection(conn_id=conn_id)
- assert isinstance(response_conn, Connection)
+ response_conn = hook.get_connection(conn_id=conn_id)
assert hook._generate_base_url(response_conn) == expected
def test_build_body(self):
diff --git a/shared/secrets_backend/src/airflow_shared/secrets_backend/base.py
b/shared/secrets_backend/src/airflow_shared/secrets_backend/base.py
index c5aa62c803f..566ed379c85 100644
--- a/shared/secrets_backend/src/airflow_shared/secrets_backend/base.py
+++ b/shared/secrets_backend/src/airflow_shared/secrets_backend/base.py
@@ -64,32 +64,38 @@ class BaseSecretsBackend(ABC):
"""
return None
- @staticmethod
- def _get_connection_class():
- """
- Detect which Connection class to use based on execution context.
-
- Returns SDK Connection in worker context, core Connection in server
context.
- """
- import os
-
- process_context = os.environ.get("_AIRFLOW_PROCESS_CONTEXT",
"").lower()
- if process_context == "client":
- # Client context (worker, dag processor, triggerer)
- from airflow.sdk.definitions.connection import Connection
-
- return Connection
+ def _set_connection_class(self, conn_class: type) -> None:
+ if not isinstance(conn_class, type):
+ raise TypeError(f"Connection class must be a type/class, got
{type(conn_class).__name__}")
+ self._connection_class = conn_class
+
+ def _get_connection_class(self) -> type:
+ """Get the Connection class to use for deserialization."""
+ conn_class = getattr(self, "_connection_class", None)
+ if conn_class is None:
+ raise RuntimeError(
+ "Connection class not set on backend instance. "
+ "Backends must be instantiated via
initialize_secrets_backends() "
+ "or have _connection_class set manually."
+ )
+ return conn_class
- # Server context (scheduler, API server, etc.)
- from airflow.models.connection import Connection
+ @staticmethod
+ def _deserialize_connection_value(conn_class: type, conn_id: str, value:
str):
+ value = value.strip()
+ if value[0] == "{":
+ return conn_class.from_json(value=value, conn_id=conn_id)
- return Connection
+ # TODO: Only sdk has from_uri defined on it. Is it worthwhile
developing the core path or not?
+ if hasattr(conn_class, "from_uri"):
+ return conn_class.from_uri(conn_id=conn_id, uri=value)
+ return conn_class(conn_id=conn_id, uri=value)
def deserialize_connection(self, conn_id: str, value: str):
"""
Given a serialized representation of the airflow Connection, return an
instance.
- Auto-detects which Connection class to use based on execution context.
+ Uses the Connection class set on this class (which should be set to
the appropriate Connection class for the execution context).
Uses Connection.from_json() for JSON format, Connection(uri=...) for
URI format.
:param conn_id: connection id
@@ -97,15 +103,7 @@ class BaseSecretsBackend(ABC):
:return: the deserialized Connection
"""
conn_class = self._get_connection_class()
-
- value = value.strip()
- if value[0] == "{":
- return conn_class.from_json(value=value, conn_id=conn_id)
-
- # TODO: Only sdk has from_uri defined on it. Is it worthwhile
developing the core path or not?
- if hasattr(conn_class, "from_uri"):
- return conn_class.from_uri(conn_id=conn_id, uri=value)
- return conn_class(conn_id=conn_id, uri=value)
+ return self._deserialize_connection_value(conn_class, conn_id, value)
def get_connection(self, conn_id: str, team_name: str | None = None):
"""
diff --git a/shared/secrets_backend/tests/secrets_backend/test_base.py
b/shared/secrets_backend/tests/secrets_backend/test_base.py
index d57f272d043..cd2b7a0934a 100644
--- a/shared/secrets_backend/tests/secrets_backend/test_base.py
+++ b/shared/secrets_backend/tests/secrets_backend/test_base.py
@@ -22,6 +22,26 @@ import pytest
from airflow_shared.secrets_backend.base import BaseSecretsBackend
+class MockConnection:
+ """Mock Connection class for testing deserialize_connection."""
+
+ def __init__(self, conn_id: str, uri: str | None = None, **kwargs):
+ self.conn_id = conn_id
+ self.uri = uri
+ self._kwargs = kwargs
+
+ @classmethod
+ def from_json(cls, value: str, conn_id: str):
+ import json
+
+ data = json.loads(value)
+ return cls(conn_id=conn_id, **data)
+
+ @classmethod
+ def from_uri(cls, conn_id: str, uri: str):
+ return cls(conn_id=conn_id, uri=uri)
+
+
class _TestBackend(BaseSecretsBackend):
def __init__(self, conn_values: dict[str, str] | None = None, variables:
dict[str, str] | None = None):
self.conn_values = conn_values or {}
@@ -93,3 +113,23 @@ class TestBaseSecretsBackend:
backend = _TestBackend(conn_values={conn_id: f"uri_{expected}"})
conn_value = backend.get_conn_value(conn_id)
assert conn_value == f"uri_{expected}"
+
+ def test_deserialize_connection_json(self, sample_conn_json):
+ """Test deserialize_connection with JSON format through
_TestBackend."""
+ backend = _TestBackend()
+ backend._set_connection_class(MockConnection)
+
+ conn = backend.deserialize_connection("test_conn", sample_conn_json)
+ assert isinstance(conn, MockConnection)
+ assert conn.conn_id == "test_conn"
+ assert conn._kwargs["conn_type"] == "mysql"
+
+ def test_deserialize_connection_uri(self, sample_conn_uri):
+ """Test deserialize_connection with URI format through _TestBackend."""
+ backend = _TestBackend()
+ backend._set_connection_class(MockConnection)
+
+ conn = backend.deserialize_connection("test_conn", sample_conn_uri)
+ assert isinstance(conn, MockConnection)
+ assert conn.conn_id == "test_conn"
+ assert conn.uri == sample_conn_uri
diff --git a/task-sdk/src/airflow/sdk/bases/secrets_backend.py
b/task-sdk/src/airflow/sdk/bases/secrets_backend.py
index 408eca251fb..df176ff8af1 100644
--- a/task-sdk/src/airflow/sdk/bases/secrets_backend.py
+++ b/task-sdk/src/airflow/sdk/bases/secrets_backend.py
@@ -16,5 +16,17 @@
# under the License.
from __future__ import annotations
-# Re export for compat
-from airflow.sdk._shared.secrets_backend.base import BaseSecretsBackend as
BaseSecretsBackend
+from airflow.sdk._shared.secrets_backend.base import BaseSecretsBackend as
_BaseSecretsBackend
+
+
+class BaseSecretsBackend(_BaseSecretsBackend):
+ """Base class for secrets backend with SDK Connection as default."""
+
+ def _get_connection_class(self) -> type:
+ conn_class = getattr(self, "_connection_class", None)
+ if conn_class is None:
+ from airflow.sdk.definitions.connection import Connection
+
+ self._connection_class = Connection
+ return Connection
+ return conn_class
diff --git a/task-sdk/src/airflow/sdk/configuration.py
b/task-sdk/src/airflow/sdk/configuration.py
index 9d0da594b4f..30e073d9ee3 100644
--- a/task-sdk/src/airflow/sdk/configuration.py
+++ b/task-sdk/src/airflow/sdk/configuration.py
@@ -227,11 +227,18 @@ def initialize_secrets_backends(
custom_secret_backend = get_custom_secret_backend(worker_mode)
if custom_secret_backend is not None:
+ from airflow.sdk.definitions.connection import Connection
+
+ custom_secret_backend._set_connection_class(Connection)
backend_list.append(custom_secret_backend)
for class_name in default_backends:
+ from airflow.sdk.definitions.connection import Connection
+
secrets_backend_cls = import_string(class_name)
- backend_list.append(secrets_backend_cls())
+ backend = secrets_backend_cls()
+ backend._set_connection_class(Connection)
+ backend_list.append(backend)
return backend_list