This is an automated email from the ASF dual-hosted git repository.

kaxil pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new d4a5294dcef Add token_provider for short-lived MCP auth in common.ai 
(#68104)
d4a5294dcef is described below

commit d4a5294dcef72f61be040db597c4fd685dec42fb
Author: Kaxil Naik <[email protected]>
AuthorDate: Sat Jun 6 01:03:19 2026 +0100

    Add token_provider for short-lived MCP auth in common.ai (#68104)
    
    MCPHook built the MCP server with a single static Authorization header from 
the
    connection password, so it could not authenticate to MCP endpoints that 
require
    a freshly minted or short-lived token. The motivating case is a Snowflake
    managed MCP server, best authenticated with a key-pair JWT that expires 
after
    about an hour and cannot be stored as a static connection value; the same 
limit
    blocked OAuth/refresh tokens, Workload Identity Federation, and GitHub App
    installation tokens.
    
    MCPHook and MCPToolset now accept an optional token_provider callable. When 
set,
    it is invoked each time the HTTP/SSE server connection is established and 
its
    return value is used as the bearer token, overriding the static password. 
The
    minted token is registered with secret masking (matching the auto-masking 
the
    connection password already receives) so it does not leak into task logs, 
and a
    provider that returns a non-string or empty value fails loud instead of 
silently
    sending an unauthenticated request. token_provider is resolved in DAG code, 
so
    the signing key never enters the serialized DAG.
---
 providers/common/ai/docs/connections/mcp.rst       |  40 +++++++-
 providers/common/ai/docs/toolsets.rst              |   6 ++
 .../src/airflow/providers/common/ai/hooks/mcp.py   |  62 ++++++++++++-
 .../airflow/providers/common/ai/toolsets/mcp.py    |  19 +++-
 .../ai/tests/unit/common/ai/hooks/test_mcp.py      | 102 ++++++++++++++++++++-
 .../ai/tests/unit/common/ai/toolsets/test_mcp.py   |  20 +++-
 6 files changed, 240 insertions(+), 9 deletions(-)

diff --git a/providers/common/ai/docs/connections/mcp.rst 
b/providers/common/ai/docs/connections/mcp.rst
index 48f7749b76c..0eb5f7bfc60 100644
--- a/providers/common/ai/docs/connections/mcp.rst
+++ b/providers/common/ai/docs/connections/mcp.rst
@@ -45,7 +45,9 @@ Host
     Examples: ``http://localhost:3001/mcp``, ``https://mcp.example.com/v1``
 
 Auth Token (Password field)
-    Optional authentication token for the MCP server.
+    Optional authentication token for the MCP server. Sent as a static
+    ``Authorization: Bearer <token>`` header on HTTP/SSE transports. For
+    short-lived or minted tokens, use a ``token_provider`` instead (see below).
 
 Command (Extra field)
     The command to run for ``stdio`` transport. Required when transport is 
``stdio``.
@@ -96,3 +98,39 @@ Examples
         "conn_type": "mcp",
         "extra": "{\"transport\": \"stdio\", \"command\": \"python\", 
\"args\": [\"-m\", \"my_server\"], \"timeout\": 30}"
     }
+
+Short-lived or minted tokens
+----------------------------
+
+Some MCP endpoints require a freshly minted, short-lived token rather than a
+static one. For example, `Snowflake managed MCP servers
+<https://docs.snowflake.com/en/user-guide/snowflake-cortex/cortex-agents-mcp>`__
+are best authenticated with a `key-pair JWT
+<https://docs.snowflake.com/en/user-guide/key-pair-auth>`__: the private key 
never
+leaves your environment and the signed JWT expires after about an hour, so it
+cannot be stored as a static connection ``password``. The same applies to 
OAuth /
+refresh tokens, Workload Identity Federation, and GitHub App installation 
tokens.
+
+For these, pass a ``token_provider`` callable to ``MCPHook`` or ``MCPToolset``
+instead of a static token. It is called each time the connection is established
+and its return value is used as the bearer token, so a fresh token is minted 
(and
+registered with secret masking so it does not leak into task logs):
+
+.. code-block:: python
+
+    from airflow.providers.common.ai.toolsets.mcp import MCPToolset
+
+
+    def mint_snowflake_jwt() -> str:
+        # Sign a short-lived JWT from the Snowflake connection's key-pair.
+        ...
+
+
+    toolset = MCPToolset(
+        mcp_conn_id="snowflake_managed_mcp",
+        token_provider=mint_snowflake_jwt,
+    )
+
+``token_provider`` is resolved in DAG code (it is a Python callable, not a 
stored
+connection field), so the signing key stays in your environment and is never 
baked
+into the serialized DAG.
diff --git a/providers/common/ai/docs/toolsets.rst 
b/providers/common/ai/docs/toolsets.rst
index ee6d8a85b63..a9f454fbc89 100644
--- a/providers/common/ai/docs/toolsets.rst
+++ b/providers/common/ai/docs/toolsets.rst
@@ -272,6 +272,12 @@ Parameters
 - ``tool_prefix``: Optional prefix prepended to tool names to avoid
   collisions when using multiple MCP servers (e.g. ``"weather"`` produces
   ``"weather_get_forecast"``).
+- ``token_provider``: Optional zero-argument callable returning a bearer token.
+  When set, it overrides the connection's static ``password`` for the
+  ``Authorization`` header and is called each time the server connection is
+  established -- use it for short-lived or minted tokens (e.g. a Snowflake
+  managed MCP server authenticated with a key-pair JWT). See
+  :ref:`howto/connection:mcp`.
 
 Using Multiple MCP Servers
 ^^^^^^^^^^^^^^^^^^^^^^^^^^
diff --git a/providers/common/ai/src/airflow/providers/common/ai/hooks/mcp.py 
b/providers/common/ai/src/airflow/providers/common/ai/hooks/mcp.py
index b69517ac6f4..f6749d280f3 100644
--- a/providers/common/ai/src/airflow/providers/common/ai/hooks/mcp.py
+++ b/providers/common/ai/src/airflow/providers/common/ai/hooks/mcp.py
@@ -16,10 +16,23 @@
 # under the License.
 from __future__ import annotations
 
-from typing import Any
+from typing import TYPE_CHECKING, Any
 
 from airflow.providers.common.compat.sdk import BaseHook
 
+if TYPE_CHECKING:
+    from collections.abc import Callable
+
+    from airflow.sdk.execution_time.secrets_masker import mask_secret
+else:
+    try:
+        from airflow.sdk.log import mask_secret
+    except ImportError:
+        try:
+            from airflow.sdk.execution_time.secrets_masker import mask_secret
+        except ImportError:
+            from airflow.utils.log.secrets_masker import mask_secret
+
 
 class MCPHook(BaseHook):
     """
@@ -36,9 +49,22 @@ class MCPHook(BaseHook):
         - **Extra.args**: Command arguments for stdio transport (e.g. 
``["mcp-run-python"]``)
         - **Extra.timeout**: Connection timeout in seconds for stdio (default: 
10)
 
+    For HTTP/SSE transports the ``Authorization`` header is, by default, a 
static
+    ``Bearer`` token taken from the connection ``password``. Endpoints that 
require
+    a freshly minted or short-lived token (e.g. a Snowflake managed MCP server
+    authenticated with a key-pair JWT, OAuth/refresh tokens, Workload Identity
+    Federation, or GitHub App installation tokens) can pass a 
``token_provider``
+    callable instead. It is invoked each time the server connection is 
established
+    and its return value is used as the bearer token, so a fresh token is 
minted
+    without storing a long-lived secret in the connection.
+
     :param mcp_conn_id: Airflow connection ID for the MCP server.
     :param tool_prefix: Optional prefix prepended to tool names
         (e.g. ``"weather"`` → ``"weather_get_forecast"``).
+    :param token_provider: Optional zero-argument callable returning a bearer
+        token string. When set, it overrides the connection ``password`` for 
the
+        ``Authorization`` header on HTTP/SSE transports and is called each 
time the
+        server connection is established. Ignored for the ``stdio`` transport.
     """
 
     conn_name_attr = "mcp_conn_id"
@@ -50,11 +76,14 @@ class MCPHook(BaseHook):
         self,
         mcp_conn_id: str = default_conn_name,
         tool_prefix: str | None = None,
+        *,
+        token_provider: Callable[[], str] | None = None,
         **kwargs,
     ) -> None:
         super().__init__(**kwargs)
         self.mcp_conn_id = mcp_conn_id
         self.tool_prefix = tool_prefix
+        self.token_provider = token_provider
         self._server: Any = None
 
     @staticmethod
@@ -68,6 +97,28 @@ class MCPHook(BaseHook):
             },
         }
 
+    def _auth_headers(self, conn: Any) -> dict[str, str] | None:
+        """
+        Build the ``Authorization`` header for HTTP/SSE transports.
+
+        Prefers ``token_provider`` (minted per connection) over the static
+        connection ``password``. Returns ``None`` when neither yields a token.
+        """
+        if self.token_provider is not None:
+            token = self.token_provider()
+            if not isinstance(token, str) or not token:
+                raise ValueError(
+                    f"token_provider for connection {self.mcp_conn_id!r} must 
return a non-empty "
+                    f"string token, got {type(token).__name__}."
+                )
+            # The static connection password is masked when the connection is
+            # fetched; mask the minted token too so it never leaks into task 
logs.
+            mask_secret(token)
+            return {"Authorization": f"Bearer {token}"}
+        if conn.password:
+            return {"Authorization": f"Bearer {conn.password}"}
+        return None
+
     def get_conn(self) -> Any:
         """
         Return a configured PydanticAI MCP server instance.
@@ -94,16 +145,19 @@ class MCPHook(BaseHook):
         conn = self.get_connection(self.mcp_conn_id)
         extra = conn.extra_dejson
         transport = extra.get("transport", "http")
-        headers = {"Authorization": f"Bearer {conn.password}"} if 
conn.password else None
 
         if transport == "http":
             if not conn.host:
                 raise ValueError(f"Connection {self.mcp_conn_id!r} requires a 
host URL for HTTP transport.")
-            self._server = MCPServerStreamableHTTP(conn.host, headers=headers, 
tool_prefix=self.tool_prefix)
+            self._server = MCPServerStreamableHTTP(
+                conn.host, headers=self._auth_headers(conn), 
tool_prefix=self.tool_prefix
+            )
         elif transport == "sse":
             if not conn.host:
                 raise ValueError(f"Connection {self.mcp_conn_id!r} requires a 
host URL for SSE transport.")
-            self._server = MCPServerSSE(conn.host, headers=headers, 
tool_prefix=self.tool_prefix)
+            self._server = MCPServerSSE(
+                conn.host, headers=self._auth_headers(conn), 
tool_prefix=self.tool_prefix
+            )
         elif transport == "stdio":
             command = extra.get("command")
             if not command:
diff --git 
a/providers/common/ai/src/airflow/providers/common/ai/toolsets/mcp.py 
b/providers/common/ai/src/airflow/providers/common/ai/toolsets/mcp.py
index 0fa085eba0c..d7aae3f9a6a 100644
--- a/providers/common/ai/src/airflow/providers/common/ai/toolsets/mcp.py
+++ b/providers/common/ai/src/airflow/providers/common/ai/toolsets/mcp.py
@@ -24,6 +24,8 @@ from pydantic_ai.toolsets.abstract import AbstractToolset, 
ToolsetTool
 from typing_extensions import Self
 
 if TYPE_CHECKING:
+    from collections.abc import Callable
+
     from pydantic_ai._run_context import RunContext
 
 
@@ -46,9 +48,18 @@ class MCPToolset(AbstractToolset[Any]):
     :class:`~pydantic_ai.mcp.MCPServerSSE`, and
     :class:`~pydantic_ai.mcp.MCPServerStdio` all implement ``AbstractToolset``.
 
+    For MCP endpoints that need a freshly minted or short-lived token (e.g. a
+    Snowflake managed MCP server authenticated with a key-pair JWT, or OAuth /
+    Workload Identity / GitHub App tokens), pass a ``token_provider`` callable.
+    It is invoked each time the connection is established and its return value 
is
+    used as the bearer token, so a fresh token is minted rather than storing a
+    long-lived secret in the connection.
+
     :param mcp_conn_id: Airflow connection ID for the MCP server.
     :param tool_prefix: Optional prefix prepended to tool names
         (e.g. ``"weather"`` → ``"weather_get_forecast"``).
+    :param token_provider: Optional zero-argument callable returning a bearer
+        token string, overriding the connection ``password`` for HTTP/SSE auth.
     """
 
     def __init__(
@@ -56,9 +67,11 @@ class MCPToolset(AbstractToolset[Any]):
         mcp_conn_id: str,
         *,
         tool_prefix: str | None = None,
+        token_provider: Callable[[], str] | None = None,
     ) -> None:
         self._mcp_conn_id = mcp_conn_id
         self._tool_prefix = tool_prefix
+        self._token_provider = token_provider
         self._server: Any = None
 
     @property
@@ -69,7 +82,11 @@ class MCPToolset(AbstractToolset[Any]):
         if self._server is None:
             from airflow.providers.common.ai.hooks.mcp import MCPHook
 
-            hook = MCPHook(mcp_conn_id=self._mcp_conn_id, 
tool_prefix=self._tool_prefix)
+            hook = MCPHook(
+                mcp_conn_id=self._mcp_conn_id,
+                tool_prefix=self._tool_prefix,
+                token_provider=self._token_provider,
+            )
             self._server = hook.get_conn()
         return self._server
 
diff --git a/providers/common/ai/tests/unit/common/ai/hooks/test_mcp.py 
b/providers/common/ai/tests/unit/common/ai/hooks/test_mcp.py
index cf7e983f542..2928376d699 100644
--- a/providers/common/ai/tests/unit/common/ai/hooks/test_mcp.py
+++ b/providers/common/ai/tests/unit/common/ai/hooks/test_mcp.py
@@ -17,7 +17,7 @@
 from __future__ import annotations
 
 import json
-from unittest.mock import patch
+from unittest.mock import MagicMock, patch
 
 import pytest
 
@@ -246,3 +246,103 @@ class TestMCPHookUIFieldBehaviour:
     def test_relabeling(self):
         behaviour = MCPHook.get_ui_field_behaviour()
         assert behaviour["relabeling"]["password"] == "Auth Token"
+
+
+class TestMCPHookTokenProvider:
+    @patch(_MCP_HTTP)
+    def test_http_uses_token_provider(self, mock_server_cls):
+        hook = MCPHook(mcp_conn_id="test_conn", token_provider=lambda: 
"minted-jwt")
+        conn = Connection(conn_id="test_conn", conn_type="mcp", 
host="http://localhost:3001/mcp";)
+        with patch.object(hook, "get_connection", return_value=conn):
+            hook.get_conn()
+
+        mock_server_cls.assert_called_once_with(
+            "http://localhost:3001/mcp";,
+            headers={"Authorization": "Bearer minted-jwt"},
+            tool_prefix=None,
+        )
+
+    @patch(_MCP_HTTP)
+    def test_token_provider_overrides_static_password(self, mock_server_cls):
+        hook = MCPHook(mcp_conn_id="test_conn", token_provider=lambda: "fresh")
+        conn = Connection(
+            conn_id="test_conn",
+            conn_type="mcp",
+            host="http://localhost:3001/mcp";,
+            password="static-pat",
+        )
+        with patch.object(hook, "get_connection", return_value=conn):
+            hook.get_conn()
+
+        mock_server_cls.assert_called_once_with(
+            "http://localhost:3001/mcp";,
+            headers={"Authorization": "Bearer fresh"},
+            tool_prefix=None,
+        )
+
+    @patch(_MCP_HTTP)
+    def test_token_provider_called_when_establishing_connection(self, 
mock_server_cls):
+        provider = MagicMock(return_value="tok")
+        hook = MCPHook(mcp_conn_id="test_conn", token_provider=provider)
+        conn = Connection(conn_id="test_conn", conn_type="mcp", 
host="http://localhost:3001/mcp";)
+        with patch.object(hook, "get_connection", return_value=conn):
+            hook.get_conn()
+
+        provider.assert_called_once_with()
+
+    @patch(_MCP_HTTP)
+    def test_masks_minted_token(self, mock_server_cls):
+        """The minted token must be registered with secret masking, like 
conn.password."""
+        hook = MCPHook(mcp_conn_id="test_conn", token_provider=lambda: 
"minted-jwt")
+        conn = Connection(conn_id="test_conn", conn_type="mcp", 
host="http://localhost:3001/mcp";)
+        with (
+            patch.object(hook, "get_connection", return_value=conn),
+            patch("airflow.providers.common.ai.hooks.mcp.mask_secret") as 
mock_mask,
+        ):
+            hook.get_conn()
+
+        mock_mask.assert_called_once_with("minted-jwt")
+
+    @pytest.mark.parametrize("bad_token", ["", None], ids=["empty", 
"non_string"])
+    @patch(_MCP_HTTP)
+    def test_invalid_token_raises(self, mock_server_cls, bad_token):
+        """A token_provider returning a non-string or empty value fails 
loud."""
+        hook = MCPHook(mcp_conn_id="test_conn", token_provider=lambda: 
bad_token)
+        conn = Connection(conn_id="test_conn", conn_type="mcp", 
host="http://localhost:3001/mcp";)
+        with patch.object(hook, "get_connection", return_value=conn):
+            with pytest.raises(ValueError, match="must return a non-empty 
string token"):
+                hook.get_conn()
+
+    @patch(_MCP_SSE)
+    def test_sse_uses_token_provider(self, mock_server_cls):
+        hook = MCPHook(mcp_conn_id="test_conn", token_provider=lambda: 
"minted")
+        conn = Connection(
+            conn_id="test_conn",
+            conn_type="mcp",
+            host="http://localhost:3001/sse";,
+            extra=json.dumps({"transport": "sse"}),
+        )
+        with patch.object(hook, "get_connection", return_value=conn):
+            hook.get_conn()
+
+        mock_server_cls.assert_called_once_with(
+            "http://localhost:3001/sse";,
+            headers={"Authorization": "Bearer minted"},
+            tool_prefix=None,
+        )
+
+    @patch(_MCP_STDIO)
+    def test_stdio_does_not_invoke_token_provider(self, mock_server_cls):
+        """stdio has no HTTP headers, so the token provider must not be 
called."""
+        provider = MagicMock(return_value="tok")
+        hook = MCPHook(mcp_conn_id="test_conn", token_provider=provider)
+        conn = Connection(
+            conn_id="test_conn",
+            conn_type="mcp",
+            extra=json.dumps({"transport": "stdio", "command": "uvx", "args": 
["mcp-run-python"]}),
+        )
+        with patch.object(hook, "get_connection", return_value=conn):
+            hook.get_conn()
+
+        provider.assert_not_called()
+        mock_server_cls.assert_called_once_with("uvx", 
args=["mcp-run-python"], timeout=10, tool_prefix=None)
diff --git a/providers/common/ai/tests/unit/common/ai/toolsets/test_mcp.py 
b/providers/common/ai/tests/unit/common/ai/toolsets/test_mcp.py
index 5391ca143de..bbc3e079d58 100644
--- a/providers/common/ai/tests/unit/common/ai/toolsets/test_mcp.py
+++ b/providers/common/ai/tests/unit/common/ai/toolsets/test_mcp.py
@@ -43,7 +43,7 @@ class TestMCPToolsetGetServer:
         ts = MCPToolset("mcp_conn")
         server = ts._get_server()
 
-        mock_hook_cls.assert_called_once_with(mcp_conn_id="mcp_conn", 
tool_prefix=None)
+        mock_hook_cls.assert_called_once_with(mcp_conn_id="mcp_conn", 
tool_prefix=None, token_provider=None)
         mock_hook_cls.return_value.get_conn.assert_called_once()
         assert server is mock_server
 
@@ -54,7 +54,23 @@ class TestMCPToolsetGetServer:
         ts = MCPToolset("mcp_conn", tool_prefix="weather")
         ts._get_server()
 
-        mock_hook_cls.assert_called_once_with(mcp_conn_id="mcp_conn", 
tool_prefix="weather")
+        mock_hook_cls.assert_called_once_with(
+            mcp_conn_id="mcp_conn", tool_prefix="weather", token_provider=None
+        )
+
+    @patch(_HOOK_PATH, autospec=True)
+    def test_passes_token_provider_to_hook(self, mock_hook_cls):
+        mock_hook_cls.return_value.get_conn.return_value = MagicMock()
+
+        def provider() -> str:
+            return "minted"
+
+        ts = MCPToolset("mcp_conn", token_provider=provider)
+        ts._get_server()
+
+        mock_hook_cls.assert_called_once_with(
+            mcp_conn_id="mcp_conn", tool_prefix=None, token_provider=provider
+        )
 
     @patch(_HOOK_PATH, autospec=True)
     def test_caches_server(self, mock_hook_cls):

Reply via email to