This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch v3-1-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 3a7ac635802287839c4989a78f07220e153b1fd7
Author: Kaxil Naik <[email protected]>
AuthorDate: Tue Sep 23 04:38:23 2025 +0100

    Use ``SecretCache`` for connection and variable access in task sdk (#55972)
    
    (cherry picked from commit 049123e8c730ea471dc5e93b69a21d6b1fddba91)
---
 task-sdk/src/airflow/sdk/definitions/connection.py |  86 +++++-
 task-sdk/src/airflow/sdk/execution_time/context.py |  83 +++++-
 task-sdk/tests/task_sdk/bases/test_hook.py         |   2 +-
 .../{test_connections.py => test_connection.py}    | 124 ++++++++
 .../tests/task_sdk/execution_time/test_context.py  |   2 +-
 .../task_sdk/execution_time/test_context_cache.py  | 332 +++++++++++++++++++++
 .../task_sdk/execution_time/test_task_runner.py    |   2 +-
 7 files changed, 612 insertions(+), 19 deletions(-)

diff --git a/task-sdk/src/airflow/sdk/definitions/connection.py 
b/task-sdk/src/airflow/sdk/definitions/connection.py
index b0e1372f52f..b1fb1190bc3 100644
--- a/task-sdk/src/airflow/sdk/definitions/connection.py
+++ b/task-sdk/src/airflow/sdk/definitions/connection.py
@@ -22,6 +22,7 @@ import json
 import logging
 from json import JSONDecodeError
 from typing import Any
+from urllib.parse import parse_qsl, quote, unquote, urlencode, urlsplit
 
 import attrs
 
@@ -31,6 +32,26 @@ from airflow.sdk.exceptions import AirflowRuntimeError, 
ErrorType
 log = logging.getLogger(__name__)
 
 
+def _parse_netloc_to_hostname(uri_parts):
+    """
+    Parse a URI string to get the correct Hostname.
+
+    ``urlparse(...).hostname`` or ``urlsplit(...).hostname`` returns value 
into the lowercase in most cases,
+    there are some exclusion exists for specific cases such as 
https://bugs.python.org/issue32323
+    In case if expected to get a path as part of hostname path,
+    then default behavior ``urlparse``/``urlsplit`` is unexpected.
+    """
+    hostname = unquote(uri_parts.hostname or "")
+    if "/" in hostname:
+        hostname = uri_parts.netloc
+        if "@" in hostname:
+            hostname = hostname.rsplit("@", 1)[1]
+        if ":" in hostname:
+            hostname = hostname.split(":", 1)[0]
+        hostname = unquote(hostname)
+    return hostname
+
+
 def _prune_dict(val: Any, mode="strict"):
     """
     Given dict ``val``, returns new dict based on ``val`` with all empty 
elements removed.
@@ -104,7 +125,7 @@ class Connection:
 
     def get_uri(self) -> str:
         """Generate and return connection in URI format."""
-        from urllib.parse import parse_qsl, quote, urlencode
+        from urllib.parse import parse_qsl
 
         if self.conn_type and "_" in self.conn_type:
             log.warning(
@@ -156,10 +177,11 @@ class Connection:
 
         if self.extra:
             try:
-                query: str | None = urlencode(self.extra_dejson)
+                extra_dejson = self.extra_dejson
+                query: str | None = urlencode(extra_dejson)
             except TypeError:
                 query = None
-            if query and self.extra_dejson == dict(parse_qsl(query, 
keep_blank_values=True)):
+            if query and extra_dejson == dict(parse_qsl(query, 
keep_blank_values=True)):
                 uri += ("?" if self.schema else "/?") + query
             else:
                 uri += ("?" if self.schema else "/?") + 
urlencode({self.EXTRA_KEY: self.extra})
@@ -316,6 +338,64 @@ class Connection:
         conn_repr.pop("conn_id", None)
         return json.dumps(conn_repr)
 
+    @classmethod
+    def from_uri(cls, uri: str, conn_id: str) -> Connection:
+        """
+        Create a Connection from a URI string.
+
+        :param uri: URI string to parse
+        :param conn_id: Connection ID to assign to the connection
+        :return: Connection object
+        """
+        schemes_count_in_uri = uri.count("://")
+        if schemes_count_in_uri > 2:
+            raise AirflowException(f"Invalid connection string: {uri}.")
+        host_with_protocol = schemes_count_in_uri == 2
+        uri_parts = urlsplit(uri)
+        conn_type = uri_parts.scheme
+        normalized_conn_type = cls._normalize_conn_type(conn_type)
+        rest_of_the_url = uri.replace(f"{conn_type}://", ("" if 
host_with_protocol else "//"))
+        if host_with_protocol:
+            uri_splits = rest_of_the_url.split("://", 1)
+            if "@" in uri_splits[0] or ":" in uri_splits[0]:
+                raise AirflowException(f"Invalid connection string: {uri}.")
+        uri_parts = urlsplit(rest_of_the_url)
+        protocol = uri_parts.scheme if host_with_protocol else None
+        host = _parse_netloc_to_hostname(uri_parts)
+        parsed_host = cls._create_host(protocol, host)
+        quoted_schema = uri_parts.path[1:]
+        schema = unquote(quoted_schema) if quoted_schema else quoted_schema
+        login = unquote(uri_parts.username) if uri_parts.username else 
uri_parts.username
+        password = unquote(uri_parts.password) if uri_parts.password else 
uri_parts.password
+        port = uri_parts.port
+        extra = None
+        if uri_parts.query:
+            query = dict(parse_qsl(uri_parts.query, keep_blank_values=True))
+            if cls.EXTRA_KEY in query:
+                extra = query[cls.EXTRA_KEY]
+            else:
+                extra = json.dumps(query)
+
+        return cls(
+            conn_id=conn_id,
+            conn_type=normalized_conn_type,
+            host=parsed_host,
+            schema=schema,
+            login=login,
+            password=password,
+            port=port,
+            extra=extra,
+        )
+
+    @staticmethod
+    def _create_host(protocol, host) -> str | None:
+        """Return the connection host with the protocol."""
+        if not host:
+            return host
+        if protocol:
+            return f"{protocol}://{host}"
+        return host
+
     @staticmethod
     def _normalize_conn_type(conn_type):
         if conn_type == "postgresql":
diff --git a/task-sdk/src/airflow/sdk/execution_time/context.py 
b/task-sdk/src/airflow/sdk/execution_time/context.py
index 2107cd853ab..4580ef4651e 100644
--- a/task-sdk/src/airflow/sdk/execution_time/context.py
+++ b/task-sdk/src/airflow/sdk/execution_time/context.py
@@ -116,6 +116,14 @@ def _process_connection_result_conn(conn_result: 
ReceiveMsgType | None) -> Conne
     return Connection(**conn_result.model_dump(exclude={"type"}, 
by_alias=True))
 
 
+def _mask_connection_secrets(conn: Connection) -> None:
+    """Mask sensitive connection fields from logs."""
+    if conn.password:
+        mask_secret(conn.password)
+    if conn.extra:
+        mask_secret(conn.extra)
+
+
 def _convert_variable_result_to_variable(var_result: VariableResult, 
deserialize_json: bool) -> Variable:
     from airflow.sdk.definitions.variable import Variable
 
@@ -127,22 +135,28 @@ def _convert_variable_result_to_variable(var_result: 
VariableResult, deserialize
 
 
 def _get_connection(conn_id: str) -> Connection:
+    from airflow.sdk.execution_time.cache import SecretCache
     from airflow.sdk.execution_time.supervisor import 
ensure_secrets_backend_loaded
 
-    # TODO: check cache first (also in _async_get_connection)
-    # enabled only if SecretCache.init() has been called first
+    # Check cache first (optional; only on dag processor)
+    try:
+        uri = SecretCache.get_connection_uri(conn_id)
+        from airflow.sdk.definitions.connection import Connection
+
+        conn = Connection.from_uri(uri, conn_id=conn_id)
+        _mask_connection_secrets(conn)
+        return conn
+    except SecretCache.NotPresentException:
+        pass  # continue to backends
 
     # iterate over configured backends if not in cache (or expired)
     backends = ensure_secrets_backend_loaded()
     for secrets_backend in backends:
         try:
-            conn = secrets_backend.get_connection(conn_id=conn_id)
+            conn = secrets_backend.get_connection(conn_id=conn_id)  # type: 
ignore[assignment]
             if conn:
-                # TODO: this should probably be in get conn
-                if conn.password:
-                    mask_secret(conn.password)
-                if conn.extra:
-                    mask_secret(conn.extra)
+                SecretCache.save_connection_uri(conn_id, conn.get_uri())
+                _mask_connection_secrets(conn)
                 return conn
         except Exception:
             log.exception(
@@ -168,12 +182,27 @@ def _get_connection(conn_id: str) -> Connection:
 
     msg = SUPERVISOR_COMMS.send(GetConnection(conn_id=conn_id))
 
-    return _process_connection_result_conn(msg)
+    conn = _process_connection_result_conn(msg)
+    SecretCache.save_connection_uri(conn_id, conn.get_uri())
+    return conn
 
 
 async def _async_get_connection(conn_id: str) -> Connection:
     from asgiref.sync import sync_to_async
 
+    from airflow.sdk.execution_time.cache import SecretCache
+
+    # Check cache first
+    try:
+        uri = SecretCache.get_connection_uri(conn_id)
+        from airflow.sdk.definitions.connection import Connection
+
+        conn = Connection.from_uri(uri, conn_id=conn_id)
+        _mask_connection_secrets(conn)
+        return conn
+    except SecretCache.NotPresentException:
+        pass  # continue to API
+
     from airflow.sdk.execution_time.comms import GetConnection
     from airflow.sdk.execution_time.supervisor import 
ensure_secrets_backend_loaded
     from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
@@ -185,7 +214,7 @@ async def _async_get_connection(conn_id: str) -> Connection:
     backends = ensure_secrets_backend_loaded()
     for secrets_backend in backends:
         try:
-            conn = await sync_to_async(secrets_backend.get_connection)(conn_id)
+            conn = await 
sync_to_async(secrets_backend.get_connection)(conn_id)  # type: 
ignore[assignment]
             if conn:
                 # TODO: this should probably be in get conn
                 if conn.password:
@@ -199,20 +228,38 @@ async def _async_get_connection(conn_id: str) -> 
Connection:
 
     # If no secrets backend has the connection, fall back to API server
     msg = await SUPERVISOR_COMMS.asend(GetConnection(conn_id=conn_id))
-    return _process_connection_result_conn(msg)
+    conn = _process_connection_result_conn(msg)
+    SecretCache.save_connection_uri(conn_id, conn.get_uri())
+    _mask_connection_secrets(conn)
+    return conn
 
 
 def _get_variable(key: str, deserialize_json: bool) -> Any:
-    # TODO: check cache first
-    # enabled only if SecretCache.init() has been called first
+    from airflow.sdk.execution_time.cache import SecretCache
     from airflow.sdk.execution_time.supervisor import 
ensure_secrets_backend_loaded
 
+    # Check cache first
+    try:
+        var_val = SecretCache.get_variable(key)
+        if var_val is not None:
+            if deserialize_json:
+                import json
+
+                var_val = json.loads(var_val)
+            if isinstance(var_val, str):
+                mask_secret(var_val, key)
+            return var_val
+    except SecretCache.NotPresentException:
+        pass  # Continue to check backends
+
     backends = ensure_secrets_backend_loaded()
     # iterate over backends if not in cache (or expired)
     for secrets_backend in backends:
         try:
             var_val = secrets_backend.get_variable(key=key)
             if var_val is not None:
+                # Save raw value before deserialization to maintain cache 
consistency
+                SecretCache.save_variable(key, var_val)
                 if deserialize_json:
                     import json
 
@@ -248,6 +295,8 @@ def _get_variable(key: str, deserialize_json: bool) -> Any:
     if TYPE_CHECKING:
         assert isinstance(msg, VariableResult)
     variable = _convert_variable_result_to_variable(msg, deserialize_json)
+    # Save raw value to ensure cache consistency regardless of 
deserialize_json parameter
+    SecretCache.save_variable(key, msg.value)
     return variable.value
 
 
@@ -259,6 +308,7 @@ def _set_variable(key: str, value: Any, description: str | 
None = None, serializ
     #   keep Task SDK as a separate package than execution time mods.
     import json
 
+    from airflow.sdk.execution_time.cache import SecretCache
     from airflow.sdk.execution_time.comms import PutVariable
     from airflow.sdk.execution_time.supervisor import 
ensure_secrets_backend_loaded
     from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
@@ -292,6 +342,9 @@ def _set_variable(key: str, value: Any, description: str | 
None = None, serializ
 
     SUPERVISOR_COMMS.send(PutVariable(key=key, value=value, 
description=description))
 
+    # Invalidate cache after setting the variable
+    SecretCache.invalidate_variable(key)
+
 
 def _delete_variable(key: str) -> None:
     # TODO: This should probably be moved to a separate module like 
`airflow.sdk.execution_time.comms`
@@ -299,6 +352,7 @@ def _delete_variable(key: str) -> None:
     #   A reason to not move it to `airflow.sdk.execution_time.comms` is that 
it
     #   will make that module depend on Task SDK, which is not ideal because 
we intend to
     #   keep Task SDK as a separate package than execution time mods.
+    from airflow.sdk.execution_time.cache import SecretCache
     from airflow.sdk.execution_time.comms import DeleteVariable
     from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
 
@@ -306,6 +360,9 @@ def _delete_variable(key: str) -> None:
     if TYPE_CHECKING:
         assert isinstance(msg, OKResponse)
 
+    # Invalidate cache after deleting the variable
+    SecretCache.invalidate_variable(key)
+
 
 class ConnectionAccessor:
     """Wrapper to access Connection entries in template."""
diff --git a/task-sdk/tests/task_sdk/bases/test_hook.py 
b/task-sdk/tests/task_sdk/bases/test_hook.py
index 4c691266f82..df1c28593e9 100644
--- a/task-sdk/tests/task_sdk/bases/test_hook.py
+++ b/task-sdk/tests/task_sdk/bases/test_hook.py
@@ -55,7 +55,7 @@ class TestBaseHook:
 
         hook = BaseHook(logger_name="")
         hook.get_connection(conn_id="test_conn")
-        mock_supervisor_comms.send.assert_called_once_with(
+        mock_supervisor_comms.send.assert_any_call(
             msg=GetConnection(conn_id="test_conn"),
         )
 
diff --git a/task-sdk/tests/task_sdk/definitions/test_connections.py 
b/task-sdk/tests/task_sdk/definitions/test_connection.py
similarity index 66%
rename from task-sdk/tests/task_sdk/definitions/test_connections.py
rename to task-sdk/tests/task_sdk/definitions/test_connection.py
index 05a80f59c7c..eff5ed1e02b 100644
--- a/task-sdk/tests/task_sdk/definitions/test_connections.py
+++ b/task-sdk/tests/task_sdk/definitions/test_connection.py
@@ -273,3 +273,127 @@ class TestConnectionsFromSecrets:
         # mock_env is only called when LocalFilesystemBackend doesn't have it
         mock_env_get.assert_called()
         assert conn == Connection(conn_id="something", conn_type="some-type")
+
+
+class TestConnectionFromUri:
+    """Test the Connection.from_uri() classmethod."""
+
+    def test_from_uri_basic(self):
+        """Test basic URI parsing."""
+        uri = "postgres://user:pass@host:5432/db"
+        conn = Connection.from_uri(uri, conn_id="test_conn")
+
+        assert conn.conn_id == "test_conn"
+        assert conn.conn_type == "postgres"
+        assert conn.host == "host"
+        assert conn.login == "user"
+        assert conn.password == "pass"
+        assert conn.port == 5432
+        assert conn.schema == "db"
+        assert conn.extra is None
+
+    def test_from_uri_with_query_params(self):
+        """Test URI parsing with query parameters."""
+        uri = "mysql://user:pass@host:3306/db?charset=utf8&timeout=30"
+        conn = Connection.from_uri(uri, conn_id="test_conn")
+
+        assert conn.conn_id == "test_conn"
+        assert conn.conn_type == "mysql"
+        assert conn.host == "host"
+        assert conn.login == "user"
+        assert conn.password == "pass"
+        assert conn.port == 3306
+        assert conn.schema == "db"
+        # Extra should be JSON string with query params
+        extra_dict = json.loads(conn.extra)
+        assert extra_dict == {"charset": "utf8", "timeout": "30"}
+
+    def test_from_uri_with_extra_key(self):
+        """Test URI parsing with __extra__ query parameter."""
+        extra_value = json.dumps({"ssl_mode": "require", "connect_timeout": 
10})
+        uri = f"postgres://user:pass@host:5432/db?__extra__={extra_value}"
+        conn = Connection.from_uri(uri, conn_id="test_conn")
+
+        assert conn.conn_id == "test_conn"
+        assert conn.conn_type == "postgres"
+        assert conn.extra == extra_value
+
+    def test_from_uri_with_protocol_in_host(self):
+        """Test URI parsing with protocol in host (double ://)."""
+        uri = "http://https://example.com:8080/path";
+        conn = Connection.from_uri(uri, conn_id="test_conn")
+
+        assert conn.conn_id == "test_conn"
+        assert conn.conn_type == "http"
+        assert conn.host == "https://example.com";
+        assert conn.port == 8080
+        assert conn.schema == "path"
+
+    def test_from_uri_encoded_credentials(self):
+        """Test URI parsing with URL-encoded credentials."""
+        uri = "postgres://user%40domain:pass%21word@host:5432/db"
+        conn = Connection.from_uri(uri, conn_id="test_conn")
+
+        assert conn.conn_id == "test_conn"
+        assert conn.conn_type == "postgres"
+        assert conn.login == "user@domain"
+        assert conn.password == "pass!word"
+
+    def test_from_uri_minimal(self):
+        """Test URI parsing with minimal information."""
+        uri = "redis://"
+        conn = Connection.from_uri(uri, conn_id="test_conn")
+
+        assert conn.conn_id == "test_conn"
+        assert conn.conn_type == "redis"
+        assert conn.host == ""  # urlsplit returns empty string, not None for 
minimal URI
+        assert conn.login is None
+        assert conn.password is None
+        assert conn.port is None
+        assert conn.schema == ""
+
+    def test_from_uri_conn_type_normalization(self):
+        """Test that connection types are normalized."""
+        # postgresql -> postgres
+        uri = "postgresql://user:pass@host:5432/db"
+        conn = Connection.from_uri(uri, conn_id="test_conn")
+        assert conn.conn_type == "postgres"
+
+        # hyphen to underscore
+        uri = "google-cloud-platform://user:pass@host/db"
+        conn = Connection.from_uri(uri, conn_id="test_conn")
+        assert conn.conn_type == "google_cloud_platform"
+
+    def test_from_uri_too_many_schemes_error(self):
+        """Test that too many schemes in URI raises an error."""
+        uri = "http://https://ftp://example.com";
+        with pytest.raises(AirflowException, match="Invalid connection 
string"):
+            Connection.from_uri(uri, conn_id="test_conn")
+
+    def test_from_uri_invalid_protocol_host_error(self):
+        """Test that invalid protocol host raises an error."""
+        uri = "http://user@host://example.com";
+        with pytest.raises(AirflowException, match="Invalid connection 
string"):
+            Connection.from_uri(uri, conn_id="test_conn")
+
+    def test_from_uri_roundtrip(self):
+        """Test that from_uri and get_uri are inverse operations."""
+        original_uri = 
"postgres://user:pass@host:5432/db?param1=value1&param2=value2"
+        conn = Connection.from_uri(original_uri, conn_id="test_conn")
+        roundtrip_uri = conn.get_uri()
+
+        # Parse both URIs to compare (order of query params might differ)
+        conn_from_original = Connection.from_uri(original_uri, conn_id="test")
+        conn_from_roundtrip = Connection.from_uri(roundtrip_uri, 
conn_id="test")
+
+        assert conn_from_original.conn_type == conn_from_roundtrip.conn_type
+        assert conn_from_original.host == conn_from_roundtrip.host
+        assert conn_from_original.login == conn_from_roundtrip.login
+        assert conn_from_original.password == conn_from_roundtrip.password
+        assert conn_from_original.port == conn_from_roundtrip.port
+        assert conn_from_original.schema == conn_from_roundtrip.schema
+        # Check extra content is equivalent (JSON order might differ)
+        if conn_from_original.extra:
+            original_extra = json.loads(conn_from_original.extra)
+            roundtrip_extra = json.loads(conn_from_roundtrip.extra)
+            assert original_extra == roundtrip_extra
diff --git a/task-sdk/tests/task_sdk/execution_time/test_context.py 
b/task-sdk/tests/task_sdk/execution_time/test_context.py
index 39aecb9d955..bd264286a5c 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_context.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_context.py
@@ -260,7 +260,7 @@ class TestConnectionAccessor:
         # empty in case of failed deserialising
         assert dejson == {}
 
-        mock_log.exception.assert_called_once_with(
+        mock_log.exception.assert_any_call(
             "Failed to deserialize extra property `extra`, returning empty 
dictionary"
         )
 
diff --git a/task-sdk/tests/task_sdk/execution_time/test_context_cache.py 
b/task-sdk/tests/task_sdk/execution_time/test_context_cache.py
new file mode 100644
index 00000000000..591bdd617b6
--- /dev/null
+++ b/task-sdk/tests/task_sdk/execution_time/test_context_cache.py
@@ -0,0 +1,332 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest.mock import AsyncMock, MagicMock, call, patch
+
+import pytest
+
+from airflow.sdk.definitions.connection import Connection
+from airflow.sdk.execution_time.cache import SecretCache
+from airflow.sdk.execution_time.comms import ConnectionResult, VariableResult
+from airflow.sdk.execution_time.context import (
+    _delete_variable,
+    _get_connection,
+    _get_variable,
+    _set_variable,
+)
+
+from tests_common.test_utils.config import conf_vars
+
+
+class TestConnectionCacheIntegration:
+    """Test the integration of SecretCache with connection access."""
+
+    @staticmethod
+    @conf_vars({("secrets", "use_cache"): "true"})
+    def setup_method():
+        SecretCache.reset()
+        SecretCache.init()
+
+    @staticmethod
+    def teardown_method():
+        SecretCache.reset()
+
+    
@patch("airflow.sdk.execution_time.supervisor.ensure_secrets_backend_loaded")
+    def test_get_connection_uses_cache_when_available(self, 
mock_ensure_backends):
+        """Test that _get_connection uses cache when connection is cached."""
+        conn_id = "test_conn"
+        uri = "postgres://user:pass@host:5432/db"
+
+        SecretCache.save_connection_uri(conn_id, uri)
+
+        result = _get_connection(conn_id)
+        assert result.conn_id == conn_id
+        assert result.conn_type == "postgres"
+        assert result.host == "host"
+        assert result.login == "user"
+        assert result.password == "pass"
+        assert result.port == 5432
+        assert result.schema == "db"
+
+        mock_ensure_backends.assert_not_called()
+
+    
@patch("airflow.sdk.execution_time.supervisor.ensure_secrets_backend_loaded")
+    def test_get_connection_from_backend_saves_to_cache(self, 
mock_ensure_backends):
+        """Test that connection from secrets backend is retrieved correctly 
and cached."""
+        conn_id = "test_conn"
+        conn = Connection(conn_id=conn_id, conn_type="mysql", host="host", 
port=3306)
+
+        mock_backend = MagicMock(spec=["get_connection"])
+        mock_backend.get_connection.return_value = conn
+        mock_ensure_backends.return_value = [mock_backend]
+
+        result = _get_connection(conn_id)
+        assert result.conn_id == conn_id
+        assert result.conn_type == "mysql"
+        mock_backend.get_connection.assert_called_once_with(conn_id=conn_id)
+
+        cached_uri = SecretCache.get_connection_uri(conn_id)
+        cached_conn = Connection.from_uri(cached_uri, conn_id=conn_id)
+        assert cached_conn.conn_type == "mysql"
+        assert cached_conn.host == "host"
+
+    
@patch("airflow.sdk.execution_time.supervisor.ensure_secrets_backend_loaded")
+    def test_get_connection_from_api(self, mock_ensure_backends, 
mock_supervisor_comms):
+        """Test that connection from API server works correctly."""
+        conn_id = "test_conn"
+        conn_result = ConnectionResult(
+            conn_id=conn_id,
+            conn_type="mysql",
+            host="host",
+            port=3306,
+            login="user",
+            password="pass",
+        )
+
+        mock_ensure_backends.return_value = []
+
+        mock_supervisor_comms.send.return_value = conn_result
+
+        result = _get_connection(conn_id)
+
+        assert result.conn_id == conn_id
+        assert result.conn_type == "mysql"
+        # Called for GetConnection (and possibly MaskSecret)
+        assert mock_supervisor_comms.send.call_count >= 1
+
+        cached_uri = SecretCache.get_connection_uri(conn_id)
+        cached_conn = Connection.from_uri(cached_uri, conn_id=conn_id)
+        assert cached_conn.conn_type == "mysql"
+        assert cached_conn.host == "host"
+
+    @patch("airflow.sdk.execution_time.context.mask_secret")
+    def test_get_connection_masks_secrets(self, mock_mask_secret):
+        """Test that connection secrets are masked from logs."""
+        conn_id = "test_conn"
+        conn = Connection(
+            conn_id=conn_id, conn_type="mysql", login="user", 
password="password", extra='{"key": "value"}'
+        )
+
+        mock_backend = MagicMock(spec=["get_connection"])
+        mock_backend.get_connection.return_value = conn
+
+        with patch(
+            
"airflow.sdk.execution_time.supervisor.ensure_secrets_backend_loaded", 
return_value=[mock_backend]
+        ):
+            result = _get_connection(conn_id)
+
+            assert result.conn_id == conn_id
+            # Check that password and extra were masked
+            mock_mask_secret.assert_has_calls(
+                [
+                    call("password"),
+                    call('{"key": "value"}'),
+                ],
+                any_order=True,
+            )
+
+
+class TestVariableCacheIntegration:
+    """Test the integration of SecretCache with variable access."""
+
+    @staticmethod
+    @conf_vars({("secrets", "use_cache"): "true"})
+    def setup_method():
+        SecretCache.reset()
+        SecretCache.init()
+
+    @staticmethod
+    def teardown_method():
+        SecretCache.reset()
+
+    
@patch("airflow.sdk.execution_time.supervisor.ensure_secrets_backend_loaded")
+    def test_get_variable_uses_cache_when_available(self, 
mock_ensure_backends):
+        """Test that _get_variable uses cache when variable is cached."""
+        key = "test_key"
+        value = "test_value"
+        SecretCache.save_variable(key, value)
+
+        result = _get_variable(key, deserialize_json=False)
+        assert result == value
+        mock_ensure_backends.assert_not_called()
+
+    
@patch("airflow.sdk.execution_time.supervisor.ensure_secrets_backend_loaded")
+    def test_get_variable_from_backend_saves_to_cache(self, 
mock_ensure_backends):
+        """Test that variable from secrets backend is saved to cache."""
+        key = "test_key"
+        value = "test_value"
+
+        mock_backend = MagicMock(spec=["get_variable"])
+        mock_backend.get_variable.return_value = value
+        mock_ensure_backends.return_value = [mock_backend]
+
+        result = _get_variable(key, deserialize_json=False)
+        assert result == value
+        mock_backend.get_variable.assert_called_once_with(key=key)
+        cached_value = SecretCache.get_variable(key)
+        assert cached_value == value
+
+    
@patch("airflow.sdk.execution_time.supervisor.ensure_secrets_backend_loaded")
+    def test_get_variable_from_api_saves_to_cache(self, mock_ensure_backends, 
mock_supervisor_comms):
+        """Test that variable from API server is saved to cache."""
+        key = "test_key"
+        value = "test_value"
+        var_result = VariableResult(key=key, value=value)
+
+        mock_ensure_backends.return_value = []
+        mock_supervisor_comms.send.return_value = var_result
+
+        result = _get_variable(key, deserialize_json=False)
+
+        assert result == value
+        cached_value = SecretCache.get_variable(key)
+        assert cached_value == value
+
+    
@patch("airflow.sdk.execution_time.supervisor.ensure_secrets_backend_loaded")
+    def test_get_variable_with_json_deserialization(self, 
mock_ensure_backends):
+        """Test that _get_variable handles JSON deserialization correctly with 
cache."""
+        key = "test_key"
+        json_value = '{"key": "value", "number": 42}'
+        SecretCache.save_variable(key, json_value)
+
+        result = _get_variable(key, deserialize_json=True)
+        assert result == {"key": "value", "number": 42}
+        cached_value = SecretCache.get_variable(key)
+        assert cached_value == json_value
+
+    def test_set_variable_invalidates_cache(self, mock_supervisor_comms):
+        """Test that _set_variable invalidates the cache."""
+        key = "test_key"
+        old_value = "old_value"
+        new_value = "new_value"
+        SecretCache.save_variable(key, old_value)
+
+        _set_variable(key, new_value)
+        mock_supervisor_comms.send.assert_called_once()
+        with pytest.raises(SecretCache.NotPresentException):
+            SecretCache.get_variable(key)
+
+    def test_delete_variable_invalidates_cache(self, mock_supervisor_comms):
+        """Test that _delete_variable invalidates the cache."""
+        key = "test_key"
+        value = "test_value"
+        SecretCache.save_variable(key, value)
+
+        from airflow.sdk.execution_time.comms import OKResponse
+
+        mock_supervisor_comms.send.return_value = OKResponse(ok=True)
+
+        _delete_variable(key)
+        mock_supervisor_comms.send.assert_called_once()
+        with pytest.raises(SecretCache.NotPresentException):
+            SecretCache.get_variable(key)
+
+
+class TestAsyncConnectionCache:
+    """Test the integration of SecretCache with async connection access."""
+
+    @staticmethod
+    @conf_vars({("secrets", "use_cache"): "true"})
+    def setup_method():
+        SecretCache.reset()
+        SecretCache.init()
+
+    @staticmethod
+    def teardown_method():
+        SecretCache.reset()
+
+    @pytest.mark.asyncio
+    async def test_async_get_connection_uses_cache(self):
+        """Test that _async_get_connection uses cache when connection is 
cached."""
+        from airflow.sdk.execution_time.context import _async_get_connection
+
+        conn_id = "test_conn"
+        uri = "postgres://user:pass@host:5432/db"
+
+        SecretCache.save_connection_uri(conn_id, uri)
+
+        result = await _async_get_connection(conn_id)
+        assert result.conn_id == conn_id
+        assert result.conn_type == "postgres"
+        assert result.host == "host"
+        assert result.login == "user"
+        assert result.password == "pass"
+        assert result.port == 5432
+        assert result.schema == "db"
+
+    @pytest.mark.asyncio
+    async def test_async_get_connection_from_api(self, mock_supervisor_comms):
+        """Test that async connection from API server works correctly."""
+        from airflow.sdk.execution_time.context import _async_get_connection
+
+        conn_id = "test_conn"
+        conn_result = ConnectionResult(
+            conn_id=conn_id,
+            conn_type="mysql",
+            host="host",
+            port=3306,
+        )
+
+        # Configure asend to return the conn_result when awaited
+        mock_supervisor_comms.asend = AsyncMock(return_value=conn_result)
+
+        result = await _async_get_connection(conn_id)
+
+        assert result.conn_id == conn_id
+        assert result.conn_type == "mysql"
+        mock_supervisor_comms.asend.assert_called_once()
+
+        cached_uri = SecretCache.get_connection_uri(conn_id)
+        cached_conn = Connection.from_uri(cached_uri, conn_id=conn_id)
+        assert cached_conn.conn_type == "mysql"
+        assert cached_conn.host == "host"
+
+
+class TestCacheDisabled:
+    """Test behavior when cache is disabled."""
+
+    @staticmethod
+    @conf_vars({("secrets", "use_cache"): "false"})
+    def setup_method():
+        SecretCache.reset()
+        SecretCache.init()
+
+    @staticmethod
+    def teardown_method():
+        SecretCache.reset()
+
+    
@patch("airflow.sdk.execution_time.supervisor.ensure_secrets_backend_loaded")
+    def test_get_connection_no_cache_when_disabled(self, mock_ensure_backends, 
mock_supervisor_comms):
+        """Test that cache is not used when disabled."""
+        conn_id = "test_conn"
+        conn_result = ConnectionResult(conn_id=conn_id, conn_type="mysql", 
host="host")
+
+        mock_ensure_backends.return_value = []
+
+        mock_supervisor_comms.send.return_value = conn_result
+
+        result = _get_connection(conn_id)
+
+        assert result.conn_id == conn_id
+        # Called for GetConnection (and possibly MaskSecret)
+        assert mock_supervisor_comms.send.call_count >= 1
+
+        _get_connection(conn_id)
+
+        # Called twice since cache is disabled
+        assert mock_supervisor_comms.send.call_count >= 2
diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py 
b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
index 525d137a481..fad6e4d66e8 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
@@ -1287,7 +1287,7 @@ class TestRuntimeTaskInstance:
         # Access the connection from the context
         conn_from_context = context["conn"].test_conn
 
-        
mock_supervisor_comms.send.assert_called_once_with(GetConnection(conn_id="test_conn"))
+        
mock_supervisor_comms.send.assert_any_call(GetConnection(conn_id="test_conn"))
 
         assert conn_from_context == Connection(
             conn_id="test_conn",


Reply via email to