This is an automated email from the ASF dual-hosted git repository.
arivero pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git
The following commit(s) were added to refs/heads/master by this push:
new ae99b194225 feat(mcp): add detailed JWT error messages and default
auth factory fallback (#37972)
ae99b194225 is described below
commit ae99b19422517fce5a1b4f50fb473ea4991ac170
Author: Amin Ghadersohi <[email protected]>
AuthorDate: Thu Feb 26 08:21:40 2026 -0500
feat(mcp): add detailed JWT error messages and default auth factory
fallback (#37972)
Co-authored-by: Claude Opus 4.6 <[email protected]>
---
superset/mcp_service/auth.py | 18 +-
superset/mcp_service/jwt_verifier.py | 320 ++++++++++
superset/mcp_service/mcp_config.py | 68 +-
superset/mcp_service/server.py | 50 +-
tests/unit_tests/mcp_service/test_jwt_verifier.py | 726 ++++++++++++++++++++++
5 files changed, 1138 insertions(+), 44 deletions(-)
diff --git a/superset/mcp_service/auth.py b/superset/mcp_service/auth.py
index ac11ecd62d0..8675947097f 100644
--- a/superset/mcp_service/auth.py
+++ b/superset/mcp_service/auth.py
@@ -111,9 +111,23 @@ def get_user_from_request() -> User:
username = current_app.config.get("MCP_DEV_USERNAME")
if not username:
+ auth_enabled = current_app.config.get("MCP_AUTH_ENABLED", False)
+ jwt_configured = bool(
+ current_app.config.get("MCP_JWKS_URI")
+ or current_app.config.get("MCP_JWT_PUBLIC_KEY")
+ or current_app.config.get("MCP_JWT_SECRET")
+ )
+ details = []
+ details.append(
+ f"g.user was not set by JWT middleware "
+ f"(MCP_AUTH_ENABLED={auth_enabled}, "
+ f"JWT keys configured={jwt_configured})"
+ )
+ details.append("MCP_DEV_USERNAME is not configured")
raise ValueError(
- "No authenticated user found. "
- "Either pass a valid JWT bearer token or configure "
+ "No authenticated user found. Tried:\n"
+ + "\n".join(f" - {d}" for d in details)
+ + "\n\nEither pass a valid JWT bearer token or configure "
"MCP_DEV_USERNAME for development."
)
diff --git a/superset/mcp_service/jwt_verifier.py
b/superset/mcp_service/jwt_verifier.py
new file mode 100644
index 00000000000..abd8019d9e4
--- /dev/null
+++ b/superset/mcp_service/jwt_verifier.py
@@ -0,0 +1,320 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Detailed JWT verification for the MCP service.
+
+Provides step-by-step JWT validation with tiered server-side logging:
+- WARNING level: generic failure categories only (e.g. "Issuer mismatch")
+- DEBUG level: detailed claim values and config for troubleshooting
+- Secrets (e.g. HS256 keys) are NEVER logged at any level
+
+HTTP responses always return generic errors per RFC 6750 Section 3.1.
+"""
+
+import base64
+import logging
+import time
+from contextvars import ContextVar
+from typing import Any, cast
+
+from authlib.jose.errors import (
+ BadSignatureError,
+ DecodeError,
+ ExpiredTokenError,
+ JoseError,
+)
+from fastmcp.server.auth.auth import AccessToken
+from fastmcp.server.auth.providers.jwt import JWTVerifier
+from mcp.server.auth.middleware.auth_context import AuthContextMiddleware
+from mcp.server.auth.middleware.bearer_auth import BearerAuthBackend
+from starlette.authentication import AuthenticationError
+from starlette.middleware import Middleware
+from starlette.middleware.authentication import AuthenticationMiddleware
+from starlette.requests import HTTPConnection
+from starlette.responses import JSONResponse
+
+from superset.utils import json
+
+logger = logging.getLogger(__name__)
+
+# Thread-safe storage for the specific JWT failure reason.
+# Set by DetailedJWTVerifier.load_access_token() on failure,
+# read by DetailedBearerAuthBackend.authenticate() to raise
+# an AuthenticationError with the specific reason.
+# SECURITY: Must ALWAYS contain generic failure categories only.
+# Claim values and secrets must NEVER be stored here.
+_jwt_failure_reason: ContextVar[str | None] = ContextVar(
+ "_jwt_failure_reason", default=None
+)
+
+
+def _json_auth_error_handler(
+ conn: HTTPConnection, exc: AuthenticationError
+) -> JSONResponse:
+ """JSON 401 error handler for authentication failures.
+
+ Per RFC 6750 Section 3.1, error responses MUST NOT leak server
+ configuration or token claim values. Only generic error codes are
+ returned to clients. Detailed failure reasons are logged server-side
+ only for debugging.
+
+ References:
+ - RFC 6750 Section 3.1:
https://datatracker.ietf.org/doc/html/rfc6750#section-3.1
+ - CVE-2022-29266, CVE-2019-7644: verbose JWT errors led to exploits
+ """
+ # Log detailed reason server-side only
+ logger.warning("JWT authentication failed: %s", exc)
+
+ return JSONResponse(
+ status_code=401,
+ content={
+ "error": "invalid_token",
+ "error_description": "Authentication failed",
+ },
+ headers={
+ "WWW-Authenticate": 'Bearer error="invalid_token"',
+ },
+ )
+
+
+class DetailedBearerAuthBackend(BearerAuthBackend):
+ """
+ Bearer auth backend that raises AuthenticationError with specific
+ JWT failure reasons instead of silently returning None.
+ """
+
+ async def authenticate(self, conn: HTTPConnection) -> tuple[Any, Any] |
None:
+ result = await super().authenticate(conn)
+
+ if result is not None:
+ # Clear any stale failure reason on success
+ _jwt_failure_reason.set(None)
+ return result
+
+ # Check if there's a Bearer token present - if so, there was a
+ # validation failure we can report with a specific reason
+ auth_header = next(
+ (
+ conn.headers.get(key)
+ for key in conn.headers
+ if key.lower() == "authorization"
+ ),
+ None,
+ )
+ if auth_header and auth_header.lower().startswith("bearer "):
+ reason = _jwt_failure_reason.get()
+ if reason:
+ _jwt_failure_reason.set(None)
+ raise AuthenticationError(reason)
+
+ return None
+
+
+class DetailedJWTVerifier(JWTVerifier):
+ """
+ JWT verifier with tiered server-side logging for each validation step.
+
+ Logging tiers:
+ - WARNING: generic failure categories only (via _jwt_failure_reason
ContextVar
+ and the error handler). No claim values, no server config.
+ - DEBUG: detailed values (issuer, audience, client_id, scopes, exceptions)
+ for operator troubleshooting.
+ - Secrets (e.g. HS256 signing keys) are NEVER logged at any level.
+
+ HTTP responses always return generic errors per RFC 6750 Section 3.1.
+ Controlled by MCP_JWT_DEBUG_ERRORS config flag.
+ """
+
+ async def load_access_token(self, token: str) -> AccessToken | None: #
noqa: C901
+ """
+ Validate a JWT bearer token with detailed error reporting.
+
+ Each validation step stores a specific failure reason in the
+ _jwt_failure_reason ContextVar before returning None.
+ """
+ # Reset any previous failure reason
+ _jwt_failure_reason.set(None)
+
+ try:
+ # Step 1: Decode header and check algorithm
+ try:
+ header = self._decode_token_header(token)
+ except (ValueError, DecodeError) as e:
+ reason = "Malformed token header"
+ _jwt_failure_reason.set(reason)
+ logger.debug("Malformed token header: %s", e)
+ return None
+
+ token_alg = header.get("alg")
+ if self.algorithm and token_alg != self.algorithm:
+ reason = "Algorithm mismatch"
+ _jwt_failure_reason.set(reason)
+ logger.debug(
+ "Algorithm mismatch: token uses '%s', expected '%s'",
+ token_alg,
+ self.algorithm,
+ )
+ return None
+
+ # Step 2: Get verification key (static or JWKS)
+ try:
+ verification_key = await self._get_verification_key(token)
+ except ValueError as e:
+ reason = "Failed to get verification key"
+ _jwt_failure_reason.set(reason)
+ logger.debug("Failed to get verification key: %s", e)
+ return None
+
+ # Step 3: Decode and verify signature
+ try:
+ claims = self.jwt.decode(token, verification_key)
+ except BadSignatureError:
+ reason = "Signature verification failed"
+ _jwt_failure_reason.set(reason)
+ return None
+ except ExpiredTokenError:
+ reason = "Token has expired (detected during decode)"
+ _jwt_failure_reason.set(reason)
+ return None
+ except JoseError as e:
+ reason = "Token decode failed"
+ _jwt_failure_reason.set(reason)
+ logger.debug("Token decode failed: %s", e)
+ return None
+
+ # Extract client ID for logging
+ client_id = (
+ claims.get("client_id")
+ or claims.get("azp")
+ or claims.get("sub")
+ or "unknown"
+ )
+
+ # Step 4: Check expiration
+ exp = claims.get("exp")
+ if exp and exp < time.time():
+ reason = "Token expired"
+ _jwt_failure_reason.set(reason)
+ logger.debug("Token expired for client '%s'", client_id)
+ return None
+
+ # Step 5: Validate issuer
+ if self.issuer:
+ iss = claims.get("iss")
+ if isinstance(self.issuer, list):
+ issuer_valid = iss in self.issuer
+ else:
+ issuer_valid = iss == self.issuer
+
+ if not issuer_valid:
+ reason = "Issuer mismatch"
+ _jwt_failure_reason.set(reason)
+ logger.debug(
+ "Issuer mismatch: token has '%s', expected '%s'",
+ iss,
+ self.issuer,
+ )
+ return None
+
+ # Step 6: Validate audience
+ if self.audience:
+ aud = claims.get("aud")
+ if isinstance(self.audience, list):
+ if isinstance(aud, list):
+ audience_valid = any(
+ expected in aud
+ for expected in cast(list[str], self.audience)
+ )
+ else:
+ audience_valid = aud in cast(list[str], self.audience)
+ else:
+ if isinstance(aud, list):
+ audience_valid = self.audience in aud
+ else:
+ audience_valid = aud == self.audience
+
+ if not audience_valid:
+ reason = "Audience mismatch"
+ _jwt_failure_reason.set(reason)
+ logger.debug(
+ "Audience mismatch: token has '%s', expected '%s'",
+ aud,
+ self.audience,
+ )
+ return None
+
+ # Step 7: Check required scopes
+ scopes = self._extract_scopes(claims)
+ if self.required_scopes:
+ token_scopes = set(scopes)
+ required = set(self.required_scopes)
+ if not required.issubset(token_scopes):
+ missing = required - token_scopes
+ reason = "Missing required scopes"
+ _jwt_failure_reason.set(reason)
+ logger.debug(
+ "Missing required scopes: %s. Token has: %s",
+ missing,
+ token_scopes,
+ )
+ return None
+
+ # All validations passed
+ return AccessToken(
+ token=token,
+ client_id=str(client_id),
+ scopes=scopes,
+ expires_at=int(exp) if exp else None,
+ claims=dict(claims),
+ )
+
+ except (ValueError, JoseError, KeyError, AttributeError, TypeError) as
e:
+ reason = "Token validation failed"
+ _jwt_failure_reason.set(reason)
+ logger.debug("Token validation failed: %s", e)
+ return None
+
+ def get_middleware(self) -> list[Any]:
+ """
+ Get middleware with detailed server-side error logging.
+
+ Uses DetailedBearerAuthBackend which raises AuthenticationError
+ with specific reasons for server-side logging. The error handler
+ always returns generic 401 responses per RFC 6750.
+ """
+ return [
+ Middleware(
+ AuthenticationMiddleware,
+ backend=DetailedBearerAuthBackend(self),
+ on_error=_json_auth_error_handler,
+ ),
+ Middleware(AuthContextMiddleware),
+ ]
+
+ @staticmethod
+ def _decode_token_header(token: str) -> dict[str, Any]:
+ """Decode the JWT header without verifying the signature."""
+ parts = token.split(".")
+ if len(parts) != 3:
+ raise ValueError(
+ f"Token must have 3 parts (header.payload.signature), got
{len(parts)}"
+ )
+ header_b64 = parts[0]
+ # Add padding only if needed
+ header_b64 += "=" * (-len(header_b64) % 4)
+ header_bytes = base64.urlsafe_b64decode(header_b64)
+ return json.loads(header_bytes)
diff --git a/superset/mcp_service/mcp_config.py
b/superset/mcp_service/mcp_config.py
index 86a772aa483..dab549b9b06 100644
--- a/superset/mcp_service/mcp_config.py
+++ b/superset/mcp_service/mcp_config.py
@@ -45,6 +45,16 @@ MCP_SERVICE_PORT = 5008
# MCP Debug mode - shows suppressed initialization output in stdio mode
MCP_DEBUG = False
+# MCP JWT Debug Errors - controls server-side JWT debug logging.
+# When False (default), uses the default JWTVerifier with minimal logging.
+# When True, uses DetailedJWTVerifier with tiered logging:
+# - WARNING level: generic failure categories only (e.g. "Issuer mismatch")
+# - DEBUG level: detailed claim values for troubleshooting
+# - Secrets (e.g. HS256 keys) are NEVER logged at any level
+# HTTP responses ALWAYS return generic errors regardless of this setting,
+# per RFC 6750 Section 3.1. This flag NEVER affects client-facing output.
+MCP_JWT_DEBUG_ERRORS = False
+
# Enable parse_request decorator for MCP tools.
# When True (default), tool requests are automatically parsed from JSON strings
# to Pydantic models, working around a Claude Code double-serialization bug
@@ -231,47 +241,47 @@ def create_default_mcp_auth_factory(app: Flask) ->
Optional[Any]:
return None
try:
- from fastmcp.server.auth.providers.jwt import JWTVerifier
+ debug_errors = app.config.get("MCP_JWT_DEBUG_ERRORS", False)
+
+ common_kwargs: dict[str, Any] = {
+ "issuer": app.config.get("MCP_JWT_ISSUER"),
+ "audience": app.config.get("MCP_JWT_AUDIENCE"),
+ "required_scopes": app.config.get("MCP_REQUIRED_SCOPES", []),
+ }
# For HS256 (symmetric), use the secret as the public_key parameter
if app.config.get("MCP_JWT_ALGORITHM") == "HS256" and secret:
- auth_provider = JWTVerifier(
- public_key=secret, # HS256 uses secret as key
- issuer=app.config.get("MCP_JWT_ISSUER"),
- audience=app.config.get("MCP_JWT_AUDIENCE"),
- algorithm="HS256",
- required_scopes=app.config.get("MCP_REQUIRED_SCOPES", []),
- )
- logger.info("Created JWTVerifier with HS256 secret")
+ common_kwargs["public_key"] = secret
+ common_kwargs["algorithm"] = "HS256"
else:
# For RS256 (asymmetric), use public key or JWKS
- auth_provider = JWTVerifier(
- jwks_uri=jwks_uri,
- public_key=public_key,
- issuer=app.config.get("MCP_JWT_ISSUER"),
- audience=app.config.get("MCP_JWT_AUDIENCE"),
- algorithm=app.config.get("MCP_JWT_ALGORITHM", "RS256"),
- required_scopes=app.config.get("MCP_REQUIRED_SCOPES", []),
- )
- logger.info(
- "Created JWTVerifier with jwks_uri=%s, public_key=%s",
- jwks_uri,
- "***" if public_key else None,
- )
+ common_kwargs["jwks_uri"] = jwks_uri
+ common_kwargs["public_key"] = public_key
+ common_kwargs["algorithm"] = app.config.get("MCP_JWT_ALGORITHM",
"RS256")
+
+ if debug_errors:
+ # DetailedJWTVerifier: detailed server-side logging of JWT
+ # validation failures. HTTP responses are always generic per
+ # RFC 6750 Section 3.1.
+ from superset.mcp_service.jwt_verifier import DetailedJWTVerifier
+
+ auth_provider = DetailedJWTVerifier(**common_kwargs)
+ else:
+ # Default JWTVerifier: minimal logging, generic error responses.
+ from fastmcp.server.auth.providers.jwt import JWTVerifier
+
+ auth_provider = JWTVerifier(**common_kwargs)
return auth_provider
- except Exception as e:
- logger.error("Failed to create MCP auth provider: %s", e)
+ except Exception:
+ # Do not log the exception — it may contain the HS256 secret
+ # from common_kwargs["public_key"]
+ logger.error("Failed to create MCP auth provider")
return None
def default_user_resolver(app: Any, access_token: Any) -> Optional[str]:
"""Extract username from JWT token claims."""
- logger.info(
- "Resolving user from token: type=%s, token=%s",
- type(access_token),
- access_token,
- )
if hasattr(access_token, "subject"):
return access_token.subject
if hasattr(access_token, "client_id"):
diff --git a/superset/mcp_service/server.py b/superset/mcp_service/server.py
index a0d7e7f5d42..9e9af7dd47b 100644
--- a/superset/mcp_service/server.py
+++ b/superset/mcp_service/server.py
@@ -33,6 +33,8 @@ from superset.mcp_service.mcp_config import
get_mcp_factory_config, MCP_STORE_CO
from superset.mcp_service.middleware import
create_response_size_guard_middleware
from superset.mcp_service.storage import _create_redis_store
+logger = logging.getLogger(__name__)
+
def configure_logging(debug: bool = False) -> None:
"""Configure logging for the MCP service."""
@@ -119,6 +121,40 @@ def create_event_store(config: dict[str, Any] | None =
None) -> Any | None:
return None
+def _create_auth_provider(flask_app: Any) -> Any | None:
+ """Create an auth provider from Flask app config.
+
+ Tries MCP_AUTH_FACTORY first, then falls back to the default factory
+ when MCP_AUTH_ENABLED is True.
+ """
+ auth_provider = None
+ if auth_factory := flask_app.config.get("MCP_AUTH_FACTORY"):
+ try:
+ auth_provider = auth_factory(flask_app)
+ logger.info(
+ "Auth provider created from MCP_AUTH_FACTORY: %s",
+ type(auth_provider).__name__ if auth_provider else "None",
+ )
+ except Exception:
+ # Do not log the exception — it may contain secrets
+ logger.error("Failed to create auth provider from
MCP_AUTH_FACTORY")
+ elif flask_app.config.get("MCP_AUTH_ENABLED", False):
+ from superset.mcp_service.mcp_config import (
+ create_default_mcp_auth_factory,
+ )
+
+ try:
+ auth_provider = create_default_mcp_auth_factory(flask_app)
+ logger.info(
+ "Auth provider created from default factory: %s",
+ type(auth_provider).__name__ if auth_provider else "None",
+ )
+ except Exception:
+ # Do not log the exception — it may contain secrets
+ logger.error("Failed to create auth provider from default factory")
+ return auth_provider
+
+
def run_server(
host: str = "127.0.0.1",
port: int = 5008,
@@ -158,19 +194,7 @@ def run_server(
from superset.mcp_service.flask_singleton import get_flask_app
flask_app = get_flask_app()
-
- # Get auth factory from config and create auth provider
- auth_provider = None
- auth_factory = flask_app.config.get("MCP_AUTH_FACTORY")
- if auth_factory:
- try:
- auth_provider = auth_factory(flask_app)
- logging.info(
- "Auth provider created: %s",
- type(auth_provider).__name__ if auth_provider else "None",
- )
- except Exception as e:
- logging.error("Failed to create auth provider: %s", e)
+ auth_provider = _create_auth_provider(flask_app)
# Build middleware list
middleware_list = []
diff --git a/tests/unit_tests/mcp_service/test_jwt_verifier.py
b/tests/unit_tests/mcp_service/test_jwt_verifier.py
new file mode 100644
index 00000000000..b6ba58aa7f2
--- /dev/null
+++ b/tests/unit_tests/mcp_service/test_jwt_verifier.py
@@ -0,0 +1,726 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Tests for DetailedJWTVerifier and related middleware."""
+
+import base64
+import logging
+import time
+from unittest.mock import AsyncMock, MagicMock, patch
+
+import pytest
+from authlib.jose.errors import BadSignatureError, DecodeError,
ExpiredTokenError
+
+from superset.mcp_service.jwt_verifier import (
+ _json_auth_error_handler,
+ _jwt_failure_reason,
+ DetailedBearerAuthBackend,
+ DetailedJWTVerifier,
+)
+from superset.utils import json
+
+
+def _make_token(
+ header: dict[str, str], payload: dict[str, object], signature: str = "sig"
+) -> str:
+ """Build a fake JWT string from header + payload dicts."""
+ h =
base64.urlsafe_b64encode(json.dumps(header).encode()).rstrip(b"=").decode()
+ p =
base64.urlsafe_b64encode(json.dumps(payload).encode()).rstrip(b"=").decode()
+ return f"{h}.{p}.{signature}"
+
+
[email protected]
+def hs256_verifier():
+ """Create a DetailedJWTVerifier configured for HS256."""
+ return DetailedJWTVerifier(
+ public_key="test-secret-key-for-hs256-tokens",
+ issuer="test-issuer",
+ audience="test-audience",
+ algorithm="HS256",
+ required_scopes=[],
+ )
+
+
[email protected](autouse=True)
+def _reset_contextvar():
+ """Reset the failure reason contextvar before each test."""
+ _jwt_failure_reason.set(None)
+ yield
+ _jwt_failure_reason.set(None)
+
+
[email protected]
+async def test_algorithm_mismatch(hs256_verifier):
+ """Token with wrong algorithm should report algorithm mismatch."""
+ token = _make_token(
+ {"alg": "RS256", "typ": "JWT"},
+ {"sub": "user1", "iss": "test-issuer", "aud": "test-audience"},
+ )
+
+ result = await hs256_verifier.load_access_token(token)
+
+ assert result is None
+ reason = _jwt_failure_reason.get()
+ assert reason == "Algorithm mismatch"
+ # Claim values must not leak into the contextvar reason
+ assert "RS256" not in reason
+ assert "HS256" not in reason
+
+
[email protected]
+async def test_malformed_token_header(hs256_verifier):
+ """Token with invalid header should report malformed header."""
+ # A token with only 2 parts (missing signature)
+ result = await hs256_verifier.load_access_token("part1.part2")
+
+ assert result is None
+ reason = _jwt_failure_reason.get()
+ assert reason == "Malformed token header"
+
+
[email protected]
+async def test_signature_verification_failed(hs256_verifier):
+ """Token with bad signature should report signature failure."""
+ token = _make_token(
+ {"alg": "HS256", "typ": "JWT"},
+ {
+ "sub": "user1",
+ "iss": "test-issuer",
+ "aud": "test-audience",
+ "exp": int(time.time()) + 3600,
+ },
+ )
+
+ with patch.object(
+ hs256_verifier.jwt,
+ "decode",
+ side_effect=BadSignatureError(result=None),
+ ):
+ result = await hs256_verifier.load_access_token(token)
+
+ assert result is None
+ reason = _jwt_failure_reason.get()
+ assert reason == "Signature verification failed"
+
+
[email protected]
+async def test_expired_token(hs256_verifier):
+ """Expired token should report token expired."""
+ expired_time = int(time.time()) - 3600
+ token = _make_token(
+ {"alg": "HS256", "typ": "JWT"},
+ {
+ "sub": "user1",
+ "iss": "test-issuer",
+ "aud": "test-audience",
+ "exp": expired_time,
+ },
+ )
+ claims = {
+ "sub": "user1",
+ "iss": "test-issuer",
+ "aud": "test-audience",
+ "exp": expired_time,
+ }
+
+ with patch.object(hs256_verifier.jwt, "decode", return_value=claims):
+ result = await hs256_verifier.load_access_token(token)
+
+ assert result is None
+ reason = _jwt_failure_reason.get()
+ assert reason == "Token expired"
+ # Claim values must not leak into the contextvar reason
+ assert "user1" not in reason
+
+
[email protected]
+async def test_issuer_mismatch(hs256_verifier):
+ """Token with wrong issuer should report issuer mismatch."""
+ token = _make_token(
+ {"alg": "HS256", "typ": "JWT"},
+ {
+ "sub": "user1",
+ "iss": "wrong-issuer",
+ "aud": "test-audience",
+ "exp": int(time.time()) + 3600,
+ },
+ )
+ claims = {
+ "sub": "user1",
+ "iss": "wrong-issuer",
+ "aud": "test-audience",
+ "exp": int(time.time()) + 3600,
+ }
+
+ with patch.object(hs256_verifier.jwt, "decode", return_value=claims):
+ result = await hs256_verifier.load_access_token(token)
+
+ assert result is None
+ reason = _jwt_failure_reason.get()
+ assert reason == "Issuer mismatch"
+ # Claim values must not leak into the contextvar reason
+ assert "wrong-issuer" not in reason
+ assert "test-issuer" not in reason
+
+
[email protected]
+async def test_audience_mismatch(hs256_verifier):
+ """Token with wrong audience should report audience mismatch."""
+ token = _make_token(
+ {"alg": "HS256", "typ": "JWT"},
+ {
+ "sub": "user1",
+ "iss": "test-issuer",
+ "aud": "wrong-audience",
+ "exp": int(time.time()) + 3600,
+ },
+ )
+ claims = {
+ "sub": "user1",
+ "iss": "test-issuer",
+ "aud": "wrong-audience",
+ "exp": int(time.time()) + 3600,
+ }
+
+ with patch.object(hs256_verifier.jwt, "decode", return_value=claims):
+ result = await hs256_verifier.load_access_token(token)
+
+ assert result is None
+ reason = _jwt_failure_reason.get()
+ assert reason == "Audience mismatch"
+ # Claim values must not leak into the contextvar reason
+ assert "wrong-audience" not in reason
+ assert "test-audience" not in reason
+
+
[email protected]
+async def test_missing_required_scopes(hs256_verifier):
+ """Token missing required scopes should report missing scopes."""
+ hs256_verifier.required_scopes = ["admin", "read"]
+
+ token = _make_token(
+ {"alg": "HS256", "typ": "JWT"},
+ {
+ "sub": "user1",
+ "iss": "test-issuer",
+ "aud": "test-audience",
+ "exp": int(time.time()) + 3600,
+ "scope": "read",
+ },
+ )
+ claims = {
+ "sub": "user1",
+ "iss": "test-issuer",
+ "aud": "test-audience",
+ "exp": int(time.time()) + 3600,
+ "scope": "read",
+ }
+
+ with patch.object(hs256_verifier.jwt, "decode", return_value=claims):
+ result = await hs256_verifier.load_access_token(token)
+
+ assert result is None
+ reason = _jwt_failure_reason.get()
+ assert reason == "Missing required scopes"
+ # Claim values must not leak into the contextvar reason
+ assert "admin" not in reason
+
+
[email protected]
+async def test_valid_token(hs256_verifier):
+ """Valid token should return AccessToken and clear contextvar."""
+ future_exp = int(time.time()) + 3600
+ token = _make_token(
+ {"alg": "HS256", "typ": "JWT"},
+ {
+ "sub": "user1",
+ "iss": "test-issuer",
+ "aud": "test-audience",
+ "exp": future_exp,
+ },
+ )
+ claims = {
+ "sub": "user1",
+ "iss": "test-issuer",
+ "aud": "test-audience",
+ "exp": future_exp,
+ }
+
+ with patch.object(hs256_verifier.jwt, "decode", return_value=claims):
+ result = await hs256_verifier.load_access_token(token)
+
+ assert result is not None
+ assert result.client_id == "user1"
+ assert result.expires_at == future_exp
+ # Contextvar should be None on success
+ assert _jwt_failure_reason.get() is None
+
+
[email protected]
+async def test_valid_token_no_expiration(hs256_verifier):
+ """Valid token without expiration should still succeed."""
+ token = _make_token(
+ {"alg": "HS256", "typ": "JWT"},
+ {
+ "sub": "user1",
+ "iss": "test-issuer",
+ "aud": "test-audience",
+ },
+ )
+ claims = {
+ "sub": "user1",
+ "iss": "test-issuer",
+ "aud": "test-audience",
+ }
+
+ with patch.object(hs256_verifier.jwt, "decode", return_value=claims):
+ result = await hs256_verifier.load_access_token(token)
+
+ assert result is not None
+ assert result.client_id == "user1"
+ assert result.expires_at is None
+
+
[email protected]
+async def test_decode_error(hs256_verifier):
+ """Token that fails to decode should report decode failure."""
+ token = _make_token(
+ {"alg": "HS256", "typ": "JWT"},
+ {"sub": "user1"},
+ )
+
+ with patch.object(
+ hs256_verifier.jwt,
+ "decode",
+ side_effect=DecodeError("bad token"),
+ ):
+ result = await hs256_verifier.load_access_token(token)
+
+ assert result is None
+ reason = _jwt_failure_reason.get()
+ assert reason == "Token decode failed"
+
+
[email protected]
+async def test_verification_key_failure(hs256_verifier):
+ """Failure to get verification key should report specific error."""
+ token = _make_token(
+ {"alg": "HS256", "typ": "JWT"},
+ {"sub": "user1"},
+ )
+
+ with patch.object(
+ hs256_verifier,
+ "_get_verification_key",
+ side_effect=ValueError("JWKS endpoint unreachable"),
+ ):
+ result = await hs256_verifier.load_access_token(token)
+
+ assert result is None
+ reason = _jwt_failure_reason.get()
+ assert reason == "Failed to get verification key"
+ # Exception details must not leak into the contextvar reason
+ assert "JWKS endpoint unreachable" not in reason
+
+
[email protected]
+async def test_contextvar_cleared_on_success(hs256_verifier):
+ """Contextvar should be reset to None before successful validation."""
+ # Set a stale failure reason
+ _jwt_failure_reason.set("previous failure")
+
+ future_exp = int(time.time()) + 3600
+ token = _make_token(
+ {"alg": "HS256", "typ": "JWT"},
+ {
+ "sub": "user1",
+ "iss": "test-issuer",
+ "aud": "test-audience",
+ "exp": future_exp,
+ },
+ )
+ claims = {
+ "sub": "user1",
+ "iss": "test-issuer",
+ "aud": "test-audience",
+ "exp": future_exp,
+ }
+
+ with patch.object(hs256_verifier.jwt, "decode", return_value=claims):
+ result = await hs256_verifier.load_access_token(token)
+
+ assert result is not None
+ assert _jwt_failure_reason.get() is None
+
+
+def test_decode_token_header_valid():
+ """_decode_token_header should decode a valid JWT header."""
+ header = {"alg": "RS256", "typ": "JWT", "kid": "key1"}
+ header_b64 = (
+
base64.urlsafe_b64encode(json.dumps(header).encode()).rstrip(b"=").decode()
+ )
+ token = f"{header_b64}.payload.signature"
+
+ result = DetailedJWTVerifier._decode_token_header(token)
+
+ assert result["alg"] == "RS256"
+ assert result["kid"] == "key1"
+
+
+def test_decode_token_header_too_few_parts():
+ """_decode_token_header should raise for tokens with wrong number of
parts."""
+ with pytest.raises(ValueError, match="3 parts"):
+ DetailedJWTVerifier._decode_token_header("only.two")
+
+
+def test_get_middleware_returns_custom_components(hs256_verifier):
+ """get_middleware should use DetailedBearerAuthBackend and generic error
handler."""
+ middleware_list = hs256_verifier.get_middleware()
+
+ assert len(middleware_list) == 2
+
+ # First middleware should be AuthenticationMiddleware with our custom
backend
+ auth_middleware = middleware_list[0]
+ assert (
+ auth_middleware.kwargs["backend"].__class__.__name__
+ == "DetailedBearerAuthBackend"
+ )
+ # on_error should be the RFC 6750-compliant generic handler
+ assert auth_middleware.kwargs["on_error"] is _json_auth_error_handler
+
+
+class _FakeHeaders(dict[str, str]):
+ """A dict subclass that allows overriding .get() for mock connections."""
+
+ def __init__(self, *args: object, **kwargs: object) -> None:
+ super().__init__(*args, **kwargs)
+
+ def get(self, key: str, default: str | None = None) -> str | None: #
type: ignore[override]
+ return super().get(key, default)
+
+
[email protected]
+async def test_detailed_bearer_backend_raises_on_failure():
+ """DetailedBearerAuthBackend should raise AuthenticationError with
reason."""
+ from starlette.authentication import AuthenticationError
+
+ mock_verifier = MagicMock()
+ mock_verifier.verify_token = AsyncMock(return_value=None)
+
+ backend = DetailedBearerAuthBackend(mock_verifier)
+
+ # Mock connection with Bearer token
+ mock_conn = MagicMock()
+ mock_conn.headers = _FakeHeaders({"authorization": "Bearer some-token"})
+
+ # Set failure reason (generic, no claim values)
+ _jwt_failure_reason.set("Token expired")
+
+ with pytest.raises(AuthenticationError, match="Token expired"):
+ await backend.authenticate(mock_conn)
+
+ # Contextvar should be cleared after raising
+ assert _jwt_failure_reason.get() is None
+
+
[email protected]
+async def test_detailed_bearer_backend_passes_through_success():
+ """DetailedBearerAuthBackend should return normally on success."""
+ mock_verifier = MagicMock()
+ mock_token = MagicMock()
+ mock_token.scopes = ["read"]
+ mock_token.expires_at = None
+ mock_verifier.verify_token = AsyncMock(return_value=mock_token)
+
+ backend = DetailedBearerAuthBackend(mock_verifier)
+
+ mock_conn = MagicMock()
+ mock_conn.headers = _FakeHeaders({"authorization": "Bearer valid-token"})
+
+ result = await backend.authenticate(mock_conn)
+
+ assert result is not None
+ assert _jwt_failure_reason.get() is None
+
+
[email protected]
+async def test_detailed_bearer_backend_no_bearer_token():
+ """DetailedBearerAuthBackend should return None when no Bearer token."""
+ mock_verifier = MagicMock()
+ mock_verifier.verify_token = AsyncMock(return_value=None)
+
+ backend = DetailedBearerAuthBackend(mock_verifier)
+
+ # Mock connection without auth header
+ mock_conn = MagicMock()
+ mock_conn.headers = _FakeHeaders({})
+
+ result = await backend.authenticate(mock_conn)
+
+ assert result is None
+
+
+def test_error_handler_never_leaks_jwt_details():
+ """Error handler MUST return generic error per RFC 6750 Section 3.1.
+
+ No JWT claim values, server config, or validation details should
+ ever appear in the HTTP response - regardless of the failure type.
+ References: CVE-2022-29266, CVE-2019-7644.
+ """
+ from starlette.authentication import AuthenticationError
+
+ mock_conn = MagicMock()
+
+ # Simulate various failure reasons that contain sensitive claim values
+ sensitive_reasons = [
+ "Algorithm mismatch: token uses 'RS256', expected 'HS256'",
+ "Issuer mismatch: token has 'https://evil.com', expected
'https://good.com'",
+ "Audience mismatch: token has 'wrong-aud', expected 'my-api'",
+ "Token expired for client 'admin-service'",
+ "Missing required scopes: {'admin'}. Token has: {'read'}",
+ ]
+
+ for reason in sensitive_reasons:
+ exc = AuthenticationError(reason)
+ response = _json_auth_error_handler(mock_conn, exc)
+
+ assert response.status_code == 401
+
+ body = json.loads(response.body.decode())
+ # Body must only have generic message
+ assert body["error"] == "invalid_token", f"Wrong error code for:
{reason}"
+ assert body["error_description"] == "Authentication failed", (
+ f"Detailed reason leaked for: {reason}"
+ )
+
+ # WWW-Authenticate must not contain any claim values
+ www_auth = response.headers.get("www-authenticate", "")
+ assert www_auth == 'Bearer error="invalid_token"', (
+ f"Detailed reason leaked in header for: {reason}"
+ )
+
+
[email protected]
+async def test_audience_mismatch_list_audience():
+ """Token audience not in allowed audience list should fail."""
+ verifier = DetailedJWTVerifier(
+ public_key="test-secret",
+ issuer="test-issuer",
+ audience=["aud1", "aud2"],
+ algorithm="HS256",
+ )
+
+ token = _make_token(
+ {"alg": "HS256", "typ": "JWT"},
+ {
+ "sub": "user1",
+ "iss": "test-issuer",
+ "aud": "wrong-aud",
+ "exp": int(time.time()) + 3600,
+ },
+ )
+ claims = {
+ "sub": "user1",
+ "iss": "test-issuer",
+ "aud": "wrong-aud",
+ "exp": int(time.time()) + 3600,
+ }
+
+ with patch.object(verifier.jwt, "decode", return_value=claims):
+ result = await verifier.load_access_token(token)
+
+ assert result is None
+ reason = _jwt_failure_reason.get()
+ assert reason == "Audience mismatch"
+
+
[email protected]
+async def test_issuer_mismatch_list_issuer():
+ """Token issuer not in allowed issuer list should fail."""
+ verifier = DetailedJWTVerifier(
+ public_key="test-secret",
+ issuer=["iss1", "iss2"],
+ audience="test-audience",
+ algorithm="HS256",
+ )
+
+ token = _make_token(
+ {"alg": "HS256", "typ": "JWT"},
+ {
+ "sub": "user1",
+ "iss": "wrong-issuer",
+ "aud": "test-audience",
+ "exp": int(time.time()) + 3600,
+ },
+ )
+ claims = {
+ "sub": "user1",
+ "iss": "wrong-issuer",
+ "aud": "test-audience",
+ "exp": int(time.time()) + 3600,
+ }
+
+ with patch.object(verifier.jwt, "decode", return_value=claims):
+ result = await verifier.load_access_token(token)
+
+ assert result is None
+ reason = _jwt_failure_reason.get()
+ assert reason == "Issuer mismatch"
+ # Claim values must not leak into the contextvar reason
+ assert "wrong-issuer" not in reason
+
+
+def test_decode_token_header_padding_multiple_of_4():
+ """_decode_token_header should handle headers whose length is a multiple
of 4."""
+ # eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9 is 36 chars (divisible by 4)
+ # This is the standard HS256/JWT header
+ header = {"alg": "HS256", "typ": "JWT"}
+ header_b64 = (
+
base64.urlsafe_b64encode(json.dumps(header).encode()).rstrip(b"=").decode()
+ )
+ token = f"{header_b64}.payload.signature"
+
+ result = DetailedJWTVerifier._decode_token_header(token)
+
+ assert result["alg"] == "HS256"
+ assert result["typ"] == "JWT"
+
+
[email protected]
+async def test_warning_logs_never_contain_claim_values(hs256_verifier, caplog):
+ """WARNING logs must contain only generic categories; details go to
DEBUG."""
+ token = _make_token(
+ {"alg": "HS256", "typ": "JWT"},
+ {
+ "sub": "user1",
+ "iss": "wrong-issuer",
+ "aud": "test-audience",
+ "exp": int(time.time()) + 3600,
+ },
+ )
+ claims = {
+ "sub": "user1",
+ "iss": "wrong-issuer",
+ "aud": "test-audience",
+ "exp": int(time.time()) + 3600,
+ }
+
+ with caplog.at_level(logging.DEBUG,
logger="superset.mcp_service.jwt_verifier"):
+ with patch.object(hs256_verifier.jwt, "decode", return_value=claims):
+ await hs256_verifier.load_access_token(token)
+
+ # WARNING logs must not contain claim values
+ warning_messages = [
+ r.message for r in caplog.records if r.levelno >= logging.WARNING
+ ]
+ for msg in warning_messages:
+ assert "wrong-issuer" not in msg
+ assert "test-issuer" not in msg
+
+ # DEBUG logs should contain the detailed values
+ debug_messages = [r.message for r in caplog.records if r.levelno ==
logging.DEBUG]
+ assert any("wrong-issuer" in msg for msg in debug_messages)
+
+
[email protected]
+async def test_hs256_secret_never_logged(hs256_verifier, caplog):
+ """The HS256 secret key must never appear in any log at any level."""
+ # This matches the public_key value from the hs256_verifier fixture
+ hs256_signing_value = "test-secret-key-for-hs256-tokens"
+
+ token = _make_token(
+ {"alg": "HS256", "typ": "JWT"},
+ {
+ "sub": "user1",
+ "iss": "wrong-issuer",
+ "aud": "test-audience",
+ "exp": int(time.time()) + 3600,
+ },
+ )
+ claims = {
+ "sub": "user1",
+ "iss": "wrong-issuer",
+ "aud": "test-audience",
+ "exp": int(time.time()) + 3600,
+ }
+
+ with caplog.at_level(logging.DEBUG,
logger="superset.mcp_service.jwt_verifier"):
+ with patch.object(hs256_verifier.jwt, "decode", return_value=claims):
+ await hs256_verifier.load_access_token(token)
+
+ # The signing value must never appear at ANY log level
+ all_messages = [r.message for r in caplog.records]
+ for msg in all_messages:
+ assert hs256_signing_value not in msg, f"HS256 secret leaked in log:
{msg}"
+
+
[email protected]
+async def test_expired_token_during_decode(hs256_verifier):
+ """ExpiredTokenError raised by jwt.decode should set generic reason."""
+ token = _make_token(
+ {"alg": "HS256", "typ": "JWT"},
+ {
+ "sub": "user1",
+ "iss": "test-issuer",
+ "aud": "test-audience",
+ "exp": int(time.time()) - 3600,
+ },
+ )
+
+ with patch.object(
+ hs256_verifier.jwt,
+ "decode",
+ side_effect=ExpiredTokenError(),
+ ):
+ result = await hs256_verifier.load_access_token(token)
+
+ assert result is None
+ reason = _jwt_failure_reason.get()
+ assert reason == "Token has expired (detected during decode)"
+
+
[email protected]
+async def test_catch_all_exception_sets_generic_reason(hs256_verifier):
+ """Catch-all handler should set generic reason without exception
details."""
+ token = _make_token(
+ {"alg": "HS256", "typ": "JWT"},
+ {
+ "sub": "user1",
+ "iss": "test-issuer",
+ "aud": "test-audience",
+ "exp": int(time.time()) + 3600,
+ },
+ )
+ claims = {
+ "sub": "user1",
+ "iss": "test-issuer",
+ "aud": "test-audience",
+ "exp": int(time.time()) + 3600,
+ }
+
+ with patch.object(hs256_verifier.jwt, "decode", return_value=claims):
+ with patch.object(
+ hs256_verifier,
+ "_extract_scopes",
+ side_effect=TypeError("unexpected type in scopes"),
+ ):
+ result = await hs256_verifier.load_access_token(token)
+
+ assert result is None
+ reason = _jwt_failure_reason.get()
+ assert reason == "Token validation failed"
+ assert "unexpected type" not in reason