This is an automated email from the ASF dual-hosted git repository.

arivero pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new ae99b194225 feat(mcp): add detailed JWT error messages and default 
auth factory fallback (#37972)
ae99b194225 is described below

commit ae99b19422517fce5a1b4f50fb473ea4991ac170
Author: Amin Ghadersohi <[email protected]>
AuthorDate: Thu Feb 26 08:21:40 2026 -0500

    feat(mcp): add detailed JWT error messages and default auth factory 
fallback (#37972)
    
    Co-authored-by: Claude Opus 4.6 <[email protected]>
---
 superset/mcp_service/auth.py                      |  18 +-
 superset/mcp_service/jwt_verifier.py              | 320 ++++++++++
 superset/mcp_service/mcp_config.py                |  68 +-
 superset/mcp_service/server.py                    |  50 +-
 tests/unit_tests/mcp_service/test_jwt_verifier.py | 726 ++++++++++++++++++++++
 5 files changed, 1138 insertions(+), 44 deletions(-)

diff --git a/superset/mcp_service/auth.py b/superset/mcp_service/auth.py
index ac11ecd62d0..8675947097f 100644
--- a/superset/mcp_service/auth.py
+++ b/superset/mcp_service/auth.py
@@ -111,9 +111,23 @@ def get_user_from_request() -> User:
     username = current_app.config.get("MCP_DEV_USERNAME")
 
     if not username:
+        auth_enabled = current_app.config.get("MCP_AUTH_ENABLED", False)
+        jwt_configured = bool(
+            current_app.config.get("MCP_JWKS_URI")
+            or current_app.config.get("MCP_JWT_PUBLIC_KEY")
+            or current_app.config.get("MCP_JWT_SECRET")
+        )
+        details = []
+        details.append(
+            f"g.user was not set by JWT middleware "
+            f"(MCP_AUTH_ENABLED={auth_enabled}, "
+            f"JWT keys configured={jwt_configured})"
+        )
+        details.append("MCP_DEV_USERNAME is not configured")
         raise ValueError(
-            "No authenticated user found. "
-            "Either pass a valid JWT bearer token or configure "
+            "No authenticated user found. Tried:\n"
+            + "\n".join(f"  - {d}" for d in details)
+            + "\n\nEither pass a valid JWT bearer token or configure "
             "MCP_DEV_USERNAME for development."
         )
 
diff --git a/superset/mcp_service/jwt_verifier.py 
b/superset/mcp_service/jwt_verifier.py
new file mode 100644
index 00000000000..abd8019d9e4
--- /dev/null
+++ b/superset/mcp_service/jwt_verifier.py
@@ -0,0 +1,320 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Detailed JWT verification for the MCP service.
+
+Provides step-by-step JWT validation with tiered server-side logging:
+- WARNING level: generic failure categories only (e.g. "Issuer mismatch")
+- DEBUG level: detailed claim values and config for troubleshooting
+- Secrets (e.g. HS256 keys) are NEVER logged at any level
+
+HTTP responses always return generic errors per RFC 6750 Section 3.1.
+"""
+
+import base64
+import logging
+import time
+from contextvars import ContextVar
+from typing import Any, cast
+
+from authlib.jose.errors import (
+    BadSignatureError,
+    DecodeError,
+    ExpiredTokenError,
+    JoseError,
+)
+from fastmcp.server.auth.auth import AccessToken
+from fastmcp.server.auth.providers.jwt import JWTVerifier
+from mcp.server.auth.middleware.auth_context import AuthContextMiddleware
+from mcp.server.auth.middleware.bearer_auth import BearerAuthBackend
+from starlette.authentication import AuthenticationError
+from starlette.middleware import Middleware
+from starlette.middleware.authentication import AuthenticationMiddleware
+from starlette.requests import HTTPConnection
+from starlette.responses import JSONResponse
+
+from superset.utils import json
+
+logger = logging.getLogger(__name__)
+
+# Thread-safe storage for the specific JWT failure reason.
+# Set by DetailedJWTVerifier.load_access_token() on failure,
+# read by DetailedBearerAuthBackend.authenticate() to raise
+# an AuthenticationError with the specific reason.
+# SECURITY: Must ALWAYS contain generic failure categories only.
+# Claim values and secrets must NEVER be stored here.
+_jwt_failure_reason: ContextVar[str | None] = ContextVar(
+    "_jwt_failure_reason", default=None
+)
+
+
+def _json_auth_error_handler(
+    conn: HTTPConnection, exc: AuthenticationError
+) -> JSONResponse:
+    """JSON 401 error handler for authentication failures.
+
+    Per RFC 6750 Section 3.1, error responses MUST NOT leak server
+    configuration or token claim values. Only generic error codes are
+    returned to clients. Detailed failure reasons are logged server-side
+    only for debugging.
+
+    References:
+        - RFC 6750 Section 3.1: 
https://datatracker.ietf.org/doc/html/rfc6750#section-3.1
+        - CVE-2022-29266, CVE-2019-7644: verbose JWT errors led to exploits
+    """
+    # Log detailed reason server-side only
+    logger.warning("JWT authentication failed: %s", exc)
+
+    return JSONResponse(
+        status_code=401,
+        content={
+            "error": "invalid_token",
+            "error_description": "Authentication failed",
+        },
+        headers={
+            "WWW-Authenticate": 'Bearer error="invalid_token"',
+        },
+    )
+
+
+class DetailedBearerAuthBackend(BearerAuthBackend):
+    """
+    Bearer auth backend that raises AuthenticationError with specific
+    JWT failure reasons instead of silently returning None.
+    """
+
+    async def authenticate(self, conn: HTTPConnection) -> tuple[Any, Any] | 
None:
+        result = await super().authenticate(conn)
+
+        if result is not None:
+            # Clear any stale failure reason on success
+            _jwt_failure_reason.set(None)
+            return result
+
+        # Check if there's a Bearer token present - if so, there was a
+        # validation failure we can report with a specific reason
+        auth_header = next(
+            (
+                conn.headers.get(key)
+                for key in conn.headers
+                if key.lower() == "authorization"
+            ),
+            None,
+        )
+        if auth_header and auth_header.lower().startswith("bearer "):
+            reason = _jwt_failure_reason.get()
+            if reason:
+                _jwt_failure_reason.set(None)
+                raise AuthenticationError(reason)
+
+        return None
+
+
+class DetailedJWTVerifier(JWTVerifier):
+    """
+    JWT verifier with tiered server-side logging for each validation step.
+
+    Logging tiers:
+    - WARNING: generic failure categories only (via _jwt_failure_reason 
ContextVar
+      and the error handler). No claim values, no server config.
+    - DEBUG: detailed values (issuer, audience, client_id, scopes, exceptions)
+      for operator troubleshooting.
+    - Secrets (e.g. HS256 signing keys) are NEVER logged at any level.
+
+    HTTP responses always return generic errors per RFC 6750 Section 3.1.
+    Controlled by MCP_JWT_DEBUG_ERRORS config flag.
+    """
+
+    async def load_access_token(self, token: str) -> AccessToken | None:  # 
noqa: C901
+        """
+        Validate a JWT bearer token with detailed error reporting.
+
+        Each validation step stores a specific failure reason in the
+        _jwt_failure_reason ContextVar before returning None.
+        """
+        # Reset any previous failure reason
+        _jwt_failure_reason.set(None)
+
+        try:
+            # Step 1: Decode header and check algorithm
+            try:
+                header = self._decode_token_header(token)
+            except (ValueError, DecodeError) as e:
+                reason = "Malformed token header"
+                _jwt_failure_reason.set(reason)
+                logger.debug("Malformed token header: %s", e)
+                return None
+
+            token_alg = header.get("alg")
+            if self.algorithm and token_alg != self.algorithm:
+                reason = "Algorithm mismatch"
+                _jwt_failure_reason.set(reason)
+                logger.debug(
+                    "Algorithm mismatch: token uses '%s', expected '%s'",
+                    token_alg,
+                    self.algorithm,
+                )
+                return None
+
+            # Step 2: Get verification key (static or JWKS)
+            try:
+                verification_key = await self._get_verification_key(token)
+            except ValueError as e:
+                reason = "Failed to get verification key"
+                _jwt_failure_reason.set(reason)
+                logger.debug("Failed to get verification key: %s", e)
+                return None
+
+            # Step 3: Decode and verify signature
+            try:
+                claims = self.jwt.decode(token, verification_key)
+            except BadSignatureError:
+                reason = "Signature verification failed"
+                _jwt_failure_reason.set(reason)
+                return None
+            except ExpiredTokenError:
+                reason = "Token has expired (detected during decode)"
+                _jwt_failure_reason.set(reason)
+                return None
+            except JoseError as e:
+                reason = "Token decode failed"
+                _jwt_failure_reason.set(reason)
+                logger.debug("Token decode failed: %s", e)
+                return None
+
+            # Extract client ID for logging
+            client_id = (
+                claims.get("client_id")
+                or claims.get("azp")
+                or claims.get("sub")
+                or "unknown"
+            )
+
+            # Step 4: Check expiration
+            exp = claims.get("exp")
+            if exp and exp < time.time():
+                reason = "Token expired"
+                _jwt_failure_reason.set(reason)
+                logger.debug("Token expired for client '%s'", client_id)
+                return None
+
+            # Step 5: Validate issuer
+            if self.issuer:
+                iss = claims.get("iss")
+                if isinstance(self.issuer, list):
+                    issuer_valid = iss in self.issuer
+                else:
+                    issuer_valid = iss == self.issuer
+
+                if not issuer_valid:
+                    reason = "Issuer mismatch"
+                    _jwt_failure_reason.set(reason)
+                    logger.debug(
+                        "Issuer mismatch: token has '%s', expected '%s'",
+                        iss,
+                        self.issuer,
+                    )
+                    return None
+
+            # Step 6: Validate audience
+            if self.audience:
+                aud = claims.get("aud")
+                if isinstance(self.audience, list):
+                    if isinstance(aud, list):
+                        audience_valid = any(
+                            expected in aud
+                            for expected in cast(list[str], self.audience)
+                        )
+                    else:
+                        audience_valid = aud in cast(list[str], self.audience)
+                else:
+                    if isinstance(aud, list):
+                        audience_valid = self.audience in aud
+                    else:
+                        audience_valid = aud == self.audience
+
+                if not audience_valid:
+                    reason = "Audience mismatch"
+                    _jwt_failure_reason.set(reason)
+                    logger.debug(
+                        "Audience mismatch: token has '%s', expected '%s'",
+                        aud,
+                        self.audience,
+                    )
+                    return None
+
+            # Step 7: Check required scopes
+            scopes = self._extract_scopes(claims)
+            if self.required_scopes:
+                token_scopes = set(scopes)
+                required = set(self.required_scopes)
+                if not required.issubset(token_scopes):
+                    missing = required - token_scopes
+                    reason = "Missing required scopes"
+                    _jwt_failure_reason.set(reason)
+                    logger.debug(
+                        "Missing required scopes: %s. Token has: %s",
+                        missing,
+                        token_scopes,
+                    )
+                    return None
+
+            # All validations passed
+            return AccessToken(
+                token=token,
+                client_id=str(client_id),
+                scopes=scopes,
+                expires_at=int(exp) if exp else None,
+                claims=dict(claims),
+            )
+
+        except (ValueError, JoseError, KeyError, AttributeError, TypeError) as 
e:
+            reason = "Token validation failed"
+            _jwt_failure_reason.set(reason)
+            logger.debug("Token validation failed: %s", e)
+            return None
+
+    def get_middleware(self) -> list[Any]:
+        """
+        Get middleware with detailed server-side error logging.
+
+        Uses DetailedBearerAuthBackend which raises AuthenticationError
+        with specific reasons for server-side logging. The error handler
+        always returns generic 401 responses per RFC 6750.
+        """
+        return [
+            Middleware(
+                AuthenticationMiddleware,
+                backend=DetailedBearerAuthBackend(self),
+                on_error=_json_auth_error_handler,
+            ),
+            Middleware(AuthContextMiddleware),
+        ]
+
+    @staticmethod
+    def _decode_token_header(token: str) -> dict[str, Any]:
+        """Decode the JWT header without verifying the signature."""
+        parts = token.split(".")
+        if len(parts) != 3:
+            raise ValueError(
+                f"Token must have 3 parts (header.payload.signature), got 
{len(parts)}"
+            )
+        header_b64 = parts[0]
+        # Add padding only if needed
+        header_b64 += "=" * (-len(header_b64) % 4)
+        header_bytes = base64.urlsafe_b64decode(header_b64)
+        return json.loads(header_bytes)
diff --git a/superset/mcp_service/mcp_config.py 
b/superset/mcp_service/mcp_config.py
index 86a772aa483..dab549b9b06 100644
--- a/superset/mcp_service/mcp_config.py
+++ b/superset/mcp_service/mcp_config.py
@@ -45,6 +45,16 @@ MCP_SERVICE_PORT = 5008
 # MCP Debug mode - shows suppressed initialization output in stdio mode
 MCP_DEBUG = False
 
+# MCP JWT Debug Errors - controls server-side JWT debug logging.
+# When False (default), uses the default JWTVerifier with minimal logging.
+# When True, uses DetailedJWTVerifier with tiered logging:
+#   - WARNING level: generic failure categories only (e.g. "Issuer mismatch")
+#   - DEBUG level: detailed claim values for troubleshooting
+#   - Secrets (e.g. HS256 keys) are NEVER logged at any level
+# HTTP responses ALWAYS return generic errors regardless of this setting,
+# per RFC 6750 Section 3.1. This flag NEVER affects client-facing output.
+MCP_JWT_DEBUG_ERRORS = False
+
 # Enable parse_request decorator for MCP tools.
 # When True (default), tool requests are automatically parsed from JSON strings
 # to Pydantic models, working around a Claude Code double-serialization bug
@@ -231,47 +241,47 @@ def create_default_mcp_auth_factory(app: Flask) -> 
Optional[Any]:
         return None
 
     try:
-        from fastmcp.server.auth.providers.jwt import JWTVerifier
+        debug_errors = app.config.get("MCP_JWT_DEBUG_ERRORS", False)
+
+        common_kwargs: dict[str, Any] = {
+            "issuer": app.config.get("MCP_JWT_ISSUER"),
+            "audience": app.config.get("MCP_JWT_AUDIENCE"),
+            "required_scopes": app.config.get("MCP_REQUIRED_SCOPES", []),
+        }
 
         # For HS256 (symmetric), use the secret as the public_key parameter
         if app.config.get("MCP_JWT_ALGORITHM") == "HS256" and secret:
-            auth_provider = JWTVerifier(
-                public_key=secret,  # HS256 uses secret as key
-                issuer=app.config.get("MCP_JWT_ISSUER"),
-                audience=app.config.get("MCP_JWT_AUDIENCE"),
-                algorithm="HS256",
-                required_scopes=app.config.get("MCP_REQUIRED_SCOPES", []),
-            )
-            logger.info("Created JWTVerifier with HS256 secret")
+            common_kwargs["public_key"] = secret
+            common_kwargs["algorithm"] = "HS256"
         else:
             # For RS256 (asymmetric), use public key or JWKS
-            auth_provider = JWTVerifier(
-                jwks_uri=jwks_uri,
-                public_key=public_key,
-                issuer=app.config.get("MCP_JWT_ISSUER"),
-                audience=app.config.get("MCP_JWT_AUDIENCE"),
-                algorithm=app.config.get("MCP_JWT_ALGORITHM", "RS256"),
-                required_scopes=app.config.get("MCP_REQUIRED_SCOPES", []),
-            )
-            logger.info(
-                "Created JWTVerifier with jwks_uri=%s, public_key=%s",
-                jwks_uri,
-                "***" if public_key else None,
-            )
+            common_kwargs["jwks_uri"] = jwks_uri
+            common_kwargs["public_key"] = public_key
+            common_kwargs["algorithm"] = app.config.get("MCP_JWT_ALGORITHM", 
"RS256")
+
+        if debug_errors:
+            # DetailedJWTVerifier: detailed server-side logging of JWT
+            # validation failures. HTTP responses are always generic per
+            # RFC 6750 Section 3.1.
+            from superset.mcp_service.jwt_verifier import DetailedJWTVerifier
+
+            auth_provider = DetailedJWTVerifier(**common_kwargs)
+        else:
+            # Default JWTVerifier: minimal logging, generic error responses.
+            from fastmcp.server.auth.providers.jwt import JWTVerifier
+
+            auth_provider = JWTVerifier(**common_kwargs)
 
         return auth_provider
-    except Exception as e:
-        logger.error("Failed to create MCP auth provider: %s", e)
+    except Exception:
+        # Do not log the exception — it may contain the HS256 secret
+        # from common_kwargs["public_key"]
+        logger.error("Failed to create MCP auth provider")
         return None
 
 
 def default_user_resolver(app: Any, access_token: Any) -> Optional[str]:
     """Extract username from JWT token claims."""
-    logger.info(
-        "Resolving user from token: type=%s, token=%s",
-        type(access_token),
-        access_token,
-    )
     if hasattr(access_token, "subject"):
         return access_token.subject
     if hasattr(access_token, "client_id"):
diff --git a/superset/mcp_service/server.py b/superset/mcp_service/server.py
index a0d7e7f5d42..9e9af7dd47b 100644
--- a/superset/mcp_service/server.py
+++ b/superset/mcp_service/server.py
@@ -33,6 +33,8 @@ from superset.mcp_service.mcp_config import 
get_mcp_factory_config, MCP_STORE_CO
 from superset.mcp_service.middleware import 
create_response_size_guard_middleware
 from superset.mcp_service.storage import _create_redis_store
 
+logger = logging.getLogger(__name__)
+
 
 def configure_logging(debug: bool = False) -> None:
     """Configure logging for the MCP service."""
@@ -119,6 +121,40 @@ def create_event_store(config: dict[str, Any] | None = 
None) -> Any | None:
         return None
 
 
+def _create_auth_provider(flask_app: Any) -> Any | None:
+    """Create an auth provider from Flask app config.
+
+    Tries MCP_AUTH_FACTORY first, then falls back to the default factory
+    when MCP_AUTH_ENABLED is True.
+    """
+    auth_provider = None
+    if auth_factory := flask_app.config.get("MCP_AUTH_FACTORY"):
+        try:
+            auth_provider = auth_factory(flask_app)
+            logger.info(
+                "Auth provider created from MCP_AUTH_FACTORY: %s",
+                type(auth_provider).__name__ if auth_provider else "None",
+            )
+        except Exception:
+            # Do not log the exception — it may contain secrets
+            logger.error("Failed to create auth provider from 
MCP_AUTH_FACTORY")
+    elif flask_app.config.get("MCP_AUTH_ENABLED", False):
+        from superset.mcp_service.mcp_config import (
+            create_default_mcp_auth_factory,
+        )
+
+        try:
+            auth_provider = create_default_mcp_auth_factory(flask_app)
+            logger.info(
+                "Auth provider created from default factory: %s",
+                type(auth_provider).__name__ if auth_provider else "None",
+            )
+        except Exception:
+            # Do not log the exception — it may contain secrets
+            logger.error("Failed to create auth provider from default factory")
+    return auth_provider
+
+
 def run_server(
     host: str = "127.0.0.1",
     port: int = 5008,
@@ -158,19 +194,7 @@ def run_server(
         from superset.mcp_service.flask_singleton import get_flask_app
 
         flask_app = get_flask_app()
-
-        # Get auth factory from config and create auth provider
-        auth_provider = None
-        auth_factory = flask_app.config.get("MCP_AUTH_FACTORY")
-        if auth_factory:
-            try:
-                auth_provider = auth_factory(flask_app)
-                logging.info(
-                    "Auth provider created: %s",
-                    type(auth_provider).__name__ if auth_provider else "None",
-                )
-            except Exception as e:
-                logging.error("Failed to create auth provider: %s", e)
+        auth_provider = _create_auth_provider(flask_app)
 
         # Build middleware list
         middleware_list = []
diff --git a/tests/unit_tests/mcp_service/test_jwt_verifier.py 
b/tests/unit_tests/mcp_service/test_jwt_verifier.py
new file mode 100644
index 00000000000..b6ba58aa7f2
--- /dev/null
+++ b/tests/unit_tests/mcp_service/test_jwt_verifier.py
@@ -0,0 +1,726 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Tests for DetailedJWTVerifier and related middleware."""
+
+import base64
+import logging
+import time
+from unittest.mock import AsyncMock, MagicMock, patch
+
+import pytest
+from authlib.jose.errors import BadSignatureError, DecodeError, 
ExpiredTokenError
+
+from superset.mcp_service.jwt_verifier import (
+    _json_auth_error_handler,
+    _jwt_failure_reason,
+    DetailedBearerAuthBackend,
+    DetailedJWTVerifier,
+)
+from superset.utils import json
+
+
+def _make_token(
+    header: dict[str, str], payload: dict[str, object], signature: str = "sig"
+) -> str:
+    """Build a fake JWT string from header + payload dicts."""
+    h = 
base64.urlsafe_b64encode(json.dumps(header).encode()).rstrip(b"=").decode()
+    p = 
base64.urlsafe_b64encode(json.dumps(payload).encode()).rstrip(b"=").decode()
+    return f"{h}.{p}.{signature}"
+
+
[email protected]
+def hs256_verifier():
+    """Create a DetailedJWTVerifier configured for HS256."""
+    return DetailedJWTVerifier(
+        public_key="test-secret-key-for-hs256-tokens",
+        issuer="test-issuer",
+        audience="test-audience",
+        algorithm="HS256",
+        required_scopes=[],
+    )
+
+
[email protected](autouse=True)
+def _reset_contextvar():
+    """Reset the failure reason contextvar before each test."""
+    _jwt_failure_reason.set(None)
+    yield
+    _jwt_failure_reason.set(None)
+
+
[email protected]
+async def test_algorithm_mismatch(hs256_verifier):
+    """Token with wrong algorithm should report algorithm mismatch."""
+    token = _make_token(
+        {"alg": "RS256", "typ": "JWT"},
+        {"sub": "user1", "iss": "test-issuer", "aud": "test-audience"},
+    )
+
+    result = await hs256_verifier.load_access_token(token)
+
+    assert result is None
+    reason = _jwt_failure_reason.get()
+    assert reason == "Algorithm mismatch"
+    # Claim values must not leak into the contextvar reason
+    assert "RS256" not in reason
+    assert "HS256" not in reason
+
+
[email protected]
+async def test_malformed_token_header(hs256_verifier):
+    """Token with invalid header should report malformed header."""
+    # A token with only 2 parts (missing signature)
+    result = await hs256_verifier.load_access_token("part1.part2")
+
+    assert result is None
+    reason = _jwt_failure_reason.get()
+    assert reason == "Malformed token header"
+
+
[email protected]
+async def test_signature_verification_failed(hs256_verifier):
+    """Token with bad signature should report signature failure."""
+    token = _make_token(
+        {"alg": "HS256", "typ": "JWT"},
+        {
+            "sub": "user1",
+            "iss": "test-issuer",
+            "aud": "test-audience",
+            "exp": int(time.time()) + 3600,
+        },
+    )
+
+    with patch.object(
+        hs256_verifier.jwt,
+        "decode",
+        side_effect=BadSignatureError(result=None),
+    ):
+        result = await hs256_verifier.load_access_token(token)
+
+    assert result is None
+    reason = _jwt_failure_reason.get()
+    assert reason == "Signature verification failed"
+
+
[email protected]
+async def test_expired_token(hs256_verifier):
+    """Expired token should report token expired."""
+    expired_time = int(time.time()) - 3600
+    token = _make_token(
+        {"alg": "HS256", "typ": "JWT"},
+        {
+            "sub": "user1",
+            "iss": "test-issuer",
+            "aud": "test-audience",
+            "exp": expired_time,
+        },
+    )
+    claims = {
+        "sub": "user1",
+        "iss": "test-issuer",
+        "aud": "test-audience",
+        "exp": expired_time,
+    }
+
+    with patch.object(hs256_verifier.jwt, "decode", return_value=claims):
+        result = await hs256_verifier.load_access_token(token)
+
+    assert result is None
+    reason = _jwt_failure_reason.get()
+    assert reason == "Token expired"
+    # Claim values must not leak into the contextvar reason
+    assert "user1" not in reason
+
+
[email protected]
+async def test_issuer_mismatch(hs256_verifier):
+    """Token with wrong issuer should report issuer mismatch."""
+    token = _make_token(
+        {"alg": "HS256", "typ": "JWT"},
+        {
+            "sub": "user1",
+            "iss": "wrong-issuer",
+            "aud": "test-audience",
+            "exp": int(time.time()) + 3600,
+        },
+    )
+    claims = {
+        "sub": "user1",
+        "iss": "wrong-issuer",
+        "aud": "test-audience",
+        "exp": int(time.time()) + 3600,
+    }
+
+    with patch.object(hs256_verifier.jwt, "decode", return_value=claims):
+        result = await hs256_verifier.load_access_token(token)
+
+    assert result is None
+    reason = _jwt_failure_reason.get()
+    assert reason == "Issuer mismatch"
+    # Claim values must not leak into the contextvar reason
+    assert "wrong-issuer" not in reason
+    assert "test-issuer" not in reason
+
+
[email protected]
+async def test_audience_mismatch(hs256_verifier):
+    """Token with wrong audience should report audience mismatch."""
+    token = _make_token(
+        {"alg": "HS256", "typ": "JWT"},
+        {
+            "sub": "user1",
+            "iss": "test-issuer",
+            "aud": "wrong-audience",
+            "exp": int(time.time()) + 3600,
+        },
+    )
+    claims = {
+        "sub": "user1",
+        "iss": "test-issuer",
+        "aud": "wrong-audience",
+        "exp": int(time.time()) + 3600,
+    }
+
+    with patch.object(hs256_verifier.jwt, "decode", return_value=claims):
+        result = await hs256_verifier.load_access_token(token)
+
+    assert result is None
+    reason = _jwt_failure_reason.get()
+    assert reason == "Audience mismatch"
+    # Claim values must not leak into the contextvar reason
+    assert "wrong-audience" not in reason
+    assert "test-audience" not in reason
+
+
[email protected]
+async def test_missing_required_scopes(hs256_verifier):
+    """Token missing required scopes should report missing scopes."""
+    hs256_verifier.required_scopes = ["admin", "read"]
+
+    token = _make_token(
+        {"alg": "HS256", "typ": "JWT"},
+        {
+            "sub": "user1",
+            "iss": "test-issuer",
+            "aud": "test-audience",
+            "exp": int(time.time()) + 3600,
+            "scope": "read",
+        },
+    )
+    claims = {
+        "sub": "user1",
+        "iss": "test-issuer",
+        "aud": "test-audience",
+        "exp": int(time.time()) + 3600,
+        "scope": "read",
+    }
+
+    with patch.object(hs256_verifier.jwt, "decode", return_value=claims):
+        result = await hs256_verifier.load_access_token(token)
+
+    assert result is None
+    reason = _jwt_failure_reason.get()
+    assert reason == "Missing required scopes"
+    # Claim values must not leak into the contextvar reason
+    assert "admin" not in reason
+
+
[email protected]
+async def test_valid_token(hs256_verifier):
+    """Valid token should return AccessToken and clear contextvar."""
+    future_exp = int(time.time()) + 3600
+    token = _make_token(
+        {"alg": "HS256", "typ": "JWT"},
+        {
+            "sub": "user1",
+            "iss": "test-issuer",
+            "aud": "test-audience",
+            "exp": future_exp,
+        },
+    )
+    claims = {
+        "sub": "user1",
+        "iss": "test-issuer",
+        "aud": "test-audience",
+        "exp": future_exp,
+    }
+
+    with patch.object(hs256_verifier.jwt, "decode", return_value=claims):
+        result = await hs256_verifier.load_access_token(token)
+
+    assert result is not None
+    assert result.client_id == "user1"
+    assert result.expires_at == future_exp
+    # Contextvar should be None on success
+    assert _jwt_failure_reason.get() is None
+
+
[email protected]
+async def test_valid_token_no_expiration(hs256_verifier):
+    """Valid token without expiration should still succeed."""
+    token = _make_token(
+        {"alg": "HS256", "typ": "JWT"},
+        {
+            "sub": "user1",
+            "iss": "test-issuer",
+            "aud": "test-audience",
+        },
+    )
+    claims = {
+        "sub": "user1",
+        "iss": "test-issuer",
+        "aud": "test-audience",
+    }
+
+    with patch.object(hs256_verifier.jwt, "decode", return_value=claims):
+        result = await hs256_verifier.load_access_token(token)
+
+    assert result is not None
+    assert result.client_id == "user1"
+    assert result.expires_at is None
+
+
[email protected]
+async def test_decode_error(hs256_verifier):
+    """Token that fails to decode should report decode failure."""
+    token = _make_token(
+        {"alg": "HS256", "typ": "JWT"},
+        {"sub": "user1"},
+    )
+
+    with patch.object(
+        hs256_verifier.jwt,
+        "decode",
+        side_effect=DecodeError("bad token"),
+    ):
+        result = await hs256_verifier.load_access_token(token)
+
+    assert result is None
+    reason = _jwt_failure_reason.get()
+    assert reason == "Token decode failed"
+
+
[email protected]
+async def test_verification_key_failure(hs256_verifier):
+    """Failure to get verification key should report specific error."""
+    token = _make_token(
+        {"alg": "HS256", "typ": "JWT"},
+        {"sub": "user1"},
+    )
+
+    with patch.object(
+        hs256_verifier,
+        "_get_verification_key",
+        side_effect=ValueError("JWKS endpoint unreachable"),
+    ):
+        result = await hs256_verifier.load_access_token(token)
+
+    assert result is None
+    reason = _jwt_failure_reason.get()
+    assert reason == "Failed to get verification key"
+    # Exception details must not leak into the contextvar reason
+    assert "JWKS endpoint unreachable" not in reason
+
+
[email protected]
+async def test_contextvar_cleared_on_success(hs256_verifier):
+    """Contextvar should be reset to None before successful validation."""
+    # Set a stale failure reason
+    _jwt_failure_reason.set("previous failure")
+
+    future_exp = int(time.time()) + 3600
+    token = _make_token(
+        {"alg": "HS256", "typ": "JWT"},
+        {
+            "sub": "user1",
+            "iss": "test-issuer",
+            "aud": "test-audience",
+            "exp": future_exp,
+        },
+    )
+    claims = {
+        "sub": "user1",
+        "iss": "test-issuer",
+        "aud": "test-audience",
+        "exp": future_exp,
+    }
+
+    with patch.object(hs256_verifier.jwt, "decode", return_value=claims):
+        result = await hs256_verifier.load_access_token(token)
+
+    assert result is not None
+    assert _jwt_failure_reason.get() is None
+
+
+def test_decode_token_header_valid():
+    """_decode_token_header should decode a valid JWT header."""
+    header = {"alg": "RS256", "typ": "JWT", "kid": "key1"}
+    header_b64 = (
+        
base64.urlsafe_b64encode(json.dumps(header).encode()).rstrip(b"=").decode()
+    )
+    token = f"{header_b64}.payload.signature"
+
+    result = DetailedJWTVerifier._decode_token_header(token)
+
+    assert result["alg"] == "RS256"
+    assert result["kid"] == "key1"
+
+
+def test_decode_token_header_too_few_parts():
+    """_decode_token_header should raise for tokens with wrong number of 
parts."""
+    with pytest.raises(ValueError, match="3 parts"):
+        DetailedJWTVerifier._decode_token_header("only.two")
+
+
+def test_get_middleware_returns_custom_components(hs256_verifier):
+    """get_middleware should use DetailedBearerAuthBackend and generic error 
handler."""
+    middleware_list = hs256_verifier.get_middleware()
+
+    assert len(middleware_list) == 2
+
+    # First middleware should be AuthenticationMiddleware with our custom 
backend
+    auth_middleware = middleware_list[0]
+    assert (
+        auth_middleware.kwargs["backend"].__class__.__name__
+        == "DetailedBearerAuthBackend"
+    )
+    # on_error should be the RFC 6750-compliant generic handler
+    assert auth_middleware.kwargs["on_error"] is _json_auth_error_handler
+
+
+class _FakeHeaders(dict[str, str]):
+    """A dict subclass that allows overriding .get() for mock connections."""
+
+    def __init__(self, *args: object, **kwargs: object) -> None:
+        super().__init__(*args, **kwargs)
+
+    def get(self, key: str, default: str | None = None) -> str | None:  # 
type: ignore[override]
+        return super().get(key, default)
+
+
[email protected]
+async def test_detailed_bearer_backend_raises_on_failure():
+    """DetailedBearerAuthBackend should raise AuthenticationError with 
reason."""
+    from starlette.authentication import AuthenticationError
+
+    mock_verifier = MagicMock()
+    mock_verifier.verify_token = AsyncMock(return_value=None)
+
+    backend = DetailedBearerAuthBackend(mock_verifier)
+
+    # Mock connection with Bearer token
+    mock_conn = MagicMock()
+    mock_conn.headers = _FakeHeaders({"authorization": "Bearer some-token"})
+
+    # Set failure reason (generic, no claim values)
+    _jwt_failure_reason.set("Token expired")
+
+    with pytest.raises(AuthenticationError, match="Token expired"):
+        await backend.authenticate(mock_conn)
+
+    # Contextvar should be cleared after raising
+    assert _jwt_failure_reason.get() is None
+
+
[email protected]
+async def test_detailed_bearer_backend_passes_through_success():
+    """DetailedBearerAuthBackend should return normally on success."""
+    mock_verifier = MagicMock()
+    mock_token = MagicMock()
+    mock_token.scopes = ["read"]
+    mock_token.expires_at = None
+    mock_verifier.verify_token = AsyncMock(return_value=mock_token)
+
+    backend = DetailedBearerAuthBackend(mock_verifier)
+
+    mock_conn = MagicMock()
+    mock_conn.headers = _FakeHeaders({"authorization": "Bearer valid-token"})
+
+    result = await backend.authenticate(mock_conn)
+
+    assert result is not None
+    assert _jwt_failure_reason.get() is None
+
+
[email protected]
+async def test_detailed_bearer_backend_no_bearer_token():
+    """DetailedBearerAuthBackend should return None when no Bearer token."""
+    mock_verifier = MagicMock()
+    mock_verifier.verify_token = AsyncMock(return_value=None)
+
+    backend = DetailedBearerAuthBackend(mock_verifier)
+
+    # Mock connection without auth header
+    mock_conn = MagicMock()
+    mock_conn.headers = _FakeHeaders({})
+
+    result = await backend.authenticate(mock_conn)
+
+    assert result is None
+
+
+def test_error_handler_never_leaks_jwt_details():
+    """Error handler MUST return generic error per RFC 6750 Section 3.1.
+
+    No JWT claim values, server config, or validation details should
+    ever appear in the HTTP response - regardless of the failure type.
+    References: CVE-2022-29266, CVE-2019-7644.
+    """
+    from starlette.authentication import AuthenticationError
+
+    mock_conn = MagicMock()
+
+    # Simulate various failure reasons that contain sensitive claim values
+    sensitive_reasons = [
+        "Algorithm mismatch: token uses 'RS256', expected 'HS256'",
+        "Issuer mismatch: token has 'https://evil.com', expected 
'https://good.com'",
+        "Audience mismatch: token has 'wrong-aud', expected 'my-api'",
+        "Token expired for client 'admin-service'",
+        "Missing required scopes: {'admin'}. Token has: {'read'}",
+    ]
+
+    for reason in sensitive_reasons:
+        exc = AuthenticationError(reason)
+        response = _json_auth_error_handler(mock_conn, exc)
+
+        assert response.status_code == 401
+
+        body = json.loads(response.body.decode())
+        # Body must only have generic message
+        assert body["error"] == "invalid_token", f"Wrong error code for: 
{reason}"
+        assert body["error_description"] == "Authentication failed", (
+            f"Detailed reason leaked for: {reason}"
+        )
+
+        # WWW-Authenticate must not contain any claim values
+        www_auth = response.headers.get("www-authenticate", "")
+        assert www_auth == 'Bearer error="invalid_token"', (
+            f"Detailed reason leaked in header for: {reason}"
+        )
+
+
[email protected]
+async def test_audience_mismatch_list_audience():
+    """Token audience not in allowed audience list should fail."""
+    verifier = DetailedJWTVerifier(
+        public_key="test-secret",
+        issuer="test-issuer",
+        audience=["aud1", "aud2"],
+        algorithm="HS256",
+    )
+
+    token = _make_token(
+        {"alg": "HS256", "typ": "JWT"},
+        {
+            "sub": "user1",
+            "iss": "test-issuer",
+            "aud": "wrong-aud",
+            "exp": int(time.time()) + 3600,
+        },
+    )
+    claims = {
+        "sub": "user1",
+        "iss": "test-issuer",
+        "aud": "wrong-aud",
+        "exp": int(time.time()) + 3600,
+    }
+
+    with patch.object(verifier.jwt, "decode", return_value=claims):
+        result = await verifier.load_access_token(token)
+
+    assert result is None
+    reason = _jwt_failure_reason.get()
+    assert reason == "Audience mismatch"
+
+
[email protected]
+async def test_issuer_mismatch_list_issuer():
+    """Token issuer not in allowed issuer list should fail."""
+    verifier = DetailedJWTVerifier(
+        public_key="test-secret",
+        issuer=["iss1", "iss2"],
+        audience="test-audience",
+        algorithm="HS256",
+    )
+
+    token = _make_token(
+        {"alg": "HS256", "typ": "JWT"},
+        {
+            "sub": "user1",
+            "iss": "wrong-issuer",
+            "aud": "test-audience",
+            "exp": int(time.time()) + 3600,
+        },
+    )
+    claims = {
+        "sub": "user1",
+        "iss": "wrong-issuer",
+        "aud": "test-audience",
+        "exp": int(time.time()) + 3600,
+    }
+
+    with patch.object(verifier.jwt, "decode", return_value=claims):
+        result = await verifier.load_access_token(token)
+
+    assert result is None
+    reason = _jwt_failure_reason.get()
+    assert reason == "Issuer mismatch"
+    # Claim values must not leak into the contextvar reason
+    assert "wrong-issuer" not in reason
+
+
+def test_decode_token_header_padding_multiple_of_4():
+    """_decode_token_header should handle headers whose length is a multiple 
of 4."""
+    # eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9 is 36 chars (divisible by 4)
+    # This is the standard HS256/JWT header
+    header = {"alg": "HS256", "typ": "JWT"}
+    header_b64 = (
+        
base64.urlsafe_b64encode(json.dumps(header).encode()).rstrip(b"=").decode()
+    )
+    token = f"{header_b64}.payload.signature"
+
+    result = DetailedJWTVerifier._decode_token_header(token)
+
+    assert result["alg"] == "HS256"
+    assert result["typ"] == "JWT"
+
+
[email protected]
+async def test_warning_logs_never_contain_claim_values(hs256_verifier, caplog):
+    """WARNING logs must contain only generic categories; details go to 
DEBUG."""
+    token = _make_token(
+        {"alg": "HS256", "typ": "JWT"},
+        {
+            "sub": "user1",
+            "iss": "wrong-issuer",
+            "aud": "test-audience",
+            "exp": int(time.time()) + 3600,
+        },
+    )
+    claims = {
+        "sub": "user1",
+        "iss": "wrong-issuer",
+        "aud": "test-audience",
+        "exp": int(time.time()) + 3600,
+    }
+
+    with caplog.at_level(logging.DEBUG, 
logger="superset.mcp_service.jwt_verifier"):
+        with patch.object(hs256_verifier.jwt, "decode", return_value=claims):
+            await hs256_verifier.load_access_token(token)
+
+    # WARNING logs must not contain claim values
+    warning_messages = [
+        r.message for r in caplog.records if r.levelno >= logging.WARNING
+    ]
+    for msg in warning_messages:
+        assert "wrong-issuer" not in msg
+        assert "test-issuer" not in msg
+
+    # DEBUG logs should contain the detailed values
+    debug_messages = [r.message for r in caplog.records if r.levelno == 
logging.DEBUG]
+    assert any("wrong-issuer" in msg for msg in debug_messages)
+
+
[email protected]
+async def test_hs256_secret_never_logged(hs256_verifier, caplog):
+    """The HS256 secret key must never appear in any log at any level."""
+    # This matches the public_key value from the hs256_verifier fixture
+    hs256_signing_value = "test-secret-key-for-hs256-tokens"
+
+    token = _make_token(
+        {"alg": "HS256", "typ": "JWT"},
+        {
+            "sub": "user1",
+            "iss": "wrong-issuer",
+            "aud": "test-audience",
+            "exp": int(time.time()) + 3600,
+        },
+    )
+    claims = {
+        "sub": "user1",
+        "iss": "wrong-issuer",
+        "aud": "test-audience",
+        "exp": int(time.time()) + 3600,
+    }
+
+    with caplog.at_level(logging.DEBUG, 
logger="superset.mcp_service.jwt_verifier"):
+        with patch.object(hs256_verifier.jwt, "decode", return_value=claims):
+            await hs256_verifier.load_access_token(token)
+
+    # The signing value must never appear at ANY log level
+    all_messages = [r.message for r in caplog.records]
+    for msg in all_messages:
+        assert hs256_signing_value not in msg, f"HS256 secret leaked in log: 
{msg}"
+
+
[email protected]
+async def test_expired_token_during_decode(hs256_verifier):
+    """ExpiredTokenError raised by jwt.decode should set generic reason."""
+    token = _make_token(
+        {"alg": "HS256", "typ": "JWT"},
+        {
+            "sub": "user1",
+            "iss": "test-issuer",
+            "aud": "test-audience",
+            "exp": int(time.time()) - 3600,
+        },
+    )
+
+    with patch.object(
+        hs256_verifier.jwt,
+        "decode",
+        side_effect=ExpiredTokenError(),
+    ):
+        result = await hs256_verifier.load_access_token(token)
+
+    assert result is None
+    reason = _jwt_failure_reason.get()
+    assert reason == "Token has expired (detected during decode)"
+
+
[email protected]
+async def test_catch_all_exception_sets_generic_reason(hs256_verifier):
+    """Catch-all handler should set generic reason without exception 
details."""
+    token = _make_token(
+        {"alg": "HS256", "typ": "JWT"},
+        {
+            "sub": "user1",
+            "iss": "test-issuer",
+            "aud": "test-audience",
+            "exp": int(time.time()) + 3600,
+        },
+    )
+    claims = {
+        "sub": "user1",
+        "iss": "test-issuer",
+        "aud": "test-audience",
+        "exp": int(time.time()) + 3600,
+    }
+
+    with patch.object(hs256_verifier.jwt, "decode", return_value=claims):
+        with patch.object(
+            hs256_verifier,
+            "_extract_scopes",
+            side_effect=TypeError("unexpected type in scopes"),
+        ):
+            result = await hs256_verifier.load_access_token(token)
+
+    assert result is None
+    reason = _jwt_failure_reason.get()
+    assert reason == "Token validation failed"
+    assert "unexpected type" not in reason


Reply via email to