This is an automated email from the ASF dual-hosted git repository.

vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new f1325133187 Fix Execution API refresh token (#58782)
f1325133187 is described below

commit f1325133187f830d465dcdf56ebf31a33bbaf340
Author: Pierre Jeambrun <[email protected]>
AuthorDate: Fri Nov 28 16:35:13 2025 +0100

    Fix Execution API refresh token (#58782)
---
 .../src/airflow/api_fastapi/execution_api/app.py   | 37 ++++++++++++-
 .../src/airflow/api_fastapi/execution_api/deps.py  | 62 +---------------------
 .../api_fastapi/execution_api/routes/__init__.py   |  4 +-
 .../execution_api/versions/head/test_router.py     |  8 ++-
 .../versions/head/test_task_instances.py           |  3 +-
 5 files changed, 43 insertions(+), 71 deletions(-)

diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/app.py 
b/airflow-core/src/airflow/api_fastapi/execution_api/app.py
index 9e30ffb5af5..9d93f3bf84d 100644
--- a/airflow-core/src/airflow/api_fastapi/execution_api/app.py
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/app.py
@@ -18,6 +18,7 @@
 from __future__ import annotations
 
 import json
+import time
 from contextlib import AsyncExitStack
 from functools import cached_property
 from typing import TYPE_CHECKING, Any
@@ -126,6 +127,39 @@ class CorrelationIdMiddleware(BaseHTTPMiddleware):
         return response
 
 
+class JWTReissueMiddleware(BaseHTTPMiddleware):
+    async def dispatch(self, request: Request, call_next):
+        from airflow.configuration import conf
+
+        response: Response = await call_next(request)
+
+        refreshed_token: str | None = None
+        auth_header = request.headers.get("authorization")
+        if auth_header and auth_header.lower().startswith("bearer "):
+            token = auth_header.split(" ", 1)[1]
+            try:
+                async with svcs.Container(request.app.state.svcs_registry) as 
services:
+                    validator: JWTValidator = await services.aget(JWTValidator)
+                    claims = await validator.avalidated_claims(token, {})
+
+                    now = int(time.time())
+                    validity = conf.getint("execution_api", 
"jwt_expiration_time")
+                    refresh_when_less_than = max(int(validity * 0.20), 30)
+                    valid_left = int(claims.get("exp", 0)) - now
+                    if valid_left <= refresh_when_less_than:
+                        generator: JWTGenerator = await 
services.aget(JWTGenerator)
+                        refreshed_token = generator.generate(claims)
+            except Exception as err:
+                # Do not block the response if refreshing fails; log a warning 
for visibility
+                logger.warning(
+                    "JWT reissue middleware failed to refresh token", 
error=str(err), exc_info=True
+                )
+
+        if refreshed_token:
+            response.headers["Refreshed-API-Token"] = refreshed_token
+        return response
+
+
 class CadwynWithOpenAPICustomization(Cadwyn):
     # Workaround lack of customzation 
https://github.com/zmievsa/cadwyn/issues/255
     async def openapi_jsons(self, req: Request) -> JSONResponse:
@@ -211,6 +245,7 @@ def create_task_execution_api_app() -> FastAPI:
 
     # Add correlation-id middleware for request tracing
     app.add_middleware(CorrelationIdMiddleware)
+    app.add_middleware(JWTReissueMiddleware)
 
     app.generate_and_include_versioned_routers(execution_api_router)
 
@@ -266,7 +301,6 @@ class InProcessExecutionAPI:
             from airflow.api_fastapi.execution_api.deps import (
                 JWTBearerDep,
                 JWTBearerTIPathDep,
-                JWTRefresherDep,
             )
             from airflow.api_fastapi.execution_api.routes.connections import 
has_connection_access
             from airflow.api_fastapi.execution_api.routes.variables import 
has_variable_access
@@ -281,7 +315,6 @@ class InProcessExecutionAPI:
 
             self._app.dependency_overrides[JWTBearerDep.dependency] = 
always_allow
             self._app.dependency_overrides[JWTBearerTIPathDep.dependency] = 
always_allow
-            self._app.dependency_overrides[JWTRefresherDep.dependency] = 
always_allow
             self._app.dependency_overrides[has_connection_access] = 
always_allow
             self._app.dependency_overrides[has_variable_access] = always_allow
             self._app.dependency_overrides[has_xcom_access] = always_allow
diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/deps.py 
b/airflow-core/src/airflow/api_fastapi/execution_api/deps.py
index 2648a64ffad..d247a31f5f4 100644
--- a/airflow-core/src/airflow/api_fastapi/execution_api/deps.py
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/deps.py
@@ -18,17 +18,14 @@
 # Disable future annotations in this file to work around 
https://github.com/fastapi/fastapi/issues/13056
 # ruff: noqa: I002
 
-import sys
-import time
 from typing import Any
 
 import structlog
 import svcs
-from fastapi import Depends, HTTPException, Request, Response, status
+from fastapi import Depends, HTTPException, Request, status
 from fastapi.security import HTTPBearer
-from starlette.exceptions import HTTPException as StarletteHTTPException
 
-from airflow.api_fastapi.auth.tokens import JWTGenerator, JWTValidator
+from airflow.api_fastapi.auth.tokens import JWTValidator
 from airflow.api_fastapi.execution_api.datamodels.token import TIToken
 
 log = structlog.get_logger(logger_name=__name__)
@@ -98,58 +95,3 @@ JWTBearerDep: TIToken = Depends(JWTBearer())
 
 # This checks that the UUID in the url matches the one in the token for us.
 JWTBearerTIPathDep = Depends(JWTBearer(path_param_name="task_instance_id"))
-
-
-class JWTReissuer:
-    """Re-issue JWTs to requests when they are about to run out."""
-
-    def __init__(self):
-        from airflow.configuration import conf
-
-        self.refresh_when_less_than = max(
-            # Issue a new token to a task when the current one is valid for 
only either 20% of the total validity,
-            # or 30s
-            int(conf.getint("execution_api", "jwt_expiration_time") * 0.20),
-            30,
-        )
-
-    async def __call__(
-        self,
-        response: Response,
-        token=JWTBearerDep,
-        services=DepContainer,
-    ):
-        try:
-            yield
-        finally:
-            # We want to run this even in the case of 404 errors etc
-            now = int(time.time())
-
-            try:
-                valid_left = token.claims["exp"] - now
-                if valid_left <= self.refresh_when_less_than:
-                    generator: JWTGenerator = await services.aget(JWTGenerator)
-                    new = generator.generate(token.claims)
-                    response.headers["Refreshed-API-Token"] = new
-                    log.debug(
-                        "Refreshed token issued to Task",
-                        valid_left=valid_left,
-                        refresh_when_less_than=self.refresh_when_less_than,
-                    )
-
-                    exc, val, _ = sys.exc_info()
-                    if val and isinstance(val, StarletteHTTPException):
-                        # If there is an exception thrown, we need to set the 
headers there instead
-                        if val.headers is None:
-                            val.headers = {}
-
-                        # Defined as a "mapping type", but 99.9% of the time 
it's a mutable dict. We catch
-                        # errors if not
-                        val.headers["Refreshed-API-Token"] = new  # type: 
ignore[index]
-
-            except Exception as e:
-                # Don't 500 if there's a problem
-                log.warning("Error refreshing Task JWT", 
err=f"{type(e).__name__}: {e}")
-
-
-JWTRefresherDep = Depends(JWTReissuer())
diff --git 
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py 
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py
index 89d96083876..562b8588fbf 100644
--- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py
@@ -19,7 +19,7 @@ from __future__ import annotations
 from cadwyn import VersionedAPIRouter
 from fastapi import APIRouter
 
-from airflow.api_fastapi.execution_api.deps import JWTBearerDep, 
JWTRefresherDep
+from airflow.api_fastapi.execution_api.deps import JWTBearerDep
 from airflow.api_fastapi.execution_api.routes import (
     asset_events,
     assets,
@@ -37,7 +37,7 @@ execution_api_router = APIRouter()
 execution_api_router.include_router(health.router, prefix="/health", 
tags=["Health"])
 
 # _Every_ single endpoint under here must be authenticated. Some do further 
checks on top of these
-authenticated_router = VersionedAPIRouter(dependencies=[JWTBearerDep, 
JWTRefresherDep])  # type: ignore[list-item]
+authenticated_router = VersionedAPIRouter(dependencies=[JWTBearerDep])  # 
type: ignore[list-item]
 
 authenticated_router.include_router(assets.router, prefix="/assets", 
tags=["Assets"])
 authenticated_router.include_router(asset_events.router, 
prefix="/asset-events", tags=["Asset Events"])
diff --git 
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_router.py
 
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_router.py
index e4aadcead6c..85f7df46915 100644
--- 
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_router.py
+++ 
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_router.py
@@ -24,7 +24,6 @@ from fastapi import FastAPI
 
 from airflow.api_fastapi.auth.tokens import JWTValidator
 from airflow.api_fastapi.execution_api.app import lifespan
-from airflow.api_fastapi.execution_api.deps import JWTRefresherDep, JWTReissuer
 
 from tests_common.test_utils.config import conf_vars
 
@@ -57,15 +56,14 @@ def test_expiring_token_is_reissued(
         "exp": moment + validity,
     }
 
-    with conf_vars({("execution_api", "jwt_expiration_time"): str(validity)}):
-        exec_app.dependency_overrides[JWTRefresherDep.dependency] = 
JWTReissuer()
-
     time_machine.move_to(moment + age, tick=False)
 
     # Inject our fake JWTValidator object. Can be over-ridden by tests if they 
want
     lifespan.registry.register_value(JWTValidator, auth)
     # In order to test this we need any endpoint to hit. The easiest one to 
use is variable get
-    response = client.get("/execution/variables/key1")
+
+    with conf_vars({("execution_api", "jwt_expiration_time"): str(validity)}):
+        response = client.get("/execution/variables/key1", 
headers={"Authorization": "Bearer dummy"})
 
     if expect_refreshed_token:
         assert "Refreshed-API-Token" in response.headers
diff --git 
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
 
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
index c410d4e4733..f91ae7f285f 100644
--- 
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
+++ 
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
@@ -98,7 +98,6 @@ def test_id_matches_sub_claim(client, session, 
create_task_instance):
             raise RuntimeError("Fake auth denied")
         return claims
 
-    # validator.avalidated_claims.side_effect = [{}, RuntimeError("fail for 
tests"), claims, claims]
     validator.avalidated_claims.side_effect = side_effect
 
     lifespan.registry.register_value(JWTValidator, validator)
@@ -113,7 +112,7 @@ def test_id_matches_sub_claim(client, session, 
create_task_instance):
 
     resp = 
client.patch("/execution/task-instances/9c230b40-da03-451d-8bd7-be30471be383/run",
 json=payload)
     assert resp.status_code == 403
-    validator.avalidated_claims.assert_called_with(
+    assert validator.avalidated_claims.call_args_list[1] == mock.call(
         mock.ANY, {"sub": {"essential": True, "value": 
"9c230b40-da03-451d-8bd7-be30471be383"}}
     )
     validator.avalidated_claims.reset_mock()

Reply via email to