This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 72d94aceea8 AIP-72: Move logic to load Secrets backend on Workers 
(#48446)
72d94aceea8 is described below

commit 72d94aceea8f3d298f6270c3859972947349d21c
Author: Kaxil Naik <[email protected]>
AuthorDate: Thu Mar 27 21:19:30 2025 +0530

    AIP-72: Move logic to load Secrets backend on Workers (#48446)
    
    Currently, even the default backend is only loaded when the supervise 
process is called via initialize_secrets_backend_on_workers. This makes it 
difficult to test any operator without calling that function.
    
    The changes in this commit unifies how we handle it to the 2.x side and on 
the API-server side.
    
    
https://github.com/apache/airflow/blob/2.10.5/airflow/models/connection.py#L524-L536
---
 task-sdk/src/airflow/sdk/execution_time/context.py      |  8 ++++----
 task-sdk/src/airflow/sdk/execution_time/supervisor.py   | 12 ++++--------
 task-sdk/tests/task_sdk/definitions/test_connections.py |  4 ----
 task-sdk/tests/task_sdk/definitions/test_variables.py   |  4 ----
 4 files changed, 8 insertions(+), 20 deletions(-)

diff --git a/task-sdk/src/airflow/sdk/execution_time/context.py 
b/task-sdk/src/airflow/sdk/execution_time/context.py
index f49201760c5..d9b865c18b1 100644
--- a/task-sdk/src/airflow/sdk/execution_time/context.py
+++ b/task-sdk/src/airflow/sdk/execution_time/context.py
@@ -112,12 +112,12 @@ def _convert_variable_result_to_variable(var_result: 
VariableResult, deserialize
 
 
 def _get_connection(conn_id: str) -> Connection:
-    from airflow.sdk.execution_time.supervisor import SECRETS_BACKEND
+    from airflow.sdk.execution_time.supervisor import 
ensure_secrets_backend_loaded
     # TODO: check cache first
     # enabled only if SecretCache.init() has been called first
 
     # iterate over configured backends if not in cache (or expired)
-    for secrets_backend in SECRETS_BACKEND:
+    for secrets_backend in ensure_secrets_backend_loaded():
         try:
             conn = secrets_backend.get_connection(conn_id=conn_id)
             if conn:
@@ -155,11 +155,11 @@ def _get_connection(conn_id: str) -> Connection:
 def _get_variable(key: str, deserialize_json: bool) -> Any:
     # TODO: check cache first
     # enabled only if SecretCache.init() has been called first
-    from airflow.sdk.execution_time.supervisor import SECRETS_BACKEND
+    from airflow.sdk.execution_time.supervisor import 
ensure_secrets_backend_loaded
 
     var_val = None
     # iterate over backends if not in cache (or expired)
-    for secrets_backend in SECRETS_BACKEND:
+    for secrets_backend in ensure_secrets_backend_loaded():
         try:
             var_val = secrets_backend.get_variable(key=key)  # type: 
ignore[assignment]
             if var_val is not None:
diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py 
b/task-sdk/src/airflow/sdk/execution_time/supervisor.py
index 27dd1bc4339..deceae7fb43 100644
--- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py
+++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py
@@ -103,7 +103,7 @@ if TYPE_CHECKING:
     from airflow.typing_compat import Self
 
 
-__all__ = ["ActivitySubprocess", "WatchedSubprocess", "supervise", 
"SECRETS_BACKEND"]
+__all__ = ["ActivitySubprocess", "WatchedSubprocess", "supervise"]
 
 log: FilteringBoundLogger = structlog.get_logger(logger_name="supervisor")
 
@@ -124,8 +124,6 @@ STATES_SENT_DIRECTLY = [
     TerminalTIState.SUCCESS,
 ]
 
-SECRETS_BACKEND: list[BaseSecretsBackend] = []
-
 
 @overload
 def mkpipe() -> tuple[socket, socket]: ...
@@ -1070,14 +1068,12 @@ def forward_to_log(
             log.log(level, msg, chan=chan)
 
 
-def initialize_secrets_backend_on_workers():
+def ensure_secrets_backend_loaded() -> list[BaseSecretsBackend]:
     """Initialize the secrets backend on workers."""
     from airflow.configuration import ensure_secrets_loaded
     from airflow.secrets import DEFAULT_SECRETS_SEARCH_PATH_WORKERS
 
-    global SECRETS_BACKEND
-    SECRETS_BACKEND = 
ensure_secrets_loaded(default_backends=DEFAULT_SECRETS_SEARCH_PATH_WORKERS)
-    log.debug("Initialized secrets backend on workers", 
secrets_backend=SECRETS_BACKEND)
+    return 
ensure_secrets_loaded(default_backends=DEFAULT_SECRETS_SEARCH_PATH_WORKERS)
 
 
 def register_secrets_masker():
@@ -1145,7 +1141,7 @@ def supervise(
         processors = logging_processors(enable_pretty_log=pretty_logs)[0]
         logger = structlog.wrap_logger(underlying_logger, 
processors=processors, logger_name="task").bind()
 
-    initialize_secrets_backend_on_workers()
+    ensure_secrets_backend_loaded()
 
     register_secrets_masker()
 
diff --git a/task-sdk/tests/task_sdk/definitions/test_connections.py 
b/task-sdk/tests/task_sdk/definitions/test_connections.py
index 2e64355a884..b92090ff81c 100644
--- a/task-sdk/tests/task_sdk/definitions/test_connections.py
+++ b/task-sdk/tests/task_sdk/definitions/test_connections.py
@@ -25,7 +25,6 @@ from airflow.configuration import initialize_secrets_backends
 from airflow.exceptions import AirflowException
 from airflow.sdk import Connection
 from airflow.sdk.execution_time.comms import ConnectionResult
-from airflow.sdk.execution_time.supervisor import 
initialize_secrets_backend_on_workers
 from airflow.secrets import DEFAULT_SECRETS_SEARCH_PATH_WORKERS
 
 from tests_common.test_utils.config import conf_vars
@@ -112,7 +111,6 @@ class TestConnectionsFromSecrets:
                 ("workers", "secrets_backend_kwargs"): 
f'{{"connections_file_path": "{path}"}}',
             }
         ):
-            initialize_secrets_backend_on_workers()
             retrieved_conn = Connection.get(conn_id="CONN_A")
             assert retrieved_conn is not None
             assert retrieved_conn.conn_id == "CONN_A"
@@ -120,7 +118,6 @@ class TestConnectionsFromSecrets:
     
@mock.patch("airflow.secrets.environment_variables.EnvironmentVariablesBackend.get_connection")
     def test_get_connection_env_var(self, mock_env_get, mock_supervisor_comms):
         """Tests getting a connection from environment variable."""
-        initialize_secrets_backend_on_workers()
         mock_env_get.return_value = Connection(conn_id="something", 
conn_type="some-type")  # return None
         Connection.get("something")
         mock_env_get.assert_called_once_with(conn_id="something")
@@ -135,7 +132,6 @@ class TestConnectionsFromSecrets:
     
@mock.patch("airflow.secrets.environment_variables.EnvironmentVariablesBackend.get_connection")
     def test_backend_fallback_to_env_var(self, mock_get_connection, 
mock_env_get, mock_supervisor_comms):
         """Tests if connection retrieval falls back to environment variable 
backend if not found in secrets backend."""
-        initialize_secrets_backend_on_workers()
         mock_get_connection.return_value = None
         mock_env_get.return_value = Connection(conn_id="something", 
conn_type="some-type")
 
diff --git a/task-sdk/tests/task_sdk/definitions/test_variables.py 
b/task-sdk/tests/task_sdk/definitions/test_variables.py
index 4a88cb1f6eb..6560bdee903 100644
--- a/task-sdk/tests/task_sdk/definitions/test_variables.py
+++ b/task-sdk/tests/task_sdk/definitions/test_variables.py
@@ -24,7 +24,6 @@ import pytest
 from airflow.configuration import initialize_secrets_backends
 from airflow.sdk import Variable
 from airflow.sdk.execution_time.comms import VariableResult
-from airflow.sdk.execution_time.supervisor import 
initialize_secrets_backend_on_workers
 from airflow.secrets import DEFAULT_SECRETS_SEARCH_PATH_WORKERS
 
 from tests_common.test_utils.config import conf_vars
@@ -72,7 +71,6 @@ class TestVariableFromSecrets:
                 ("workers", "secrets_backend_kwargs"): 
f'{{"variables_file_path": "{path}"}}',
             }
         ):
-            initialize_secrets_backend_on_workers()
             retrieved_var = Variable.get(key="VAR_A")
             assert retrieved_var is not None
             assert retrieved_var == "some_value"
@@ -80,7 +78,6 @@ class TestVariableFromSecrets:
     
@mock.patch("airflow.secrets.environment_variables.EnvironmentVariablesBackend.get_variable")
     def test_get_variable_env_var(self, mock_env_get, mock_supervisor_comms):
         """Tests getting a variable from environment variable."""
-        initialize_secrets_backend_on_workers()
         mock_env_get.return_value = "fake_value"
         Variable.get(key="fake_var_key")
         mock_env_get.assert_called_once_with(key="fake_var_key")
@@ -97,7 +94,6 @@ class TestVariableFromSecrets:
     
@mock.patch("airflow.secrets.environment_variables.EnvironmentVariablesBackend.get_variable")
     def test_backend_fallback_to_env_var(self, mock_get_variable, 
mock_env_get, mock_supervisor_comms):
         """Tests if variable retrieval falls back to environment variable 
backend if not found in secrets backend."""
-        initialize_secrets_backend_on_workers()
         mock_get_variable.return_value = None
         mock_env_get.return_value = "fake_value"
 

Reply via email to