This is an automated email from the ASF dual-hosted git repository.
kaxil pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new d4a5294dcef Add token_provider for short-lived MCP auth in common.ai
(#68104)
d4a5294dcef is described below
commit d4a5294dcef72f61be040db597c4fd685dec42fb
Author: Kaxil Naik <[email protected]>
AuthorDate: Sat Jun 6 01:03:19 2026 +0100
Add token_provider for short-lived MCP auth in common.ai (#68104)
MCPHook built the MCP server with a single static Authorization header from
the
connection password, so it could not authenticate to MCP endpoints that
require
a freshly minted or short-lived token. The motivating case is a Snowflake
managed MCP server, best authenticated with a key-pair JWT that expires
after
about an hour and cannot be stored as a static connection value; the same
limit
blocked OAuth/refresh tokens, Workload Identity Federation, and GitHub App
installation tokens.
MCPHook and MCPToolset now accept an optional token_provider callable. When
set,
it is invoked each time the HTTP/SSE server connection is established and
its
return value is used as the bearer token, overriding the static password.
The
minted token is registered with secret masking (matching the auto-masking
the
connection password already receives) so it does not leak into task logs,
and a
provider that returns a non-string or empty value fails loud instead of
silently
sending an unauthenticated request. token_provider is resolved in DAG code,
so
the signing key never enters the serialized DAG.
---
providers/common/ai/docs/connections/mcp.rst | 40 +++++++-
providers/common/ai/docs/toolsets.rst | 6 ++
.../src/airflow/providers/common/ai/hooks/mcp.py | 62 ++++++++++++-
.../airflow/providers/common/ai/toolsets/mcp.py | 19 +++-
.../ai/tests/unit/common/ai/hooks/test_mcp.py | 102 ++++++++++++++++++++-
.../ai/tests/unit/common/ai/toolsets/test_mcp.py | 20 +++-
6 files changed, 240 insertions(+), 9 deletions(-)
diff --git a/providers/common/ai/docs/connections/mcp.rst
b/providers/common/ai/docs/connections/mcp.rst
index 48f7749b76c..0eb5f7bfc60 100644
--- a/providers/common/ai/docs/connections/mcp.rst
+++ b/providers/common/ai/docs/connections/mcp.rst
@@ -45,7 +45,9 @@ Host
Examples: ``http://localhost:3001/mcp``, ``https://mcp.example.com/v1``
Auth Token (Password field)
- Optional authentication token for the MCP server.
+ Optional authentication token for the MCP server. Sent as a static
+ ``Authorization: Bearer <token>`` header on HTTP/SSE transports. For
+ short-lived or minted tokens, use a ``token_provider`` instead (see below).
Command (Extra field)
The command to run for ``stdio`` transport. Required when transport is
``stdio``.
@@ -96,3 +98,39 @@ Examples
"conn_type": "mcp",
"extra": "{\"transport\": \"stdio\", \"command\": \"python\",
\"args\": [\"-m\", \"my_server\"], \"timeout\": 30}"
}
+
+Short-lived or minted tokens
+----------------------------
+
+Some MCP endpoints require a freshly minted, short-lived token rather than a
+static one. For example, `Snowflake managed MCP servers
+<https://docs.snowflake.com/en/user-guide/snowflake-cortex/cortex-agents-mcp>`__
+are best authenticated with a `key-pair JWT
+<https://docs.snowflake.com/en/user-guide/key-pair-auth>`__: the private key
never
+leaves your environment and the signed JWT expires after about an hour, so it
+cannot be stored as a static connection ``password``. The same applies to
OAuth /
+refresh tokens, Workload Identity Federation, and GitHub App installation
tokens.
+
+For these, pass a ``token_provider`` callable to ``MCPHook`` or ``MCPToolset``
+instead of a static token. It is called each time the connection is established
+and its return value is used as the bearer token, so a fresh token is minted
(and
+registered with secret masking so it does not leak into task logs):
+
+.. code-block:: python
+
+ from airflow.providers.common.ai.toolsets.mcp import MCPToolset
+
+
+ def mint_snowflake_jwt() -> str:
+ # Sign a short-lived JWT from the Snowflake connection's key-pair.
+ ...
+
+
+ toolset = MCPToolset(
+ mcp_conn_id="snowflake_managed_mcp",
+ token_provider=mint_snowflake_jwt,
+ )
+
+``token_provider`` is resolved in DAG code (it is a Python callable, not a
stored
+connection field), so the signing key stays in your environment and is never
baked
+into the serialized DAG.
diff --git a/providers/common/ai/docs/toolsets.rst
b/providers/common/ai/docs/toolsets.rst
index ee6d8a85b63..a9f454fbc89 100644
--- a/providers/common/ai/docs/toolsets.rst
+++ b/providers/common/ai/docs/toolsets.rst
@@ -272,6 +272,12 @@ Parameters
- ``tool_prefix``: Optional prefix prepended to tool names to avoid
collisions when using multiple MCP servers (e.g. ``"weather"`` produces
``"weather_get_forecast"``).
+- ``token_provider``: Optional zero-argument callable returning a bearer token.
+ When set, it overrides the connection's static ``password`` for the
+ ``Authorization`` header and is called each time the server connection is
+ established -- use it for short-lived or minted tokens (e.g. a Snowflake
+ managed MCP server authenticated with a key-pair JWT). See
+ :ref:`howto/connection:mcp`.
Using Multiple MCP Servers
^^^^^^^^^^^^^^^^^^^^^^^^^^
diff --git a/providers/common/ai/src/airflow/providers/common/ai/hooks/mcp.py
b/providers/common/ai/src/airflow/providers/common/ai/hooks/mcp.py
index b69517ac6f4..f6749d280f3 100644
--- a/providers/common/ai/src/airflow/providers/common/ai/hooks/mcp.py
+++ b/providers/common/ai/src/airflow/providers/common/ai/hooks/mcp.py
@@ -16,10 +16,23 @@
# under the License.
from __future__ import annotations
-from typing import Any
+from typing import TYPE_CHECKING, Any
from airflow.providers.common.compat.sdk import BaseHook
+if TYPE_CHECKING:
+ from collections.abc import Callable
+
+ from airflow.sdk.execution_time.secrets_masker import mask_secret
+else:
+ try:
+ from airflow.sdk.log import mask_secret
+ except ImportError:
+ try:
+ from airflow.sdk.execution_time.secrets_masker import mask_secret
+ except ImportError:
+ from airflow.utils.log.secrets_masker import mask_secret
+
class MCPHook(BaseHook):
"""
@@ -36,9 +49,22 @@ class MCPHook(BaseHook):
- **Extra.args**: Command arguments for stdio transport (e.g.
``["mcp-run-python"]``)
- **Extra.timeout**: Connection timeout in seconds for stdio (default:
10)
+ For HTTP/SSE transports the ``Authorization`` header is, by default, a
static
+ ``Bearer`` token taken from the connection ``password``. Endpoints that
require
+ a freshly minted or short-lived token (e.g. a Snowflake managed MCP server
+ authenticated with a key-pair JWT, OAuth/refresh tokens, Workload Identity
+ Federation, or GitHub App installation tokens) can pass a
``token_provider``
+ callable instead. It is invoked each time the server connection is
established
+ and its return value is used as the bearer token, so a fresh token is
minted
+ without storing a long-lived secret in the connection.
+
:param mcp_conn_id: Airflow connection ID for the MCP server.
:param tool_prefix: Optional prefix prepended to tool names
(e.g. ``"weather"`` → ``"weather_get_forecast"``).
+ :param token_provider: Optional zero-argument callable returning a bearer
+ token string. When set, it overrides the connection ``password`` for
the
+ ``Authorization`` header on HTTP/SSE transports and is called each
time the
+ server connection is established. Ignored for the ``stdio`` transport.
"""
conn_name_attr = "mcp_conn_id"
@@ -50,11 +76,14 @@ class MCPHook(BaseHook):
self,
mcp_conn_id: str = default_conn_name,
tool_prefix: str | None = None,
+ *,
+ token_provider: Callable[[], str] | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.mcp_conn_id = mcp_conn_id
self.tool_prefix = tool_prefix
+ self.token_provider = token_provider
self._server: Any = None
@staticmethod
@@ -68,6 +97,28 @@ class MCPHook(BaseHook):
},
}
+ def _auth_headers(self, conn: Any) -> dict[str, str] | None:
+ """
+ Build the ``Authorization`` header for HTTP/SSE transports.
+
+ Prefers ``token_provider`` (minted per connection) over the static
+ connection ``password``. Returns ``None`` when neither yields a token.
+ """
+ if self.token_provider is not None:
+ token = self.token_provider()
+ if not isinstance(token, str) or not token:
+ raise ValueError(
+ f"token_provider for connection {self.mcp_conn_id!r} must
return a non-empty "
+ f"string token, got {type(token).__name__}."
+ )
+ # The static connection password is masked when the connection is
+ # fetched; mask the minted token too so it never leaks into task
logs.
+ mask_secret(token)
+ return {"Authorization": f"Bearer {token}"}
+ if conn.password:
+ return {"Authorization": f"Bearer {conn.password}"}
+ return None
+
def get_conn(self) -> Any:
"""
Return a configured PydanticAI MCP server instance.
@@ -94,16 +145,19 @@ class MCPHook(BaseHook):
conn = self.get_connection(self.mcp_conn_id)
extra = conn.extra_dejson
transport = extra.get("transport", "http")
- headers = {"Authorization": f"Bearer {conn.password}"} if
conn.password else None
if transport == "http":
if not conn.host:
raise ValueError(f"Connection {self.mcp_conn_id!r} requires a
host URL for HTTP transport.")
- self._server = MCPServerStreamableHTTP(conn.host, headers=headers,
tool_prefix=self.tool_prefix)
+ self._server = MCPServerStreamableHTTP(
+ conn.host, headers=self._auth_headers(conn),
tool_prefix=self.tool_prefix
+ )
elif transport == "sse":
if not conn.host:
raise ValueError(f"Connection {self.mcp_conn_id!r} requires a
host URL for SSE transport.")
- self._server = MCPServerSSE(conn.host, headers=headers,
tool_prefix=self.tool_prefix)
+ self._server = MCPServerSSE(
+ conn.host, headers=self._auth_headers(conn),
tool_prefix=self.tool_prefix
+ )
elif transport == "stdio":
command = extra.get("command")
if not command:
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/toolsets/mcp.py
b/providers/common/ai/src/airflow/providers/common/ai/toolsets/mcp.py
index 0fa085eba0c..d7aae3f9a6a 100644
--- a/providers/common/ai/src/airflow/providers/common/ai/toolsets/mcp.py
+++ b/providers/common/ai/src/airflow/providers/common/ai/toolsets/mcp.py
@@ -24,6 +24,8 @@ from pydantic_ai.toolsets.abstract import AbstractToolset,
ToolsetTool
from typing_extensions import Self
if TYPE_CHECKING:
+ from collections.abc import Callable
+
from pydantic_ai._run_context import RunContext
@@ -46,9 +48,18 @@ class MCPToolset(AbstractToolset[Any]):
:class:`~pydantic_ai.mcp.MCPServerSSE`, and
:class:`~pydantic_ai.mcp.MCPServerStdio` all implement ``AbstractToolset``.
+ For MCP endpoints that need a freshly minted or short-lived token (e.g. a
+ Snowflake managed MCP server authenticated with a key-pair JWT, or OAuth /
+ Workload Identity / GitHub App tokens), pass a ``token_provider`` callable.
+ It is invoked each time the connection is established and its return value
is
+ used as the bearer token, so a fresh token is minted rather than storing a
+ long-lived secret in the connection.
+
:param mcp_conn_id: Airflow connection ID for the MCP server.
:param tool_prefix: Optional prefix prepended to tool names
(e.g. ``"weather"`` → ``"weather_get_forecast"``).
+ :param token_provider: Optional zero-argument callable returning a bearer
+ token string, overriding the connection ``password`` for HTTP/SSE auth.
"""
def __init__(
@@ -56,9 +67,11 @@ class MCPToolset(AbstractToolset[Any]):
mcp_conn_id: str,
*,
tool_prefix: str | None = None,
+ token_provider: Callable[[], str] | None = None,
) -> None:
self._mcp_conn_id = mcp_conn_id
self._tool_prefix = tool_prefix
+ self._token_provider = token_provider
self._server: Any = None
@property
@@ -69,7 +82,11 @@ class MCPToolset(AbstractToolset[Any]):
if self._server is None:
from airflow.providers.common.ai.hooks.mcp import MCPHook
- hook = MCPHook(mcp_conn_id=self._mcp_conn_id,
tool_prefix=self._tool_prefix)
+ hook = MCPHook(
+ mcp_conn_id=self._mcp_conn_id,
+ tool_prefix=self._tool_prefix,
+ token_provider=self._token_provider,
+ )
self._server = hook.get_conn()
return self._server
diff --git a/providers/common/ai/tests/unit/common/ai/hooks/test_mcp.py
b/providers/common/ai/tests/unit/common/ai/hooks/test_mcp.py
index cf7e983f542..2928376d699 100644
--- a/providers/common/ai/tests/unit/common/ai/hooks/test_mcp.py
+++ b/providers/common/ai/tests/unit/common/ai/hooks/test_mcp.py
@@ -17,7 +17,7 @@
from __future__ import annotations
import json
-from unittest.mock import patch
+from unittest.mock import MagicMock, patch
import pytest
@@ -246,3 +246,103 @@ class TestMCPHookUIFieldBehaviour:
def test_relabeling(self):
behaviour = MCPHook.get_ui_field_behaviour()
assert behaviour["relabeling"]["password"] == "Auth Token"
+
+
+class TestMCPHookTokenProvider:
+ @patch(_MCP_HTTP)
+ def test_http_uses_token_provider(self, mock_server_cls):
+ hook = MCPHook(mcp_conn_id="test_conn", token_provider=lambda:
"minted-jwt")
+ conn = Connection(conn_id="test_conn", conn_type="mcp",
host="http://localhost:3001/mcp")
+ with patch.object(hook, "get_connection", return_value=conn):
+ hook.get_conn()
+
+ mock_server_cls.assert_called_once_with(
+ "http://localhost:3001/mcp",
+ headers={"Authorization": "Bearer minted-jwt"},
+ tool_prefix=None,
+ )
+
+ @patch(_MCP_HTTP)
+ def test_token_provider_overrides_static_password(self, mock_server_cls):
+ hook = MCPHook(mcp_conn_id="test_conn", token_provider=lambda: "fresh")
+ conn = Connection(
+ conn_id="test_conn",
+ conn_type="mcp",
+ host="http://localhost:3001/mcp",
+ password="static-pat",
+ )
+ with patch.object(hook, "get_connection", return_value=conn):
+ hook.get_conn()
+
+ mock_server_cls.assert_called_once_with(
+ "http://localhost:3001/mcp",
+ headers={"Authorization": "Bearer fresh"},
+ tool_prefix=None,
+ )
+
+ @patch(_MCP_HTTP)
+ def test_token_provider_called_when_establishing_connection(self,
mock_server_cls):
+ provider = MagicMock(return_value="tok")
+ hook = MCPHook(mcp_conn_id="test_conn", token_provider=provider)
+ conn = Connection(conn_id="test_conn", conn_type="mcp",
host="http://localhost:3001/mcp")
+ with patch.object(hook, "get_connection", return_value=conn):
+ hook.get_conn()
+
+ provider.assert_called_once_with()
+
+ @patch(_MCP_HTTP)
+ def test_masks_minted_token(self, mock_server_cls):
+ """The minted token must be registered with secret masking, like
conn.password."""
+ hook = MCPHook(mcp_conn_id="test_conn", token_provider=lambda:
"minted-jwt")
+ conn = Connection(conn_id="test_conn", conn_type="mcp",
host="http://localhost:3001/mcp")
+ with (
+ patch.object(hook, "get_connection", return_value=conn),
+ patch("airflow.providers.common.ai.hooks.mcp.mask_secret") as
mock_mask,
+ ):
+ hook.get_conn()
+
+ mock_mask.assert_called_once_with("minted-jwt")
+
+ @pytest.mark.parametrize("bad_token", ["", None], ids=["empty",
"non_string"])
+ @patch(_MCP_HTTP)
+ def test_invalid_token_raises(self, mock_server_cls, bad_token):
+ """A token_provider returning a non-string or empty value fails
loud."""
+ hook = MCPHook(mcp_conn_id="test_conn", token_provider=lambda:
bad_token)
+ conn = Connection(conn_id="test_conn", conn_type="mcp",
host="http://localhost:3001/mcp")
+ with patch.object(hook, "get_connection", return_value=conn):
+ with pytest.raises(ValueError, match="must return a non-empty
string token"):
+ hook.get_conn()
+
+ @patch(_MCP_SSE)
+ def test_sse_uses_token_provider(self, mock_server_cls):
+ hook = MCPHook(mcp_conn_id="test_conn", token_provider=lambda:
"minted")
+ conn = Connection(
+ conn_id="test_conn",
+ conn_type="mcp",
+ host="http://localhost:3001/sse",
+ extra=json.dumps({"transport": "sse"}),
+ )
+ with patch.object(hook, "get_connection", return_value=conn):
+ hook.get_conn()
+
+ mock_server_cls.assert_called_once_with(
+ "http://localhost:3001/sse",
+ headers={"Authorization": "Bearer minted"},
+ tool_prefix=None,
+ )
+
+ @patch(_MCP_STDIO)
+ def test_stdio_does_not_invoke_token_provider(self, mock_server_cls):
+ """stdio has no HTTP headers, so the token provider must not be
called."""
+ provider = MagicMock(return_value="tok")
+ hook = MCPHook(mcp_conn_id="test_conn", token_provider=provider)
+ conn = Connection(
+ conn_id="test_conn",
+ conn_type="mcp",
+ extra=json.dumps({"transport": "stdio", "command": "uvx", "args":
["mcp-run-python"]}),
+ )
+ with patch.object(hook, "get_connection", return_value=conn):
+ hook.get_conn()
+
+ provider.assert_not_called()
+ mock_server_cls.assert_called_once_with("uvx",
args=["mcp-run-python"], timeout=10, tool_prefix=None)
diff --git a/providers/common/ai/tests/unit/common/ai/toolsets/test_mcp.py
b/providers/common/ai/tests/unit/common/ai/toolsets/test_mcp.py
index 5391ca143de..bbc3e079d58 100644
--- a/providers/common/ai/tests/unit/common/ai/toolsets/test_mcp.py
+++ b/providers/common/ai/tests/unit/common/ai/toolsets/test_mcp.py
@@ -43,7 +43,7 @@ class TestMCPToolsetGetServer:
ts = MCPToolset("mcp_conn")
server = ts._get_server()
- mock_hook_cls.assert_called_once_with(mcp_conn_id="mcp_conn",
tool_prefix=None)
+ mock_hook_cls.assert_called_once_with(mcp_conn_id="mcp_conn",
tool_prefix=None, token_provider=None)
mock_hook_cls.return_value.get_conn.assert_called_once()
assert server is mock_server
@@ -54,7 +54,23 @@ class TestMCPToolsetGetServer:
ts = MCPToolset("mcp_conn", tool_prefix="weather")
ts._get_server()
- mock_hook_cls.assert_called_once_with(mcp_conn_id="mcp_conn",
tool_prefix="weather")
+ mock_hook_cls.assert_called_once_with(
+ mcp_conn_id="mcp_conn", tool_prefix="weather", token_provider=None
+ )
+
+ @patch(_HOOK_PATH, autospec=True)
+ def test_passes_token_provider_to_hook(self, mock_hook_cls):
+ mock_hook_cls.return_value.get_conn.return_value = MagicMock()
+
+ def provider() -> str:
+ return "minted"
+
+ ts = MCPToolset("mcp_conn", token_provider=provider)
+ ts._get_server()
+
+ mock_hook_cls.assert_called_once_with(
+ mcp_conn_id="mcp_conn", tool_prefix=None, token_provider=provider
+ )
@patch(_HOOK_PATH, autospec=True)
def test_caches_server(self, mock_hook_cls):