This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 72d94aceea8 AIP-72: Move logic to load Secrets backend on Workers
(#48446)
72d94aceea8 is described below
commit 72d94aceea8f3d298f6270c3859972947349d21c
Author: Kaxil Naik <[email protected]>
AuthorDate: Thu Mar 27 21:19:30 2025 +0530
AIP-72: Move logic to load Secrets backend on Workers (#48446)
Currently, even the default backend is only loaded when the supervise
process is called via initialize_secrets_backend_on_workers. This makes it
difficult to test any operator without calling that function.
The changes in this commit unifies how we handle it to the 2.x side and on
the API-server side.
https://github.com/apache/airflow/blob/2.10.5/airflow/models/connection.py#L524-L536
---
task-sdk/src/airflow/sdk/execution_time/context.py | 8 ++++----
task-sdk/src/airflow/sdk/execution_time/supervisor.py | 12 ++++--------
task-sdk/tests/task_sdk/definitions/test_connections.py | 4 ----
task-sdk/tests/task_sdk/definitions/test_variables.py | 4 ----
4 files changed, 8 insertions(+), 20 deletions(-)
diff --git a/task-sdk/src/airflow/sdk/execution_time/context.py
b/task-sdk/src/airflow/sdk/execution_time/context.py
index f49201760c5..d9b865c18b1 100644
--- a/task-sdk/src/airflow/sdk/execution_time/context.py
+++ b/task-sdk/src/airflow/sdk/execution_time/context.py
@@ -112,12 +112,12 @@ def _convert_variable_result_to_variable(var_result:
VariableResult, deserialize
def _get_connection(conn_id: str) -> Connection:
- from airflow.sdk.execution_time.supervisor import SECRETS_BACKEND
+ from airflow.sdk.execution_time.supervisor import
ensure_secrets_backend_loaded
# TODO: check cache first
# enabled only if SecretCache.init() has been called first
# iterate over configured backends if not in cache (or expired)
- for secrets_backend in SECRETS_BACKEND:
+ for secrets_backend in ensure_secrets_backend_loaded():
try:
conn = secrets_backend.get_connection(conn_id=conn_id)
if conn:
@@ -155,11 +155,11 @@ def _get_connection(conn_id: str) -> Connection:
def _get_variable(key: str, deserialize_json: bool) -> Any:
# TODO: check cache first
# enabled only if SecretCache.init() has been called first
- from airflow.sdk.execution_time.supervisor import SECRETS_BACKEND
+ from airflow.sdk.execution_time.supervisor import
ensure_secrets_backend_loaded
var_val = None
# iterate over backends if not in cache (or expired)
- for secrets_backend in SECRETS_BACKEND:
+ for secrets_backend in ensure_secrets_backend_loaded():
try:
var_val = secrets_backend.get_variable(key=key) # type:
ignore[assignment]
if var_val is not None:
diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py
b/task-sdk/src/airflow/sdk/execution_time/supervisor.py
index 27dd1bc4339..deceae7fb43 100644
--- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py
+++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py
@@ -103,7 +103,7 @@ if TYPE_CHECKING:
from airflow.typing_compat import Self
-__all__ = ["ActivitySubprocess", "WatchedSubprocess", "supervise",
"SECRETS_BACKEND"]
+__all__ = ["ActivitySubprocess", "WatchedSubprocess", "supervise"]
log: FilteringBoundLogger = structlog.get_logger(logger_name="supervisor")
@@ -124,8 +124,6 @@ STATES_SENT_DIRECTLY = [
TerminalTIState.SUCCESS,
]
-SECRETS_BACKEND: list[BaseSecretsBackend] = []
-
@overload
def mkpipe() -> tuple[socket, socket]: ...
@@ -1070,14 +1068,12 @@ def forward_to_log(
log.log(level, msg, chan=chan)
-def initialize_secrets_backend_on_workers():
+def ensure_secrets_backend_loaded() -> list[BaseSecretsBackend]:
"""Initialize the secrets backend on workers."""
from airflow.configuration import ensure_secrets_loaded
from airflow.secrets import DEFAULT_SECRETS_SEARCH_PATH_WORKERS
- global SECRETS_BACKEND
- SECRETS_BACKEND =
ensure_secrets_loaded(default_backends=DEFAULT_SECRETS_SEARCH_PATH_WORKERS)
- log.debug("Initialized secrets backend on workers",
secrets_backend=SECRETS_BACKEND)
+ return
ensure_secrets_loaded(default_backends=DEFAULT_SECRETS_SEARCH_PATH_WORKERS)
def register_secrets_masker():
@@ -1145,7 +1141,7 @@ def supervise(
processors = logging_processors(enable_pretty_log=pretty_logs)[0]
logger = structlog.wrap_logger(underlying_logger,
processors=processors, logger_name="task").bind()
- initialize_secrets_backend_on_workers()
+ ensure_secrets_backend_loaded()
register_secrets_masker()
diff --git a/task-sdk/tests/task_sdk/definitions/test_connections.py
b/task-sdk/tests/task_sdk/definitions/test_connections.py
index 2e64355a884..b92090ff81c 100644
--- a/task-sdk/tests/task_sdk/definitions/test_connections.py
+++ b/task-sdk/tests/task_sdk/definitions/test_connections.py
@@ -25,7 +25,6 @@ from airflow.configuration import initialize_secrets_backends
from airflow.exceptions import AirflowException
from airflow.sdk import Connection
from airflow.sdk.execution_time.comms import ConnectionResult
-from airflow.sdk.execution_time.supervisor import
initialize_secrets_backend_on_workers
from airflow.secrets import DEFAULT_SECRETS_SEARCH_PATH_WORKERS
from tests_common.test_utils.config import conf_vars
@@ -112,7 +111,6 @@ class TestConnectionsFromSecrets:
("workers", "secrets_backend_kwargs"):
f'{{"connections_file_path": "{path}"}}',
}
):
- initialize_secrets_backend_on_workers()
retrieved_conn = Connection.get(conn_id="CONN_A")
assert retrieved_conn is not None
assert retrieved_conn.conn_id == "CONN_A"
@@ -120,7 +118,6 @@ class TestConnectionsFromSecrets:
@mock.patch("airflow.secrets.environment_variables.EnvironmentVariablesBackend.get_connection")
def test_get_connection_env_var(self, mock_env_get, mock_supervisor_comms):
"""Tests getting a connection from environment variable."""
- initialize_secrets_backend_on_workers()
mock_env_get.return_value = Connection(conn_id="something",
conn_type="some-type") # return None
Connection.get("something")
mock_env_get.assert_called_once_with(conn_id="something")
@@ -135,7 +132,6 @@ class TestConnectionsFromSecrets:
@mock.patch("airflow.secrets.environment_variables.EnvironmentVariablesBackend.get_connection")
def test_backend_fallback_to_env_var(self, mock_get_connection,
mock_env_get, mock_supervisor_comms):
"""Tests if connection retrieval falls back to environment variable
backend if not found in secrets backend."""
- initialize_secrets_backend_on_workers()
mock_get_connection.return_value = None
mock_env_get.return_value = Connection(conn_id="something",
conn_type="some-type")
diff --git a/task-sdk/tests/task_sdk/definitions/test_variables.py
b/task-sdk/tests/task_sdk/definitions/test_variables.py
index 4a88cb1f6eb..6560bdee903 100644
--- a/task-sdk/tests/task_sdk/definitions/test_variables.py
+++ b/task-sdk/tests/task_sdk/definitions/test_variables.py
@@ -24,7 +24,6 @@ import pytest
from airflow.configuration import initialize_secrets_backends
from airflow.sdk import Variable
from airflow.sdk.execution_time.comms import VariableResult
-from airflow.sdk.execution_time.supervisor import
initialize_secrets_backend_on_workers
from airflow.secrets import DEFAULT_SECRETS_SEARCH_PATH_WORKERS
from tests_common.test_utils.config import conf_vars
@@ -72,7 +71,6 @@ class TestVariableFromSecrets:
("workers", "secrets_backend_kwargs"):
f'{{"variables_file_path": "{path}"}}',
}
):
- initialize_secrets_backend_on_workers()
retrieved_var = Variable.get(key="VAR_A")
assert retrieved_var is not None
assert retrieved_var == "some_value"
@@ -80,7 +78,6 @@ class TestVariableFromSecrets:
@mock.patch("airflow.secrets.environment_variables.EnvironmentVariablesBackend.get_variable")
def test_get_variable_env_var(self, mock_env_get, mock_supervisor_comms):
"""Tests getting a variable from environment variable."""
- initialize_secrets_backend_on_workers()
mock_env_get.return_value = "fake_value"
Variable.get(key="fake_var_key")
mock_env_get.assert_called_once_with(key="fake_var_key")
@@ -97,7 +94,6 @@ class TestVariableFromSecrets:
@mock.patch("airflow.secrets.environment_variables.EnvironmentVariablesBackend.get_variable")
def test_backend_fallback_to_env_var(self, mock_get_variable,
mock_env_get, mock_supervisor_comms):
"""Tests if variable retrieval falls back to environment variable
backend if not found in secrets backend."""
- initialize_secrets_backend_on_workers()
mock_get_variable.return_value = None
mock_env_get.return_value = "fake_value"