This is an automated email from the ASF dual-hosted git repository.

vatsrahul1001 pushed a commit to branch v3-2-test
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/v3-2-test by this push:
     new 35d1f2cfee5 Validate task identity token claims with a typed schema 
(#63604) (#66988)
35d1f2cfee5 is described below

commit 35d1f2cfee5189e074a6fd7e6aa746593b4e3d98
Author: Rahul Vats <[email protected]>
AuthorDate: Mon May 18 09:48:30 2026 +0530

    Validate task identity token claims with a typed schema (#63604) (#66988)
    
    * Validate task identity token claims with a typed schema.
    
    * Add unit test test_rejects_invalid_uuid and 
test_rejects_missing_required_claims
    
    * delete redoundant parameters in TIClames.
    
    * fix mypy errors
    
    * remove sub
    
    (cherry picked from commit 7e85e7bea5ac072aef7072a117de857e59243c99)
    
    Co-authored-by: Henry Chen <[email protected]>
---
 .../src/airflow/api_fastapi/execution_api/app.py   |  5 +++--
 .../api_fastapi/execution_api/datamodels/token.py  | 23 ++++++++++++++++++----
 .../airflow/api_fastapi/execution_api/security.py  | 22 ++++++++++++---------
 .../unit/api_fastapi/execution_api/conftest.py     |  8 +++++---
 .../api_fastapi/execution_api/test_security.py     | 20 ++++++++++++++++---
 .../execution_api/versions/head/test_dag_runs.py   |  4 ++--
 .../versions/head/test_task_instances.py           |  9 ++++++++-
 7 files changed, 67 insertions(+), 24 deletions(-)

diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/app.py 
b/airflow-core/src/airflow/api_fastapi/execution_api/app.py
index 635d04e47f1..4f1f671cd83 100644
--- a/airflow-core/src/airflow/api_fastapi/execution_api/app.py
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/app.py
@@ -313,7 +313,7 @@ class InProcessExecutionAPI:
         if not self._app:
             from airflow.api_fastapi.common.dagbag import create_dag_bag
             from airflow.api_fastapi.execution_api.app import 
create_task_execution_api_app
-            from airflow.api_fastapi.execution_api.datamodels.token import 
TIToken
+            from airflow.api_fastapi.execution_api.datamodels.token import 
TIClaims, TIToken
             from airflow.api_fastapi.execution_api.routes.connections import 
has_connection_access
             from airflow.api_fastapi.execution_api.routes.variables import 
has_variable_access
             from airflow.api_fastapi.execution_api.routes.xcoms import 
has_xcom_access
@@ -330,7 +330,8 @@ class InProcessExecutionAPI:
                 ti_id = UUID(
                     request.path_params.get("task_instance_id", 
"00000000-0000-0000-0000-000000000000")
                 )
-                return TIToken(id=ti_id, claims={"scope": "execution"})
+                claims = TIClaims(scope="execution")
+                return TIToken(id=ti_id, claims=claims)
 
             self._app.dependency_overrides[_jwt_bearer] = always_allow
             self._app.dependency_overrides[has_connection_access] = 
always_allow
diff --git 
a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/token.py 
b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/token.py
index 43e52914fea..4c3b935f5aa 100644
--- a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/token.py
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/token.py
@@ -17,16 +17,31 @@
 
 from __future__ import annotations
 
-from typing import Any
+from typing import Literal
 from uuid import UUID
 
+from pydantic import ConfigDict
+
 from airflow.api_fastapi.core_api.base import BaseModel
 
+TokenScope = Literal["execution", "workload"]
+
+
+class TIClaims(BaseModel):
+    """
+    Validated JWT claims for a task identity token.
+
+    Only fields used by the Execution API (sub, scope) are explicitly typed.
+    JWTValidator already validates exp/iat/nbf/aud/etc. Extra claims are 
allowed.
+    """
+
+    model_config = ConfigDict(extra="allow")
+
+    scope: TokenScope = "execution"
+
 
-# TODO: This is a placeholder for Task Identity Token schema.
 class TIToken(BaseModel):
     """Task Identity Token."""
 
     id: UUID
-
-    claims: dict[str, Any]
+    claims: TIClaims
diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/security.py 
b/airflow-core/src/airflow/api_fastapi/execution_api/security.py
index 19a29f7405d..1100674b727 100644
--- a/airflow-core/src/airflow/api_fastapi/execution_api/security.py
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/security.py
@@ -67,23 +67,23 @@ Why ``ExecutionAPIRoute`` is needed:
 # Disable future annotations in this file to work around 
https://github.com/fastapi/fastapi/issues/13056
 # ruff: noqa: I002
 
-from typing import Any, Literal, get_args
+from typing import Any, get_args
 
 import structlog
 from fastapi import Depends, HTTPException, Request, status
 from fastapi.params import Security as SecurityParam
 from fastapi.routing import APIRoute
 from fastapi.security import HTTPBearer, SecurityScopes
+from pydantic import ValidationError
+from sqlalchemy import select
 
 from airflow.api_fastapi.auth.tokens import JWTValidator
-from airflow.api_fastapi.execution_api.datamodels.token import TIToken
+from airflow.api_fastapi.execution_api.datamodels.token import TIClaims, 
TIToken, TokenScope
 from airflow.api_fastapi.execution_api.deps import DepContainer
 
 log = structlog.get_logger(logger_name=__name__)
 
-TokenType = Literal["execution", "workload"]
-
-VALID_TOKEN_TYPES: frozenset[str] = frozenset(get_args(TokenType))
+VALID_TOKEN_TYPES: frozenset[str] = frozenset(get_args(TokenScope))
 
 _REQUEST_SCOPE_TOKEN_KEY = "ti_token"
 
@@ -129,7 +129,13 @@ class JWTBearer(HTTPBearer):
 
         claims.setdefault("scope", "execution")
 
-        token = TIToken(id=claims["sub"], claims=claims)
+        try:
+            claim_model = TIClaims(**claims)
+        except ValidationError as err:
+            log.warning("JWT claims did not match task identity token schema", 
exc_info=True)
+            raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, 
detail=f"Invalid auth token: {err}")
+
+        token = TIToken(id=claims["sub"], claims=claim_model)
         request.scope[_REQUEST_SCOPE_TOKEN_KEY] = token
         return token
 
@@ -151,7 +157,7 @@ async def require_auth(
     Token type enforcement reads ``route.allowed_token_types`` (precomputed
     by ``ExecutionAPIRoute``) or defaults to ``{"execution"}``.
     """
-    token_scope = token.claims.get("scope", "execution")
+    token_scope = token.claims.scope
 
     if token_scope not in VALID_TOKEN_TYPES:
         log.warning("Invalid token scope in claims", token_scope=token_scope, 
path=request.url.path)
@@ -227,8 +233,6 @@ async def get_team_name_dep(token=CurrentTIToken) -> str | 
None:
     if not conf.getboolean("core", "multi_team"):
         return None
 
-    from sqlalchemy import select
-
     from airflow.models import DagModel, TaskInstance
     from airflow.models.dagbundle import DagBundleModel
     from airflow.models.team import Team
diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/conftest.py 
b/airflow-core/tests/unit/api_fastapi/execution_api/conftest.py
index 78bd0548df9..20eed734701 100644
--- a/airflow-core/tests/unit/api_fastapi/execution_api/conftest.py
+++ b/airflow-core/tests/unit/api_fastapi/execution_api/conftest.py
@@ -19,16 +19,17 @@ from __future__ import annotations
 import pytest
 from fastapi import FastAPI, Request
 from fastapi.testclient import TestClient
+from starlette.routing import Mount
 
 from airflow.api_fastapi.app import cached_app
-from airflow.api_fastapi.execution_api.datamodels.token import TIToken
+from airflow.api_fastapi.execution_api.datamodels.token import TIClaims, 
TIToken
 from airflow.api_fastapi.execution_api.security import _jwt_bearer
 
 
 def _get_execution_api_app(root_app: FastAPI) -> FastAPI:
     """Find the mounted execution API sub-app."""
     for route in root_app.routes:
-        if hasattr(route, "path") and route.path == "/execution":
+        if isinstance(route, Mount) and route.path == "/execution" and 
isinstance(route.app, FastAPI):
             return route.app
     raise RuntimeError("Execution API sub-app not found")
 
@@ -48,7 +49,8 @@ def client(request: pytest.FixtureRequest):
         from uuid import UUID
 
         ti_id = UUID(request.path_params.get("task_instance_id", 
"00000000-0000-0000-0000-000000000000"))
-        return TIToken(id=ti_id, claims={"sub": str(ti_id), "scope": 
"execution"})
+        claims = TIClaims(scope="execution")
+        return TIToken(id=ti_id, claims=claims)
 
     exec_app.dependency_overrides[_jwt_bearer] = mock_jwt_bearer
 
diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/test_security.py 
b/airflow-core/tests/unit/api_fastapi/execution_api/test_security.py
index 29ca8901efb..985dfd32599 100644
--- a/airflow-core/tests/unit/api_fastapi/execution_api/test_security.py
+++ b/airflow-core/tests/unit/api_fastapi/execution_api/test_security.py
@@ -23,7 +23,7 @@ import pytest
 from fastapi import APIRouter, FastAPI, Request, Security
 from fastapi.testclient import TestClient
 
-from airflow.api_fastapi.execution_api.datamodels.token import TIToken
+from airflow.api_fastapi.execution_api.datamodels.token import TIClaims, 
TIToken, TokenScope
 from airflow.api_fastapi.execution_api.security import (
     ExecutionAPIRoute,
     _jwt_bearer,
@@ -32,6 +32,19 @@ from airflow.api_fastapi.execution_api.security import (
 )
 
 
+class TestTIClaims:
+    def test_defaults_scope_and_retains_extra(self):
+        claims = TIClaims(team="data")
+
+        assert claims.scope == "execution"
+        assert claims.team == "data"
+
+    def test_accepts_sub_as_extra_claim(self):
+        claims = TIClaims(sub="not-a-uuid")
+
+        assert claims.sub == "not-a-uuid"
+
+
 class TestExecutionAPIRoute:
     """Unit tests for ExecutionAPIRoute precomputing allowed_token_types from 
Security scopes."""
 
@@ -111,11 +124,12 @@ class TestTokenTypeScopeEnforcement:
 
     TI_ID = "00000000-0000-0000-0000-000000000001"
 
-    def _override_jwt(self, app, scope: str):
+    def _override_jwt(self, app, scope: TokenScope):
         ti_id = self.TI_ID
 
         async def mock_jwt(request: Request):
-            return TIToken(id=UUID(ti_id), claims={"scope": scope})
+            claims = TIClaims(scope=scope)
+            return TIToken(id=UUID(ti_id), claims=claims)
 
         app.dependency_overrides[_jwt_bearer] = mock_jwt
 
diff --git 
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_dag_runs.py
 
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_dag_runs.py
index 8d25936f77b..5191770c5f3 100644
--- 
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_dag_runs.py
+++ 
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_dag_runs.py
@@ -23,7 +23,7 @@ from fastapi import Request
 from sqlalchemy import select, update
 
 from airflow._shared.timezones import timezone
-from airflow.api_fastapi.execution_api.datamodels.token import TIToken
+from airflow.api_fastapi.execution_api.datamodels.token import TIClaims, 
TIToken
 from airflow.api_fastapi.execution_api.security import require_auth
 from airflow.models import DagModel
 from airflow.models.dagrun import DagRun
@@ -217,7 +217,7 @@ class TestDagRunTrigger:
         session.commit()
 
         async def auth_as_parent_ti(request: Request) -> TIToken:
-            return TIToken(id=parent_ti.id, claims={"scope": "execution"})
+            return TIToken(id=parent_ti.id, claims=TIClaims(scope="execution"))
 
         exec_app.dependency_overrides[require_auth] = auth_as_parent_ti
         try:
diff --git 
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
 
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
index 17591f696d9..6066dba5dc1 100644
--- 
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
+++ 
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
@@ -130,6 +130,9 @@ def test_id_matches_sub_claim(client, session, 
create_task_instance):
     validator.avalidated_claims.return_value = {
         "sub": str(ti.id),
         "scope": "execution",
+        "exp": 9999999999,
+        "iat": 1000000000,
+        "nbf": 1000000000,
     }
     lifespan.registry.register_value(JWTValidator, validator)
 
@@ -3447,6 +3450,7 @@ class TestTokenTypeValidation:
             "scope": "workload",
             "exp": 9999999999,
             "iat": 1000000000,
+            "nbf": 1000000000,
         }
         lifespan.registry.register_value(JWTValidator, validator)
 
@@ -3466,6 +3470,7 @@ class TestTokenTypeValidation:
             "scope": "execution",
             "exp": 9999999999,
             "iat": 1000000000,
+            "nbf": 1000000000,
         }
         lifespan.registry.register_value(JWTValidator, validator)
 
@@ -3484,6 +3489,7 @@ class TestTokenTypeValidation:
             "scope": "bogus:scope",
             "exp": 9999999999,
             "iat": 1000000000,
+            "nbf": 1000000000,
         }
         lifespan.registry.register_value(JWTValidator, validator)
 
@@ -3497,7 +3503,7 @@ class TestTokenTypeValidation:
 
         resp = client.patch(f"/execution/task-instances/{ti.id}/run", 
json=payload)
         assert resp.status_code == 403
-        assert "Invalid token scope" in resp.json()["detail"]
+        assert "Invalid auth token" in resp.json()["detail"]
 
     def test_no_scope_defaults_to_execution(self, client, session, 
create_task_instance):
         """Tokens without scope claim should default to 'execution'."""
@@ -3509,6 +3515,7 @@ class TestTokenTypeValidation:
             "sub": str(ti.id),
             "exp": 9999999999,
             "iat": 1000000000,
+            "nbf": 1000000000,
         }
         lifespan.registry.register_value(JWTValidator, validator)
 

Reply via email to