This is an automated email from the ASF dual-hosted git repository. kaxilnaik pushed a commit to branch v3-1-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 3a7ac635802287839c4989a78f07220e153b1fd7 Author: Kaxil Naik <[email protected]> AuthorDate: Tue Sep 23 04:38:23 2025 +0100 Use ``SecretCache`` for connection and variable access in task sdk (#55972) (cherry picked from commit 049123e8c730ea471dc5e93b69a21d6b1fddba91) --- task-sdk/src/airflow/sdk/definitions/connection.py | 86 +++++- task-sdk/src/airflow/sdk/execution_time/context.py | 83 +++++- task-sdk/tests/task_sdk/bases/test_hook.py | 2 +- .../{test_connections.py => test_connection.py} | 124 ++++++++ .../tests/task_sdk/execution_time/test_context.py | 2 +- .../task_sdk/execution_time/test_context_cache.py | 332 +++++++++++++++++++++ .../task_sdk/execution_time/test_task_runner.py | 2 +- 7 files changed, 612 insertions(+), 19 deletions(-) diff --git a/task-sdk/src/airflow/sdk/definitions/connection.py b/task-sdk/src/airflow/sdk/definitions/connection.py index b0e1372f52f..b1fb1190bc3 100644 --- a/task-sdk/src/airflow/sdk/definitions/connection.py +++ b/task-sdk/src/airflow/sdk/definitions/connection.py @@ -22,6 +22,7 @@ import json import logging from json import JSONDecodeError from typing import Any +from urllib.parse import parse_qsl, quote, unquote, urlencode, urlsplit import attrs @@ -31,6 +32,26 @@ from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType log = logging.getLogger(__name__) +def _parse_netloc_to_hostname(uri_parts): + """ + Parse a URI string to get the correct Hostname. + + ``urlparse(...).hostname`` or ``urlsplit(...).hostname`` returns value into the lowercase in most cases, + there are some exclusion exists for specific cases such as https://bugs.python.org/issue32323 + In case if expected to get a path as part of hostname path, + then default behavior ``urlparse``/``urlsplit`` is unexpected. + """ + hostname = unquote(uri_parts.hostname or "") + if "/" in hostname: + hostname = uri_parts.netloc + if "@" in hostname: + hostname = hostname.rsplit("@", 1)[1] + if ":" in hostname: + hostname = hostname.split(":", 1)[0] + hostname = unquote(hostname) + return hostname + + def _prune_dict(val: Any, mode="strict"): """ Given dict ``val``, returns new dict based on ``val`` with all empty elements removed. @@ -104,7 +125,7 @@ class Connection: def get_uri(self) -> str: """Generate and return connection in URI format.""" - from urllib.parse import parse_qsl, quote, urlencode + from urllib.parse import parse_qsl if self.conn_type and "_" in self.conn_type: log.warning( @@ -156,10 +177,11 @@ class Connection: if self.extra: try: - query: str | None = urlencode(self.extra_dejson) + extra_dejson = self.extra_dejson + query: str | None = urlencode(extra_dejson) except TypeError: query = None - if query and self.extra_dejson == dict(parse_qsl(query, keep_blank_values=True)): + if query and extra_dejson == dict(parse_qsl(query, keep_blank_values=True)): uri += ("?" if self.schema else "/?") + query else: uri += ("?" if self.schema else "/?") + urlencode({self.EXTRA_KEY: self.extra}) @@ -316,6 +338,64 @@ class Connection: conn_repr.pop("conn_id", None) return json.dumps(conn_repr) + @classmethod + def from_uri(cls, uri: str, conn_id: str) -> Connection: + """ + Create a Connection from a URI string. + + :param uri: URI string to parse + :param conn_id: Connection ID to assign to the connection + :return: Connection object + """ + schemes_count_in_uri = uri.count("://") + if schemes_count_in_uri > 2: + raise AirflowException(f"Invalid connection string: {uri}.") + host_with_protocol = schemes_count_in_uri == 2 + uri_parts = urlsplit(uri) + conn_type = uri_parts.scheme + normalized_conn_type = cls._normalize_conn_type(conn_type) + rest_of_the_url = uri.replace(f"{conn_type}://", ("" if host_with_protocol else "//")) + if host_with_protocol: + uri_splits = rest_of_the_url.split("://", 1) + if "@" in uri_splits[0] or ":" in uri_splits[0]: + raise AirflowException(f"Invalid connection string: {uri}.") + uri_parts = urlsplit(rest_of_the_url) + protocol = uri_parts.scheme if host_with_protocol else None + host = _parse_netloc_to_hostname(uri_parts) + parsed_host = cls._create_host(protocol, host) + quoted_schema = uri_parts.path[1:] + schema = unquote(quoted_schema) if quoted_schema else quoted_schema + login = unquote(uri_parts.username) if uri_parts.username else uri_parts.username + password = unquote(uri_parts.password) if uri_parts.password else uri_parts.password + port = uri_parts.port + extra = None + if uri_parts.query: + query = dict(parse_qsl(uri_parts.query, keep_blank_values=True)) + if cls.EXTRA_KEY in query: + extra = query[cls.EXTRA_KEY] + else: + extra = json.dumps(query) + + return cls( + conn_id=conn_id, + conn_type=normalized_conn_type, + host=parsed_host, + schema=schema, + login=login, + password=password, + port=port, + extra=extra, + ) + + @staticmethod + def _create_host(protocol, host) -> str | None: + """Return the connection host with the protocol.""" + if not host: + return host + if protocol: + return f"{protocol}://{host}" + return host + @staticmethod def _normalize_conn_type(conn_type): if conn_type == "postgresql": diff --git a/task-sdk/src/airflow/sdk/execution_time/context.py b/task-sdk/src/airflow/sdk/execution_time/context.py index 2107cd853ab..4580ef4651e 100644 --- a/task-sdk/src/airflow/sdk/execution_time/context.py +++ b/task-sdk/src/airflow/sdk/execution_time/context.py @@ -116,6 +116,14 @@ def _process_connection_result_conn(conn_result: ReceiveMsgType | None) -> Conne return Connection(**conn_result.model_dump(exclude={"type"}, by_alias=True)) +def _mask_connection_secrets(conn: Connection) -> None: + """Mask sensitive connection fields from logs.""" + if conn.password: + mask_secret(conn.password) + if conn.extra: + mask_secret(conn.extra) + + def _convert_variable_result_to_variable(var_result: VariableResult, deserialize_json: bool) -> Variable: from airflow.sdk.definitions.variable import Variable @@ -127,22 +135,28 @@ def _convert_variable_result_to_variable(var_result: VariableResult, deserialize def _get_connection(conn_id: str) -> Connection: + from airflow.sdk.execution_time.cache import SecretCache from airflow.sdk.execution_time.supervisor import ensure_secrets_backend_loaded - # TODO: check cache first (also in _async_get_connection) - # enabled only if SecretCache.init() has been called first + # Check cache first (optional; only on dag processor) + try: + uri = SecretCache.get_connection_uri(conn_id) + from airflow.sdk.definitions.connection import Connection + + conn = Connection.from_uri(uri, conn_id=conn_id) + _mask_connection_secrets(conn) + return conn + except SecretCache.NotPresentException: + pass # continue to backends # iterate over configured backends if not in cache (or expired) backends = ensure_secrets_backend_loaded() for secrets_backend in backends: try: - conn = secrets_backend.get_connection(conn_id=conn_id) + conn = secrets_backend.get_connection(conn_id=conn_id) # type: ignore[assignment] if conn: - # TODO: this should probably be in get conn - if conn.password: - mask_secret(conn.password) - if conn.extra: - mask_secret(conn.extra) + SecretCache.save_connection_uri(conn_id, conn.get_uri()) + _mask_connection_secrets(conn) return conn except Exception: log.exception( @@ -168,12 +182,27 @@ def _get_connection(conn_id: str) -> Connection: msg = SUPERVISOR_COMMS.send(GetConnection(conn_id=conn_id)) - return _process_connection_result_conn(msg) + conn = _process_connection_result_conn(msg) + SecretCache.save_connection_uri(conn_id, conn.get_uri()) + return conn async def _async_get_connection(conn_id: str) -> Connection: from asgiref.sync import sync_to_async + from airflow.sdk.execution_time.cache import SecretCache + + # Check cache first + try: + uri = SecretCache.get_connection_uri(conn_id) + from airflow.sdk.definitions.connection import Connection + + conn = Connection.from_uri(uri, conn_id=conn_id) + _mask_connection_secrets(conn) + return conn + except SecretCache.NotPresentException: + pass # continue to API + from airflow.sdk.execution_time.comms import GetConnection from airflow.sdk.execution_time.supervisor import ensure_secrets_backend_loaded from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS @@ -185,7 +214,7 @@ async def _async_get_connection(conn_id: str) -> Connection: backends = ensure_secrets_backend_loaded() for secrets_backend in backends: try: - conn = await sync_to_async(secrets_backend.get_connection)(conn_id) + conn = await sync_to_async(secrets_backend.get_connection)(conn_id) # type: ignore[assignment] if conn: # TODO: this should probably be in get conn if conn.password: @@ -199,20 +228,38 @@ async def _async_get_connection(conn_id: str) -> Connection: # If no secrets backend has the connection, fall back to API server msg = await SUPERVISOR_COMMS.asend(GetConnection(conn_id=conn_id)) - return _process_connection_result_conn(msg) + conn = _process_connection_result_conn(msg) + SecretCache.save_connection_uri(conn_id, conn.get_uri()) + _mask_connection_secrets(conn) + return conn def _get_variable(key: str, deserialize_json: bool) -> Any: - # TODO: check cache first - # enabled only if SecretCache.init() has been called first + from airflow.sdk.execution_time.cache import SecretCache from airflow.sdk.execution_time.supervisor import ensure_secrets_backend_loaded + # Check cache first + try: + var_val = SecretCache.get_variable(key) + if var_val is not None: + if deserialize_json: + import json + + var_val = json.loads(var_val) + if isinstance(var_val, str): + mask_secret(var_val, key) + return var_val + except SecretCache.NotPresentException: + pass # Continue to check backends + backends = ensure_secrets_backend_loaded() # iterate over backends if not in cache (or expired) for secrets_backend in backends: try: var_val = secrets_backend.get_variable(key=key) if var_val is not None: + # Save raw value before deserialization to maintain cache consistency + SecretCache.save_variable(key, var_val) if deserialize_json: import json @@ -248,6 +295,8 @@ def _get_variable(key: str, deserialize_json: bool) -> Any: if TYPE_CHECKING: assert isinstance(msg, VariableResult) variable = _convert_variable_result_to_variable(msg, deserialize_json) + # Save raw value to ensure cache consistency regardless of deserialize_json parameter + SecretCache.save_variable(key, msg.value) return variable.value @@ -259,6 +308,7 @@ def _set_variable(key: str, value: Any, description: str | None = None, serializ # keep Task SDK as a separate package than execution time mods. import json + from airflow.sdk.execution_time.cache import SecretCache from airflow.sdk.execution_time.comms import PutVariable from airflow.sdk.execution_time.supervisor import ensure_secrets_backend_loaded from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS @@ -292,6 +342,9 @@ def _set_variable(key: str, value: Any, description: str | None = None, serializ SUPERVISOR_COMMS.send(PutVariable(key=key, value=value, description=description)) + # Invalidate cache after setting the variable + SecretCache.invalidate_variable(key) + def _delete_variable(key: str) -> None: # TODO: This should probably be moved to a separate module like `airflow.sdk.execution_time.comms` @@ -299,6 +352,7 @@ def _delete_variable(key: str) -> None: # A reason to not move it to `airflow.sdk.execution_time.comms` is that it # will make that module depend on Task SDK, which is not ideal because we intend to # keep Task SDK as a separate package than execution time mods. + from airflow.sdk.execution_time.cache import SecretCache from airflow.sdk.execution_time.comms import DeleteVariable from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS @@ -306,6 +360,9 @@ def _delete_variable(key: str) -> None: if TYPE_CHECKING: assert isinstance(msg, OKResponse) + # Invalidate cache after deleting the variable + SecretCache.invalidate_variable(key) + class ConnectionAccessor: """Wrapper to access Connection entries in template.""" diff --git a/task-sdk/tests/task_sdk/bases/test_hook.py b/task-sdk/tests/task_sdk/bases/test_hook.py index 4c691266f82..df1c28593e9 100644 --- a/task-sdk/tests/task_sdk/bases/test_hook.py +++ b/task-sdk/tests/task_sdk/bases/test_hook.py @@ -55,7 +55,7 @@ class TestBaseHook: hook = BaseHook(logger_name="") hook.get_connection(conn_id="test_conn") - mock_supervisor_comms.send.assert_called_once_with( + mock_supervisor_comms.send.assert_any_call( msg=GetConnection(conn_id="test_conn"), ) diff --git a/task-sdk/tests/task_sdk/definitions/test_connections.py b/task-sdk/tests/task_sdk/definitions/test_connection.py similarity index 66% rename from task-sdk/tests/task_sdk/definitions/test_connections.py rename to task-sdk/tests/task_sdk/definitions/test_connection.py index 05a80f59c7c..eff5ed1e02b 100644 --- a/task-sdk/tests/task_sdk/definitions/test_connections.py +++ b/task-sdk/tests/task_sdk/definitions/test_connection.py @@ -273,3 +273,127 @@ class TestConnectionsFromSecrets: # mock_env is only called when LocalFilesystemBackend doesn't have it mock_env_get.assert_called() assert conn == Connection(conn_id="something", conn_type="some-type") + + +class TestConnectionFromUri: + """Test the Connection.from_uri() classmethod.""" + + def test_from_uri_basic(self): + """Test basic URI parsing.""" + uri = "postgres://user:pass@host:5432/db" + conn = Connection.from_uri(uri, conn_id="test_conn") + + assert conn.conn_id == "test_conn" + assert conn.conn_type == "postgres" + assert conn.host == "host" + assert conn.login == "user" + assert conn.password == "pass" + assert conn.port == 5432 + assert conn.schema == "db" + assert conn.extra is None + + def test_from_uri_with_query_params(self): + """Test URI parsing with query parameters.""" + uri = "mysql://user:pass@host:3306/db?charset=utf8&timeout=30" + conn = Connection.from_uri(uri, conn_id="test_conn") + + assert conn.conn_id == "test_conn" + assert conn.conn_type == "mysql" + assert conn.host == "host" + assert conn.login == "user" + assert conn.password == "pass" + assert conn.port == 3306 + assert conn.schema == "db" + # Extra should be JSON string with query params + extra_dict = json.loads(conn.extra) + assert extra_dict == {"charset": "utf8", "timeout": "30"} + + def test_from_uri_with_extra_key(self): + """Test URI parsing with __extra__ query parameter.""" + extra_value = json.dumps({"ssl_mode": "require", "connect_timeout": 10}) + uri = f"postgres://user:pass@host:5432/db?__extra__={extra_value}" + conn = Connection.from_uri(uri, conn_id="test_conn") + + assert conn.conn_id == "test_conn" + assert conn.conn_type == "postgres" + assert conn.extra == extra_value + + def test_from_uri_with_protocol_in_host(self): + """Test URI parsing with protocol in host (double ://).""" + uri = "http://https://example.com:8080/path" + conn = Connection.from_uri(uri, conn_id="test_conn") + + assert conn.conn_id == "test_conn" + assert conn.conn_type == "http" + assert conn.host == "https://example.com" + assert conn.port == 8080 + assert conn.schema == "path" + + def test_from_uri_encoded_credentials(self): + """Test URI parsing with URL-encoded credentials.""" + uri = "postgres://user%40domain:pass%21word@host:5432/db" + conn = Connection.from_uri(uri, conn_id="test_conn") + + assert conn.conn_id == "test_conn" + assert conn.conn_type == "postgres" + assert conn.login == "user@domain" + assert conn.password == "pass!word" + + def test_from_uri_minimal(self): + """Test URI parsing with minimal information.""" + uri = "redis://" + conn = Connection.from_uri(uri, conn_id="test_conn") + + assert conn.conn_id == "test_conn" + assert conn.conn_type == "redis" + assert conn.host == "" # urlsplit returns empty string, not None for minimal URI + assert conn.login is None + assert conn.password is None + assert conn.port is None + assert conn.schema == "" + + def test_from_uri_conn_type_normalization(self): + """Test that connection types are normalized.""" + # postgresql -> postgres + uri = "postgresql://user:pass@host:5432/db" + conn = Connection.from_uri(uri, conn_id="test_conn") + assert conn.conn_type == "postgres" + + # hyphen to underscore + uri = "google-cloud-platform://user:pass@host/db" + conn = Connection.from_uri(uri, conn_id="test_conn") + assert conn.conn_type == "google_cloud_platform" + + def test_from_uri_too_many_schemes_error(self): + """Test that too many schemes in URI raises an error.""" + uri = "http://https://ftp://example.com" + with pytest.raises(AirflowException, match="Invalid connection string"): + Connection.from_uri(uri, conn_id="test_conn") + + def test_from_uri_invalid_protocol_host_error(self): + """Test that invalid protocol host raises an error.""" + uri = "http://user@host://example.com" + with pytest.raises(AirflowException, match="Invalid connection string"): + Connection.from_uri(uri, conn_id="test_conn") + + def test_from_uri_roundtrip(self): + """Test that from_uri and get_uri are inverse operations.""" + original_uri = "postgres://user:pass@host:5432/db?param1=value1¶m2=value2" + conn = Connection.from_uri(original_uri, conn_id="test_conn") + roundtrip_uri = conn.get_uri() + + # Parse both URIs to compare (order of query params might differ) + conn_from_original = Connection.from_uri(original_uri, conn_id="test") + conn_from_roundtrip = Connection.from_uri(roundtrip_uri, conn_id="test") + + assert conn_from_original.conn_type == conn_from_roundtrip.conn_type + assert conn_from_original.host == conn_from_roundtrip.host + assert conn_from_original.login == conn_from_roundtrip.login + assert conn_from_original.password == conn_from_roundtrip.password + assert conn_from_original.port == conn_from_roundtrip.port + assert conn_from_original.schema == conn_from_roundtrip.schema + # Check extra content is equivalent (JSON order might differ) + if conn_from_original.extra: + original_extra = json.loads(conn_from_original.extra) + roundtrip_extra = json.loads(conn_from_roundtrip.extra) + assert original_extra == roundtrip_extra diff --git a/task-sdk/tests/task_sdk/execution_time/test_context.py b/task-sdk/tests/task_sdk/execution_time/test_context.py index 39aecb9d955..bd264286a5c 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_context.py +++ b/task-sdk/tests/task_sdk/execution_time/test_context.py @@ -260,7 +260,7 @@ class TestConnectionAccessor: # empty in case of failed deserialising assert dejson == {} - mock_log.exception.assert_called_once_with( + mock_log.exception.assert_any_call( "Failed to deserialize extra property `extra`, returning empty dictionary" ) diff --git a/task-sdk/tests/task_sdk/execution_time/test_context_cache.py b/task-sdk/tests/task_sdk/execution_time/test_context_cache.py new file mode 100644 index 00000000000..591bdd617b6 --- /dev/null +++ b/task-sdk/tests/task_sdk/execution_time/test_context_cache.py @@ -0,0 +1,332 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock, call, patch + +import pytest + +from airflow.sdk.definitions.connection import Connection +from airflow.sdk.execution_time.cache import SecretCache +from airflow.sdk.execution_time.comms import ConnectionResult, VariableResult +from airflow.sdk.execution_time.context import ( + _delete_variable, + _get_connection, + _get_variable, + _set_variable, +) + +from tests_common.test_utils.config import conf_vars + + +class TestConnectionCacheIntegration: + """Test the integration of SecretCache with connection access.""" + + @staticmethod + @conf_vars({("secrets", "use_cache"): "true"}) + def setup_method(): + SecretCache.reset() + SecretCache.init() + + @staticmethod + def teardown_method(): + SecretCache.reset() + + @patch("airflow.sdk.execution_time.supervisor.ensure_secrets_backend_loaded") + def test_get_connection_uses_cache_when_available(self, mock_ensure_backends): + """Test that _get_connection uses cache when connection is cached.""" + conn_id = "test_conn" + uri = "postgres://user:pass@host:5432/db" + + SecretCache.save_connection_uri(conn_id, uri) + + result = _get_connection(conn_id) + assert result.conn_id == conn_id + assert result.conn_type == "postgres" + assert result.host == "host" + assert result.login == "user" + assert result.password == "pass" + assert result.port == 5432 + assert result.schema == "db" + + mock_ensure_backends.assert_not_called() + + @patch("airflow.sdk.execution_time.supervisor.ensure_secrets_backend_loaded") + def test_get_connection_from_backend_saves_to_cache(self, mock_ensure_backends): + """Test that connection from secrets backend is retrieved correctly and cached.""" + conn_id = "test_conn" + conn = Connection(conn_id=conn_id, conn_type="mysql", host="host", port=3306) + + mock_backend = MagicMock(spec=["get_connection"]) + mock_backend.get_connection.return_value = conn + mock_ensure_backends.return_value = [mock_backend] + + result = _get_connection(conn_id) + assert result.conn_id == conn_id + assert result.conn_type == "mysql" + mock_backend.get_connection.assert_called_once_with(conn_id=conn_id) + + cached_uri = SecretCache.get_connection_uri(conn_id) + cached_conn = Connection.from_uri(cached_uri, conn_id=conn_id) + assert cached_conn.conn_type == "mysql" + assert cached_conn.host == "host" + + @patch("airflow.sdk.execution_time.supervisor.ensure_secrets_backend_loaded") + def test_get_connection_from_api(self, mock_ensure_backends, mock_supervisor_comms): + """Test that connection from API server works correctly.""" + conn_id = "test_conn" + conn_result = ConnectionResult( + conn_id=conn_id, + conn_type="mysql", + host="host", + port=3306, + login="user", + password="pass", + ) + + mock_ensure_backends.return_value = [] + + mock_supervisor_comms.send.return_value = conn_result + + result = _get_connection(conn_id) + + assert result.conn_id == conn_id + assert result.conn_type == "mysql" + # Called for GetConnection (and possibly MaskSecret) + assert mock_supervisor_comms.send.call_count >= 1 + + cached_uri = SecretCache.get_connection_uri(conn_id) + cached_conn = Connection.from_uri(cached_uri, conn_id=conn_id) + assert cached_conn.conn_type == "mysql" + assert cached_conn.host == "host" + + @patch("airflow.sdk.execution_time.context.mask_secret") + def test_get_connection_masks_secrets(self, mock_mask_secret): + """Test that connection secrets are masked from logs.""" + conn_id = "test_conn" + conn = Connection( + conn_id=conn_id, conn_type="mysql", login="user", password="password", extra='{"key": "value"}' + ) + + mock_backend = MagicMock(spec=["get_connection"]) + mock_backend.get_connection.return_value = conn + + with patch( + "airflow.sdk.execution_time.supervisor.ensure_secrets_backend_loaded", return_value=[mock_backend] + ): + result = _get_connection(conn_id) + + assert result.conn_id == conn_id + # Check that password and extra were masked + mock_mask_secret.assert_has_calls( + [ + call("password"), + call('{"key": "value"}'), + ], + any_order=True, + ) + + +class TestVariableCacheIntegration: + """Test the integration of SecretCache with variable access.""" + + @staticmethod + @conf_vars({("secrets", "use_cache"): "true"}) + def setup_method(): + SecretCache.reset() + SecretCache.init() + + @staticmethod + def teardown_method(): + SecretCache.reset() + + @patch("airflow.sdk.execution_time.supervisor.ensure_secrets_backend_loaded") + def test_get_variable_uses_cache_when_available(self, mock_ensure_backends): + """Test that _get_variable uses cache when variable is cached.""" + key = "test_key" + value = "test_value" + SecretCache.save_variable(key, value) + + result = _get_variable(key, deserialize_json=False) + assert result == value + mock_ensure_backends.assert_not_called() + + @patch("airflow.sdk.execution_time.supervisor.ensure_secrets_backend_loaded") + def test_get_variable_from_backend_saves_to_cache(self, mock_ensure_backends): + """Test that variable from secrets backend is saved to cache.""" + key = "test_key" + value = "test_value" + + mock_backend = MagicMock(spec=["get_variable"]) + mock_backend.get_variable.return_value = value + mock_ensure_backends.return_value = [mock_backend] + + result = _get_variable(key, deserialize_json=False) + assert result == value + mock_backend.get_variable.assert_called_once_with(key=key) + cached_value = SecretCache.get_variable(key) + assert cached_value == value + + @patch("airflow.sdk.execution_time.supervisor.ensure_secrets_backend_loaded") + def test_get_variable_from_api_saves_to_cache(self, mock_ensure_backends, mock_supervisor_comms): + """Test that variable from API server is saved to cache.""" + key = "test_key" + value = "test_value" + var_result = VariableResult(key=key, value=value) + + mock_ensure_backends.return_value = [] + mock_supervisor_comms.send.return_value = var_result + + result = _get_variable(key, deserialize_json=False) + + assert result == value + cached_value = SecretCache.get_variable(key) + assert cached_value == value + + @patch("airflow.sdk.execution_time.supervisor.ensure_secrets_backend_loaded") + def test_get_variable_with_json_deserialization(self, mock_ensure_backends): + """Test that _get_variable handles JSON deserialization correctly with cache.""" + key = "test_key" + json_value = '{"key": "value", "number": 42}' + SecretCache.save_variable(key, json_value) + + result = _get_variable(key, deserialize_json=True) + assert result == {"key": "value", "number": 42} + cached_value = SecretCache.get_variable(key) + assert cached_value == json_value + + def test_set_variable_invalidates_cache(self, mock_supervisor_comms): + """Test that _set_variable invalidates the cache.""" + key = "test_key" + old_value = "old_value" + new_value = "new_value" + SecretCache.save_variable(key, old_value) + + _set_variable(key, new_value) + mock_supervisor_comms.send.assert_called_once() + with pytest.raises(SecretCache.NotPresentException): + SecretCache.get_variable(key) + + def test_delete_variable_invalidates_cache(self, mock_supervisor_comms): + """Test that _delete_variable invalidates the cache.""" + key = "test_key" + value = "test_value" + SecretCache.save_variable(key, value) + + from airflow.sdk.execution_time.comms import OKResponse + + mock_supervisor_comms.send.return_value = OKResponse(ok=True) + + _delete_variable(key) + mock_supervisor_comms.send.assert_called_once() + with pytest.raises(SecretCache.NotPresentException): + SecretCache.get_variable(key) + + +class TestAsyncConnectionCache: + """Test the integration of SecretCache with async connection access.""" + + @staticmethod + @conf_vars({("secrets", "use_cache"): "true"}) + def setup_method(): + SecretCache.reset() + SecretCache.init() + + @staticmethod + def teardown_method(): + SecretCache.reset() + + @pytest.mark.asyncio + async def test_async_get_connection_uses_cache(self): + """Test that _async_get_connection uses cache when connection is cached.""" + from airflow.sdk.execution_time.context import _async_get_connection + + conn_id = "test_conn" + uri = "postgres://user:pass@host:5432/db" + + SecretCache.save_connection_uri(conn_id, uri) + + result = await _async_get_connection(conn_id) + assert result.conn_id == conn_id + assert result.conn_type == "postgres" + assert result.host == "host" + assert result.login == "user" + assert result.password == "pass" + assert result.port == 5432 + assert result.schema == "db" + + @pytest.mark.asyncio + async def test_async_get_connection_from_api(self, mock_supervisor_comms): + """Test that async connection from API server works correctly.""" + from airflow.sdk.execution_time.context import _async_get_connection + + conn_id = "test_conn" + conn_result = ConnectionResult( + conn_id=conn_id, + conn_type="mysql", + host="host", + port=3306, + ) + + # Configure asend to return the conn_result when awaited + mock_supervisor_comms.asend = AsyncMock(return_value=conn_result) + + result = await _async_get_connection(conn_id) + + assert result.conn_id == conn_id + assert result.conn_type == "mysql" + mock_supervisor_comms.asend.assert_called_once() + + cached_uri = SecretCache.get_connection_uri(conn_id) + cached_conn = Connection.from_uri(cached_uri, conn_id=conn_id) + assert cached_conn.conn_type == "mysql" + assert cached_conn.host == "host" + + +class TestCacheDisabled: + """Test behavior when cache is disabled.""" + + @staticmethod + @conf_vars({("secrets", "use_cache"): "false"}) + def setup_method(): + SecretCache.reset() + SecretCache.init() + + @staticmethod + def teardown_method(): + SecretCache.reset() + + @patch("airflow.sdk.execution_time.supervisor.ensure_secrets_backend_loaded") + def test_get_connection_no_cache_when_disabled(self, mock_ensure_backends, mock_supervisor_comms): + """Test that cache is not used when disabled.""" + conn_id = "test_conn" + conn_result = ConnectionResult(conn_id=conn_id, conn_type="mysql", host="host") + + mock_ensure_backends.return_value = [] + + mock_supervisor_comms.send.return_value = conn_result + + result = _get_connection(conn_id) + + assert result.conn_id == conn_id + # Called for GetConnection (and possibly MaskSecret) + assert mock_supervisor_comms.send.call_count >= 1 + + _get_connection(conn_id) + + # Called twice since cache is disabled + assert mock_supervisor_comms.send.call_count >= 2 diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index 525d137a481..fad6e4d66e8 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -1287,7 +1287,7 @@ class TestRuntimeTaskInstance: # Access the connection from the context conn_from_context = context["conn"].test_conn - mock_supervisor_comms.send.assert_called_once_with(GetConnection(conn_id="test_conn")) + mock_supervisor_comms.send.assert_any_call(GetConnection(conn_id="test_conn")) assert conn_from_context == Connection( conn_id="test_conn",
