codeant-ai-for-open-source[bot] commented on code in PR #37972:
URL: https://github.com/apache/superset/pull/37972#discussion_r2806684439


##########
superset/mcp_service/jwt_verifier.py:
##########
@@ -0,0 +1,294 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Detailed JWT verification for the MCP service.
+
+Provides step-by-step JWT validation with specific error messages
+instead of the generic "invalid_token" response from the base JWTVerifier.
+"""
+
+import base64
+import logging
+import time
+from contextvars import ContextVar
+from typing import Any, cast
+
+from authlib.jose.errors import (
+    BadSignatureError,
+    DecodeError,
+    ExpiredTokenError,
+    JoseError,
+)
+from fastmcp.server.auth.auth import AccessToken
+from fastmcp.server.auth.providers.jwt import JWTVerifier
+from mcp.server.auth.middleware.auth_context import AuthContextMiddleware
+from mcp.server.auth.middleware.bearer_auth import BearerAuthBackend
+from starlette.authentication import AuthenticationError
+from starlette.middleware import Middleware
+from starlette.middleware.authentication import AuthenticationMiddleware
+from starlette.requests import HTTPConnection
+from starlette.responses import JSONResponse
+
+from superset.utils import json
+
+logger = logging.getLogger(__name__)
+
+# Thread-safe storage for the specific JWT failure reason.
+# Set by DetailedJWTVerifier.load_access_token() on failure,
+# read by DetailedBearerAuthBackend.authenticate() to raise
+# an AuthenticationError with the specific reason.
+_jwt_failure_reason: ContextVar[str | None] = ContextVar(
+    "_jwt_failure_reason", default=None
+)
+
+
+def _json_auth_error_handler(
+    conn: HTTPConnection, exc: AuthenticationError
+) -> JSONResponse:
+    """Return a JSON 401 response with the specific JWT failure reason."""
+    return JSONResponse(
+        status_code=401,
+        content={
+            "error": "invalid_token",
+            "error_description": str(exc),
+        },
+        headers={
+            "WWW-Authenticate": f'Bearer error="invalid_token", '
+            f'error_description="{exc}"',

Review Comment:
   **Suggestion:** The error message is interpolated directly into the 
WWW-Authenticate header without sanitization, so a crafted token value that 
surfaces in the message (e.g., issuer or algorithm) could inject newline 
characters or quotes and lead to HTTP header injection or malformed headers; 
you should sanitize or normalize the message before placing it in the header. 
[security]
   
   <details>
   <summary><b>Severity Level:</b> Critical 🚨</summary>
   
   ```mdx
   - ❌ Attacker-controlled data injected into WWW-Authenticate header.
   - ⚠️ Possible HTTP response splitting or cache poisoning.
   - ⚠️ Risk of 500s from invalid response headers.
   ```
   </details>
   
   ```suggestion
       reason = str(exc)
       # Sanitize reason for use in HTTP header to prevent header injection
       safe_reason = reason.replace("\r", " ").replace("\n", " ").replace('"', 
"'")
       return JSONResponse(
           status_code=401,
           content={
               "error": "invalid_token",
               "error_description": reason,
           },
           headers={
               "WWW-Authenticate": f'Bearer error="invalid_token", '
               f'error_description="{safe_reason}"',
   ```
   <details>
   <summary><b>Steps of Reproduction ✅ </b></summary>
   
   ```mdx
   1. The MCP service configures authentication using 
`DetailedJWTVerifier.get_middleware()`
   at `superset/mcp_service/jwt_verifier.py:265-280`, which registers 
Starlette's
   `AuthenticationMiddleware` with `on_error=_json_auth_error_handler`.
   
   2. A client sends an HTTP request to any MCP endpoint protected by this 
middleware with an
   `Authorization: Bearer <token>` header where the token contains an issuer 
(`iss`) or
   audience (`aud`) claim including newline characters and quotes (e.g.,
   `"evil\"\r\nX-Injected: value"`).
   
   3. During token validation in `DetailedJWTVerifier.load_access_token()` 
(lines 119–263),
   an issuer or audience mismatch occurs at lines 185–207 or 211–233, creating 
an error
   string like `f"Issuer mismatch: token has '{iss}', expected 
'{self.issuer}'"` that
   includes the attacker-controlled newline and quote characters.
   
   4. `DetailedBearerAuthBackend.authenticate()` at lines 76–106 reads this 
reason from
   `_jwt_failure_reason` and raises `AuthenticationError(reason)`, which is 
passed as `exc`
   into `_json_auth_error_handler()` at lines 59–73; the handler interpolates 
`exc` directly
   into the `WWW-Authenticate` header at line 71 without sanitization, allowing 
the crafted
   CR/LF and quotes to break the header syntax and potentially inject 
additional HTTP headers
   or cause malformed-header errors, depending on the ASGI server's enforcement.
   ```
   </details>
   <details>
   <summary><b>Prompt for AI Agent 🤖 </b></summary>
   
   ```mdx
   This is a comment left during a code review.
   
   **Path:** superset/mcp_service/jwt_verifier.py
   **Line:** 63:71
   **Comment:**
        *Security: The error message is interpolated directly into the 
WWW-Authenticate header without sanitization, so a crafted token value that 
surfaces in the message (e.g., issuer or algorithm) could inject newline 
characters or quotes and lead to HTTP header injection or malformed headers; 
you should sanitize or normalize the message before placing it in the header.
   
   Validate the correctness of the flagged issue. If correct, How can I resolve 
this? If you propose a fix, implement it and please make it concise.
   ```
   </details>
   <a 
href='https://app.codeant.ai/feedback?pr_url=https%3A%2F%2Fgithub.com%2Fapache%2Fsuperset%2Fpull%2F37972&comment_hash=5b73506973a187754b51303945058724094e9591d5346bd941a717993651e407&reaction=like'>👍</a>
 | <a 
href='https://app.codeant.ai/feedback?pr_url=https%3A%2F%2Fgithub.com%2Fapache%2Fsuperset%2Fpull%2F37972&comment_hash=5b73506973a187754b51303945058724094e9591d5346bd941a717993651e407&reaction=dislike'>👎</a>



##########
superset/mcp_service/jwt_verifier.py:
##########
@@ -0,0 +1,294 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Detailed JWT verification for the MCP service.
+
+Provides step-by-step JWT validation with specific error messages
+instead of the generic "invalid_token" response from the base JWTVerifier.
+"""
+
+import base64
+import logging
+import time
+from contextvars import ContextVar
+from typing import Any, cast
+
+from authlib.jose.errors import (
+    BadSignatureError,
+    DecodeError,
+    ExpiredTokenError,
+    JoseError,
+)
+from fastmcp.server.auth.auth import AccessToken
+from fastmcp.server.auth.providers.jwt import JWTVerifier
+from mcp.server.auth.middleware.auth_context import AuthContextMiddleware
+from mcp.server.auth.middleware.bearer_auth import BearerAuthBackend
+from starlette.authentication import AuthenticationError
+from starlette.middleware import Middleware
+from starlette.middleware.authentication import AuthenticationMiddleware
+from starlette.requests import HTTPConnection
+from starlette.responses import JSONResponse
+
+from superset.utils import json
+
+logger = logging.getLogger(__name__)
+
+# Thread-safe storage for the specific JWT failure reason.
+# Set by DetailedJWTVerifier.load_access_token() on failure,
+# read by DetailedBearerAuthBackend.authenticate() to raise
+# an AuthenticationError with the specific reason.
+_jwt_failure_reason: ContextVar[str | None] = ContextVar(
+    "_jwt_failure_reason", default=None
+)
+
+
+def _json_auth_error_handler(
+    conn: HTTPConnection, exc: AuthenticationError
+) -> JSONResponse:
+    """Return a JSON 401 response with the specific JWT failure reason."""
+    return JSONResponse(
+        status_code=401,
+        content={
+            "error": "invalid_token",
+            "error_description": str(exc),
+        },
+        headers={
+            "WWW-Authenticate": f'Bearer error="invalid_token", '
+            f'error_description="{exc}"',
+        },
+    )
+
+
+class DetailedBearerAuthBackend(BearerAuthBackend):
+    """
+    Bearer auth backend that raises AuthenticationError with specific
+    JWT failure reasons instead of silently returning None.
+    """
+
+    async def authenticate(self, conn: HTTPConnection) -> tuple[Any, Any] | 
None:
+        result = await super().authenticate(conn)
+
+        if result is not None:
+            # Clear any stale failure reason on success
+            _jwt_failure_reason.set(None)
+            return result
+
+        # Check if there's a Bearer token present - if so, there was a
+        # validation failure we can report with a specific reason
+        auth_header = next(
+            (
+                conn.headers.get(key)
+                for key in conn.headers
+                if key.lower() == "authorization"
+            ),
+            None,
+        )
+        if auth_header and auth_header.lower().startswith("bearer "):
+            reason = _jwt_failure_reason.get()
+            if reason:
+                _jwt_failure_reason.set(None)
+                raise AuthenticationError(reason)
+
+        return None
+
+
+class DetailedJWTVerifier(JWTVerifier):
+    """
+    JWT verifier that provides specific error messages for each
+    validation failure instead of generic "invalid_token".
+
+    Overrides load_access_token() to perform step-by-step validation,
+    storing the specific failure reason in a ContextVar that the
+    custom BearerAuthBackend reads to return a descriptive 401 response.
+    """
+
+    async def load_access_token(self, token: str) -> AccessToken | None:  # 
noqa: C901
+        """
+        Validate a JWT bearer token with detailed error reporting.
+
+        Each validation step stores a specific failure reason in the
+        _jwt_failure_reason ContextVar before returning None.
+        """
+        # Reset any previous failure reason
+        _jwt_failure_reason.set(None)
+
+        try:
+            # Step 1: Decode header and check algorithm
+            try:
+                header = self._decode_token_header(token)
+            except (ValueError, DecodeError) as e:
+                reason = f"Malformed token header: {e}"
+                _jwt_failure_reason.set(reason)
+                logger.warning(reason)
+                return None
+
+            token_alg = header.get("alg")
+            if self.algorithm and token_alg != self.algorithm:
+                reason = (
+                    f"Algorithm mismatch: token uses '{token_alg}', "
+                    f"expected '{self.algorithm}'"
+                )
+                _jwt_failure_reason.set(reason)
+                logger.warning(reason)
+                return None
+
+            # Step 2: Get verification key (static or JWKS)
+            try:
+                verification_key = await self._get_verification_key(token)
+            except ValueError as e:
+                reason = f"Failed to get verification key: {e}"
+                _jwt_failure_reason.set(reason)
+                logger.warning(reason)
+                return None
+
+            # Step 3: Decode and verify signature
+            try:
+                claims = self.jwt.decode(token, verification_key)
+            except BadSignatureError:
+                reason = "Signature verification failed"
+                _jwt_failure_reason.set(reason)
+                logger.warning(reason)
+                return None
+            except ExpiredTokenError:
+                reason = "Token has expired (detected during decode)"
+                _jwt_failure_reason.set(reason)
+                logger.warning(reason)
+                return None
+            except JoseError as e:
+                reason = f"Token decode failed: {e}"
+                _jwt_failure_reason.set(reason)
+                logger.warning(reason)
+                return None
+
+            # Extract client ID for logging
+            client_id = (
+                claims.get("client_id")
+                or claims.get("azp")
+                or claims.get("sub")
+                or "unknown"
+            )
+
+            # Step 4: Check expiration
+            exp = claims.get("exp")
+            if exp and exp < time.time():
+                reason = f"Token expired for client '{client_id}'"
+                _jwt_failure_reason.set(reason)
+                logger.warning(reason)
+                return None
+
+            # Step 5: Validate issuer
+            if self.issuer:
+                iss = claims.get("iss")
+                if isinstance(self.issuer, list):
+                    issuer_valid = iss in self.issuer
+                else:
+                    issuer_valid = iss == self.issuer
+
+                if not issuer_valid:
+                    reason = (
+                        f"Issuer mismatch: token has '{iss}', expected 
'{self.issuer}'"
+                    )
+                    _jwt_failure_reason.set(reason)
+                    logger.warning(reason)
+                    return None
+
+            # Step 6: Validate audience
+            if self.audience:
+                aud = claims.get("aud")
+                if isinstance(self.audience, list):
+                    if isinstance(aud, list):
+                        audience_valid = any(
+                            expected in aud
+                            for expected in cast(list[str], self.audience)
+                        )
+                    else:
+                        audience_valid = aud in cast(list[str], self.audience)
+                else:
+                    if isinstance(aud, list):
+                        audience_valid = self.audience in aud
+                    else:
+                        audience_valid = aud == self.audience
+
+                if not audience_valid:
+                    reason = (
+                        f"Audience mismatch: token has '{aud}', "
+                        f"expected '{self.audience}'"
+                    )
+                    _jwt_failure_reason.set(reason)
+                    logger.warning(reason)
+                    return None
+
+            # Step 7: Check required scopes
+            scopes = self._extract_scopes(claims)
+            if self.required_scopes:
+                token_scopes = set(scopes)
+                required = set(self.required_scopes)
+                if not required.issubset(token_scopes):
+                    missing = required - token_scopes
+                    reason = (
+                        f"Missing required scopes: {missing}. Token has: 
{token_scopes}"
+                    )
+                    _jwt_failure_reason.set(reason)
+                    logger.warning(reason)
+                    return None
+
+            # All validations passed
+            logger.info("JWT validated for client '%s'", client_id)
+            return AccessToken(
+                token=token,
+                client_id=str(client_id),
+                scopes=scopes,
+                expires_at=int(exp) if exp else None,
+                claims=dict(claims),
+            )
+
+        except Exception as e:
+            reason = f"Token validation failed: {e}"
+            _jwt_failure_reason.set(reason)
+            logger.warning(reason)
+            return None
+
+    def get_middleware(self) -> list[Any]:
+        """
+        Get middleware with detailed error reporting.
+
+        Uses DetailedBearerAuthBackend which raises AuthenticationError
+        with specific reasons, and a JSON error handler that returns
+        structured 401 responses.
+        """
+        return [
+            Middleware(
+                AuthenticationMiddleware,
+                backend=DetailedBearerAuthBackend(self),
+                on_error=_json_auth_error_handler,
+            ),
+            Middleware(AuthContextMiddleware),
+        ]
+
+    @staticmethod
+    def _decode_token_header(token: str) -> dict[str, Any]:
+        """Decode the JWT header without verifying the signature."""
+        parts = token.split(".")
+        if len(parts) != 3:
+            raise ValueError(
+                f"Token must have 3 parts (header.payload.signature), got 
{len(parts)}"
+            )
+        header_b64 = parts[0]
+        # Add padding
+        header_b64 += "=" * (4 - len(header_b64) % 4)

Review Comment:
   **Suggestion:** The JWT header padding calculation always appends four '=' 
characters when the header length is already a multiple of 4, which can make 
otherwise valid JWTs fail to decode with a padding error and be incorrectly 
treated as having a malformed header. [logic error]
   
   <details>
   <summary><b>Severity Level:</b> Critical 🚨</summary>
   
   ```mdx
   - ❌ Valid JWTs rejected as "Malformed token header".
   - ❌ MCP JWT authentication fails for common HS256 tokens.
   - ⚠️ Breaks all callers relying on DetailedJWTVerifier auth.
   ```
   </details>
   
   ```suggestion
           # Add padding only if needed
           header_b64 += "=" * (-len(header_b64) % 4)
   ```
   <details>
   <summary><b>Steps of Reproduction ✅ </b></summary>
   
   ```mdx
   1. In a test module, import `DetailedJWTVerifier` and its helper 
`_decode_token_header`
   from `superset/mcp_service/jwt_verifier.py` (class `DetailedJWTVerifier` is 
defined above
   line 119, `_decode_token_header` at line 283).
   
   2. Construct a structurally valid JWT whose header segment length is a 
multiple of 4, for
   example using the common HS256 header 
`"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9"` and forming
   a dummy token `"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.foo.bar"`.
   
   3. Call `await DetailedJWTVerifier(...).load_access_token(token)` from
   `superset/mcp_service/jwt_verifier.py:119`, which immediately invokes
   `_decode_token_header(token)` at line 132 and reaches the padding logic at 
line 291
   (`header_b64 += "=" * (4 - len(header_b64) % 4)`).
   
   4. Observe that for a header length already divisible by 4, four `"="` 
characters are
   appended, and `base64.urlsafe_b64decode(header_b64)` at line 293 raises a 
`binascii.Error`
   (`ValueError`) for incorrect padding, causing the inner `except (ValueError, 
DecodeError)`
   at lines 133–135 in `load_access_token()` to treat the header as malformed 
and return
   `None` with reason `"Malformed token header"` even though the JWT is valid.
   ```
   </details>
   <details>
   <summary><b>Prompt for AI Agent 🤖 </b></summary>
   
   ```mdx
   This is a comment left during a code review.
   
   **Path:** superset/mcp_service/jwt_verifier.py
   **Line:** 292:292
   **Comment:**
        *Logic Error: The JWT header padding calculation always appends four 
'=' characters when the header length is already a multiple of 4, which can 
make otherwise valid JWTs fail to decode with a padding error and be 
incorrectly treated as having a malformed header.
   
   Validate the correctness of the flagged issue. If correct, How can I resolve 
this? If you propose a fix, implement it and please make it concise.
   ```
   </details>
   <a 
href='https://app.codeant.ai/feedback?pr_url=https%3A%2F%2Fgithub.com%2Fapache%2Fsuperset%2Fpull%2F37972&comment_hash=ce2dacdcdda27ee49038d963acfb04ac7e84c83103587b286ce96d00bde97f50&reaction=like'>👍</a>
 | <a 
href='https://app.codeant.ai/feedback?pr_url=https%3A%2F%2Fgithub.com%2Fapache%2Fsuperset%2Fpull%2F37972&comment_hash=ce2dacdcdda27ee49038d963acfb04ac7e84c83103587b286ce96d00bde97f50&reaction=dislike'>👎</a>



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to